Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
singlestore-labs
GitHub Repository: singlestore-labs/singlestoredb-python
Path: blob/main/singlestoredb/mysql/connection.py
469 views
1
# type: ignore
2
# Python implementation of the MySQL client-server protocol
3
# http://dev.mysql.com/doc/internals/en/client-server-protocol.html
4
# Error codes:
5
# https://dev.mysql.com/doc/refman/5.5/en/error-handling.html
6
import errno
7
import functools
8
import io
9
import os
10
import queue
11
import socket
12
import struct
13
import sys
14
import traceback
15
import warnings
16
from typing import Any
17
from typing import Dict
18
from typing import Iterable
19
20
try:
21
import _singlestoredb_accel
22
except (ImportError, ModuleNotFoundError):
23
_singlestoredb_accel = None
24
25
from . import _auth
26
from ..utils import events
27
28
from .charset import charset_by_name, charset_by_id
29
from .constants import CLIENT, COMMAND, CR, ER, FIELD_TYPE, SERVER_STATUS
30
from . import converters
31
from .cursors import (
32
Cursor,
33
CursorSV,
34
DictCursor,
35
DictCursorSV,
36
NamedtupleCursor,
37
NamedtupleCursorSV,
38
ArrowCursor,
39
ArrowCursorSV,
40
NumpyCursor,
41
NumpyCursorSV,
42
PandasCursor,
43
PandasCursorSV,
44
PolarsCursor,
45
PolarsCursorSV,
46
SSCursor,
47
SSCursorSV,
48
SSDictCursor,
49
SSDictCursorSV,
50
SSNamedtupleCursor,
51
SSNamedtupleCursorSV,
52
SSArrowCursor,
53
SSArrowCursorSV,
54
SSNumpyCursor,
55
SSNumpyCursorSV,
56
SSPandasCursor,
57
SSPandasCursorSV,
58
SSPolarsCursor,
59
SSPolarsCursorSV,
60
)
61
from .optionfile import Parser
62
from .protocol import (
63
dump_packet,
64
MysqlPacket,
65
FieldDescriptorPacket,
66
OKPacketWrapper,
67
EOFPacketWrapper,
68
LoadLocalPacketWrapper,
69
)
70
from . import err
71
from ..config import get_option
72
from .. import fusion
73
from .. import connection
74
from ..connection import Connection as BaseConnection
75
from ..utils.debug import log_query
76
77
try:
78
import ssl
79
80
SSL_ENABLED = True
81
except ImportError:
82
ssl = None
83
SSL_ENABLED = False
84
85
try:
86
import getpass
87
88
DEFAULT_USER = getpass.getuser()
89
del getpass
90
except (ImportError, KeyError):
91
# KeyError occurs when there's no entry in OS database for a current user.
92
DEFAULT_USER = None
93
94
DEBUG = get_option('debug.connection')
95
96
TEXT_TYPES = {
97
FIELD_TYPE.BIT,
98
FIELD_TYPE.BLOB,
99
FIELD_TYPE.LONG_BLOB,
100
FIELD_TYPE.MEDIUM_BLOB,
101
FIELD_TYPE.STRING,
102
FIELD_TYPE.TINY_BLOB,
103
FIELD_TYPE.VAR_STRING,
104
FIELD_TYPE.VARCHAR,
105
FIELD_TYPE.GEOMETRY,
106
FIELD_TYPE.BSON,
107
FIELD_TYPE.FLOAT32_VECTOR_JSON,
108
FIELD_TYPE.FLOAT64_VECTOR_JSON,
109
FIELD_TYPE.INT8_VECTOR_JSON,
110
FIELD_TYPE.INT16_VECTOR_JSON,
111
FIELD_TYPE.INT32_VECTOR_JSON,
112
FIELD_TYPE.INT64_VECTOR_JSON,
113
FIELD_TYPE.FLOAT32_VECTOR,
114
FIELD_TYPE.FLOAT64_VECTOR,
115
FIELD_TYPE.INT8_VECTOR,
116
FIELD_TYPE.INT16_VECTOR,
117
FIELD_TYPE.INT32_VECTOR,
118
FIELD_TYPE.INT64_VECTOR,
119
}
120
121
UNSET = 'unset'
122
123
DEFAULT_CHARSET = 'utf8mb4'
124
125
MAX_PACKET_LEN = 2**24 - 1
126
127
128
def _pack_int24(n):
129
return struct.pack('<I', n)[:3]
130
131
132
# https://dev.mysql.com/doc/internals/en/integer.html#packet-Protocol::LengthEncodedInteger
133
def _lenenc_int(i):
134
if i < 0:
135
raise ValueError(
136
'Encoding %d is less than 0 - no representation in LengthEncodedInteger' % i,
137
)
138
elif i < 0xFB:
139
return bytes([i])
140
elif i < (1 << 16):
141
return b'\xfc' + struct.pack('<H', i)
142
elif i < (1 << 24):
143
return b'\xfd' + struct.pack('<I', i)[:3]
144
elif i < (1 << 64):
145
return b'\xfe' + struct.pack('<Q', i)
146
else:
147
raise ValueError(
148
'Encoding %x is larger than %x - no representation in LengthEncodedInteger'
149
% (i, (1 << 64)),
150
)
151
152
153
class Connection(BaseConnection):
154
"""
155
Representation of a socket with a mysql server.
156
157
The proper way to get an instance of this class is to call
158
``connect()``.
159
160
Establish a connection to the SingleStoreDB database.
161
162
Parameters
163
----------
164
host : str, optional
165
Host where the database server is located.
166
user : str, optional
167
Username to log in as.
168
password : str, optional
169
Password to use.
170
database : str, optional
171
Database to use, None to not use a particular one.
172
port : int, optional
173
Server port to use, default is usually OK. (default: 3306)
174
bind_address : str, optional
175
When the client has multiple network interfaces, specify
176
the interface from which to connect to the host. Argument can be
177
a hostname or an IP address.
178
unix_socket : str, optional
179
Use a unix socket rather than TCP/IP.
180
read_timeout : int, optional
181
The timeout for reading from the connection in seconds
182
(default: None - no timeout)
183
write_timeout : int, optional
184
The timeout for writing to the connection in seconds
185
(default: None - no timeout)
186
charset : str, optional
187
Charset to use.
188
collation : str, optional
189
The charset collation
190
sql_mode : str, optional
191
Default SQL_MODE to use.
192
read_default_file : str, optional
193
Specifies my.cnf file to read these parameters from under the
194
[client] section.
195
conv : Dict[str, Callable[Any]], optional
196
Conversion dictionary to use instead of the default one.
197
This is used to provide custom marshalling and unmarshalling of types.
198
See converters.
199
use_unicode : bool, optional
200
Whether or not to default to unicode strings.
201
This option defaults to true.
202
client_flag : int, optional
203
Custom flags to send to MySQL. Find potential values in constants.CLIENT.
204
cursorclass : type, optional
205
Custom cursor class to use.
206
init_command : str, optional
207
Initial SQL statement to run when connection is established.
208
connect_timeout : int, optional
209
The timeout for connecting to the database in seconds.
210
(default: 10, min: 1, max: 31536000)
211
ssl : Dict[str, str], optional
212
A dict of arguments similar to mysql_ssl_set()'s parameters or
213
an ssl.SSLContext.
214
ssl_ca : str, optional
215
Path to the file that contains a PEM-formatted CA certificate.
216
ssl_cert : str, optional
217
Path to the file that contains a PEM-formatted client certificate.
218
ssl_cipher : str, optional
219
SSL ciphers to allow.
220
ssl_disabled : bool, optional
221
A boolean value that disables usage of TLS.
222
ssl_key : str, optional
223
Path to the file that contains a PEM-formatted private key for the
224
client certificate.
225
ssl_verify_cert : str, optional
226
Set to true to check the server certificate's validity.
227
ssl_verify_identity : bool, optional
228
Set to true to check the server's identity.
229
tls_sni_servername: str, optional
230
Set server host name for TLS connection
231
read_default_group : str, optional
232
Group to read from in the configuration file.
233
autocommit : bool, optional
234
Autocommit mode. None means use server default. (default: False)
235
local_infile : bool, optional
236
Boolean to enable the use of LOAD DATA LOCAL command. (default: False)
237
max_allowed_packet : int, optional
238
Max size of packet sent to server in bytes. (default: 16MB)
239
Only used to limit size of "LOAD LOCAL INFILE" data packet smaller
240
than default (16KB).
241
defer_connect : bool, optional
242
Don't explicitly connect on construction - wait for connect call.
243
(default: False)
244
auth_plugin_map : Dict[str, type], optional
245
A dict of plugin names to a class that processes that plugin.
246
The class will take the Connection object as the argument to the
247
constructor. The class needs an authenticate method taking an
248
authentication packet as an argument. For the dialog plugin, a
249
prompt(echo, prompt) method can be used (if no authenticate method)
250
for returning a string from the user. (experimental)
251
server_public_key : str, optional
252
SHA256 authentication plugin public key value. (default: None)
253
binary_prefix : bool, optional
254
Add _binary prefix on bytes and bytearray. (default: False)
255
compress :
256
Not supported.
257
named_pipe :
258
Not supported.
259
db : str, optional
260
**DEPRECATED** Alias for database.
261
passwd : str, optional
262
**DEPRECATED** Alias for password.
263
parse_json : bool, optional
264
Parse JSON values into Python objects?
265
invalid_values : Dict[int, Any], optional
266
Dictionary of values to use in place of invalid values
267
found during conversion of data. The default is to return the byte content
268
containing the invalid value. The keys are the integers associtated with
269
the column type.
270
pure_python : bool, optional
271
Should we ignore the C extension even if it's available?
272
This can be given explicitly using True or False, or if the value is None,
273
the C extension will be loaded if it is available. If set to False and
274
the C extension can't be loaded, a NotSupportedError is raised.
275
nan_as_null : bool, optional
276
Should NaN values be treated as NULLs in parameter substitution including
277
uploading data?
278
inf_as_null : bool, optional
279
Should Inf values be treated as NULLs in parameter substitution including
280
uploading data?
281
track_env : bool, optional
282
Should the connection track the SINGLESTOREDB_URL environment variable?
283
enable_extended_data_types : bool, optional
284
Should extended data types (BSON, vector) be enabled?
285
vector_data_format : str, optional
286
Specify the data type of vector values: json or binary
287
288
See `Connection <https://www.python.org/dev/peps/pep-0249/#connection-objects>`_
289
in the specification.
290
291
"""
292
293
driver = 'mysql'
294
paramstyle = 'pyformat'
295
296
_sock = None
297
_auth_plugin_name = ''
298
_closed = False
299
_secure = False
300
_tls_sni_servername = None
301
302
def __init__( # noqa: C901
303
self,
304
*,
305
user=None, # The first four arguments is based on DB-API 2.0 recommendation.
306
password='',
307
host=None,
308
database=None,
309
unix_socket=None,
310
port=0,
311
charset='',
312
collation=None,
313
sql_mode=None,
314
read_default_file=None,
315
conv=None,
316
use_unicode=True,
317
client_flag=0,
318
cursorclass=None,
319
init_command=None,
320
connect_timeout=10,
321
read_default_group=None,
322
autocommit=False,
323
local_infile=False,
324
max_allowed_packet=16 * 1024 * 1024,
325
defer_connect=False,
326
auth_plugin_map=None,
327
read_timeout=None,
328
write_timeout=None,
329
bind_address=None,
330
binary_prefix=False,
331
program_name=None,
332
server_public_key=None,
333
ssl=None,
334
ssl_ca=None,
335
ssl_cert=None,
336
ssl_cipher=None,
337
ssl_disabled=None,
338
ssl_key=None,
339
ssl_verify_cert=None,
340
ssl_verify_identity=None,
341
tls_sni_servername=None,
342
parse_json=True,
343
invalid_values=None,
344
pure_python=None,
345
buffered=True,
346
results_type='tuples',
347
compress=None, # not supported
348
named_pipe=None, # not supported
349
passwd=None, # deprecated
350
db=None, # deprecated
351
driver=None, # internal use
352
conn_attrs=None,
353
multi_statements=None,
354
client_found_rows=None,
355
nan_as_null=None,
356
inf_as_null=None,
357
encoding_errors='strict',
358
track_env=False,
359
enable_extended_data_types=True,
360
vector_data_format='binary',
361
):
362
BaseConnection.__init__(**dict(locals()))
363
364
if db is not None and database is None:
365
# We will raise warning in 2022 or later.
366
# See https://github.com/PyMySQL/PyMySQL/issues/939
367
# warnings.warn("'db' is deprecated, use 'database'", DeprecationWarning, 3)
368
database = db
369
if passwd is not None and not password:
370
# We will raise warning in 2022 or later.
371
# See https://github.com/PyMySQL/PyMySQL/issues/939
372
# warnings.warn(
373
# "'passwd' is deprecated, use 'password'", DeprecationWarning, 3
374
# )
375
password = passwd
376
377
if compress or named_pipe:
378
raise NotImplementedError(
379
'compress and named_pipe arguments are not supported',
380
)
381
382
self._local_infile = bool(local_infile)
383
self._local_infile_stream = None
384
if self._local_infile:
385
client_flag |= CLIENT.LOCAL_FILES
386
if multi_statements:
387
client_flag |= CLIENT.MULTI_STATEMENTS
388
if client_found_rows:
389
client_flag |= CLIENT.FOUND_ROWS
390
391
if read_default_group and not read_default_file:
392
if sys.platform.startswith('win'):
393
read_default_file = 'c:\\my.ini'
394
else:
395
read_default_file = '/etc/my.cnf'
396
397
if read_default_file:
398
if not read_default_group:
399
read_default_group = 'client'
400
401
cfg = Parser()
402
cfg.read(os.path.expanduser(read_default_file))
403
404
def _config(key, arg):
405
if arg:
406
return arg
407
try:
408
return cfg.get(read_default_group, key)
409
except Exception:
410
return arg
411
412
user = _config('user', user)
413
password = _config('password', password)
414
host = _config('host', host)
415
database = _config('database', database)
416
unix_socket = _config('socket', unix_socket)
417
port = int(_config('port', port))
418
bind_address = _config('bind-address', bind_address)
419
charset = _config('default-character-set', charset)
420
if not ssl:
421
ssl = {}
422
if isinstance(ssl, dict):
423
for key in ['ca', 'capath', 'cert', 'key', 'cipher']:
424
value = _config('ssl-' + key, ssl.get(key))
425
if value:
426
ssl[key] = value
427
428
self.ssl = False
429
if not ssl_disabled:
430
if ssl_ca or ssl_cert or ssl_key or ssl_cipher or \
431
ssl_verify_cert or ssl_verify_identity:
432
ssl = {
433
'ca': ssl_ca,
434
'check_hostname': bool(ssl_verify_identity),
435
'verify_mode': ssl_verify_cert
436
if ssl_verify_cert is not None
437
else False,
438
}
439
if ssl_cert is not None:
440
ssl['cert'] = ssl_cert
441
if ssl_key is not None:
442
ssl['key'] = ssl_key
443
if ssl_cipher is not None:
444
ssl['cipher'] = ssl_cipher
445
if ssl:
446
if not SSL_ENABLED:
447
raise NotImplementedError('ssl module not found')
448
self.ssl = True
449
client_flag |= CLIENT.SSL
450
self.ctx = self._create_ssl_ctx(ssl)
451
452
self.host = host or 'localhost'
453
self.port = port or 3306
454
if type(self.port) is not int:
455
raise ValueError('port should be of type int')
456
self.user = user or DEFAULT_USER
457
self.password = password or b''
458
if isinstance(self.password, str):
459
self.password = self.password.encode('latin1')
460
self.db = database
461
self.unix_socket = unix_socket
462
self.bind_address = bind_address
463
if not (0 < connect_timeout <= 31536000):
464
raise ValueError('connect_timeout should be >0 and <=31536000')
465
self.connect_timeout = connect_timeout or None
466
if read_timeout is not None and read_timeout <= 0:
467
raise ValueError('read_timeout should be > 0')
468
self._read_timeout = read_timeout
469
if write_timeout is not None and write_timeout <= 0:
470
raise ValueError('write_timeout should be > 0')
471
self._write_timeout = write_timeout
472
473
self.charset = charset or DEFAULT_CHARSET
474
self.collation = collation
475
self.use_unicode = use_unicode
476
self.encoding_errors = encoding_errors
477
478
self.encoding = charset_by_name(self.charset).encoding
479
480
client_flag |= CLIENT.CAPABILITIES
481
client_flag |= CLIENT.CONNECT_WITH_DB
482
483
self.client_flag = client_flag
484
485
self.pure_python = pure_python
486
self.results_type = results_type
487
self.resultclass = MySQLResult
488
if cursorclass is not None:
489
self.cursorclass = cursorclass
490
elif buffered:
491
if 'dict' in self.results_type:
492
self.cursorclass = DictCursor
493
elif 'namedtuple' in self.results_type:
494
self.cursorclass = NamedtupleCursor
495
elif 'numpy' in self.results_type:
496
self.cursorclass = NumpyCursor
497
elif 'arrow' in self.results_type:
498
self.cursorclass = ArrowCursor
499
elif 'pandas' in self.results_type:
500
self.cursorclass = PandasCursor
501
elif 'polars' in self.results_type:
502
self.cursorclass = PolarsCursor
503
else:
504
self.cursorclass = Cursor
505
else:
506
if 'dict' in self.results_type:
507
self.cursorclass = SSDictCursor
508
elif 'namedtuple' in self.results_type:
509
self.cursorclass = SSNamedtupleCursor
510
elif 'numpy' in self.results_type:
511
self.cursorclass = SSNumpyCursor
512
elif 'arrow' in self.results_type:
513
self.cursorclass = SSArrowCursor
514
elif 'pandas' in self.results_type:
515
self.cursorclass = SSPandasCursor
516
elif 'polars' in self.results_type:
517
self.cursorclass = SSPolarsCursor
518
else:
519
self.cursorclass = SSCursor
520
521
if self.pure_python is False and _singlestoredb_accel is None:
522
try:
523
import _singlestortedb_accel # noqa: F401
524
except Exception:
525
import traceback
526
traceback.print_exc(file=sys.stderr)
527
finally:
528
raise err.NotSupportedError(
529
'pure_python=False, but the '
530
'C extension can not be loaded',
531
)
532
533
if self.pure_python is True:
534
pass
535
536
# The C extension handles these types internally.
537
elif _singlestoredb_accel is not None:
538
self.resultclass = MySQLResultSV
539
if self.cursorclass is Cursor:
540
self.cursorclass = CursorSV
541
elif self.cursorclass is SSCursor:
542
self.cursorclass = SSCursorSV
543
elif self.cursorclass is DictCursor:
544
self.cursorclass = DictCursorSV
545
self.results_type = 'dicts'
546
elif self.cursorclass is SSDictCursor:
547
self.cursorclass = SSDictCursorSV
548
self.results_type = 'dicts'
549
elif self.cursorclass is NamedtupleCursor:
550
self.cursorclass = NamedtupleCursorSV
551
self.results_type = 'namedtuples'
552
elif self.cursorclass is SSNamedtupleCursor:
553
self.cursorclass = SSNamedtupleCursorSV
554
self.results_type = 'namedtuples'
555
elif self.cursorclass is NumpyCursor:
556
self.cursorclass = NumpyCursorSV
557
self.results_type = 'numpy'
558
elif self.cursorclass is SSNumpyCursor:
559
self.cursorclass = SSNumpyCursorSV
560
self.results_type = 'numpy'
561
elif self.cursorclass is ArrowCursor:
562
self.cursorclass = ArrowCursorSV
563
self.results_type = 'arrow'
564
elif self.cursorclass is SSArrowCursor:
565
self.cursorclass = SSArrowCursorSV
566
self.results_type = 'arrow'
567
elif self.cursorclass is PandasCursor:
568
self.cursorclass = PandasCursorSV
569
self.results_type = 'pandas'
570
elif self.cursorclass is SSPandasCursor:
571
self.cursorclass = SSPandasCursorSV
572
self.results_type = 'pandas'
573
elif self.cursorclass is PolarsCursor:
574
self.cursorclass = PolarsCursorSV
575
self.results_type = 'polars'
576
elif self.cursorclass is SSPolarsCursor:
577
self.cursorclass = SSPolarsCursorSV
578
self.results_type = 'polars'
579
580
self._result = None
581
self._affected_rows = 0
582
self.host_info = 'Not connected'
583
584
# specified autocommit mode. None means use server default.
585
self.autocommit_mode = autocommit
586
587
if conv is None:
588
conv = converters.conversions
589
590
conv = conv.copy()
591
592
self.parse_json = parse_json
593
self.invalid_values = (invalid_values or {}).copy()
594
595
# Disable JSON parsing for Arrow
596
if self.results_type in ['arrow']:
597
conv[245] = None
598
self.parse_json = False
599
600
# Disable date/time parsing for polars; let polars do the parsing
601
elif self.results_type in ['polars']:
602
conv[7] = None
603
conv[10] = None
604
conv[12] = None
605
606
# Need for MySQLdb compatibility.
607
self.encoders = {k: v for (k, v) in conv.items() if type(k) is not int}
608
self.decoders = {k: v for (k, v) in conv.items() if type(k) is int}
609
self.sql_mode = sql_mode
610
self.init_command = init_command
611
self.max_allowed_packet = max_allowed_packet
612
self._auth_plugin_map = auth_plugin_map or {}
613
self._binary_prefix = binary_prefix
614
self.server_public_key = server_public_key
615
616
if self.connection_params['nan_as_null'] or \
617
self.connection_params['inf_as_null']:
618
float_encoder = self.encoders.get(float)
619
if float_encoder is not None:
620
self.encoders[float] = functools.partial(
621
float_encoder,
622
nan_as_null=self.connection_params['nan_as_null'],
623
inf_as_null=self.connection_params['inf_as_null'],
624
)
625
626
from .. import __version__ as VERSION_STRING
627
628
if 'SINGLESTOREDB_WORKLOAD_TYPE' in os.environ:
629
VERSION_STRING += '+' + os.environ['SINGLESTOREDB_WORKLOAD_TYPE']
630
631
self._connect_attrs = {
632
'_os': str(sys.platform),
633
'_pid': str(os.getpid()),
634
'_client_name': 'SingleStoreDB Python Client',
635
'_client_version': VERSION_STRING,
636
}
637
638
if program_name:
639
self._connect_attrs['program_name'] = program_name
640
if conn_attrs is not None:
641
# do not overwrite the attributes that we set ourselves
642
for k, v in conn_attrs.items():
643
if k not in self._connect_attrs:
644
self._connect_attrs[k] = v
645
646
self._is_committable = True
647
self._in_sync = False
648
self._tls_sni_servername = tls_sni_servername
649
self._track_env = bool(track_env) or self.host == 'singlestore.com'
650
self._enable_extended_data_types = enable_extended_data_types
651
if vector_data_format.lower() in ['json', 'binary']:
652
self._vector_data_format = vector_data_format
653
else:
654
raise ValueError(
655
'unknown value for vector_data_format, '
656
f'expecting "json" or "binary": {vector_data_format}',
657
)
658
self._connection_info = {}
659
events.subscribe(self._handle_event)
660
661
if defer_connect or self._track_env:
662
self._sock = None
663
else:
664
self.connect()
665
666
def _handle_event(self, data: Dict[str, Any]) -> None:
667
if data.get('name', '') == 'singlestore.portal.connection_updated':
668
self._connection_info = dict(data)
669
670
@property
671
def messages(self):
672
# TODO
673
[]
674
675
def __enter__(self):
676
return self
677
678
def __exit__(self, *exc_info):
679
del exc_info
680
self.close()
681
682
def _raise_mysql_exception(self, data):
683
err.raise_mysql_exception(data)
684
685
def _create_ssl_ctx(self, sslp):
686
if isinstance(sslp, ssl.SSLContext):
687
return sslp
688
ca = sslp.get('ca')
689
capath = sslp.get('capath')
690
hasnoca = ca is None and capath is None
691
ctx = ssl.create_default_context(cafile=ca, capath=capath)
692
ctx.check_hostname = not hasnoca and sslp.get('check_hostname', True)
693
verify_mode_value = sslp.get('verify_mode')
694
if verify_mode_value is None:
695
ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED
696
elif isinstance(verify_mode_value, bool):
697
ctx.verify_mode = ssl.CERT_REQUIRED if verify_mode_value else ssl.CERT_NONE
698
else:
699
if isinstance(verify_mode_value, str):
700
verify_mode_value = verify_mode_value.lower()
701
if verify_mode_value in ('none', '0', 'false', 'no'):
702
ctx.verify_mode = ssl.CERT_NONE
703
elif verify_mode_value == 'optional':
704
ctx.verify_mode = ssl.CERT_OPTIONAL
705
elif verify_mode_value in ('required', '1', 'true', 'yes'):
706
ctx.verify_mode = ssl.CERT_REQUIRED
707
else:
708
ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED
709
if 'cert' in sslp:
710
ctx.load_cert_chain(sslp['cert'], keyfile=sslp.get('key'))
711
if 'cipher' in sslp:
712
ctx.set_ciphers(sslp['cipher'])
713
ctx.options |= ssl.OP_NO_SSLv2
714
ctx.options |= ssl.OP_NO_SSLv3
715
return ctx
716
717
def close(self):
718
"""
719
Send the quit message and close the socket.
720
721
See `Connection.close()
722
<https://www.python.org/dev/peps/pep-0249/#Connection.close>`_
723
in the specification.
724
725
Raises
726
------
727
Error : If the connection is already closed.
728
729
"""
730
self._result = None
731
if self.host == 'singlestore.com':
732
return
733
if self._closed:
734
raise err.Error('Already closed')
735
events.unsubscribe(self._handle_event)
736
self._closed = True
737
if self._sock is None:
738
return
739
send_data = struct.pack('<iB', 1, COMMAND.COM_QUIT)
740
try:
741
self._write_bytes(send_data)
742
except Exception:
743
pass
744
finally:
745
self._force_close()
746
747
@property
748
def open(self):
749
"""Return True if the connection is open."""
750
return self._sock is not None
751
752
def is_connected(self):
753
"""Return True if the connection is open."""
754
return self.open
755
756
def _force_close(self):
757
"""Close connection without QUIT message."""
758
if self._sock:
759
try:
760
self._sock.close()
761
except: # noqa
762
pass
763
self._sock = None
764
self._rfile = None
765
766
__del__ = _force_close
767
768
def autocommit(self, value):
769
"""Enable autocommit in the server."""
770
self.autocommit_mode = bool(value)
771
current = self.get_autocommit()
772
if value != current:
773
self._send_autocommit_mode()
774
775
def get_autocommit(self):
776
"""Retrieve autocommit status."""
777
return bool(self.server_status & SERVER_STATUS.SERVER_STATUS_AUTOCOMMIT)
778
779
def _read_ok_packet(self):
780
pkt = self._read_packet()
781
if not pkt.is_ok_packet():
782
raise err.OperationalError(
783
CR.CR_COMMANDS_OUT_OF_SYNC,
784
'Command Out of Sync',
785
)
786
ok = OKPacketWrapper(pkt)
787
self.server_status = ok.server_status
788
return ok
789
790
def _send_autocommit_mode(self):
791
"""Set whether or not to commit after every execute()."""
792
log_query('SET AUTOCOMMIT = %s' % self.escape(self.autocommit_mode))
793
self._execute_command(
794
COMMAND.COM_QUERY, 'SET AUTOCOMMIT = %s' % self.escape(self.autocommit_mode),
795
)
796
self._read_ok_packet()
797
798
def begin(self):
799
"""Begin transaction."""
800
log_query('BEGIN')
801
if self.host == 'singlestore.com':
802
return
803
self._execute_command(COMMAND.COM_QUERY, 'BEGIN')
804
self._read_ok_packet()
805
806
def commit(self):
807
"""
808
Commit changes to stable storage.
809
810
See `Connection.commit() <https://www.python.org/dev/peps/pep-0249/#commit>`_
811
in the specification.
812
813
"""
814
log_query('COMMIT')
815
if not self._is_committable or self.host == 'singlestore.com':
816
self._is_committable = True
817
return
818
self._execute_command(COMMAND.COM_QUERY, 'COMMIT')
819
self._read_ok_packet()
820
821
def rollback(self):
822
"""
823
Roll back the current transaction.
824
825
See `Connection.rollback() <https://www.python.org/dev/peps/pep-0249/#rollback>`_
826
in the specification.
827
828
"""
829
log_query('ROLLBACK')
830
if not self._is_committable or self.host == 'singlestore.com':
831
self._is_committable = True
832
return
833
self._execute_command(COMMAND.COM_QUERY, 'ROLLBACK')
834
self._read_ok_packet()
835
836
def show_warnings(self):
837
"""Send the "SHOW WARNINGS" SQL command."""
838
log_query('SHOW WARNINGS')
839
self._execute_command(COMMAND.COM_QUERY, 'SHOW WARNINGS')
840
result = self.resultclass(self)
841
result.read()
842
return result.rows
843
844
def select_db(self, db):
845
"""
846
Set current db.
847
848
db : str
849
The name of the db.
850
851
"""
852
self._execute_command(COMMAND.COM_INIT_DB, db)
853
self._read_ok_packet()
854
855
def escape(self, obj, mapping=None):
856
"""
857
Escape whatever value is passed.
858
859
Non-standard, for internal use; do not use this in your applications.
860
861
"""
862
dtype = type(obj)
863
if dtype is str or isinstance(obj, str):
864
return "'{}'".format(self.escape_string(obj))
865
if dtype is bytes or dtype is bytearray or isinstance(obj, (bytes, bytearray)):
866
return self._quote_bytes(obj)
867
if mapping is None:
868
mapping = self.encoders
869
return converters.escape_item(obj, self.charset, mapping=mapping)
870
871
def literal(self, obj):
872
"""
873
Alias for escape().
874
875
Non-standard, for internal use; do not use this in your applications.
876
877
"""
878
return self.escape(obj, self.encoders)
879
880
def escape_string(self, s):
881
"""Escape a string value."""
882
if self.server_status & SERVER_STATUS.SERVER_STATUS_NO_BACKSLASH_ESCAPES:
883
return s.replace("'", "''")
884
return converters.escape_string(s)
885
886
def _quote_bytes(self, s):
887
if self.server_status & SERVER_STATUS.SERVER_STATUS_NO_BACKSLASH_ESCAPES:
888
if self._binary_prefix:
889
return "_binary X'{}'".format(s.hex())
890
return "X'{}'".format(s.hex())
891
return converters.escape_bytes(s)
892
893
def cursor(self):
894
"""Create a new cursor to execute queries with."""
895
return self.cursorclass(self)
896
897
# The following methods are INTERNAL USE ONLY (called from Cursor)
898
def query(self, sql, unbuffered=False, infile_stream=None):
899
"""
900
Run a query on the server.
901
902
Internal use only.
903
904
"""
905
# if DEBUG:
906
# print("DEBUG: sending query:", sql)
907
handler = fusion.get_handler(sql)
908
if handler is not None:
909
self._is_committable = False
910
self._result = fusion.execute(self, sql, handler=handler)
911
self._affected_rows = self._result.affected_rows
912
else:
913
self._is_committable = True
914
if isinstance(sql, str):
915
sql = sql.encode(self.encoding, 'surrogateescape')
916
self._local_infile_stream = infile_stream
917
self._execute_command(COMMAND.COM_QUERY, sql)
918
self._affected_rows = self._read_query_result(unbuffered=unbuffered)
919
self._local_infile_stream = None
920
return self._affected_rows
921
922
def next_result(self, unbuffered=False):
923
"""
924
Retrieve the next result set.
925
926
Internal use only.
927
928
"""
929
self._affected_rows = self._read_query_result(unbuffered=unbuffered)
930
return self._affected_rows
931
932
def affected_rows(self):
933
"""
934
Return number of affected rows.
935
936
Internal use only.
937
938
"""
939
return self._affected_rows
940
941
def kill(self, thread_id):
942
"""
943
Execute kill command.
944
945
Internal use only.
946
947
"""
948
arg = struct.pack('<I', thread_id)
949
self._execute_command(COMMAND.COM_PROCESS_KILL, arg)
950
return self._read_ok_packet()
951
952
def ping(self, reconnect=True):
953
"""
954
Check if the server is alive.
955
956
Parameters
957
----------
958
reconnect : bool, optional
959
If the connection is closed, reconnect.
960
961
Raises
962
------
963
Error : If the connection is closed and reconnect=False.
964
965
"""
966
if self._sock is None:
967
if reconnect:
968
self.connect()
969
reconnect = False
970
else:
971
raise err.Error('Already closed')
972
try:
973
self._execute_command(COMMAND.COM_PING, '')
974
self._read_ok_packet()
975
except Exception:
976
if reconnect:
977
self.connect()
978
self.ping(False)
979
else:
980
raise
981
982
def set_charset(self, charset):
983
"""Deprecated. Use set_character_set() instead."""
984
# This function has been implemented in old PyMySQL.
985
# But this name is different from MySQLdb.
986
# So we keep this function for compatibility and add
987
# new set_character_set() function.
988
self.set_character_set(charset)
989
990
def set_character_set(self, charset, collation=None):
991
"""
992
Set charaset (and collation) on the server.
993
994
Send "SET NAMES charset [COLLATE collation]" query.
995
Update Connection.encoding based on charset.
996
997
Parameters
998
----------
999
charset : str
1000
The charset to enable.
1001
collation : str, optional
1002
The collation value
1003
1004
"""
1005
# Make sure charset is supported.
1006
encoding = charset_by_name(charset).encoding
1007
1008
if collation:
1009
query = f'SET NAMES {charset} COLLATE {collation}'
1010
else:
1011
query = f'SET NAMES {charset}'
1012
self._execute_command(COMMAND.COM_QUERY, query)
1013
self._read_packet()
1014
self.charset = charset
1015
self.encoding = encoding
1016
self.collation = collation
1017
1018
def _sync_connection(self):
1019
"""Synchronize connection with env variable."""
1020
if self._in_sync:
1021
return
1022
1023
if not self._track_env:
1024
return
1025
1026
url = self._connection_info.get('connection_url')
1027
if not url:
1028
url = os.environ.get('SINGLESTOREDB_URL')
1029
if not url:
1030
return
1031
1032
out = {}
1033
urlp = connection._parse_url(url)
1034
out.update(urlp)
1035
1036
out = connection._cast_params(out)
1037
1038
# Set default port based on driver.
1039
if 'port' not in out or not out['port']:
1040
out['port'] = int(get_option('port') or 3306)
1041
1042
# If there is no user and the password is empty, remove the password key.
1043
if 'user' not in out and not out.get('password', None):
1044
out.pop('password', None)
1045
1046
if out['host'] == 'singlestore.com':
1047
raise err.InterfaceError(0, 'Connection URL has not been established')
1048
1049
# If it's just a password change, we don't need to reconnect
1050
if self._sock is not None and \
1051
(self.host, self.port, self.user, self.db) == \
1052
(out['host'], out['port'], out['user'], out.get('database')):
1053
return
1054
1055
self.host = out['host']
1056
self.port = out['port']
1057
self.user = out['user']
1058
if isinstance(out['password'], str):
1059
self.password = out['password'].encode('latin-1')
1060
else:
1061
self.password = out['password'] or b''
1062
self.db = out.get('database')
1063
try:
1064
self._in_sync = True
1065
self.connect()
1066
finally:
1067
self._in_sync = False
1068
1069
def connect(self, sock=None):
1070
"""
1071
Connect to server using existing parameters.
1072
1073
Internal use only.
1074
1075
"""
1076
self._closed = False
1077
try:
1078
if sock is None:
1079
if self.unix_socket:
1080
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1081
sock.settimeout(self.connect_timeout)
1082
sock.connect(self.unix_socket)
1083
self.host_info = 'Localhost via UNIX socket'
1084
self._secure = True
1085
if DEBUG:
1086
print('connected using unix_socket')
1087
else:
1088
kwargs = {}
1089
if self.bind_address is not None:
1090
kwargs['source_address'] = (self.bind_address, 0)
1091
while True:
1092
try:
1093
sock = socket.create_connection(
1094
(self.host, self.port), self.connect_timeout, **kwargs,
1095
)
1096
break
1097
except OSError as e:
1098
if e.errno == errno.EINTR:
1099
continue
1100
raise
1101
self.host_info = 'socket %s:%d' % (self.host, self.port)
1102
if DEBUG:
1103
print('connected using socket')
1104
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
1105
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
1106
sock.settimeout(None)
1107
1108
self._sock = sock
1109
self._rfile = sock.makefile('rb')
1110
self._next_seq_id = 0
1111
1112
self._get_server_information()
1113
self._request_authentication()
1114
1115
# Send "SET NAMES" query on init for:
1116
# - Ensure charaset (and collation) is set to the server.
1117
# - collation_id in handshake packet may be ignored.
1118
# - If collation is not specified, we don't know what is server's
1119
# default collation for the charset. For example, default collation
1120
# of utf8mb4 is:
1121
# - MySQL 5.7, MariaDB 10.x: utf8mb4_general_ci
1122
# - MySQL 8.0: utf8mb4_0900_ai_ci
1123
#
1124
# Reference:
1125
# - https://github.com/PyMySQL/PyMySQL/issues/1092
1126
# - https://github.com/wagtail/wagtail/issues/9477
1127
# - https://zenn.dev/methane/articles/2023-mysql-collation (Japanese)
1128
self.set_character_set(self.charset, self.collation)
1129
1130
if self.sql_mode is not None:
1131
c = self.cursor()
1132
c.execute('SET sql_mode=%s', (self.sql_mode,))
1133
c.close()
1134
1135
if self._enable_extended_data_types:
1136
c = self.cursor()
1137
try:
1138
c.execute('SET @@SESSION.enable_extended_types_metadata=on')
1139
except self.OperationalError:
1140
pass
1141
c.close()
1142
1143
if self._vector_data_format:
1144
c = self.cursor()
1145
try:
1146
val = self._vector_data_format
1147
c.execute(f'SET @@SESSION.vector_type_project_format={val}')
1148
except self.OperationalError:
1149
pass
1150
c.close()
1151
1152
if self.init_command is not None:
1153
c = self.cursor()
1154
c.execute(self.init_command)
1155
c.close()
1156
1157
if self.autocommit_mode is not None:
1158
self.autocommit(self.autocommit_mode)
1159
1160
except BaseException as e:
1161
self._rfile = None
1162
if sock is not None:
1163
try:
1164
sock.close()
1165
except: # noqa
1166
pass
1167
1168
if isinstance(e, (OSError, IOError, socket.error)):
1169
exc = err.OperationalError(
1170
CR.CR_CONN_HOST_ERROR,
1171
f'Can\'t connect to MySQL server on {self.host!r} ({e})',
1172
)
1173
# Keep original exception and traceback to investigate error.
1174
exc.original_exception = e
1175
exc.traceback = traceback.format_exc()
1176
if DEBUG:
1177
print(exc.traceback)
1178
raise exc
1179
1180
# If e is neither DatabaseError or IOError, It's a bug.
1181
# But raising AssertionError hides original error.
1182
# So just reraise it.
1183
raise
1184
1185
def write_packet(self, payload):
1186
"""
1187
Writes an entire "mysql packet" in its entirety to the network.
1188
1189
Adds its length and sequence number.
1190
1191
"""
1192
# Internal note: when you build packet manually and calls _write_bytes()
1193
# directly, you should set self._next_seq_id properly.
1194
data = _pack_int24(len(payload)) + bytes([self._next_seq_id]) + payload
1195
if DEBUG:
1196
dump_packet(data)
1197
self._write_bytes(data)
1198
self._next_seq_id = (self._next_seq_id + 1) % 256
1199
1200
def _read_packet(self, packet_type=MysqlPacket):
1201
"""
1202
Read an entire "mysql packet" in its entirety from the network.
1203
1204
Raises
1205
------
1206
OperationalError : If the connection to the MySQL server is lost.
1207
InternalError : If the packet sequence number is wrong.
1208
1209
Returns
1210
-------
1211
MysqlPacket
1212
1213
"""
1214
buff = bytearray()
1215
while True:
1216
packet_header = self._read_bytes(4)
1217
# if DEBUG: dump_packet(packet_header)
1218
1219
btrl, btrh, packet_number = struct.unpack('<HBB', packet_header)
1220
bytes_to_read = btrl + (btrh << 16)
1221
if packet_number != self._next_seq_id:
1222
self._force_close()
1223
if packet_number == 0:
1224
# MariaDB sends error packet with seqno==0 when shutdown
1225
raise err.OperationalError(
1226
CR.CR_SERVER_LOST,
1227
'Lost connection to MySQL server during query',
1228
)
1229
raise err.InternalError(
1230
'Packet sequence number wrong - got %d expected %d'
1231
% (packet_number, self._next_seq_id),
1232
)
1233
self._next_seq_id = (self._next_seq_id + 1) % 256
1234
1235
recv_data = self._read_bytes(bytes_to_read)
1236
if DEBUG:
1237
dump_packet(recv_data)
1238
buff += recv_data
1239
# https://dev.mysql.com/doc/internals/en/sending-more-than-16mbyte.html
1240
if bytes_to_read == 0xFFFFFF:
1241
continue
1242
if bytes_to_read < MAX_PACKET_LEN:
1243
break
1244
1245
packet = packet_type(bytes(buff), self.encoding)
1246
if packet.is_error_packet():
1247
if self._result is not None and self._result.unbuffered_active is True:
1248
self._result.unbuffered_active = False
1249
packet.raise_for_error()
1250
return packet
1251
1252
def _read_bytes(self, num_bytes):
1253
if self._read_timeout is not None:
1254
self._sock.settimeout(self._read_timeout)
1255
while True:
1256
try:
1257
data = self._rfile.read(num_bytes)
1258
break
1259
except OSError as e:
1260
if e.errno == errno.EINTR:
1261
continue
1262
self._force_close()
1263
raise err.OperationalError(
1264
CR.CR_SERVER_LOST,
1265
'Lost connection to MySQL server during query (%s)' % (e,),
1266
)
1267
except BaseException:
1268
# Don't convert unknown exception to MySQLError.
1269
self._force_close()
1270
raise
1271
if len(data) < num_bytes:
1272
self._force_close()
1273
raise err.OperationalError(
1274
CR.CR_SERVER_LOST, 'Lost connection to MySQL server during query',
1275
)
1276
return data
1277
1278
def _write_bytes(self, data):
1279
if self._write_timeout is not None:
1280
self._sock.settimeout(self._write_timeout)
1281
try:
1282
self._sock.sendall(data)
1283
except OSError as e:
1284
self._force_close()
1285
raise err.OperationalError(
1286
CR.CR_SERVER_GONE_ERROR, f'MySQL server has gone away ({e!r})',
1287
)
1288
1289
def _read_query_result(self, unbuffered=False):
1290
self._result = None
1291
if unbuffered:
1292
result = self.resultclass(self, unbuffered=unbuffered)
1293
else:
1294
result = self.resultclass(self)
1295
result.read()
1296
self._result = result
1297
if result.server_status is not None:
1298
self.server_status = result.server_status
1299
return result.affected_rows
1300
1301
def insert_id(self):
1302
if self._result:
1303
return self._result.insert_id
1304
else:
1305
return 0
1306
1307
def _execute_command(self, command, sql):
1308
"""
1309
Execute command.
1310
1311
Raises
1312
------
1313
InterfaceError : If the connection is closed.
1314
ValueError : If no username was specified.
1315
1316
"""
1317
self._sync_connection()
1318
1319
if self._sock is None:
1320
raise err.InterfaceError(0, 'The connection has been closed')
1321
1322
# If the last query was unbuffered, make sure it finishes before
1323
# sending new commands
1324
if self._result is not None:
1325
if self._result.unbuffered_active:
1326
warnings.warn('Previous unbuffered result was left incomplete')
1327
self._result._finish_unbuffered_query()
1328
while self._result.has_next:
1329
self.next_result()
1330
self._result = None
1331
1332
if isinstance(sql, str):
1333
sql = sql.encode(self.encoding)
1334
1335
packet_size = min(MAX_PACKET_LEN, len(sql) + 1) # +1 is for command
1336
1337
# tiny optimization: build first packet manually instead of
1338
# calling self..write_packet()
1339
prelude = struct.pack('<iB', packet_size, command)
1340
packet = prelude + sql[: packet_size - 1]
1341
self._write_bytes(packet)
1342
if DEBUG:
1343
dump_packet(packet)
1344
self._next_seq_id = 1
1345
1346
if packet_size < MAX_PACKET_LEN:
1347
return
1348
1349
sql = sql[packet_size - 1:]
1350
while True:
1351
packet_size = min(MAX_PACKET_LEN, len(sql))
1352
self.write_packet(sql[:packet_size])
1353
sql = sql[packet_size:]
1354
if not sql and packet_size < MAX_PACKET_LEN:
1355
break
1356
1357
def _request_authentication(self): # noqa: C901
1358
# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
1359
if int(self.server_version.split('.', 1)[0]) >= 5:
1360
self.client_flag |= CLIENT.MULTI_RESULTS
1361
1362
if self.user is None:
1363
raise ValueError('Did not specify a username')
1364
1365
charset_id = charset_by_name(self.charset).id
1366
if isinstance(self.user, str):
1367
self.user = self.user.encode(self.encoding)
1368
1369
data_init = struct.pack(
1370
'<iIB23s', self.client_flag, MAX_PACKET_LEN, charset_id, b'',
1371
)
1372
1373
if self.ssl and self.server_capabilities & CLIENT.SSL:
1374
self.write_packet(data_init)
1375
1376
hostname = self.host
1377
if self._tls_sni_servername:
1378
hostname = self._tls_sni_servername
1379
self._sock = self.ctx.wrap_socket(self._sock, server_hostname=hostname)
1380
self._rfile = self._sock.makefile('rb')
1381
self._secure = True
1382
1383
data = data_init + self.user + b'\0'
1384
1385
authresp = b''
1386
plugin_name = None
1387
1388
if self._auth_plugin_name == '':
1389
plugin_name = b''
1390
authresp = _auth.scramble_native_password(self.password, self.salt)
1391
elif self._auth_plugin_name == 'mysql_native_password':
1392
plugin_name = b'mysql_native_password'
1393
authresp = _auth.scramble_native_password(self.password, self.salt)
1394
elif self._auth_plugin_name == 'caching_sha2_password':
1395
plugin_name = b'caching_sha2_password'
1396
if self.password:
1397
if DEBUG:
1398
print('caching_sha2: trying fast path')
1399
authresp = _auth.scramble_caching_sha2(self.password, self.salt)
1400
else:
1401
if DEBUG:
1402
print('caching_sha2: empty password')
1403
elif self._auth_plugin_name == 'sha256_password':
1404
plugin_name = b'sha256_password'
1405
if self.ssl and self.server_capabilities & CLIENT.SSL:
1406
authresp = self.password + b'\0'
1407
elif self.password:
1408
authresp = b'\1' # request public key
1409
else:
1410
authresp = b'\0' # empty password
1411
1412
if self.server_capabilities & CLIENT.PLUGIN_AUTH_LENENC_CLIENT_DATA:
1413
data += _lenenc_int(len(authresp)) + authresp
1414
elif self.server_capabilities & CLIENT.SECURE_CONNECTION:
1415
data += struct.pack('B', len(authresp)) + authresp
1416
else: # pragma: no cover - no testing against servers w/o secure auth (>=5.0)
1417
data += authresp + b'\0'
1418
1419
if self.server_capabilities & CLIENT.CONNECT_WITH_DB:
1420
db = self.db
1421
if isinstance(db, str):
1422
db = db.encode(self.encoding)
1423
data += (db or b'') + b'\0'
1424
1425
if self.server_capabilities & CLIENT.PLUGIN_AUTH:
1426
data += (plugin_name or b'') + b'\0'
1427
1428
if self.server_capabilities & CLIENT.CONNECT_ATTRS:
1429
connect_attrs = b''
1430
for k, v in self._connect_attrs.items():
1431
k = k.encode('utf-8')
1432
connect_attrs += _lenenc_int(len(k)) + k
1433
v = v.encode('utf-8')
1434
connect_attrs += _lenenc_int(len(v)) + v
1435
data += _lenenc_int(len(connect_attrs)) + connect_attrs
1436
1437
self.write_packet(data)
1438
auth_packet = self._read_packet()
1439
1440
# if authentication method isn't accepted the first byte
1441
# will have the octet 254
1442
if auth_packet.is_auth_switch_request():
1443
if DEBUG:
1444
print('received auth switch')
1445
# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest
1446
auth_packet.read_uint8() # 0xfe packet identifier
1447
plugin_name = auth_packet.read_string()
1448
if (
1449
self.server_capabilities & CLIENT.PLUGIN_AUTH
1450
and plugin_name is not None
1451
):
1452
auth_packet = self._process_auth(plugin_name, auth_packet)
1453
else:
1454
raise err.OperationalError('received unknown auth switch request')
1455
elif auth_packet.is_extra_auth_data():
1456
if DEBUG:
1457
print('received extra data')
1458
# https://dev.mysql.com/doc/internals/en/successful-authentication.html
1459
if self._auth_plugin_name == 'caching_sha2_password':
1460
auth_packet = _auth.caching_sha2_password_auth(self, auth_packet)
1461
elif self._auth_plugin_name == 'sha256_password':
1462
auth_packet = _auth.sha256_password_auth(self, auth_packet)
1463
else:
1464
raise err.OperationalError(
1465
'Received extra packet for auth method %r', self._auth_plugin_name,
1466
)
1467
1468
if DEBUG:
1469
print('Succeed to auth')
1470
1471
def _process_auth(self, plugin_name, auth_packet):
1472
handler = self._get_auth_plugin_handler(plugin_name)
1473
if handler:
1474
try:
1475
return handler.authenticate(auth_packet)
1476
except AttributeError:
1477
if plugin_name != b'dialog':
1478
raise err.OperationalError(
1479
CR.CR_AUTH_PLUGIN_CANNOT_LOAD,
1480
"Authentication plugin '%s'"
1481
' not loaded: - %r missing authenticate method'
1482
% (plugin_name, type(handler)),
1483
)
1484
if plugin_name == b'caching_sha2_password':
1485
return _auth.caching_sha2_password_auth(self, auth_packet)
1486
elif plugin_name == b'sha256_password':
1487
return _auth.sha256_password_auth(self, auth_packet)
1488
elif plugin_name == b'mysql_native_password':
1489
data = _auth.scramble_native_password(self.password, auth_packet.read_all())
1490
elif plugin_name == b'client_ed25519':
1491
data = _auth.ed25519_password(self.password, auth_packet.read_all())
1492
elif plugin_name == b'mysql_old_password':
1493
data = (
1494
_auth.scramble_old_password(self.password, auth_packet.read_all())
1495
+ b'\0'
1496
)
1497
elif plugin_name == b'mysql_clear_password':
1498
# https://dev.mysql.com/doc/internals/en/clear-text-authentication.html
1499
data = self.password + b'\0'
1500
elif plugin_name == b'auth_gssapi_client':
1501
data = _auth.gssapi_auth(auth_packet.read_all())
1502
elif plugin_name == b'dialog':
1503
pkt = auth_packet
1504
while True:
1505
flag = pkt.read_uint8()
1506
echo = (flag & 0x06) == 0x02
1507
last = (flag & 0x01) == 0x01
1508
prompt = pkt.read_all()
1509
1510
if prompt == b'Password: ':
1511
self.write_packet(self.password + b'\0')
1512
elif handler:
1513
resp = 'no response - TypeError within plugin.prompt method'
1514
try:
1515
resp = handler.prompt(echo, prompt)
1516
self.write_packet(resp + b'\0')
1517
except AttributeError:
1518
raise err.OperationalError(
1519
CR.CR_AUTH_PLUGIN_CANNOT_LOAD,
1520
"Authentication plugin '%s'"
1521
' not loaded: - %r missing prompt method'
1522
% (plugin_name, handler),
1523
)
1524
except TypeError:
1525
raise err.OperationalError(
1526
CR.CR_AUTH_PLUGIN_ERR,
1527
"Authentication plugin '%s'"
1528
" %r didn't respond with string. Returned '%r' to prompt %r"
1529
% (plugin_name, handler, resp, prompt),
1530
)
1531
else:
1532
raise err.OperationalError(
1533
CR.CR_AUTH_PLUGIN_CANNOT_LOAD,
1534
"Authentication plugin '%s' not configured" % (plugin_name,),
1535
)
1536
pkt = self._read_packet()
1537
pkt.check_error()
1538
if pkt.is_ok_packet() or last:
1539
break
1540
return pkt
1541
else:
1542
raise err.OperationalError(
1543
CR.CR_AUTH_PLUGIN_CANNOT_LOAD,
1544
"Authentication plugin '%s' not configured" % plugin_name,
1545
)
1546
1547
self.write_packet(data)
1548
pkt = self._read_packet()
1549
pkt.check_error()
1550
return pkt
1551
1552
def _get_auth_plugin_handler(self, plugin_name):
1553
plugin_class = self._auth_plugin_map.get(plugin_name)
1554
if not plugin_class and isinstance(plugin_name, bytes):
1555
plugin_class = self._auth_plugin_map.get(plugin_name.decode('ascii'))
1556
if plugin_class:
1557
try:
1558
handler = plugin_class(self)
1559
except TypeError:
1560
raise err.OperationalError(
1561
CR.CR_AUTH_PLUGIN_CANNOT_LOAD,
1562
"Authentication plugin '%s'"
1563
' not loaded: - %r cannot be constructed with connection object'
1564
% (plugin_name, plugin_class),
1565
)
1566
else:
1567
handler = None
1568
return handler
1569
1570
# _mysql support
1571
def thread_id(self):
1572
return self.server_thread_id[0]
1573
1574
def character_set_name(self):
1575
return self.charset
1576
1577
def get_host_info(self):
1578
return self.host_info
1579
1580
def get_proto_info(self):
1581
return self.protocol_version
1582
1583
def _get_server_information(self):
1584
i = 0
1585
packet = self._read_packet()
1586
data = packet.get_all_data()
1587
1588
self.protocol_version = data[i]
1589
i += 1
1590
1591
server_end = data.find(b'\0', i)
1592
self.server_version = data[i:server_end].decode('latin1')
1593
i = server_end + 1
1594
1595
self.server_thread_id = struct.unpack('<I', data[i: i + 4])
1596
i += 4
1597
1598
self.salt = data[i: i + 8]
1599
i += 9 # 8 + 1(filler)
1600
1601
self.server_capabilities = struct.unpack('<H', data[i: i + 2])[0]
1602
i += 2
1603
1604
if len(data) >= i + 6:
1605
lang, stat, cap_h, salt_len = struct.unpack('<BHHB', data[i: i + 6])
1606
i += 6
1607
# TODO: deprecate server_language and server_charset.
1608
# mysqlclient-python doesn't provide it.
1609
self.server_language = lang
1610
try:
1611
self.server_charset = charset_by_id(lang).name
1612
except KeyError:
1613
# unknown collation
1614
self.server_charset = None
1615
1616
self.server_status = stat
1617
if DEBUG:
1618
print('server_status: %x' % stat)
1619
1620
self.server_capabilities |= cap_h << 16
1621
if DEBUG:
1622
print('salt_len:', salt_len)
1623
salt_len = max(12, salt_len - 9)
1624
1625
# reserved
1626
i += 10
1627
1628
if len(data) >= i + salt_len:
1629
# salt_len includes auth_plugin_data_part_1 and filler
1630
self.salt += data[i: i + salt_len]
1631
i += salt_len
1632
1633
i += 1
1634
# AUTH PLUGIN NAME may appear here.
1635
if self.server_capabilities & CLIENT.PLUGIN_AUTH and len(data) >= i:
1636
# Due to Bug#59453 the auth-plugin-name is missing the terminating
1637
# NUL-char in versions prior to 5.5.10 and 5.6.2.
1638
# ref: https://dev.mysql.com/doc/internals/en/
1639
# connection-phase-packets.html#packet-Protocol::Handshake
1640
# didn't use version checks as mariadb is corrected and reports
1641
# earlier than those two.
1642
server_end = data.find(b'\0', i)
1643
if server_end < 0: # pragma: no cover - very specific upstream bug
1644
# not found \0 and last field so take it all
1645
self._auth_plugin_name = data[i:].decode('utf-8')
1646
else:
1647
self._auth_plugin_name = data[i:server_end].decode('utf-8')
1648
1649
def get_server_info(self):
1650
return self.server_version
1651
1652
Warning = err.Warning
1653
Error = err.Error
1654
InterfaceError = err.InterfaceError
1655
DatabaseError = err.DatabaseError
1656
DataError = err.DataError
1657
OperationalError = err.OperationalError
1658
IntegrityError = err.IntegrityError
1659
InternalError = err.InternalError
1660
ProgrammingError = err.ProgrammingError
1661
NotSupportedError = err.NotSupportedError
1662
1663
1664
class MySQLResult:
1665
"""
1666
Results of a SQL query.
1667
1668
Parameters
1669
----------
1670
connection : Connection
1671
The connection the result came from.
1672
unbuffered : bool, optional
1673
Should the reads be unbuffered?
1674
1675
"""
1676
1677
def __init__(self, connection, unbuffered=False):
1678
self.connection = connection
1679
self.affected_rows = None
1680
self.insert_id = None
1681
self.server_status = None
1682
self.warning_count = 0
1683
self.message = None
1684
self.field_count = 0
1685
self.description = None
1686
self.rows = None
1687
self.has_next = None
1688
self.unbuffered_active = False
1689
self.converters = []
1690
self.fields = []
1691
self.encoding_errors = self.connection.encoding_errors
1692
if unbuffered:
1693
try:
1694
self.init_unbuffered_query()
1695
except Exception:
1696
self.connection = None
1697
self.unbuffered_active = False
1698
raise
1699
1700
def __del__(self):
1701
if self.unbuffered_active:
1702
self._finish_unbuffered_query()
1703
1704
def read(self):
1705
try:
1706
first_packet = self.connection._read_packet()
1707
1708
if first_packet.is_ok_packet():
1709
self._read_ok_packet(first_packet)
1710
elif first_packet.is_load_local_packet():
1711
self._read_load_local_packet(first_packet)
1712
else:
1713
self._read_result_packet(first_packet)
1714
finally:
1715
self.connection = None
1716
1717
def init_unbuffered_query(self):
1718
"""
1719
Initialize an unbuffered query.
1720
1721
Raises
1722
------
1723
OperationalError : If the connection to the MySQL server is lost.
1724
InternalError : Other errors.
1725
1726
"""
1727
self.unbuffered_active = True
1728
first_packet = self.connection._read_packet()
1729
1730
if first_packet.is_ok_packet():
1731
self._read_ok_packet(first_packet)
1732
self.unbuffered_active = False
1733
self.connection = None
1734
elif first_packet.is_load_local_packet():
1735
self._read_load_local_packet(first_packet)
1736
self.unbuffered_active = False
1737
self.connection = None
1738
else:
1739
self.field_count = first_packet.read_length_encoded_integer()
1740
self._get_descriptions()
1741
1742
# Apparently, MySQLdb picks this number because it's the maximum
1743
# value of a 64bit unsigned integer. Since we're emulating MySQLdb,
1744
# we set it to this instead of None, which would be preferred.
1745
self.affected_rows = 18446744073709551615
1746
1747
def _read_ok_packet(self, first_packet):
1748
ok_packet = OKPacketWrapper(first_packet)
1749
self.affected_rows = ok_packet.affected_rows
1750
self.insert_id = ok_packet.insert_id
1751
self.server_status = ok_packet.server_status
1752
self.warning_count = ok_packet.warning_count
1753
self.message = ok_packet.message
1754
self.has_next = ok_packet.has_next
1755
1756
def _read_load_local_packet(self, first_packet):
1757
if not self.connection._local_infile:
1758
raise RuntimeError(
1759
'**WARN**: Received LOAD_LOCAL packet but local_infile option is false.',
1760
)
1761
load_packet = LoadLocalPacketWrapper(first_packet)
1762
sender = LoadLocalFile(load_packet.filename, self.connection)
1763
try:
1764
sender.send_data()
1765
except Exception:
1766
self.connection._read_packet() # skip ok packet
1767
raise
1768
1769
ok_packet = self.connection._read_packet()
1770
if (
1771
not ok_packet.is_ok_packet()
1772
): # pragma: no cover - upstream induced protocol error
1773
raise err.OperationalError(
1774
CR.CR_COMMANDS_OUT_OF_SYNC,
1775
'Commands Out of Sync',
1776
)
1777
self._read_ok_packet(ok_packet)
1778
1779
def _check_packet_is_eof(self, packet):
1780
if not packet.is_eof_packet():
1781
return False
1782
# TODO: Support CLIENT.DEPRECATE_EOF
1783
# 1) Add DEPRECATE_EOF to CAPABILITIES
1784
# 2) Mask CAPABILITIES with server_capabilities
1785
# 3) if server_capabilities & CLIENT.DEPRECATE_EOF: use OKPacketWrapper
1786
# instead of EOFPacketWrapper
1787
wp = EOFPacketWrapper(packet)
1788
self.warning_count = wp.warning_count
1789
self.has_next = wp.has_next
1790
return True
1791
1792
def _read_result_packet(self, first_packet):
1793
self.field_count = first_packet.read_length_encoded_integer()
1794
self._get_descriptions()
1795
self._read_rowdata_packet()
1796
1797
def _read_rowdata_packet_unbuffered(self):
1798
# Check if in an active query
1799
if not self.unbuffered_active:
1800
return
1801
1802
# EOF
1803
packet = self.connection._read_packet()
1804
if self._check_packet_is_eof(packet):
1805
self.unbuffered_active = False
1806
self.connection = None
1807
self.rows = None
1808
return
1809
1810
row = self._read_row_from_packet(packet)
1811
self.affected_rows = 1
1812
self.rows = (row,) # rows should tuple of row for MySQL-python compatibility.
1813
return row
1814
1815
def _finish_unbuffered_query(self):
1816
# After much reading on the MySQL protocol, it appears that there is,
1817
# in fact, no way to stop MySQL from sending all the data after
1818
# executing a query, so we just spin, and wait for an EOF packet.
1819
while self.unbuffered_active and self.connection._sock is not None:
1820
try:
1821
packet = self.connection._read_packet()
1822
except err.OperationalError as e:
1823
if e.args[0] in (
1824
ER.QUERY_TIMEOUT,
1825
ER.STATEMENT_TIMEOUT,
1826
):
1827
# if the query timed out we can simply ignore this error
1828
self.unbuffered_active = False
1829
self.connection = None
1830
return
1831
1832
raise
1833
1834
if self._check_packet_is_eof(packet):
1835
self.unbuffered_active = False
1836
self.connection = None # release reference to kill cyclic reference.
1837
1838
def _read_rowdata_packet(self):
1839
"""Read a rowdata packet for each data row in the result set."""
1840
rows = []
1841
while True:
1842
packet = self.connection._read_packet()
1843
if self._check_packet_is_eof(packet):
1844
self.connection = None # release reference to kill cyclic reference.
1845
break
1846
rows.append(self._read_row_from_packet(packet))
1847
1848
self.affected_rows = len(rows)
1849
self.rows = tuple(rows)
1850
1851
def _read_row_from_packet(self, packet):
1852
row = []
1853
for i, (encoding, converter) in enumerate(self.converters):
1854
try:
1855
data = packet.read_length_coded_string()
1856
except IndexError:
1857
# No more columns in this row
1858
# See https://github.com/PyMySQL/PyMySQL/pull/434
1859
break
1860
if data is not None:
1861
if encoding is not None:
1862
try:
1863
data = data.decode(encoding, errors=self.encoding_errors)
1864
except UnicodeDecodeError:
1865
raise UnicodeDecodeError(
1866
'failed to decode string value in column '
1867
f"'{self.fields[i].name}' using encoding '{encoding}'; " +
1868
"use the 'encoding_errors' option on the connection " +
1869
'to specify how to handle this error',
1870
)
1871
if DEBUG:
1872
print('DEBUG: DATA = ', data)
1873
if converter is not None:
1874
data = converter(data)
1875
row.append(data)
1876
return tuple(row)
1877
1878
def _get_descriptions(self):
1879
"""Read a column descriptor packet for each column in the result."""
1880
self.fields = []
1881
self.converters = []
1882
use_unicode = self.connection.use_unicode
1883
conn_encoding = self.connection.encoding
1884
description = []
1885
1886
for i in range(self.field_count):
1887
field = self.connection._read_packet(FieldDescriptorPacket)
1888
self.fields.append(field)
1889
description.append(field.description())
1890
field_type = field.type_code
1891
if use_unicode:
1892
if field_type == FIELD_TYPE.JSON:
1893
# When SELECT from JSON column: charset = binary
1894
# When SELECT CAST(... AS JSON): charset = connection encoding
1895
# This behavior is different from TEXT / BLOB.
1896
# We should decode result by connection encoding regardless charsetnr.
1897
# See https://github.com/PyMySQL/PyMySQL/issues/488
1898
encoding = conn_encoding # SELECT CAST(... AS JSON)
1899
elif field_type in TEXT_TYPES:
1900
if field.charsetnr == 63: # binary
1901
# TEXTs with charset=binary means BINARY types.
1902
encoding = None
1903
else:
1904
encoding = conn_encoding
1905
else:
1906
# Integers, Dates and Times, and other basic data is encoded in ascii
1907
encoding = 'ascii'
1908
else:
1909
encoding = None
1910
converter = self.connection.decoders.get(field_type)
1911
if converter is converters.through:
1912
converter = None
1913
if DEBUG:
1914
print(f'DEBUG: field={field}, converter={converter}')
1915
self.converters.append((encoding, converter))
1916
1917
eof_packet = self.connection._read_packet()
1918
assert eof_packet.is_eof_packet(), 'Protocol error, expecting EOF'
1919
self.description = tuple(description)
1920
1921
1922
class MySQLResultSV(MySQLResult):
1923
1924
def __init__(self, connection, unbuffered=False):
1925
MySQLResult.__init__(self, connection, unbuffered=unbuffered)
1926
self.options = {
1927
k: v for k, v in dict(
1928
default_converters=converters.decoders,
1929
results_type=connection.results_type,
1930
parse_json=connection.parse_json,
1931
invalid_values=connection.invalid_values,
1932
unbuffered=unbuffered,
1933
encoding_errors=connection.encoding_errors,
1934
).items() if v is not UNSET
1935
}
1936
1937
def _read_rowdata_packet(self, *args, **kwargs):
1938
return _singlestoredb_accel.read_rowdata_packet(self, False, *args, **kwargs)
1939
1940
def _read_rowdata_packet_unbuffered(self, *args, **kwargs):
1941
return _singlestoredb_accel.read_rowdata_packet(self, True, *args, **kwargs)
1942
1943
1944
class LoadLocalFile:
1945
1946
def __init__(self, filename, connection):
1947
self.filename = filename
1948
self.connection = connection
1949
1950
def send_data(self):
1951
"""Send data packets from the local file to the server"""
1952
if not self.connection._sock:
1953
raise err.InterfaceError(0, 'Connection is closed')
1954
1955
conn = self.connection
1956
infile = conn._local_infile_stream
1957
1958
# 16KB is efficient enough
1959
packet_size = min(conn.max_allowed_packet, 16 * 1024)
1960
1961
try:
1962
1963
if self.filename in [':stream:', b':stream:']:
1964
1965
if infile is None:
1966
raise err.OperationalError(
1967
ER.FILE_NOT_FOUND,
1968
':stream: specified for LOCAL INFILE, but no stream was supplied',
1969
)
1970
1971
# Binary IO
1972
elif isinstance(infile, io.RawIOBase):
1973
while True:
1974
chunk = infile.read(packet_size)
1975
if not chunk:
1976
break
1977
conn.write_packet(chunk)
1978
1979
# Text IO
1980
elif isinstance(infile, io.TextIOBase):
1981
while True:
1982
chunk = infile.read(packet_size)
1983
if not chunk:
1984
break
1985
conn.write_packet(chunk.encode('utf8'))
1986
1987
# Iterable of bytes or str
1988
elif isinstance(infile, Iterable):
1989
for chunk in infile:
1990
if not chunk:
1991
continue
1992
if isinstance(chunk, str):
1993
conn.write_packet(chunk.encode('utf8'))
1994
else:
1995
conn.write_packet(chunk)
1996
1997
# Queue (empty value ends the iteration)
1998
elif isinstance(infile, queue.Queue):
1999
while True:
2000
chunk = infile.get()
2001
if not chunk:
2002
break
2003
if isinstance(chunk, str):
2004
conn.write_packet(chunk.encode('utf8'))
2005
else:
2006
conn.write_packet(chunk)
2007
2008
else:
2009
raise err.OperationalError(
2010
ER.FILE_NOT_FOUND,
2011
':stream: specified for LOCAL INFILE, ' +
2012
f'but stream type is unrecognized: {infile}',
2013
)
2014
2015
else:
2016
try:
2017
with open(self.filename, 'rb') as open_file:
2018
while True:
2019
chunk = open_file.read(packet_size)
2020
if not chunk:
2021
break
2022
conn.write_packet(chunk)
2023
except OSError:
2024
raise err.OperationalError(
2025
ER.FILE_NOT_FOUND,
2026
f"Can't find file '{self.filename!s}'",
2027
)
2028
2029
finally:
2030
if not conn._closed:
2031
# send the empty packet to signify we are done sending data
2032
conn.write_packet(b'')
2033
2034