Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
singlestore-labs
GitHub Repository: singlestore-labs/singlestoredb-python
Path: blob/main/singlestoredb/fusion/result.py
469 views
1
#!/usr/bin/env python3
2
from __future__ import annotations
3
4
import re
5
from typing import Any
6
from typing import Iterable
7
from typing import List
8
from typing import Optional
9
from typing import Tuple
10
from typing import Union
11
12
from .. import connection
13
from ..mysql.constants.FIELD_TYPE import BLOB # noqa: F401
14
from ..mysql.constants.FIELD_TYPE import BOOL # noqa: F401
15
from ..mysql.constants.FIELD_TYPE import DATE # noqa: F401
16
from ..mysql.constants.FIELD_TYPE import DATETIME # noqa: F401
17
from ..mysql.constants.FIELD_TYPE import DOUBLE # noqa: F401
18
from ..mysql.constants.FIELD_TYPE import JSON # noqa: F401
19
from ..mysql.constants.FIELD_TYPE import LONGLONG as INTEGER # noqa: F401
20
from ..mysql.constants.FIELD_TYPE import STRING # noqa: F401
21
from ..utils.results import Description
22
from ..utils.results import format_results
23
24
25
class FusionField(object):
26
"""Field for PyMySQL compatibility."""
27
28
def __init__(self, name: str, flags: int, charset: int) -> None:
29
self.name = name
30
self.flags = flags
31
self.charsetnr = charset
32
33
34
class FusionSQLColumn(object):
35
"""Column accessor for a FusionSQLResult."""
36
37
def __init__(self, result: FusionSQLResult, index: int) -> None:
38
self._result = result
39
self._index = index
40
41
def __getitem__(self, index: Any) -> Any:
42
return self._result.rows[index][self._index]
43
44
def __iter__(self) -> Iterable[Any]:
45
def gen() -> Iterable[Any]:
46
for row in iter(self._result):
47
yield row[self._index]
48
return gen()
49
50
51
class FieldIndexDict(dict): # type: ignore
52
"""Case-insensitive dictionary for column name lookups."""
53
54
def __getitem__(self, key: str) -> int:
55
return super().__getitem__(key.lower())
56
57
def __setitem__(self, key: str, value: int) -> None:
58
super().__setitem__(key.lower(), value)
59
60
def __contains__(self, key: object) -> bool:
61
if not isinstance(key, str):
62
return False
63
return super().__contains__(str(key).lower())
64
65
def copy(self) -> FieldIndexDict:
66
out = type(self)()
67
for k, v in self.items():
68
out[k.lower()] = v
69
return out
70
71
72
class FusionSQLResult(object):
73
"""Result for Fusion SQL commands."""
74
75
def __init__(self) -> None:
76
self.connection: Any = None
77
self.affected_rows: Optional[int] = None
78
self.insert_id: int = 0
79
self.server_status: Optional[int] = None
80
self.warning_count: int = 0
81
self.message: Optional[str] = None
82
self.description: List[Description] = []
83
self.rows: Any = []
84
self.has_next: bool = False
85
self.unbuffered_active: bool = False
86
self.converters: List[Any] = []
87
self.fields: List[FusionField] = []
88
self._field_indexes: FieldIndexDict = FieldIndexDict()
89
self._row_idx: int = -1
90
91
def copy(self) -> FusionSQLResult:
92
"""Copy the result."""
93
out = type(self)()
94
for k, v in vars(self).items():
95
if isinstance(v, list):
96
setattr(out, k, list(v))
97
elif isinstance(v, dict):
98
setattr(out, k, v.copy())
99
else:
100
setattr(out, k, v)
101
return out
102
103
def _read_rowdata_packet_unbuffered(self, size: int = 1) -> Optional[List[Any]]:
104
if not self.rows:
105
return None
106
107
out = []
108
109
try:
110
for i in range(1, size + 1):
111
out.append(self.rows[self._row_idx + i])
112
except IndexError:
113
self._row_idx = -1
114
self.rows = []
115
return None
116
else:
117
self._row_idx += size
118
119
return out
120
121
def _finish_unbuffered_query(self) -> None:
122
self._row_idx = -1
123
self.rows = []
124
self.affected_rows = None
125
126
def format_results(self, connection: connection.Connection) -> None:
127
"""
128
Format the results using the connection converters and options.
129
130
Parameters
131
----------
132
connection : Connection
133
The connection containing the converters and options
134
135
"""
136
self.converters = []
137
138
for item in self.description:
139
self.converters.append((
140
item.charset,
141
connection.decoders.get(item.type_code),
142
))
143
144
# Convert values
145
for i, row in enumerate(self.rows):
146
new_row = []
147
for (_, converter), value in zip(self.converters, row):
148
new_row.append(converter(value) if converter is not None else value)
149
self.rows[i] = tuple(new_row)
150
151
self.rows[:] = format_results(
152
connection._results_type, self.description, self.rows,
153
)
154
155
def __iter__(self) -> Iterable[Tuple[Any, ...]]:
156
return iter(self.rows)
157
158
def __len__(self) -> int:
159
return len(self.rows)
160
161
def __getitem__(self, key: Any) -> Tuple[Any, ...]:
162
if isinstance(key, str):
163
return self.__getattr__(key)
164
return self.rows[key]
165
166
def __getattr__(self, name: str) -> Any:
167
return FusionSQLColumn(self, self._field_indexes[name])
168
169
def add_field(self, name: str, dtype: int) -> None:
170
"""
171
Add a new field / column to the data set.
172
173
Parameters
174
----------
175
name : str
176
The name of the field / column
177
dtype : int
178
The MySQL field type: BLOB, BOOL, DATE, DATETIME,
179
DOUBLE, JSON, INTEGER, or STRING
180
181
"""
182
charset = 0
183
if dtype in (JSON, STRING):
184
encoding = 'utf-8'
185
elif dtype == BLOB:
186
charset = 63
187
encoding = None
188
else:
189
encoding = 'ascii'
190
self.description.append(
191
Description(name, dtype, None, None, 0, 0, True, 0, charset),
192
)
193
self.fields.append(FusionField(name, 0, charset))
194
self._field_indexes[name] = len(self.fields) - 1
195
self.converters.append((encoding, None))
196
197
def set_rows(self, data: List[Tuple[Any, ...]]) -> None:
198
"""
199
Set the rows of the result.
200
201
Parameters
202
----------
203
data : List[Tuple[Any, ...]]
204
The data should be a list of tuples where each element of the
205
tuple corresponds to a field added to the result with
206
the :meth:`add_field` method.
207
208
"""
209
self.rows = list(data)
210
self.affected_rows = 0
211
212
def like(self, **kwargs: str) -> FusionSQLResult:
213
"""
214
Return a new result containing only rows that match all `kwargs` like patterns.
215
216
Parameters
217
----------
218
**kwargs : str
219
Each parameter name corresponds to a column name in the result. The value
220
of the parameters is a LIKE pattern to match.
221
222
Returns
223
-------
224
FusionSQLResult
225
226
"""
227
likers = []
228
for k, v in kwargs.items():
229
if k not in self._field_indexes:
230
raise KeyError(f'field name does not exist in results: {k}')
231
if not v:
232
continue
233
regex = re.compile(
234
'^{}$'.format(
235
re.sub(r'\\%', r'.*', re.sub(r'([^\w])', r'\\\1', v)),
236
), flags=re.I | re.M,
237
)
238
likers.append((self._field_indexes[k], regex))
239
240
filtered_rows = []
241
for row in self.rows:
242
found = True
243
for i, liker in likers:
244
if row[i] is None or not liker.match(row[i]):
245
found = False
246
break
247
if found:
248
filtered_rows.append(row)
249
250
out = self.copy()
251
out.rows[:] = filtered_rows
252
return out
253
254
like_all = like
255
256
def like_any(self, **kwargs: str) -> FusionSQLResult:
257
"""
258
Return a new result containing only rows that match any `kwargs` like patterns.
259
260
Parameters
261
----------
262
**kwargs : str
263
Each parameter name corresponds to a column name in the result. The value
264
of the parameters is a LIKE pattern to match.
265
266
Returns
267
-------
268
FusionSQLResult
269
270
"""
271
likers = []
272
for k, v in kwargs.items():
273
if k not in self._field_indexes:
274
raise KeyError(f'field name does not exist in results: {k}')
275
if not v:
276
continue
277
regex = re.compile(
278
'^{}$'.format(
279
re.sub(r'\\%', r'.*', re.sub(r'([^\w])', r'\\\1', v)),
280
), flags=re.I | re.M,
281
)
282
likers.append((self._field_indexes[k], regex))
283
284
filtered_rows = []
285
for row in self.rows:
286
found = False
287
for i, liker in likers:
288
if liker.match(row[i]):
289
found = True
290
break
291
if found:
292
filtered_rows.append(row)
293
294
out = self.copy()
295
out.rows[:] = filtered_rows
296
return out
297
298
def filter(self, **kwargs: str) -> FusionSQLResult:
299
"""
300
Return a new result containing only rows that match all `kwargs` values.
301
302
Parameters
303
----------
304
**kwargs : str
305
Each parameter name corresponds to a column name in the result. The value
306
of the parameters is the value to match.
307
308
Returns
309
-------
310
FusionSQLResult
311
312
"""
313
if not kwargs:
314
return self.copy()
315
316
values = []
317
for k, v in kwargs.items():
318
if k not in self._field_indexes:
319
raise KeyError(f'field name does not exist in results: {k}')
320
values.append((self._field_indexes[k], v))
321
322
filtered_rows = []
323
for row in self.rows:
324
found = True
325
for i, val in values:
326
if row[0] != val:
327
found = False
328
break
329
if found:
330
filtered_rows.append(row)
331
332
out = self.copy()
333
out.rows[:] = filtered_rows
334
return out
335
336
def limit(self, n_rows: int) -> FusionSQLResult:
337
"""
338
Return a new result containing only `n_rows` rows.
339
340
Parameters
341
----------
342
n_rows : int
343
The number of rows to limit the result to
344
345
Returns
346
-------
347
FusionSQLResult
348
349
"""
350
out = self.copy()
351
if n_rows:
352
out.rows[:] = out.rows[:n_rows]
353
return out
354
355
def sort_by(
356
self,
357
by: Union[str, List[str]],
358
ascending: Union[bool, List[bool]] = True,
359
) -> FusionSQLResult:
360
"""
361
Return a new result with rows sorted in specified order.
362
363
Parameters
364
----------
365
by : str or List[str]
366
Name or names of columns to sort by
367
ascending : bool or List[bool], optional
368
Should the sort order be ascending? If not all sort columns
369
use the same ordering, a list of booleans can be supplied to
370
indicate the order for each column.
371
372
Returns
373
-------
374
FusionSQLResult
375
376
"""
377
if not by:
378
return self.copy()
379
380
if isinstance(by, str):
381
by = [by]
382
by = list(reversed(by))
383
384
if isinstance(ascending, bool):
385
ascending = [ascending]
386
ascending = list(reversed(ascending))
387
388
out = self.copy()
389
for i, byvar in enumerate(by):
390
out.rows.sort(
391
key=lambda x: (
392
0 if x[self._field_indexes[byvar]] is None else 1,
393
x[self._field_indexes[byvar]],
394
),
395
reverse=not ascending[i],
396
)
397
return out
398
399
order_by = sort_by
400
401