Path: blob/main/singlestoredb/http/connection.py
801 views
#!/usr/bin/env python1"""SingleStoreDB HTTP API interface."""2import datetime3import decimal4import functools5import io6import json7import math8import os9import re10import time11from base64 import b64decode12from collections.abc import Iterable13from collections.abc import Sequence14from typing import Any15from typing import Callable16from typing import Dict17from typing import List18from typing import Optional19from typing import Tuple20from typing import Union21from urllib.parse import urljoin22from urllib.parse import urlparse2324import requests2526try:27import numpy as np28has_numpy = True29except ImportError:30has_numpy = False3132try:33import pygeos34has_pygeos = True35except ImportError:36has_pygeos = False3738try:39import shapely.geometry40import shapely.wkt41has_shapely = True42except ImportError:43has_shapely = False4445try:46import pydantic47has_pydantic = True48except ImportError:49has_pydantic = False5051from .. import connection52from .. import fusion53from .. import types54from ..config import get_option55from ..converters import converters56from ..exceptions import DatabaseError # noqa: F40157from ..exceptions import DataError58from ..exceptions import Error # noqa: F40159from ..exceptions import IntegrityError60from ..exceptions import InterfaceError61from ..exceptions import InternalError62from ..exceptions import NotSupportedError63from ..exceptions import OperationalError64from ..exceptions import ProgrammingError65from ..exceptions import Warning # noqa: F40166from ..utils.convert_rows import convert_rows67from ..utils.debug import log_query68from ..utils.mogrify import mogrify69from ..utils.results import Description70from ..utils.results import format_results71from ..utils.results import get_schema72from ..utils.results import Result737475# DB-API settings76apilevel = '2.0'77paramstyle = 'named'78threadsafety = 1798081_interface_errors = set([820,832013, # CR_SERVER_LOST842006, # CR_SERVER_GONE_ERROR852012, # CR_HANDSHAKE_ERR862004, # CR_IPSOCK_ERROR872014, # CR_COMMANDS_OUT_OF_SYNC88])89_data_errors = set([901406, # ER_DATA_TOO_LONG911441, # ER_DATETIME_FUNCTION_OVERFLOW921365, # ER_DIVISION_BY_ZERO931230, # ER_NO_DEFAULT941171, # ER_PRIMARY_CANT_HAVE_NULL951264, # ER_WARN_DATA_OUT_OF_RANGE961265, # ER_WARN_DATA_TRUNCATED97])98_programming_errors = set([991065, # ER_EMPTY_QUERY1001179, # ER_CANT_DO_THIS_DURING_AN_TRANSACTION1011007, # ER_DB_CREATE_EXISTS1021110, # ER_FIELD_SPECIFIED_TWICE1031111, # ER_INVALID_GROUP_FUNC_USE1041082, # ER_NO_SUCH_INDEX1051741, # ER_NO_SUCH_KEY_VALUE1061146, # ER_NO_SUCH_TABLE1071449, # ER_NO_SUCH_USER1081064, # ER_PARSE_ERROR1091149, # ER_SYNTAX_ERROR1101113, # ER_TABLE_MUST_HAVE_COLUMNS1111112, # ER_UNSUPPORTED_EXTENSION1121102, # ER_WRONG_DB_NAME1131103, # ER_WRONG_TABLE_NAME1141049, # ER_BAD_DB_ERROR1151582, # ER_??? Wrong number of args116])117_integrity_errors = set([1181215, # ER_CANNOT_ADD_FOREIGN1191062, # ER_DUP_ENTRY1201169, # ER_DUP_UNIQUE1211364, # ER_NO_DEFAULT_FOR_FIELD1221216, # ER_NO_REFERENCED_ROW1231452, # ER_NO_REFERENCED_ROW_21241217, # ER_ROW_IS_REFERENCED1251451, # ER_ROW_IS_REFERENCED_21261460, # ER_XAER_OUTSIDE1271401, # ER_XAER_RMERR1281048, # ER_BAD_NULL_ERROR1291264, # ER_DATA_OUT_OF_RANGE1304025, # ER_CONSTRAINT_FAILED1311826, # ER_DUP_CONSTRAINT_NAME132])133134135def get_precision_scale(type_code: str) -> Tuple[Optional[int], Optional[int]]:136"""Parse the precision and scale from a data type."""137if '(' not in type_code:138return (None, None)139m = re.search(r'\(\s*(\d+)\s*,\s*(\d+)\s*\)', type_code)140if m:141return int(m.group(1)), int(m.group(2))142m = re.search(r'\(\s*(\d+)\s*\)', type_code)143if m:144return (int(m.group(1)), None)145raise ValueError(f'Unrecognized type code: {type_code}')146147148def get_exc_type(code: int) -> type:149"""Map error code to DB-API error type."""150if code in _interface_errors:151return InterfaceError152if code in _data_errors:153return DataError154if code in _programming_errors:155return ProgrammingError156if code in _integrity_errors:157return IntegrityError158if code >= 1000:159return OperationalError160return InternalError161162163def identity(x: Any) -> Any:164"""Return input value."""165return x166167168def b64decode_converter(169converter: Callable[..., Any],170x: Optional[str],171encoding: str = 'utf-8',172) -> Optional[bytes]:173"""Decode value before applying converter."""174if x is None:175return None176if converter is None:177return b64decode(x)178return converter(b64decode(x))179180181def encode_timedelta(obj: datetime.timedelta) -> str:182"""Encode timedelta as str."""183seconds = int(obj.seconds) % 60184minutes = int(obj.seconds // 60) % 60185hours = int(obj.seconds // 3600) % 24 + int(obj.days) * 24186if obj.microseconds:187fmt = '{0:02d}:{1:02d}:{2:02d}.{3:06d}'188else:189fmt = '{0:02d}:{1:02d}:{2:02d}'190return fmt.format(hours, minutes, seconds, obj.microseconds)191192193def encode_time(obj: datetime.time) -> str:194"""Encode time as str."""195if obj.microsecond:196fmt = '{0.hour:02}:{0.minute:02}:{0.second:02}.{0.microsecond:06}'197else:198fmt = '{0.hour:02}:{0.minute:02}:{0.second:02}'199return fmt.format(obj)200201202def encode_datetime(obj: datetime.datetime) -> str:203"""Encode datetime as str."""204if obj.microsecond:205fmt = '{0.year:04}-{0.month:02}-{0.day:02} ' \206'{0.hour:02}:{0.minute:02}:{0.second:02}.{0.microsecond:06}'207else:208fmt = '{0.year:04}-{0.month:02}-{0.day:02} ' \209'{0.hour:02}:{0.minute:02}:{0.second:02}'210return fmt.format(obj)211212213def encode_date(obj: datetime.date) -> str:214"""Encode date as str."""215fmt = '{0.year:04}-{0.month:02}-{0.day:02}'216return fmt.format(obj)217218219def encode_struct_time(obj: time.struct_time) -> str:220"""Encode time struct to str."""221return encode_datetime(datetime.datetime(*obj[:6]))222223224def encode_decimal(o: decimal.Decimal) -> str:225"""Encode decimal to str."""226return format(o, 'f')227228229# Most argument encoding is done by the JSON encoder, but these230# are exceptions to the rule.231encoders = {232datetime.datetime: encode_datetime,233datetime.date: encode_date,234datetime.time: encode_time,235datetime.timedelta: encode_timedelta,236time.struct_time: encode_struct_time,237decimal.Decimal: encode_decimal,238}239240241if has_shapely:242encoders[shapely.geometry.Point] = shapely.wkt.dumps243encoders[shapely.geometry.Polygon] = shapely.wkt.dumps244encoders[shapely.geometry.LineString] = shapely.wkt.dumps245246if has_numpy:247248def encode_ndarray(obj: np.ndarray) -> bytes: # type: ignore249"""Encode an ndarray as bytes."""250return obj.tobytes()251252encoders[np.ndarray] = encode_ndarray253254if has_pygeos:255encoders[pygeos.Geometry] = pygeos.io.to_wkt256257258def convert_special_type(259arg: Any,260nan_as_null: bool = False,261inf_as_null: bool = False,262) -> Any:263"""Convert special data type objects."""264dtype = type(arg)265if dtype is float or \266(267has_numpy and dtype in (268np.float16, np.float32, np.float64,269getattr(np, 'float128', np.float64),270)271):272if nan_as_null and math.isnan(arg):273return None274if inf_as_null and math.isinf(arg):275return None276func = encoders.get(dtype, None)277if func is not None:278return func(arg) # type: ignore279return arg280281282def convert_special_params(283params: Optional[Union[Sequence[Any], Dict[str, Any]]] = None,284nan_as_null: bool = False,285inf_as_null: bool = False,286) -> Optional[Union[Sequence[Any], Dict[str, Any]]]:287"""Convert parameters of special data types."""288if params is None:289return params290converter = functools.partial(291convert_special_type,292nan_as_null=nan_as_null,293inf_as_null=inf_as_null,294)295if isinstance(params, Dict):296return {k: converter(v) for k, v in params.items()}297return tuple(map(converter, params))298299300class PyMyField(object):301"""Field for PyMySQL compatibility."""302303def __init__(self, name: str, flags: int, charset: int) -> None:304self.name = name305self.flags = flags306self.charsetnr = charset307308309class PyMyResult(object):310"""Result for PyMySQL compatibility."""311312def __init__(self) -> None:313self.fields: List[PyMyField] = []314self.unbuffered_active = False315316def append(self, item: PyMyField) -> None:317self.fields.append(item)318319320class Cursor(connection.Cursor):321"""322SingleStoreDB HTTP database cursor.323324Cursor objects should not be created directly. They should come from325the `cursor` method on the `Connection` object.326327Parameters328----------329connection : Connection330The HTTP Connection object the cursor belongs to331332"""333334def __init__(self, conn: 'Connection'):335connection.Cursor.__init__(self, conn)336self._connection: Optional[Connection] = conn337self._results: List[List[Tuple[Any, ...]]] = [[]]338self._results_type: str = self._connection._results_type \339if self._connection is not None else 'tuples'340self._row_idx: int = -1341self._result_idx: int = -1342self._descriptions: List[List[Description]] = []343self._schemas: List[Dict[str, Any]] = []344self.arraysize: int = get_option('results.arraysize')345self.rowcount: int = 0346self.lastrowid: Optional[int] = None347self._pymy_results: List[PyMyResult] = []348self._expect_results: bool = False349350@property351def _result(self) -> Optional[PyMyResult]:352"""Return Result object for PyMySQL compatibility."""353if self._result_idx < 0:354return None355return self._pymy_results[self._result_idx]356357@property358def description(self) -> Optional[List[Description]]:359"""Return description for current result set."""360if not self._descriptions:361return None362if self._result_idx >= 0 and self._result_idx < len(self._descriptions):363return self._descriptions[self._result_idx]364return None365366@property367def _schema(self) -> Optional[Any]:368if not self._schemas:369return None370if self._result_idx >= 0 and self._result_idx < len(self._schemas):371return self._schemas[self._result_idx]372return None373374def _post(self, path: str, *args: Any, **kwargs: Any) -> requests.Response:375"""376Invoke a POST request on the HTTP connection.377378Parameters379----------380path : str381The path of the resource382*args : positional parameters, optional383Extra parameters to the POST request384**kwargs : keyword parameters, optional385Extra keyword parameters to the POST request386387Returns388-------389requests.Response390391"""392if self._connection is None:393raise ProgrammingError(errno=2048, msg='Connection is closed.')394if 'timeout' not in kwargs:395kwargs['timeout'] = self._connection.connection_params['connect_timeout']396return self._connection._post(path, *args, **kwargs)397398def callproc(399self, name: str,400params: Optional[Sequence[Any]] = None,401) -> None:402"""403Call a stored procedure.404405Parameters406----------407name : str408Name of the stored procedure409params : sequence, optional410Parameters to the stored procedure411412"""413if self._connection is None:414raise ProgrammingError(errno=2048, msg='Connection is closed.')415416name = connection._name_check(name)417418if not params:419self._execute(f'CALL {name}();', is_callproc=True)420else:421keys = ', '.join(['%s' for i in range(len(params))])422self._execute(f'CALL {name}({keys});', params, is_callproc=True)423424def close(self) -> None:425"""Close the cursor."""426self._connection = None427428def execute(429self, query: str,430args: Optional[Union[Sequence[Any], Dict[str, Any]]] = None,431infile_stream: Optional[ # type: ignore432Union[433io.RawIOBase,434io.TextIOBase,435Iterable[Union[bytes, str]],436connection.InfileQueue,437]438] = None,439) -> int:440"""441Execute a SQL statement.442443Parameters444----------445query : str446The SQL statement to execute447args : iterable or dict, optional448Parameters to substitute into the SQL code449450"""451return self._execute(query, args, infile_stream=infile_stream)452453def _validate_param_subs(454self, query: str,455args: Optional[Union[Sequence[Any], Dict[str, Any]]] = None,456) -> None:457"""Make sure the parameter substitions are valid."""458if args:459if isinstance(args, Sequence):460query = query % tuple(args)461else:462query = query % args463464def _execute_fusion_query(465self,466oper: Union[str, bytes],467params: Optional[Union[Sequence[Any], Dict[str, Any]]] = None,468handler: Any = None,469) -> int:470oper = mogrify(oper, params)471472if isinstance(oper, bytes):473oper = oper.decode('utf-8')474475log_query(oper, None)476477results_type = self._results_type478self._results_type = 'tuples'479try:480mgmt_res = fusion.execute(481self._connection, # type: ignore482oper,483handler=handler,484)485finally:486self._results_type = results_type487488self._descriptions.append(list(mgmt_res.description))489self._schemas.append(get_schema(self._results_type, list(mgmt_res.description)))490self._results.append(list(mgmt_res.rows))491self.rowcount = len(self._results[-1])492493pymy_res = PyMyResult()494for field in mgmt_res.fields:495pymy_res.append(496PyMyField(497field.name,498field.flags,499field.charsetnr,500),501)502503self._pymy_results.append(pymy_res)504505if self._results and self._results[0]:506self._row_idx = 0507self._result_idx = 0508509return self.rowcount510511def _execute(512self, oper: str,513params: Optional[Union[Sequence[Any], Dict[str, Any]]] = None,514is_callproc: bool = False,515infile_stream: Optional[ # type: ignore516Union[517io.RawIOBase,518io.TextIOBase,519Iterable[Union[bytes, str]],520connection.InfileQueue,521]522] = None,523) -> int:524self._descriptions = []525self._schemas = []526self._results = []527self._pymy_results = []528self._row_idx = -1529self._result_idx = -1530self.rowcount = 0531self._expect_results = False532533if self._connection is None:534raise ProgrammingError(errno=2048, msg='Connection is closed.')535536sql_type = 'exec'537if re.match(r'^\s*(select|show|call|echo|describe|with)\s+', oper, flags=re.I):538self._expect_results = True539sql_type = 'query'540541if has_pydantic and isinstance(params, pydantic.BaseModel):542params = params.model_dump()543544self._validate_param_subs(oper, params)545546handler = fusion.get_handler(oper)547if handler is not None:548return self._execute_fusion_query(oper, params, handler=handler)549550interpolate_query_with_empty_args = self._connection.connection_params.get(551'interpolate_query_with_empty_args', False,552)553oper, params = self._connection._convert_params(554oper, params, interpolate_query_with_empty_args,555)556557log_query(oper, params)558559data: Dict[str, Any] = dict(sql=oper)560if params is not None:561data['args'] = convert_special_params(562params,563nan_as_null=self._connection.connection_params['nan_as_null'],564inf_as_null=self._connection.connection_params['inf_as_null'],565)566if self._connection._database:567data['database'] = self._connection._database568569if sql_type == 'query':570res = self._post('query/tuples', json=data)571else:572res = self._post('exec', json=data)573574if res.status_code >= 400:575if res.text:576m = re.match(r'^Error\s+(\d+).*?:', res.text)577if m:578code = m.group(1)579msg = res.text.split(':', 1)[-1]580icode = int(code.split()[-1])581else:582icode = res.status_code583msg = res.text584raise get_exc_type(icode)(icode, msg.strip())585raise InterfaceError(errno=res.status_code, msg='HTTP Error')586587out = json.loads(res.text)588589if 'error' in out:590raise OperationalError(591errno=out['error'].get('code', 0),592msg=out['error'].get('message', 'HTTP Error'),593)594595if sql_type == 'query':596# description: (name, type_code, display_size, internal_size,597# precision, scale, null_ok, column_flags, charset)598599# Remove converters for things the JSON parser already converted600http_converters = dict(self._connection.decoders)601http_converters.pop(4, None)602http_converters.pop(5, None)603http_converters.pop(6, None)604http_converters.pop(15, None)605http_converters.pop(245, None)606http_converters.pop(247, None)607http_converters.pop(249, None)608http_converters.pop(250, None)609http_converters.pop(251, None)610http_converters.pop(252, None)611http_converters.pop(253, None)612http_converters.pop(254, None)613614# Merge passed in converters615if self._connection._conv:616for k, v in self._connection._conv.items():617if isinstance(k, int):618http_converters[k] = v619620# Make JSON a string for Arrow621if 'arrow' in self._results_type:622def json_to_str(x: Any) -> Optional[str]:623if x is None:624return None625return json.dumps(x)626http_converters[245] = json_to_str627628# Don't convert date/times in polars629elif 'polars' in self._results_type:630http_converters.pop(7, None)631http_converters.pop(10, None)632http_converters.pop(12, None)633634results = out['results']635636# Convert data to Python types637if results and results[0]:638self._row_idx = 0639self._result_idx = 0640641for result in results:642643pymy_res = PyMyResult()644convs = []645646description: List[Description] = []647for i, col in enumerate(result.get('columns', [])):648charset = 0649flags = 0650data_type = col['dataType'].split('(')[0]651type_code = types.ColumnType.get_code(data_type)652prec, scale = get_precision_scale(col['dataType'])653converter = http_converters.get(type_code, None)654655if 'UNSIGNED' in data_type:656flags = 32657658if data_type.endswith('BLOB') or data_type.endswith('BINARY'):659converter = functools.partial(660b64decode_converter, converter, # type: ignore661)662charset = 63 # BINARY663664if type_code == 0: # DECIMAL665type_code = types.ColumnType.get_code('NEWDECIMAL')666elif type_code == 15: # VARCHAR / VARBINARY667type_code = types.ColumnType.get_code('VARSTRING')668669if converter is not None:670convs.append((i, None, converter))671672description.append(673Description(674str(col['name']), type_code,675None, None, prec, scale,676col.get('nullable', False),677flags, charset,678),679)680pymy_res.append(PyMyField(col['name'], flags, charset))681682self._descriptions.append(description)683self._schemas.append(get_schema(self._results_type, description))684685rows = convert_rows(result.get('rows', []), convs)686687self._results.append(rows)688self._pymy_results.append(pymy_res)689690# For compatibility with PyMySQL/MySQLdb691if is_callproc:692self._results.append([])693694self.rowcount = len(self._results[0])695696else:697# For compatibility with PyMySQL/MySQLdb698if is_callproc:699self._results.append([])700701self.rowcount = out['rowsAffected']702703return self.rowcount704705def executemany(706self, query: str,707args: Optional[Sequence[Union[Sequence[Any], Dict[str, Any]]]] = None,708) -> int:709"""710Execute SQL code against multiple sets of parameters.711712Parameters713----------714query : str715The SQL statement to execute716args : iterable of iterables or dicts, optional717Sets of parameters to substitute into the SQL code718719"""720if self._connection is None:721raise ProgrammingError(errno=2048, msg='Connection is closed.')722723results = []724rowcount = 0725if args is not None and len(args) > 0:726description = []727schema = {}728# Detect dataframes729if hasattr(args, 'itertuples'):730argiter = args.itertuples(index=False) # type: ignore731else:732argiter = iter(args)733for params in argiter:734self.execute(query, params)735if self._descriptions:736description = self._descriptions[-1]737if self._schemas:738schema = self._schemas[-1]739if self._rows is not None:740results.append(self._rows)741rowcount += self.rowcount742self._results = results743self._descriptions = [description for _ in range(len(results))]744self._schemas = [schema for _ in range(len(results))]745else:746self.execute(query)747rowcount += self.rowcount748749self.rowcount = rowcount750751return self.rowcount752753@property754def _has_row(self) -> bool:755"""Determine if a row is available."""756if self._result_idx < 0 or self._result_idx >= len(self._results):757return False758if self._row_idx < 0 or self._row_idx >= len(self._results[self._result_idx]):759return False760return True761762@property763def _rows(self) -> List[Tuple[Any, ...]]:764"""Return current set of rows."""765if not self._has_row:766return []767return self._results[self._result_idx]768769def fetchone(self) -> Optional[Result]:770"""771Fetch a single row from the result set.772773Returns774-------775tuple776Values of the returned row if there are rows remaining777None778If there are no rows left to return779780"""781if self._connection is None:782raise ProgrammingError(errno=2048, msg='Connection is closed')783if not self._expect_results:784raise self._connection.ProgrammingError(msg='No query has been submitted')785if not self._has_row:786return None787out = self._rows[self._row_idx]788self._row_idx += 1789return format_results(790self._results_type,791self.description or [],792out, single=True,793schema=self._schema,794)795796def fetchmany(797self,798size: Optional[int] = None,799) -> Result:800"""801Fetch `size` rows from the result.802803If `size` is not specified, the `arraysize` attribute is used.804805Returns806-------807list of tuples808Values of the returned rows if there are rows remaining809810"""811if self._connection is None:812raise ProgrammingError(errno=2048, msg='Connection is closed')813if not self._expect_results:814raise self._connection.ProgrammingError(msg='No query has been submitted')815if not self._has_row:816if 'dict' in self._results_type:817return {}818return tuple()819if not size:820size = max(int(self.arraysize), 1)821else:822size = max(int(size), 1)823out = self._rows[self._row_idx:self._row_idx+size]824self._row_idx += len(out)825return format_results(826self._results_type, self.description or [],827out, schema=self._schema,828)829830def fetchall(self) -> Result:831"""832Fetch all rows in the result set.833834Returns835-------836list of tuples837Values of the returned rows if there are rows remaining838839"""840if self._connection is None:841raise ProgrammingError(errno=2048, msg='Connection is closed')842if not self._expect_results:843raise self._connection.ProgrammingError(msg='No query has been submitted')844if not self._has_row:845if 'dict' in self._results_type:846return {}847return tuple()848out = list(self._rows[self._row_idx:])849self._row_idx = len(out)850return format_results(851self._results_type, self.description or [],852out, schema=self._schema,853)854855def nextset(self) -> Optional[bool]:856"""Skip to the next available result set."""857if self._connection is None:858raise ProgrammingError(errno=2048, msg='Connection is closed')859860if self._result_idx < 0:861self._row_idx = -1862return None863864self._result_idx += 1865self._row_idx = 0866867if self._result_idx >= len(self._results):868self._result_idx = -1869self._row_idx = -1870return None871872self.rowcount = len(self._results[self._result_idx])873874return True875876def setinputsizes(self, sizes: Sequence[int]) -> None:877"""Predefine memory areas for parameters."""878pass879880def setoutputsize(self, size: int, column: Optional[str] = None) -> None:881"""Set a column buffer size for fetches of large columns."""882pass883884@property885def rownumber(self) -> Optional[int]:886"""887Return the zero-based index of the cursor in the result set.888889Returns890-------891int892893"""894if self._row_idx < 0:895return None896return self._row_idx897898def scroll(self, value: int, mode: str = 'relative') -> None:899"""900Scroll the cursor to the position in the result set.901902Parameters903----------904value : int905Value of the positional move906mode : str907Type of move that should be made: 'relative' or 'absolute'908909"""910if self._connection is None:911raise ProgrammingError(errno=2048, msg='Connection is closed')912if mode == 'relative':913self._row_idx += value914elif mode == 'absolute':915self._row_idx = value916else:917raise ValueError(918f'{mode} is not a valid mode, '919'expecting "relative" or "absolute"',920)921922def next(self) -> Optional[Result]:923"""924Return the next row from the result set for use in iterators.925926Returns927-------928tuple929Values from the next result row930None931If no more rows exist932933"""934if self._connection is None:935raise InterfaceError(errno=2048, msg='Connection is closed')936out = self.fetchone()937if out is None:938raise StopIteration939return out940941__next__ = next942943def __iter__(self) -> Iterable[Tuple[Any, ...]]:944"""Return result iterator."""945return iter(self._rows[self._row_idx:])946947def __enter__(self) -> 'Cursor':948"""Enter a context."""949return self950951def __exit__(952self, exc_type: Optional[object],953exc_value: Optional[Exception], exc_traceback: Optional[str],954) -> None:955"""Exit a context."""956self.close()957958@property959def open(self) -> bool:960"""Check if the cursor is still connected."""961if self._connection is None:962return False963return self._connection.is_connected()964965def is_connected(self) -> bool:966"""967Check if the cursor is still connected.968969Returns970-------971bool972973"""974return self.open975976977class Connection(connection.Connection):978"""979SingleStoreDB HTTP database connection.980981Instances of this object are typically created through the982`connection` function rather than creating them directly.983984See Also985--------986`connect`987988"""989driver = 'https'990paramstyle = 'qmark'991992def __init__(self, **kwargs: Any):993from .. import __version__ as client_version994995if 'SINGLESTOREDB_WORKLOAD_TYPE' in os.environ:996client_version += '+' + os.environ['SINGLESTOREDB_WORKLOAD_TYPE']997998connection.Connection.__init__(self, **kwargs)9991000host = kwargs.get('host', get_option('host'))1001port = kwargs.get('port', get_option('http_port'))10021003self._sess: Optional[requests.Session] = requests.Session()10041005user = kwargs.get('user', get_option('user'))1006password = kwargs.get('password', get_option('password'))1007if user is not None and password is not None:1008self._sess.auth = (user, password)1009elif user is not None:1010self._sess.auth = (user, '')1011self._sess.headers.update({1012'Content-Type': 'application/json',1013'Accept': 'application/json',1014'Accept-Encoding': 'compress,identity',1015'User-Agent': f'SingleStoreDB-Python/{client_version}',1016})10171018if kwargs.get('ssl_disabled', get_option('ssl_disabled')):1019self._sess.verify = False1020else:1021ssl_key = kwargs.get('ssl_key', get_option('ssl_key'))1022ssl_cert = kwargs.get('ssl_cert', get_option('ssl_cert'))1023if ssl_key and ssl_cert:1024self._sess.cert = (ssl_key, ssl_cert)1025elif ssl_cert:1026self._sess.cert = ssl_cert10271028ssl_ca = kwargs.get('ssl_ca', get_option('ssl_ca'))1029if ssl_ca:1030self._sess.verify = ssl_ca10311032ssl_verify_cert = kwargs.get('ssl_verify_cert', True)1033if not ssl_verify_cert:1034self._sess.verify = False10351036if kwargs.get('multi_statements', False):1037raise self.InterfaceError(10380, 'The Data API does not allow multiple '1039'statements within a query',1040)10411042self._version = kwargs.get('version', 'v2')1043self.driver = kwargs.get('driver', 'https')10441045self.encoders = {k: v for (k, v) in converters.items() if type(k) is not int}1046self.decoders = {k: v for (k, v) in converters.items() if type(k) is int}10471048self._database = kwargs.get('database', get_option('database'))1049self._url = f'{self.driver}://{host}:{port}/api/{self._version}/'1050self._host = host1051self._messages: List[Tuple[int, str]] = []1052self._autocommit: bool = True1053self._conv = kwargs.get('conv', None)1054self._in_sync: bool = False1055self._track_env: bool = kwargs.get('track_env', False) \1056or host == 'singlestore.com'10571058@property1059def messages(self) -> List[Tuple[int, str]]:1060return self._messages10611062def connect(self) -> 'Connection':1063"""Connect to the server."""1064return self10651066def _sync_connection(self, kwargs: Dict[str, Any]) -> None:1067"""Synchronize connection with env variable."""1068if self._sess is None:1069raise InterfaceError(errno=2048, msg='Connection is closed.')10701071if self._in_sync:1072return10731074if not self._track_env:1075return10761077url = os.environ.get('SINGLESTOREDB_URL')1078if not url:1079if self._host == 'singlestore.com':1080raise InterfaceError(0, 'Connection URL has not been established')1081return10821083out = {}1084urlp = connection._parse_url(url)1085out.update(urlp)1086out = connection._cast_params(out)10871088# Set default port based on driver.1089if 'port' not in out or not out['port']:1090if out.get('driver', 'https') == 'http':1091out['port'] = int(get_option('port') or 80)1092else:1093out['port'] = int(get_option('port') or 443)10941095# If there is no user and the password is empty, remove the password key.1096if 'user' not in out and not out.get('password', None):1097out.pop('password', None)10981099if out['host'] == 'singlestore.com':1100raise InterfaceError(0, 'Connection URL has not been established')11011102# Get current connection attributes1103curr_url = urlparse(self._url, scheme='singlestoredb', allow_fragments=True)1104if self._sess.auth is not None:1105auth = tuple(self._sess.auth) # type: ignore1106else:1107auth = (None, None) # type: ignore11081109# If it's just a password change, we don't need to reconnect1110if (curr_url.hostname, curr_url.port, auth[0], self._database) == \1111(out['host'], out['port'], out['user'], out.get('database')):1112return11131114try:1115self._in_sync = True1116sess = requests.Session()1117sess.auth = (out['user'], out['password'])1118sess.headers.update(self._sess.headers)1119sess.verify = self._sess.verify1120sess.cert = self._sess.cert1121self._database = out.get('database')1122self._host = out['host']1123self._url = f'{out.get("driver", "https")}://{out["host"]}:{out["port"]}' \1124f'/api/{self._version}/'1125self._sess = sess1126if self._database:1127kwargs['json']['database'] = self._database1128finally:1129self._in_sync = False11301131def _post(self, path: str, *args: Any, **kwargs: Any) -> requests.Response:1132"""1133Invoke a POST request on the HTTP connection.11341135Parameters1136----------1137path : str1138The path of the resource1139*args : positional parameters, optional1140Extra parameters to the POST request1141**kwargs : keyword parameters, optional1142Extra keyword parameters to the POST request11431144Returns1145-------1146requests.Response11471148"""1149if self._sess is None:1150raise InterfaceError(errno=2048, msg='Connection is closed.')11511152self._sync_connection(kwargs)11531154return self._sess.post(urljoin(self._url, path), *args, **kwargs)11551156def close(self) -> None:1157"""Close the connection."""1158if self._host == 'singlestore.com':1159return1160if self._sess is None:1161raise Error(errno=2048, msg='Connection is closed')1162self._sess = None11631164def autocommit(self, value: bool = True) -> None:1165"""Set autocommit mode."""1166if self._host == 'singlestore.com':1167return1168if self._sess is None:1169raise InterfaceError(errno=2048, msg='Connection is closed')1170self._autocommit = value11711172def commit(self) -> None:1173"""Commit the pending transaction."""1174if self._host == 'singlestore.com':1175return1176if self._sess is None:1177raise InterfaceError(errno=2048, msg='Connection is closed')1178if self._autocommit:1179return1180raise NotSupportedError(msg='operation not supported')11811182def rollback(self) -> None:1183"""Rollback the pending transaction."""1184if self._host == 'singlestore.com':1185return1186if self._sess is None:1187raise InterfaceError(errno=2048, msg='Connection is closed')1188if self._autocommit:1189return1190raise NotSupportedError(msg='operation not supported')11911192def cursor(self) -> Cursor:1193"""1194Create a new cursor object.11951196Returns1197-------1198Cursor11991200"""1201return Cursor(self)12021203def __enter__(self) -> 'Connection':1204"""Enter a context."""1205return self12061207def __exit__(1208self, exc_type: Optional[object],1209exc_value: Optional[Exception], exc_traceback: Optional[str],1210) -> None:1211"""Exit a context."""1212self.close()12131214@property1215def open(self) -> bool:1216"""Check if the database is still connected."""1217if self._sess is None:1218return False1219url = '/'.join(self._url.split('/')[:3]) + '/ping'1220res = self._sess.get(url)1221if res.status_code <= 400 and res.text == 'pong':1222return True1223return False12241225def is_connected(self) -> bool:1226"""1227Check if the database is still connected.12281229Returns1230-------1231bool12321233"""1234return self.open123512361237def connect(1238host: Optional[str] = None,1239user: Optional[str] = None,1240password: Optional[str] = None,1241port: Optional[int] = None,1242database: Optional[str] = None,1243driver: Optional[str] = None,1244pure_python: Optional[bool] = None,1245local_infile: Optional[bool] = None,1246charset: Optional[str] = None,1247ssl_key: Optional[str] = None,1248ssl_cert: Optional[str] = None,1249ssl_ca: Optional[str] = None,1250ssl_disabled: Optional[bool] = None,1251ssl_cipher: Optional[str] = None,1252ssl_verify_cert: Optional[bool] = None,1253ssl_verify_identity: Optional[bool] = None,1254conv: Optional[Dict[int, Callable[..., Any]]] = None,1255credential_type: Optional[str] = None,1256autocommit: Optional[bool] = None,1257results_type: Optional[str] = None,1258buffered: Optional[bool] = None,1259results_format: Optional[str] = None,1260program_name: Optional[str] = None,1261conn_attrs: Optional[Dict[str, str]] = None,1262multi_statements: Optional[bool] = None,1263connect_timeout: Optional[int] = None,1264nan_as_null: Optional[bool] = None,1265inf_as_null: Optional[bool] = None,1266encoding_errors: Optional[str] = None,1267track_env: Optional[bool] = None,1268enable_extended_data_types: Optional[bool] = None,1269vector_data_format: Optional[str] = None,1270) -> Connection:1271return Connection(**dict(locals()))127212731274