Path: blob/main/singlestoredb/mysql/cursors.py
801 views
# type: ignore1import re2from collections import namedtuple34from . import err5from ..connection import Cursor as BaseCursor6from ..utils import results7from ..utils.debug import log_query8from ..utils.mogrify import should_interpolate_query9from ..utils.results import get_schema1011try:12from pydantic import BaseModel13has_pydantic = True14except ImportError:15has_pydantic = False161718#: Regular expression for :meth:`Cursor.executemany`.19#: executemany only supports simple bulk insert.20#: You can use it to load large dataset.21RE_INSERT_VALUES = re.compile(22r'\s*((?:INSERT|REPLACE)\b.+\bVALUES?\s*)'23+ r'(\(\s*(?:%s|%\(.+\)s)\s*(?:,\s*(?:%s|%\(.+\)s)\s*)*\))'24+ r'(\s*(?:ON DUPLICATE.*)?);?\s*\Z',25re.IGNORECASE | re.DOTALL,26)272829class Cursor(BaseCursor):30"""31This is the object used to interact with the database.3233Do not create an instance of a Cursor yourself. Call34connection.Connection.cursor().3536See `Cursor <https://www.python.org/dev/peps/pep-0249/#cursor-objects>`_ in37the specification.3839Parameters40----------41connection : Connection42The connection the cursor is associated with.4344"""4546#: Max statement size which :meth:`executemany` generates.47#:48#: Max size of allowed statement is max_allowed_packet - packet_header_size.49#: Default value of max_allowed_packet is 1048576.50max_stmt_length = 10240005152def __init__(self, connection):53self._connection = connection54self.warning_count = 055self._description = None56self._format_schema = None57self._rownumber = 058self.rowcount = -159self.arraysize = 160self._executed = None61self._result = None62self._rows = None63self.lastrowid = None6465@property66def messages(self):67# TODO68return []6970@property71def description(self):72return self._description7374@property75def _schema(self):76return self._format_schema7778@property79def connection(self):80return self._connection8182@property83def rownumber(self):84return self._rownumber8586def close(self):87"""Closing a cursor just exhausts all remaining data."""88conn = self._connection89if conn is None:90return91try:92while self.nextset():93pass94finally:95self._connection = None9697@property98def open(self) -> bool:99conn = self._connection100if conn is None:101return False102return True103104def is_connected(self):105return self.open106107def __enter__(self):108return self109110def __exit__(self, *exc_info):111del exc_info112self.close()113114def _get_db(self):115if not self._connection:116raise err.ProgrammingError('Cursor closed')117return self._connection118119def _check_executed(self):120if not self._executed:121raise err.ProgrammingError('execute() first')122123def _conv_row(self, row):124return row125126def setinputsizes(self, *args):127"""Does nothing, required by DB API."""128129def setoutputsizes(self, *args):130"""Does nothing, required by DB API."""131132setoutputsize = setoutputsizes133134def _nextset(self, unbuffered=False):135"""Get the next query set."""136conn = self._get_db()137current_result = self._result138if current_result is None or current_result is not conn._result:139return None140if not current_result.has_next:141return None142self._result = None143self._clear_result()144conn.next_result(unbuffered=unbuffered)145self._do_get_result()146return True147148def nextset(self):149return self._nextset(False)150151def _escape_args(self, args, conn):152dtype = type(args)153literal = conn.literal154if dtype is tuple or dtype is list or isinstance(args, (tuple, list)):155return tuple(literal(arg) for arg in args)156elif dtype is dict or isinstance(args, dict):157return {key: literal(val) for (key, val) in args.items()}158elif has_pydantic and isinstance(args, BaseModel):159return {key: literal(val) for (key, val) in args.model_dump().items()}160# If it's not a dictionary let's try escaping it anyways.161# Worst case it will throw a Value error162return conn.escape(args)163164def mogrify(self, query, args=None):165"""166Returns the exact string sent to the database by calling the execute() method.167168This method follows the extension to the DB API 2.0 followed by Psycopg.169170Parameters171----------172query : str173Query to mogrify.174args : Sequence[Any] or Dict[str, Any] or Any, optional175Parameters used with query. (optional)176177Returns178-------179str : The query with argument binding applied.180181"""182conn = self._get_db()183184if should_interpolate_query(conn.interpolate_query_with_empty_args, args):185query = query % self._escape_args(args, conn)186187return query188189def execute(self, query, args=None, infile_stream=None):190"""191Execute a query.192193If args is a list or tuple, :1, :2, etc. can be used as a194placeholder in the query. If args is a dict, :name can be used195as a placeholder in the query.196197Parameters198----------199query : str200Query to execute.201args : Sequence[Any] or Dict[str, Any] or Any, optional202Parameters used with query. (optional)203infile_stream : io.BytesIO or Iterator[bytes], optional204Data stream for ``LOCAL INFILE`` statements205206Returns207-------208int : Number of affected rows.209210"""211while self.nextset():212pass213214log_query(query, args)215216query = self.mogrify(query, args)217218result = self._query(query, infile_stream=infile_stream)219self._executed = query220return result221222def executemany(self, query, args=None):223"""224Run several data against one query.225226This method improves performance on multiple-row INSERT and227REPLACE. Otherwise it is equivalent to looping over args with228execute().229230Parameters231----------232query : str,233Query to execute.234args : Sequnce[Any], optional235Sequence of sequences or mappings. It is used as parameter.236237Returns238-------239int : Number of rows affected, if any.240241"""242if args is None or len(args) == 0:243return244245m = RE_INSERT_VALUES.match(query)246if m:247q_prefix = m.group(1) % ()248q_values = m.group(2).rstrip()249q_postfix = m.group(3) or ''250assert q_values[0] == '(' and q_values[-1] == ')'251return self._do_execute_many(252q_prefix,253q_values,254q_postfix,255args,256self.max_stmt_length,257self._get_db().encoding,258)259260self.rowcount = sum(self.execute(query, arg) for arg in args)261return self.rowcount262263def _do_execute_many(264self, prefix, values, postfix, args, max_stmt_length, encoding,265):266conn = self._get_db()267escape = self._escape_args268if isinstance(prefix, str):269prefix = prefix.encode(encoding)270if isinstance(postfix, str):271postfix = postfix.encode(encoding)272sql = bytearray(prefix)273# Detect dataframes274if hasattr(args, 'itertuples'):275args = args.itertuples(index=False)276else:277args = iter(args)278v = values % escape(next(args), conn)279if isinstance(v, str):280v = v.encode(encoding, 'surrogateescape')281sql += v282rows = 0283for arg in args:284v = values % escape(arg, conn)285if type(v) is str or isinstance(v, str):286v = v.encode(encoding, 'surrogateescape')287if len(sql) + len(v) + len(postfix) + 1 > max_stmt_length:288rows += self.execute(sql + postfix)289sql = bytearray(prefix)290else:291sql += b','292sql += v293rows += self.execute(sql + postfix)294self.rowcount = rows295return rows296297def callproc(self, procname, args=()):298"""299Execute stored procedure procname with args.300301Compatibility warning: PEP-249 specifies that any modified302parameters must be returned. This is currently impossible303as they are only available by storing them in a server304variable and then retrieved by a query. Since stored305procedures return zero or more result sets, there is no306reliable way to get at OUT or INOUT parameters via callproc.307The server variables are named @_procname_n, where procname308is the parameter above and n is the position of the parameter309(from zero). Once all result sets generated by the procedure310have been fetched, you can issue a SELECT @_procname_0, ...311query using .execute() to get any OUT or INOUT values.312313Compatibility warning: The act of calling a stored procedure314itself creates an empty result set. This appears after any315result sets generated by the procedure. This is non-standard316behavior with respect to the DB-API. Be sure to use nextset()317to advance through all result sets; otherwise you may get318disconnected.319320Parameters321----------322procname : str323Name of procedure to execute on server.324args : Sequence[Any], optional325Sequence of parameters to use with procedure.326327Returns328-------329Sequence[Any] : The original args.330331"""332conn = self._get_db()333if args:334fmt = f'@_{procname}_%d=%s'335self._query(336'SET %s'337% ','.join(338fmt % (index, conn.escape(arg)) for index, arg in enumerate(args)339),340)341self.nextset()342343q = 'CALL {}({})'.format(344procname,345','.join(['@_%s_%d' % (procname, i) for i in range(len(args))]),346)347self._query(q)348self._executed = q349return args350351def fetchone(self):352"""Fetch the next row."""353self._check_executed()354return self._unchecked_fetchone()355356def _unchecked_fetchone(self):357"""Fetch the next row."""358if self._rows is None or self._rownumber >= len(self._rows):359return None360result = self._rows[self._rownumber]361self._rownumber += 1362return result363364def fetchmany(self, size=None):365"""Fetch several rows."""366self._check_executed()367if self._rows is None:368self.warning_count = self._result.warning_count369return ()370end = self._rownumber + (size or self.arraysize)371result = self._rows[self._rownumber: end]372self._rownumber = min(end, len(self._rows))373return result374375def fetchall(self):376"""Fetch all the rows."""377self._check_executed()378if self._rows is None:379return ()380if self._rownumber:381result = self._rows[self._rownumber:]382else:383result = self._rows384self._rownumber = len(self._rows)385return result386387def scroll(self, value, mode='relative'):388self._check_executed()389if mode == 'relative':390r = self._rownumber + value391elif mode == 'absolute':392r = value393else:394raise err.ProgrammingError('unknown scroll mode %s' % mode)395396if not (0 <= r < len(self._rows)):397raise IndexError('out of range')398self._rownumber = r399400def _query(self, q, infile_stream=None):401conn = self._get_db()402self._clear_result()403conn.query(q, infile_stream=infile_stream)404self._do_get_result()405return self.rowcount406407def _clear_result(self):408self._rownumber = 0409self._result = None410411self.rowcount = 0412self.warning_count = 0413self._description = None414self._format_schema = None415self.lastrowid = None416self._rows = None417418def _do_get_result(self):419conn = self._get_db()420421self._result = result = conn._result422423self.rowcount = result.affected_rows424self.warning_count = result.warning_count425# Affected rows is set to max int64 for compatibility with MySQLdb, but426# the DB-API requires this value to be -1. This happens in unbuffered mode.427if self.rowcount == 18446744073709551615:428self.rowcount = -1429self._description = result.description430if self._description:431self._format_schema = get_schema(432self.connection._results_type,433result.description,434)435self.lastrowid = result.insert_id436self._rows = result.rows437438def __iter__(self):439self._check_executed()440441def fetchall_unbuffered_gen(_unchecked_fetchone=self._unchecked_fetchone):442while True:443out = _unchecked_fetchone()444if out is not None:445yield out446else:447break448return fetchall_unbuffered_gen()449450Warning = err.Warning451Error = err.Error452InterfaceError = err.InterfaceError453DatabaseError = err.DatabaseError454DataError = err.DataError455OperationalError = err.OperationalError456IntegrityError = err.IntegrityError457InternalError = err.InternalError458ProgrammingError = err.ProgrammingError459NotSupportedError = err.NotSupportedError460461462class CursorSV(Cursor):463"""Cursor class for C extension."""464465466class ArrowCursorMixin:467"""Fetch methods for Arrow Tables."""468469def fetchone(self):470return results.results_to_arrow(471self.description, super().fetchone(), single=True, schema=self._schema,472)473474def fetchall(self):475return results.results_to_arrow(476self.description, super().fetchall(), schema=self._schema,477)478479def fetchall_unbuffered(self):480return results.results_to_arrow(481self.description, super().fetchall_unbuffered(), schema=self._schema,482)483484def fetchmany(self, size=None):485return results.results_to_arrow(486self.description, super().fetchmany(size), schema=self._schema,487)488489490class ArrowCursor(ArrowCursorMixin, Cursor):491"""A cursor which returns results as an Arrow Table."""492493494class ArrowCursorSV(ArrowCursorMixin, CursorSV):495"""A cursor which returns results as an Arrow Table for C extension."""496497498class NumpyCursorMixin:499"""Fetch methods for numpy arrays."""500501def fetchone(self):502return results.results_to_numpy(503self.description, super().fetchone(), single=True, schema=self._schema,504)505506def fetchall(self):507return results.results_to_numpy(508self.description, super().fetchall(), schema=self._schema,509)510511def fetchall_unbuffered(self):512return results.results_to_numpy(513self.description, super().fetchall_unbuffered(), schema=self._schema,514)515516def fetchmany(self, size=None):517return results.results_to_numpy(518self.description, super().fetchmany(size), schema=self._schema,519)520521522class NumpyCursor(NumpyCursorMixin, Cursor):523"""A cursor which returns results as a numpy array."""524525526class NumpyCursorSV(NumpyCursorMixin, CursorSV):527"""A cursor which returns results as a numpy array for C extension."""528529530class PandasCursorMixin:531"""Fetch methods for pandas DataFrames."""532533def fetchone(self):534return results.results_to_pandas(535self.description, super().fetchone(), single=True, schema=self._schema,536)537538def fetchall(self):539return results.results_to_pandas(540self.description, super().fetchall(), schema=self._schema,541)542543def fetchall_unbuffered(self):544return results.results_to_pandas(545self.description, super().fetchall_unbuffered(), schema=self._schema,546)547548def fetchmany(self, size=None):549return results.results_to_pandas(550self.description, super().fetchmany(size), schema=self._schema,551)552553554class PandasCursor(PandasCursorMixin, Cursor):555"""A cursor which returns results as a pandas DataFrame."""556557558class PandasCursorSV(PandasCursorMixin, CursorSV):559"""A cursor which returns results as a pandas DataFrame for C extension."""560561562class PolarsCursorMixin:563"""Fetch methods for polars DataFrames."""564565def fetchone(self):566return results.results_to_polars(567self.description, super().fetchone(), single=True, schema=self._schema,568)569570def fetchall(self):571return results.results_to_polars(572self.description, super().fetchall(), schema=self._schema,573)574575def fetchall_unbuffered(self):576return results.results_to_polars(577self.description, super().fetchall_unbuffered(), schema=self._schema,578)579580def fetchmany(self, size=None):581return results.results_to_polars(582self.description, super().fetchmany(size), schema=self._schema,583)584585586class PolarsCursor(PolarsCursorMixin, Cursor):587"""A cursor which returns results as a polars DataFrame."""588589590class PolarsCursorSV(PolarsCursorMixin, CursorSV):591"""A cursor which returns results as a polars DataFrame for C extension."""592593594class DictCursorMixin:595# You can override this to use OrderedDict or other dict-like types.596dict_type = dict597598def _do_get_result(self):599super(DictCursorMixin, self)._do_get_result()600fields = []601if self._description:602for f in self._result.fields:603name = f.name604if name in fields:605name = f.table_name + '.' + name606fields.append(name)607self._fields = fields608609if fields and self._rows:610self._rows = [self._conv_row(r) for r in self._rows]611612def _conv_row(self, row):613if row is None:614return None615return self.dict_type(zip(self._fields, row))616617618class DictCursor(DictCursorMixin, Cursor):619"""A cursor which returns results as a dictionary."""620621622class DictCursorSV(Cursor):623"""A cursor which returns results as a dictionary for C extension."""624625626class NamedtupleCursorMixin:627628def _do_get_result(self):629super(NamedtupleCursorMixin, self)._do_get_result()630fields = []631if self._description:632for f in self._result.fields:633name = f.name634if name in fields:635name = f.table_name + '.' + name636fields.append(name)637self._fields = fields638self._namedtuple = namedtuple('Row', self._fields, rename=True)639640if fields and self._rows:641self._rows = [self._conv_row(r) for r in self._rows]642643def _conv_row(self, row):644if row is None:645return None646return self._namedtuple(*row)647648649class NamedtupleCursor(NamedtupleCursorMixin, Cursor):650"""A cursor which returns results in a named tuple."""651652653class NamedtupleCursorSV(Cursor):654"""A cursor which returns results as a named tuple for C extension."""655656657class SSCursor(Cursor):658"""659Unbuffered Cursor, mainly useful for queries that return a lot of data,660or for connections to remote servers over a slow network.661662Instead of copying every row of data into a buffer, this will fetch663rows as needed. The upside of this is the client uses much less memory,664and rows are returned much faster when traveling over a slow network665or if the result set is very big.666667There are limitations, though. The MySQL protocol doesn't support668returning the total number of rows, so the only way to tell how many rows669there are is to iterate over every row returned. Also, it currently isn't670possible to scroll backwards, as only the current row is held in memory.671672"""673674def _conv_row(self, row):675return row676677def close(self):678conn = self._connection679if conn is None:680return681682if self._result is not None and self._result is conn._result:683self._result._finish_unbuffered_query()684685try:686while self.nextset():687pass688finally:689self._connection = None690691__del__ = close692693def _query(self, q, infile_stream=None):694conn = self._get_db()695self._clear_result()696conn.query(q, unbuffered=True, infile_stream=infile_stream)697self._do_get_result()698return self.rowcount699700def nextset(self):701return self._nextset(unbuffered=True)702703def read_next(self):704"""Read next row."""705return self._conv_row(self._result._read_rowdata_packet_unbuffered())706707def fetchone(self):708"""Fetch next row."""709self._check_executed()710return self._unchecked_fetchone()711712def _unchecked_fetchone(self):713"""Fetch next row."""714row = self.read_next()715if row is None:716self.warning_count = self._result.warning_count717return None718self._rownumber += 1719return row720721def fetchall(self):722"""723Fetch all, as per MySQLdb.724725Pretty useless for large queries, as it is buffered.726See fetchall_unbuffered(), if you want an unbuffered727generator version of this method.728729"""730return list(self.fetchall_unbuffered())731732def fetchall_unbuffered(self):733"""734Fetch all, implemented as a generator.735736This is not a standard DB-API operation, however, it doesn't make737sense to return everything in a list, as that would use ridiculous738memory for large result sets.739740"""741self._check_executed()742743def fetchall_unbuffered_gen(_unchecked_fetchone=self._unchecked_fetchone):744while True:745out = _unchecked_fetchone()746if out is not None:747yield out748else:749break750return fetchall_unbuffered_gen()751752def __iter__(self):753return self.fetchall_unbuffered()754755def fetchmany(self, size=None):756"""Fetch many."""757self._check_executed()758if size is None:759size = self.arraysize760761rows = []762for i in range(size):763row = self.read_next()764if row is None:765self.warning_count = self._result.warning_count766break767rows.append(row)768self._rownumber += 1769return rows770771def scroll(self, value, mode='relative'):772self._check_executed()773774if mode == 'relative':775if value < 0:776raise err.NotSupportedError(777'Backwards scrolling not supported by this cursor',778)779780for _ in range(value):781self.read_next()782self._rownumber += value783elif mode == 'absolute':784if value < self._rownumber:785raise err.NotSupportedError(786'Backwards scrolling not supported by this cursor',787)788789end = value - self._rownumber790for _ in range(end):791self.read_next()792self._rownumber = value793else:794raise err.ProgrammingError('unknown scroll mode %s' % mode)795796797class SSCursorSV(SSCursor):798"""An unbuffered cursor for use with PyMySQLsv."""799800def _unchecked_fetchone(self):801"""Fetch next row."""802row = self._result._read_rowdata_packet_unbuffered(1)803if row is None:804return None805self._rownumber += 1806return row807808def fetchone(self):809"""Fetch next row."""810self._check_executed()811return self._unchecked_fetchone()812813def fetchmany(self, size=None):814"""Fetch many."""815self._check_executed()816if size is None:817size = self.arraysize818out = self._result._read_rowdata_packet_unbuffered(size)819if out is None:820return []821if size == 1:822self._rownumber += 1823return [out]824self._rownumber += len(out)825return out826827def scroll(self, value, mode='relative'):828self._check_executed()829830if mode == 'relative':831if value < 0:832raise err.NotSupportedError(833'Backwards scrolling not supported by this cursor',834)835836self._result._read_rowdata_packet_unbuffered(value)837self._rownumber += value838elif mode == 'absolute':839if value < self._rownumber:840raise err.NotSupportedError(841'Backwards scrolling not supported by this cursor',842)843844end = value - self._rownumber845self._result._read_rowdata_packet_unbuffered(end)846self._rownumber = value847else:848raise err.ProgrammingError('unknown scroll mode %s' % mode)849850851class SSDictCursor(DictCursorMixin, SSCursor):852"""An unbuffered cursor, which returns results as a dictionary."""853854855class SSDictCursorSV(SSCursorSV):856"""An unbuffered cursor for the C extension, which returns a dictionary."""857858859class SSNamedtupleCursor(NamedtupleCursorMixin, SSCursor):860"""An unbuffered cursor, which returns results as a named tuple."""861862863class SSNamedtupleCursorSV(SSCursorSV):864"""An unbuffered cursor for the C extension, which returns results as named tuple."""865866867class SSArrowCursor(ArrowCursorMixin, SSCursor):868"""An unbuffered cursor, which returns results as an Arrow Table."""869870871class SSArrowCursorSV(ArrowCursorMixin, SSCursorSV):872"""An unbuffered cursor, which returns results as an Arrow Table (accelerated)."""873874875class SSNumpyCursor(NumpyCursorMixin, SSCursor):876"""An unbuffered cursor, which returns results as a numpy array."""877878879class SSNumpyCursorSV(NumpyCursorMixin, SSCursorSV):880"""An unbuffered cursor, which returns results as a numpy array (accelerated)."""881882883class SSPandasCursor(PandasCursorMixin, SSCursor):884"""An unbuffered cursor, which returns results as a pandas DataFrame."""885886887class SSPandasCursorSV(PandasCursorMixin, SSCursorSV):888"""An unbuffered cursor, which returns results as a pandas DataFrame (accelerated)."""889890891class SSPolarsCursor(PolarsCursorMixin, SSCursor):892"""An unbuffered cursor, which returns results as a polars DataFrame."""893894895class SSPolarsCursorSV(PolarsCursorMixin, SSCursorSV):896"""An unbuffered cursor, which returns results as a polars DataFrame (accelerated)."""897898899