Path: blob/main/singlestoredb/mysql/connection.py
798 views
# type: ignore1# Python implementation of the MySQL client-server protocol2# http://dev.mysql.com/doc/internals/en/client-server-protocol.html3# Error codes:4# https://dev.mysql.com/doc/refman/5.5/en/error-handling.html5import errno6import functools7import io8import os9import queue10import socket11import struct12import sys13import traceback14import warnings15from collections.abc import Iterable16from typing import Any17from typing import Dict1819try:20import _singlestoredb_accel21except (ImportError, ModuleNotFoundError):22_singlestoredb_accel = None2324from . import _auth25from ..utils import events2627from .charset import charset_by_name, charset_by_id28from .constants import CLIENT, COMMAND, CR, ER, FIELD_TYPE, SERVER_STATUS29from . import converters30from .cursors import (31Cursor,32CursorSV,33DictCursor,34DictCursorSV,35NamedtupleCursor,36NamedtupleCursorSV,37ArrowCursor,38ArrowCursorSV,39NumpyCursor,40NumpyCursorSV,41PandasCursor,42PandasCursorSV,43PolarsCursor,44PolarsCursorSV,45SSCursor,46SSCursorSV,47SSDictCursor,48SSDictCursorSV,49SSNamedtupleCursor,50SSNamedtupleCursorSV,51SSArrowCursor,52SSArrowCursorSV,53SSNumpyCursor,54SSNumpyCursorSV,55SSPandasCursor,56SSPandasCursorSV,57SSPolarsCursor,58SSPolarsCursorSV,59)60from .optionfile import Parser61from .protocol import (62dump_packet,63MysqlPacket,64FieldDescriptorPacket,65OKPacketWrapper,66EOFPacketWrapper,67LoadLocalPacketWrapper,68)69from . import err70from ..config import get_option71from .. import fusion72from .. import connection73from ..connection import Connection as BaseConnection74from ..utils.debug import log_query7576try:77import ssl7879SSL_ENABLED = True80except ImportError:81ssl = None82SSL_ENABLED = False8384try:85import getpass8687DEFAULT_USER = getpass.getuser()88del getpass89except (ImportError, KeyError):90# KeyError occurs when there's no entry in OS database for a current user.91DEFAULT_USER = None9293DEBUG = get_option('debug.connection')9495TEXT_TYPES = {96FIELD_TYPE.BIT,97FIELD_TYPE.BLOB,98FIELD_TYPE.LONG_BLOB,99FIELD_TYPE.MEDIUM_BLOB,100FIELD_TYPE.STRING,101FIELD_TYPE.TINY_BLOB,102FIELD_TYPE.VAR_STRING,103FIELD_TYPE.VARCHAR,104FIELD_TYPE.GEOMETRY,105FIELD_TYPE.BSON,106FIELD_TYPE.FLOAT32_VECTOR_JSON,107FIELD_TYPE.FLOAT64_VECTOR_JSON,108FIELD_TYPE.INT8_VECTOR_JSON,109FIELD_TYPE.INT16_VECTOR_JSON,110FIELD_TYPE.INT32_VECTOR_JSON,111FIELD_TYPE.INT64_VECTOR_JSON,112FIELD_TYPE.FLOAT16_VECTOR_JSON,113FIELD_TYPE.FLOAT32_VECTOR,114FIELD_TYPE.FLOAT64_VECTOR,115FIELD_TYPE.INT8_VECTOR,116FIELD_TYPE.INT16_VECTOR,117FIELD_TYPE.INT32_VECTOR,118FIELD_TYPE.INT64_VECTOR,119FIELD_TYPE.FLOAT16_VECTOR,120}121122UNSET = 'unset'123124DEFAULT_CHARSET = 'utf8mb4'125126MAX_PACKET_LEN = 2**24 - 1127128129def _pack_int24(n):130return struct.pack('<I', n)[:3]131132133# https://dev.mysql.com/doc/internals/en/integer.html#packet-Protocol::LengthEncodedInteger134def _lenenc_int(i):135if i < 0:136raise ValueError(137'Encoding %d is less than 0 - no representation in LengthEncodedInteger' % i,138)139elif i < 0xFB:140return bytes([i])141elif i < (1 << 16):142return b'\xfc' + struct.pack('<H', i)143elif i < (1 << 24):144return b'\xfd' + struct.pack('<I', i)[:3]145elif i < (1 << 64):146return b'\xfe' + struct.pack('<Q', i)147else:148raise ValueError(149'Encoding %x is larger than %x - no representation in LengthEncodedInteger'150% (i, (1 << 64)),151)152153154class Connection(BaseConnection):155"""156Representation of a socket with a mysql server.157158The proper way to get an instance of this class is to call159``connect()``.160161Establish a connection to the SingleStoreDB database.162163Parameters164----------165host : str, optional166Host where the database server is located.167user : str, optional168Username to log in as.169password : str, optional170Password to use.171database : str, optional172Database to use, None to not use a particular one.173port : int, optional174Server port to use, default is usually OK. (default: 3306)175bind_address : str, optional176When the client has multiple network interfaces, specify177the interface from which to connect to the host. Argument can be178a hostname or an IP address.179unix_socket : str, optional180Use a unix socket rather than TCP/IP.181read_timeout : int, optional182The timeout for reading from the connection in seconds183(default: None - no timeout)184write_timeout : int, optional185The timeout for writing to the connection in seconds186(default: None - no timeout)187charset : str, optional188Charset to use.189collation : str, optional190The charset collation191sql_mode : str, optional192Default SQL_MODE to use.193read_default_file : str, optional194Specifies my.cnf file to read these parameters from under the195[client] section.196conv : Dict[str, Callable[Any]], optional197Conversion dictionary to use instead of the default one.198This is used to provide custom marshalling and unmarshalling of types.199See converters.200use_unicode : bool, optional201Whether or not to default to unicode strings.202This option defaults to true.203client_flag : int, optional204Custom flags to send to MySQL. Find potential values in constants.CLIENT.205cursorclass : type, optional206Custom cursor class to use.207init_command : str, optional208Initial SQL statement to run when connection is established.209connect_timeout : int, optional210The timeout for connecting to the database in seconds.211(default: 10, min: 1, max: 31536000)212ssl : Dict[str, str], optional213A dict of arguments similar to mysql_ssl_set()'s parameters or214an ssl.SSLContext.215ssl_ca : str, optional216Path to the file that contains a PEM-formatted CA certificate.217ssl_cert : str, optional218Path to the file that contains a PEM-formatted client certificate.219ssl_cipher : str, optional220SSL ciphers to allow.221ssl_disabled : bool, optional222A boolean value that disables usage of TLS.223ssl_key : str, optional224Path to the file that contains a PEM-formatted private key for the225client certificate.226ssl_verify_cert : str, optional227Set to true to check the server certificate's validity.228ssl_verify_identity : bool, optional229Set to true to check the server's identity.230tls_sni_servername: str, optional231Set server host name for TLS connection232read_default_group : str, optional233Group to read from in the configuration file.234autocommit : bool, optional235Autocommit mode. None means use server default. (default: False)236local_infile : bool, optional237Boolean to enable the use of LOAD DATA LOCAL command. (default: False)238max_allowed_packet : int, optional239Max size of packet sent to server in bytes. (default: 16MB)240Only used to limit size of "LOAD LOCAL INFILE" data packet smaller241than default (16KB).242defer_connect : bool, optional243Don't explicitly connect on construction - wait for connect call.244(default: False)245auth_plugin_map : Dict[str, type], optional246A dict of plugin names to a class that processes that plugin.247The class will take the Connection object as the argument to the248constructor. The class needs an authenticate method taking an249authentication packet as an argument. For the dialog plugin, a250prompt(echo, prompt) method can be used (if no authenticate method)251for returning a string from the user. (experimental)252server_public_key : str, optional253SHA256 authentication plugin public key value. (default: None)254binary_prefix : bool, optional255Add _binary prefix on bytes and bytearray. (default: False)256compress :257Not supported.258named_pipe :259Not supported.260db : str, optional261**DEPRECATED** Alias for database.262passwd : str, optional263**DEPRECATED** Alias for password.264parse_json : bool, optional265Parse JSON values into Python objects?266invalid_values : Dict[int, Any], optional267Dictionary of values to use in place of invalid values268found during conversion of data. The default is to return the byte content269containing the invalid value. The keys are the integers associtated with270the column type.271pure_python : bool, optional272Should we ignore the C extension even if it's available?273This can be given explicitly using True or False, or if the value is None,274the C extension will be loaded if it is available. If set to False and275the C extension can't be loaded, a NotSupportedError is raised.276nan_as_null : bool, optional277Should NaN values be treated as NULLs in parameter substitution including278uploading data?279inf_as_null : bool, optional280Should Inf values be treated as NULLs in parameter substitution including281uploading data?282track_env : bool, optional283Should the connection track the SINGLESTOREDB_URL environment variable?284enable_extended_data_types : bool, optional285Should extended data types (BSON, vector) be enabled?286vector_data_format : str, optional287Specify the data type of vector values: json or binary288289See `Connection <https://www.python.org/dev/peps/pep-0249/#connection-objects>`_290in the specification.291292"""293294driver = 'mysql'295paramstyle = 'pyformat'296297_sock = None298_auth_plugin_name = ''299_closed = False300_secure = False301_tls_sni_servername = None302303def __init__( # noqa: C901304self,305*,306user=None, # The first four arguments is based on DB-API 2.0 recommendation.307password='',308host=None,309database=None,310unix_socket=None,311port=0,312charset='',313collation=None,314sql_mode=None,315read_default_file=None,316conv=None,317use_unicode=True,318client_flag=0,319cursorclass=None,320init_command=None,321connect_timeout=10,322read_default_group=None,323autocommit=False,324local_infile=False,325max_allowed_packet=16 * 1024 * 1024,326defer_connect=False,327auth_plugin_map=None,328read_timeout=None,329write_timeout=None,330bind_address=None,331binary_prefix=False,332program_name=None,333server_public_key=None,334ssl=None,335ssl_ca=None,336ssl_cert=None,337ssl_cipher=None,338ssl_disabled=None,339ssl_key=None,340ssl_verify_cert=None,341ssl_verify_identity=None,342tls_sni_servername=None,343parse_json=True,344invalid_values=None,345pure_python=None,346buffered=True,347results_type='tuples',348compress=None, # not supported349named_pipe=None, # not supported350passwd=None, # deprecated351db=None, # deprecated352driver=None, # internal use353conn_attrs=None,354multi_statements=None,355client_found_rows=None,356nan_as_null=None,357inf_as_null=None,358encoding_errors='strict',359track_env=False,360enable_extended_data_types=True,361vector_data_format='binary',362interpolate_query_with_empty_args=None,363):364BaseConnection.__init__(**dict(locals()))365366if db is not None and database is None:367# We will raise warning in 2022 or later.368# See https://github.com/PyMySQL/PyMySQL/issues/939369# warnings.warn("'db' is deprecated, use 'database'", DeprecationWarning, 3)370database = db371if passwd is not None and not password:372# We will raise warning in 2022 or later.373# See https://github.com/PyMySQL/PyMySQL/issues/939374# warnings.warn(375# "'passwd' is deprecated, use 'password'", DeprecationWarning, 3376# )377password = passwd378379if compress or named_pipe:380raise NotImplementedError(381'compress and named_pipe arguments are not supported',382)383384self._local_infile = bool(local_infile)385self._local_infile_stream = None386if self._local_infile:387client_flag |= CLIENT.LOCAL_FILES388if multi_statements:389client_flag |= CLIENT.MULTI_STATEMENTS390if client_found_rows:391client_flag |= CLIENT.FOUND_ROWS392393if read_default_group and not read_default_file:394if sys.platform.startswith('win'):395read_default_file = 'c:\\my.ini'396else:397read_default_file = '/etc/my.cnf'398399if read_default_file:400if not read_default_group:401read_default_group = 'client'402403cfg = Parser()404cfg.read(os.path.expanduser(read_default_file))405406def _config(key, arg):407if arg:408return arg409try:410return cfg.get(read_default_group, key)411except Exception:412return arg413414user = _config('user', user)415password = _config('password', password)416host = _config('host', host)417database = _config('database', database)418unix_socket = _config('socket', unix_socket)419port = int(_config('port', port))420bind_address = _config('bind-address', bind_address)421charset = _config('default-character-set', charset)422if not ssl:423ssl = {}424if isinstance(ssl, dict):425for key in ['ca', 'capath', 'cert', 'key', 'cipher']:426value = _config('ssl-' + key, ssl.get(key))427if value:428ssl[key] = value429430self.ssl = False431if not ssl_disabled:432if ssl_ca or ssl_cert or ssl_key or ssl_cipher or \433ssl_verify_cert or ssl_verify_identity:434ssl = {435'ca': ssl_ca,436'check_hostname': bool(ssl_verify_identity),437'verify_mode': ssl_verify_cert438if ssl_verify_cert is not None439else False,440}441if ssl_cert is not None:442ssl['cert'] = ssl_cert443if ssl_key is not None:444ssl['key'] = ssl_key445if ssl_cipher is not None:446ssl['cipher'] = ssl_cipher447if ssl:448if not SSL_ENABLED:449raise NotImplementedError('ssl module not found')450self.ssl = True451client_flag |= CLIENT.SSL452self.ctx = self._create_ssl_ctx(ssl)453454self.host = host or 'localhost'455self.port = port or 3306456if type(self.port) is not int:457raise ValueError('port should be of type int')458self.user = user or DEFAULT_USER459self.password = password or b''460if isinstance(self.password, str):461self.password = self.password.encode('latin1')462self.db = database463self.unix_socket = unix_socket464self.bind_address = bind_address465if not (0 < connect_timeout <= 31536000):466raise ValueError('connect_timeout should be >0 and <=31536000')467self.connect_timeout = connect_timeout or None468if read_timeout is not None and read_timeout <= 0:469raise ValueError('read_timeout should be > 0')470self._read_timeout = read_timeout471if write_timeout is not None and write_timeout <= 0:472raise ValueError('write_timeout should be > 0')473self._write_timeout = write_timeout474475self.charset = charset or DEFAULT_CHARSET476self.collation = collation477self.use_unicode = use_unicode478self.encoding_errors = encoding_errors479480self.encoding = charset_by_name(self.charset).encoding481482client_flag |= CLIENT.CAPABILITIES483client_flag |= CLIENT.CONNECT_WITH_DB484485self.client_flag = client_flag486487self.pure_python = pure_python488self.results_type = results_type489self.resultclass = MySQLResult490if cursorclass is not None:491self.cursorclass = cursorclass492elif buffered:493if 'dict' in self.results_type:494self.cursorclass = DictCursor495elif 'namedtuple' in self.results_type:496self.cursorclass = NamedtupleCursor497elif 'numpy' in self.results_type:498self.cursorclass = NumpyCursor499elif 'arrow' in self.results_type:500self.cursorclass = ArrowCursor501elif 'pandas' in self.results_type:502self.cursorclass = PandasCursor503elif 'polars' in self.results_type:504self.cursorclass = PolarsCursor505else:506self.cursorclass = Cursor507else:508if 'dict' in self.results_type:509self.cursorclass = SSDictCursor510elif 'namedtuple' in self.results_type:511self.cursorclass = SSNamedtupleCursor512elif 'numpy' in self.results_type:513self.cursorclass = SSNumpyCursor514elif 'arrow' in self.results_type:515self.cursorclass = SSArrowCursor516elif 'pandas' in self.results_type:517self.cursorclass = SSPandasCursor518elif 'polars' in self.results_type:519self.cursorclass = SSPolarsCursor520else:521self.cursorclass = SSCursor522523if self.pure_python is False and _singlestoredb_accel is None:524try:525import _singlestortedb_accel # noqa: F401526except Exception:527import traceback528traceback.print_exc(file=sys.stderr)529finally:530raise err.NotSupportedError(531'pure_python=False, but the '532'C extension can not be loaded',533)534535if self.pure_python is True:536pass537538# The C extension handles these types internally.539elif _singlestoredb_accel is not None:540self.resultclass = MySQLResultSV541if self.cursorclass is Cursor:542self.cursorclass = CursorSV543elif self.cursorclass is SSCursor:544self.cursorclass = SSCursorSV545elif self.cursorclass is DictCursor:546self.cursorclass = DictCursorSV547self.results_type = 'dicts'548elif self.cursorclass is SSDictCursor:549self.cursorclass = SSDictCursorSV550self.results_type = 'dicts'551elif self.cursorclass is NamedtupleCursor:552self.cursorclass = NamedtupleCursorSV553self.results_type = 'namedtuples'554elif self.cursorclass is SSNamedtupleCursor:555self.cursorclass = SSNamedtupleCursorSV556self.results_type = 'namedtuples'557elif self.cursorclass is NumpyCursor:558self.cursorclass = NumpyCursorSV559self.results_type = 'numpy'560elif self.cursorclass is SSNumpyCursor:561self.cursorclass = SSNumpyCursorSV562self.results_type = 'numpy'563elif self.cursorclass is ArrowCursor:564self.cursorclass = ArrowCursorSV565self.results_type = 'arrow'566elif self.cursorclass is SSArrowCursor:567self.cursorclass = SSArrowCursorSV568self.results_type = 'arrow'569elif self.cursorclass is PandasCursor:570self.cursorclass = PandasCursorSV571self.results_type = 'pandas'572elif self.cursorclass is SSPandasCursor:573self.cursorclass = SSPandasCursorSV574self.results_type = 'pandas'575elif self.cursorclass is PolarsCursor:576self.cursorclass = PolarsCursorSV577self.results_type = 'polars'578elif self.cursorclass is SSPolarsCursor:579self.cursorclass = SSPolarsCursorSV580self.results_type = 'polars'581582self._result = None583self._affected_rows = 0584self.host_info = 'Not connected'585586# specified autocommit mode. None means use server default.587self.autocommit_mode = autocommit588589if conv is None:590conv = converters.conversions591592conv = conv.copy()593594self.parse_json = parse_json595self.invalid_values = (invalid_values or {}).copy()596597# Disable JSON parsing for Arrow598if self.results_type in ['arrow']:599conv[245] = None600self.parse_json = False601602# Disable date/time parsing for polars; let polars do the parsing603elif self.results_type in ['polars']:604conv[7] = None605conv[10] = None606conv[12] = None607608# Need for MySQLdb compatibility.609self.encoders = {k: v for (k, v) in conv.items() if type(k) is not int}610self.decoders = {k: v for (k, v) in conv.items() if type(k) is int}611self.sql_mode = sql_mode612self.init_command = init_command613self.max_allowed_packet = max_allowed_packet614self._auth_plugin_map = auth_plugin_map or {}615self._binary_prefix = binary_prefix616self.server_public_key = server_public_key617self.interpolate_query_with_empty_args = interpolate_query_with_empty_args618619if self.connection_params['nan_as_null'] or \620self.connection_params['inf_as_null']:621float_encoder = self.encoders.get(float)622if float_encoder is not None:623self.encoders[float] = functools.partial(624float_encoder,625nan_as_null=self.connection_params['nan_as_null'],626inf_as_null=self.connection_params['inf_as_null'],627)628629from .. import __version__ as VERSION_STRING630631if 'SINGLESTOREDB_WORKLOAD_TYPE' in os.environ:632VERSION_STRING += '+' + os.environ['SINGLESTOREDB_WORKLOAD_TYPE']633634self._connect_attrs = {635'_os': str(sys.platform),636'_pid': str(os.getpid()),637'_client_name': 'SingleStoreDB Python Client',638'_client_version': VERSION_STRING,639}640641if program_name:642self._connect_attrs['program_name'] = program_name643if conn_attrs is not None:644# do not overwrite the attributes that we set ourselves645for k, v in conn_attrs.items():646if k not in self._connect_attrs:647self._connect_attrs[k] = v648649self._is_committable = True650self._in_sync = False651self._tls_sni_servername = tls_sni_servername652self._track_env = bool(track_env) or self.host == 'singlestore.com'653self._enable_extended_data_types = enable_extended_data_types654if vector_data_format.lower() in ['json', 'binary']:655self._vector_data_format = vector_data_format656else:657raise ValueError(658'unknown value for vector_data_format, '659f'expecting "json" or "binary": {vector_data_format}',660)661self._connection_info = {}662events.subscribe(self._handle_event)663664if defer_connect or self._track_env:665self._sock = None666else:667self.connect()668669def _handle_event(self, data: Dict[str, Any]) -> None:670if data.get('name', '') == 'singlestore.portal.connection_updated':671self._connection_info = dict(data)672673@property674def messages(self):675# TODO676[]677678def __enter__(self):679return self680681def __exit__(self, *exc_info):682del exc_info683self.close()684685def _raise_mysql_exception(self, data):686err.raise_mysql_exception(data)687688def _create_ssl_ctx(self, sslp):689if isinstance(sslp, ssl.SSLContext):690return sslp691ca = sslp.get('ca')692capath = sslp.get('capath')693hasnoca = ca is None and capath is None694ctx = ssl.create_default_context(cafile=ca, capath=capath)695ctx.check_hostname = not hasnoca and sslp.get('check_hostname', True)696verify_mode_value = sslp.get('verify_mode')697if verify_mode_value is None:698ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED699elif isinstance(verify_mode_value, bool):700ctx.verify_mode = ssl.CERT_REQUIRED if verify_mode_value else ssl.CERT_NONE701else:702if isinstance(verify_mode_value, str):703verify_mode_value = verify_mode_value.lower()704if verify_mode_value in ('none', '0', 'false', 'no'):705ctx.verify_mode = ssl.CERT_NONE706elif verify_mode_value == 'optional':707ctx.verify_mode = ssl.CERT_OPTIONAL708elif verify_mode_value in ('required', '1', 'true', 'yes'):709ctx.verify_mode = ssl.CERT_REQUIRED710else:711ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED712if 'cert' in sslp:713ctx.load_cert_chain(sslp['cert'], keyfile=sslp.get('key'))714if 'cipher' in sslp:715ctx.set_ciphers(sslp['cipher'])716ctx.options |= ssl.OP_NO_SSLv2717ctx.options |= ssl.OP_NO_SSLv3718return ctx719720def close(self):721"""722Send the quit message and close the socket.723724See `Connection.close()725<https://www.python.org/dev/peps/pep-0249/#Connection.close>`_726in the specification.727728Raises729------730Error : If the connection is already closed.731732"""733self._result = None734if self.host == 'singlestore.com':735return736if self._closed:737raise err.Error('Already closed')738events.unsubscribe(self._handle_event)739self._closed = True740if self._sock is None:741return742send_data = struct.pack('<iB', 1, COMMAND.COM_QUIT)743try:744self._write_bytes(send_data)745except Exception:746pass747finally:748self._force_close()749750@property751def open(self):752"""Return True if the connection is open."""753return self._sock is not None754755def is_connected(self):756"""Return True if the connection is open."""757return self.open758759def _force_close(self):760"""Close connection without QUIT message."""761if self._sock:762try:763self._sock.close()764except: # noqa765pass766self._sock = None767self._rfile = None768769__del__ = _force_close770771def autocommit(self, value):772"""Enable autocommit in the server."""773self.autocommit_mode = bool(value)774current = self.get_autocommit()775if value != current:776self._send_autocommit_mode()777778def get_autocommit(self):779"""Retrieve autocommit status."""780return bool(self.server_status & SERVER_STATUS.SERVER_STATUS_AUTOCOMMIT)781782def _read_ok_packet(self):783pkt = self._read_packet()784if not pkt.is_ok_packet():785raise err.OperationalError(786CR.CR_COMMANDS_OUT_OF_SYNC,787'Command Out of Sync',788)789ok = OKPacketWrapper(pkt)790self.server_status = ok.server_status791return ok792793def _send_autocommit_mode(self):794"""Set whether or not to commit after every execute()."""795log_query('SET AUTOCOMMIT = %s' % self.escape(self.autocommit_mode))796self._execute_command(797COMMAND.COM_QUERY, 'SET AUTOCOMMIT = %s' % self.escape(self.autocommit_mode),798)799self._read_ok_packet()800801def begin(self):802"""Begin transaction."""803log_query('BEGIN')804if self.host == 'singlestore.com':805return806self._execute_command(COMMAND.COM_QUERY, 'BEGIN')807self._read_ok_packet()808809def commit(self):810"""811Commit changes to stable storage.812813See `Connection.commit() <https://www.python.org/dev/peps/pep-0249/#commit>`_814in the specification.815816"""817log_query('COMMIT')818if not self._is_committable or self.host == 'singlestore.com':819self._is_committable = True820return821self._execute_command(COMMAND.COM_QUERY, 'COMMIT')822self._read_ok_packet()823824def rollback(self):825"""826Roll back the current transaction.827828See `Connection.rollback() <https://www.python.org/dev/peps/pep-0249/#rollback>`_829in the specification.830831"""832log_query('ROLLBACK')833if not self._is_committable or self.host == 'singlestore.com':834self._is_committable = True835return836self._execute_command(COMMAND.COM_QUERY, 'ROLLBACK')837self._read_ok_packet()838839def show_warnings(self):840"""Send the "SHOW WARNINGS" SQL command."""841log_query('SHOW WARNINGS')842self._execute_command(COMMAND.COM_QUERY, 'SHOW WARNINGS')843result = self.resultclass(self)844result.read()845return result.rows846847def select_db(self, db):848"""849Set current db.850851db : str852The name of the db.853854"""855self._execute_command(COMMAND.COM_INIT_DB, db)856self._read_ok_packet()857858def escape(self, obj, mapping=None):859"""860Escape whatever value is passed.861862Non-standard, for internal use; do not use this in your applications.863864"""865dtype = type(obj)866if dtype is str or isinstance(obj, str):867return "'{}'".format(self.escape_string(obj))868if dtype is bytes or dtype is bytearray or isinstance(obj, (bytes, bytearray)):869return self._quote_bytes(obj)870if mapping is None:871mapping = self.encoders872return converters.escape_item(obj, self.charset, mapping=mapping)873874def literal(self, obj):875"""876Alias for escape().877878Non-standard, for internal use; do not use this in your applications.879880"""881return self.escape(obj, self.encoders)882883def escape_string(self, s):884"""Escape a string value."""885if self.server_status & SERVER_STATUS.SERVER_STATUS_NO_BACKSLASH_ESCAPES:886return s.replace("'", "''")887return converters.escape_string(s)888889def _quote_bytes(self, s):890if self.server_status & SERVER_STATUS.SERVER_STATUS_NO_BACKSLASH_ESCAPES:891if self._binary_prefix:892return "_binary X'{}'".format(s.hex())893return "X'{}'".format(s.hex())894return converters.escape_bytes(s)895896def cursor(self):897"""Create a new cursor to execute queries with."""898return self.cursorclass(self)899900# The following methods are INTERNAL USE ONLY (called from Cursor)901def query(self, sql, unbuffered=False, infile_stream=None):902"""903Run a query on the server.904905Internal use only.906907"""908# if DEBUG:909# print("DEBUG: sending query:", sql)910handler = fusion.get_handler(sql)911if handler is not None:912self._is_committable = False913self._result = fusion.execute(self, sql, handler=handler)914self._affected_rows = self._result.affected_rows915else:916self._is_committable = True917if isinstance(sql, str):918sql = sql.encode(self.encoding, 'surrogateescape')919self._local_infile_stream = infile_stream920self._execute_command(COMMAND.COM_QUERY, sql)921self._affected_rows = self._read_query_result(unbuffered=unbuffered)922self._local_infile_stream = None923return self._affected_rows924925def next_result(self, unbuffered=False):926"""927Retrieve the next result set.928929Internal use only.930931"""932self._affected_rows = self._read_query_result(unbuffered=unbuffered)933return self._affected_rows934935def affected_rows(self):936"""937Return number of affected rows.938939Internal use only.940941"""942return self._affected_rows943944def kill(self, thread_id):945"""946Execute kill command.947948Internal use only.949950"""951arg = struct.pack('<I', thread_id)952self._execute_command(COMMAND.COM_PROCESS_KILL, arg)953return self._read_ok_packet()954955def ping(self, reconnect=True):956"""957Check if the server is alive.958959Parameters960----------961reconnect : bool, optional962If the connection is closed, reconnect.963964Raises965------966Error : If the connection is closed and reconnect=False.967968"""969if self._sock is None:970if reconnect:971self.connect()972reconnect = False973else:974raise err.Error('Already closed')975try:976self._execute_command(COMMAND.COM_PING, '')977self._read_ok_packet()978except Exception:979if reconnect:980self.connect()981self.ping(False)982else:983raise984985def set_charset(self, charset):986"""Deprecated. Use set_character_set() instead."""987# This function has been implemented in old PyMySQL.988# But this name is different from MySQLdb.989# So we keep this function for compatibility and add990# new set_character_set() function.991self.set_character_set(charset)992993def set_character_set(self, charset, collation=None):994"""995Set charaset (and collation) on the server.996997Send "SET NAMES charset [COLLATE collation]" query.998Update Connection.encoding based on charset.9991000Parameters1001----------1002charset : str1003The charset to enable.1004collation : str, optional1005The collation value10061007"""1008# Make sure charset is supported.1009encoding = charset_by_name(charset).encoding10101011if collation:1012query = f'SET NAMES {charset} COLLATE {collation}'1013else:1014query = f'SET NAMES {charset}'1015self._execute_command(COMMAND.COM_QUERY, query)1016self._read_packet()1017self.charset = charset1018self.encoding = encoding1019self.collation = collation10201021def _sync_connection(self):1022"""Synchronize connection with env variable."""1023if self._in_sync:1024return10251026if not self._track_env:1027return10281029url = self._connection_info.get('connection_url')1030if not url:1031url = os.environ.get('SINGLESTOREDB_URL')1032if not url:1033return10341035out = {}1036urlp = connection._parse_url(url)1037out.update(urlp)10381039out = connection._cast_params(out)10401041# Set default port based on driver.1042if 'port' not in out or not out['port']:1043out['port'] = int(get_option('port') or 3306)10441045# If there is no user and the password is empty, remove the password key.1046if 'user' not in out and not out.get('password', None):1047out.pop('password', None)10481049if out['host'] == 'singlestore.com':1050raise err.InterfaceError(0, 'Connection URL has not been established')10511052# If it's just a password change, we don't need to reconnect1053if self._sock is not None and \1054(self.host, self.port, self.user, self.db) == \1055(out['host'], out['port'], out['user'], out.get('database')):1056return10571058self.host = out['host']1059self.port = out['port']1060self.user = out['user']1061if isinstance(out['password'], str):1062self.password = out['password'].encode('latin-1')1063else:1064self.password = out['password'] or b''1065self.db = out.get('database')1066try:1067self._in_sync = True1068self.connect()1069finally:1070self._in_sync = False10711072def connect(self, sock=None):1073"""1074Connect to server using existing parameters.10751076Internal use only.10771078"""1079self._closed = False1080try:1081if sock is None:1082if self.unix_socket:1083sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)1084sock.settimeout(self.connect_timeout)1085sock.connect(self.unix_socket)1086self.host_info = 'Localhost via UNIX socket'1087self._secure = True1088if DEBUG:1089print('connected using unix_socket')1090else:1091kwargs = {}1092if self.bind_address is not None:1093kwargs['source_address'] = (self.bind_address, 0)1094while True:1095try:1096sock = socket.create_connection(1097(self.host, self.port), self.connect_timeout, **kwargs,1098)1099break1100except OSError as e:1101if e.errno == errno.EINTR:1102continue1103raise1104self.host_info = 'socket %s:%d' % (self.host, self.port)1105if DEBUG:1106print('connected using socket')1107sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)1108sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)1109sock.settimeout(None)11101111self._sock = sock1112self._rfile = sock.makefile('rb')1113self._next_seq_id = 011141115self._get_server_information()1116self._request_authentication()11171118# Send "SET NAMES" query on init for:1119# - Ensure charaset (and collation) is set to the server.1120# - collation_id in handshake packet may be ignored.1121# - If collation is not specified, we don't know what is server's1122# default collation for the charset. For example, default collation1123# of utf8mb4 is:1124# - MySQL 5.7, MariaDB 10.x: utf8mb4_general_ci1125# - MySQL 8.0: utf8mb4_0900_ai_ci1126#1127# Reference:1128# - https://github.com/PyMySQL/PyMySQL/issues/10921129# - https://github.com/wagtail/wagtail/issues/94771130# - https://zenn.dev/methane/articles/2023-mysql-collation (Japanese)1131self.set_character_set(self.charset, self.collation)11321133if self.sql_mode is not None:1134c = self.cursor()1135c.execute('SET sql_mode=%s', (self.sql_mode,))1136c.close()11371138if self._enable_extended_data_types:1139c = self.cursor()1140try:1141c.execute('SET @@SESSION.enable_extended_types_metadata=on')1142except self.OperationalError:1143pass1144c.close()11451146if self._vector_data_format:1147c = self.cursor()1148try:1149val = self._vector_data_format1150c.execute(f'SET @@SESSION.vector_type_project_format={val}')1151except self.OperationalError:1152pass1153c.close()11541155if self.init_command is not None:1156c = self.cursor()1157c.execute(self.init_command)1158c.close()11591160if self.autocommit_mode is not None:1161self.autocommit(self.autocommit_mode)11621163except BaseException as e:1164self._rfile = None1165if sock is not None:1166try:1167sock.close()1168except: # noqa1169pass11701171if isinstance(e, (OSError, IOError, socket.error)):1172exc = err.OperationalError(1173CR.CR_CONN_HOST_ERROR,1174f'Can\'t connect to MySQL server on {self.host!r} ({e})',1175)1176# Keep original exception and traceback to investigate error.1177exc.original_exception = e1178exc.traceback = traceback.format_exc()1179if DEBUG:1180print(exc.traceback)1181raise exc11821183# If e is neither DatabaseError or IOError, It's a bug.1184# But raising AssertionError hides original error.1185# So just reraise it.1186raise11871188def write_packet(self, payload):1189"""1190Writes an entire "mysql packet" in its entirety to the network.11911192Adds its length and sequence number.11931194"""1195# Internal note: when you build packet manually and calls _write_bytes()1196# directly, you should set self._next_seq_id properly.1197data = _pack_int24(len(payload)) + bytes([self._next_seq_id]) + payload1198if DEBUG:1199dump_packet(data)1200self._write_bytes(data)1201self._next_seq_id = (self._next_seq_id + 1) % 25612021203def _read_packet(self, packet_type=MysqlPacket):1204"""1205Read an entire "mysql packet" in its entirety from the network.12061207Raises1208------1209OperationalError : If the connection to the MySQL server is lost.1210InternalError : If the packet sequence number is wrong.12111212Returns1213-------1214MysqlPacket12151216"""1217buff = bytearray()1218while True:1219packet_header = self._read_bytes(4)1220# if DEBUG: dump_packet(packet_header)12211222btrl, btrh, packet_number = struct.unpack('<HBB', packet_header)1223bytes_to_read = btrl + (btrh << 16)1224if packet_number != self._next_seq_id:1225self._force_close()1226if packet_number == 0:1227# MariaDB sends error packet with seqno==0 when shutdown1228raise err.OperationalError(1229CR.CR_SERVER_LOST,1230'Lost connection to MySQL server during query',1231)1232raise err.InternalError(1233'Packet sequence number wrong - got %d expected %d'1234% (packet_number, self._next_seq_id),1235)1236self._next_seq_id = (self._next_seq_id + 1) % 25612371238recv_data = self._read_bytes(bytes_to_read)1239if DEBUG:1240dump_packet(recv_data)1241buff += recv_data1242# https://dev.mysql.com/doc/internals/en/sending-more-than-16mbyte.html1243if bytes_to_read == 0xFFFFFF:1244continue1245if bytes_to_read < MAX_PACKET_LEN:1246break12471248packet = packet_type(bytes(buff), self.encoding)1249if packet.is_error_packet():1250if self._result is not None and self._result.unbuffered_active is True:1251self._result.unbuffered_active = False1252packet.raise_for_error()1253return packet12541255def _read_bytes(self, num_bytes):1256if self._read_timeout is not None:1257self._sock.settimeout(self._read_timeout)1258while True:1259try:1260data = self._rfile.read(num_bytes)1261break1262except OSError as e:1263if e.errno == errno.EINTR:1264continue1265self._force_close()1266raise err.OperationalError(1267CR.CR_SERVER_LOST,1268'Lost connection to MySQL server during query (%s)' % (e,),1269)1270except BaseException:1271# Don't convert unknown exception to MySQLError.1272self._force_close()1273raise1274if len(data) < num_bytes:1275self._force_close()1276raise err.OperationalError(1277CR.CR_SERVER_LOST, 'Lost connection to MySQL server during query',1278)1279return data12801281def _write_bytes(self, data):1282if self._write_timeout is not None:1283self._sock.settimeout(self._write_timeout)1284try:1285self._sock.sendall(data)1286except OSError as e:1287self._force_close()1288raise err.OperationalError(1289CR.CR_SERVER_GONE_ERROR, f'MySQL server has gone away ({e!r})',1290)12911292def _read_query_result(self, unbuffered=False):1293self._result = None1294if unbuffered:1295result = self.resultclass(self, unbuffered=unbuffered)1296else:1297result = self.resultclass(self)1298result.read()1299self._result = result1300if result.server_status is not None:1301self.server_status = result.server_status1302return result.affected_rows13031304def insert_id(self):1305if self._result:1306return self._result.insert_id1307else:1308return 013091310def _execute_command(self, command, sql):1311"""1312Execute command.13131314Raises1315------1316InterfaceError : If the connection is closed.1317ValueError : If no username was specified.13181319"""1320self._sync_connection()13211322if self._sock is None:1323raise err.InterfaceError(0, 'The connection has been closed')13241325# If the last query was unbuffered, make sure it finishes before1326# sending new commands1327if self._result is not None:1328if self._result.unbuffered_active:1329warnings.warn('Previous unbuffered result was left incomplete')1330self._result._finish_unbuffered_query()1331while self._result.has_next:1332self.next_result()1333self._result = None13341335if isinstance(sql, str):1336sql = sql.encode(self.encoding)13371338packet_size = min(MAX_PACKET_LEN, len(sql) + 1) # +1 is for command13391340# tiny optimization: build first packet manually instead of1341# calling self..write_packet()1342prelude = struct.pack('<iB', packet_size, command)1343packet = prelude + sql[: packet_size - 1]1344self._write_bytes(packet)1345if DEBUG:1346dump_packet(packet)1347self._next_seq_id = 113481349if packet_size < MAX_PACKET_LEN:1350return13511352sql = sql[packet_size - 1:]1353while True:1354packet_size = min(MAX_PACKET_LEN, len(sql))1355self.write_packet(sql[:packet_size])1356sql = sql[packet_size:]1357if not sql and packet_size < MAX_PACKET_LEN:1358break13591360def _request_authentication(self): # noqa: C9011361# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse1362if int(self.server_version.split('.', 1)[0]) >= 5:1363self.client_flag |= CLIENT.MULTI_RESULTS13641365if self.user is None:1366raise ValueError('Did not specify a username')13671368charset_id = charset_by_name(self.charset).id1369if isinstance(self.user, str):1370self.user = self.user.encode(self.encoding)13711372data_init = struct.pack(1373'<iIB23s', self.client_flag, MAX_PACKET_LEN, charset_id, b'',1374)13751376if self.ssl and self.server_capabilities & CLIENT.SSL:1377self.write_packet(data_init)13781379hostname = self.host1380if self._tls_sni_servername:1381hostname = self._tls_sni_servername1382self._sock = self.ctx.wrap_socket(self._sock, server_hostname=hostname)1383self._rfile = self._sock.makefile('rb')1384self._secure = True13851386data = data_init + self.user + b'\0'13871388authresp = b''1389plugin_name = None13901391if self._auth_plugin_name == '':1392plugin_name = b''1393authresp = _auth.scramble_native_password(self.password, self.salt)1394elif self._auth_plugin_name == 'mysql_native_password':1395plugin_name = b'mysql_native_password'1396authresp = _auth.scramble_native_password(self.password, self.salt)1397elif self._auth_plugin_name == 'caching_sha2_password':1398plugin_name = b'caching_sha2_password'1399if self.password:1400if DEBUG:1401print('caching_sha2: trying fast path')1402authresp = _auth.scramble_caching_sha2(self.password, self.salt)1403else:1404if DEBUG:1405print('caching_sha2: empty password')1406elif self._auth_plugin_name == 'sha256_password':1407plugin_name = b'sha256_password'1408if self.ssl and self.server_capabilities & CLIENT.SSL:1409authresp = self.password + b'\0'1410elif self.password:1411authresp = b'\1' # request public key1412else:1413authresp = b'\0' # empty password14141415if self.server_capabilities & CLIENT.PLUGIN_AUTH_LENENC_CLIENT_DATA:1416data += _lenenc_int(len(authresp)) + authresp1417elif self.server_capabilities & CLIENT.SECURE_CONNECTION:1418data += struct.pack('B', len(authresp)) + authresp1419else: # pragma: no cover - no testing against servers w/o secure auth (>=5.0)1420data += authresp + b'\0'14211422if self.server_capabilities & CLIENT.CONNECT_WITH_DB:1423db = self.db1424if isinstance(db, str):1425db = db.encode(self.encoding)1426data += (db or b'') + b'\0'14271428if self.server_capabilities & CLIENT.PLUGIN_AUTH:1429data += (plugin_name or b'') + b'\0'14301431if self.server_capabilities & CLIENT.CONNECT_ATTRS:1432connect_attrs = b''1433for k, v in self._connect_attrs.items():1434k = k.encode('utf-8')1435connect_attrs += _lenenc_int(len(k)) + k1436v = v.encode('utf-8')1437connect_attrs += _lenenc_int(len(v)) + v1438data += _lenenc_int(len(connect_attrs)) + connect_attrs14391440self.write_packet(data)1441auth_packet = self._read_packet()14421443# if authentication method isn't accepted the first byte1444# will have the octet 2541445if auth_packet.is_auth_switch_request():1446if DEBUG:1447print('received auth switch')1448# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest1449auth_packet.read_uint8() # 0xfe packet identifier1450plugin_name = auth_packet.read_string()1451if (1452self.server_capabilities & CLIENT.PLUGIN_AUTH1453and plugin_name is not None1454):1455auth_packet = self._process_auth(plugin_name, auth_packet)1456else:1457raise err.OperationalError('received unknown auth switch request')1458elif auth_packet.is_extra_auth_data():1459if DEBUG:1460print('received extra data')1461# https://dev.mysql.com/doc/internals/en/successful-authentication.html1462if self._auth_plugin_name == 'caching_sha2_password':1463auth_packet = _auth.caching_sha2_password_auth(self, auth_packet)1464elif self._auth_plugin_name == 'sha256_password':1465auth_packet = _auth.sha256_password_auth(self, auth_packet)1466else:1467raise err.OperationalError(1468'Received extra packet for auth method %r', self._auth_plugin_name,1469)14701471if DEBUG:1472print('Succeed to auth')14731474def _process_auth(self, plugin_name, auth_packet):1475handler = self._get_auth_plugin_handler(plugin_name)1476if handler:1477try:1478return handler.authenticate(auth_packet)1479except AttributeError:1480if plugin_name != b'dialog':1481raise err.OperationalError(1482CR.CR_AUTH_PLUGIN_CANNOT_LOAD,1483"Authentication plugin '%s'"1484' not loaded: - %r missing authenticate method'1485% (plugin_name, type(handler)),1486)1487if plugin_name == b'caching_sha2_password':1488return _auth.caching_sha2_password_auth(self, auth_packet)1489elif plugin_name == b'sha256_password':1490return _auth.sha256_password_auth(self, auth_packet)1491elif plugin_name == b'mysql_native_password':1492data = _auth.scramble_native_password(self.password, auth_packet.read_all())1493elif plugin_name == b'client_ed25519':1494data = _auth.ed25519_password(self.password, auth_packet.read_all())1495elif plugin_name == b'mysql_old_password':1496data = (1497_auth.scramble_old_password(self.password, auth_packet.read_all())1498+ b'\0'1499)1500elif plugin_name == b'mysql_clear_password':1501# https://dev.mysql.com/doc/internals/en/clear-text-authentication.html1502data = self.password + b'\0'1503elif plugin_name == b'auth_gssapi_client':1504data = _auth.gssapi_auth(auth_packet.read_all())1505elif plugin_name == b'dialog':1506pkt = auth_packet1507while True:1508flag = pkt.read_uint8()1509echo = (flag & 0x06) == 0x021510last = (flag & 0x01) == 0x011511prompt = pkt.read_all()15121513if prompt == b'Password: ':1514self.write_packet(self.password + b'\0')1515elif handler:1516resp = 'no response - TypeError within plugin.prompt method'1517try:1518resp = handler.prompt(echo, prompt)1519self.write_packet(resp + b'\0')1520except AttributeError:1521raise err.OperationalError(1522CR.CR_AUTH_PLUGIN_CANNOT_LOAD,1523"Authentication plugin '%s'"1524' not loaded: - %r missing prompt method'1525% (plugin_name, handler),1526)1527except TypeError:1528raise err.OperationalError(1529CR.CR_AUTH_PLUGIN_ERR,1530"Authentication plugin '%s'"1531" %r didn't respond with string. Returned '%r' to prompt %r"1532% (plugin_name, handler, resp, prompt),1533)1534else:1535raise err.OperationalError(1536CR.CR_AUTH_PLUGIN_CANNOT_LOAD,1537"Authentication plugin '%s' not configured" % (plugin_name,),1538)1539pkt = self._read_packet()1540pkt.check_error()1541if pkt.is_ok_packet() or last:1542break1543return pkt1544else:1545raise err.OperationalError(1546CR.CR_AUTH_PLUGIN_CANNOT_LOAD,1547"Authentication plugin '%s' not configured" % plugin_name,1548)15491550self.write_packet(data)1551pkt = self._read_packet()1552pkt.check_error()1553return pkt15541555def _get_auth_plugin_handler(self, plugin_name):1556plugin_class = self._auth_plugin_map.get(plugin_name)1557if not plugin_class and isinstance(plugin_name, bytes):1558plugin_class = self._auth_plugin_map.get(plugin_name.decode('ascii'))1559if plugin_class:1560try:1561handler = plugin_class(self)1562except TypeError:1563raise err.OperationalError(1564CR.CR_AUTH_PLUGIN_CANNOT_LOAD,1565"Authentication plugin '%s'"1566' not loaded: - %r cannot be constructed with connection object'1567% (plugin_name, plugin_class),1568)1569else:1570handler = None1571return handler15721573# _mysql support1574def thread_id(self):1575return self.server_thread_id[0]15761577def character_set_name(self):1578return self.charset15791580def get_host_info(self):1581return self.host_info15821583def get_proto_info(self):1584return self.protocol_version15851586def _get_server_information(self):1587i = 01588packet = self._read_packet()1589data = packet.get_all_data()15901591self.protocol_version = data[i]1592i += 115931594server_end = data.find(b'\0', i)1595self.server_version = data[i:server_end].decode('latin1')1596i = server_end + 115971598self.server_thread_id = struct.unpack('<I', data[i: i + 4])1599i += 416001601self.salt = data[i: i + 8]1602i += 9 # 8 + 1(filler)16031604self.server_capabilities = struct.unpack('<H', data[i: i + 2])[0]1605i += 216061607if len(data) >= i + 6:1608lang, stat, cap_h, salt_len = struct.unpack('<BHHB', data[i: i + 6])1609i += 61610# TODO: deprecate server_language and server_charset.1611# mysqlclient-python doesn't provide it.1612self.server_language = lang1613try:1614self.server_charset = charset_by_id(lang).name1615except KeyError:1616# unknown collation1617self.server_charset = None16181619self.server_status = stat1620if DEBUG:1621print('server_status: %x' % stat)16221623self.server_capabilities |= cap_h << 161624if DEBUG:1625print('salt_len:', salt_len)1626salt_len = max(12, salt_len - 9)16271628# reserved1629i += 1016301631if len(data) >= i + salt_len:1632# salt_len includes auth_plugin_data_part_1 and filler1633self.salt += data[i: i + salt_len]1634i += salt_len16351636i += 11637# AUTH PLUGIN NAME may appear here.1638if self.server_capabilities & CLIENT.PLUGIN_AUTH and len(data) >= i:1639# Due to Bug#59453 the auth-plugin-name is missing the terminating1640# NUL-char in versions prior to 5.5.10 and 5.6.2.1641# ref: https://dev.mysql.com/doc/internals/en/1642# connection-phase-packets.html#packet-Protocol::Handshake1643# didn't use version checks as mariadb is corrected and reports1644# earlier than those two.1645server_end = data.find(b'\0', i)1646if server_end < 0: # pragma: no cover - very specific upstream bug1647# not found \0 and last field so take it all1648self._auth_plugin_name = data[i:].decode('utf-8')1649else:1650self._auth_plugin_name = data[i:server_end].decode('utf-8')16511652def get_server_info(self):1653return self.server_version16541655Warning = err.Warning1656Error = err.Error1657InterfaceError = err.InterfaceError1658DatabaseError = err.DatabaseError1659DataError = err.DataError1660OperationalError = err.OperationalError1661IntegrityError = err.IntegrityError1662InternalError = err.InternalError1663ProgrammingError = err.ProgrammingError1664NotSupportedError = err.NotSupportedError166516661667class MySQLResult:1668"""1669Results of a SQL query.16701671Parameters1672----------1673connection : Connection1674The connection the result came from.1675unbuffered : bool, optional1676Should the reads be unbuffered?16771678"""16791680def __init__(self, connection, unbuffered=False):1681self.connection = connection1682self.affected_rows = None1683self.insert_id = None1684self.server_status = None1685self.warning_count = 01686self.message = None1687self.field_count = 01688self.description = None1689self.rows = None1690self.has_next = None1691self.unbuffered_active = False1692self.converters = []1693self.fields = []1694self.encoding_errors = self.connection.encoding_errors1695if unbuffered:1696try:1697self.init_unbuffered_query()1698except Exception:1699self.connection = None1700self.unbuffered_active = False1701raise17021703def __del__(self):1704if self.unbuffered_active:1705self._finish_unbuffered_query()17061707def read(self):1708try:1709first_packet = self.connection._read_packet()17101711if first_packet.is_ok_packet():1712self._read_ok_packet(first_packet)1713elif first_packet.is_load_local_packet():1714self._read_load_local_packet(first_packet)1715else:1716self._read_result_packet(first_packet)1717finally:1718self.connection = None17191720def init_unbuffered_query(self):1721"""1722Initialize an unbuffered query.17231724Raises1725------1726OperationalError : If the connection to the MySQL server is lost.1727InternalError : Other errors.17281729"""1730self.unbuffered_active = True1731first_packet = self.connection._read_packet()17321733if first_packet.is_ok_packet():1734self._read_ok_packet(first_packet)1735self.unbuffered_active = False1736self.connection = None1737elif first_packet.is_load_local_packet():1738self._read_load_local_packet(first_packet)1739self.unbuffered_active = False1740self.connection = None1741else:1742self.field_count = first_packet.read_length_encoded_integer()1743self._get_descriptions()17441745# Apparently, MySQLdb picks this number because it's the maximum1746# value of a 64bit unsigned integer. Since we're emulating MySQLdb,1747# we set it to this instead of None, which would be preferred.1748self.affected_rows = 1844674407370955161517491750def _read_ok_packet(self, first_packet):1751ok_packet = OKPacketWrapper(first_packet)1752self.affected_rows = ok_packet.affected_rows1753self.insert_id = ok_packet.insert_id1754self.server_status = ok_packet.server_status1755self.warning_count = ok_packet.warning_count1756self.message = ok_packet.message1757self.has_next = ok_packet.has_next17581759def _read_load_local_packet(self, first_packet):1760if not self.connection._local_infile:1761raise RuntimeError(1762'**WARN**: Received LOAD_LOCAL packet but local_infile option is false.',1763)1764load_packet = LoadLocalPacketWrapper(first_packet)1765sender = LoadLocalFile(load_packet.filename, self.connection)1766try:1767sender.send_data()1768except Exception:1769self.connection._read_packet() # skip ok packet1770raise17711772ok_packet = self.connection._read_packet()1773if (1774not ok_packet.is_ok_packet()1775): # pragma: no cover - upstream induced protocol error1776raise err.OperationalError(1777CR.CR_COMMANDS_OUT_OF_SYNC,1778'Commands Out of Sync',1779)1780self._read_ok_packet(ok_packet)17811782def _check_packet_is_eof(self, packet):1783if not packet.is_eof_packet():1784return False1785# TODO: Support CLIENT.DEPRECATE_EOF1786# 1) Add DEPRECATE_EOF to CAPABILITIES1787# 2) Mask CAPABILITIES with server_capabilities1788# 3) if server_capabilities & CLIENT.DEPRECATE_EOF: use OKPacketWrapper1789# instead of EOFPacketWrapper1790wp = EOFPacketWrapper(packet)1791self.warning_count = wp.warning_count1792self.has_next = wp.has_next1793return True17941795def _read_result_packet(self, first_packet):1796self.field_count = first_packet.read_length_encoded_integer()1797self._get_descriptions()1798self._read_rowdata_packet()17991800def _read_rowdata_packet_unbuffered(self):1801# Check if in an active query1802if not self.unbuffered_active:1803return18041805# EOF1806packet = self.connection._read_packet()1807if self._check_packet_is_eof(packet):1808self.unbuffered_active = False1809self.connection = None1810self.rows = None1811return18121813row = self._read_row_from_packet(packet)1814self.affected_rows = 11815self.rows = (row,) # rows should tuple of row for MySQL-python compatibility.1816return row18171818def _finish_unbuffered_query(self):1819# After much reading on the MySQL protocol, it appears that there is,1820# in fact, no way to stop MySQL from sending all the data after1821# executing a query, so we just spin, and wait for an EOF packet.1822while self.unbuffered_active and self.connection._sock is not None:1823try:1824packet = self.connection._read_packet()1825except err.OperationalError as e:1826if e.args[0] in (1827ER.QUERY_TIMEOUT,1828ER.STATEMENT_TIMEOUT,1829):1830# if the query timed out we can simply ignore this error1831self.unbuffered_active = False1832self.connection = None1833return18341835raise18361837if self._check_packet_is_eof(packet):1838self.unbuffered_active = False1839self.connection = None # release reference to kill cyclic reference.18401841def _read_rowdata_packet(self):1842"""Read a rowdata packet for each data row in the result set."""1843rows = []1844while True:1845packet = self.connection._read_packet()1846if self._check_packet_is_eof(packet):1847self.connection = None # release reference to kill cyclic reference.1848break1849rows.append(self._read_row_from_packet(packet))18501851self.affected_rows = len(rows)1852self.rows = tuple(rows)18531854def _read_row_from_packet(self, packet):1855row = []1856for i, (encoding, converter) in enumerate(self.converters):1857try:1858data = packet.read_length_coded_string()1859except IndexError:1860# No more columns in this row1861# See https://github.com/PyMySQL/PyMySQL/pull/4341862break1863if data is not None:1864if encoding is not None:1865try:1866data = data.decode(encoding, errors=self.encoding_errors)1867except UnicodeDecodeError:1868raise UnicodeDecodeError(1869'failed to decode string value in column '1870f"'{self.fields[i].name}' using encoding '{encoding}'; " +1871"use the 'encoding_errors' option on the connection " +1872'to specify how to handle this error',1873)1874if DEBUG:1875print('DEBUG: DATA = ', data)1876if converter is not None:1877data = converter(data)1878row.append(data)1879return tuple(row)18801881def _get_descriptions(self):1882"""Read a column descriptor packet for each column in the result."""1883self.fields = []1884self.converters = []1885use_unicode = self.connection.use_unicode1886conn_encoding = self.connection.encoding1887description = []18881889for i in range(self.field_count):1890field = self.connection._read_packet(FieldDescriptorPacket)1891self.fields.append(field)1892description.append(field.description())1893field_type = field.type_code1894if use_unicode:1895if field_type == FIELD_TYPE.JSON:1896# When SELECT from JSON column: charset = binary1897# When SELECT CAST(... AS JSON): charset = connection encoding1898# This behavior is different from TEXT / BLOB.1899# We should decode result by connection encoding regardless charsetnr.1900# See https://github.com/PyMySQL/PyMySQL/issues/4881901encoding = conn_encoding # SELECT CAST(... AS JSON)1902elif field_type in TEXT_TYPES:1903if field.charsetnr == 63: # binary1904# TEXTs with charset=binary means BINARY types.1905encoding = None1906else:1907encoding = conn_encoding1908else:1909# Integers, Dates and Times, and other basic data is encoded in ascii1910encoding = 'ascii'1911else:1912encoding = None1913converter = self.connection.decoders.get(field_type)1914if converter is converters.through:1915converter = None1916if DEBUG:1917print(f'DEBUG: field={field}, converter={converter}')1918self.converters.append((encoding, converter))19191920eof_packet = self.connection._read_packet()1921assert eof_packet.is_eof_packet(), 'Protocol error, expecting EOF'1922self.description = tuple(description)192319241925class MySQLResultSV(MySQLResult):19261927def __init__(self, connection, unbuffered=False):1928MySQLResult.__init__(self, connection, unbuffered=unbuffered)1929self.options = {1930k: v for k, v in dict(1931default_converters=converters.decoders,1932results_type=connection.results_type,1933parse_json=connection.parse_json,1934invalid_values=connection.invalid_values,1935unbuffered=unbuffered,1936encoding_errors=connection.encoding_errors,1937).items() if v is not UNSET1938}19391940def _read_rowdata_packet(self, *args, **kwargs):1941return _singlestoredb_accel.read_rowdata_packet(self, False, *args, **kwargs)19421943def _read_rowdata_packet_unbuffered(self, *args, **kwargs):1944return _singlestoredb_accel.read_rowdata_packet(self, True, *args, **kwargs)194519461947class LoadLocalFile:19481949def __init__(self, filename, connection):1950self.filename = filename1951self.connection = connection19521953def send_data(self):1954"""Send data packets from the local file to the server"""1955if not self.connection._sock:1956raise err.InterfaceError(0, 'Connection is closed')19571958conn = self.connection1959infile = conn._local_infile_stream19601961# 16KB is efficient enough1962packet_size = min(conn.max_allowed_packet, 16 * 1024)19631964try:19651966if self.filename in [':stream:', b':stream:']:19671968if infile is None:1969raise err.OperationalError(1970ER.FILE_NOT_FOUND,1971':stream: specified for LOCAL INFILE, but no stream was supplied',1972)19731974# Binary IO1975elif isinstance(infile, io.RawIOBase):1976while True:1977chunk = infile.read(packet_size)1978if not chunk:1979break1980conn.write_packet(chunk)19811982# Text IO1983elif isinstance(infile, io.TextIOBase):1984while True:1985chunk = infile.read(packet_size)1986if not chunk:1987break1988conn.write_packet(chunk.encode('utf8'))19891990# Iterable of bytes or str1991elif isinstance(infile, Iterable):1992for chunk in infile:1993if not chunk:1994continue1995if isinstance(chunk, str):1996conn.write_packet(chunk.encode('utf8'))1997else:1998conn.write_packet(chunk)19992000# Queue (empty value ends the iteration)2001elif isinstance(infile, queue.Queue):2002while True:2003chunk = infile.get()2004if not chunk:2005break2006if isinstance(chunk, str):2007conn.write_packet(chunk.encode('utf8'))2008else:2009conn.write_packet(chunk)20102011else:2012raise err.OperationalError(2013ER.FILE_NOT_FOUND,2014':stream: specified for LOCAL INFILE, ' +2015f'but stream type is unrecognized: {infile}',2016)20172018else:2019try:2020with open(self.filename, 'rb') as open_file:2021while True:2022chunk = open_file.read(packet_size)2023if not chunk:2024break2025conn.write_packet(chunk)2026except OSError:2027raise err.OperationalError(2028ER.FILE_NOT_FOUND,2029f"Can't find file '{self.filename!s}'",2030)20312032finally:2033if not conn._closed:2034# send the empty packet to signify we are done sending data2035conn.write_packet(b'')203620372038