Path: blob/main/singlestoredb/mysql/protocol.py
469 views
# type: ignore1# Python implementation of low level MySQL client-server protocol2# http://dev.mysql.com/doc/internals/en/client-server-protocol.html3import struct4import sys56from . import err7from ..config import get_option8from ..utils.results import Description9from .charset import MBLENGTH10from .constants import EXTENDED_TYPE11from .constants import FIELD_TYPE12from .constants import SERVER_STATUS13from .constants import VECTOR_TYPE141516DEBUG = get_option('debug.connection')1718NULL_COLUMN = 25119UNSIGNED_CHAR_COLUMN = 25120UNSIGNED_SHORT_COLUMN = 25221UNSIGNED_INT24_COLUMN = 25322UNSIGNED_INT64_COLUMN = 254232425def dump_packet(data): # pragma: no cover2627def printable(data):28if 32 <= data < 127:29return chr(data)30return '.'3132try:33print('packet length:', len(data))34for i in range(1, 7):35f = sys._getframe(i)36print('call[%d]: %s (line %d)' % (i, f.f_code.co_name, f.f_lineno))37print('-' * 66)38except ValueError:39pass40dump_data = [data[i: i + 16] for i in range(0, min(len(data), 256), 16)]41for d in dump_data:42print(43' '.join('{:02X}'.format(x) for x in d)44+ ' ' * (16 - len(d))45+ ' ' * 246+ ''.join(printable(x) for x in d),47)48print('-' * 66)49print()505152class MysqlPacket:53"""54Representation of a MySQL response packet.5556Provides an interface for reading/parsing the packet results.5758"""5960__slots__ = ('_position', '_data')6162def __init__(self, data, encoding):63self._position = 064self._data = data6566def get_all_data(self):67return self._data6869def read(self, size):70"""Read the first 'size' bytes in packet and advance cursor past them."""71result = self._data[self._position: (self._position + size)]72if len(result) != size:73error = (74'Result length not requested length:\n'75'Expected=%s. Actual=%s. Position: %s. Data Length: %s'76% (size, len(result), self._position, len(self._data))77)78if DEBUG:79print(error)80self.dump()81raise AssertionError(error)82self._position += size83return result8485def read_all(self):86"""87Read all remaining data in the packet.8889(Subsequent read() will return errors.)9091"""92result = self._data[self._position:]93self._position = None # ensure no subsequent read()94return result9596def advance(self, length):97"""Advance the cursor in data buffer ``length`` bytes."""98new_position = self._position + length99if new_position < 0 or new_position > len(self._data):100raise Exception(101'Invalid advance amount (%s) for cursor. '102'Position=%s' % (length, new_position),103)104self._position = new_position105106def rewind(self, position=0):107"""Set the position of the data buffer cursor to 'position'."""108if position < 0 or position > len(self._data):109raise Exception('Invalid position to rewind cursor to: %s.' % position)110self._position = position111112def get_bytes(self, position, length=1):113"""114Get 'length' bytes starting at 'position'.115116Position is start of payload (first four packet header bytes are not117included) starting at index '0'.118119No error checking is done. If requesting outside end of buffer120an empty string (or string shorter than 'length') may be returned!121122"""123return self._data[position: (position + length)]124125def read_uint8(self):126result = self._data[self._position]127self._position += 1128return result129130def read_uint16(self):131result = struct.unpack_from('<H', self._data, self._position)[0]132self._position += 2133return result134135def read_uint24(self):136low, high = struct.unpack_from('<HB', self._data, self._position)137self._position += 3138return low + (high << 16)139140def read_uint32(self):141result = struct.unpack_from('<I', self._data, self._position)[0]142self._position += 4143return result144145def read_uint64(self):146result = struct.unpack_from('<Q', self._data, self._position)[0]147self._position += 8148return result149150def read_string(self):151end_pos = self._data.find(b'\0', self._position)152if end_pos < 0:153return None154result = self._data[self._position: end_pos]155self._position = end_pos + 1156return result157158def read_length_encoded_integer(self):159"""160Read a 'Length Coded Binary' number from the data buffer.161162Length coded numbers can be anywhere from 1 to 9 bytes depending163on the value of the first byte.164165"""166c = self.read_uint8()167if c == NULL_COLUMN:168return None169if c < UNSIGNED_CHAR_COLUMN:170return c171elif c == UNSIGNED_SHORT_COLUMN:172return self.read_uint16()173elif c == UNSIGNED_INT24_COLUMN:174return self.read_uint24()175elif c == UNSIGNED_INT64_COLUMN:176return self.read_uint64()177178def read_length_coded_string(self):179"""180Read a 'Length Coded String' from the data buffer.181182A 'Length Coded String' consists first of a length coded183(unsigned, positive) integer represented in 1-9 bytes followed by184that many bytes of binary data. (For example "cat" would be "3cat".)185186"""187length = self.read_length_encoded_integer()188if length is None:189return None190return self.read(length)191192def read_struct(self, fmt):193s = struct.Struct(fmt)194result = s.unpack_from(self._data, self._position)195self._position += s.size196return result197198def is_ok_packet(self):199# https://dev.mysql.com/doc/internals/en/packet-OK_Packet.html200return self._data[0] == 0 and len(self._data) >= 7201202def is_eof_packet(self):203# http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-EOF_Packet204# Caution: \xFE may be LengthEncodedInteger.205# If \xFE is LengthEncodedInteger header, 8bytes followed.206return self._data[0] == 0xFE and len(self._data) < 9207208def is_auth_switch_request(self):209# http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest210return self._data[0] == 0xFE211212def is_extra_auth_data(self):213# https://dev.mysql.com/doc/internals/en/successful-authentication.html214return self._data[0] == 1215216def is_resultset_packet(self):217field_count = self._data[0]218return 1 <= field_count <= 250219220def is_load_local_packet(self):221return self._data[0] == 0xFB222223def is_error_packet(self):224return self._data[0] == 0xFF225226def check_error(self):227if self.is_error_packet():228self.raise_for_error()229230def raise_for_error(self):231self.rewind()232self.advance(1) # field_count == error (we already know that)233errno = self.read_uint16()234if DEBUG:235print('errno =', errno)236err.raise_mysql_exception(self._data)237238def dump(self):239dump_packet(self._data)240241242class FieldDescriptorPacket(MysqlPacket):243"""244A MysqlPacket that represents a specific column's metadata in the result.245246Parsing is automatically done and the results are exported via public247attributes on the class such as: db, table_name, name, length, type_code.248249"""250251def __init__(self, data, encoding):252MysqlPacket.__init__(self, data, encoding)253self._parse_field_descriptor(encoding)254255def _parse_field_descriptor(self, encoding):256"""257Parse the 'Field Descriptor' (Metadata) packet.258259This is compatible with MySQL 4.1+ (not compatible with MySQL 4.0).260261"""262self.catalog = self.read_length_coded_string()263self.db = self.read_length_coded_string()264self.table_name = self.read_length_coded_string().decode(encoding)265self.org_table = self.read_length_coded_string().decode(encoding)266self.name = self.read_length_coded_string().decode(encoding)267self.org_name = self.read_length_coded_string().decode(encoding)268n_bytes = 0269(270n_bytes,271self.charsetnr,272self.length,273self.type_code,274self.flags,275self.scale,276) = self.read_struct('<BHIBHBxx')277278# 'default' is a length coded binary and is still in the buffer?279# not used for normal result sets...280281# Extended types282if n_bytes > 12:283ext_type_code = self.read_uint8()284if ext_type_code == EXTENDED_TYPE.NONE:285pass286elif ext_type_code == EXTENDED_TYPE.BSON:287self.type_code = FIELD_TYPE.BSON288elif ext_type_code == EXTENDED_TYPE.VECTOR:289(self.length, vec_type) = self.read_struct('<IB')290if vec_type == VECTOR_TYPE.FLOAT32:291if self.charsetnr == 63:292self.type_code = FIELD_TYPE.FLOAT32_VECTOR293else:294self.type_code = FIELD_TYPE.FLOAT32_VECTOR_JSON295elif vec_type == VECTOR_TYPE.FLOAT64:296if self.charsetnr == 63:297self.type_code = FIELD_TYPE.FLOAT64_VECTOR298else:299self.type_code = FIELD_TYPE.FLOAT64_VECTOR_JSON300elif vec_type == VECTOR_TYPE.INT8:301if self.charsetnr == 63:302self.type_code = FIELD_TYPE.INT8_VECTOR303else:304self.type_code = FIELD_TYPE.INT8_VECTOR_JSON305elif vec_type == VECTOR_TYPE.INT16:306if self.charsetnr == 63:307self.type_code = FIELD_TYPE.INT16_VECTOR308else:309self.type_code = FIELD_TYPE.INT16_VECTOR_JSON310elif vec_type == VECTOR_TYPE.INT32:311if self.charsetnr == 63:312self.type_code = FIELD_TYPE.INT32_VECTOR313else:314self.type_code = FIELD_TYPE.INT32_VECTOR_JSON315elif vec_type == VECTOR_TYPE.INT64:316if self.charsetnr == 63:317self.type_code = FIELD_TYPE.INT64_VECTOR318else:319self.type_code = FIELD_TYPE.INT64_VECTOR_JSON320else:321raise TypeError(f'unrecognized vector data type: {vec_type}')322else:323raise TypeError(f'unrecognized extended data type: {ext_type_code}')324325def description(self):326"""327Provides a 9-item tuple.328329Standard descriptions only have 7 fields according to the Python330PEP249 DB Spec, but we need to surface information about unsigned331types and charsetnr for proper type handling.332333"""334precision = self.get_column_length()335if self.type_code in (FIELD_TYPE.DECIMAL, FIELD_TYPE.NEWDECIMAL):336if precision:337precision -= 1 # for the sign338if self.scale > 0:339precision -= 1 # for the decimal point340return Description(341self.name,342self.type_code,343None, # TODO: display_length; should this be self.length?344self.get_column_length(), # 'internal_size'345precision, # 'precision'346self.scale,347self.flags % 2 == 0,348self.flags,349self.charsetnr,350)351352def get_column_length(self):353if self.type_code == FIELD_TYPE.VAR_STRING:354mblen = MBLENGTH.get(self.charsetnr, 1)355return self.length // mblen356return self.length357358def __str__(self):359return '%s %r.%r.%r, type=%s, flags=%x, charsetnr=%s' % (360self.__class__,361self.db,362self.table_name,363self.name,364self.type_code,365self.flags,366self.charsetnr,367)368369370class OKPacketWrapper:371"""372OK Packet Wrapper.373374It uses an existing packet object, and wraps around it, exposing375useful variables while still providing access to the original packet376objects variables and methods.377378"""379380def __init__(self, from_packet):381if not from_packet.is_ok_packet():382raise ValueError(383'Cannot create '384+ str(self.__class__.__name__)385+ ' object from invalid packet type',386)387388self.packet = from_packet389self.packet.advance(1)390391self.affected_rows = self.packet.read_length_encoded_integer()392self.insert_id = self.packet.read_length_encoded_integer()393self.server_status, self.warning_count = self.read_struct('<HH')394self.message = self.packet.read_all()395self.has_next = self.server_status & SERVER_STATUS.SERVER_MORE_RESULTS_EXISTS396397def __getattr__(self, key):398return getattr(self.packet, key)399400401class EOFPacketWrapper:402"""403EOF Packet Wrapper.404405It uses an existing packet object, and wraps around it, exposing406useful variables while still providing access to the original packet407objects variables and methods.408409"""410411def __init__(self, from_packet):412if not from_packet.is_eof_packet():413raise ValueError(414f"Cannot create '{self.__class__}' object from invalid packet type",415)416417self.packet = from_packet418self.warning_count, self.server_status = self.packet.read_struct('<xhh')419if DEBUG:420print('server_status=', self.server_status)421self.has_next = self.server_status & SERVER_STATUS.SERVER_MORE_RESULTS_EXISTS422423def __getattr__(self, key):424return getattr(self.packet, key)425426427class LoadLocalPacketWrapper:428"""429Load Local Packet Wrapper.430431It uses an existing packet object, and wraps around it, exposing useful432variables while still providing access to the original packet433objects variables and methods.434435"""436437def __init__(self, from_packet):438if not from_packet.is_load_local_packet():439raise ValueError(440f"Cannot create '{self.__class__}' object from invalid packet type",441)442443self.packet = from_packet444self.filename = self.packet.get_all_data()[1:]445if DEBUG:446print('filename=', self.filename)447448def __getattr__(self, key):449return getattr(self.packet, key)450451452