Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
singlestore-labs
GitHub Repository: singlestore-labs/singlestoredb-python
Path: blob/main/singlestoredb/http/connection.py
469 views
1
#!/usr/bin/env python
2
"""SingleStoreDB HTTP API interface."""
3
import datetime
4
import decimal
5
import functools
6
import io
7
import json
8
import math
9
import os
10
import re
11
import time
12
from base64 import b64decode
13
from typing import Any
14
from typing import Callable
15
from typing import Dict
16
from typing import Iterable
17
from typing import List
18
from typing import Optional
19
from typing import Sequence
20
from typing import Tuple
21
from typing import Union
22
from urllib.parse import urljoin
23
from urllib.parse import urlparse
24
25
import requests
26
27
try:
28
import numpy as np
29
has_numpy = True
30
except ImportError:
31
has_numpy = False
32
33
try:
34
import pygeos
35
has_pygeos = True
36
except ImportError:
37
has_pygeos = False
38
39
try:
40
import shapely.geometry
41
import shapely.wkt
42
has_shapely = True
43
except ImportError:
44
has_shapely = False
45
46
try:
47
import pydantic
48
has_pydantic = True
49
except ImportError:
50
has_pydantic = False
51
52
from .. import connection
53
from .. import fusion
54
from .. import types
55
from ..config import get_option
56
from ..converters import converters
57
from ..exceptions import DatabaseError # noqa: F401
58
from ..exceptions import DataError
59
from ..exceptions import Error # noqa: F401
60
from ..exceptions import IntegrityError
61
from ..exceptions import InterfaceError
62
from ..exceptions import InternalError
63
from ..exceptions import NotSupportedError
64
from ..exceptions import OperationalError
65
from ..exceptions import ProgrammingError
66
from ..exceptions import Warning # noqa: F401
67
from ..utils.convert_rows import convert_rows
68
from ..utils.debug import log_query
69
from ..utils.mogrify import mogrify
70
from ..utils.results import Description
71
from ..utils.results import format_results
72
from ..utils.results import get_schema
73
from ..utils.results import Result
74
75
76
# DB-API settings
77
apilevel = '2.0'
78
paramstyle = 'named'
79
threadsafety = 1
80
81
82
_interface_errors = set([
83
0,
84
2013, # CR_SERVER_LOST
85
2006, # CR_SERVER_GONE_ERROR
86
2012, # CR_HANDSHAKE_ERR
87
2004, # CR_IPSOCK_ERROR
88
2014, # CR_COMMANDS_OUT_OF_SYNC
89
])
90
_data_errors = set([
91
1406, # ER_DATA_TOO_LONG
92
1441, # ER_DATETIME_FUNCTION_OVERFLOW
93
1365, # ER_DIVISION_BY_ZERO
94
1230, # ER_NO_DEFAULT
95
1171, # ER_PRIMARY_CANT_HAVE_NULL
96
1264, # ER_WARN_DATA_OUT_OF_RANGE
97
1265, # ER_WARN_DATA_TRUNCATED
98
])
99
_programming_errors = set([
100
1065, # ER_EMPTY_QUERY
101
1179, # ER_CANT_DO_THIS_DURING_AN_TRANSACTION
102
1007, # ER_DB_CREATE_EXISTS
103
1110, # ER_FIELD_SPECIFIED_TWICE
104
1111, # ER_INVALID_GROUP_FUNC_USE
105
1082, # ER_NO_SUCH_INDEX
106
1741, # ER_NO_SUCH_KEY_VALUE
107
1146, # ER_NO_SUCH_TABLE
108
1449, # ER_NO_SUCH_USER
109
1064, # ER_PARSE_ERROR
110
1149, # ER_SYNTAX_ERROR
111
1113, # ER_TABLE_MUST_HAVE_COLUMNS
112
1112, # ER_UNSUPPORTED_EXTENSION
113
1102, # ER_WRONG_DB_NAME
114
1103, # ER_WRONG_TABLE_NAME
115
1049, # ER_BAD_DB_ERROR
116
1582, # ER_??? Wrong number of args
117
])
118
_integrity_errors = set([
119
1215, # ER_CANNOT_ADD_FOREIGN
120
1062, # ER_DUP_ENTRY
121
1169, # ER_DUP_UNIQUE
122
1364, # ER_NO_DEFAULT_FOR_FIELD
123
1216, # ER_NO_REFERENCED_ROW
124
1452, # ER_NO_REFERENCED_ROW_2
125
1217, # ER_ROW_IS_REFERENCED
126
1451, # ER_ROW_IS_REFERENCED_2
127
1460, # ER_XAER_OUTSIDE
128
1401, # ER_XAER_RMERR
129
1048, # ER_BAD_NULL_ERROR
130
1264, # ER_DATA_OUT_OF_RANGE
131
4025, # ER_CONSTRAINT_FAILED
132
1826, # ER_DUP_CONSTRAINT_NAME
133
])
134
135
136
def get_precision_scale(type_code: str) -> Tuple[Optional[int], Optional[int]]:
137
"""Parse the precision and scale from a data type."""
138
if '(' not in type_code:
139
return (None, None)
140
m = re.search(r'\(\s*(\d+)\s*,\s*(\d+)\s*\)', type_code)
141
if m:
142
return int(m.group(1)), int(m.group(2))
143
m = re.search(r'\(\s*(\d+)\s*\)', type_code)
144
if m:
145
return (int(m.group(1)), None)
146
raise ValueError(f'Unrecognized type code: {type_code}')
147
148
149
def get_exc_type(code: int) -> type:
150
"""Map error code to DB-API error type."""
151
if code in _interface_errors:
152
return InterfaceError
153
if code in _data_errors:
154
return DataError
155
if code in _programming_errors:
156
return ProgrammingError
157
if code in _integrity_errors:
158
return IntegrityError
159
if code >= 1000:
160
return OperationalError
161
return InternalError
162
163
164
def identity(x: Any) -> Any:
165
"""Return input value."""
166
return x
167
168
169
def b64decode_converter(
170
converter: Callable[..., Any],
171
x: Optional[str],
172
encoding: str = 'utf-8',
173
) -> Optional[bytes]:
174
"""Decode value before applying converter."""
175
if x is None:
176
return None
177
if converter is None:
178
return b64decode(x)
179
return converter(b64decode(x))
180
181
182
def encode_timedelta(obj: datetime.timedelta) -> str:
183
"""Encode timedelta as str."""
184
seconds = int(obj.seconds) % 60
185
minutes = int(obj.seconds // 60) % 60
186
hours = int(obj.seconds // 3600) % 24 + int(obj.days) * 24
187
if obj.microseconds:
188
fmt = '{0:02d}:{1:02d}:{2:02d}.{3:06d}'
189
else:
190
fmt = '{0:02d}:{1:02d}:{2:02d}'
191
return fmt.format(hours, minutes, seconds, obj.microseconds)
192
193
194
def encode_time(obj: datetime.time) -> str:
195
"""Encode time as str."""
196
if obj.microsecond:
197
fmt = '{0.hour:02}:{0.minute:02}:{0.second:02}.{0.microsecond:06}'
198
else:
199
fmt = '{0.hour:02}:{0.minute:02}:{0.second:02}'
200
return fmt.format(obj)
201
202
203
def encode_datetime(obj: datetime.datetime) -> str:
204
"""Encode datetime as str."""
205
if obj.microsecond:
206
fmt = '{0.year:04}-{0.month:02}-{0.day:02} ' \
207
'{0.hour:02}:{0.minute:02}:{0.second:02}.{0.microsecond:06}'
208
else:
209
fmt = '{0.year:04}-{0.month:02}-{0.day:02} ' \
210
'{0.hour:02}:{0.minute:02}:{0.second:02}'
211
return fmt.format(obj)
212
213
214
def encode_date(obj: datetime.date) -> str:
215
"""Encode date as str."""
216
fmt = '{0.year:04}-{0.month:02}-{0.day:02}'
217
return fmt.format(obj)
218
219
220
def encode_struct_time(obj: time.struct_time) -> str:
221
"""Encode time struct to str."""
222
return encode_datetime(datetime.datetime(*obj[:6]))
223
224
225
def encode_decimal(o: decimal.Decimal) -> str:
226
"""Encode decimal to str."""
227
return format(o, 'f')
228
229
230
# Most argument encoding is done by the JSON encoder, but these
231
# are exceptions to the rule.
232
encoders = {
233
datetime.datetime: encode_datetime,
234
datetime.date: encode_date,
235
datetime.time: encode_time,
236
datetime.timedelta: encode_timedelta,
237
time.struct_time: encode_struct_time,
238
decimal.Decimal: encode_decimal,
239
}
240
241
242
if has_shapely:
243
encoders[shapely.geometry.Point] = shapely.wkt.dumps
244
encoders[shapely.geometry.Polygon] = shapely.wkt.dumps
245
encoders[shapely.geometry.LineString] = shapely.wkt.dumps
246
247
if has_numpy:
248
249
def encode_ndarray(obj: np.ndarray) -> bytes: # type: ignore
250
"""Encode an ndarray as bytes."""
251
return obj.tobytes()
252
253
encoders[np.ndarray] = encode_ndarray
254
255
if has_pygeos:
256
encoders[pygeos.Geometry] = pygeos.io.to_wkt
257
258
259
def convert_special_type(
260
arg: Any,
261
nan_as_null: bool = False,
262
inf_as_null: bool = False,
263
) -> Any:
264
"""Convert special data type objects."""
265
dtype = type(arg)
266
if dtype is float or \
267
(
268
has_numpy and dtype in (
269
np.float16, np.float32, np.float64,
270
getattr(np, 'float128', np.float64),
271
)
272
):
273
if nan_as_null and math.isnan(arg):
274
return None
275
if inf_as_null and math.isinf(arg):
276
return None
277
func = encoders.get(dtype, None)
278
if func is not None:
279
return func(arg) # type: ignore
280
return arg
281
282
283
def convert_special_params(
284
params: Optional[Union[Sequence[Any], Dict[str, Any]]] = None,
285
nan_as_null: bool = False,
286
inf_as_null: bool = False,
287
) -> Optional[Union[Sequence[Any], Dict[str, Any]]]:
288
"""Convert parameters of special data types."""
289
if params is None:
290
return params
291
converter = functools.partial(
292
convert_special_type,
293
nan_as_null=nan_as_null,
294
inf_as_null=inf_as_null,
295
)
296
if isinstance(params, Dict):
297
return {k: converter(v) for k, v in params.items()}
298
return tuple(map(converter, params))
299
300
301
class PyMyField(object):
302
"""Field for PyMySQL compatibility."""
303
304
def __init__(self, name: str, flags: int, charset: int) -> None:
305
self.name = name
306
self.flags = flags
307
self.charsetnr = charset
308
309
310
class PyMyResult(object):
311
"""Result for PyMySQL compatibility."""
312
313
def __init__(self) -> None:
314
self.fields: List[PyMyField] = []
315
self.unbuffered_active = False
316
317
def append(self, item: PyMyField) -> None:
318
self.fields.append(item)
319
320
321
class Cursor(connection.Cursor):
322
"""
323
SingleStoreDB HTTP database cursor.
324
325
Cursor objects should not be created directly. They should come from
326
the `cursor` method on the `Connection` object.
327
328
Parameters
329
----------
330
connection : Connection
331
The HTTP Connection object the cursor belongs to
332
333
"""
334
335
def __init__(self, conn: 'Connection'):
336
connection.Cursor.__init__(self, conn)
337
self._connection: Optional[Connection] = conn
338
self._results: List[List[Tuple[Any, ...]]] = [[]]
339
self._results_type: str = self._connection._results_type \
340
if self._connection is not None else 'tuples'
341
self._row_idx: int = -1
342
self._result_idx: int = -1
343
self._descriptions: List[List[Description]] = []
344
self._schemas: List[Dict[str, Any]] = []
345
self.arraysize: int = get_option('results.arraysize')
346
self.rowcount: int = 0
347
self.lastrowid: Optional[int] = None
348
self._pymy_results: List[PyMyResult] = []
349
self._expect_results: bool = False
350
351
@property
352
def _result(self) -> Optional[PyMyResult]:
353
"""Return Result object for PyMySQL compatibility."""
354
if self._result_idx < 0:
355
return None
356
return self._pymy_results[self._result_idx]
357
358
@property
359
def description(self) -> Optional[List[Description]]:
360
"""Return description for current result set."""
361
if not self._descriptions:
362
return None
363
if self._result_idx >= 0 and self._result_idx < len(self._descriptions):
364
return self._descriptions[self._result_idx]
365
return None
366
367
@property
368
def _schema(self) -> Optional[Any]:
369
if not self._schemas:
370
return None
371
if self._result_idx >= 0 and self._result_idx < len(self._schemas):
372
return self._schemas[self._result_idx]
373
return None
374
375
def _post(self, path: str, *args: Any, **kwargs: Any) -> requests.Response:
376
"""
377
Invoke a POST request on the HTTP connection.
378
379
Parameters
380
----------
381
path : str
382
The path of the resource
383
*args : positional parameters, optional
384
Extra parameters to the POST request
385
**kwargs : keyword parameters, optional
386
Extra keyword parameters to the POST request
387
388
Returns
389
-------
390
requests.Response
391
392
"""
393
if self._connection is None:
394
raise ProgrammingError(errno=2048, msg='Connection is closed.')
395
if 'timeout' not in kwargs:
396
kwargs['timeout'] = self._connection.connection_params['connect_timeout']
397
return self._connection._post(path, *args, **kwargs)
398
399
def callproc(
400
self, name: str,
401
params: Optional[Sequence[Any]] = None,
402
) -> None:
403
"""
404
Call a stored procedure.
405
406
Parameters
407
----------
408
name : str
409
Name of the stored procedure
410
params : sequence, optional
411
Parameters to the stored procedure
412
413
"""
414
if self._connection is None:
415
raise ProgrammingError(errno=2048, msg='Connection is closed.')
416
417
name = connection._name_check(name)
418
419
if not params:
420
self._execute(f'CALL {name}();', is_callproc=True)
421
else:
422
keys = ', '.join(['%s' for i in range(len(params))])
423
self._execute(f'CALL {name}({keys});', params, is_callproc=True)
424
425
def close(self) -> None:
426
"""Close the cursor."""
427
self._connection = None
428
429
def execute(
430
self, query: str,
431
args: Optional[Union[Sequence[Any], Dict[str, Any]]] = None,
432
infile_stream: Optional[ # type: ignore
433
Union[
434
io.RawIOBase,
435
io.TextIOBase,
436
Iterable[Union[bytes, str]],
437
connection.InfileQueue,
438
]
439
] = None,
440
) -> int:
441
"""
442
Execute a SQL statement.
443
444
Parameters
445
----------
446
query : str
447
The SQL statement to execute
448
args : iterable or dict, optional
449
Parameters to substitute into the SQL code
450
451
"""
452
return self._execute(query, args, infile_stream=infile_stream)
453
454
def _validate_param_subs(
455
self, query: str,
456
args: Optional[Union[Sequence[Any], Dict[str, Any]]] = None,
457
) -> None:
458
"""Make sure the parameter substitions are valid."""
459
if args:
460
if isinstance(args, Sequence):
461
query = query % tuple(args)
462
else:
463
query = query % args
464
465
def _execute_fusion_query(
466
self,
467
oper: Union[str, bytes],
468
params: Optional[Union[Sequence[Any], Dict[str, Any]]] = None,
469
handler: Any = None,
470
) -> int:
471
oper = mogrify(oper, params)
472
473
if isinstance(oper, bytes):
474
oper = oper.decode('utf-8')
475
476
log_query(oper, None)
477
478
results_type = self._results_type
479
self._results_type = 'tuples'
480
try:
481
mgmt_res = fusion.execute(
482
self._connection, # type: ignore
483
oper,
484
handler=handler,
485
)
486
finally:
487
self._results_type = results_type
488
489
self._descriptions.append(list(mgmt_res.description))
490
self._schemas.append(get_schema(self._results_type, list(mgmt_res.description)))
491
self._results.append(list(mgmt_res.rows))
492
self.rowcount = len(self._results[-1])
493
494
pymy_res = PyMyResult()
495
for field in mgmt_res.fields:
496
pymy_res.append(
497
PyMyField(
498
field.name,
499
field.flags,
500
field.charsetnr,
501
),
502
)
503
504
self._pymy_results.append(pymy_res)
505
506
if self._results and self._results[0]:
507
self._row_idx = 0
508
self._result_idx = 0
509
510
return self.rowcount
511
512
def _execute(
513
self, oper: str,
514
params: Optional[Union[Sequence[Any], Dict[str, Any]]] = None,
515
is_callproc: bool = False,
516
infile_stream: Optional[ # type: ignore
517
Union[
518
io.RawIOBase,
519
io.TextIOBase,
520
Iterable[Union[bytes, str]],
521
connection.InfileQueue,
522
]
523
] = None,
524
) -> int:
525
self._descriptions = []
526
self._schemas = []
527
self._results = []
528
self._pymy_results = []
529
self._row_idx = -1
530
self._result_idx = -1
531
self.rowcount = 0
532
self._expect_results = False
533
534
if self._connection is None:
535
raise ProgrammingError(errno=2048, msg='Connection is closed.')
536
537
sql_type = 'exec'
538
if re.match(r'^\s*(select|show|call|echo|describe|with)\s+', oper, flags=re.I):
539
self._expect_results = True
540
sql_type = 'query'
541
542
if has_pydantic and isinstance(params, pydantic.BaseModel):
543
params = params.model_dump()
544
545
self._validate_param_subs(oper, params)
546
547
handler = fusion.get_handler(oper)
548
if handler is not None:
549
return self._execute_fusion_query(oper, params, handler=handler)
550
551
oper, params = self._connection._convert_params(oper, params)
552
553
log_query(oper, params)
554
555
data: Dict[str, Any] = dict(sql=oper)
556
if params is not None:
557
data['args'] = convert_special_params(
558
params,
559
nan_as_null=self._connection.connection_params['nan_as_null'],
560
inf_as_null=self._connection.connection_params['inf_as_null'],
561
)
562
if self._connection._database:
563
data['database'] = self._connection._database
564
565
if sql_type == 'query':
566
res = self._post('query/tuples', json=data)
567
else:
568
res = self._post('exec', json=data)
569
570
if res.status_code >= 400:
571
if res.text:
572
m = re.match(r'^Error\s+(\d+).*?:', res.text)
573
if m:
574
code = m.group(1)
575
msg = res.text.split(':', 1)[-1]
576
icode = int(code.split()[-1])
577
else:
578
icode = res.status_code
579
msg = res.text
580
raise get_exc_type(icode)(icode, msg.strip())
581
raise InterfaceError(errno=res.status_code, msg='HTTP Error')
582
583
out = json.loads(res.text)
584
585
if 'error' in out:
586
raise OperationalError(
587
errno=out['error'].get('code', 0),
588
msg=out['error'].get('message', 'HTTP Error'),
589
)
590
591
if sql_type == 'query':
592
# description: (name, type_code, display_size, internal_size,
593
# precision, scale, null_ok, column_flags, charset)
594
595
# Remove converters for things the JSON parser already converted
596
http_converters = dict(self._connection.decoders)
597
http_converters.pop(4, None)
598
http_converters.pop(5, None)
599
http_converters.pop(6, None)
600
http_converters.pop(15, None)
601
http_converters.pop(245, None)
602
http_converters.pop(247, None)
603
http_converters.pop(249, None)
604
http_converters.pop(250, None)
605
http_converters.pop(251, None)
606
http_converters.pop(252, None)
607
http_converters.pop(253, None)
608
http_converters.pop(254, None)
609
610
# Merge passed in converters
611
if self._connection._conv:
612
for k, v in self._connection._conv.items():
613
if isinstance(k, int):
614
http_converters[k] = v
615
616
# Make JSON a string for Arrow
617
if 'arrow' in self._results_type:
618
def json_to_str(x: Any) -> Optional[str]:
619
if x is None:
620
return None
621
return json.dumps(x)
622
http_converters[245] = json_to_str
623
624
# Don't convert date/times in polars
625
elif 'polars' in self._results_type:
626
http_converters.pop(7, None)
627
http_converters.pop(10, None)
628
http_converters.pop(12, None)
629
630
results = out['results']
631
632
# Convert data to Python types
633
if results and results[0]:
634
self._row_idx = 0
635
self._result_idx = 0
636
637
for result in results:
638
639
pymy_res = PyMyResult()
640
convs = []
641
642
description: List[Description] = []
643
for i, col in enumerate(result.get('columns', [])):
644
charset = 0
645
flags = 0
646
data_type = col['dataType'].split('(')[0]
647
type_code = types.ColumnType.get_code(data_type)
648
prec, scale = get_precision_scale(col['dataType'])
649
converter = http_converters.get(type_code, None)
650
651
if 'UNSIGNED' in data_type:
652
flags = 32
653
654
if data_type.endswith('BLOB') or data_type.endswith('BINARY'):
655
converter = functools.partial(
656
b64decode_converter, converter, # type: ignore
657
)
658
charset = 63 # BINARY
659
660
if type_code == 0: # DECIMAL
661
type_code = types.ColumnType.get_code('NEWDECIMAL')
662
elif type_code == 15: # VARCHAR / VARBINARY
663
type_code = types.ColumnType.get_code('VARSTRING')
664
665
if converter is not None:
666
convs.append((i, None, converter))
667
668
description.append(
669
Description(
670
str(col['name']), type_code,
671
None, None, prec, scale,
672
col.get('nullable', False),
673
flags, charset,
674
),
675
)
676
pymy_res.append(PyMyField(col['name'], flags, charset))
677
678
self._descriptions.append(description)
679
self._schemas.append(get_schema(self._results_type, description))
680
681
rows = convert_rows(result.get('rows', []), convs)
682
683
self._results.append(rows)
684
self._pymy_results.append(pymy_res)
685
686
# For compatibility with PyMySQL/MySQLdb
687
if is_callproc:
688
self._results.append([])
689
690
self.rowcount = len(self._results[0])
691
692
else:
693
# For compatibility with PyMySQL/MySQLdb
694
if is_callproc:
695
self._results.append([])
696
697
self.rowcount = out['rowsAffected']
698
699
return self.rowcount
700
701
def executemany(
702
self, query: str,
703
args: Optional[Sequence[Union[Sequence[Any], Dict[str, Any]]]] = None,
704
) -> int:
705
"""
706
Execute SQL code against multiple sets of parameters.
707
708
Parameters
709
----------
710
query : str
711
The SQL statement to execute
712
args : iterable of iterables or dicts, optional
713
Sets of parameters to substitute into the SQL code
714
715
"""
716
if self._connection is None:
717
raise ProgrammingError(errno=2048, msg='Connection is closed.')
718
719
results = []
720
rowcount = 0
721
if args is not None and len(args) > 0:
722
description = []
723
schema = {}
724
# Detect dataframes
725
if hasattr(args, 'itertuples'):
726
argiter = args.itertuples(index=False) # type: ignore
727
else:
728
argiter = iter(args)
729
for params in argiter:
730
self.execute(query, params)
731
if self._descriptions:
732
description = self._descriptions[-1]
733
if self._schemas:
734
schema = self._schemas[-1]
735
if self._rows is not None:
736
results.append(self._rows)
737
rowcount += self.rowcount
738
self._results = results
739
self._descriptions = [description for _ in range(len(results))]
740
self._schemas = [schema for _ in range(len(results))]
741
else:
742
self.execute(query)
743
rowcount += self.rowcount
744
745
self.rowcount = rowcount
746
747
return self.rowcount
748
749
@property
750
def _has_row(self) -> bool:
751
"""Determine if a row is available."""
752
if self._result_idx < 0 or self._result_idx >= len(self._results):
753
return False
754
if self._row_idx < 0 or self._row_idx >= len(self._results[self._result_idx]):
755
return False
756
return True
757
758
@property
759
def _rows(self) -> List[Tuple[Any, ...]]:
760
"""Return current set of rows."""
761
if not self._has_row:
762
return []
763
return self._results[self._result_idx]
764
765
def fetchone(self) -> Optional[Result]:
766
"""
767
Fetch a single row from the result set.
768
769
Returns
770
-------
771
tuple
772
Values of the returned row if there are rows remaining
773
None
774
If there are no rows left to return
775
776
"""
777
if self._connection is None:
778
raise ProgrammingError(errno=2048, msg='Connection is closed')
779
if not self._expect_results:
780
raise self._connection.ProgrammingError(msg='No query has been submitted')
781
if not self._has_row:
782
return None
783
out = self._rows[self._row_idx]
784
self._row_idx += 1
785
return format_results(
786
self._results_type,
787
self.description or [],
788
out, single=True,
789
schema=self._schema,
790
)
791
792
def fetchmany(
793
self,
794
size: Optional[int] = None,
795
) -> Result:
796
"""
797
Fetch `size` rows from the result.
798
799
If `size` is not specified, the `arraysize` attribute is used.
800
801
Returns
802
-------
803
list of tuples
804
Values of the returned rows if there are rows remaining
805
806
"""
807
if self._connection is None:
808
raise ProgrammingError(errno=2048, msg='Connection is closed')
809
if not self._expect_results:
810
raise self._connection.ProgrammingError(msg='No query has been submitted')
811
if not self._has_row:
812
if 'dict' in self._results_type:
813
return {}
814
return tuple()
815
if not size:
816
size = max(int(self.arraysize), 1)
817
else:
818
size = max(int(size), 1)
819
out = self._rows[self._row_idx:self._row_idx+size]
820
self._row_idx += len(out)
821
return format_results(
822
self._results_type, self.description or [],
823
out, schema=self._schema,
824
)
825
826
def fetchall(self) -> Result:
827
"""
828
Fetch all rows in the result set.
829
830
Returns
831
-------
832
list of tuples
833
Values of the returned rows if there are rows remaining
834
835
"""
836
if self._connection is None:
837
raise ProgrammingError(errno=2048, msg='Connection is closed')
838
if not self._expect_results:
839
raise self._connection.ProgrammingError(msg='No query has been submitted')
840
if not self._has_row:
841
if 'dict' in self._results_type:
842
return {}
843
return tuple()
844
out = list(self._rows[self._row_idx:])
845
self._row_idx = len(out)
846
return format_results(
847
self._results_type, self.description or [],
848
out, schema=self._schema,
849
)
850
851
def nextset(self) -> Optional[bool]:
852
"""Skip to the next available result set."""
853
if self._connection is None:
854
raise ProgrammingError(errno=2048, msg='Connection is closed')
855
856
if self._result_idx < 0:
857
self._row_idx = -1
858
return None
859
860
self._result_idx += 1
861
self._row_idx = 0
862
863
if self._result_idx >= len(self._results):
864
self._result_idx = -1
865
self._row_idx = -1
866
return None
867
868
self.rowcount = len(self._results[self._result_idx])
869
870
return True
871
872
def setinputsizes(self, sizes: Sequence[int]) -> None:
873
"""Predefine memory areas for parameters."""
874
pass
875
876
def setoutputsize(self, size: int, column: Optional[str] = None) -> None:
877
"""Set a column buffer size for fetches of large columns."""
878
pass
879
880
@property
881
def rownumber(self) -> Optional[int]:
882
"""
883
Return the zero-based index of the cursor in the result set.
884
885
Returns
886
-------
887
int
888
889
"""
890
if self._row_idx < 0:
891
return None
892
return self._row_idx
893
894
def scroll(self, value: int, mode: str = 'relative') -> None:
895
"""
896
Scroll the cursor to the position in the result set.
897
898
Parameters
899
----------
900
value : int
901
Value of the positional move
902
mode : str
903
Type of move that should be made: 'relative' or 'absolute'
904
905
"""
906
if self._connection is None:
907
raise ProgrammingError(errno=2048, msg='Connection is closed')
908
if mode == 'relative':
909
self._row_idx += value
910
elif mode == 'absolute':
911
self._row_idx = value
912
else:
913
raise ValueError(
914
f'{mode} is not a valid mode, '
915
'expecting "relative" or "absolute"',
916
)
917
918
def next(self) -> Optional[Result]:
919
"""
920
Return the next row from the result set for use in iterators.
921
922
Returns
923
-------
924
tuple
925
Values from the next result row
926
None
927
If no more rows exist
928
929
"""
930
if self._connection is None:
931
raise InterfaceError(errno=2048, msg='Connection is closed')
932
out = self.fetchone()
933
if out is None:
934
raise StopIteration
935
return out
936
937
__next__ = next
938
939
def __iter__(self) -> Iterable[Tuple[Any, ...]]:
940
"""Return result iterator."""
941
return iter(self._rows[self._row_idx:])
942
943
def __enter__(self) -> 'Cursor':
944
"""Enter a context."""
945
return self
946
947
def __exit__(
948
self, exc_type: Optional[object],
949
exc_value: Optional[Exception], exc_traceback: Optional[str],
950
) -> None:
951
"""Exit a context."""
952
self.close()
953
954
@property
955
def open(self) -> bool:
956
"""Check if the cursor is still connected."""
957
if self._connection is None:
958
return False
959
return self._connection.is_connected()
960
961
def is_connected(self) -> bool:
962
"""
963
Check if the cursor is still connected.
964
965
Returns
966
-------
967
bool
968
969
"""
970
return self.open
971
972
973
class Connection(connection.Connection):
974
"""
975
SingleStoreDB HTTP database connection.
976
977
Instances of this object are typically created through the
978
`connection` function rather than creating them directly.
979
980
See Also
981
--------
982
`connect`
983
984
"""
985
driver = 'https'
986
paramstyle = 'qmark'
987
988
def __init__(self, **kwargs: Any):
989
from .. import __version__ as client_version
990
991
if 'SINGLESTOREDB_WORKLOAD_TYPE' in os.environ:
992
client_version += '+' + os.environ['SINGLESTOREDB_WORKLOAD_TYPE']
993
994
connection.Connection.__init__(self, **kwargs)
995
996
host = kwargs.get('host', get_option('host'))
997
port = kwargs.get('port', get_option('http_port'))
998
999
self._sess: Optional[requests.Session] = requests.Session()
1000
1001
user = kwargs.get('user', get_option('user'))
1002
password = kwargs.get('password', get_option('password'))
1003
if user is not None and password is not None:
1004
self._sess.auth = (user, password)
1005
elif user is not None:
1006
self._sess.auth = (user, '')
1007
self._sess.headers.update({
1008
'Content-Type': 'application/json',
1009
'Accept': 'application/json',
1010
'Accept-Encoding': 'compress,identity',
1011
'User-Agent': f'SingleStoreDB-Python/{client_version}',
1012
})
1013
1014
if kwargs.get('ssl_disabled', get_option('ssl_disabled')):
1015
self._sess.verify = False
1016
else:
1017
ssl_key = kwargs.get('ssl_key', get_option('ssl_key'))
1018
ssl_cert = kwargs.get('ssl_cert', get_option('ssl_cert'))
1019
if ssl_key and ssl_cert:
1020
self._sess.cert = (ssl_key, ssl_cert)
1021
elif ssl_cert:
1022
self._sess.cert = ssl_cert
1023
1024
ssl_ca = kwargs.get('ssl_ca', get_option('ssl_ca'))
1025
if ssl_ca:
1026
self._sess.verify = ssl_ca
1027
1028
ssl_verify_cert = kwargs.get('ssl_verify_cert', True)
1029
if not ssl_verify_cert:
1030
self._sess.verify = False
1031
1032
if kwargs.get('multi_statements', False):
1033
raise self.InterfaceError(
1034
0, 'The Data API does not allow multiple '
1035
'statements within a query',
1036
)
1037
1038
self._version = kwargs.get('version', 'v2')
1039
self.driver = kwargs.get('driver', 'https')
1040
1041
self.encoders = {k: v for (k, v) in converters.items() if type(k) is not int}
1042
self.decoders = {k: v for (k, v) in converters.items() if type(k) is int}
1043
1044
self._database = kwargs.get('database', get_option('database'))
1045
self._url = f'{self.driver}://{host}:{port}/api/{self._version}/'
1046
self._host = host
1047
self._messages: List[Tuple[int, str]] = []
1048
self._autocommit: bool = True
1049
self._conv = kwargs.get('conv', None)
1050
self._in_sync: bool = False
1051
self._track_env: bool = kwargs.get('track_env', False) \
1052
or host == 'singlestore.com'
1053
1054
@property
1055
def messages(self) -> List[Tuple[int, str]]:
1056
return self._messages
1057
1058
def connect(self) -> 'Connection':
1059
"""Connect to the server."""
1060
return self
1061
1062
def _sync_connection(self, kwargs: Dict[str, Any]) -> None:
1063
"""Synchronize connection with env variable."""
1064
if self._sess is None:
1065
raise InterfaceError(errno=2048, msg='Connection is closed.')
1066
1067
if self._in_sync:
1068
return
1069
1070
if not self._track_env:
1071
return
1072
1073
url = os.environ.get('SINGLESTOREDB_URL')
1074
if not url:
1075
if self._host == 'singlestore.com':
1076
raise InterfaceError(0, 'Connection URL has not been established')
1077
return
1078
1079
out = {}
1080
urlp = connection._parse_url(url)
1081
out.update(urlp)
1082
out = connection._cast_params(out)
1083
1084
# Set default port based on driver.
1085
if 'port' not in out or not out['port']:
1086
if out.get('driver', 'https') == 'http':
1087
out['port'] = int(get_option('port') or 80)
1088
else:
1089
out['port'] = int(get_option('port') or 443)
1090
1091
# If there is no user and the password is empty, remove the password key.
1092
if 'user' not in out and not out.get('password', None):
1093
out.pop('password', None)
1094
1095
if out['host'] == 'singlestore.com':
1096
raise InterfaceError(0, 'Connection URL has not been established')
1097
1098
# Get current connection attributes
1099
curr_url = urlparse(self._url, scheme='singlestoredb', allow_fragments=True)
1100
if self._sess.auth is not None:
1101
auth = tuple(self._sess.auth) # type: ignore
1102
else:
1103
auth = (None, None) # type: ignore
1104
1105
# If it's just a password change, we don't need to reconnect
1106
if (curr_url.hostname, curr_url.port, auth[0], self._database) == \
1107
(out['host'], out['port'], out['user'], out.get('database')):
1108
return
1109
1110
try:
1111
self._in_sync = True
1112
sess = requests.Session()
1113
sess.auth = (out['user'], out['password'])
1114
sess.headers.update(self._sess.headers)
1115
sess.verify = self._sess.verify
1116
sess.cert = self._sess.cert
1117
self._database = out.get('database')
1118
self._host = out['host']
1119
self._url = f'{out.get("driver", "https")}://{out["host"]}:{out["port"]}' \
1120
f'/api/{self._version}/'
1121
self._sess = sess
1122
if self._database:
1123
kwargs['json']['database'] = self._database
1124
finally:
1125
self._in_sync = False
1126
1127
def _post(self, path: str, *args: Any, **kwargs: Any) -> requests.Response:
1128
"""
1129
Invoke a POST request on the HTTP connection.
1130
1131
Parameters
1132
----------
1133
path : str
1134
The path of the resource
1135
*args : positional parameters, optional
1136
Extra parameters to the POST request
1137
**kwargs : keyword parameters, optional
1138
Extra keyword parameters to the POST request
1139
1140
Returns
1141
-------
1142
requests.Response
1143
1144
"""
1145
if self._sess is None:
1146
raise InterfaceError(errno=2048, msg='Connection is closed.')
1147
1148
self._sync_connection(kwargs)
1149
1150
return self._sess.post(urljoin(self._url, path), *args, **kwargs)
1151
1152
def close(self) -> None:
1153
"""Close the connection."""
1154
if self._host == 'singlestore.com':
1155
return
1156
if self._sess is None:
1157
raise Error(errno=2048, msg='Connection is closed')
1158
self._sess = None
1159
1160
def autocommit(self, value: bool = True) -> None:
1161
"""Set autocommit mode."""
1162
if self._host == 'singlestore.com':
1163
return
1164
if self._sess is None:
1165
raise InterfaceError(errno=2048, msg='Connection is closed')
1166
self._autocommit = value
1167
1168
def commit(self) -> None:
1169
"""Commit the pending transaction."""
1170
if self._host == 'singlestore.com':
1171
return
1172
if self._sess is None:
1173
raise InterfaceError(errno=2048, msg='Connection is closed')
1174
if self._autocommit:
1175
return
1176
raise NotSupportedError(msg='operation not supported')
1177
1178
def rollback(self) -> None:
1179
"""Rollback the pending transaction."""
1180
if self._host == 'singlestore.com':
1181
return
1182
if self._sess is None:
1183
raise InterfaceError(errno=2048, msg='Connection is closed')
1184
if self._autocommit:
1185
return
1186
raise NotSupportedError(msg='operation not supported')
1187
1188
def cursor(self) -> Cursor:
1189
"""
1190
Create a new cursor object.
1191
1192
Returns
1193
-------
1194
Cursor
1195
1196
"""
1197
return Cursor(self)
1198
1199
def __enter__(self) -> 'Connection':
1200
"""Enter a context."""
1201
return self
1202
1203
def __exit__(
1204
self, exc_type: Optional[object],
1205
exc_value: Optional[Exception], exc_traceback: Optional[str],
1206
) -> None:
1207
"""Exit a context."""
1208
self.close()
1209
1210
@property
1211
def open(self) -> bool:
1212
"""Check if the database is still connected."""
1213
if self._sess is None:
1214
return False
1215
url = '/'.join(self._url.split('/')[:3]) + '/ping'
1216
res = self._sess.get(url)
1217
if res.status_code <= 400 and res.text == 'pong':
1218
return True
1219
return False
1220
1221
def is_connected(self) -> bool:
1222
"""
1223
Check if the database is still connected.
1224
1225
Returns
1226
-------
1227
bool
1228
1229
"""
1230
return self.open
1231
1232
1233
def connect(
1234
host: Optional[str] = None,
1235
user: Optional[str] = None,
1236
password: Optional[str] = None,
1237
port: Optional[int] = None,
1238
database: Optional[str] = None,
1239
driver: Optional[str] = None,
1240
pure_python: Optional[bool] = None,
1241
local_infile: Optional[bool] = None,
1242
charset: Optional[str] = None,
1243
ssl_key: Optional[str] = None,
1244
ssl_cert: Optional[str] = None,
1245
ssl_ca: Optional[str] = None,
1246
ssl_disabled: Optional[bool] = None,
1247
ssl_cipher: Optional[str] = None,
1248
ssl_verify_cert: Optional[bool] = None,
1249
ssl_verify_identity: Optional[bool] = None,
1250
conv: Optional[Dict[int, Callable[..., Any]]] = None,
1251
credential_type: Optional[str] = None,
1252
autocommit: Optional[bool] = None,
1253
results_type: Optional[str] = None,
1254
buffered: Optional[bool] = None,
1255
results_format: Optional[str] = None,
1256
program_name: Optional[str] = None,
1257
conn_attrs: Optional[Dict[str, str]] = None,
1258
multi_statements: Optional[bool] = None,
1259
connect_timeout: Optional[int] = None,
1260
nan_as_null: Optional[bool] = None,
1261
inf_as_null: Optional[bool] = None,
1262
encoding_errors: Optional[str] = None,
1263
track_env: Optional[bool] = None,
1264
enable_extended_data_types: Optional[bool] = None,
1265
vector_data_format: Optional[str] = None,
1266
) -> Connection:
1267
return Connection(**dict(locals()))
1268
1269