Path: blob/main/singlestoredb/fusion/result.py
469 views
#!/usr/bin/env python31from __future__ import annotations23import re4from typing import Any5from typing import Iterable6from typing import List7from typing import Optional8from typing import Tuple9from typing import Union1011from .. import connection12from ..mysql.constants.FIELD_TYPE import BLOB # noqa: F40113from ..mysql.constants.FIELD_TYPE import BOOL # noqa: F40114from ..mysql.constants.FIELD_TYPE import DATE # noqa: F40115from ..mysql.constants.FIELD_TYPE import DATETIME # noqa: F40116from ..mysql.constants.FIELD_TYPE import DOUBLE # noqa: F40117from ..mysql.constants.FIELD_TYPE import JSON # noqa: F40118from ..mysql.constants.FIELD_TYPE import LONGLONG as INTEGER # noqa: F40119from ..mysql.constants.FIELD_TYPE import STRING # noqa: F40120from ..utils.results import Description21from ..utils.results import format_results222324class FusionField(object):25"""Field for PyMySQL compatibility."""2627def __init__(self, name: str, flags: int, charset: int) -> None:28self.name = name29self.flags = flags30self.charsetnr = charset313233class FusionSQLColumn(object):34"""Column accessor for a FusionSQLResult."""3536def __init__(self, result: FusionSQLResult, index: int) -> None:37self._result = result38self._index = index3940def __getitem__(self, index: Any) -> Any:41return self._result.rows[index][self._index]4243def __iter__(self) -> Iterable[Any]:44def gen() -> Iterable[Any]:45for row in iter(self._result):46yield row[self._index]47return gen()484950class FieldIndexDict(dict): # type: ignore51"""Case-insensitive dictionary for column name lookups."""5253def __getitem__(self, key: str) -> int:54return super().__getitem__(key.lower())5556def __setitem__(self, key: str, value: int) -> None:57super().__setitem__(key.lower(), value)5859def __contains__(self, key: object) -> bool:60if not isinstance(key, str):61return False62return super().__contains__(str(key).lower())6364def copy(self) -> FieldIndexDict:65out = type(self)()66for k, v in self.items():67out[k.lower()] = v68return out697071class FusionSQLResult(object):72"""Result for Fusion SQL commands."""7374def __init__(self) -> None:75self.connection: Any = None76self.affected_rows: Optional[int] = None77self.insert_id: int = 078self.server_status: Optional[int] = None79self.warning_count: int = 080self.message: Optional[str] = None81self.description: List[Description] = []82self.rows: Any = []83self.has_next: bool = False84self.unbuffered_active: bool = False85self.converters: List[Any] = []86self.fields: List[FusionField] = []87self._field_indexes: FieldIndexDict = FieldIndexDict()88self._row_idx: int = -18990def copy(self) -> FusionSQLResult:91"""Copy the result."""92out = type(self)()93for k, v in vars(self).items():94if isinstance(v, list):95setattr(out, k, list(v))96elif isinstance(v, dict):97setattr(out, k, v.copy())98else:99setattr(out, k, v)100return out101102def _read_rowdata_packet_unbuffered(self, size: int = 1) -> Optional[List[Any]]:103if not self.rows:104return None105106out = []107108try:109for i in range(1, size + 1):110out.append(self.rows[self._row_idx + i])111except IndexError:112self._row_idx = -1113self.rows = []114return None115else:116self._row_idx += size117118return out119120def _finish_unbuffered_query(self) -> None:121self._row_idx = -1122self.rows = []123self.affected_rows = None124125def format_results(self, connection: connection.Connection) -> None:126"""127Format the results using the connection converters and options.128129Parameters130----------131connection : Connection132The connection containing the converters and options133134"""135self.converters = []136137for item in self.description:138self.converters.append((139item.charset,140connection.decoders.get(item.type_code),141))142143# Convert values144for i, row in enumerate(self.rows):145new_row = []146for (_, converter), value in zip(self.converters, row):147new_row.append(converter(value) if converter is not None else value)148self.rows[i] = tuple(new_row)149150self.rows[:] = format_results(151connection._results_type, self.description, self.rows,152)153154def __iter__(self) -> Iterable[Tuple[Any, ...]]:155return iter(self.rows)156157def __len__(self) -> int:158return len(self.rows)159160def __getitem__(self, key: Any) -> Tuple[Any, ...]:161if isinstance(key, str):162return self.__getattr__(key)163return self.rows[key]164165def __getattr__(self, name: str) -> Any:166return FusionSQLColumn(self, self._field_indexes[name])167168def add_field(self, name: str, dtype: int) -> None:169"""170Add a new field / column to the data set.171172Parameters173----------174name : str175The name of the field / column176dtype : int177The MySQL field type: BLOB, BOOL, DATE, DATETIME,178DOUBLE, JSON, INTEGER, or STRING179180"""181charset = 0182if dtype in (JSON, STRING):183encoding = 'utf-8'184elif dtype == BLOB:185charset = 63186encoding = None187else:188encoding = 'ascii'189self.description.append(190Description(name, dtype, None, None, 0, 0, True, 0, charset),191)192self.fields.append(FusionField(name, 0, charset))193self._field_indexes[name] = len(self.fields) - 1194self.converters.append((encoding, None))195196def set_rows(self, data: List[Tuple[Any, ...]]) -> None:197"""198Set the rows of the result.199200Parameters201----------202data : List[Tuple[Any, ...]]203The data should be a list of tuples where each element of the204tuple corresponds to a field added to the result with205the :meth:`add_field` method.206207"""208self.rows = list(data)209self.affected_rows = 0210211def like(self, **kwargs: str) -> FusionSQLResult:212"""213Return a new result containing only rows that match all `kwargs` like patterns.214215Parameters216----------217**kwargs : str218Each parameter name corresponds to a column name in the result. The value219of the parameters is a LIKE pattern to match.220221Returns222-------223FusionSQLResult224225"""226likers = []227for k, v in kwargs.items():228if k not in self._field_indexes:229raise KeyError(f'field name does not exist in results: {k}')230if not v:231continue232regex = re.compile(233'^{}$'.format(234re.sub(r'\\%', r'.*', re.sub(r'([^\w])', r'\\\1', v)),235), flags=re.I | re.M,236)237likers.append((self._field_indexes[k], regex))238239filtered_rows = []240for row in self.rows:241found = True242for i, liker in likers:243if row[i] is None or not liker.match(row[i]):244found = False245break246if found:247filtered_rows.append(row)248249out = self.copy()250out.rows[:] = filtered_rows251return out252253like_all = like254255def like_any(self, **kwargs: str) -> FusionSQLResult:256"""257Return a new result containing only rows that match any `kwargs` like patterns.258259Parameters260----------261**kwargs : str262Each parameter name corresponds to a column name in the result. The value263of the parameters is a LIKE pattern to match.264265Returns266-------267FusionSQLResult268269"""270likers = []271for k, v in kwargs.items():272if k not in self._field_indexes:273raise KeyError(f'field name does not exist in results: {k}')274if not v:275continue276regex = re.compile(277'^{}$'.format(278re.sub(r'\\%', r'.*', re.sub(r'([^\w])', r'\\\1', v)),279), flags=re.I | re.M,280)281likers.append((self._field_indexes[k], regex))282283filtered_rows = []284for row in self.rows:285found = False286for i, liker in likers:287if liker.match(row[i]):288found = True289break290if found:291filtered_rows.append(row)292293out = self.copy()294out.rows[:] = filtered_rows295return out296297def filter(self, **kwargs: str) -> FusionSQLResult:298"""299Return a new result containing only rows that match all `kwargs` values.300301Parameters302----------303**kwargs : str304Each parameter name corresponds to a column name in the result. The value305of the parameters is the value to match.306307Returns308-------309FusionSQLResult310311"""312if not kwargs:313return self.copy()314315values = []316for k, v in kwargs.items():317if k not in self._field_indexes:318raise KeyError(f'field name does not exist in results: {k}')319values.append((self._field_indexes[k], v))320321filtered_rows = []322for row in self.rows:323found = True324for i, val in values:325if row[0] != val:326found = False327break328if found:329filtered_rows.append(row)330331out = self.copy()332out.rows[:] = filtered_rows333return out334335def limit(self, n_rows: int) -> FusionSQLResult:336"""337Return a new result containing only `n_rows` rows.338339Parameters340----------341n_rows : int342The number of rows to limit the result to343344Returns345-------346FusionSQLResult347348"""349out = self.copy()350if n_rows:351out.rows[:] = out.rows[:n_rows]352return out353354def sort_by(355self,356by: Union[str, List[str]],357ascending: Union[bool, List[bool]] = True,358) -> FusionSQLResult:359"""360Return a new result with rows sorted in specified order.361362Parameters363----------364by : str or List[str]365Name or names of columns to sort by366ascending : bool or List[bool], optional367Should the sort order be ascending? If not all sort columns368use the same ordering, a list of booleans can be supplied to369indicate the order for each column.370371Returns372-------373FusionSQLResult374375"""376if not by:377return self.copy()378379if isinstance(by, str):380by = [by]381by = list(reversed(by))382383if isinstance(ascending, bool):384ascending = [ascending]385ascending = list(reversed(ascending))386387out = self.copy()388for i, byvar in enumerate(by):389out.rows.sort(390key=lambda x: (3910 if x[self._field_indexes[byvar]] is None else 1,392x[self._field_indexes[byvar]],393),394reverse=not ascending[i],395)396return out397398order_by = sort_by399400401