Path: blob/main/singlestoredb/mysql/connection.py
469 views
# type: ignore1# Python implementation of the MySQL client-server protocol2# http://dev.mysql.com/doc/internals/en/client-server-protocol.html3# Error codes:4# https://dev.mysql.com/doc/refman/5.5/en/error-handling.html5import errno6import functools7import io8import os9import queue10import socket11import struct12import sys13import traceback14import warnings15from typing import Any16from typing import Dict17from typing import Iterable1819try:20import _singlestoredb_accel21except (ImportError, ModuleNotFoundError):22_singlestoredb_accel = None2324from . import _auth25from ..utils import events2627from .charset import charset_by_name, charset_by_id28from .constants import CLIENT, COMMAND, CR, ER, FIELD_TYPE, SERVER_STATUS29from . import converters30from .cursors import (31Cursor,32CursorSV,33DictCursor,34DictCursorSV,35NamedtupleCursor,36NamedtupleCursorSV,37ArrowCursor,38ArrowCursorSV,39NumpyCursor,40NumpyCursorSV,41PandasCursor,42PandasCursorSV,43PolarsCursor,44PolarsCursorSV,45SSCursor,46SSCursorSV,47SSDictCursor,48SSDictCursorSV,49SSNamedtupleCursor,50SSNamedtupleCursorSV,51SSArrowCursor,52SSArrowCursorSV,53SSNumpyCursor,54SSNumpyCursorSV,55SSPandasCursor,56SSPandasCursorSV,57SSPolarsCursor,58SSPolarsCursorSV,59)60from .optionfile import Parser61from .protocol import (62dump_packet,63MysqlPacket,64FieldDescriptorPacket,65OKPacketWrapper,66EOFPacketWrapper,67LoadLocalPacketWrapper,68)69from . import err70from ..config import get_option71from .. import fusion72from .. import connection73from ..connection import Connection as BaseConnection74from ..utils.debug import log_query7576try:77import ssl7879SSL_ENABLED = True80except ImportError:81ssl = None82SSL_ENABLED = False8384try:85import getpass8687DEFAULT_USER = getpass.getuser()88del getpass89except (ImportError, KeyError):90# KeyError occurs when there's no entry in OS database for a current user.91DEFAULT_USER = None9293DEBUG = get_option('debug.connection')9495TEXT_TYPES = {96FIELD_TYPE.BIT,97FIELD_TYPE.BLOB,98FIELD_TYPE.LONG_BLOB,99FIELD_TYPE.MEDIUM_BLOB,100FIELD_TYPE.STRING,101FIELD_TYPE.TINY_BLOB,102FIELD_TYPE.VAR_STRING,103FIELD_TYPE.VARCHAR,104FIELD_TYPE.GEOMETRY,105FIELD_TYPE.BSON,106FIELD_TYPE.FLOAT32_VECTOR_JSON,107FIELD_TYPE.FLOAT64_VECTOR_JSON,108FIELD_TYPE.INT8_VECTOR_JSON,109FIELD_TYPE.INT16_VECTOR_JSON,110FIELD_TYPE.INT32_VECTOR_JSON,111FIELD_TYPE.INT64_VECTOR_JSON,112FIELD_TYPE.FLOAT32_VECTOR,113FIELD_TYPE.FLOAT64_VECTOR,114FIELD_TYPE.INT8_VECTOR,115FIELD_TYPE.INT16_VECTOR,116FIELD_TYPE.INT32_VECTOR,117FIELD_TYPE.INT64_VECTOR,118}119120UNSET = 'unset'121122DEFAULT_CHARSET = 'utf8mb4'123124MAX_PACKET_LEN = 2**24 - 1125126127def _pack_int24(n):128return struct.pack('<I', n)[:3]129130131# https://dev.mysql.com/doc/internals/en/integer.html#packet-Protocol::LengthEncodedInteger132def _lenenc_int(i):133if i < 0:134raise ValueError(135'Encoding %d is less than 0 - no representation in LengthEncodedInteger' % i,136)137elif i < 0xFB:138return bytes([i])139elif i < (1 << 16):140return b'\xfc' + struct.pack('<H', i)141elif i < (1 << 24):142return b'\xfd' + struct.pack('<I', i)[:3]143elif i < (1 << 64):144return b'\xfe' + struct.pack('<Q', i)145else:146raise ValueError(147'Encoding %x is larger than %x - no representation in LengthEncodedInteger'148% (i, (1 << 64)),149)150151152class Connection(BaseConnection):153"""154Representation of a socket with a mysql server.155156The proper way to get an instance of this class is to call157``connect()``.158159Establish a connection to the SingleStoreDB database.160161Parameters162----------163host : str, optional164Host where the database server is located.165user : str, optional166Username to log in as.167password : str, optional168Password to use.169database : str, optional170Database to use, None to not use a particular one.171port : int, optional172Server port to use, default is usually OK. (default: 3306)173bind_address : str, optional174When the client has multiple network interfaces, specify175the interface from which to connect to the host. Argument can be176a hostname or an IP address.177unix_socket : str, optional178Use a unix socket rather than TCP/IP.179read_timeout : int, optional180The timeout for reading from the connection in seconds181(default: None - no timeout)182write_timeout : int, optional183The timeout for writing to the connection in seconds184(default: None - no timeout)185charset : str, optional186Charset to use.187collation : str, optional188The charset collation189sql_mode : str, optional190Default SQL_MODE to use.191read_default_file : str, optional192Specifies my.cnf file to read these parameters from under the193[client] section.194conv : Dict[str, Callable[Any]], optional195Conversion dictionary to use instead of the default one.196This is used to provide custom marshalling and unmarshalling of types.197See converters.198use_unicode : bool, optional199Whether or not to default to unicode strings.200This option defaults to true.201client_flag : int, optional202Custom flags to send to MySQL. Find potential values in constants.CLIENT.203cursorclass : type, optional204Custom cursor class to use.205init_command : str, optional206Initial SQL statement to run when connection is established.207connect_timeout : int, optional208The timeout for connecting to the database in seconds.209(default: 10, min: 1, max: 31536000)210ssl : Dict[str, str], optional211A dict of arguments similar to mysql_ssl_set()'s parameters or212an ssl.SSLContext.213ssl_ca : str, optional214Path to the file that contains a PEM-formatted CA certificate.215ssl_cert : str, optional216Path to the file that contains a PEM-formatted client certificate.217ssl_cipher : str, optional218SSL ciphers to allow.219ssl_disabled : bool, optional220A boolean value that disables usage of TLS.221ssl_key : str, optional222Path to the file that contains a PEM-formatted private key for the223client certificate.224ssl_verify_cert : str, optional225Set to true to check the server certificate's validity.226ssl_verify_identity : bool, optional227Set to true to check the server's identity.228tls_sni_servername: str, optional229Set server host name for TLS connection230read_default_group : str, optional231Group to read from in the configuration file.232autocommit : bool, optional233Autocommit mode. None means use server default. (default: False)234local_infile : bool, optional235Boolean to enable the use of LOAD DATA LOCAL command. (default: False)236max_allowed_packet : int, optional237Max size of packet sent to server in bytes. (default: 16MB)238Only used to limit size of "LOAD LOCAL INFILE" data packet smaller239than default (16KB).240defer_connect : bool, optional241Don't explicitly connect on construction - wait for connect call.242(default: False)243auth_plugin_map : Dict[str, type], optional244A dict of plugin names to a class that processes that plugin.245The class will take the Connection object as the argument to the246constructor. The class needs an authenticate method taking an247authentication packet as an argument. For the dialog plugin, a248prompt(echo, prompt) method can be used (if no authenticate method)249for returning a string from the user. (experimental)250server_public_key : str, optional251SHA256 authentication plugin public key value. (default: None)252binary_prefix : bool, optional253Add _binary prefix on bytes and bytearray. (default: False)254compress :255Not supported.256named_pipe :257Not supported.258db : str, optional259**DEPRECATED** Alias for database.260passwd : str, optional261**DEPRECATED** Alias for password.262parse_json : bool, optional263Parse JSON values into Python objects?264invalid_values : Dict[int, Any], optional265Dictionary of values to use in place of invalid values266found during conversion of data. The default is to return the byte content267containing the invalid value. The keys are the integers associtated with268the column type.269pure_python : bool, optional270Should we ignore the C extension even if it's available?271This can be given explicitly using True or False, or if the value is None,272the C extension will be loaded if it is available. If set to False and273the C extension can't be loaded, a NotSupportedError is raised.274nan_as_null : bool, optional275Should NaN values be treated as NULLs in parameter substitution including276uploading data?277inf_as_null : bool, optional278Should Inf values be treated as NULLs in parameter substitution including279uploading data?280track_env : bool, optional281Should the connection track the SINGLESTOREDB_URL environment variable?282enable_extended_data_types : bool, optional283Should extended data types (BSON, vector) be enabled?284vector_data_format : str, optional285Specify the data type of vector values: json or binary286287See `Connection <https://www.python.org/dev/peps/pep-0249/#connection-objects>`_288in the specification.289290"""291292driver = 'mysql'293paramstyle = 'pyformat'294295_sock = None296_auth_plugin_name = ''297_closed = False298_secure = False299_tls_sni_servername = None300301def __init__( # noqa: C901302self,303*,304user=None, # The first four arguments is based on DB-API 2.0 recommendation.305password='',306host=None,307database=None,308unix_socket=None,309port=0,310charset='',311collation=None,312sql_mode=None,313read_default_file=None,314conv=None,315use_unicode=True,316client_flag=0,317cursorclass=None,318init_command=None,319connect_timeout=10,320read_default_group=None,321autocommit=False,322local_infile=False,323max_allowed_packet=16 * 1024 * 1024,324defer_connect=False,325auth_plugin_map=None,326read_timeout=None,327write_timeout=None,328bind_address=None,329binary_prefix=False,330program_name=None,331server_public_key=None,332ssl=None,333ssl_ca=None,334ssl_cert=None,335ssl_cipher=None,336ssl_disabled=None,337ssl_key=None,338ssl_verify_cert=None,339ssl_verify_identity=None,340tls_sni_servername=None,341parse_json=True,342invalid_values=None,343pure_python=None,344buffered=True,345results_type='tuples',346compress=None, # not supported347named_pipe=None, # not supported348passwd=None, # deprecated349db=None, # deprecated350driver=None, # internal use351conn_attrs=None,352multi_statements=None,353client_found_rows=None,354nan_as_null=None,355inf_as_null=None,356encoding_errors='strict',357track_env=False,358enable_extended_data_types=True,359vector_data_format='binary',360):361BaseConnection.__init__(**dict(locals()))362363if db is not None and database is None:364# We will raise warning in 2022 or later.365# See https://github.com/PyMySQL/PyMySQL/issues/939366# warnings.warn("'db' is deprecated, use 'database'", DeprecationWarning, 3)367database = db368if passwd is not None and not password:369# We will raise warning in 2022 or later.370# See https://github.com/PyMySQL/PyMySQL/issues/939371# warnings.warn(372# "'passwd' is deprecated, use 'password'", DeprecationWarning, 3373# )374password = passwd375376if compress or named_pipe:377raise NotImplementedError(378'compress and named_pipe arguments are not supported',379)380381self._local_infile = bool(local_infile)382self._local_infile_stream = None383if self._local_infile:384client_flag |= CLIENT.LOCAL_FILES385if multi_statements:386client_flag |= CLIENT.MULTI_STATEMENTS387if client_found_rows:388client_flag |= CLIENT.FOUND_ROWS389390if read_default_group and not read_default_file:391if sys.platform.startswith('win'):392read_default_file = 'c:\\my.ini'393else:394read_default_file = '/etc/my.cnf'395396if read_default_file:397if not read_default_group:398read_default_group = 'client'399400cfg = Parser()401cfg.read(os.path.expanduser(read_default_file))402403def _config(key, arg):404if arg:405return arg406try:407return cfg.get(read_default_group, key)408except Exception:409return arg410411user = _config('user', user)412password = _config('password', password)413host = _config('host', host)414database = _config('database', database)415unix_socket = _config('socket', unix_socket)416port = int(_config('port', port))417bind_address = _config('bind-address', bind_address)418charset = _config('default-character-set', charset)419if not ssl:420ssl = {}421if isinstance(ssl, dict):422for key in ['ca', 'capath', 'cert', 'key', 'cipher']:423value = _config('ssl-' + key, ssl.get(key))424if value:425ssl[key] = value426427self.ssl = False428if not ssl_disabled:429if ssl_ca or ssl_cert or ssl_key or ssl_cipher or \430ssl_verify_cert or ssl_verify_identity:431ssl = {432'ca': ssl_ca,433'check_hostname': bool(ssl_verify_identity),434'verify_mode': ssl_verify_cert435if ssl_verify_cert is not None436else False,437}438if ssl_cert is not None:439ssl['cert'] = ssl_cert440if ssl_key is not None:441ssl['key'] = ssl_key442if ssl_cipher is not None:443ssl['cipher'] = ssl_cipher444if ssl:445if not SSL_ENABLED:446raise NotImplementedError('ssl module not found')447self.ssl = True448client_flag |= CLIENT.SSL449self.ctx = self._create_ssl_ctx(ssl)450451self.host = host or 'localhost'452self.port = port or 3306453if type(self.port) is not int:454raise ValueError('port should be of type int')455self.user = user or DEFAULT_USER456self.password = password or b''457if isinstance(self.password, str):458self.password = self.password.encode('latin1')459self.db = database460self.unix_socket = unix_socket461self.bind_address = bind_address462if not (0 < connect_timeout <= 31536000):463raise ValueError('connect_timeout should be >0 and <=31536000')464self.connect_timeout = connect_timeout or None465if read_timeout is not None and read_timeout <= 0:466raise ValueError('read_timeout should be > 0')467self._read_timeout = read_timeout468if write_timeout is not None and write_timeout <= 0:469raise ValueError('write_timeout should be > 0')470self._write_timeout = write_timeout471472self.charset = charset or DEFAULT_CHARSET473self.collation = collation474self.use_unicode = use_unicode475self.encoding_errors = encoding_errors476477self.encoding = charset_by_name(self.charset).encoding478479client_flag |= CLIENT.CAPABILITIES480client_flag |= CLIENT.CONNECT_WITH_DB481482self.client_flag = client_flag483484self.pure_python = pure_python485self.results_type = results_type486self.resultclass = MySQLResult487if cursorclass is not None:488self.cursorclass = cursorclass489elif buffered:490if 'dict' in self.results_type:491self.cursorclass = DictCursor492elif 'namedtuple' in self.results_type:493self.cursorclass = NamedtupleCursor494elif 'numpy' in self.results_type:495self.cursorclass = NumpyCursor496elif 'arrow' in self.results_type:497self.cursorclass = ArrowCursor498elif 'pandas' in self.results_type:499self.cursorclass = PandasCursor500elif 'polars' in self.results_type:501self.cursorclass = PolarsCursor502else:503self.cursorclass = Cursor504else:505if 'dict' in self.results_type:506self.cursorclass = SSDictCursor507elif 'namedtuple' in self.results_type:508self.cursorclass = SSNamedtupleCursor509elif 'numpy' in self.results_type:510self.cursorclass = SSNumpyCursor511elif 'arrow' in self.results_type:512self.cursorclass = SSArrowCursor513elif 'pandas' in self.results_type:514self.cursorclass = SSPandasCursor515elif 'polars' in self.results_type:516self.cursorclass = SSPolarsCursor517else:518self.cursorclass = SSCursor519520if self.pure_python is False and _singlestoredb_accel is None:521try:522import _singlestortedb_accel # noqa: F401523except Exception:524import traceback525traceback.print_exc(file=sys.stderr)526finally:527raise err.NotSupportedError(528'pure_python=False, but the '529'C extension can not be loaded',530)531532if self.pure_python is True:533pass534535# The C extension handles these types internally.536elif _singlestoredb_accel is not None:537self.resultclass = MySQLResultSV538if self.cursorclass is Cursor:539self.cursorclass = CursorSV540elif self.cursorclass is SSCursor:541self.cursorclass = SSCursorSV542elif self.cursorclass is DictCursor:543self.cursorclass = DictCursorSV544self.results_type = 'dicts'545elif self.cursorclass is SSDictCursor:546self.cursorclass = SSDictCursorSV547self.results_type = 'dicts'548elif self.cursorclass is NamedtupleCursor:549self.cursorclass = NamedtupleCursorSV550self.results_type = 'namedtuples'551elif self.cursorclass is SSNamedtupleCursor:552self.cursorclass = SSNamedtupleCursorSV553self.results_type = 'namedtuples'554elif self.cursorclass is NumpyCursor:555self.cursorclass = NumpyCursorSV556self.results_type = 'numpy'557elif self.cursorclass is SSNumpyCursor:558self.cursorclass = SSNumpyCursorSV559self.results_type = 'numpy'560elif self.cursorclass is ArrowCursor:561self.cursorclass = ArrowCursorSV562self.results_type = 'arrow'563elif self.cursorclass is SSArrowCursor:564self.cursorclass = SSArrowCursorSV565self.results_type = 'arrow'566elif self.cursorclass is PandasCursor:567self.cursorclass = PandasCursorSV568self.results_type = 'pandas'569elif self.cursorclass is SSPandasCursor:570self.cursorclass = SSPandasCursorSV571self.results_type = 'pandas'572elif self.cursorclass is PolarsCursor:573self.cursorclass = PolarsCursorSV574self.results_type = 'polars'575elif self.cursorclass is SSPolarsCursor:576self.cursorclass = SSPolarsCursorSV577self.results_type = 'polars'578579self._result = None580self._affected_rows = 0581self.host_info = 'Not connected'582583# specified autocommit mode. None means use server default.584self.autocommit_mode = autocommit585586if conv is None:587conv = converters.conversions588589conv = conv.copy()590591self.parse_json = parse_json592self.invalid_values = (invalid_values or {}).copy()593594# Disable JSON parsing for Arrow595if self.results_type in ['arrow']:596conv[245] = None597self.parse_json = False598599# Disable date/time parsing for polars; let polars do the parsing600elif self.results_type in ['polars']:601conv[7] = None602conv[10] = None603conv[12] = None604605# Need for MySQLdb compatibility.606self.encoders = {k: v for (k, v) in conv.items() if type(k) is not int}607self.decoders = {k: v for (k, v) in conv.items() if type(k) is int}608self.sql_mode = sql_mode609self.init_command = init_command610self.max_allowed_packet = max_allowed_packet611self._auth_plugin_map = auth_plugin_map or {}612self._binary_prefix = binary_prefix613self.server_public_key = server_public_key614615if self.connection_params['nan_as_null'] or \616self.connection_params['inf_as_null']:617float_encoder = self.encoders.get(float)618if float_encoder is not None:619self.encoders[float] = functools.partial(620float_encoder,621nan_as_null=self.connection_params['nan_as_null'],622inf_as_null=self.connection_params['inf_as_null'],623)624625from .. import __version__ as VERSION_STRING626627if 'SINGLESTOREDB_WORKLOAD_TYPE' in os.environ:628VERSION_STRING += '+' + os.environ['SINGLESTOREDB_WORKLOAD_TYPE']629630self._connect_attrs = {631'_os': str(sys.platform),632'_pid': str(os.getpid()),633'_client_name': 'SingleStoreDB Python Client',634'_client_version': VERSION_STRING,635}636637if program_name:638self._connect_attrs['program_name'] = program_name639if conn_attrs is not None:640# do not overwrite the attributes that we set ourselves641for k, v in conn_attrs.items():642if k not in self._connect_attrs:643self._connect_attrs[k] = v644645self._is_committable = True646self._in_sync = False647self._tls_sni_servername = tls_sni_servername648self._track_env = bool(track_env) or self.host == 'singlestore.com'649self._enable_extended_data_types = enable_extended_data_types650if vector_data_format.lower() in ['json', 'binary']:651self._vector_data_format = vector_data_format652else:653raise ValueError(654'unknown value for vector_data_format, '655f'expecting "json" or "binary": {vector_data_format}',656)657self._connection_info = {}658events.subscribe(self._handle_event)659660if defer_connect or self._track_env:661self._sock = None662else:663self.connect()664665def _handle_event(self, data: Dict[str, Any]) -> None:666if data.get('name', '') == 'singlestore.portal.connection_updated':667self._connection_info = dict(data)668669@property670def messages(self):671# TODO672[]673674def __enter__(self):675return self676677def __exit__(self, *exc_info):678del exc_info679self.close()680681def _raise_mysql_exception(self, data):682err.raise_mysql_exception(data)683684def _create_ssl_ctx(self, sslp):685if isinstance(sslp, ssl.SSLContext):686return sslp687ca = sslp.get('ca')688capath = sslp.get('capath')689hasnoca = ca is None and capath is None690ctx = ssl.create_default_context(cafile=ca, capath=capath)691ctx.check_hostname = not hasnoca and sslp.get('check_hostname', True)692verify_mode_value = sslp.get('verify_mode')693if verify_mode_value is None:694ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED695elif isinstance(verify_mode_value, bool):696ctx.verify_mode = ssl.CERT_REQUIRED if verify_mode_value else ssl.CERT_NONE697else:698if isinstance(verify_mode_value, str):699verify_mode_value = verify_mode_value.lower()700if verify_mode_value in ('none', '0', 'false', 'no'):701ctx.verify_mode = ssl.CERT_NONE702elif verify_mode_value == 'optional':703ctx.verify_mode = ssl.CERT_OPTIONAL704elif verify_mode_value in ('required', '1', 'true', 'yes'):705ctx.verify_mode = ssl.CERT_REQUIRED706else:707ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED708if 'cert' in sslp:709ctx.load_cert_chain(sslp['cert'], keyfile=sslp.get('key'))710if 'cipher' in sslp:711ctx.set_ciphers(sslp['cipher'])712ctx.options |= ssl.OP_NO_SSLv2713ctx.options |= ssl.OP_NO_SSLv3714return ctx715716def close(self):717"""718Send the quit message and close the socket.719720See `Connection.close()721<https://www.python.org/dev/peps/pep-0249/#Connection.close>`_722in the specification.723724Raises725------726Error : If the connection is already closed.727728"""729self._result = None730if self.host == 'singlestore.com':731return732if self._closed:733raise err.Error('Already closed')734events.unsubscribe(self._handle_event)735self._closed = True736if self._sock is None:737return738send_data = struct.pack('<iB', 1, COMMAND.COM_QUIT)739try:740self._write_bytes(send_data)741except Exception:742pass743finally:744self._force_close()745746@property747def open(self):748"""Return True if the connection is open."""749return self._sock is not None750751def is_connected(self):752"""Return True if the connection is open."""753return self.open754755def _force_close(self):756"""Close connection without QUIT message."""757if self._sock:758try:759self._sock.close()760except: # noqa761pass762self._sock = None763self._rfile = None764765__del__ = _force_close766767def autocommit(self, value):768"""Enable autocommit in the server."""769self.autocommit_mode = bool(value)770current = self.get_autocommit()771if value != current:772self._send_autocommit_mode()773774def get_autocommit(self):775"""Retrieve autocommit status."""776return bool(self.server_status & SERVER_STATUS.SERVER_STATUS_AUTOCOMMIT)777778def _read_ok_packet(self):779pkt = self._read_packet()780if not pkt.is_ok_packet():781raise err.OperationalError(782CR.CR_COMMANDS_OUT_OF_SYNC,783'Command Out of Sync',784)785ok = OKPacketWrapper(pkt)786self.server_status = ok.server_status787return ok788789def _send_autocommit_mode(self):790"""Set whether or not to commit after every execute()."""791log_query('SET AUTOCOMMIT = %s' % self.escape(self.autocommit_mode))792self._execute_command(793COMMAND.COM_QUERY, 'SET AUTOCOMMIT = %s' % self.escape(self.autocommit_mode),794)795self._read_ok_packet()796797def begin(self):798"""Begin transaction."""799log_query('BEGIN')800if self.host == 'singlestore.com':801return802self._execute_command(COMMAND.COM_QUERY, 'BEGIN')803self._read_ok_packet()804805def commit(self):806"""807Commit changes to stable storage.808809See `Connection.commit() <https://www.python.org/dev/peps/pep-0249/#commit>`_810in the specification.811812"""813log_query('COMMIT')814if not self._is_committable or self.host == 'singlestore.com':815self._is_committable = True816return817self._execute_command(COMMAND.COM_QUERY, 'COMMIT')818self._read_ok_packet()819820def rollback(self):821"""822Roll back the current transaction.823824See `Connection.rollback() <https://www.python.org/dev/peps/pep-0249/#rollback>`_825in the specification.826827"""828log_query('ROLLBACK')829if not self._is_committable or self.host == 'singlestore.com':830self._is_committable = True831return832self._execute_command(COMMAND.COM_QUERY, 'ROLLBACK')833self._read_ok_packet()834835def show_warnings(self):836"""Send the "SHOW WARNINGS" SQL command."""837log_query('SHOW WARNINGS')838self._execute_command(COMMAND.COM_QUERY, 'SHOW WARNINGS')839result = self.resultclass(self)840result.read()841return result.rows842843def select_db(self, db):844"""845Set current db.846847db : str848The name of the db.849850"""851self._execute_command(COMMAND.COM_INIT_DB, db)852self._read_ok_packet()853854def escape(self, obj, mapping=None):855"""856Escape whatever value is passed.857858Non-standard, for internal use; do not use this in your applications.859860"""861dtype = type(obj)862if dtype is str or isinstance(obj, str):863return "'{}'".format(self.escape_string(obj))864if dtype is bytes or dtype is bytearray or isinstance(obj, (bytes, bytearray)):865return self._quote_bytes(obj)866if mapping is None:867mapping = self.encoders868return converters.escape_item(obj, self.charset, mapping=mapping)869870def literal(self, obj):871"""872Alias for escape().873874Non-standard, for internal use; do not use this in your applications.875876"""877return self.escape(obj, self.encoders)878879def escape_string(self, s):880"""Escape a string value."""881if self.server_status & SERVER_STATUS.SERVER_STATUS_NO_BACKSLASH_ESCAPES:882return s.replace("'", "''")883return converters.escape_string(s)884885def _quote_bytes(self, s):886if self.server_status & SERVER_STATUS.SERVER_STATUS_NO_BACKSLASH_ESCAPES:887if self._binary_prefix:888return "_binary X'{}'".format(s.hex())889return "X'{}'".format(s.hex())890return converters.escape_bytes(s)891892def cursor(self):893"""Create a new cursor to execute queries with."""894return self.cursorclass(self)895896# The following methods are INTERNAL USE ONLY (called from Cursor)897def query(self, sql, unbuffered=False, infile_stream=None):898"""899Run a query on the server.900901Internal use only.902903"""904# if DEBUG:905# print("DEBUG: sending query:", sql)906handler = fusion.get_handler(sql)907if handler is not None:908self._is_committable = False909self._result = fusion.execute(self, sql, handler=handler)910self._affected_rows = self._result.affected_rows911else:912self._is_committable = True913if isinstance(sql, str):914sql = sql.encode(self.encoding, 'surrogateescape')915self._local_infile_stream = infile_stream916self._execute_command(COMMAND.COM_QUERY, sql)917self._affected_rows = self._read_query_result(unbuffered=unbuffered)918self._local_infile_stream = None919return self._affected_rows920921def next_result(self, unbuffered=False):922"""923Retrieve the next result set.924925Internal use only.926927"""928self._affected_rows = self._read_query_result(unbuffered=unbuffered)929return self._affected_rows930931def affected_rows(self):932"""933Return number of affected rows.934935Internal use only.936937"""938return self._affected_rows939940def kill(self, thread_id):941"""942Execute kill command.943944Internal use only.945946"""947arg = struct.pack('<I', thread_id)948self._execute_command(COMMAND.COM_PROCESS_KILL, arg)949return self._read_ok_packet()950951def ping(self, reconnect=True):952"""953Check if the server is alive.954955Parameters956----------957reconnect : bool, optional958If the connection is closed, reconnect.959960Raises961------962Error : If the connection is closed and reconnect=False.963964"""965if self._sock is None:966if reconnect:967self.connect()968reconnect = False969else:970raise err.Error('Already closed')971try:972self._execute_command(COMMAND.COM_PING, '')973self._read_ok_packet()974except Exception:975if reconnect:976self.connect()977self.ping(False)978else:979raise980981def set_charset(self, charset):982"""Deprecated. Use set_character_set() instead."""983# This function has been implemented in old PyMySQL.984# But this name is different from MySQLdb.985# So we keep this function for compatibility and add986# new set_character_set() function.987self.set_character_set(charset)988989def set_character_set(self, charset, collation=None):990"""991Set charaset (and collation) on the server.992993Send "SET NAMES charset [COLLATE collation]" query.994Update Connection.encoding based on charset.995996Parameters997----------998charset : str999The charset to enable.1000collation : str, optional1001The collation value10021003"""1004# Make sure charset is supported.1005encoding = charset_by_name(charset).encoding10061007if collation:1008query = f'SET NAMES {charset} COLLATE {collation}'1009else:1010query = f'SET NAMES {charset}'1011self._execute_command(COMMAND.COM_QUERY, query)1012self._read_packet()1013self.charset = charset1014self.encoding = encoding1015self.collation = collation10161017def _sync_connection(self):1018"""Synchronize connection with env variable."""1019if self._in_sync:1020return10211022if not self._track_env:1023return10241025url = self._connection_info.get('connection_url')1026if not url:1027url = os.environ.get('SINGLESTOREDB_URL')1028if not url:1029return10301031out = {}1032urlp = connection._parse_url(url)1033out.update(urlp)10341035out = connection._cast_params(out)10361037# Set default port based on driver.1038if 'port' not in out or not out['port']:1039out['port'] = int(get_option('port') or 3306)10401041# If there is no user and the password is empty, remove the password key.1042if 'user' not in out and not out.get('password', None):1043out.pop('password', None)10441045if out['host'] == 'singlestore.com':1046raise err.InterfaceError(0, 'Connection URL has not been established')10471048# If it's just a password change, we don't need to reconnect1049if self._sock is not None and \1050(self.host, self.port, self.user, self.db) == \1051(out['host'], out['port'], out['user'], out.get('database')):1052return10531054self.host = out['host']1055self.port = out['port']1056self.user = out['user']1057if isinstance(out['password'], str):1058self.password = out['password'].encode('latin-1')1059else:1060self.password = out['password'] or b''1061self.db = out.get('database')1062try:1063self._in_sync = True1064self.connect()1065finally:1066self._in_sync = False10671068def connect(self, sock=None):1069"""1070Connect to server using existing parameters.10711072Internal use only.10731074"""1075self._closed = False1076try:1077if sock is None:1078if self.unix_socket:1079sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)1080sock.settimeout(self.connect_timeout)1081sock.connect(self.unix_socket)1082self.host_info = 'Localhost via UNIX socket'1083self._secure = True1084if DEBUG:1085print('connected using unix_socket')1086else:1087kwargs = {}1088if self.bind_address is not None:1089kwargs['source_address'] = (self.bind_address, 0)1090while True:1091try:1092sock = socket.create_connection(1093(self.host, self.port), self.connect_timeout, **kwargs,1094)1095break1096except OSError as e:1097if e.errno == errno.EINTR:1098continue1099raise1100self.host_info = 'socket %s:%d' % (self.host, self.port)1101if DEBUG:1102print('connected using socket')1103sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)1104sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)1105sock.settimeout(None)11061107self._sock = sock1108self._rfile = sock.makefile('rb')1109self._next_seq_id = 011101111self._get_server_information()1112self._request_authentication()11131114# Send "SET NAMES" query on init for:1115# - Ensure charaset (and collation) is set to the server.1116# - collation_id in handshake packet may be ignored.1117# - If collation is not specified, we don't know what is server's1118# default collation for the charset. For example, default collation1119# of utf8mb4 is:1120# - MySQL 5.7, MariaDB 10.x: utf8mb4_general_ci1121# - MySQL 8.0: utf8mb4_0900_ai_ci1122#1123# Reference:1124# - https://github.com/PyMySQL/PyMySQL/issues/10921125# - https://github.com/wagtail/wagtail/issues/94771126# - https://zenn.dev/methane/articles/2023-mysql-collation (Japanese)1127self.set_character_set(self.charset, self.collation)11281129if self.sql_mode is not None:1130c = self.cursor()1131c.execute('SET sql_mode=%s', (self.sql_mode,))1132c.close()11331134if self._enable_extended_data_types:1135c = self.cursor()1136try:1137c.execute('SET @@SESSION.enable_extended_types_metadata=on')1138except self.OperationalError:1139pass1140c.close()11411142if self._vector_data_format:1143c = self.cursor()1144try:1145val = self._vector_data_format1146c.execute(f'SET @@SESSION.vector_type_project_format={val}')1147except self.OperationalError:1148pass1149c.close()11501151if self.init_command is not None:1152c = self.cursor()1153c.execute(self.init_command)1154c.close()11551156if self.autocommit_mode is not None:1157self.autocommit(self.autocommit_mode)11581159except BaseException as e:1160self._rfile = None1161if sock is not None:1162try:1163sock.close()1164except: # noqa1165pass11661167if isinstance(e, (OSError, IOError, socket.error)):1168exc = err.OperationalError(1169CR.CR_CONN_HOST_ERROR,1170f'Can\'t connect to MySQL server on {self.host!r} ({e})',1171)1172# Keep original exception and traceback to investigate error.1173exc.original_exception = e1174exc.traceback = traceback.format_exc()1175if DEBUG:1176print(exc.traceback)1177raise exc11781179# If e is neither DatabaseError or IOError, It's a bug.1180# But raising AssertionError hides original error.1181# So just reraise it.1182raise11831184def write_packet(self, payload):1185"""1186Writes an entire "mysql packet" in its entirety to the network.11871188Adds its length and sequence number.11891190"""1191# Internal note: when you build packet manually and calls _write_bytes()1192# directly, you should set self._next_seq_id properly.1193data = _pack_int24(len(payload)) + bytes([self._next_seq_id]) + payload1194if DEBUG:1195dump_packet(data)1196self._write_bytes(data)1197self._next_seq_id = (self._next_seq_id + 1) % 25611981199def _read_packet(self, packet_type=MysqlPacket):1200"""1201Read an entire "mysql packet" in its entirety from the network.12021203Raises1204------1205OperationalError : If the connection to the MySQL server is lost.1206InternalError : If the packet sequence number is wrong.12071208Returns1209-------1210MysqlPacket12111212"""1213buff = bytearray()1214while True:1215packet_header = self._read_bytes(4)1216# if DEBUG: dump_packet(packet_header)12171218btrl, btrh, packet_number = struct.unpack('<HBB', packet_header)1219bytes_to_read = btrl + (btrh << 16)1220if packet_number != self._next_seq_id:1221self._force_close()1222if packet_number == 0:1223# MariaDB sends error packet with seqno==0 when shutdown1224raise err.OperationalError(1225CR.CR_SERVER_LOST,1226'Lost connection to MySQL server during query',1227)1228raise err.InternalError(1229'Packet sequence number wrong - got %d expected %d'1230% (packet_number, self._next_seq_id),1231)1232self._next_seq_id = (self._next_seq_id + 1) % 25612331234recv_data = self._read_bytes(bytes_to_read)1235if DEBUG:1236dump_packet(recv_data)1237buff += recv_data1238# https://dev.mysql.com/doc/internals/en/sending-more-than-16mbyte.html1239if bytes_to_read == 0xFFFFFF:1240continue1241if bytes_to_read < MAX_PACKET_LEN:1242break12431244packet = packet_type(bytes(buff), self.encoding)1245if packet.is_error_packet():1246if self._result is not None and self._result.unbuffered_active is True:1247self._result.unbuffered_active = False1248packet.raise_for_error()1249return packet12501251def _read_bytes(self, num_bytes):1252if self._read_timeout is not None:1253self._sock.settimeout(self._read_timeout)1254while True:1255try:1256data = self._rfile.read(num_bytes)1257break1258except OSError as e:1259if e.errno == errno.EINTR:1260continue1261self._force_close()1262raise err.OperationalError(1263CR.CR_SERVER_LOST,1264'Lost connection to MySQL server during query (%s)' % (e,),1265)1266except BaseException:1267# Don't convert unknown exception to MySQLError.1268self._force_close()1269raise1270if len(data) < num_bytes:1271self._force_close()1272raise err.OperationalError(1273CR.CR_SERVER_LOST, 'Lost connection to MySQL server during query',1274)1275return data12761277def _write_bytes(self, data):1278if self._write_timeout is not None:1279self._sock.settimeout(self._write_timeout)1280try:1281self._sock.sendall(data)1282except OSError as e:1283self._force_close()1284raise err.OperationalError(1285CR.CR_SERVER_GONE_ERROR, f'MySQL server has gone away ({e!r})',1286)12871288def _read_query_result(self, unbuffered=False):1289self._result = None1290if unbuffered:1291result = self.resultclass(self, unbuffered=unbuffered)1292else:1293result = self.resultclass(self)1294result.read()1295self._result = result1296if result.server_status is not None:1297self.server_status = result.server_status1298return result.affected_rows12991300def insert_id(self):1301if self._result:1302return self._result.insert_id1303else:1304return 013051306def _execute_command(self, command, sql):1307"""1308Execute command.13091310Raises1311------1312InterfaceError : If the connection is closed.1313ValueError : If no username was specified.13141315"""1316self._sync_connection()13171318if self._sock is None:1319raise err.InterfaceError(0, 'The connection has been closed')13201321# If the last query was unbuffered, make sure it finishes before1322# sending new commands1323if self._result is not None:1324if self._result.unbuffered_active:1325warnings.warn('Previous unbuffered result was left incomplete')1326self._result._finish_unbuffered_query()1327while self._result.has_next:1328self.next_result()1329self._result = None13301331if isinstance(sql, str):1332sql = sql.encode(self.encoding)13331334packet_size = min(MAX_PACKET_LEN, len(sql) + 1) # +1 is for command13351336# tiny optimization: build first packet manually instead of1337# calling self..write_packet()1338prelude = struct.pack('<iB', packet_size, command)1339packet = prelude + sql[: packet_size - 1]1340self._write_bytes(packet)1341if DEBUG:1342dump_packet(packet)1343self._next_seq_id = 113441345if packet_size < MAX_PACKET_LEN:1346return13471348sql = sql[packet_size - 1:]1349while True:1350packet_size = min(MAX_PACKET_LEN, len(sql))1351self.write_packet(sql[:packet_size])1352sql = sql[packet_size:]1353if not sql and packet_size < MAX_PACKET_LEN:1354break13551356def _request_authentication(self): # noqa: C9011357# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse1358if int(self.server_version.split('.', 1)[0]) >= 5:1359self.client_flag |= CLIENT.MULTI_RESULTS13601361if self.user is None:1362raise ValueError('Did not specify a username')13631364charset_id = charset_by_name(self.charset).id1365if isinstance(self.user, str):1366self.user = self.user.encode(self.encoding)13671368data_init = struct.pack(1369'<iIB23s', self.client_flag, MAX_PACKET_LEN, charset_id, b'',1370)13711372if self.ssl and self.server_capabilities & CLIENT.SSL:1373self.write_packet(data_init)13741375hostname = self.host1376if self._tls_sni_servername:1377hostname = self._tls_sni_servername1378self._sock = self.ctx.wrap_socket(self._sock, server_hostname=hostname)1379self._rfile = self._sock.makefile('rb')1380self._secure = True13811382data = data_init + self.user + b'\0'13831384authresp = b''1385plugin_name = None13861387if self._auth_plugin_name == '':1388plugin_name = b''1389authresp = _auth.scramble_native_password(self.password, self.salt)1390elif self._auth_plugin_name == 'mysql_native_password':1391plugin_name = b'mysql_native_password'1392authresp = _auth.scramble_native_password(self.password, self.salt)1393elif self._auth_plugin_name == 'caching_sha2_password':1394plugin_name = b'caching_sha2_password'1395if self.password:1396if DEBUG:1397print('caching_sha2: trying fast path')1398authresp = _auth.scramble_caching_sha2(self.password, self.salt)1399else:1400if DEBUG:1401print('caching_sha2: empty password')1402elif self._auth_plugin_name == 'sha256_password':1403plugin_name = b'sha256_password'1404if self.ssl and self.server_capabilities & CLIENT.SSL:1405authresp = self.password + b'\0'1406elif self.password:1407authresp = b'\1' # request public key1408else:1409authresp = b'\0' # empty password14101411if self.server_capabilities & CLIENT.PLUGIN_AUTH_LENENC_CLIENT_DATA:1412data += _lenenc_int(len(authresp)) + authresp1413elif self.server_capabilities & CLIENT.SECURE_CONNECTION:1414data += struct.pack('B', len(authresp)) + authresp1415else: # pragma: no cover - no testing against servers w/o secure auth (>=5.0)1416data += authresp + b'\0'14171418if self.server_capabilities & CLIENT.CONNECT_WITH_DB:1419db = self.db1420if isinstance(db, str):1421db = db.encode(self.encoding)1422data += (db or b'') + b'\0'14231424if self.server_capabilities & CLIENT.PLUGIN_AUTH:1425data += (plugin_name or b'') + b'\0'14261427if self.server_capabilities & CLIENT.CONNECT_ATTRS:1428connect_attrs = b''1429for k, v in self._connect_attrs.items():1430k = k.encode('utf-8')1431connect_attrs += _lenenc_int(len(k)) + k1432v = v.encode('utf-8')1433connect_attrs += _lenenc_int(len(v)) + v1434data += _lenenc_int(len(connect_attrs)) + connect_attrs14351436self.write_packet(data)1437auth_packet = self._read_packet()14381439# if authentication method isn't accepted the first byte1440# will have the octet 2541441if auth_packet.is_auth_switch_request():1442if DEBUG:1443print('received auth switch')1444# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest1445auth_packet.read_uint8() # 0xfe packet identifier1446plugin_name = auth_packet.read_string()1447if (1448self.server_capabilities & CLIENT.PLUGIN_AUTH1449and plugin_name is not None1450):1451auth_packet = self._process_auth(plugin_name, auth_packet)1452else:1453raise err.OperationalError('received unknown auth switch request')1454elif auth_packet.is_extra_auth_data():1455if DEBUG:1456print('received extra data')1457# https://dev.mysql.com/doc/internals/en/successful-authentication.html1458if self._auth_plugin_name == 'caching_sha2_password':1459auth_packet = _auth.caching_sha2_password_auth(self, auth_packet)1460elif self._auth_plugin_name == 'sha256_password':1461auth_packet = _auth.sha256_password_auth(self, auth_packet)1462else:1463raise err.OperationalError(1464'Received extra packet for auth method %r', self._auth_plugin_name,1465)14661467if DEBUG:1468print('Succeed to auth')14691470def _process_auth(self, plugin_name, auth_packet):1471handler = self._get_auth_plugin_handler(plugin_name)1472if handler:1473try:1474return handler.authenticate(auth_packet)1475except AttributeError:1476if plugin_name != b'dialog':1477raise err.OperationalError(1478CR.CR_AUTH_PLUGIN_CANNOT_LOAD,1479"Authentication plugin '%s'"1480' not loaded: - %r missing authenticate method'1481% (plugin_name, type(handler)),1482)1483if plugin_name == b'caching_sha2_password':1484return _auth.caching_sha2_password_auth(self, auth_packet)1485elif plugin_name == b'sha256_password':1486return _auth.sha256_password_auth(self, auth_packet)1487elif plugin_name == b'mysql_native_password':1488data = _auth.scramble_native_password(self.password, auth_packet.read_all())1489elif plugin_name == b'client_ed25519':1490data = _auth.ed25519_password(self.password, auth_packet.read_all())1491elif plugin_name == b'mysql_old_password':1492data = (1493_auth.scramble_old_password(self.password, auth_packet.read_all())1494+ b'\0'1495)1496elif plugin_name == b'mysql_clear_password':1497# https://dev.mysql.com/doc/internals/en/clear-text-authentication.html1498data = self.password + b'\0'1499elif plugin_name == b'auth_gssapi_client':1500data = _auth.gssapi_auth(auth_packet.read_all())1501elif plugin_name == b'dialog':1502pkt = auth_packet1503while True:1504flag = pkt.read_uint8()1505echo = (flag & 0x06) == 0x021506last = (flag & 0x01) == 0x011507prompt = pkt.read_all()15081509if prompt == b'Password: ':1510self.write_packet(self.password + b'\0')1511elif handler:1512resp = 'no response - TypeError within plugin.prompt method'1513try:1514resp = handler.prompt(echo, prompt)1515self.write_packet(resp + b'\0')1516except AttributeError:1517raise err.OperationalError(1518CR.CR_AUTH_PLUGIN_CANNOT_LOAD,1519"Authentication plugin '%s'"1520' not loaded: - %r missing prompt method'1521% (plugin_name, handler),1522)1523except TypeError:1524raise err.OperationalError(1525CR.CR_AUTH_PLUGIN_ERR,1526"Authentication plugin '%s'"1527" %r didn't respond with string. Returned '%r' to prompt %r"1528% (plugin_name, handler, resp, prompt),1529)1530else:1531raise err.OperationalError(1532CR.CR_AUTH_PLUGIN_CANNOT_LOAD,1533"Authentication plugin '%s' not configured" % (plugin_name,),1534)1535pkt = self._read_packet()1536pkt.check_error()1537if pkt.is_ok_packet() or last:1538break1539return pkt1540else:1541raise err.OperationalError(1542CR.CR_AUTH_PLUGIN_CANNOT_LOAD,1543"Authentication plugin '%s' not configured" % plugin_name,1544)15451546self.write_packet(data)1547pkt = self._read_packet()1548pkt.check_error()1549return pkt15501551def _get_auth_plugin_handler(self, plugin_name):1552plugin_class = self._auth_plugin_map.get(plugin_name)1553if not plugin_class and isinstance(plugin_name, bytes):1554plugin_class = self._auth_plugin_map.get(plugin_name.decode('ascii'))1555if plugin_class:1556try:1557handler = plugin_class(self)1558except TypeError:1559raise err.OperationalError(1560CR.CR_AUTH_PLUGIN_CANNOT_LOAD,1561"Authentication plugin '%s'"1562' not loaded: - %r cannot be constructed with connection object'1563% (plugin_name, plugin_class),1564)1565else:1566handler = None1567return handler15681569# _mysql support1570def thread_id(self):1571return self.server_thread_id[0]15721573def character_set_name(self):1574return self.charset15751576def get_host_info(self):1577return self.host_info15781579def get_proto_info(self):1580return self.protocol_version15811582def _get_server_information(self):1583i = 01584packet = self._read_packet()1585data = packet.get_all_data()15861587self.protocol_version = data[i]1588i += 115891590server_end = data.find(b'\0', i)1591self.server_version = data[i:server_end].decode('latin1')1592i = server_end + 115931594self.server_thread_id = struct.unpack('<I', data[i: i + 4])1595i += 415961597self.salt = data[i: i + 8]1598i += 9 # 8 + 1(filler)15991600self.server_capabilities = struct.unpack('<H', data[i: i + 2])[0]1601i += 216021603if len(data) >= i + 6:1604lang, stat, cap_h, salt_len = struct.unpack('<BHHB', data[i: i + 6])1605i += 61606# TODO: deprecate server_language and server_charset.1607# mysqlclient-python doesn't provide it.1608self.server_language = lang1609try:1610self.server_charset = charset_by_id(lang).name1611except KeyError:1612# unknown collation1613self.server_charset = None16141615self.server_status = stat1616if DEBUG:1617print('server_status: %x' % stat)16181619self.server_capabilities |= cap_h << 161620if DEBUG:1621print('salt_len:', salt_len)1622salt_len = max(12, salt_len - 9)16231624# reserved1625i += 1016261627if len(data) >= i + salt_len:1628# salt_len includes auth_plugin_data_part_1 and filler1629self.salt += data[i: i + salt_len]1630i += salt_len16311632i += 11633# AUTH PLUGIN NAME may appear here.1634if self.server_capabilities & CLIENT.PLUGIN_AUTH and len(data) >= i:1635# Due to Bug#59453 the auth-plugin-name is missing the terminating1636# NUL-char in versions prior to 5.5.10 and 5.6.2.1637# ref: https://dev.mysql.com/doc/internals/en/1638# connection-phase-packets.html#packet-Protocol::Handshake1639# didn't use version checks as mariadb is corrected and reports1640# earlier than those two.1641server_end = data.find(b'\0', i)1642if server_end < 0: # pragma: no cover - very specific upstream bug1643# not found \0 and last field so take it all1644self._auth_plugin_name = data[i:].decode('utf-8')1645else:1646self._auth_plugin_name = data[i:server_end].decode('utf-8')16471648def get_server_info(self):1649return self.server_version16501651Warning = err.Warning1652Error = err.Error1653InterfaceError = err.InterfaceError1654DatabaseError = err.DatabaseError1655DataError = err.DataError1656OperationalError = err.OperationalError1657IntegrityError = err.IntegrityError1658InternalError = err.InternalError1659ProgrammingError = err.ProgrammingError1660NotSupportedError = err.NotSupportedError166116621663class MySQLResult:1664"""1665Results of a SQL query.16661667Parameters1668----------1669connection : Connection1670The connection the result came from.1671unbuffered : bool, optional1672Should the reads be unbuffered?16731674"""16751676def __init__(self, connection, unbuffered=False):1677self.connection = connection1678self.affected_rows = None1679self.insert_id = None1680self.server_status = None1681self.warning_count = 01682self.message = None1683self.field_count = 01684self.description = None1685self.rows = None1686self.has_next = None1687self.unbuffered_active = False1688self.converters = []1689self.fields = []1690self.encoding_errors = self.connection.encoding_errors1691if unbuffered:1692try:1693self.init_unbuffered_query()1694except Exception:1695self.connection = None1696self.unbuffered_active = False1697raise16981699def __del__(self):1700if self.unbuffered_active:1701self._finish_unbuffered_query()17021703def read(self):1704try:1705first_packet = self.connection._read_packet()17061707if first_packet.is_ok_packet():1708self._read_ok_packet(first_packet)1709elif first_packet.is_load_local_packet():1710self._read_load_local_packet(first_packet)1711else:1712self._read_result_packet(first_packet)1713finally:1714self.connection = None17151716def init_unbuffered_query(self):1717"""1718Initialize an unbuffered query.17191720Raises1721------1722OperationalError : If the connection to the MySQL server is lost.1723InternalError : Other errors.17241725"""1726self.unbuffered_active = True1727first_packet = self.connection._read_packet()17281729if first_packet.is_ok_packet():1730self._read_ok_packet(first_packet)1731self.unbuffered_active = False1732self.connection = None1733elif first_packet.is_load_local_packet():1734self._read_load_local_packet(first_packet)1735self.unbuffered_active = False1736self.connection = None1737else:1738self.field_count = first_packet.read_length_encoded_integer()1739self._get_descriptions()17401741# Apparently, MySQLdb picks this number because it's the maximum1742# value of a 64bit unsigned integer. Since we're emulating MySQLdb,1743# we set it to this instead of None, which would be preferred.1744self.affected_rows = 1844674407370955161517451746def _read_ok_packet(self, first_packet):1747ok_packet = OKPacketWrapper(first_packet)1748self.affected_rows = ok_packet.affected_rows1749self.insert_id = ok_packet.insert_id1750self.server_status = ok_packet.server_status1751self.warning_count = ok_packet.warning_count1752self.message = ok_packet.message1753self.has_next = ok_packet.has_next17541755def _read_load_local_packet(self, first_packet):1756if not self.connection._local_infile:1757raise RuntimeError(1758'**WARN**: Received LOAD_LOCAL packet but local_infile option is false.',1759)1760load_packet = LoadLocalPacketWrapper(first_packet)1761sender = LoadLocalFile(load_packet.filename, self.connection)1762try:1763sender.send_data()1764except Exception:1765self.connection._read_packet() # skip ok packet1766raise17671768ok_packet = self.connection._read_packet()1769if (1770not ok_packet.is_ok_packet()1771): # pragma: no cover - upstream induced protocol error1772raise err.OperationalError(1773CR.CR_COMMANDS_OUT_OF_SYNC,1774'Commands Out of Sync',1775)1776self._read_ok_packet(ok_packet)17771778def _check_packet_is_eof(self, packet):1779if not packet.is_eof_packet():1780return False1781# TODO: Support CLIENT.DEPRECATE_EOF1782# 1) Add DEPRECATE_EOF to CAPABILITIES1783# 2) Mask CAPABILITIES with server_capabilities1784# 3) if server_capabilities & CLIENT.DEPRECATE_EOF: use OKPacketWrapper1785# instead of EOFPacketWrapper1786wp = EOFPacketWrapper(packet)1787self.warning_count = wp.warning_count1788self.has_next = wp.has_next1789return True17901791def _read_result_packet(self, first_packet):1792self.field_count = first_packet.read_length_encoded_integer()1793self._get_descriptions()1794self._read_rowdata_packet()17951796def _read_rowdata_packet_unbuffered(self):1797# Check if in an active query1798if not self.unbuffered_active:1799return18001801# EOF1802packet = self.connection._read_packet()1803if self._check_packet_is_eof(packet):1804self.unbuffered_active = False1805self.connection = None1806self.rows = None1807return18081809row = self._read_row_from_packet(packet)1810self.affected_rows = 11811self.rows = (row,) # rows should tuple of row for MySQL-python compatibility.1812return row18131814def _finish_unbuffered_query(self):1815# After much reading on the MySQL protocol, it appears that there is,1816# in fact, no way to stop MySQL from sending all the data after1817# executing a query, so we just spin, and wait for an EOF packet.1818while self.unbuffered_active and self.connection._sock is not None:1819try:1820packet = self.connection._read_packet()1821except err.OperationalError as e:1822if e.args[0] in (1823ER.QUERY_TIMEOUT,1824ER.STATEMENT_TIMEOUT,1825):1826# if the query timed out we can simply ignore this error1827self.unbuffered_active = False1828self.connection = None1829return18301831raise18321833if self._check_packet_is_eof(packet):1834self.unbuffered_active = False1835self.connection = None # release reference to kill cyclic reference.18361837def _read_rowdata_packet(self):1838"""Read a rowdata packet for each data row in the result set."""1839rows = []1840while True:1841packet = self.connection._read_packet()1842if self._check_packet_is_eof(packet):1843self.connection = None # release reference to kill cyclic reference.1844break1845rows.append(self._read_row_from_packet(packet))18461847self.affected_rows = len(rows)1848self.rows = tuple(rows)18491850def _read_row_from_packet(self, packet):1851row = []1852for i, (encoding, converter) in enumerate(self.converters):1853try:1854data = packet.read_length_coded_string()1855except IndexError:1856# No more columns in this row1857# See https://github.com/PyMySQL/PyMySQL/pull/4341858break1859if data is not None:1860if encoding is not None:1861try:1862data = data.decode(encoding, errors=self.encoding_errors)1863except UnicodeDecodeError:1864raise UnicodeDecodeError(1865'failed to decode string value in column '1866f"'{self.fields[i].name}' using encoding '{encoding}'; " +1867"use the 'encoding_errors' option on the connection " +1868'to specify how to handle this error',1869)1870if DEBUG:1871print('DEBUG: DATA = ', data)1872if converter is not None:1873data = converter(data)1874row.append(data)1875return tuple(row)18761877def _get_descriptions(self):1878"""Read a column descriptor packet for each column in the result."""1879self.fields = []1880self.converters = []1881use_unicode = self.connection.use_unicode1882conn_encoding = self.connection.encoding1883description = []18841885for i in range(self.field_count):1886field = self.connection._read_packet(FieldDescriptorPacket)1887self.fields.append(field)1888description.append(field.description())1889field_type = field.type_code1890if use_unicode:1891if field_type == FIELD_TYPE.JSON:1892# When SELECT from JSON column: charset = binary1893# When SELECT CAST(... AS JSON): charset = connection encoding1894# This behavior is different from TEXT / BLOB.1895# We should decode result by connection encoding regardless charsetnr.1896# See https://github.com/PyMySQL/PyMySQL/issues/4881897encoding = conn_encoding # SELECT CAST(... AS JSON)1898elif field_type in TEXT_TYPES:1899if field.charsetnr == 63: # binary1900# TEXTs with charset=binary means BINARY types.1901encoding = None1902else:1903encoding = conn_encoding1904else:1905# Integers, Dates and Times, and other basic data is encoded in ascii1906encoding = 'ascii'1907else:1908encoding = None1909converter = self.connection.decoders.get(field_type)1910if converter is converters.through:1911converter = None1912if DEBUG:1913print(f'DEBUG: field={field}, converter={converter}')1914self.converters.append((encoding, converter))19151916eof_packet = self.connection._read_packet()1917assert eof_packet.is_eof_packet(), 'Protocol error, expecting EOF'1918self.description = tuple(description)191919201921class MySQLResultSV(MySQLResult):19221923def __init__(self, connection, unbuffered=False):1924MySQLResult.__init__(self, connection, unbuffered=unbuffered)1925self.options = {1926k: v for k, v in dict(1927default_converters=converters.decoders,1928results_type=connection.results_type,1929parse_json=connection.parse_json,1930invalid_values=connection.invalid_values,1931unbuffered=unbuffered,1932encoding_errors=connection.encoding_errors,1933).items() if v is not UNSET1934}19351936def _read_rowdata_packet(self, *args, **kwargs):1937return _singlestoredb_accel.read_rowdata_packet(self, False, *args, **kwargs)19381939def _read_rowdata_packet_unbuffered(self, *args, **kwargs):1940return _singlestoredb_accel.read_rowdata_packet(self, True, *args, **kwargs)194119421943class LoadLocalFile:19441945def __init__(self, filename, connection):1946self.filename = filename1947self.connection = connection19481949def send_data(self):1950"""Send data packets from the local file to the server"""1951if not self.connection._sock:1952raise err.InterfaceError(0, 'Connection is closed')19531954conn = self.connection1955infile = conn._local_infile_stream19561957# 16KB is efficient enough1958packet_size = min(conn.max_allowed_packet, 16 * 1024)19591960try:19611962if self.filename in [':stream:', b':stream:']:19631964if infile is None:1965raise err.OperationalError(1966ER.FILE_NOT_FOUND,1967':stream: specified for LOCAL INFILE, but no stream was supplied',1968)19691970# Binary IO1971elif isinstance(infile, io.RawIOBase):1972while True:1973chunk = infile.read(packet_size)1974if not chunk:1975break1976conn.write_packet(chunk)19771978# Text IO1979elif isinstance(infile, io.TextIOBase):1980while True:1981chunk = infile.read(packet_size)1982if not chunk:1983break1984conn.write_packet(chunk.encode('utf8'))19851986# Iterable of bytes or str1987elif isinstance(infile, Iterable):1988for chunk in infile:1989if not chunk:1990continue1991if isinstance(chunk, str):1992conn.write_packet(chunk.encode('utf8'))1993else:1994conn.write_packet(chunk)19951996# Queue (empty value ends the iteration)1997elif isinstance(infile, queue.Queue):1998while True:1999chunk = infile.get()2000if not chunk:2001break2002if isinstance(chunk, str):2003conn.write_packet(chunk.encode('utf8'))2004else:2005conn.write_packet(chunk)20062007else:2008raise err.OperationalError(2009ER.FILE_NOT_FOUND,2010':stream: specified for LOCAL INFILE, ' +2011f'but stream type is unrecognized: {infile}',2012)20132014else:2015try:2016with open(self.filename, 'rb') as open_file:2017while True:2018chunk = open_file.read(packet_size)2019if not chunk:2020break2021conn.write_packet(chunk)2022except OSError:2023raise err.OperationalError(2024ER.FILE_NOT_FOUND,2025f"Can't find file '{self.filename!s}'",2026)20272028finally:2029if not conn._closed:2030# send the empty packet to signify we are done sending data2031conn.write_packet(b'')203220332034