Path: blob/main/singlestoredb/mysql/cursors.py
469 views
# type: ignore1import re2from collections import namedtuple34from . import err5from ..connection import Cursor as BaseCursor6from ..utils import results7from ..utils.debug import log_query8from ..utils.results import get_schema910try:11from pydantic import BaseModel12has_pydantic = True13except ImportError:14has_pydantic = False151617#: Regular expression for :meth:`Cursor.executemany`.18#: executemany only supports simple bulk insert.19#: You can use it to load large dataset.20RE_INSERT_VALUES = re.compile(21r'\s*((?:INSERT|REPLACE)\b.+\bVALUES?\s*)'22+ r'(\(\s*(?:%s|%\(.+\)s)\s*(?:,\s*(?:%s|%\(.+\)s)\s*)*\))'23+ r'(\s*(?:ON DUPLICATE.*)?);?\s*\Z',24re.IGNORECASE | re.DOTALL,25)262728class Cursor(BaseCursor):29"""30This is the object used to interact with the database.3132Do not create an instance of a Cursor yourself. Call33connection.Connection.cursor().3435See `Cursor <https://www.python.org/dev/peps/pep-0249/#cursor-objects>`_ in36the specification.3738Parameters39----------40connection : Connection41The connection the cursor is associated with.4243"""4445#: Max statement size which :meth:`executemany` generates.46#:47#: Max size of allowed statement is max_allowed_packet - packet_header_size.48#: Default value of max_allowed_packet is 1048576.49max_stmt_length = 10240005051def __init__(self, connection):52self._connection = connection53self.warning_count = 054self._description = None55self._format_schema = None56self._rownumber = 057self.rowcount = -158self.arraysize = 159self._executed = None60self._result = None61self._rows = None62self.lastrowid = None6364@property65def messages(self):66# TODO67return []6869@property70def description(self):71return self._description7273@property74def _schema(self):75return self._format_schema7677@property78def connection(self):79return self._connection8081@property82def rownumber(self):83return self._rownumber8485def close(self):86"""Closing a cursor just exhausts all remaining data."""87conn = self._connection88if conn is None:89return90try:91while self.nextset():92pass93finally:94self._connection = None9596@property97def open(self) -> bool:98conn = self._connection99if conn is None:100return False101return True102103def is_connected(self):104return self.open105106def __enter__(self):107return self108109def __exit__(self, *exc_info):110del exc_info111self.close()112113def _get_db(self):114if not self._connection:115raise err.ProgrammingError('Cursor closed')116return self._connection117118def _check_executed(self):119if not self._executed:120raise err.ProgrammingError('execute() first')121122def _conv_row(self, row):123return row124125def setinputsizes(self, *args):126"""Does nothing, required by DB API."""127128def setoutputsizes(self, *args):129"""Does nothing, required by DB API."""130131setoutputsize = setoutputsizes132133def _nextset(self, unbuffered=False):134"""Get the next query set."""135conn = self._get_db()136current_result = self._result137if current_result is None or current_result is not conn._result:138return None139if not current_result.has_next:140return None141self._result = None142self._clear_result()143conn.next_result(unbuffered=unbuffered)144self._do_get_result()145return True146147def nextset(self):148return self._nextset(False)149150def _escape_args(self, args, conn):151dtype = type(args)152literal = conn.literal153if dtype is tuple or dtype is list or isinstance(args, (tuple, list)):154return tuple(literal(arg) for arg in args)155elif dtype is dict or isinstance(args, dict):156return {key: literal(val) for (key, val) in args.items()}157elif has_pydantic and isinstance(args, BaseModel):158return {key: literal(val) for (key, val) in args.model_dump().items()}159# If it's not a dictionary let's try escaping it anyways.160# Worst case it will throw a Value error161return conn.escape(args)162163def mogrify(self, query, args=None):164"""165Returns the exact string sent to the database by calling the execute() method.166167This method follows the extension to the DB API 2.0 followed by Psycopg.168169Parameters170----------171query : str172Query to mogrify.173args : Sequence[Any] or Dict[str, Any] or Any, optional174Parameters used with query. (optional)175176Returns177-------178str : The query with argument binding applied.179180"""181conn = self._get_db()182183if args:184query = query % self._escape_args(args, conn)185186return query187188def execute(self, query, args=None, infile_stream=None):189"""190Execute a query.191192If args is a list or tuple, :1, :2, etc. can be used as a193placeholder in the query. If args is a dict, :name can be used194as a placeholder in the query.195196Parameters197----------198query : str199Query to execute.200args : Sequence[Any] or Dict[str, Any] or Any, optional201Parameters used with query. (optional)202infile_stream : io.BytesIO or Iterator[bytes], optional203Data stream for ``LOCAL INFILE`` statements204205Returns206-------207int : Number of affected rows.208209"""210while self.nextset():211pass212213log_query(query, args)214215query = self.mogrify(query, args)216217result = self._query(query, infile_stream=infile_stream)218self._executed = query219return result220221def executemany(self, query, args=None):222"""223Run several data against one query.224225This method improves performance on multiple-row INSERT and226REPLACE. Otherwise it is equivalent to looping over args with227execute().228229Parameters230----------231query : str,232Query to execute.233args : Sequnce[Any], optional234Sequence of sequences or mappings. It is used as parameter.235236Returns237-------238int : Number of rows affected, if any.239240"""241if args is None or len(args) == 0:242return243244m = RE_INSERT_VALUES.match(query)245if m:246q_prefix = m.group(1) % ()247q_values = m.group(2).rstrip()248q_postfix = m.group(3) or ''249assert q_values[0] == '(' and q_values[-1] == ')'250return self._do_execute_many(251q_prefix,252q_values,253q_postfix,254args,255self.max_stmt_length,256self._get_db().encoding,257)258259self.rowcount = sum(self.execute(query, arg) for arg in args)260return self.rowcount261262def _do_execute_many(263self, prefix, values, postfix, args, max_stmt_length, encoding,264):265conn = self._get_db()266escape = self._escape_args267if isinstance(prefix, str):268prefix = prefix.encode(encoding)269if isinstance(postfix, str):270postfix = postfix.encode(encoding)271sql = bytearray(prefix)272# Detect dataframes273if hasattr(args, 'itertuples'):274args = args.itertuples(index=False)275else:276args = iter(args)277v = values % escape(next(args), conn)278if isinstance(v, str):279v = v.encode(encoding, 'surrogateescape')280sql += v281rows = 0282for arg in args:283v = values % escape(arg, conn)284if type(v) is str or isinstance(v, str):285v = v.encode(encoding, 'surrogateescape')286if len(sql) + len(v) + len(postfix) + 1 > max_stmt_length:287rows += self.execute(sql + postfix)288sql = bytearray(prefix)289else:290sql += b','291sql += v292rows += self.execute(sql + postfix)293self.rowcount = rows294return rows295296def callproc(self, procname, args=()):297"""298Execute stored procedure procname with args.299300Compatibility warning: PEP-249 specifies that any modified301parameters must be returned. This is currently impossible302as they are only available by storing them in a server303variable and then retrieved by a query. Since stored304procedures return zero or more result sets, there is no305reliable way to get at OUT or INOUT parameters via callproc.306The server variables are named @_procname_n, where procname307is the parameter above and n is the position of the parameter308(from zero). Once all result sets generated by the procedure309have been fetched, you can issue a SELECT @_procname_0, ...310query using .execute() to get any OUT or INOUT values.311312Compatibility warning: The act of calling a stored procedure313itself creates an empty result set. This appears after any314result sets generated by the procedure. This is non-standard315behavior with respect to the DB-API. Be sure to use nextset()316to advance through all result sets; otherwise you may get317disconnected.318319Parameters320----------321procname : str322Name of procedure to execute on server.323args : Sequence[Any], optional324Sequence of parameters to use with procedure.325326Returns327-------328Sequence[Any] : The original args.329330"""331conn = self._get_db()332if args:333fmt = f'@_{procname}_%d=%s'334self._query(335'SET %s'336% ','.join(337fmt % (index, conn.escape(arg)) for index, arg in enumerate(args)338),339)340self.nextset()341342q = 'CALL {}({})'.format(343procname,344','.join(['@_%s_%d' % (procname, i) for i in range(len(args))]),345)346self._query(q)347self._executed = q348return args349350def fetchone(self):351"""Fetch the next row."""352self._check_executed()353return self._unchecked_fetchone()354355def _unchecked_fetchone(self):356"""Fetch the next row."""357if self._rows is None or self._rownumber >= len(self._rows):358return None359result = self._rows[self._rownumber]360self._rownumber += 1361return result362363def fetchmany(self, size=None):364"""Fetch several rows."""365self._check_executed()366if self._rows is None:367self.warning_count = self._result.warning_count368return ()369end = self._rownumber + (size or self.arraysize)370result = self._rows[self._rownumber: end]371self._rownumber = min(end, len(self._rows))372return result373374def fetchall(self):375"""Fetch all the rows."""376self._check_executed()377if self._rows is None:378return ()379if self._rownumber:380result = self._rows[self._rownumber:]381else:382result = self._rows383self._rownumber = len(self._rows)384return result385386def scroll(self, value, mode='relative'):387self._check_executed()388if mode == 'relative':389r = self._rownumber + value390elif mode == 'absolute':391r = value392else:393raise err.ProgrammingError('unknown scroll mode %s' % mode)394395if not (0 <= r < len(self._rows)):396raise IndexError('out of range')397self._rownumber = r398399def _query(self, q, infile_stream=None):400conn = self._get_db()401self._clear_result()402conn.query(q, infile_stream=infile_stream)403self._do_get_result()404return self.rowcount405406def _clear_result(self):407self._rownumber = 0408self._result = None409410self.rowcount = 0411self.warning_count = 0412self._description = None413self._format_schema = None414self.lastrowid = None415self._rows = None416417def _do_get_result(self):418conn = self._get_db()419420self._result = result = conn._result421422self.rowcount = result.affected_rows423self.warning_count = result.warning_count424# Affected rows is set to max int64 for compatibility with MySQLdb, but425# the DB-API requires this value to be -1. This happens in unbuffered mode.426if self.rowcount == 18446744073709551615:427self.rowcount = -1428self._description = result.description429if self._description:430self._format_schema = get_schema(431self.connection._results_type,432result.description,433)434self.lastrowid = result.insert_id435self._rows = result.rows436437def __iter__(self):438self._check_executed()439440def fetchall_unbuffered_gen(_unchecked_fetchone=self._unchecked_fetchone):441while True:442out = _unchecked_fetchone()443if out is not None:444yield out445else:446break447return fetchall_unbuffered_gen()448449Warning = err.Warning450Error = err.Error451InterfaceError = err.InterfaceError452DatabaseError = err.DatabaseError453DataError = err.DataError454OperationalError = err.OperationalError455IntegrityError = err.IntegrityError456InternalError = err.InternalError457ProgrammingError = err.ProgrammingError458NotSupportedError = err.NotSupportedError459460461class CursorSV(Cursor):462"""Cursor class for C extension."""463464465class ArrowCursorMixin:466"""Fetch methods for Arrow Tables."""467468def fetchone(self):469return results.results_to_arrow(470self.description, super().fetchone(), single=True, schema=self._schema,471)472473def fetchall(self):474return results.results_to_arrow(475self.description, super().fetchall(), schema=self._schema,476)477478def fetchall_unbuffered(self):479return results.results_to_arrow(480self.description, super().fetchall_unbuffered(), schema=self._schema,481)482483def fetchmany(self, size=None):484return results.results_to_arrow(485self.description, super().fetchmany(size), schema=self._schema,486)487488489class ArrowCursor(ArrowCursorMixin, Cursor):490"""A cursor which returns results as an Arrow Table."""491492493class ArrowCursorSV(ArrowCursorMixin, CursorSV):494"""A cursor which returns results as an Arrow Table for C extension."""495496497class NumpyCursorMixin:498"""Fetch methods for numpy arrays."""499500def fetchone(self):501return results.results_to_numpy(502self.description, super().fetchone(), single=True, schema=self._schema,503)504505def fetchall(self):506return results.results_to_numpy(507self.description, super().fetchall(), schema=self._schema,508)509510def fetchall_unbuffered(self):511return results.results_to_numpy(512self.description, super().fetchall_unbuffered(), schema=self._schema,513)514515def fetchmany(self, size=None):516return results.results_to_numpy(517self.description, super().fetchmany(size), schema=self._schema,518)519520521class NumpyCursor(NumpyCursorMixin, Cursor):522"""A cursor which returns results as a numpy array."""523524525class NumpyCursorSV(NumpyCursorMixin, CursorSV):526"""A cursor which returns results as a numpy array for C extension."""527528529class PandasCursorMixin:530"""Fetch methods for pandas DataFrames."""531532def fetchone(self):533return results.results_to_pandas(534self.description, super().fetchone(), single=True, schema=self._schema,535)536537def fetchall(self):538return results.results_to_pandas(539self.description, super().fetchall(), schema=self._schema,540)541542def fetchall_unbuffered(self):543return results.results_to_pandas(544self.description, super().fetchall_unbuffered(), schema=self._schema,545)546547def fetchmany(self, size=None):548return results.results_to_pandas(549self.description, super().fetchmany(size), schema=self._schema,550)551552553class PandasCursor(PandasCursorMixin, Cursor):554"""A cursor which returns results as a pandas DataFrame."""555556557class PandasCursorSV(PandasCursorMixin, CursorSV):558"""A cursor which returns results as a pandas DataFrame for C extension."""559560561class PolarsCursorMixin:562"""Fetch methods for polars DataFrames."""563564def fetchone(self):565return results.results_to_polars(566self.description, super().fetchone(), single=True, schema=self._schema,567)568569def fetchall(self):570return results.results_to_polars(571self.description, super().fetchall(), schema=self._schema,572)573574def fetchall_unbuffered(self):575return results.results_to_polars(576self.description, super().fetchall_unbuffered(), schema=self._schema,577)578579def fetchmany(self, size=None):580return results.results_to_polars(581self.description, super().fetchmany(size), schema=self._schema,582)583584585class PolarsCursor(PolarsCursorMixin, Cursor):586"""A cursor which returns results as a polars DataFrame."""587588589class PolarsCursorSV(PolarsCursorMixin, CursorSV):590"""A cursor which returns results as a polars DataFrame for C extension."""591592593class DictCursorMixin:594# You can override this to use OrderedDict or other dict-like types.595dict_type = dict596597def _do_get_result(self):598super(DictCursorMixin, self)._do_get_result()599fields = []600if self._description:601for f in self._result.fields:602name = f.name603if name in fields:604name = f.table_name + '.' + name605fields.append(name)606self._fields = fields607608if fields and self._rows:609self._rows = [self._conv_row(r) for r in self._rows]610611def _conv_row(self, row):612if row is None:613return None614return self.dict_type(zip(self._fields, row))615616617class DictCursor(DictCursorMixin, Cursor):618"""A cursor which returns results as a dictionary."""619620621class DictCursorSV(Cursor):622"""A cursor which returns results as a dictionary for C extension."""623624625class NamedtupleCursorMixin:626627def _do_get_result(self):628super(NamedtupleCursorMixin, self)._do_get_result()629fields = []630if self._description:631for f in self._result.fields:632name = f.name633if name in fields:634name = f.table_name + '.' + name635fields.append(name)636self._fields = fields637self._namedtuple = namedtuple('Row', self._fields, rename=True)638639if fields and self._rows:640self._rows = [self._conv_row(r) for r in self._rows]641642def _conv_row(self, row):643if row is None:644return None645return self._namedtuple(*row)646647648class NamedtupleCursor(NamedtupleCursorMixin, Cursor):649"""A cursor which returns results in a named tuple."""650651652class NamedtupleCursorSV(Cursor):653"""A cursor which returns results as a named tuple for C extension."""654655656class SSCursor(Cursor):657"""658Unbuffered Cursor, mainly useful for queries that return a lot of data,659or for connections to remote servers over a slow network.660661Instead of copying every row of data into a buffer, this will fetch662rows as needed. The upside of this is the client uses much less memory,663and rows are returned much faster when traveling over a slow network664or if the result set is very big.665666There are limitations, though. The MySQL protocol doesn't support667returning the total number of rows, so the only way to tell how many rows668there are is to iterate over every row returned. Also, it currently isn't669possible to scroll backwards, as only the current row is held in memory.670671"""672673def _conv_row(self, row):674return row675676def close(self):677conn = self._connection678if conn is None:679return680681if self._result is not None and self._result is conn._result:682self._result._finish_unbuffered_query()683684try:685while self.nextset():686pass687finally:688self._connection = None689690__del__ = close691692def _query(self, q, infile_stream=None):693conn = self._get_db()694self._clear_result()695conn.query(q, unbuffered=True, infile_stream=infile_stream)696self._do_get_result()697return self.rowcount698699def nextset(self):700return self._nextset(unbuffered=True)701702def read_next(self):703"""Read next row."""704return self._conv_row(self._result._read_rowdata_packet_unbuffered())705706def fetchone(self):707"""Fetch next row."""708self._check_executed()709return self._unchecked_fetchone()710711def _unchecked_fetchone(self):712"""Fetch next row."""713row = self.read_next()714if row is None:715self.warning_count = self._result.warning_count716return None717self._rownumber += 1718return row719720def fetchall(self):721"""722Fetch all, as per MySQLdb.723724Pretty useless for large queries, as it is buffered.725See fetchall_unbuffered(), if you want an unbuffered726generator version of this method.727728"""729return list(self.fetchall_unbuffered())730731def fetchall_unbuffered(self):732"""733Fetch all, implemented as a generator.734735This is not a standard DB-API operation, however, it doesn't make736sense to return everything in a list, as that would use ridiculous737memory for large result sets.738739"""740self._check_executed()741742def fetchall_unbuffered_gen(_unchecked_fetchone=self._unchecked_fetchone):743while True:744out = _unchecked_fetchone()745if out is not None:746yield out747else:748break749return fetchall_unbuffered_gen()750751def __iter__(self):752return self.fetchall_unbuffered()753754def fetchmany(self, size=None):755"""Fetch many."""756self._check_executed()757if size is None:758size = self.arraysize759760rows = []761for i in range(size):762row = self.read_next()763if row is None:764self.warning_count = self._result.warning_count765break766rows.append(row)767self._rownumber += 1768return rows769770def scroll(self, value, mode='relative'):771self._check_executed()772773if mode == 'relative':774if value < 0:775raise err.NotSupportedError(776'Backwards scrolling not supported by this cursor',777)778779for _ in range(value):780self.read_next()781self._rownumber += value782elif mode == 'absolute':783if value < self._rownumber:784raise err.NotSupportedError(785'Backwards scrolling not supported by this cursor',786)787788end = value - self._rownumber789for _ in range(end):790self.read_next()791self._rownumber = value792else:793raise err.ProgrammingError('unknown scroll mode %s' % mode)794795796class SSCursorSV(SSCursor):797"""An unbuffered cursor for use with PyMySQLsv."""798799def _unchecked_fetchone(self):800"""Fetch next row."""801row = self._result._read_rowdata_packet_unbuffered(1)802if row is None:803return None804self._rownumber += 1805return row806807def fetchone(self):808"""Fetch next row."""809self._check_executed()810return self._unchecked_fetchone()811812def fetchmany(self, size=None):813"""Fetch many."""814self._check_executed()815if size is None:816size = self.arraysize817out = self._result._read_rowdata_packet_unbuffered(size)818if out is None:819return []820if size == 1:821self._rownumber += 1822return [out]823self._rownumber += len(out)824return out825826def scroll(self, value, mode='relative'):827self._check_executed()828829if mode == 'relative':830if value < 0:831raise err.NotSupportedError(832'Backwards scrolling not supported by this cursor',833)834835self._result._read_rowdata_packet_unbuffered(value)836self._rownumber += value837elif mode == 'absolute':838if value < self._rownumber:839raise err.NotSupportedError(840'Backwards scrolling not supported by this cursor',841)842843end = value - self._rownumber844self._result._read_rowdata_packet_unbuffered(end)845self._rownumber = value846else:847raise err.ProgrammingError('unknown scroll mode %s' % mode)848849850class SSDictCursor(DictCursorMixin, SSCursor):851"""An unbuffered cursor, which returns results as a dictionary."""852853854class SSDictCursorSV(SSCursorSV):855"""An unbuffered cursor for the C extension, which returns a dictionary."""856857858class SSNamedtupleCursor(NamedtupleCursorMixin, SSCursor):859"""An unbuffered cursor, which returns results as a named tuple."""860861862class SSNamedtupleCursorSV(SSCursorSV):863"""An unbuffered cursor for the C extension, which returns results as named tuple."""864865866class SSArrowCursor(ArrowCursorMixin, SSCursor):867"""An unbuffered cursor, which returns results as an Arrow Table."""868869870class SSArrowCursorSV(ArrowCursorMixin, SSCursorSV):871"""An unbuffered cursor, which returns results as an Arrow Table (accelerated)."""872873874class SSNumpyCursor(NumpyCursorMixin, SSCursor):875"""An unbuffered cursor, which returns results as a numpy array."""876877878class SSNumpyCursorSV(NumpyCursorMixin, SSCursorSV):879"""An unbuffered cursor, which returns results as a numpy array (accelerated)."""880881882class SSPandasCursor(PandasCursorMixin, SSCursor):883"""An unbuffered cursor, which returns results as a pandas DataFrame."""884885886class SSPandasCursorSV(PandasCursorMixin, SSCursorSV):887"""An unbuffered cursor, which returns results as a pandas DataFrame (accelerated)."""888889890class SSPolarsCursor(PolarsCursorMixin, SSCursor):891"""An unbuffered cursor, which returns results as a polars DataFrame."""892893894class SSPolarsCursorSV(PolarsCursorMixin, SSCursorSV):895"""An unbuffered cursor, which returns results as a polars DataFrame (accelerated)."""896897898