Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
singlestore-labs
GitHub Repository: singlestore-labs/singlestoredb-python
Path: blob/main/singlestoredb/mysql/connection.py
798 views
1
# type: ignore
2
# Python implementation of the MySQL client-server protocol
3
# http://dev.mysql.com/doc/internals/en/client-server-protocol.html
4
# Error codes:
5
# https://dev.mysql.com/doc/refman/5.5/en/error-handling.html
6
import errno
7
import functools
8
import io
9
import os
10
import queue
11
import socket
12
import struct
13
import sys
14
import traceback
15
import warnings
16
from collections.abc import Iterable
17
from typing import Any
18
from typing import Dict
19
20
try:
21
import _singlestoredb_accel
22
except (ImportError, ModuleNotFoundError):
23
_singlestoredb_accel = None
24
25
from . import _auth
26
from ..utils import events
27
28
from .charset import charset_by_name, charset_by_id
29
from .constants import CLIENT, COMMAND, CR, ER, FIELD_TYPE, SERVER_STATUS
30
from . import converters
31
from .cursors import (
32
Cursor,
33
CursorSV,
34
DictCursor,
35
DictCursorSV,
36
NamedtupleCursor,
37
NamedtupleCursorSV,
38
ArrowCursor,
39
ArrowCursorSV,
40
NumpyCursor,
41
NumpyCursorSV,
42
PandasCursor,
43
PandasCursorSV,
44
PolarsCursor,
45
PolarsCursorSV,
46
SSCursor,
47
SSCursorSV,
48
SSDictCursor,
49
SSDictCursorSV,
50
SSNamedtupleCursor,
51
SSNamedtupleCursorSV,
52
SSArrowCursor,
53
SSArrowCursorSV,
54
SSNumpyCursor,
55
SSNumpyCursorSV,
56
SSPandasCursor,
57
SSPandasCursorSV,
58
SSPolarsCursor,
59
SSPolarsCursorSV,
60
)
61
from .optionfile import Parser
62
from .protocol import (
63
dump_packet,
64
MysqlPacket,
65
FieldDescriptorPacket,
66
OKPacketWrapper,
67
EOFPacketWrapper,
68
LoadLocalPacketWrapper,
69
)
70
from . import err
71
from ..config import get_option
72
from .. import fusion
73
from .. import connection
74
from ..connection import Connection as BaseConnection
75
from ..utils.debug import log_query
76
77
try:
78
import ssl
79
80
SSL_ENABLED = True
81
except ImportError:
82
ssl = None
83
SSL_ENABLED = False
84
85
try:
86
import getpass
87
88
DEFAULT_USER = getpass.getuser()
89
del getpass
90
except (ImportError, KeyError):
91
# KeyError occurs when there's no entry in OS database for a current user.
92
DEFAULT_USER = None
93
94
DEBUG = get_option('debug.connection')
95
96
TEXT_TYPES = {
97
FIELD_TYPE.BIT,
98
FIELD_TYPE.BLOB,
99
FIELD_TYPE.LONG_BLOB,
100
FIELD_TYPE.MEDIUM_BLOB,
101
FIELD_TYPE.STRING,
102
FIELD_TYPE.TINY_BLOB,
103
FIELD_TYPE.VAR_STRING,
104
FIELD_TYPE.VARCHAR,
105
FIELD_TYPE.GEOMETRY,
106
FIELD_TYPE.BSON,
107
FIELD_TYPE.FLOAT32_VECTOR_JSON,
108
FIELD_TYPE.FLOAT64_VECTOR_JSON,
109
FIELD_TYPE.INT8_VECTOR_JSON,
110
FIELD_TYPE.INT16_VECTOR_JSON,
111
FIELD_TYPE.INT32_VECTOR_JSON,
112
FIELD_TYPE.INT64_VECTOR_JSON,
113
FIELD_TYPE.FLOAT16_VECTOR_JSON,
114
FIELD_TYPE.FLOAT32_VECTOR,
115
FIELD_TYPE.FLOAT64_VECTOR,
116
FIELD_TYPE.INT8_VECTOR,
117
FIELD_TYPE.INT16_VECTOR,
118
FIELD_TYPE.INT32_VECTOR,
119
FIELD_TYPE.INT64_VECTOR,
120
FIELD_TYPE.FLOAT16_VECTOR,
121
}
122
123
UNSET = 'unset'
124
125
DEFAULT_CHARSET = 'utf8mb4'
126
127
MAX_PACKET_LEN = 2**24 - 1
128
129
130
def _pack_int24(n):
131
return struct.pack('<I', n)[:3]
132
133
134
# https://dev.mysql.com/doc/internals/en/integer.html#packet-Protocol::LengthEncodedInteger
135
def _lenenc_int(i):
136
if i < 0:
137
raise ValueError(
138
'Encoding %d is less than 0 - no representation in LengthEncodedInteger' % i,
139
)
140
elif i < 0xFB:
141
return bytes([i])
142
elif i < (1 << 16):
143
return b'\xfc' + struct.pack('<H', i)
144
elif i < (1 << 24):
145
return b'\xfd' + struct.pack('<I', i)[:3]
146
elif i < (1 << 64):
147
return b'\xfe' + struct.pack('<Q', i)
148
else:
149
raise ValueError(
150
'Encoding %x is larger than %x - no representation in LengthEncodedInteger'
151
% (i, (1 << 64)),
152
)
153
154
155
class Connection(BaseConnection):
156
"""
157
Representation of a socket with a mysql server.
158
159
The proper way to get an instance of this class is to call
160
``connect()``.
161
162
Establish a connection to the SingleStoreDB database.
163
164
Parameters
165
----------
166
host : str, optional
167
Host where the database server is located.
168
user : str, optional
169
Username to log in as.
170
password : str, optional
171
Password to use.
172
database : str, optional
173
Database to use, None to not use a particular one.
174
port : int, optional
175
Server port to use, default is usually OK. (default: 3306)
176
bind_address : str, optional
177
When the client has multiple network interfaces, specify
178
the interface from which to connect to the host. Argument can be
179
a hostname or an IP address.
180
unix_socket : str, optional
181
Use a unix socket rather than TCP/IP.
182
read_timeout : int, optional
183
The timeout for reading from the connection in seconds
184
(default: None - no timeout)
185
write_timeout : int, optional
186
The timeout for writing to the connection in seconds
187
(default: None - no timeout)
188
charset : str, optional
189
Charset to use.
190
collation : str, optional
191
The charset collation
192
sql_mode : str, optional
193
Default SQL_MODE to use.
194
read_default_file : str, optional
195
Specifies my.cnf file to read these parameters from under the
196
[client] section.
197
conv : Dict[str, Callable[Any]], optional
198
Conversion dictionary to use instead of the default one.
199
This is used to provide custom marshalling and unmarshalling of types.
200
See converters.
201
use_unicode : bool, optional
202
Whether or not to default to unicode strings.
203
This option defaults to true.
204
client_flag : int, optional
205
Custom flags to send to MySQL. Find potential values in constants.CLIENT.
206
cursorclass : type, optional
207
Custom cursor class to use.
208
init_command : str, optional
209
Initial SQL statement to run when connection is established.
210
connect_timeout : int, optional
211
The timeout for connecting to the database in seconds.
212
(default: 10, min: 1, max: 31536000)
213
ssl : Dict[str, str], optional
214
A dict of arguments similar to mysql_ssl_set()'s parameters or
215
an ssl.SSLContext.
216
ssl_ca : str, optional
217
Path to the file that contains a PEM-formatted CA certificate.
218
ssl_cert : str, optional
219
Path to the file that contains a PEM-formatted client certificate.
220
ssl_cipher : str, optional
221
SSL ciphers to allow.
222
ssl_disabled : bool, optional
223
A boolean value that disables usage of TLS.
224
ssl_key : str, optional
225
Path to the file that contains a PEM-formatted private key for the
226
client certificate.
227
ssl_verify_cert : str, optional
228
Set to true to check the server certificate's validity.
229
ssl_verify_identity : bool, optional
230
Set to true to check the server's identity.
231
tls_sni_servername: str, optional
232
Set server host name for TLS connection
233
read_default_group : str, optional
234
Group to read from in the configuration file.
235
autocommit : bool, optional
236
Autocommit mode. None means use server default. (default: False)
237
local_infile : bool, optional
238
Boolean to enable the use of LOAD DATA LOCAL command. (default: False)
239
max_allowed_packet : int, optional
240
Max size of packet sent to server in bytes. (default: 16MB)
241
Only used to limit size of "LOAD LOCAL INFILE" data packet smaller
242
than default (16KB).
243
defer_connect : bool, optional
244
Don't explicitly connect on construction - wait for connect call.
245
(default: False)
246
auth_plugin_map : Dict[str, type], optional
247
A dict of plugin names to a class that processes that plugin.
248
The class will take the Connection object as the argument to the
249
constructor. The class needs an authenticate method taking an
250
authentication packet as an argument. For the dialog plugin, a
251
prompt(echo, prompt) method can be used (if no authenticate method)
252
for returning a string from the user. (experimental)
253
server_public_key : str, optional
254
SHA256 authentication plugin public key value. (default: None)
255
binary_prefix : bool, optional
256
Add _binary prefix on bytes and bytearray. (default: False)
257
compress :
258
Not supported.
259
named_pipe :
260
Not supported.
261
db : str, optional
262
**DEPRECATED** Alias for database.
263
passwd : str, optional
264
**DEPRECATED** Alias for password.
265
parse_json : bool, optional
266
Parse JSON values into Python objects?
267
invalid_values : Dict[int, Any], optional
268
Dictionary of values to use in place of invalid values
269
found during conversion of data. The default is to return the byte content
270
containing the invalid value. The keys are the integers associtated with
271
the column type.
272
pure_python : bool, optional
273
Should we ignore the C extension even if it's available?
274
This can be given explicitly using True or False, or if the value is None,
275
the C extension will be loaded if it is available. If set to False and
276
the C extension can't be loaded, a NotSupportedError is raised.
277
nan_as_null : bool, optional
278
Should NaN values be treated as NULLs in parameter substitution including
279
uploading data?
280
inf_as_null : bool, optional
281
Should Inf values be treated as NULLs in parameter substitution including
282
uploading data?
283
track_env : bool, optional
284
Should the connection track the SINGLESTOREDB_URL environment variable?
285
enable_extended_data_types : bool, optional
286
Should extended data types (BSON, vector) be enabled?
287
vector_data_format : str, optional
288
Specify the data type of vector values: json or binary
289
290
See `Connection <https://www.python.org/dev/peps/pep-0249/#connection-objects>`_
291
in the specification.
292
293
"""
294
295
driver = 'mysql'
296
paramstyle = 'pyformat'
297
298
_sock = None
299
_auth_plugin_name = ''
300
_closed = False
301
_secure = False
302
_tls_sni_servername = None
303
304
def __init__( # noqa: C901
305
self,
306
*,
307
user=None, # The first four arguments is based on DB-API 2.0 recommendation.
308
password='',
309
host=None,
310
database=None,
311
unix_socket=None,
312
port=0,
313
charset='',
314
collation=None,
315
sql_mode=None,
316
read_default_file=None,
317
conv=None,
318
use_unicode=True,
319
client_flag=0,
320
cursorclass=None,
321
init_command=None,
322
connect_timeout=10,
323
read_default_group=None,
324
autocommit=False,
325
local_infile=False,
326
max_allowed_packet=16 * 1024 * 1024,
327
defer_connect=False,
328
auth_plugin_map=None,
329
read_timeout=None,
330
write_timeout=None,
331
bind_address=None,
332
binary_prefix=False,
333
program_name=None,
334
server_public_key=None,
335
ssl=None,
336
ssl_ca=None,
337
ssl_cert=None,
338
ssl_cipher=None,
339
ssl_disabled=None,
340
ssl_key=None,
341
ssl_verify_cert=None,
342
ssl_verify_identity=None,
343
tls_sni_servername=None,
344
parse_json=True,
345
invalid_values=None,
346
pure_python=None,
347
buffered=True,
348
results_type='tuples',
349
compress=None, # not supported
350
named_pipe=None, # not supported
351
passwd=None, # deprecated
352
db=None, # deprecated
353
driver=None, # internal use
354
conn_attrs=None,
355
multi_statements=None,
356
client_found_rows=None,
357
nan_as_null=None,
358
inf_as_null=None,
359
encoding_errors='strict',
360
track_env=False,
361
enable_extended_data_types=True,
362
vector_data_format='binary',
363
interpolate_query_with_empty_args=None,
364
):
365
BaseConnection.__init__(**dict(locals()))
366
367
if db is not None and database is None:
368
# We will raise warning in 2022 or later.
369
# See https://github.com/PyMySQL/PyMySQL/issues/939
370
# warnings.warn("'db' is deprecated, use 'database'", DeprecationWarning, 3)
371
database = db
372
if passwd is not None and not password:
373
# We will raise warning in 2022 or later.
374
# See https://github.com/PyMySQL/PyMySQL/issues/939
375
# warnings.warn(
376
# "'passwd' is deprecated, use 'password'", DeprecationWarning, 3
377
# )
378
password = passwd
379
380
if compress or named_pipe:
381
raise NotImplementedError(
382
'compress and named_pipe arguments are not supported',
383
)
384
385
self._local_infile = bool(local_infile)
386
self._local_infile_stream = None
387
if self._local_infile:
388
client_flag |= CLIENT.LOCAL_FILES
389
if multi_statements:
390
client_flag |= CLIENT.MULTI_STATEMENTS
391
if client_found_rows:
392
client_flag |= CLIENT.FOUND_ROWS
393
394
if read_default_group and not read_default_file:
395
if sys.platform.startswith('win'):
396
read_default_file = 'c:\\my.ini'
397
else:
398
read_default_file = '/etc/my.cnf'
399
400
if read_default_file:
401
if not read_default_group:
402
read_default_group = 'client'
403
404
cfg = Parser()
405
cfg.read(os.path.expanduser(read_default_file))
406
407
def _config(key, arg):
408
if arg:
409
return arg
410
try:
411
return cfg.get(read_default_group, key)
412
except Exception:
413
return arg
414
415
user = _config('user', user)
416
password = _config('password', password)
417
host = _config('host', host)
418
database = _config('database', database)
419
unix_socket = _config('socket', unix_socket)
420
port = int(_config('port', port))
421
bind_address = _config('bind-address', bind_address)
422
charset = _config('default-character-set', charset)
423
if not ssl:
424
ssl = {}
425
if isinstance(ssl, dict):
426
for key in ['ca', 'capath', 'cert', 'key', 'cipher']:
427
value = _config('ssl-' + key, ssl.get(key))
428
if value:
429
ssl[key] = value
430
431
self.ssl = False
432
if not ssl_disabled:
433
if ssl_ca or ssl_cert or ssl_key or ssl_cipher or \
434
ssl_verify_cert or ssl_verify_identity:
435
ssl = {
436
'ca': ssl_ca,
437
'check_hostname': bool(ssl_verify_identity),
438
'verify_mode': ssl_verify_cert
439
if ssl_verify_cert is not None
440
else False,
441
}
442
if ssl_cert is not None:
443
ssl['cert'] = ssl_cert
444
if ssl_key is not None:
445
ssl['key'] = ssl_key
446
if ssl_cipher is not None:
447
ssl['cipher'] = ssl_cipher
448
if ssl:
449
if not SSL_ENABLED:
450
raise NotImplementedError('ssl module not found')
451
self.ssl = True
452
client_flag |= CLIENT.SSL
453
self.ctx = self._create_ssl_ctx(ssl)
454
455
self.host = host or 'localhost'
456
self.port = port or 3306
457
if type(self.port) is not int:
458
raise ValueError('port should be of type int')
459
self.user = user or DEFAULT_USER
460
self.password = password or b''
461
if isinstance(self.password, str):
462
self.password = self.password.encode('latin1')
463
self.db = database
464
self.unix_socket = unix_socket
465
self.bind_address = bind_address
466
if not (0 < connect_timeout <= 31536000):
467
raise ValueError('connect_timeout should be >0 and <=31536000')
468
self.connect_timeout = connect_timeout or None
469
if read_timeout is not None and read_timeout <= 0:
470
raise ValueError('read_timeout should be > 0')
471
self._read_timeout = read_timeout
472
if write_timeout is not None and write_timeout <= 0:
473
raise ValueError('write_timeout should be > 0')
474
self._write_timeout = write_timeout
475
476
self.charset = charset or DEFAULT_CHARSET
477
self.collation = collation
478
self.use_unicode = use_unicode
479
self.encoding_errors = encoding_errors
480
481
self.encoding = charset_by_name(self.charset).encoding
482
483
client_flag |= CLIENT.CAPABILITIES
484
client_flag |= CLIENT.CONNECT_WITH_DB
485
486
self.client_flag = client_flag
487
488
self.pure_python = pure_python
489
self.results_type = results_type
490
self.resultclass = MySQLResult
491
if cursorclass is not None:
492
self.cursorclass = cursorclass
493
elif buffered:
494
if 'dict' in self.results_type:
495
self.cursorclass = DictCursor
496
elif 'namedtuple' in self.results_type:
497
self.cursorclass = NamedtupleCursor
498
elif 'numpy' in self.results_type:
499
self.cursorclass = NumpyCursor
500
elif 'arrow' in self.results_type:
501
self.cursorclass = ArrowCursor
502
elif 'pandas' in self.results_type:
503
self.cursorclass = PandasCursor
504
elif 'polars' in self.results_type:
505
self.cursorclass = PolarsCursor
506
else:
507
self.cursorclass = Cursor
508
else:
509
if 'dict' in self.results_type:
510
self.cursorclass = SSDictCursor
511
elif 'namedtuple' in self.results_type:
512
self.cursorclass = SSNamedtupleCursor
513
elif 'numpy' in self.results_type:
514
self.cursorclass = SSNumpyCursor
515
elif 'arrow' in self.results_type:
516
self.cursorclass = SSArrowCursor
517
elif 'pandas' in self.results_type:
518
self.cursorclass = SSPandasCursor
519
elif 'polars' in self.results_type:
520
self.cursorclass = SSPolarsCursor
521
else:
522
self.cursorclass = SSCursor
523
524
if self.pure_python is False and _singlestoredb_accel is None:
525
try:
526
import _singlestortedb_accel # noqa: F401
527
except Exception:
528
import traceback
529
traceback.print_exc(file=sys.stderr)
530
finally:
531
raise err.NotSupportedError(
532
'pure_python=False, but the '
533
'C extension can not be loaded',
534
)
535
536
if self.pure_python is True:
537
pass
538
539
# The C extension handles these types internally.
540
elif _singlestoredb_accel is not None:
541
self.resultclass = MySQLResultSV
542
if self.cursorclass is Cursor:
543
self.cursorclass = CursorSV
544
elif self.cursorclass is SSCursor:
545
self.cursorclass = SSCursorSV
546
elif self.cursorclass is DictCursor:
547
self.cursorclass = DictCursorSV
548
self.results_type = 'dicts'
549
elif self.cursorclass is SSDictCursor:
550
self.cursorclass = SSDictCursorSV
551
self.results_type = 'dicts'
552
elif self.cursorclass is NamedtupleCursor:
553
self.cursorclass = NamedtupleCursorSV
554
self.results_type = 'namedtuples'
555
elif self.cursorclass is SSNamedtupleCursor:
556
self.cursorclass = SSNamedtupleCursorSV
557
self.results_type = 'namedtuples'
558
elif self.cursorclass is NumpyCursor:
559
self.cursorclass = NumpyCursorSV
560
self.results_type = 'numpy'
561
elif self.cursorclass is SSNumpyCursor:
562
self.cursorclass = SSNumpyCursorSV
563
self.results_type = 'numpy'
564
elif self.cursorclass is ArrowCursor:
565
self.cursorclass = ArrowCursorSV
566
self.results_type = 'arrow'
567
elif self.cursorclass is SSArrowCursor:
568
self.cursorclass = SSArrowCursorSV
569
self.results_type = 'arrow'
570
elif self.cursorclass is PandasCursor:
571
self.cursorclass = PandasCursorSV
572
self.results_type = 'pandas'
573
elif self.cursorclass is SSPandasCursor:
574
self.cursorclass = SSPandasCursorSV
575
self.results_type = 'pandas'
576
elif self.cursorclass is PolarsCursor:
577
self.cursorclass = PolarsCursorSV
578
self.results_type = 'polars'
579
elif self.cursorclass is SSPolarsCursor:
580
self.cursorclass = SSPolarsCursorSV
581
self.results_type = 'polars'
582
583
self._result = None
584
self._affected_rows = 0
585
self.host_info = 'Not connected'
586
587
# specified autocommit mode. None means use server default.
588
self.autocommit_mode = autocommit
589
590
if conv is None:
591
conv = converters.conversions
592
593
conv = conv.copy()
594
595
self.parse_json = parse_json
596
self.invalid_values = (invalid_values or {}).copy()
597
598
# Disable JSON parsing for Arrow
599
if self.results_type in ['arrow']:
600
conv[245] = None
601
self.parse_json = False
602
603
# Disable date/time parsing for polars; let polars do the parsing
604
elif self.results_type in ['polars']:
605
conv[7] = None
606
conv[10] = None
607
conv[12] = None
608
609
# Need for MySQLdb compatibility.
610
self.encoders = {k: v for (k, v) in conv.items() if type(k) is not int}
611
self.decoders = {k: v for (k, v) in conv.items() if type(k) is int}
612
self.sql_mode = sql_mode
613
self.init_command = init_command
614
self.max_allowed_packet = max_allowed_packet
615
self._auth_plugin_map = auth_plugin_map or {}
616
self._binary_prefix = binary_prefix
617
self.server_public_key = server_public_key
618
self.interpolate_query_with_empty_args = interpolate_query_with_empty_args
619
620
if self.connection_params['nan_as_null'] or \
621
self.connection_params['inf_as_null']:
622
float_encoder = self.encoders.get(float)
623
if float_encoder is not None:
624
self.encoders[float] = functools.partial(
625
float_encoder,
626
nan_as_null=self.connection_params['nan_as_null'],
627
inf_as_null=self.connection_params['inf_as_null'],
628
)
629
630
from .. import __version__ as VERSION_STRING
631
632
if 'SINGLESTOREDB_WORKLOAD_TYPE' in os.environ:
633
VERSION_STRING += '+' + os.environ['SINGLESTOREDB_WORKLOAD_TYPE']
634
635
self._connect_attrs = {
636
'_os': str(sys.platform),
637
'_pid': str(os.getpid()),
638
'_client_name': 'SingleStoreDB Python Client',
639
'_client_version': VERSION_STRING,
640
}
641
642
if program_name:
643
self._connect_attrs['program_name'] = program_name
644
if conn_attrs is not None:
645
# do not overwrite the attributes that we set ourselves
646
for k, v in conn_attrs.items():
647
if k not in self._connect_attrs:
648
self._connect_attrs[k] = v
649
650
self._is_committable = True
651
self._in_sync = False
652
self._tls_sni_servername = tls_sni_servername
653
self._track_env = bool(track_env) or self.host == 'singlestore.com'
654
self._enable_extended_data_types = enable_extended_data_types
655
if vector_data_format.lower() in ['json', 'binary']:
656
self._vector_data_format = vector_data_format
657
else:
658
raise ValueError(
659
'unknown value for vector_data_format, '
660
f'expecting "json" or "binary": {vector_data_format}',
661
)
662
self._connection_info = {}
663
events.subscribe(self._handle_event)
664
665
if defer_connect or self._track_env:
666
self._sock = None
667
else:
668
self.connect()
669
670
def _handle_event(self, data: Dict[str, Any]) -> None:
671
if data.get('name', '') == 'singlestore.portal.connection_updated':
672
self._connection_info = dict(data)
673
674
@property
675
def messages(self):
676
# TODO
677
[]
678
679
def __enter__(self):
680
return self
681
682
def __exit__(self, *exc_info):
683
del exc_info
684
self.close()
685
686
def _raise_mysql_exception(self, data):
687
err.raise_mysql_exception(data)
688
689
def _create_ssl_ctx(self, sslp):
690
if isinstance(sslp, ssl.SSLContext):
691
return sslp
692
ca = sslp.get('ca')
693
capath = sslp.get('capath')
694
hasnoca = ca is None and capath is None
695
ctx = ssl.create_default_context(cafile=ca, capath=capath)
696
ctx.check_hostname = not hasnoca and sslp.get('check_hostname', True)
697
verify_mode_value = sslp.get('verify_mode')
698
if verify_mode_value is None:
699
ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED
700
elif isinstance(verify_mode_value, bool):
701
ctx.verify_mode = ssl.CERT_REQUIRED if verify_mode_value else ssl.CERT_NONE
702
else:
703
if isinstance(verify_mode_value, str):
704
verify_mode_value = verify_mode_value.lower()
705
if verify_mode_value in ('none', '0', 'false', 'no'):
706
ctx.verify_mode = ssl.CERT_NONE
707
elif verify_mode_value == 'optional':
708
ctx.verify_mode = ssl.CERT_OPTIONAL
709
elif verify_mode_value in ('required', '1', 'true', 'yes'):
710
ctx.verify_mode = ssl.CERT_REQUIRED
711
else:
712
ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED
713
if 'cert' in sslp:
714
ctx.load_cert_chain(sslp['cert'], keyfile=sslp.get('key'))
715
if 'cipher' in sslp:
716
ctx.set_ciphers(sslp['cipher'])
717
ctx.options |= ssl.OP_NO_SSLv2
718
ctx.options |= ssl.OP_NO_SSLv3
719
return ctx
720
721
def close(self):
722
"""
723
Send the quit message and close the socket.
724
725
See `Connection.close()
726
<https://www.python.org/dev/peps/pep-0249/#Connection.close>`_
727
in the specification.
728
729
Raises
730
------
731
Error : If the connection is already closed.
732
733
"""
734
self._result = None
735
if self.host == 'singlestore.com':
736
return
737
if self._closed:
738
raise err.Error('Already closed')
739
events.unsubscribe(self._handle_event)
740
self._closed = True
741
if self._sock is None:
742
return
743
send_data = struct.pack('<iB', 1, COMMAND.COM_QUIT)
744
try:
745
self._write_bytes(send_data)
746
except Exception:
747
pass
748
finally:
749
self._force_close()
750
751
@property
752
def open(self):
753
"""Return True if the connection is open."""
754
return self._sock is not None
755
756
def is_connected(self):
757
"""Return True if the connection is open."""
758
return self.open
759
760
def _force_close(self):
761
"""Close connection without QUIT message."""
762
if self._sock:
763
try:
764
self._sock.close()
765
except: # noqa
766
pass
767
self._sock = None
768
self._rfile = None
769
770
__del__ = _force_close
771
772
def autocommit(self, value):
773
"""Enable autocommit in the server."""
774
self.autocommit_mode = bool(value)
775
current = self.get_autocommit()
776
if value != current:
777
self._send_autocommit_mode()
778
779
def get_autocommit(self):
780
"""Retrieve autocommit status."""
781
return bool(self.server_status & SERVER_STATUS.SERVER_STATUS_AUTOCOMMIT)
782
783
def _read_ok_packet(self):
784
pkt = self._read_packet()
785
if not pkt.is_ok_packet():
786
raise err.OperationalError(
787
CR.CR_COMMANDS_OUT_OF_SYNC,
788
'Command Out of Sync',
789
)
790
ok = OKPacketWrapper(pkt)
791
self.server_status = ok.server_status
792
return ok
793
794
def _send_autocommit_mode(self):
795
"""Set whether or not to commit after every execute()."""
796
log_query('SET AUTOCOMMIT = %s' % self.escape(self.autocommit_mode))
797
self._execute_command(
798
COMMAND.COM_QUERY, 'SET AUTOCOMMIT = %s' % self.escape(self.autocommit_mode),
799
)
800
self._read_ok_packet()
801
802
def begin(self):
803
"""Begin transaction."""
804
log_query('BEGIN')
805
if self.host == 'singlestore.com':
806
return
807
self._execute_command(COMMAND.COM_QUERY, 'BEGIN')
808
self._read_ok_packet()
809
810
def commit(self):
811
"""
812
Commit changes to stable storage.
813
814
See `Connection.commit() <https://www.python.org/dev/peps/pep-0249/#commit>`_
815
in the specification.
816
817
"""
818
log_query('COMMIT')
819
if not self._is_committable or self.host == 'singlestore.com':
820
self._is_committable = True
821
return
822
self._execute_command(COMMAND.COM_QUERY, 'COMMIT')
823
self._read_ok_packet()
824
825
def rollback(self):
826
"""
827
Roll back the current transaction.
828
829
See `Connection.rollback() <https://www.python.org/dev/peps/pep-0249/#rollback>`_
830
in the specification.
831
832
"""
833
log_query('ROLLBACK')
834
if not self._is_committable or self.host == 'singlestore.com':
835
self._is_committable = True
836
return
837
self._execute_command(COMMAND.COM_QUERY, 'ROLLBACK')
838
self._read_ok_packet()
839
840
def show_warnings(self):
841
"""Send the "SHOW WARNINGS" SQL command."""
842
log_query('SHOW WARNINGS')
843
self._execute_command(COMMAND.COM_QUERY, 'SHOW WARNINGS')
844
result = self.resultclass(self)
845
result.read()
846
return result.rows
847
848
def select_db(self, db):
849
"""
850
Set current db.
851
852
db : str
853
The name of the db.
854
855
"""
856
self._execute_command(COMMAND.COM_INIT_DB, db)
857
self._read_ok_packet()
858
859
def escape(self, obj, mapping=None):
860
"""
861
Escape whatever value is passed.
862
863
Non-standard, for internal use; do not use this in your applications.
864
865
"""
866
dtype = type(obj)
867
if dtype is str or isinstance(obj, str):
868
return "'{}'".format(self.escape_string(obj))
869
if dtype is bytes or dtype is bytearray or isinstance(obj, (bytes, bytearray)):
870
return self._quote_bytes(obj)
871
if mapping is None:
872
mapping = self.encoders
873
return converters.escape_item(obj, self.charset, mapping=mapping)
874
875
def literal(self, obj):
876
"""
877
Alias for escape().
878
879
Non-standard, for internal use; do not use this in your applications.
880
881
"""
882
return self.escape(obj, self.encoders)
883
884
def escape_string(self, s):
885
"""Escape a string value."""
886
if self.server_status & SERVER_STATUS.SERVER_STATUS_NO_BACKSLASH_ESCAPES:
887
return s.replace("'", "''")
888
return converters.escape_string(s)
889
890
def _quote_bytes(self, s):
891
if self.server_status & SERVER_STATUS.SERVER_STATUS_NO_BACKSLASH_ESCAPES:
892
if self._binary_prefix:
893
return "_binary X'{}'".format(s.hex())
894
return "X'{}'".format(s.hex())
895
return converters.escape_bytes(s)
896
897
def cursor(self):
898
"""Create a new cursor to execute queries with."""
899
return self.cursorclass(self)
900
901
# The following methods are INTERNAL USE ONLY (called from Cursor)
902
def query(self, sql, unbuffered=False, infile_stream=None):
903
"""
904
Run a query on the server.
905
906
Internal use only.
907
908
"""
909
# if DEBUG:
910
# print("DEBUG: sending query:", sql)
911
handler = fusion.get_handler(sql)
912
if handler is not None:
913
self._is_committable = False
914
self._result = fusion.execute(self, sql, handler=handler)
915
self._affected_rows = self._result.affected_rows
916
else:
917
self._is_committable = True
918
if isinstance(sql, str):
919
sql = sql.encode(self.encoding, 'surrogateescape')
920
self._local_infile_stream = infile_stream
921
self._execute_command(COMMAND.COM_QUERY, sql)
922
self._affected_rows = self._read_query_result(unbuffered=unbuffered)
923
self._local_infile_stream = None
924
return self._affected_rows
925
926
def next_result(self, unbuffered=False):
927
"""
928
Retrieve the next result set.
929
930
Internal use only.
931
932
"""
933
self._affected_rows = self._read_query_result(unbuffered=unbuffered)
934
return self._affected_rows
935
936
def affected_rows(self):
937
"""
938
Return number of affected rows.
939
940
Internal use only.
941
942
"""
943
return self._affected_rows
944
945
def kill(self, thread_id):
946
"""
947
Execute kill command.
948
949
Internal use only.
950
951
"""
952
arg = struct.pack('<I', thread_id)
953
self._execute_command(COMMAND.COM_PROCESS_KILL, arg)
954
return self._read_ok_packet()
955
956
def ping(self, reconnect=True):
957
"""
958
Check if the server is alive.
959
960
Parameters
961
----------
962
reconnect : bool, optional
963
If the connection is closed, reconnect.
964
965
Raises
966
------
967
Error : If the connection is closed and reconnect=False.
968
969
"""
970
if self._sock is None:
971
if reconnect:
972
self.connect()
973
reconnect = False
974
else:
975
raise err.Error('Already closed')
976
try:
977
self._execute_command(COMMAND.COM_PING, '')
978
self._read_ok_packet()
979
except Exception:
980
if reconnect:
981
self.connect()
982
self.ping(False)
983
else:
984
raise
985
986
def set_charset(self, charset):
987
"""Deprecated. Use set_character_set() instead."""
988
# This function has been implemented in old PyMySQL.
989
# But this name is different from MySQLdb.
990
# So we keep this function for compatibility and add
991
# new set_character_set() function.
992
self.set_character_set(charset)
993
994
def set_character_set(self, charset, collation=None):
995
"""
996
Set charaset (and collation) on the server.
997
998
Send "SET NAMES charset [COLLATE collation]" query.
999
Update Connection.encoding based on charset.
1000
1001
Parameters
1002
----------
1003
charset : str
1004
The charset to enable.
1005
collation : str, optional
1006
The collation value
1007
1008
"""
1009
# Make sure charset is supported.
1010
encoding = charset_by_name(charset).encoding
1011
1012
if collation:
1013
query = f'SET NAMES {charset} COLLATE {collation}'
1014
else:
1015
query = f'SET NAMES {charset}'
1016
self._execute_command(COMMAND.COM_QUERY, query)
1017
self._read_packet()
1018
self.charset = charset
1019
self.encoding = encoding
1020
self.collation = collation
1021
1022
def _sync_connection(self):
1023
"""Synchronize connection with env variable."""
1024
if self._in_sync:
1025
return
1026
1027
if not self._track_env:
1028
return
1029
1030
url = self._connection_info.get('connection_url')
1031
if not url:
1032
url = os.environ.get('SINGLESTOREDB_URL')
1033
if not url:
1034
return
1035
1036
out = {}
1037
urlp = connection._parse_url(url)
1038
out.update(urlp)
1039
1040
out = connection._cast_params(out)
1041
1042
# Set default port based on driver.
1043
if 'port' not in out or not out['port']:
1044
out['port'] = int(get_option('port') or 3306)
1045
1046
# If there is no user and the password is empty, remove the password key.
1047
if 'user' not in out and not out.get('password', None):
1048
out.pop('password', None)
1049
1050
if out['host'] == 'singlestore.com':
1051
raise err.InterfaceError(0, 'Connection URL has not been established')
1052
1053
# If it's just a password change, we don't need to reconnect
1054
if self._sock is not None and \
1055
(self.host, self.port, self.user, self.db) == \
1056
(out['host'], out['port'], out['user'], out.get('database')):
1057
return
1058
1059
self.host = out['host']
1060
self.port = out['port']
1061
self.user = out['user']
1062
if isinstance(out['password'], str):
1063
self.password = out['password'].encode('latin-1')
1064
else:
1065
self.password = out['password'] or b''
1066
self.db = out.get('database')
1067
try:
1068
self._in_sync = True
1069
self.connect()
1070
finally:
1071
self._in_sync = False
1072
1073
def connect(self, sock=None):
1074
"""
1075
Connect to server using existing parameters.
1076
1077
Internal use only.
1078
1079
"""
1080
self._closed = False
1081
try:
1082
if sock is None:
1083
if self.unix_socket:
1084
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1085
sock.settimeout(self.connect_timeout)
1086
sock.connect(self.unix_socket)
1087
self.host_info = 'Localhost via UNIX socket'
1088
self._secure = True
1089
if DEBUG:
1090
print('connected using unix_socket')
1091
else:
1092
kwargs = {}
1093
if self.bind_address is not None:
1094
kwargs['source_address'] = (self.bind_address, 0)
1095
while True:
1096
try:
1097
sock = socket.create_connection(
1098
(self.host, self.port), self.connect_timeout, **kwargs,
1099
)
1100
break
1101
except OSError as e:
1102
if e.errno == errno.EINTR:
1103
continue
1104
raise
1105
self.host_info = 'socket %s:%d' % (self.host, self.port)
1106
if DEBUG:
1107
print('connected using socket')
1108
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
1109
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
1110
sock.settimeout(None)
1111
1112
self._sock = sock
1113
self._rfile = sock.makefile('rb')
1114
self._next_seq_id = 0
1115
1116
self._get_server_information()
1117
self._request_authentication()
1118
1119
# Send "SET NAMES" query on init for:
1120
# - Ensure charaset (and collation) is set to the server.
1121
# - collation_id in handshake packet may be ignored.
1122
# - If collation is not specified, we don't know what is server's
1123
# default collation for the charset. For example, default collation
1124
# of utf8mb4 is:
1125
# - MySQL 5.7, MariaDB 10.x: utf8mb4_general_ci
1126
# - MySQL 8.0: utf8mb4_0900_ai_ci
1127
#
1128
# Reference:
1129
# - https://github.com/PyMySQL/PyMySQL/issues/1092
1130
# - https://github.com/wagtail/wagtail/issues/9477
1131
# - https://zenn.dev/methane/articles/2023-mysql-collation (Japanese)
1132
self.set_character_set(self.charset, self.collation)
1133
1134
if self.sql_mode is not None:
1135
c = self.cursor()
1136
c.execute('SET sql_mode=%s', (self.sql_mode,))
1137
c.close()
1138
1139
if self._enable_extended_data_types:
1140
c = self.cursor()
1141
try:
1142
c.execute('SET @@SESSION.enable_extended_types_metadata=on')
1143
except self.OperationalError:
1144
pass
1145
c.close()
1146
1147
if self._vector_data_format:
1148
c = self.cursor()
1149
try:
1150
val = self._vector_data_format
1151
c.execute(f'SET @@SESSION.vector_type_project_format={val}')
1152
except self.OperationalError:
1153
pass
1154
c.close()
1155
1156
if self.init_command is not None:
1157
c = self.cursor()
1158
c.execute(self.init_command)
1159
c.close()
1160
1161
if self.autocommit_mode is not None:
1162
self.autocommit(self.autocommit_mode)
1163
1164
except BaseException as e:
1165
self._rfile = None
1166
if sock is not None:
1167
try:
1168
sock.close()
1169
except: # noqa
1170
pass
1171
1172
if isinstance(e, (OSError, IOError, socket.error)):
1173
exc = err.OperationalError(
1174
CR.CR_CONN_HOST_ERROR,
1175
f'Can\'t connect to MySQL server on {self.host!r} ({e})',
1176
)
1177
# Keep original exception and traceback to investigate error.
1178
exc.original_exception = e
1179
exc.traceback = traceback.format_exc()
1180
if DEBUG:
1181
print(exc.traceback)
1182
raise exc
1183
1184
# If e is neither DatabaseError or IOError, It's a bug.
1185
# But raising AssertionError hides original error.
1186
# So just reraise it.
1187
raise
1188
1189
def write_packet(self, payload):
1190
"""
1191
Writes an entire "mysql packet" in its entirety to the network.
1192
1193
Adds its length and sequence number.
1194
1195
"""
1196
# Internal note: when you build packet manually and calls _write_bytes()
1197
# directly, you should set self._next_seq_id properly.
1198
data = _pack_int24(len(payload)) + bytes([self._next_seq_id]) + payload
1199
if DEBUG:
1200
dump_packet(data)
1201
self._write_bytes(data)
1202
self._next_seq_id = (self._next_seq_id + 1) % 256
1203
1204
def _read_packet(self, packet_type=MysqlPacket):
1205
"""
1206
Read an entire "mysql packet" in its entirety from the network.
1207
1208
Raises
1209
------
1210
OperationalError : If the connection to the MySQL server is lost.
1211
InternalError : If the packet sequence number is wrong.
1212
1213
Returns
1214
-------
1215
MysqlPacket
1216
1217
"""
1218
buff = bytearray()
1219
while True:
1220
packet_header = self._read_bytes(4)
1221
# if DEBUG: dump_packet(packet_header)
1222
1223
btrl, btrh, packet_number = struct.unpack('<HBB', packet_header)
1224
bytes_to_read = btrl + (btrh << 16)
1225
if packet_number != self._next_seq_id:
1226
self._force_close()
1227
if packet_number == 0:
1228
# MariaDB sends error packet with seqno==0 when shutdown
1229
raise err.OperationalError(
1230
CR.CR_SERVER_LOST,
1231
'Lost connection to MySQL server during query',
1232
)
1233
raise err.InternalError(
1234
'Packet sequence number wrong - got %d expected %d'
1235
% (packet_number, self._next_seq_id),
1236
)
1237
self._next_seq_id = (self._next_seq_id + 1) % 256
1238
1239
recv_data = self._read_bytes(bytes_to_read)
1240
if DEBUG:
1241
dump_packet(recv_data)
1242
buff += recv_data
1243
# https://dev.mysql.com/doc/internals/en/sending-more-than-16mbyte.html
1244
if bytes_to_read == 0xFFFFFF:
1245
continue
1246
if bytes_to_read < MAX_PACKET_LEN:
1247
break
1248
1249
packet = packet_type(bytes(buff), self.encoding)
1250
if packet.is_error_packet():
1251
if self._result is not None and self._result.unbuffered_active is True:
1252
self._result.unbuffered_active = False
1253
packet.raise_for_error()
1254
return packet
1255
1256
def _read_bytes(self, num_bytes):
1257
if self._read_timeout is not None:
1258
self._sock.settimeout(self._read_timeout)
1259
while True:
1260
try:
1261
data = self._rfile.read(num_bytes)
1262
break
1263
except OSError as e:
1264
if e.errno == errno.EINTR:
1265
continue
1266
self._force_close()
1267
raise err.OperationalError(
1268
CR.CR_SERVER_LOST,
1269
'Lost connection to MySQL server during query (%s)' % (e,),
1270
)
1271
except BaseException:
1272
# Don't convert unknown exception to MySQLError.
1273
self._force_close()
1274
raise
1275
if len(data) < num_bytes:
1276
self._force_close()
1277
raise err.OperationalError(
1278
CR.CR_SERVER_LOST, 'Lost connection to MySQL server during query',
1279
)
1280
return data
1281
1282
def _write_bytes(self, data):
1283
if self._write_timeout is not None:
1284
self._sock.settimeout(self._write_timeout)
1285
try:
1286
self._sock.sendall(data)
1287
except OSError as e:
1288
self._force_close()
1289
raise err.OperationalError(
1290
CR.CR_SERVER_GONE_ERROR, f'MySQL server has gone away ({e!r})',
1291
)
1292
1293
def _read_query_result(self, unbuffered=False):
1294
self._result = None
1295
if unbuffered:
1296
result = self.resultclass(self, unbuffered=unbuffered)
1297
else:
1298
result = self.resultclass(self)
1299
result.read()
1300
self._result = result
1301
if result.server_status is not None:
1302
self.server_status = result.server_status
1303
return result.affected_rows
1304
1305
def insert_id(self):
1306
if self._result:
1307
return self._result.insert_id
1308
else:
1309
return 0
1310
1311
def _execute_command(self, command, sql):
1312
"""
1313
Execute command.
1314
1315
Raises
1316
------
1317
InterfaceError : If the connection is closed.
1318
ValueError : If no username was specified.
1319
1320
"""
1321
self._sync_connection()
1322
1323
if self._sock is None:
1324
raise err.InterfaceError(0, 'The connection has been closed')
1325
1326
# If the last query was unbuffered, make sure it finishes before
1327
# sending new commands
1328
if self._result is not None:
1329
if self._result.unbuffered_active:
1330
warnings.warn('Previous unbuffered result was left incomplete')
1331
self._result._finish_unbuffered_query()
1332
while self._result.has_next:
1333
self.next_result()
1334
self._result = None
1335
1336
if isinstance(sql, str):
1337
sql = sql.encode(self.encoding)
1338
1339
packet_size = min(MAX_PACKET_LEN, len(sql) + 1) # +1 is for command
1340
1341
# tiny optimization: build first packet manually instead of
1342
# calling self..write_packet()
1343
prelude = struct.pack('<iB', packet_size, command)
1344
packet = prelude + sql[: packet_size - 1]
1345
self._write_bytes(packet)
1346
if DEBUG:
1347
dump_packet(packet)
1348
self._next_seq_id = 1
1349
1350
if packet_size < MAX_PACKET_LEN:
1351
return
1352
1353
sql = sql[packet_size - 1:]
1354
while True:
1355
packet_size = min(MAX_PACKET_LEN, len(sql))
1356
self.write_packet(sql[:packet_size])
1357
sql = sql[packet_size:]
1358
if not sql and packet_size < MAX_PACKET_LEN:
1359
break
1360
1361
def _request_authentication(self): # noqa: C901
1362
# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
1363
if int(self.server_version.split('.', 1)[0]) >= 5:
1364
self.client_flag |= CLIENT.MULTI_RESULTS
1365
1366
if self.user is None:
1367
raise ValueError('Did not specify a username')
1368
1369
charset_id = charset_by_name(self.charset).id
1370
if isinstance(self.user, str):
1371
self.user = self.user.encode(self.encoding)
1372
1373
data_init = struct.pack(
1374
'<iIB23s', self.client_flag, MAX_PACKET_LEN, charset_id, b'',
1375
)
1376
1377
if self.ssl and self.server_capabilities & CLIENT.SSL:
1378
self.write_packet(data_init)
1379
1380
hostname = self.host
1381
if self._tls_sni_servername:
1382
hostname = self._tls_sni_servername
1383
self._sock = self.ctx.wrap_socket(self._sock, server_hostname=hostname)
1384
self._rfile = self._sock.makefile('rb')
1385
self._secure = True
1386
1387
data = data_init + self.user + b'\0'
1388
1389
authresp = b''
1390
plugin_name = None
1391
1392
if self._auth_plugin_name == '':
1393
plugin_name = b''
1394
authresp = _auth.scramble_native_password(self.password, self.salt)
1395
elif self._auth_plugin_name == 'mysql_native_password':
1396
plugin_name = b'mysql_native_password'
1397
authresp = _auth.scramble_native_password(self.password, self.salt)
1398
elif self._auth_plugin_name == 'caching_sha2_password':
1399
plugin_name = b'caching_sha2_password'
1400
if self.password:
1401
if DEBUG:
1402
print('caching_sha2: trying fast path')
1403
authresp = _auth.scramble_caching_sha2(self.password, self.salt)
1404
else:
1405
if DEBUG:
1406
print('caching_sha2: empty password')
1407
elif self._auth_plugin_name == 'sha256_password':
1408
plugin_name = b'sha256_password'
1409
if self.ssl and self.server_capabilities & CLIENT.SSL:
1410
authresp = self.password + b'\0'
1411
elif self.password:
1412
authresp = b'\1' # request public key
1413
else:
1414
authresp = b'\0' # empty password
1415
1416
if self.server_capabilities & CLIENT.PLUGIN_AUTH_LENENC_CLIENT_DATA:
1417
data += _lenenc_int(len(authresp)) + authresp
1418
elif self.server_capabilities & CLIENT.SECURE_CONNECTION:
1419
data += struct.pack('B', len(authresp)) + authresp
1420
else: # pragma: no cover - no testing against servers w/o secure auth (>=5.0)
1421
data += authresp + b'\0'
1422
1423
if self.server_capabilities & CLIENT.CONNECT_WITH_DB:
1424
db = self.db
1425
if isinstance(db, str):
1426
db = db.encode(self.encoding)
1427
data += (db or b'') + b'\0'
1428
1429
if self.server_capabilities & CLIENT.PLUGIN_AUTH:
1430
data += (plugin_name or b'') + b'\0'
1431
1432
if self.server_capabilities & CLIENT.CONNECT_ATTRS:
1433
connect_attrs = b''
1434
for k, v in self._connect_attrs.items():
1435
k = k.encode('utf-8')
1436
connect_attrs += _lenenc_int(len(k)) + k
1437
v = v.encode('utf-8')
1438
connect_attrs += _lenenc_int(len(v)) + v
1439
data += _lenenc_int(len(connect_attrs)) + connect_attrs
1440
1441
self.write_packet(data)
1442
auth_packet = self._read_packet()
1443
1444
# if authentication method isn't accepted the first byte
1445
# will have the octet 254
1446
if auth_packet.is_auth_switch_request():
1447
if DEBUG:
1448
print('received auth switch')
1449
# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest
1450
auth_packet.read_uint8() # 0xfe packet identifier
1451
plugin_name = auth_packet.read_string()
1452
if (
1453
self.server_capabilities & CLIENT.PLUGIN_AUTH
1454
and plugin_name is not None
1455
):
1456
auth_packet = self._process_auth(plugin_name, auth_packet)
1457
else:
1458
raise err.OperationalError('received unknown auth switch request')
1459
elif auth_packet.is_extra_auth_data():
1460
if DEBUG:
1461
print('received extra data')
1462
# https://dev.mysql.com/doc/internals/en/successful-authentication.html
1463
if self._auth_plugin_name == 'caching_sha2_password':
1464
auth_packet = _auth.caching_sha2_password_auth(self, auth_packet)
1465
elif self._auth_plugin_name == 'sha256_password':
1466
auth_packet = _auth.sha256_password_auth(self, auth_packet)
1467
else:
1468
raise err.OperationalError(
1469
'Received extra packet for auth method %r', self._auth_plugin_name,
1470
)
1471
1472
if DEBUG:
1473
print('Succeed to auth')
1474
1475
def _process_auth(self, plugin_name, auth_packet):
1476
handler = self._get_auth_plugin_handler(plugin_name)
1477
if handler:
1478
try:
1479
return handler.authenticate(auth_packet)
1480
except AttributeError:
1481
if plugin_name != b'dialog':
1482
raise err.OperationalError(
1483
CR.CR_AUTH_PLUGIN_CANNOT_LOAD,
1484
"Authentication plugin '%s'"
1485
' not loaded: - %r missing authenticate method'
1486
% (plugin_name, type(handler)),
1487
)
1488
if plugin_name == b'caching_sha2_password':
1489
return _auth.caching_sha2_password_auth(self, auth_packet)
1490
elif plugin_name == b'sha256_password':
1491
return _auth.sha256_password_auth(self, auth_packet)
1492
elif plugin_name == b'mysql_native_password':
1493
data = _auth.scramble_native_password(self.password, auth_packet.read_all())
1494
elif plugin_name == b'client_ed25519':
1495
data = _auth.ed25519_password(self.password, auth_packet.read_all())
1496
elif plugin_name == b'mysql_old_password':
1497
data = (
1498
_auth.scramble_old_password(self.password, auth_packet.read_all())
1499
+ b'\0'
1500
)
1501
elif plugin_name == b'mysql_clear_password':
1502
# https://dev.mysql.com/doc/internals/en/clear-text-authentication.html
1503
data = self.password + b'\0'
1504
elif plugin_name == b'auth_gssapi_client':
1505
data = _auth.gssapi_auth(auth_packet.read_all())
1506
elif plugin_name == b'dialog':
1507
pkt = auth_packet
1508
while True:
1509
flag = pkt.read_uint8()
1510
echo = (flag & 0x06) == 0x02
1511
last = (flag & 0x01) == 0x01
1512
prompt = pkt.read_all()
1513
1514
if prompt == b'Password: ':
1515
self.write_packet(self.password + b'\0')
1516
elif handler:
1517
resp = 'no response - TypeError within plugin.prompt method'
1518
try:
1519
resp = handler.prompt(echo, prompt)
1520
self.write_packet(resp + b'\0')
1521
except AttributeError:
1522
raise err.OperationalError(
1523
CR.CR_AUTH_PLUGIN_CANNOT_LOAD,
1524
"Authentication plugin '%s'"
1525
' not loaded: - %r missing prompt method'
1526
% (plugin_name, handler),
1527
)
1528
except TypeError:
1529
raise err.OperationalError(
1530
CR.CR_AUTH_PLUGIN_ERR,
1531
"Authentication plugin '%s'"
1532
" %r didn't respond with string. Returned '%r' to prompt %r"
1533
% (plugin_name, handler, resp, prompt),
1534
)
1535
else:
1536
raise err.OperationalError(
1537
CR.CR_AUTH_PLUGIN_CANNOT_LOAD,
1538
"Authentication plugin '%s' not configured" % (plugin_name,),
1539
)
1540
pkt = self._read_packet()
1541
pkt.check_error()
1542
if pkt.is_ok_packet() or last:
1543
break
1544
return pkt
1545
else:
1546
raise err.OperationalError(
1547
CR.CR_AUTH_PLUGIN_CANNOT_LOAD,
1548
"Authentication plugin '%s' not configured" % plugin_name,
1549
)
1550
1551
self.write_packet(data)
1552
pkt = self._read_packet()
1553
pkt.check_error()
1554
return pkt
1555
1556
def _get_auth_plugin_handler(self, plugin_name):
1557
plugin_class = self._auth_plugin_map.get(plugin_name)
1558
if not plugin_class and isinstance(plugin_name, bytes):
1559
plugin_class = self._auth_plugin_map.get(plugin_name.decode('ascii'))
1560
if plugin_class:
1561
try:
1562
handler = plugin_class(self)
1563
except TypeError:
1564
raise err.OperationalError(
1565
CR.CR_AUTH_PLUGIN_CANNOT_LOAD,
1566
"Authentication plugin '%s'"
1567
' not loaded: - %r cannot be constructed with connection object'
1568
% (plugin_name, plugin_class),
1569
)
1570
else:
1571
handler = None
1572
return handler
1573
1574
# _mysql support
1575
def thread_id(self):
1576
return self.server_thread_id[0]
1577
1578
def character_set_name(self):
1579
return self.charset
1580
1581
def get_host_info(self):
1582
return self.host_info
1583
1584
def get_proto_info(self):
1585
return self.protocol_version
1586
1587
def _get_server_information(self):
1588
i = 0
1589
packet = self._read_packet()
1590
data = packet.get_all_data()
1591
1592
self.protocol_version = data[i]
1593
i += 1
1594
1595
server_end = data.find(b'\0', i)
1596
self.server_version = data[i:server_end].decode('latin1')
1597
i = server_end + 1
1598
1599
self.server_thread_id = struct.unpack('<I', data[i: i + 4])
1600
i += 4
1601
1602
self.salt = data[i: i + 8]
1603
i += 9 # 8 + 1(filler)
1604
1605
self.server_capabilities = struct.unpack('<H', data[i: i + 2])[0]
1606
i += 2
1607
1608
if len(data) >= i + 6:
1609
lang, stat, cap_h, salt_len = struct.unpack('<BHHB', data[i: i + 6])
1610
i += 6
1611
# TODO: deprecate server_language and server_charset.
1612
# mysqlclient-python doesn't provide it.
1613
self.server_language = lang
1614
try:
1615
self.server_charset = charset_by_id(lang).name
1616
except KeyError:
1617
# unknown collation
1618
self.server_charset = None
1619
1620
self.server_status = stat
1621
if DEBUG:
1622
print('server_status: %x' % stat)
1623
1624
self.server_capabilities |= cap_h << 16
1625
if DEBUG:
1626
print('salt_len:', salt_len)
1627
salt_len = max(12, salt_len - 9)
1628
1629
# reserved
1630
i += 10
1631
1632
if len(data) >= i + salt_len:
1633
# salt_len includes auth_plugin_data_part_1 and filler
1634
self.salt += data[i: i + salt_len]
1635
i += salt_len
1636
1637
i += 1
1638
# AUTH PLUGIN NAME may appear here.
1639
if self.server_capabilities & CLIENT.PLUGIN_AUTH and len(data) >= i:
1640
# Due to Bug#59453 the auth-plugin-name is missing the terminating
1641
# NUL-char in versions prior to 5.5.10 and 5.6.2.
1642
# ref: https://dev.mysql.com/doc/internals/en/
1643
# connection-phase-packets.html#packet-Protocol::Handshake
1644
# didn't use version checks as mariadb is corrected and reports
1645
# earlier than those two.
1646
server_end = data.find(b'\0', i)
1647
if server_end < 0: # pragma: no cover - very specific upstream bug
1648
# not found \0 and last field so take it all
1649
self._auth_plugin_name = data[i:].decode('utf-8')
1650
else:
1651
self._auth_plugin_name = data[i:server_end].decode('utf-8')
1652
1653
def get_server_info(self):
1654
return self.server_version
1655
1656
Warning = err.Warning
1657
Error = err.Error
1658
InterfaceError = err.InterfaceError
1659
DatabaseError = err.DatabaseError
1660
DataError = err.DataError
1661
OperationalError = err.OperationalError
1662
IntegrityError = err.IntegrityError
1663
InternalError = err.InternalError
1664
ProgrammingError = err.ProgrammingError
1665
NotSupportedError = err.NotSupportedError
1666
1667
1668
class MySQLResult:
1669
"""
1670
Results of a SQL query.
1671
1672
Parameters
1673
----------
1674
connection : Connection
1675
The connection the result came from.
1676
unbuffered : bool, optional
1677
Should the reads be unbuffered?
1678
1679
"""
1680
1681
def __init__(self, connection, unbuffered=False):
1682
self.connection = connection
1683
self.affected_rows = None
1684
self.insert_id = None
1685
self.server_status = None
1686
self.warning_count = 0
1687
self.message = None
1688
self.field_count = 0
1689
self.description = None
1690
self.rows = None
1691
self.has_next = None
1692
self.unbuffered_active = False
1693
self.converters = []
1694
self.fields = []
1695
self.encoding_errors = self.connection.encoding_errors
1696
if unbuffered:
1697
try:
1698
self.init_unbuffered_query()
1699
except Exception:
1700
self.connection = None
1701
self.unbuffered_active = False
1702
raise
1703
1704
def __del__(self):
1705
if self.unbuffered_active:
1706
self._finish_unbuffered_query()
1707
1708
def read(self):
1709
try:
1710
first_packet = self.connection._read_packet()
1711
1712
if first_packet.is_ok_packet():
1713
self._read_ok_packet(first_packet)
1714
elif first_packet.is_load_local_packet():
1715
self._read_load_local_packet(first_packet)
1716
else:
1717
self._read_result_packet(first_packet)
1718
finally:
1719
self.connection = None
1720
1721
def init_unbuffered_query(self):
1722
"""
1723
Initialize an unbuffered query.
1724
1725
Raises
1726
------
1727
OperationalError : If the connection to the MySQL server is lost.
1728
InternalError : Other errors.
1729
1730
"""
1731
self.unbuffered_active = True
1732
first_packet = self.connection._read_packet()
1733
1734
if first_packet.is_ok_packet():
1735
self._read_ok_packet(first_packet)
1736
self.unbuffered_active = False
1737
self.connection = None
1738
elif first_packet.is_load_local_packet():
1739
self._read_load_local_packet(first_packet)
1740
self.unbuffered_active = False
1741
self.connection = None
1742
else:
1743
self.field_count = first_packet.read_length_encoded_integer()
1744
self._get_descriptions()
1745
1746
# Apparently, MySQLdb picks this number because it's the maximum
1747
# value of a 64bit unsigned integer. Since we're emulating MySQLdb,
1748
# we set it to this instead of None, which would be preferred.
1749
self.affected_rows = 18446744073709551615
1750
1751
def _read_ok_packet(self, first_packet):
1752
ok_packet = OKPacketWrapper(first_packet)
1753
self.affected_rows = ok_packet.affected_rows
1754
self.insert_id = ok_packet.insert_id
1755
self.server_status = ok_packet.server_status
1756
self.warning_count = ok_packet.warning_count
1757
self.message = ok_packet.message
1758
self.has_next = ok_packet.has_next
1759
1760
def _read_load_local_packet(self, first_packet):
1761
if not self.connection._local_infile:
1762
raise RuntimeError(
1763
'**WARN**: Received LOAD_LOCAL packet but local_infile option is false.',
1764
)
1765
load_packet = LoadLocalPacketWrapper(first_packet)
1766
sender = LoadLocalFile(load_packet.filename, self.connection)
1767
try:
1768
sender.send_data()
1769
except Exception:
1770
self.connection._read_packet() # skip ok packet
1771
raise
1772
1773
ok_packet = self.connection._read_packet()
1774
if (
1775
not ok_packet.is_ok_packet()
1776
): # pragma: no cover - upstream induced protocol error
1777
raise err.OperationalError(
1778
CR.CR_COMMANDS_OUT_OF_SYNC,
1779
'Commands Out of Sync',
1780
)
1781
self._read_ok_packet(ok_packet)
1782
1783
def _check_packet_is_eof(self, packet):
1784
if not packet.is_eof_packet():
1785
return False
1786
# TODO: Support CLIENT.DEPRECATE_EOF
1787
# 1) Add DEPRECATE_EOF to CAPABILITIES
1788
# 2) Mask CAPABILITIES with server_capabilities
1789
# 3) if server_capabilities & CLIENT.DEPRECATE_EOF: use OKPacketWrapper
1790
# instead of EOFPacketWrapper
1791
wp = EOFPacketWrapper(packet)
1792
self.warning_count = wp.warning_count
1793
self.has_next = wp.has_next
1794
return True
1795
1796
def _read_result_packet(self, first_packet):
1797
self.field_count = first_packet.read_length_encoded_integer()
1798
self._get_descriptions()
1799
self._read_rowdata_packet()
1800
1801
def _read_rowdata_packet_unbuffered(self):
1802
# Check if in an active query
1803
if not self.unbuffered_active:
1804
return
1805
1806
# EOF
1807
packet = self.connection._read_packet()
1808
if self._check_packet_is_eof(packet):
1809
self.unbuffered_active = False
1810
self.connection = None
1811
self.rows = None
1812
return
1813
1814
row = self._read_row_from_packet(packet)
1815
self.affected_rows = 1
1816
self.rows = (row,) # rows should tuple of row for MySQL-python compatibility.
1817
return row
1818
1819
def _finish_unbuffered_query(self):
1820
# After much reading on the MySQL protocol, it appears that there is,
1821
# in fact, no way to stop MySQL from sending all the data after
1822
# executing a query, so we just spin, and wait for an EOF packet.
1823
while self.unbuffered_active and self.connection._sock is not None:
1824
try:
1825
packet = self.connection._read_packet()
1826
except err.OperationalError as e:
1827
if e.args[0] in (
1828
ER.QUERY_TIMEOUT,
1829
ER.STATEMENT_TIMEOUT,
1830
):
1831
# if the query timed out we can simply ignore this error
1832
self.unbuffered_active = False
1833
self.connection = None
1834
return
1835
1836
raise
1837
1838
if self._check_packet_is_eof(packet):
1839
self.unbuffered_active = False
1840
self.connection = None # release reference to kill cyclic reference.
1841
1842
def _read_rowdata_packet(self):
1843
"""Read a rowdata packet for each data row in the result set."""
1844
rows = []
1845
while True:
1846
packet = self.connection._read_packet()
1847
if self._check_packet_is_eof(packet):
1848
self.connection = None # release reference to kill cyclic reference.
1849
break
1850
rows.append(self._read_row_from_packet(packet))
1851
1852
self.affected_rows = len(rows)
1853
self.rows = tuple(rows)
1854
1855
def _read_row_from_packet(self, packet):
1856
row = []
1857
for i, (encoding, converter) in enumerate(self.converters):
1858
try:
1859
data = packet.read_length_coded_string()
1860
except IndexError:
1861
# No more columns in this row
1862
# See https://github.com/PyMySQL/PyMySQL/pull/434
1863
break
1864
if data is not None:
1865
if encoding is not None:
1866
try:
1867
data = data.decode(encoding, errors=self.encoding_errors)
1868
except UnicodeDecodeError:
1869
raise UnicodeDecodeError(
1870
'failed to decode string value in column '
1871
f"'{self.fields[i].name}' using encoding '{encoding}'; " +
1872
"use the 'encoding_errors' option on the connection " +
1873
'to specify how to handle this error',
1874
)
1875
if DEBUG:
1876
print('DEBUG: DATA = ', data)
1877
if converter is not None:
1878
data = converter(data)
1879
row.append(data)
1880
return tuple(row)
1881
1882
def _get_descriptions(self):
1883
"""Read a column descriptor packet for each column in the result."""
1884
self.fields = []
1885
self.converters = []
1886
use_unicode = self.connection.use_unicode
1887
conn_encoding = self.connection.encoding
1888
description = []
1889
1890
for i in range(self.field_count):
1891
field = self.connection._read_packet(FieldDescriptorPacket)
1892
self.fields.append(field)
1893
description.append(field.description())
1894
field_type = field.type_code
1895
if use_unicode:
1896
if field_type == FIELD_TYPE.JSON:
1897
# When SELECT from JSON column: charset = binary
1898
# When SELECT CAST(... AS JSON): charset = connection encoding
1899
# This behavior is different from TEXT / BLOB.
1900
# We should decode result by connection encoding regardless charsetnr.
1901
# See https://github.com/PyMySQL/PyMySQL/issues/488
1902
encoding = conn_encoding # SELECT CAST(... AS JSON)
1903
elif field_type in TEXT_TYPES:
1904
if field.charsetnr == 63: # binary
1905
# TEXTs with charset=binary means BINARY types.
1906
encoding = None
1907
else:
1908
encoding = conn_encoding
1909
else:
1910
# Integers, Dates and Times, and other basic data is encoded in ascii
1911
encoding = 'ascii'
1912
else:
1913
encoding = None
1914
converter = self.connection.decoders.get(field_type)
1915
if converter is converters.through:
1916
converter = None
1917
if DEBUG:
1918
print(f'DEBUG: field={field}, converter={converter}')
1919
self.converters.append((encoding, converter))
1920
1921
eof_packet = self.connection._read_packet()
1922
assert eof_packet.is_eof_packet(), 'Protocol error, expecting EOF'
1923
self.description = tuple(description)
1924
1925
1926
class MySQLResultSV(MySQLResult):
1927
1928
def __init__(self, connection, unbuffered=False):
1929
MySQLResult.__init__(self, connection, unbuffered=unbuffered)
1930
self.options = {
1931
k: v for k, v in dict(
1932
default_converters=converters.decoders,
1933
results_type=connection.results_type,
1934
parse_json=connection.parse_json,
1935
invalid_values=connection.invalid_values,
1936
unbuffered=unbuffered,
1937
encoding_errors=connection.encoding_errors,
1938
).items() if v is not UNSET
1939
}
1940
1941
def _read_rowdata_packet(self, *args, **kwargs):
1942
return _singlestoredb_accel.read_rowdata_packet(self, False, *args, **kwargs)
1943
1944
def _read_rowdata_packet_unbuffered(self, *args, **kwargs):
1945
return _singlestoredb_accel.read_rowdata_packet(self, True, *args, **kwargs)
1946
1947
1948
class LoadLocalFile:
1949
1950
def __init__(self, filename, connection):
1951
self.filename = filename
1952
self.connection = connection
1953
1954
def send_data(self):
1955
"""Send data packets from the local file to the server"""
1956
if not self.connection._sock:
1957
raise err.InterfaceError(0, 'Connection is closed')
1958
1959
conn = self.connection
1960
infile = conn._local_infile_stream
1961
1962
# 16KB is efficient enough
1963
packet_size = min(conn.max_allowed_packet, 16 * 1024)
1964
1965
try:
1966
1967
if self.filename in [':stream:', b':stream:']:
1968
1969
if infile is None:
1970
raise err.OperationalError(
1971
ER.FILE_NOT_FOUND,
1972
':stream: specified for LOCAL INFILE, but no stream was supplied',
1973
)
1974
1975
# Binary IO
1976
elif isinstance(infile, io.RawIOBase):
1977
while True:
1978
chunk = infile.read(packet_size)
1979
if not chunk:
1980
break
1981
conn.write_packet(chunk)
1982
1983
# Text IO
1984
elif isinstance(infile, io.TextIOBase):
1985
while True:
1986
chunk = infile.read(packet_size)
1987
if not chunk:
1988
break
1989
conn.write_packet(chunk.encode('utf8'))
1990
1991
# Iterable of bytes or str
1992
elif isinstance(infile, Iterable):
1993
for chunk in infile:
1994
if not chunk:
1995
continue
1996
if isinstance(chunk, str):
1997
conn.write_packet(chunk.encode('utf8'))
1998
else:
1999
conn.write_packet(chunk)
2000
2001
# Queue (empty value ends the iteration)
2002
elif isinstance(infile, queue.Queue):
2003
while True:
2004
chunk = infile.get()
2005
if not chunk:
2006
break
2007
if isinstance(chunk, str):
2008
conn.write_packet(chunk.encode('utf8'))
2009
else:
2010
conn.write_packet(chunk)
2011
2012
else:
2013
raise err.OperationalError(
2014
ER.FILE_NOT_FOUND,
2015
':stream: specified for LOCAL INFILE, ' +
2016
f'but stream type is unrecognized: {infile}',
2017
)
2018
2019
else:
2020
try:
2021
with open(self.filename, 'rb') as open_file:
2022
while True:
2023
chunk = open_file.read(packet_size)
2024
if not chunk:
2025
break
2026
conn.write_packet(chunk)
2027
except OSError:
2028
raise err.OperationalError(
2029
ER.FILE_NOT_FOUND,
2030
f"Can't find file '{self.filename!s}'",
2031
)
2032
2033
finally:
2034
if not conn._closed:
2035
# send the empty packet to signify we are done sending data
2036
conn.write_packet(b'')
2037
2038