Path: blob/main/singlestoredb/mysql/protocol.py
801 views
# type: ignore1# Python implementation of low level MySQL client-server protocol2# http://dev.mysql.com/doc/internals/en/client-server-protocol.html3import struct4import sys56from . import err7from ..config import get_option8from ..utils.results import Description9from .charset import MBLENGTH10from .constants import EXTENDED_TYPE11from .constants import FIELD_TYPE12from .constants import SERVER_STATUS13from .constants import VECTOR_TYPE141516DEBUG = get_option('debug.connection')1718NULL_COLUMN = 25119UNSIGNED_CHAR_COLUMN = 25120UNSIGNED_SHORT_COLUMN = 25221UNSIGNED_INT24_COLUMN = 25322UNSIGNED_INT64_COLUMN = 254232425def dump_packet(data): # pragma: no cover2627def printable(data):28if 32 <= data < 127:29return chr(data)30return '.'3132try:33print('packet length:', len(data))34for i in range(1, 7):35f = sys._getframe(i)36print('call[%d]: %s (line %d)' % (i, f.f_code.co_name, f.f_lineno))37print('-' * 66)38except ValueError:39pass40dump_data = [data[i: i + 16] for i in range(0, min(len(data), 256), 16)]41for d in dump_data:42print(43' '.join('{:02X}'.format(x) for x in d)44+ ' ' * (16 - len(d))45+ ' ' * 246+ ''.join(printable(x) for x in d),47)48print('-' * 66)49print()505152class MysqlPacket:53"""54Representation of a MySQL response packet.5556Provides an interface for reading/parsing the packet results.5758"""5960__slots__ = ('_position', '_data')6162def __init__(self, data, encoding):63self._position = 064self._data = data6566def get_all_data(self):67return self._data6869def read(self, size):70"""Read the first 'size' bytes in packet and advance cursor past them."""71result = self._data[self._position: (self._position + size)]72if len(result) != size:73error = (74'Result length not requested length:\n'75'Expected=%s. Actual=%s. Position: %s. Data Length: %s'76% (size, len(result), self._position, len(self._data))77)78if DEBUG:79print(error)80self.dump()81raise AssertionError(error)82self._position += size83return result8485def read_all(self):86"""87Read all remaining data in the packet.8889(Subsequent read() will return errors.)9091"""92result = self._data[self._position:]93self._position = None # ensure no subsequent read()94return result9596def advance(self, length):97"""Advance the cursor in data buffer ``length`` bytes."""98new_position = self._position + length99if new_position < 0 or new_position > len(self._data):100raise Exception(101'Invalid advance amount (%s) for cursor. '102'Position=%s' % (length, new_position),103)104self._position = new_position105106def rewind(self, position=0):107"""Set the position of the data buffer cursor to 'position'."""108if position < 0 or position > len(self._data):109raise Exception('Invalid position to rewind cursor to: %s.' % position)110self._position = position111112def get_bytes(self, position, length=1):113"""114Get 'length' bytes starting at 'position'.115116Position is start of payload (first four packet header bytes are not117included) starting at index '0'.118119No error checking is done. If requesting outside end of buffer120an empty string (or string shorter than 'length') may be returned!121122"""123return self._data[position: (position + length)]124125def read_uint8(self):126result = self._data[self._position]127self._position += 1128return result129130def read_uint16(self):131result = struct.unpack_from('<H', self._data, self._position)[0]132self._position += 2133return result134135def read_uint24(self):136low, high = struct.unpack_from('<HB', self._data, self._position)137self._position += 3138return low + (high << 16)139140def read_uint32(self):141result = struct.unpack_from('<I', self._data, self._position)[0]142self._position += 4143return result144145def read_uint64(self):146result = struct.unpack_from('<Q', self._data, self._position)[0]147self._position += 8148return result149150def read_string(self):151end_pos = self._data.find(b'\0', self._position)152if end_pos < 0:153return None154result = self._data[self._position: end_pos]155self._position = end_pos + 1156return result157158def read_length_encoded_integer(self):159"""160Read a 'Length Coded Binary' number from the data buffer.161162Length coded numbers can be anywhere from 1 to 9 bytes depending163on the value of the first byte.164165"""166c = self.read_uint8()167if c == NULL_COLUMN:168return None169if c < UNSIGNED_CHAR_COLUMN:170return c171elif c == UNSIGNED_SHORT_COLUMN:172return self.read_uint16()173elif c == UNSIGNED_INT24_COLUMN:174return self.read_uint24()175elif c == UNSIGNED_INT64_COLUMN:176return self.read_uint64()177178def read_length_coded_string(self):179"""180Read a 'Length Coded String' from the data buffer.181182A 'Length Coded String' consists first of a length coded183(unsigned, positive) integer represented in 1-9 bytes followed by184that many bytes of binary data. (For example "cat" would be "3cat".)185186"""187length = self.read_length_encoded_integer()188if length is None:189return None190return self.read(length)191192def read_struct(self, fmt):193s = struct.Struct(fmt)194result = s.unpack_from(self._data, self._position)195self._position += s.size196return result197198def is_ok_packet(self):199# https://dev.mysql.com/doc/internals/en/packet-OK_Packet.html200return self._data[0] == 0 and len(self._data) >= 7201202def is_eof_packet(self):203# http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-EOF_Packet204# Caution: \xFE may be LengthEncodedInteger.205# If \xFE is LengthEncodedInteger header, 8bytes followed.206return self._data[0] == 0xFE and len(self._data) < 9207208def is_auth_switch_request(self):209# http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest210return self._data[0] == 0xFE211212def is_extra_auth_data(self):213# https://dev.mysql.com/doc/internals/en/successful-authentication.html214return self._data[0] == 1215216def is_resultset_packet(self):217field_count = self._data[0]218return 1 <= field_count <= 250219220def is_load_local_packet(self):221return self._data[0] == 0xFB222223def is_error_packet(self):224return self._data[0] == 0xFF225226def check_error(self):227if self.is_error_packet():228self.raise_for_error()229230def raise_for_error(self):231self.rewind()232self.advance(1) # field_count == error (we already know that)233errno = self.read_uint16()234if DEBUG:235print('errno =', errno)236err.raise_mysql_exception(self._data)237238def dump(self):239dump_packet(self._data)240241242class FieldDescriptorPacket(MysqlPacket):243"""244A MysqlPacket that represents a specific column's metadata in the result.245246Parsing is automatically done and the results are exported via public247attributes on the class such as: db, table_name, name, length, type_code.248249"""250251def __init__(self, data, encoding):252MysqlPacket.__init__(self, data, encoding)253self._parse_field_descriptor(encoding)254255def _parse_field_descriptor(self, encoding):256"""257Parse the 'Field Descriptor' (Metadata) packet.258259This is compatible with MySQL 4.1+ (not compatible with MySQL 4.0).260261"""262self.catalog = self.read_length_coded_string()263self.db = self.read_length_coded_string()264self.table_name = self.read_length_coded_string().decode(encoding)265self.org_table = self.read_length_coded_string().decode(encoding)266self.name = self.read_length_coded_string().decode(encoding)267self.org_name = self.read_length_coded_string().decode(encoding)268n_bytes = 0269(270n_bytes,271self.charsetnr,272self.length,273self.type_code,274self.flags,275self.scale,276) = self.read_struct('<BHIBHBxx')277278# 'default' is a length coded binary and is still in the buffer?279# not used for normal result sets...280281# Extended types282if n_bytes > 12:283ext_type_code = self.read_uint8()284if ext_type_code == EXTENDED_TYPE.NONE:285pass286elif ext_type_code == EXTENDED_TYPE.BSON:287self.type_code = FIELD_TYPE.BSON288elif ext_type_code == EXTENDED_TYPE.VECTOR:289(self.length, vec_type) = self.read_struct('<IB')290if vec_type == VECTOR_TYPE.FLOAT32:291if self.charsetnr == 63:292self.type_code = FIELD_TYPE.FLOAT32_VECTOR293else:294self.type_code = FIELD_TYPE.FLOAT32_VECTOR_JSON295elif vec_type == VECTOR_TYPE.FLOAT64:296if self.charsetnr == 63:297self.type_code = FIELD_TYPE.FLOAT64_VECTOR298else:299self.type_code = FIELD_TYPE.FLOAT64_VECTOR_JSON300elif vec_type == VECTOR_TYPE.INT8:301if self.charsetnr == 63:302self.type_code = FIELD_TYPE.INT8_VECTOR303else:304self.type_code = FIELD_TYPE.INT8_VECTOR_JSON305elif vec_type == VECTOR_TYPE.INT16:306if self.charsetnr == 63:307self.type_code = FIELD_TYPE.INT16_VECTOR308else:309self.type_code = FIELD_TYPE.INT16_VECTOR_JSON310elif vec_type == VECTOR_TYPE.INT32:311if self.charsetnr == 63:312self.type_code = FIELD_TYPE.INT32_VECTOR313else:314self.type_code = FIELD_TYPE.INT32_VECTOR_JSON315elif vec_type == VECTOR_TYPE.INT64:316if self.charsetnr == 63:317self.type_code = FIELD_TYPE.INT64_VECTOR318else:319self.type_code = FIELD_TYPE.INT64_VECTOR_JSON320elif vec_type == VECTOR_TYPE.FLOAT16:321if self.charsetnr == 63:322self.type_code = FIELD_TYPE.FLOAT16_VECTOR323else:324self.type_code = FIELD_TYPE.FLOAT16_VECTOR_JSON325else:326raise TypeError(f'unrecognized vector data type: {vec_type}')327else:328raise TypeError(f'unrecognized extended data type: {ext_type_code}')329330def description(self):331"""332Provides a 9-item tuple.333334Standard descriptions only have 7 fields according to the Python335PEP249 DB Spec, but we need to surface information about unsigned336types and charsetnr for proper type handling.337338"""339precision = self.get_column_length()340if self.type_code in (FIELD_TYPE.DECIMAL, FIELD_TYPE.NEWDECIMAL):341if precision:342precision -= 1 # for the sign343if self.scale > 0:344precision -= 1 # for the decimal point345return Description(346self.name,347self.type_code,348None, # TODO: display_length; should this be self.length?349self.get_column_length(), # 'internal_size'350precision, # 'precision'351self.scale,352self.flags % 2 == 0,353self.flags,354self.charsetnr,355)356357def get_column_length(self):358if self.type_code == FIELD_TYPE.VAR_STRING:359mblen = MBLENGTH.get(self.charsetnr, 1)360return self.length // mblen361return self.length362363def __str__(self):364return '%s %r.%r.%r, type=%s, flags=%x, charsetnr=%s' % (365self.__class__,366self.db,367self.table_name,368self.name,369self.type_code,370self.flags,371self.charsetnr,372)373374375class OKPacketWrapper:376"""377OK Packet Wrapper.378379It uses an existing packet object, and wraps around it, exposing380useful variables while still providing access to the original packet381objects variables and methods.382383"""384385def __init__(self, from_packet):386if not from_packet.is_ok_packet():387raise ValueError(388'Cannot create '389+ str(self.__class__.__name__)390+ ' object from invalid packet type',391)392393self.packet = from_packet394self.packet.advance(1)395396self.affected_rows = self.packet.read_length_encoded_integer()397self.insert_id = self.packet.read_length_encoded_integer()398self.server_status, self.warning_count = self.read_struct('<HH')399self.message = self.packet.read_all()400self.has_next = self.server_status & SERVER_STATUS.SERVER_MORE_RESULTS_EXISTS401402def __getattr__(self, key):403return getattr(self.packet, key)404405406class EOFPacketWrapper:407"""408EOF Packet Wrapper.409410It uses an existing packet object, and wraps around it, exposing411useful variables while still providing access to the original packet412objects variables and methods.413414"""415416def __init__(self, from_packet):417if not from_packet.is_eof_packet():418raise ValueError(419f"Cannot create '{self.__class__}' object from invalid packet type",420)421422self.packet = from_packet423self.warning_count, self.server_status = self.packet.read_struct('<xhh')424if DEBUG:425print('server_status=', self.server_status)426self.has_next = self.server_status & SERVER_STATUS.SERVER_MORE_RESULTS_EXISTS427428def __getattr__(self, key):429return getattr(self.packet, key)430431432class LoadLocalPacketWrapper:433"""434Load Local Packet Wrapper.435436It uses an existing packet object, and wraps around it, exposing useful437variables while still providing access to the original packet438objects variables and methods.439440"""441442def __init__(self, from_packet):443if not from_packet.is_load_local_packet():444raise ValueError(445f"Cannot create '{self.__class__}' object from invalid packet type",446)447448self.packet = from_packet449self.filename = self.packet.get_all_data()[1:]450if DEBUG:451print('filename=', self.filename)452453def __getattr__(self, key):454return getattr(self.packet, key)455456457