Path: blob/main/singlestoredb/http/connection.py
469 views
#!/usr/bin/env python1"""SingleStoreDB HTTP API interface."""2import datetime3import decimal4import functools5import io6import json7import math8import os9import re10import time11from base64 import b64decode12from typing import Any13from typing import Callable14from typing import Dict15from typing import Iterable16from typing import List17from typing import Optional18from typing import Sequence19from typing import Tuple20from typing import Union21from urllib.parse import urljoin22from urllib.parse import urlparse2324import requests2526try:27import numpy as np28has_numpy = True29except ImportError:30has_numpy = False3132try:33import pygeos34has_pygeos = True35except ImportError:36has_pygeos = False3738try:39import shapely.geometry40import shapely.wkt41has_shapely = True42except ImportError:43has_shapely = False4445try:46import pydantic47has_pydantic = True48except ImportError:49has_pydantic = False5051from .. import connection52from .. import fusion53from .. import types54from ..config import get_option55from ..converters import converters56from ..exceptions import DatabaseError # noqa: F40157from ..exceptions import DataError58from ..exceptions import Error # noqa: F40159from ..exceptions import IntegrityError60from ..exceptions import InterfaceError61from ..exceptions import InternalError62from ..exceptions import NotSupportedError63from ..exceptions import OperationalError64from ..exceptions import ProgrammingError65from ..exceptions import Warning # noqa: F40166from ..utils.convert_rows import convert_rows67from ..utils.debug import log_query68from ..utils.mogrify import mogrify69from ..utils.results import Description70from ..utils.results import format_results71from ..utils.results import get_schema72from ..utils.results import Result737475# DB-API settings76apilevel = '2.0'77paramstyle = 'named'78threadsafety = 1798081_interface_errors = set([820,832013, # CR_SERVER_LOST842006, # CR_SERVER_GONE_ERROR852012, # CR_HANDSHAKE_ERR862004, # CR_IPSOCK_ERROR872014, # CR_COMMANDS_OUT_OF_SYNC88])89_data_errors = set([901406, # ER_DATA_TOO_LONG911441, # ER_DATETIME_FUNCTION_OVERFLOW921365, # ER_DIVISION_BY_ZERO931230, # ER_NO_DEFAULT941171, # ER_PRIMARY_CANT_HAVE_NULL951264, # ER_WARN_DATA_OUT_OF_RANGE961265, # ER_WARN_DATA_TRUNCATED97])98_programming_errors = set([991065, # ER_EMPTY_QUERY1001179, # ER_CANT_DO_THIS_DURING_AN_TRANSACTION1011007, # ER_DB_CREATE_EXISTS1021110, # ER_FIELD_SPECIFIED_TWICE1031111, # ER_INVALID_GROUP_FUNC_USE1041082, # ER_NO_SUCH_INDEX1051741, # ER_NO_SUCH_KEY_VALUE1061146, # ER_NO_SUCH_TABLE1071449, # ER_NO_SUCH_USER1081064, # ER_PARSE_ERROR1091149, # ER_SYNTAX_ERROR1101113, # ER_TABLE_MUST_HAVE_COLUMNS1111112, # ER_UNSUPPORTED_EXTENSION1121102, # ER_WRONG_DB_NAME1131103, # ER_WRONG_TABLE_NAME1141049, # ER_BAD_DB_ERROR1151582, # ER_??? Wrong number of args116])117_integrity_errors = set([1181215, # ER_CANNOT_ADD_FOREIGN1191062, # ER_DUP_ENTRY1201169, # ER_DUP_UNIQUE1211364, # ER_NO_DEFAULT_FOR_FIELD1221216, # ER_NO_REFERENCED_ROW1231452, # ER_NO_REFERENCED_ROW_21241217, # ER_ROW_IS_REFERENCED1251451, # ER_ROW_IS_REFERENCED_21261460, # ER_XAER_OUTSIDE1271401, # ER_XAER_RMERR1281048, # ER_BAD_NULL_ERROR1291264, # ER_DATA_OUT_OF_RANGE1304025, # ER_CONSTRAINT_FAILED1311826, # ER_DUP_CONSTRAINT_NAME132])133134135def get_precision_scale(type_code: str) -> Tuple[Optional[int], Optional[int]]:136"""Parse the precision and scale from a data type."""137if '(' not in type_code:138return (None, None)139m = re.search(r'\(\s*(\d+)\s*,\s*(\d+)\s*\)', type_code)140if m:141return int(m.group(1)), int(m.group(2))142m = re.search(r'\(\s*(\d+)\s*\)', type_code)143if m:144return (int(m.group(1)), None)145raise ValueError(f'Unrecognized type code: {type_code}')146147148def get_exc_type(code: int) -> type:149"""Map error code to DB-API error type."""150if code in _interface_errors:151return InterfaceError152if code in _data_errors:153return DataError154if code in _programming_errors:155return ProgrammingError156if code in _integrity_errors:157return IntegrityError158if code >= 1000:159return OperationalError160return InternalError161162163def identity(x: Any) -> Any:164"""Return input value."""165return x166167168def b64decode_converter(169converter: Callable[..., Any],170x: Optional[str],171encoding: str = 'utf-8',172) -> Optional[bytes]:173"""Decode value before applying converter."""174if x is None:175return None176if converter is None:177return b64decode(x)178return converter(b64decode(x))179180181def encode_timedelta(obj: datetime.timedelta) -> str:182"""Encode timedelta as str."""183seconds = int(obj.seconds) % 60184minutes = int(obj.seconds // 60) % 60185hours = int(obj.seconds // 3600) % 24 + int(obj.days) * 24186if obj.microseconds:187fmt = '{0:02d}:{1:02d}:{2:02d}.{3:06d}'188else:189fmt = '{0:02d}:{1:02d}:{2:02d}'190return fmt.format(hours, minutes, seconds, obj.microseconds)191192193def encode_time(obj: datetime.time) -> str:194"""Encode time as str."""195if obj.microsecond:196fmt = '{0.hour:02}:{0.minute:02}:{0.second:02}.{0.microsecond:06}'197else:198fmt = '{0.hour:02}:{0.minute:02}:{0.second:02}'199return fmt.format(obj)200201202def encode_datetime(obj: datetime.datetime) -> str:203"""Encode datetime as str."""204if obj.microsecond:205fmt = '{0.year:04}-{0.month:02}-{0.day:02} ' \206'{0.hour:02}:{0.minute:02}:{0.second:02}.{0.microsecond:06}'207else:208fmt = '{0.year:04}-{0.month:02}-{0.day:02} ' \209'{0.hour:02}:{0.minute:02}:{0.second:02}'210return fmt.format(obj)211212213def encode_date(obj: datetime.date) -> str:214"""Encode date as str."""215fmt = '{0.year:04}-{0.month:02}-{0.day:02}'216return fmt.format(obj)217218219def encode_struct_time(obj: time.struct_time) -> str:220"""Encode time struct to str."""221return encode_datetime(datetime.datetime(*obj[:6]))222223224def encode_decimal(o: decimal.Decimal) -> str:225"""Encode decimal to str."""226return format(o, 'f')227228229# Most argument encoding is done by the JSON encoder, but these230# are exceptions to the rule.231encoders = {232datetime.datetime: encode_datetime,233datetime.date: encode_date,234datetime.time: encode_time,235datetime.timedelta: encode_timedelta,236time.struct_time: encode_struct_time,237decimal.Decimal: encode_decimal,238}239240241if has_shapely:242encoders[shapely.geometry.Point] = shapely.wkt.dumps243encoders[shapely.geometry.Polygon] = shapely.wkt.dumps244encoders[shapely.geometry.LineString] = shapely.wkt.dumps245246if has_numpy:247248def encode_ndarray(obj: np.ndarray) -> bytes: # type: ignore249"""Encode an ndarray as bytes."""250return obj.tobytes()251252encoders[np.ndarray] = encode_ndarray253254if has_pygeos:255encoders[pygeos.Geometry] = pygeos.io.to_wkt256257258def convert_special_type(259arg: Any,260nan_as_null: bool = False,261inf_as_null: bool = False,262) -> Any:263"""Convert special data type objects."""264dtype = type(arg)265if dtype is float or \266(267has_numpy and dtype in (268np.float16, np.float32, np.float64,269getattr(np, 'float128', np.float64),270)271):272if nan_as_null and math.isnan(arg):273return None274if inf_as_null and math.isinf(arg):275return None276func = encoders.get(dtype, None)277if func is not None:278return func(arg) # type: ignore279return arg280281282def convert_special_params(283params: Optional[Union[Sequence[Any], Dict[str, Any]]] = None,284nan_as_null: bool = False,285inf_as_null: bool = False,286) -> Optional[Union[Sequence[Any], Dict[str, Any]]]:287"""Convert parameters of special data types."""288if params is None:289return params290converter = functools.partial(291convert_special_type,292nan_as_null=nan_as_null,293inf_as_null=inf_as_null,294)295if isinstance(params, Dict):296return {k: converter(v) for k, v in params.items()}297return tuple(map(converter, params))298299300class PyMyField(object):301"""Field for PyMySQL compatibility."""302303def __init__(self, name: str, flags: int, charset: int) -> None:304self.name = name305self.flags = flags306self.charsetnr = charset307308309class PyMyResult(object):310"""Result for PyMySQL compatibility."""311312def __init__(self) -> None:313self.fields: List[PyMyField] = []314self.unbuffered_active = False315316def append(self, item: PyMyField) -> None:317self.fields.append(item)318319320class Cursor(connection.Cursor):321"""322SingleStoreDB HTTP database cursor.323324Cursor objects should not be created directly. They should come from325the `cursor` method on the `Connection` object.326327Parameters328----------329connection : Connection330The HTTP Connection object the cursor belongs to331332"""333334def __init__(self, conn: 'Connection'):335connection.Cursor.__init__(self, conn)336self._connection: Optional[Connection] = conn337self._results: List[List[Tuple[Any, ...]]] = [[]]338self._results_type: str = self._connection._results_type \339if self._connection is not None else 'tuples'340self._row_idx: int = -1341self._result_idx: int = -1342self._descriptions: List[List[Description]] = []343self._schemas: List[Dict[str, Any]] = []344self.arraysize: int = get_option('results.arraysize')345self.rowcount: int = 0346self.lastrowid: Optional[int] = None347self._pymy_results: List[PyMyResult] = []348self._expect_results: bool = False349350@property351def _result(self) -> Optional[PyMyResult]:352"""Return Result object for PyMySQL compatibility."""353if self._result_idx < 0:354return None355return self._pymy_results[self._result_idx]356357@property358def description(self) -> Optional[List[Description]]:359"""Return description for current result set."""360if not self._descriptions:361return None362if self._result_idx >= 0 and self._result_idx < len(self._descriptions):363return self._descriptions[self._result_idx]364return None365366@property367def _schema(self) -> Optional[Any]:368if not self._schemas:369return None370if self._result_idx >= 0 and self._result_idx < len(self._schemas):371return self._schemas[self._result_idx]372return None373374def _post(self, path: str, *args: Any, **kwargs: Any) -> requests.Response:375"""376Invoke a POST request on the HTTP connection.377378Parameters379----------380path : str381The path of the resource382*args : positional parameters, optional383Extra parameters to the POST request384**kwargs : keyword parameters, optional385Extra keyword parameters to the POST request386387Returns388-------389requests.Response390391"""392if self._connection is None:393raise ProgrammingError(errno=2048, msg='Connection is closed.')394if 'timeout' not in kwargs:395kwargs['timeout'] = self._connection.connection_params['connect_timeout']396return self._connection._post(path, *args, **kwargs)397398def callproc(399self, name: str,400params: Optional[Sequence[Any]] = None,401) -> None:402"""403Call a stored procedure.404405Parameters406----------407name : str408Name of the stored procedure409params : sequence, optional410Parameters to the stored procedure411412"""413if self._connection is None:414raise ProgrammingError(errno=2048, msg='Connection is closed.')415416name = connection._name_check(name)417418if not params:419self._execute(f'CALL {name}();', is_callproc=True)420else:421keys = ', '.join(['%s' for i in range(len(params))])422self._execute(f'CALL {name}({keys});', params, is_callproc=True)423424def close(self) -> None:425"""Close the cursor."""426self._connection = None427428def execute(429self, query: str,430args: Optional[Union[Sequence[Any], Dict[str, Any]]] = None,431infile_stream: Optional[ # type: ignore432Union[433io.RawIOBase,434io.TextIOBase,435Iterable[Union[bytes, str]],436connection.InfileQueue,437]438] = None,439) -> int:440"""441Execute a SQL statement.442443Parameters444----------445query : str446The SQL statement to execute447args : iterable or dict, optional448Parameters to substitute into the SQL code449450"""451return self._execute(query, args, infile_stream=infile_stream)452453def _validate_param_subs(454self, query: str,455args: Optional[Union[Sequence[Any], Dict[str, Any]]] = None,456) -> None:457"""Make sure the parameter substitions are valid."""458if args:459if isinstance(args, Sequence):460query = query % tuple(args)461else:462query = query % args463464def _execute_fusion_query(465self,466oper: Union[str, bytes],467params: Optional[Union[Sequence[Any], Dict[str, Any]]] = None,468handler: Any = None,469) -> int:470oper = mogrify(oper, params)471472if isinstance(oper, bytes):473oper = oper.decode('utf-8')474475log_query(oper, None)476477results_type = self._results_type478self._results_type = 'tuples'479try:480mgmt_res = fusion.execute(481self._connection, # type: ignore482oper,483handler=handler,484)485finally:486self._results_type = results_type487488self._descriptions.append(list(mgmt_res.description))489self._schemas.append(get_schema(self._results_type, list(mgmt_res.description)))490self._results.append(list(mgmt_res.rows))491self.rowcount = len(self._results[-1])492493pymy_res = PyMyResult()494for field in mgmt_res.fields:495pymy_res.append(496PyMyField(497field.name,498field.flags,499field.charsetnr,500),501)502503self._pymy_results.append(pymy_res)504505if self._results and self._results[0]:506self._row_idx = 0507self._result_idx = 0508509return self.rowcount510511def _execute(512self, oper: str,513params: Optional[Union[Sequence[Any], Dict[str, Any]]] = None,514is_callproc: bool = False,515infile_stream: Optional[ # type: ignore516Union[517io.RawIOBase,518io.TextIOBase,519Iterable[Union[bytes, str]],520connection.InfileQueue,521]522] = None,523) -> int:524self._descriptions = []525self._schemas = []526self._results = []527self._pymy_results = []528self._row_idx = -1529self._result_idx = -1530self.rowcount = 0531self._expect_results = False532533if self._connection is None:534raise ProgrammingError(errno=2048, msg='Connection is closed.')535536sql_type = 'exec'537if re.match(r'^\s*(select|show|call|echo|describe|with)\s+', oper, flags=re.I):538self._expect_results = True539sql_type = 'query'540541if has_pydantic and isinstance(params, pydantic.BaseModel):542params = params.model_dump()543544self._validate_param_subs(oper, params)545546handler = fusion.get_handler(oper)547if handler is not None:548return self._execute_fusion_query(oper, params, handler=handler)549550oper, params = self._connection._convert_params(oper, params)551552log_query(oper, params)553554data: Dict[str, Any] = dict(sql=oper)555if params is not None:556data['args'] = convert_special_params(557params,558nan_as_null=self._connection.connection_params['nan_as_null'],559inf_as_null=self._connection.connection_params['inf_as_null'],560)561if self._connection._database:562data['database'] = self._connection._database563564if sql_type == 'query':565res = self._post('query/tuples', json=data)566else:567res = self._post('exec', json=data)568569if res.status_code >= 400:570if res.text:571m = re.match(r'^Error\s+(\d+).*?:', res.text)572if m:573code = m.group(1)574msg = res.text.split(':', 1)[-1]575icode = int(code.split()[-1])576else:577icode = res.status_code578msg = res.text579raise get_exc_type(icode)(icode, msg.strip())580raise InterfaceError(errno=res.status_code, msg='HTTP Error')581582out = json.loads(res.text)583584if 'error' in out:585raise OperationalError(586errno=out['error'].get('code', 0),587msg=out['error'].get('message', 'HTTP Error'),588)589590if sql_type == 'query':591# description: (name, type_code, display_size, internal_size,592# precision, scale, null_ok, column_flags, charset)593594# Remove converters for things the JSON parser already converted595http_converters = dict(self._connection.decoders)596http_converters.pop(4, None)597http_converters.pop(5, None)598http_converters.pop(6, None)599http_converters.pop(15, None)600http_converters.pop(245, None)601http_converters.pop(247, None)602http_converters.pop(249, None)603http_converters.pop(250, None)604http_converters.pop(251, None)605http_converters.pop(252, None)606http_converters.pop(253, None)607http_converters.pop(254, None)608609# Merge passed in converters610if self._connection._conv:611for k, v in self._connection._conv.items():612if isinstance(k, int):613http_converters[k] = v614615# Make JSON a string for Arrow616if 'arrow' in self._results_type:617def json_to_str(x: Any) -> Optional[str]:618if x is None:619return None620return json.dumps(x)621http_converters[245] = json_to_str622623# Don't convert date/times in polars624elif 'polars' in self._results_type:625http_converters.pop(7, None)626http_converters.pop(10, None)627http_converters.pop(12, None)628629results = out['results']630631# Convert data to Python types632if results and results[0]:633self._row_idx = 0634self._result_idx = 0635636for result in results:637638pymy_res = PyMyResult()639convs = []640641description: List[Description] = []642for i, col in enumerate(result.get('columns', [])):643charset = 0644flags = 0645data_type = col['dataType'].split('(')[0]646type_code = types.ColumnType.get_code(data_type)647prec, scale = get_precision_scale(col['dataType'])648converter = http_converters.get(type_code, None)649650if 'UNSIGNED' in data_type:651flags = 32652653if data_type.endswith('BLOB') or data_type.endswith('BINARY'):654converter = functools.partial(655b64decode_converter, converter, # type: ignore656)657charset = 63 # BINARY658659if type_code == 0: # DECIMAL660type_code = types.ColumnType.get_code('NEWDECIMAL')661elif type_code == 15: # VARCHAR / VARBINARY662type_code = types.ColumnType.get_code('VARSTRING')663664if converter is not None:665convs.append((i, None, converter))666667description.append(668Description(669str(col['name']), type_code,670None, None, prec, scale,671col.get('nullable', False),672flags, charset,673),674)675pymy_res.append(PyMyField(col['name'], flags, charset))676677self._descriptions.append(description)678self._schemas.append(get_schema(self._results_type, description))679680rows = convert_rows(result.get('rows', []), convs)681682self._results.append(rows)683self._pymy_results.append(pymy_res)684685# For compatibility with PyMySQL/MySQLdb686if is_callproc:687self._results.append([])688689self.rowcount = len(self._results[0])690691else:692# For compatibility with PyMySQL/MySQLdb693if is_callproc:694self._results.append([])695696self.rowcount = out['rowsAffected']697698return self.rowcount699700def executemany(701self, query: str,702args: Optional[Sequence[Union[Sequence[Any], Dict[str, Any]]]] = None,703) -> int:704"""705Execute SQL code against multiple sets of parameters.706707Parameters708----------709query : str710The SQL statement to execute711args : iterable of iterables or dicts, optional712Sets of parameters to substitute into the SQL code713714"""715if self._connection is None:716raise ProgrammingError(errno=2048, msg='Connection is closed.')717718results = []719rowcount = 0720if args is not None and len(args) > 0:721description = []722schema = {}723# Detect dataframes724if hasattr(args, 'itertuples'):725argiter = args.itertuples(index=False) # type: ignore726else:727argiter = iter(args)728for params in argiter:729self.execute(query, params)730if self._descriptions:731description = self._descriptions[-1]732if self._schemas:733schema = self._schemas[-1]734if self._rows is not None:735results.append(self._rows)736rowcount += self.rowcount737self._results = results738self._descriptions = [description for _ in range(len(results))]739self._schemas = [schema for _ in range(len(results))]740else:741self.execute(query)742rowcount += self.rowcount743744self.rowcount = rowcount745746return self.rowcount747748@property749def _has_row(self) -> bool:750"""Determine if a row is available."""751if self._result_idx < 0 or self._result_idx >= len(self._results):752return False753if self._row_idx < 0 or self._row_idx >= len(self._results[self._result_idx]):754return False755return True756757@property758def _rows(self) -> List[Tuple[Any, ...]]:759"""Return current set of rows."""760if not self._has_row:761return []762return self._results[self._result_idx]763764def fetchone(self) -> Optional[Result]:765"""766Fetch a single row from the result set.767768Returns769-------770tuple771Values of the returned row if there are rows remaining772None773If there are no rows left to return774775"""776if self._connection is None:777raise ProgrammingError(errno=2048, msg='Connection is closed')778if not self._expect_results:779raise self._connection.ProgrammingError(msg='No query has been submitted')780if not self._has_row:781return None782out = self._rows[self._row_idx]783self._row_idx += 1784return format_results(785self._results_type,786self.description or [],787out, single=True,788schema=self._schema,789)790791def fetchmany(792self,793size: Optional[int] = None,794) -> Result:795"""796Fetch `size` rows from the result.797798If `size` is not specified, the `arraysize` attribute is used.799800Returns801-------802list of tuples803Values of the returned rows if there are rows remaining804805"""806if self._connection is None:807raise ProgrammingError(errno=2048, msg='Connection is closed')808if not self._expect_results:809raise self._connection.ProgrammingError(msg='No query has been submitted')810if not self._has_row:811if 'dict' in self._results_type:812return {}813return tuple()814if not size:815size = max(int(self.arraysize), 1)816else:817size = max(int(size), 1)818out = self._rows[self._row_idx:self._row_idx+size]819self._row_idx += len(out)820return format_results(821self._results_type, self.description or [],822out, schema=self._schema,823)824825def fetchall(self) -> Result:826"""827Fetch all rows in the result set.828829Returns830-------831list of tuples832Values of the returned rows if there are rows remaining833834"""835if self._connection is None:836raise ProgrammingError(errno=2048, msg='Connection is closed')837if not self._expect_results:838raise self._connection.ProgrammingError(msg='No query has been submitted')839if not self._has_row:840if 'dict' in self._results_type:841return {}842return tuple()843out = list(self._rows[self._row_idx:])844self._row_idx = len(out)845return format_results(846self._results_type, self.description or [],847out, schema=self._schema,848)849850def nextset(self) -> Optional[bool]:851"""Skip to the next available result set."""852if self._connection is None:853raise ProgrammingError(errno=2048, msg='Connection is closed')854855if self._result_idx < 0:856self._row_idx = -1857return None858859self._result_idx += 1860self._row_idx = 0861862if self._result_idx >= len(self._results):863self._result_idx = -1864self._row_idx = -1865return None866867self.rowcount = len(self._results[self._result_idx])868869return True870871def setinputsizes(self, sizes: Sequence[int]) -> None:872"""Predefine memory areas for parameters."""873pass874875def setoutputsize(self, size: int, column: Optional[str] = None) -> None:876"""Set a column buffer size for fetches of large columns."""877pass878879@property880def rownumber(self) -> Optional[int]:881"""882Return the zero-based index of the cursor in the result set.883884Returns885-------886int887888"""889if self._row_idx < 0:890return None891return self._row_idx892893def scroll(self, value: int, mode: str = 'relative') -> None:894"""895Scroll the cursor to the position in the result set.896897Parameters898----------899value : int900Value of the positional move901mode : str902Type of move that should be made: 'relative' or 'absolute'903904"""905if self._connection is None:906raise ProgrammingError(errno=2048, msg='Connection is closed')907if mode == 'relative':908self._row_idx += value909elif mode == 'absolute':910self._row_idx = value911else:912raise ValueError(913f'{mode} is not a valid mode, '914'expecting "relative" or "absolute"',915)916917def next(self) -> Optional[Result]:918"""919Return the next row from the result set for use in iterators.920921Returns922-------923tuple924Values from the next result row925None926If no more rows exist927928"""929if self._connection is None:930raise InterfaceError(errno=2048, msg='Connection is closed')931out = self.fetchone()932if out is None:933raise StopIteration934return out935936__next__ = next937938def __iter__(self) -> Iterable[Tuple[Any, ...]]:939"""Return result iterator."""940return iter(self._rows[self._row_idx:])941942def __enter__(self) -> 'Cursor':943"""Enter a context."""944return self945946def __exit__(947self, exc_type: Optional[object],948exc_value: Optional[Exception], exc_traceback: Optional[str],949) -> None:950"""Exit a context."""951self.close()952953@property954def open(self) -> bool:955"""Check if the cursor is still connected."""956if self._connection is None:957return False958return self._connection.is_connected()959960def is_connected(self) -> bool:961"""962Check if the cursor is still connected.963964Returns965-------966bool967968"""969return self.open970971972class Connection(connection.Connection):973"""974SingleStoreDB HTTP database connection.975976Instances of this object are typically created through the977`connection` function rather than creating them directly.978979See Also980--------981`connect`982983"""984driver = 'https'985paramstyle = 'qmark'986987def __init__(self, **kwargs: Any):988from .. import __version__ as client_version989990if 'SINGLESTOREDB_WORKLOAD_TYPE' in os.environ:991client_version += '+' + os.environ['SINGLESTOREDB_WORKLOAD_TYPE']992993connection.Connection.__init__(self, **kwargs)994995host = kwargs.get('host', get_option('host'))996port = kwargs.get('port', get_option('http_port'))997998self._sess: Optional[requests.Session] = requests.Session()9991000user = kwargs.get('user', get_option('user'))1001password = kwargs.get('password', get_option('password'))1002if user is not None and password is not None:1003self._sess.auth = (user, password)1004elif user is not None:1005self._sess.auth = (user, '')1006self._sess.headers.update({1007'Content-Type': 'application/json',1008'Accept': 'application/json',1009'Accept-Encoding': 'compress,identity',1010'User-Agent': f'SingleStoreDB-Python/{client_version}',1011})10121013if kwargs.get('ssl_disabled', get_option('ssl_disabled')):1014self._sess.verify = False1015else:1016ssl_key = kwargs.get('ssl_key', get_option('ssl_key'))1017ssl_cert = kwargs.get('ssl_cert', get_option('ssl_cert'))1018if ssl_key and ssl_cert:1019self._sess.cert = (ssl_key, ssl_cert)1020elif ssl_cert:1021self._sess.cert = ssl_cert10221023ssl_ca = kwargs.get('ssl_ca', get_option('ssl_ca'))1024if ssl_ca:1025self._sess.verify = ssl_ca10261027ssl_verify_cert = kwargs.get('ssl_verify_cert', True)1028if not ssl_verify_cert:1029self._sess.verify = False10301031if kwargs.get('multi_statements', False):1032raise self.InterfaceError(10330, 'The Data API does not allow multiple '1034'statements within a query',1035)10361037self._version = kwargs.get('version', 'v2')1038self.driver = kwargs.get('driver', 'https')10391040self.encoders = {k: v for (k, v) in converters.items() if type(k) is not int}1041self.decoders = {k: v for (k, v) in converters.items() if type(k) is int}10421043self._database = kwargs.get('database', get_option('database'))1044self._url = f'{self.driver}://{host}:{port}/api/{self._version}/'1045self._host = host1046self._messages: List[Tuple[int, str]] = []1047self._autocommit: bool = True1048self._conv = kwargs.get('conv', None)1049self._in_sync: bool = False1050self._track_env: bool = kwargs.get('track_env', False) \1051or host == 'singlestore.com'10521053@property1054def messages(self) -> List[Tuple[int, str]]:1055return self._messages10561057def connect(self) -> 'Connection':1058"""Connect to the server."""1059return self10601061def _sync_connection(self, kwargs: Dict[str, Any]) -> None:1062"""Synchronize connection with env variable."""1063if self._sess is None:1064raise InterfaceError(errno=2048, msg='Connection is closed.')10651066if self._in_sync:1067return10681069if not self._track_env:1070return10711072url = os.environ.get('SINGLESTOREDB_URL')1073if not url:1074if self._host == 'singlestore.com':1075raise InterfaceError(0, 'Connection URL has not been established')1076return10771078out = {}1079urlp = connection._parse_url(url)1080out.update(urlp)1081out = connection._cast_params(out)10821083# Set default port based on driver.1084if 'port' not in out or not out['port']:1085if out.get('driver', 'https') == 'http':1086out['port'] = int(get_option('port') or 80)1087else:1088out['port'] = int(get_option('port') or 443)10891090# If there is no user and the password is empty, remove the password key.1091if 'user' not in out and not out.get('password', None):1092out.pop('password', None)10931094if out['host'] == 'singlestore.com':1095raise InterfaceError(0, 'Connection URL has not been established')10961097# Get current connection attributes1098curr_url = urlparse(self._url, scheme='singlestoredb', allow_fragments=True)1099if self._sess.auth is not None:1100auth = tuple(self._sess.auth) # type: ignore1101else:1102auth = (None, None) # type: ignore11031104# If it's just a password change, we don't need to reconnect1105if (curr_url.hostname, curr_url.port, auth[0], self._database) == \1106(out['host'], out['port'], out['user'], out.get('database')):1107return11081109try:1110self._in_sync = True1111sess = requests.Session()1112sess.auth = (out['user'], out['password'])1113sess.headers.update(self._sess.headers)1114sess.verify = self._sess.verify1115sess.cert = self._sess.cert1116self._database = out.get('database')1117self._host = out['host']1118self._url = f'{out.get("driver", "https")}://{out["host"]}:{out["port"]}' \1119f'/api/{self._version}/'1120self._sess = sess1121if self._database:1122kwargs['json']['database'] = self._database1123finally:1124self._in_sync = False11251126def _post(self, path: str, *args: Any, **kwargs: Any) -> requests.Response:1127"""1128Invoke a POST request on the HTTP connection.11291130Parameters1131----------1132path : str1133The path of the resource1134*args : positional parameters, optional1135Extra parameters to the POST request1136**kwargs : keyword parameters, optional1137Extra keyword parameters to the POST request11381139Returns1140-------1141requests.Response11421143"""1144if self._sess is None:1145raise InterfaceError(errno=2048, msg='Connection is closed.')11461147self._sync_connection(kwargs)11481149return self._sess.post(urljoin(self._url, path), *args, **kwargs)11501151def close(self) -> None:1152"""Close the connection."""1153if self._host == 'singlestore.com':1154return1155if self._sess is None:1156raise Error(errno=2048, msg='Connection is closed')1157self._sess = None11581159def autocommit(self, value: bool = True) -> None:1160"""Set autocommit mode."""1161if self._host == 'singlestore.com':1162return1163if self._sess is None:1164raise InterfaceError(errno=2048, msg='Connection is closed')1165self._autocommit = value11661167def commit(self) -> None:1168"""Commit the pending transaction."""1169if self._host == 'singlestore.com':1170return1171if self._sess is None:1172raise InterfaceError(errno=2048, msg='Connection is closed')1173if self._autocommit:1174return1175raise NotSupportedError(msg='operation not supported')11761177def rollback(self) -> None:1178"""Rollback the pending transaction."""1179if self._host == 'singlestore.com':1180return1181if self._sess is None:1182raise InterfaceError(errno=2048, msg='Connection is closed')1183if self._autocommit:1184return1185raise NotSupportedError(msg='operation not supported')11861187def cursor(self) -> Cursor:1188"""1189Create a new cursor object.11901191Returns1192-------1193Cursor11941195"""1196return Cursor(self)11971198def __enter__(self) -> 'Connection':1199"""Enter a context."""1200return self12011202def __exit__(1203self, exc_type: Optional[object],1204exc_value: Optional[Exception], exc_traceback: Optional[str],1205) -> None:1206"""Exit a context."""1207self.close()12081209@property1210def open(self) -> bool:1211"""Check if the database is still connected."""1212if self._sess is None:1213return False1214url = '/'.join(self._url.split('/')[:3]) + '/ping'1215res = self._sess.get(url)1216if res.status_code <= 400 and res.text == 'pong':1217return True1218return False12191220def is_connected(self) -> bool:1221"""1222Check if the database is still connected.12231224Returns1225-------1226bool12271228"""1229return self.open123012311232def connect(1233host: Optional[str] = None,1234user: Optional[str] = None,1235password: Optional[str] = None,1236port: Optional[int] = None,1237database: Optional[str] = None,1238driver: Optional[str] = None,1239pure_python: Optional[bool] = None,1240local_infile: Optional[bool] = None,1241charset: Optional[str] = None,1242ssl_key: Optional[str] = None,1243ssl_cert: Optional[str] = None,1244ssl_ca: Optional[str] = None,1245ssl_disabled: Optional[bool] = None,1246ssl_cipher: Optional[str] = None,1247ssl_verify_cert: Optional[bool] = None,1248ssl_verify_identity: Optional[bool] = None,1249conv: Optional[Dict[int, Callable[..., Any]]] = None,1250credential_type: Optional[str] = None,1251autocommit: Optional[bool] = None,1252results_type: Optional[str] = None,1253buffered: Optional[bool] = None,1254results_format: Optional[str] = None,1255program_name: Optional[str] = None,1256conn_attrs: Optional[Dict[str, str]] = None,1257multi_statements: Optional[bool] = None,1258connect_timeout: Optional[int] = None,1259nan_as_null: Optional[bool] = None,1260inf_as_null: Optional[bool] = None,1261encoding_errors: Optional[str] = None,1262track_env: Optional[bool] = None,1263enable_extended_data_types: Optional[bool] = None,1264vector_data_format: Optional[str] = None,1265) -> Connection:1266return Connection(**dict(locals()))126712681269