Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
singlestore-labs
GitHub Repository: singlestore-labs/singlestoredb-python
Path: blob/main/singlestoredb/mysql/protocol.py
801 views
1
# type: ignore
2
# Python implementation of low level MySQL client-server protocol
3
# http://dev.mysql.com/doc/internals/en/client-server-protocol.html
4
import struct
5
import sys
6
7
from . import err
8
from ..config import get_option
9
from ..utils.results import Description
10
from .charset import MBLENGTH
11
from .constants import EXTENDED_TYPE
12
from .constants import FIELD_TYPE
13
from .constants import SERVER_STATUS
14
from .constants import VECTOR_TYPE
15
16
17
DEBUG = get_option('debug.connection')
18
19
NULL_COLUMN = 251
20
UNSIGNED_CHAR_COLUMN = 251
21
UNSIGNED_SHORT_COLUMN = 252
22
UNSIGNED_INT24_COLUMN = 253
23
UNSIGNED_INT64_COLUMN = 254
24
25
26
def dump_packet(data): # pragma: no cover
27
28
def printable(data):
29
if 32 <= data < 127:
30
return chr(data)
31
return '.'
32
33
try:
34
print('packet length:', len(data))
35
for i in range(1, 7):
36
f = sys._getframe(i)
37
print('call[%d]: %s (line %d)' % (i, f.f_code.co_name, f.f_lineno))
38
print('-' * 66)
39
except ValueError:
40
pass
41
dump_data = [data[i: i + 16] for i in range(0, min(len(data), 256), 16)]
42
for d in dump_data:
43
print(
44
' '.join('{:02X}'.format(x) for x in d)
45
+ ' ' * (16 - len(d))
46
+ ' ' * 2
47
+ ''.join(printable(x) for x in d),
48
)
49
print('-' * 66)
50
print()
51
52
53
class MysqlPacket:
54
"""
55
Representation of a MySQL response packet.
56
57
Provides an interface for reading/parsing the packet results.
58
59
"""
60
61
__slots__ = ('_position', '_data')
62
63
def __init__(self, data, encoding):
64
self._position = 0
65
self._data = data
66
67
def get_all_data(self):
68
return self._data
69
70
def read(self, size):
71
"""Read the first 'size' bytes in packet and advance cursor past them."""
72
result = self._data[self._position: (self._position + size)]
73
if len(result) != size:
74
error = (
75
'Result length not requested length:\n'
76
'Expected=%s. Actual=%s. Position: %s. Data Length: %s'
77
% (size, len(result), self._position, len(self._data))
78
)
79
if DEBUG:
80
print(error)
81
self.dump()
82
raise AssertionError(error)
83
self._position += size
84
return result
85
86
def read_all(self):
87
"""
88
Read all remaining data in the packet.
89
90
(Subsequent read() will return errors.)
91
92
"""
93
result = self._data[self._position:]
94
self._position = None # ensure no subsequent read()
95
return result
96
97
def advance(self, length):
98
"""Advance the cursor in data buffer ``length`` bytes."""
99
new_position = self._position + length
100
if new_position < 0 or new_position > len(self._data):
101
raise Exception(
102
'Invalid advance amount (%s) for cursor. '
103
'Position=%s' % (length, new_position),
104
)
105
self._position = new_position
106
107
def rewind(self, position=0):
108
"""Set the position of the data buffer cursor to 'position'."""
109
if position < 0 or position > len(self._data):
110
raise Exception('Invalid position to rewind cursor to: %s.' % position)
111
self._position = position
112
113
def get_bytes(self, position, length=1):
114
"""
115
Get 'length' bytes starting at 'position'.
116
117
Position is start of payload (first four packet header bytes are not
118
included) starting at index '0'.
119
120
No error checking is done. If requesting outside end of buffer
121
an empty string (or string shorter than 'length') may be returned!
122
123
"""
124
return self._data[position: (position + length)]
125
126
def read_uint8(self):
127
result = self._data[self._position]
128
self._position += 1
129
return result
130
131
def read_uint16(self):
132
result = struct.unpack_from('<H', self._data, self._position)[0]
133
self._position += 2
134
return result
135
136
def read_uint24(self):
137
low, high = struct.unpack_from('<HB', self._data, self._position)
138
self._position += 3
139
return low + (high << 16)
140
141
def read_uint32(self):
142
result = struct.unpack_from('<I', self._data, self._position)[0]
143
self._position += 4
144
return result
145
146
def read_uint64(self):
147
result = struct.unpack_from('<Q', self._data, self._position)[0]
148
self._position += 8
149
return result
150
151
def read_string(self):
152
end_pos = self._data.find(b'\0', self._position)
153
if end_pos < 0:
154
return None
155
result = self._data[self._position: end_pos]
156
self._position = end_pos + 1
157
return result
158
159
def read_length_encoded_integer(self):
160
"""
161
Read a 'Length Coded Binary' number from the data buffer.
162
163
Length coded numbers can be anywhere from 1 to 9 bytes depending
164
on the value of the first byte.
165
166
"""
167
c = self.read_uint8()
168
if c == NULL_COLUMN:
169
return None
170
if c < UNSIGNED_CHAR_COLUMN:
171
return c
172
elif c == UNSIGNED_SHORT_COLUMN:
173
return self.read_uint16()
174
elif c == UNSIGNED_INT24_COLUMN:
175
return self.read_uint24()
176
elif c == UNSIGNED_INT64_COLUMN:
177
return self.read_uint64()
178
179
def read_length_coded_string(self):
180
"""
181
Read a 'Length Coded String' from the data buffer.
182
183
A 'Length Coded String' consists first of a length coded
184
(unsigned, positive) integer represented in 1-9 bytes followed by
185
that many bytes of binary data. (For example "cat" would be "3cat".)
186
187
"""
188
length = self.read_length_encoded_integer()
189
if length is None:
190
return None
191
return self.read(length)
192
193
def read_struct(self, fmt):
194
s = struct.Struct(fmt)
195
result = s.unpack_from(self._data, self._position)
196
self._position += s.size
197
return result
198
199
def is_ok_packet(self):
200
# https://dev.mysql.com/doc/internals/en/packet-OK_Packet.html
201
return self._data[0] == 0 and len(self._data) >= 7
202
203
def is_eof_packet(self):
204
# http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-EOF_Packet
205
# Caution: \xFE may be LengthEncodedInteger.
206
# If \xFE is LengthEncodedInteger header, 8bytes followed.
207
return self._data[0] == 0xFE and len(self._data) < 9
208
209
def is_auth_switch_request(self):
210
# http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest
211
return self._data[0] == 0xFE
212
213
def is_extra_auth_data(self):
214
# https://dev.mysql.com/doc/internals/en/successful-authentication.html
215
return self._data[0] == 1
216
217
def is_resultset_packet(self):
218
field_count = self._data[0]
219
return 1 <= field_count <= 250
220
221
def is_load_local_packet(self):
222
return self._data[0] == 0xFB
223
224
def is_error_packet(self):
225
return self._data[0] == 0xFF
226
227
def check_error(self):
228
if self.is_error_packet():
229
self.raise_for_error()
230
231
def raise_for_error(self):
232
self.rewind()
233
self.advance(1) # field_count == error (we already know that)
234
errno = self.read_uint16()
235
if DEBUG:
236
print('errno =', errno)
237
err.raise_mysql_exception(self._data)
238
239
def dump(self):
240
dump_packet(self._data)
241
242
243
class FieldDescriptorPacket(MysqlPacket):
244
"""
245
A MysqlPacket that represents a specific column's metadata in the result.
246
247
Parsing is automatically done and the results are exported via public
248
attributes on the class such as: db, table_name, name, length, type_code.
249
250
"""
251
252
def __init__(self, data, encoding):
253
MysqlPacket.__init__(self, data, encoding)
254
self._parse_field_descriptor(encoding)
255
256
def _parse_field_descriptor(self, encoding):
257
"""
258
Parse the 'Field Descriptor' (Metadata) packet.
259
260
This is compatible with MySQL 4.1+ (not compatible with MySQL 4.0).
261
262
"""
263
self.catalog = self.read_length_coded_string()
264
self.db = self.read_length_coded_string()
265
self.table_name = self.read_length_coded_string().decode(encoding)
266
self.org_table = self.read_length_coded_string().decode(encoding)
267
self.name = self.read_length_coded_string().decode(encoding)
268
self.org_name = self.read_length_coded_string().decode(encoding)
269
n_bytes = 0
270
(
271
n_bytes,
272
self.charsetnr,
273
self.length,
274
self.type_code,
275
self.flags,
276
self.scale,
277
) = self.read_struct('<BHIBHBxx')
278
279
# 'default' is a length coded binary and is still in the buffer?
280
# not used for normal result sets...
281
282
# Extended types
283
if n_bytes > 12:
284
ext_type_code = self.read_uint8()
285
if ext_type_code == EXTENDED_TYPE.NONE:
286
pass
287
elif ext_type_code == EXTENDED_TYPE.BSON:
288
self.type_code = FIELD_TYPE.BSON
289
elif ext_type_code == EXTENDED_TYPE.VECTOR:
290
(self.length, vec_type) = self.read_struct('<IB')
291
if vec_type == VECTOR_TYPE.FLOAT32:
292
if self.charsetnr == 63:
293
self.type_code = FIELD_TYPE.FLOAT32_VECTOR
294
else:
295
self.type_code = FIELD_TYPE.FLOAT32_VECTOR_JSON
296
elif vec_type == VECTOR_TYPE.FLOAT64:
297
if self.charsetnr == 63:
298
self.type_code = FIELD_TYPE.FLOAT64_VECTOR
299
else:
300
self.type_code = FIELD_TYPE.FLOAT64_VECTOR_JSON
301
elif vec_type == VECTOR_TYPE.INT8:
302
if self.charsetnr == 63:
303
self.type_code = FIELD_TYPE.INT8_VECTOR
304
else:
305
self.type_code = FIELD_TYPE.INT8_VECTOR_JSON
306
elif vec_type == VECTOR_TYPE.INT16:
307
if self.charsetnr == 63:
308
self.type_code = FIELD_TYPE.INT16_VECTOR
309
else:
310
self.type_code = FIELD_TYPE.INT16_VECTOR_JSON
311
elif vec_type == VECTOR_TYPE.INT32:
312
if self.charsetnr == 63:
313
self.type_code = FIELD_TYPE.INT32_VECTOR
314
else:
315
self.type_code = FIELD_TYPE.INT32_VECTOR_JSON
316
elif vec_type == VECTOR_TYPE.INT64:
317
if self.charsetnr == 63:
318
self.type_code = FIELD_TYPE.INT64_VECTOR
319
else:
320
self.type_code = FIELD_TYPE.INT64_VECTOR_JSON
321
elif vec_type == VECTOR_TYPE.FLOAT16:
322
if self.charsetnr == 63:
323
self.type_code = FIELD_TYPE.FLOAT16_VECTOR
324
else:
325
self.type_code = FIELD_TYPE.FLOAT16_VECTOR_JSON
326
else:
327
raise TypeError(f'unrecognized vector data type: {vec_type}')
328
else:
329
raise TypeError(f'unrecognized extended data type: {ext_type_code}')
330
331
def description(self):
332
"""
333
Provides a 9-item tuple.
334
335
Standard descriptions only have 7 fields according to the Python
336
PEP249 DB Spec, but we need to surface information about unsigned
337
types and charsetnr for proper type handling.
338
339
"""
340
precision = self.get_column_length()
341
if self.type_code in (FIELD_TYPE.DECIMAL, FIELD_TYPE.NEWDECIMAL):
342
if precision:
343
precision -= 1 # for the sign
344
if self.scale > 0:
345
precision -= 1 # for the decimal point
346
return Description(
347
self.name,
348
self.type_code,
349
None, # TODO: display_length; should this be self.length?
350
self.get_column_length(), # 'internal_size'
351
precision, # 'precision'
352
self.scale,
353
self.flags % 2 == 0,
354
self.flags,
355
self.charsetnr,
356
)
357
358
def get_column_length(self):
359
if self.type_code == FIELD_TYPE.VAR_STRING:
360
mblen = MBLENGTH.get(self.charsetnr, 1)
361
return self.length // mblen
362
return self.length
363
364
def __str__(self):
365
return '%s %r.%r.%r, type=%s, flags=%x, charsetnr=%s' % (
366
self.__class__,
367
self.db,
368
self.table_name,
369
self.name,
370
self.type_code,
371
self.flags,
372
self.charsetnr,
373
)
374
375
376
class OKPacketWrapper:
377
"""
378
OK Packet Wrapper.
379
380
It uses an existing packet object, and wraps around it, exposing
381
useful variables while still providing access to the original packet
382
objects variables and methods.
383
384
"""
385
386
def __init__(self, from_packet):
387
if not from_packet.is_ok_packet():
388
raise ValueError(
389
'Cannot create '
390
+ str(self.__class__.__name__)
391
+ ' object from invalid packet type',
392
)
393
394
self.packet = from_packet
395
self.packet.advance(1)
396
397
self.affected_rows = self.packet.read_length_encoded_integer()
398
self.insert_id = self.packet.read_length_encoded_integer()
399
self.server_status, self.warning_count = self.read_struct('<HH')
400
self.message = self.packet.read_all()
401
self.has_next = self.server_status & SERVER_STATUS.SERVER_MORE_RESULTS_EXISTS
402
403
def __getattr__(self, key):
404
return getattr(self.packet, key)
405
406
407
class EOFPacketWrapper:
408
"""
409
EOF Packet Wrapper.
410
411
It uses an existing packet object, and wraps around it, exposing
412
useful variables while still providing access to the original packet
413
objects variables and methods.
414
415
"""
416
417
def __init__(self, from_packet):
418
if not from_packet.is_eof_packet():
419
raise ValueError(
420
f"Cannot create '{self.__class__}' object from invalid packet type",
421
)
422
423
self.packet = from_packet
424
self.warning_count, self.server_status = self.packet.read_struct('<xhh')
425
if DEBUG:
426
print('server_status=', self.server_status)
427
self.has_next = self.server_status & SERVER_STATUS.SERVER_MORE_RESULTS_EXISTS
428
429
def __getattr__(self, key):
430
return getattr(self.packet, key)
431
432
433
class LoadLocalPacketWrapper:
434
"""
435
Load Local Packet Wrapper.
436
437
It uses an existing packet object, and wraps around it, exposing useful
438
variables while still providing access to the original packet
439
objects variables and methods.
440
441
"""
442
443
def __init__(self, from_packet):
444
if not from_packet.is_load_local_packet():
445
raise ValueError(
446
f"Cannot create '{self.__class__}' object from invalid packet type",
447
)
448
449
self.packet = from_packet
450
self.filename = self.packet.get_all_data()[1:]
451
if DEBUG:
452
print('filename=', self.filename)
453
454
def __getattr__(self, key):
455
return getattr(self.packet, key)
456
457