Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
singlestore-labs
GitHub Repository: singlestore-labs/singlestoredb-python
Path: blob/main/singlestoredb/mysql/protocol.py
469 views
1
# type: ignore
2
# Python implementation of low level MySQL client-server protocol
3
# http://dev.mysql.com/doc/internals/en/client-server-protocol.html
4
import struct
5
import sys
6
7
from . import err
8
from ..config import get_option
9
from ..utils.results import Description
10
from .charset import MBLENGTH
11
from .constants import EXTENDED_TYPE
12
from .constants import FIELD_TYPE
13
from .constants import SERVER_STATUS
14
from .constants import VECTOR_TYPE
15
16
17
DEBUG = get_option('debug.connection')
18
19
NULL_COLUMN = 251
20
UNSIGNED_CHAR_COLUMN = 251
21
UNSIGNED_SHORT_COLUMN = 252
22
UNSIGNED_INT24_COLUMN = 253
23
UNSIGNED_INT64_COLUMN = 254
24
25
26
def dump_packet(data): # pragma: no cover
27
28
def printable(data):
29
if 32 <= data < 127:
30
return chr(data)
31
return '.'
32
33
try:
34
print('packet length:', len(data))
35
for i in range(1, 7):
36
f = sys._getframe(i)
37
print('call[%d]: %s (line %d)' % (i, f.f_code.co_name, f.f_lineno))
38
print('-' * 66)
39
except ValueError:
40
pass
41
dump_data = [data[i: i + 16] for i in range(0, min(len(data), 256), 16)]
42
for d in dump_data:
43
print(
44
' '.join('{:02X}'.format(x) for x in d)
45
+ ' ' * (16 - len(d))
46
+ ' ' * 2
47
+ ''.join(printable(x) for x in d),
48
)
49
print('-' * 66)
50
print()
51
52
53
class MysqlPacket:
54
"""
55
Representation of a MySQL response packet.
56
57
Provides an interface for reading/parsing the packet results.
58
59
"""
60
61
__slots__ = ('_position', '_data')
62
63
def __init__(self, data, encoding):
64
self._position = 0
65
self._data = data
66
67
def get_all_data(self):
68
return self._data
69
70
def read(self, size):
71
"""Read the first 'size' bytes in packet and advance cursor past them."""
72
result = self._data[self._position: (self._position + size)]
73
if len(result) != size:
74
error = (
75
'Result length not requested length:\n'
76
'Expected=%s. Actual=%s. Position: %s. Data Length: %s'
77
% (size, len(result), self._position, len(self._data))
78
)
79
if DEBUG:
80
print(error)
81
self.dump()
82
raise AssertionError(error)
83
self._position += size
84
return result
85
86
def read_all(self):
87
"""
88
Read all remaining data in the packet.
89
90
(Subsequent read() will return errors.)
91
92
"""
93
result = self._data[self._position:]
94
self._position = None # ensure no subsequent read()
95
return result
96
97
def advance(self, length):
98
"""Advance the cursor in data buffer ``length`` bytes."""
99
new_position = self._position + length
100
if new_position < 0 or new_position > len(self._data):
101
raise Exception(
102
'Invalid advance amount (%s) for cursor. '
103
'Position=%s' % (length, new_position),
104
)
105
self._position = new_position
106
107
def rewind(self, position=0):
108
"""Set the position of the data buffer cursor to 'position'."""
109
if position < 0 or position > len(self._data):
110
raise Exception('Invalid position to rewind cursor to: %s.' % position)
111
self._position = position
112
113
def get_bytes(self, position, length=1):
114
"""
115
Get 'length' bytes starting at 'position'.
116
117
Position is start of payload (first four packet header bytes are not
118
included) starting at index '0'.
119
120
No error checking is done. If requesting outside end of buffer
121
an empty string (or string shorter than 'length') may be returned!
122
123
"""
124
return self._data[position: (position + length)]
125
126
def read_uint8(self):
127
result = self._data[self._position]
128
self._position += 1
129
return result
130
131
def read_uint16(self):
132
result = struct.unpack_from('<H', self._data, self._position)[0]
133
self._position += 2
134
return result
135
136
def read_uint24(self):
137
low, high = struct.unpack_from('<HB', self._data, self._position)
138
self._position += 3
139
return low + (high << 16)
140
141
def read_uint32(self):
142
result = struct.unpack_from('<I', self._data, self._position)[0]
143
self._position += 4
144
return result
145
146
def read_uint64(self):
147
result = struct.unpack_from('<Q', self._data, self._position)[0]
148
self._position += 8
149
return result
150
151
def read_string(self):
152
end_pos = self._data.find(b'\0', self._position)
153
if end_pos < 0:
154
return None
155
result = self._data[self._position: end_pos]
156
self._position = end_pos + 1
157
return result
158
159
def read_length_encoded_integer(self):
160
"""
161
Read a 'Length Coded Binary' number from the data buffer.
162
163
Length coded numbers can be anywhere from 1 to 9 bytes depending
164
on the value of the first byte.
165
166
"""
167
c = self.read_uint8()
168
if c == NULL_COLUMN:
169
return None
170
if c < UNSIGNED_CHAR_COLUMN:
171
return c
172
elif c == UNSIGNED_SHORT_COLUMN:
173
return self.read_uint16()
174
elif c == UNSIGNED_INT24_COLUMN:
175
return self.read_uint24()
176
elif c == UNSIGNED_INT64_COLUMN:
177
return self.read_uint64()
178
179
def read_length_coded_string(self):
180
"""
181
Read a 'Length Coded String' from the data buffer.
182
183
A 'Length Coded String' consists first of a length coded
184
(unsigned, positive) integer represented in 1-9 bytes followed by
185
that many bytes of binary data. (For example "cat" would be "3cat".)
186
187
"""
188
length = self.read_length_encoded_integer()
189
if length is None:
190
return None
191
return self.read(length)
192
193
def read_struct(self, fmt):
194
s = struct.Struct(fmt)
195
result = s.unpack_from(self._data, self._position)
196
self._position += s.size
197
return result
198
199
def is_ok_packet(self):
200
# https://dev.mysql.com/doc/internals/en/packet-OK_Packet.html
201
return self._data[0] == 0 and len(self._data) >= 7
202
203
def is_eof_packet(self):
204
# http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-EOF_Packet
205
# Caution: \xFE may be LengthEncodedInteger.
206
# If \xFE is LengthEncodedInteger header, 8bytes followed.
207
return self._data[0] == 0xFE and len(self._data) < 9
208
209
def is_auth_switch_request(self):
210
# http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest
211
return self._data[0] == 0xFE
212
213
def is_extra_auth_data(self):
214
# https://dev.mysql.com/doc/internals/en/successful-authentication.html
215
return self._data[0] == 1
216
217
def is_resultset_packet(self):
218
field_count = self._data[0]
219
return 1 <= field_count <= 250
220
221
def is_load_local_packet(self):
222
return self._data[0] == 0xFB
223
224
def is_error_packet(self):
225
return self._data[0] == 0xFF
226
227
def check_error(self):
228
if self.is_error_packet():
229
self.raise_for_error()
230
231
def raise_for_error(self):
232
self.rewind()
233
self.advance(1) # field_count == error (we already know that)
234
errno = self.read_uint16()
235
if DEBUG:
236
print('errno =', errno)
237
err.raise_mysql_exception(self._data)
238
239
def dump(self):
240
dump_packet(self._data)
241
242
243
class FieldDescriptorPacket(MysqlPacket):
244
"""
245
A MysqlPacket that represents a specific column's metadata in the result.
246
247
Parsing is automatically done and the results are exported via public
248
attributes on the class such as: db, table_name, name, length, type_code.
249
250
"""
251
252
def __init__(self, data, encoding):
253
MysqlPacket.__init__(self, data, encoding)
254
self._parse_field_descriptor(encoding)
255
256
def _parse_field_descriptor(self, encoding):
257
"""
258
Parse the 'Field Descriptor' (Metadata) packet.
259
260
This is compatible with MySQL 4.1+ (not compatible with MySQL 4.0).
261
262
"""
263
self.catalog = self.read_length_coded_string()
264
self.db = self.read_length_coded_string()
265
self.table_name = self.read_length_coded_string().decode(encoding)
266
self.org_table = self.read_length_coded_string().decode(encoding)
267
self.name = self.read_length_coded_string().decode(encoding)
268
self.org_name = self.read_length_coded_string().decode(encoding)
269
n_bytes = 0
270
(
271
n_bytes,
272
self.charsetnr,
273
self.length,
274
self.type_code,
275
self.flags,
276
self.scale,
277
) = self.read_struct('<BHIBHBxx')
278
279
# 'default' is a length coded binary and is still in the buffer?
280
# not used for normal result sets...
281
282
# Extended types
283
if n_bytes > 12:
284
ext_type_code = self.read_uint8()
285
if ext_type_code == EXTENDED_TYPE.NONE:
286
pass
287
elif ext_type_code == EXTENDED_TYPE.BSON:
288
self.type_code = FIELD_TYPE.BSON
289
elif ext_type_code == EXTENDED_TYPE.VECTOR:
290
(self.length, vec_type) = self.read_struct('<IB')
291
if vec_type == VECTOR_TYPE.FLOAT32:
292
if self.charsetnr == 63:
293
self.type_code = FIELD_TYPE.FLOAT32_VECTOR
294
else:
295
self.type_code = FIELD_TYPE.FLOAT32_VECTOR_JSON
296
elif vec_type == VECTOR_TYPE.FLOAT64:
297
if self.charsetnr == 63:
298
self.type_code = FIELD_TYPE.FLOAT64_VECTOR
299
else:
300
self.type_code = FIELD_TYPE.FLOAT64_VECTOR_JSON
301
elif vec_type == VECTOR_TYPE.INT8:
302
if self.charsetnr == 63:
303
self.type_code = FIELD_TYPE.INT8_VECTOR
304
else:
305
self.type_code = FIELD_TYPE.INT8_VECTOR_JSON
306
elif vec_type == VECTOR_TYPE.INT16:
307
if self.charsetnr == 63:
308
self.type_code = FIELD_TYPE.INT16_VECTOR
309
else:
310
self.type_code = FIELD_TYPE.INT16_VECTOR_JSON
311
elif vec_type == VECTOR_TYPE.INT32:
312
if self.charsetnr == 63:
313
self.type_code = FIELD_TYPE.INT32_VECTOR
314
else:
315
self.type_code = FIELD_TYPE.INT32_VECTOR_JSON
316
elif vec_type == VECTOR_TYPE.INT64:
317
if self.charsetnr == 63:
318
self.type_code = FIELD_TYPE.INT64_VECTOR
319
else:
320
self.type_code = FIELD_TYPE.INT64_VECTOR_JSON
321
else:
322
raise TypeError(f'unrecognized vector data type: {vec_type}')
323
else:
324
raise TypeError(f'unrecognized extended data type: {ext_type_code}')
325
326
def description(self):
327
"""
328
Provides a 9-item tuple.
329
330
Standard descriptions only have 7 fields according to the Python
331
PEP249 DB Spec, but we need to surface information about unsigned
332
types and charsetnr for proper type handling.
333
334
"""
335
precision = self.get_column_length()
336
if self.type_code in (FIELD_TYPE.DECIMAL, FIELD_TYPE.NEWDECIMAL):
337
if precision:
338
precision -= 1 # for the sign
339
if self.scale > 0:
340
precision -= 1 # for the decimal point
341
return Description(
342
self.name,
343
self.type_code,
344
None, # TODO: display_length; should this be self.length?
345
self.get_column_length(), # 'internal_size'
346
precision, # 'precision'
347
self.scale,
348
self.flags % 2 == 0,
349
self.flags,
350
self.charsetnr,
351
)
352
353
def get_column_length(self):
354
if self.type_code == FIELD_TYPE.VAR_STRING:
355
mblen = MBLENGTH.get(self.charsetnr, 1)
356
return self.length // mblen
357
return self.length
358
359
def __str__(self):
360
return '%s %r.%r.%r, type=%s, flags=%x, charsetnr=%s' % (
361
self.__class__,
362
self.db,
363
self.table_name,
364
self.name,
365
self.type_code,
366
self.flags,
367
self.charsetnr,
368
)
369
370
371
class OKPacketWrapper:
372
"""
373
OK Packet Wrapper.
374
375
It uses an existing packet object, and wraps around it, exposing
376
useful variables while still providing access to the original packet
377
objects variables and methods.
378
379
"""
380
381
def __init__(self, from_packet):
382
if not from_packet.is_ok_packet():
383
raise ValueError(
384
'Cannot create '
385
+ str(self.__class__.__name__)
386
+ ' object from invalid packet type',
387
)
388
389
self.packet = from_packet
390
self.packet.advance(1)
391
392
self.affected_rows = self.packet.read_length_encoded_integer()
393
self.insert_id = self.packet.read_length_encoded_integer()
394
self.server_status, self.warning_count = self.read_struct('<HH')
395
self.message = self.packet.read_all()
396
self.has_next = self.server_status & SERVER_STATUS.SERVER_MORE_RESULTS_EXISTS
397
398
def __getattr__(self, key):
399
return getattr(self.packet, key)
400
401
402
class EOFPacketWrapper:
403
"""
404
EOF Packet Wrapper.
405
406
It uses an existing packet object, and wraps around it, exposing
407
useful variables while still providing access to the original packet
408
objects variables and methods.
409
410
"""
411
412
def __init__(self, from_packet):
413
if not from_packet.is_eof_packet():
414
raise ValueError(
415
f"Cannot create '{self.__class__}' object from invalid packet type",
416
)
417
418
self.packet = from_packet
419
self.warning_count, self.server_status = self.packet.read_struct('<xhh')
420
if DEBUG:
421
print('server_status=', self.server_status)
422
self.has_next = self.server_status & SERVER_STATUS.SERVER_MORE_RESULTS_EXISTS
423
424
def __getattr__(self, key):
425
return getattr(self.packet, key)
426
427
428
class LoadLocalPacketWrapper:
429
"""
430
Load Local Packet Wrapper.
431
432
It uses an existing packet object, and wraps around it, exposing useful
433
variables while still providing access to the original packet
434
objects variables and methods.
435
436
"""
437
438
def __init__(self, from_packet):
439
if not from_packet.is_load_local_packet():
440
raise ValueError(
441
f"Cannot create '{self.__class__}' object from invalid packet type",
442
)
443
444
self.packet = from_packet
445
self.filename = self.packet.get_all_data()[1:]
446
if DEBUG:
447
print('filename=', self.filename)
448
449
def __getattr__(self, key):
450
return getattr(self.packet, key)
451
452