Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
singlestore-labs
GitHub Repository: singlestore-labs/singlestoredb-python
Path: blob/main/singlestoredb/http/connection.py
801 views
1
#!/usr/bin/env python
2
"""SingleStoreDB HTTP API interface."""
3
import datetime
4
import decimal
5
import functools
6
import io
7
import json
8
import math
9
import os
10
import re
11
import time
12
from base64 import b64decode
13
from collections.abc import Iterable
14
from collections.abc import Sequence
15
from typing import Any
16
from typing import Callable
17
from typing import Dict
18
from typing import List
19
from typing import Optional
20
from typing import Tuple
21
from typing import Union
22
from urllib.parse import urljoin
23
from urllib.parse import urlparse
24
25
import requests
26
27
try:
28
import numpy as np
29
has_numpy = True
30
except ImportError:
31
has_numpy = False
32
33
try:
34
import pygeos
35
has_pygeos = True
36
except ImportError:
37
has_pygeos = False
38
39
try:
40
import shapely.geometry
41
import shapely.wkt
42
has_shapely = True
43
except ImportError:
44
has_shapely = False
45
46
try:
47
import pydantic
48
has_pydantic = True
49
except ImportError:
50
has_pydantic = False
51
52
from .. import connection
53
from .. import fusion
54
from .. import types
55
from ..config import get_option
56
from ..converters import converters
57
from ..exceptions import DatabaseError # noqa: F401
58
from ..exceptions import DataError
59
from ..exceptions import Error # noqa: F401
60
from ..exceptions import IntegrityError
61
from ..exceptions import InterfaceError
62
from ..exceptions import InternalError
63
from ..exceptions import NotSupportedError
64
from ..exceptions import OperationalError
65
from ..exceptions import ProgrammingError
66
from ..exceptions import Warning # noqa: F401
67
from ..utils.convert_rows import convert_rows
68
from ..utils.debug import log_query
69
from ..utils.mogrify import mogrify
70
from ..utils.results import Description
71
from ..utils.results import format_results
72
from ..utils.results import get_schema
73
from ..utils.results import Result
74
75
76
# DB-API settings
77
apilevel = '2.0'
78
paramstyle = 'named'
79
threadsafety = 1
80
81
82
_interface_errors = set([
83
0,
84
2013, # CR_SERVER_LOST
85
2006, # CR_SERVER_GONE_ERROR
86
2012, # CR_HANDSHAKE_ERR
87
2004, # CR_IPSOCK_ERROR
88
2014, # CR_COMMANDS_OUT_OF_SYNC
89
])
90
_data_errors = set([
91
1406, # ER_DATA_TOO_LONG
92
1441, # ER_DATETIME_FUNCTION_OVERFLOW
93
1365, # ER_DIVISION_BY_ZERO
94
1230, # ER_NO_DEFAULT
95
1171, # ER_PRIMARY_CANT_HAVE_NULL
96
1264, # ER_WARN_DATA_OUT_OF_RANGE
97
1265, # ER_WARN_DATA_TRUNCATED
98
])
99
_programming_errors = set([
100
1065, # ER_EMPTY_QUERY
101
1179, # ER_CANT_DO_THIS_DURING_AN_TRANSACTION
102
1007, # ER_DB_CREATE_EXISTS
103
1110, # ER_FIELD_SPECIFIED_TWICE
104
1111, # ER_INVALID_GROUP_FUNC_USE
105
1082, # ER_NO_SUCH_INDEX
106
1741, # ER_NO_SUCH_KEY_VALUE
107
1146, # ER_NO_SUCH_TABLE
108
1449, # ER_NO_SUCH_USER
109
1064, # ER_PARSE_ERROR
110
1149, # ER_SYNTAX_ERROR
111
1113, # ER_TABLE_MUST_HAVE_COLUMNS
112
1112, # ER_UNSUPPORTED_EXTENSION
113
1102, # ER_WRONG_DB_NAME
114
1103, # ER_WRONG_TABLE_NAME
115
1049, # ER_BAD_DB_ERROR
116
1582, # ER_??? Wrong number of args
117
])
118
_integrity_errors = set([
119
1215, # ER_CANNOT_ADD_FOREIGN
120
1062, # ER_DUP_ENTRY
121
1169, # ER_DUP_UNIQUE
122
1364, # ER_NO_DEFAULT_FOR_FIELD
123
1216, # ER_NO_REFERENCED_ROW
124
1452, # ER_NO_REFERENCED_ROW_2
125
1217, # ER_ROW_IS_REFERENCED
126
1451, # ER_ROW_IS_REFERENCED_2
127
1460, # ER_XAER_OUTSIDE
128
1401, # ER_XAER_RMERR
129
1048, # ER_BAD_NULL_ERROR
130
1264, # ER_DATA_OUT_OF_RANGE
131
4025, # ER_CONSTRAINT_FAILED
132
1826, # ER_DUP_CONSTRAINT_NAME
133
])
134
135
136
def get_precision_scale(type_code: str) -> Tuple[Optional[int], Optional[int]]:
137
"""Parse the precision and scale from a data type."""
138
if '(' not in type_code:
139
return (None, None)
140
m = re.search(r'\(\s*(\d+)\s*,\s*(\d+)\s*\)', type_code)
141
if m:
142
return int(m.group(1)), int(m.group(2))
143
m = re.search(r'\(\s*(\d+)\s*\)', type_code)
144
if m:
145
return (int(m.group(1)), None)
146
raise ValueError(f'Unrecognized type code: {type_code}')
147
148
149
def get_exc_type(code: int) -> type:
150
"""Map error code to DB-API error type."""
151
if code in _interface_errors:
152
return InterfaceError
153
if code in _data_errors:
154
return DataError
155
if code in _programming_errors:
156
return ProgrammingError
157
if code in _integrity_errors:
158
return IntegrityError
159
if code >= 1000:
160
return OperationalError
161
return InternalError
162
163
164
def identity(x: Any) -> Any:
165
"""Return input value."""
166
return x
167
168
169
def b64decode_converter(
170
converter: Callable[..., Any],
171
x: Optional[str],
172
encoding: str = 'utf-8',
173
) -> Optional[bytes]:
174
"""Decode value before applying converter."""
175
if x is None:
176
return None
177
if converter is None:
178
return b64decode(x)
179
return converter(b64decode(x))
180
181
182
def encode_timedelta(obj: datetime.timedelta) -> str:
183
"""Encode timedelta as str."""
184
seconds = int(obj.seconds) % 60
185
minutes = int(obj.seconds // 60) % 60
186
hours = int(obj.seconds // 3600) % 24 + int(obj.days) * 24
187
if obj.microseconds:
188
fmt = '{0:02d}:{1:02d}:{2:02d}.{3:06d}'
189
else:
190
fmt = '{0:02d}:{1:02d}:{2:02d}'
191
return fmt.format(hours, minutes, seconds, obj.microseconds)
192
193
194
def encode_time(obj: datetime.time) -> str:
195
"""Encode time as str."""
196
if obj.microsecond:
197
fmt = '{0.hour:02}:{0.minute:02}:{0.second:02}.{0.microsecond:06}'
198
else:
199
fmt = '{0.hour:02}:{0.minute:02}:{0.second:02}'
200
return fmt.format(obj)
201
202
203
def encode_datetime(obj: datetime.datetime) -> str:
204
"""Encode datetime as str."""
205
if obj.microsecond:
206
fmt = '{0.year:04}-{0.month:02}-{0.day:02} ' \
207
'{0.hour:02}:{0.minute:02}:{0.second:02}.{0.microsecond:06}'
208
else:
209
fmt = '{0.year:04}-{0.month:02}-{0.day:02} ' \
210
'{0.hour:02}:{0.minute:02}:{0.second:02}'
211
return fmt.format(obj)
212
213
214
def encode_date(obj: datetime.date) -> str:
215
"""Encode date as str."""
216
fmt = '{0.year:04}-{0.month:02}-{0.day:02}'
217
return fmt.format(obj)
218
219
220
def encode_struct_time(obj: time.struct_time) -> str:
221
"""Encode time struct to str."""
222
return encode_datetime(datetime.datetime(*obj[:6]))
223
224
225
def encode_decimal(o: decimal.Decimal) -> str:
226
"""Encode decimal to str."""
227
return format(o, 'f')
228
229
230
# Most argument encoding is done by the JSON encoder, but these
231
# are exceptions to the rule.
232
encoders = {
233
datetime.datetime: encode_datetime,
234
datetime.date: encode_date,
235
datetime.time: encode_time,
236
datetime.timedelta: encode_timedelta,
237
time.struct_time: encode_struct_time,
238
decimal.Decimal: encode_decimal,
239
}
240
241
242
if has_shapely:
243
encoders[shapely.geometry.Point] = shapely.wkt.dumps
244
encoders[shapely.geometry.Polygon] = shapely.wkt.dumps
245
encoders[shapely.geometry.LineString] = shapely.wkt.dumps
246
247
if has_numpy:
248
249
def encode_ndarray(obj: np.ndarray) -> bytes: # type: ignore
250
"""Encode an ndarray as bytes."""
251
return obj.tobytes()
252
253
encoders[np.ndarray] = encode_ndarray
254
255
if has_pygeos:
256
encoders[pygeos.Geometry] = pygeos.io.to_wkt
257
258
259
def convert_special_type(
260
arg: Any,
261
nan_as_null: bool = False,
262
inf_as_null: bool = False,
263
) -> Any:
264
"""Convert special data type objects."""
265
dtype = type(arg)
266
if dtype is float or \
267
(
268
has_numpy and dtype in (
269
np.float16, np.float32, np.float64,
270
getattr(np, 'float128', np.float64),
271
)
272
):
273
if nan_as_null and math.isnan(arg):
274
return None
275
if inf_as_null and math.isinf(arg):
276
return None
277
func = encoders.get(dtype, None)
278
if func is not None:
279
return func(arg) # type: ignore
280
return arg
281
282
283
def convert_special_params(
284
params: Optional[Union[Sequence[Any], Dict[str, Any]]] = None,
285
nan_as_null: bool = False,
286
inf_as_null: bool = False,
287
) -> Optional[Union[Sequence[Any], Dict[str, Any]]]:
288
"""Convert parameters of special data types."""
289
if params is None:
290
return params
291
converter = functools.partial(
292
convert_special_type,
293
nan_as_null=nan_as_null,
294
inf_as_null=inf_as_null,
295
)
296
if isinstance(params, Dict):
297
return {k: converter(v) for k, v in params.items()}
298
return tuple(map(converter, params))
299
300
301
class PyMyField(object):
302
"""Field for PyMySQL compatibility."""
303
304
def __init__(self, name: str, flags: int, charset: int) -> None:
305
self.name = name
306
self.flags = flags
307
self.charsetnr = charset
308
309
310
class PyMyResult(object):
311
"""Result for PyMySQL compatibility."""
312
313
def __init__(self) -> None:
314
self.fields: List[PyMyField] = []
315
self.unbuffered_active = False
316
317
def append(self, item: PyMyField) -> None:
318
self.fields.append(item)
319
320
321
class Cursor(connection.Cursor):
322
"""
323
SingleStoreDB HTTP database cursor.
324
325
Cursor objects should not be created directly. They should come from
326
the `cursor` method on the `Connection` object.
327
328
Parameters
329
----------
330
connection : Connection
331
The HTTP Connection object the cursor belongs to
332
333
"""
334
335
def __init__(self, conn: 'Connection'):
336
connection.Cursor.__init__(self, conn)
337
self._connection: Optional[Connection] = conn
338
self._results: List[List[Tuple[Any, ...]]] = [[]]
339
self._results_type: str = self._connection._results_type \
340
if self._connection is not None else 'tuples'
341
self._row_idx: int = -1
342
self._result_idx: int = -1
343
self._descriptions: List[List[Description]] = []
344
self._schemas: List[Dict[str, Any]] = []
345
self.arraysize: int = get_option('results.arraysize')
346
self.rowcount: int = 0
347
self.lastrowid: Optional[int] = None
348
self._pymy_results: List[PyMyResult] = []
349
self._expect_results: bool = False
350
351
@property
352
def _result(self) -> Optional[PyMyResult]:
353
"""Return Result object for PyMySQL compatibility."""
354
if self._result_idx < 0:
355
return None
356
return self._pymy_results[self._result_idx]
357
358
@property
359
def description(self) -> Optional[List[Description]]:
360
"""Return description for current result set."""
361
if not self._descriptions:
362
return None
363
if self._result_idx >= 0 and self._result_idx < len(self._descriptions):
364
return self._descriptions[self._result_idx]
365
return None
366
367
@property
368
def _schema(self) -> Optional[Any]:
369
if not self._schemas:
370
return None
371
if self._result_idx >= 0 and self._result_idx < len(self._schemas):
372
return self._schemas[self._result_idx]
373
return None
374
375
def _post(self, path: str, *args: Any, **kwargs: Any) -> requests.Response:
376
"""
377
Invoke a POST request on the HTTP connection.
378
379
Parameters
380
----------
381
path : str
382
The path of the resource
383
*args : positional parameters, optional
384
Extra parameters to the POST request
385
**kwargs : keyword parameters, optional
386
Extra keyword parameters to the POST request
387
388
Returns
389
-------
390
requests.Response
391
392
"""
393
if self._connection is None:
394
raise ProgrammingError(errno=2048, msg='Connection is closed.')
395
if 'timeout' not in kwargs:
396
kwargs['timeout'] = self._connection.connection_params['connect_timeout']
397
return self._connection._post(path, *args, **kwargs)
398
399
def callproc(
400
self, name: str,
401
params: Optional[Sequence[Any]] = None,
402
) -> None:
403
"""
404
Call a stored procedure.
405
406
Parameters
407
----------
408
name : str
409
Name of the stored procedure
410
params : sequence, optional
411
Parameters to the stored procedure
412
413
"""
414
if self._connection is None:
415
raise ProgrammingError(errno=2048, msg='Connection is closed.')
416
417
name = connection._name_check(name)
418
419
if not params:
420
self._execute(f'CALL {name}();', is_callproc=True)
421
else:
422
keys = ', '.join(['%s' for i in range(len(params))])
423
self._execute(f'CALL {name}({keys});', params, is_callproc=True)
424
425
def close(self) -> None:
426
"""Close the cursor."""
427
self._connection = None
428
429
def execute(
430
self, query: str,
431
args: Optional[Union[Sequence[Any], Dict[str, Any]]] = None,
432
infile_stream: Optional[ # type: ignore
433
Union[
434
io.RawIOBase,
435
io.TextIOBase,
436
Iterable[Union[bytes, str]],
437
connection.InfileQueue,
438
]
439
] = None,
440
) -> int:
441
"""
442
Execute a SQL statement.
443
444
Parameters
445
----------
446
query : str
447
The SQL statement to execute
448
args : iterable or dict, optional
449
Parameters to substitute into the SQL code
450
451
"""
452
return self._execute(query, args, infile_stream=infile_stream)
453
454
def _validate_param_subs(
455
self, query: str,
456
args: Optional[Union[Sequence[Any], Dict[str, Any]]] = None,
457
) -> None:
458
"""Make sure the parameter substitions are valid."""
459
if args:
460
if isinstance(args, Sequence):
461
query = query % tuple(args)
462
else:
463
query = query % args
464
465
def _execute_fusion_query(
466
self,
467
oper: Union[str, bytes],
468
params: Optional[Union[Sequence[Any], Dict[str, Any]]] = None,
469
handler: Any = None,
470
) -> int:
471
oper = mogrify(oper, params)
472
473
if isinstance(oper, bytes):
474
oper = oper.decode('utf-8')
475
476
log_query(oper, None)
477
478
results_type = self._results_type
479
self._results_type = 'tuples'
480
try:
481
mgmt_res = fusion.execute(
482
self._connection, # type: ignore
483
oper,
484
handler=handler,
485
)
486
finally:
487
self._results_type = results_type
488
489
self._descriptions.append(list(mgmt_res.description))
490
self._schemas.append(get_schema(self._results_type, list(mgmt_res.description)))
491
self._results.append(list(mgmt_res.rows))
492
self.rowcount = len(self._results[-1])
493
494
pymy_res = PyMyResult()
495
for field in mgmt_res.fields:
496
pymy_res.append(
497
PyMyField(
498
field.name,
499
field.flags,
500
field.charsetnr,
501
),
502
)
503
504
self._pymy_results.append(pymy_res)
505
506
if self._results and self._results[0]:
507
self._row_idx = 0
508
self._result_idx = 0
509
510
return self.rowcount
511
512
def _execute(
513
self, oper: str,
514
params: Optional[Union[Sequence[Any], Dict[str, Any]]] = None,
515
is_callproc: bool = False,
516
infile_stream: Optional[ # type: ignore
517
Union[
518
io.RawIOBase,
519
io.TextIOBase,
520
Iterable[Union[bytes, str]],
521
connection.InfileQueue,
522
]
523
] = None,
524
) -> int:
525
self._descriptions = []
526
self._schemas = []
527
self._results = []
528
self._pymy_results = []
529
self._row_idx = -1
530
self._result_idx = -1
531
self.rowcount = 0
532
self._expect_results = False
533
534
if self._connection is None:
535
raise ProgrammingError(errno=2048, msg='Connection is closed.')
536
537
sql_type = 'exec'
538
if re.match(r'^\s*(select|show|call|echo|describe|with)\s+', oper, flags=re.I):
539
self._expect_results = True
540
sql_type = 'query'
541
542
if has_pydantic and isinstance(params, pydantic.BaseModel):
543
params = params.model_dump()
544
545
self._validate_param_subs(oper, params)
546
547
handler = fusion.get_handler(oper)
548
if handler is not None:
549
return self._execute_fusion_query(oper, params, handler=handler)
550
551
interpolate_query_with_empty_args = self._connection.connection_params.get(
552
'interpolate_query_with_empty_args', False,
553
)
554
oper, params = self._connection._convert_params(
555
oper, params, interpolate_query_with_empty_args,
556
)
557
558
log_query(oper, params)
559
560
data: Dict[str, Any] = dict(sql=oper)
561
if params is not None:
562
data['args'] = convert_special_params(
563
params,
564
nan_as_null=self._connection.connection_params['nan_as_null'],
565
inf_as_null=self._connection.connection_params['inf_as_null'],
566
)
567
if self._connection._database:
568
data['database'] = self._connection._database
569
570
if sql_type == 'query':
571
res = self._post('query/tuples', json=data)
572
else:
573
res = self._post('exec', json=data)
574
575
if res.status_code >= 400:
576
if res.text:
577
m = re.match(r'^Error\s+(\d+).*?:', res.text)
578
if m:
579
code = m.group(1)
580
msg = res.text.split(':', 1)[-1]
581
icode = int(code.split()[-1])
582
else:
583
icode = res.status_code
584
msg = res.text
585
raise get_exc_type(icode)(icode, msg.strip())
586
raise InterfaceError(errno=res.status_code, msg='HTTP Error')
587
588
out = json.loads(res.text)
589
590
if 'error' in out:
591
raise OperationalError(
592
errno=out['error'].get('code', 0),
593
msg=out['error'].get('message', 'HTTP Error'),
594
)
595
596
if sql_type == 'query':
597
# description: (name, type_code, display_size, internal_size,
598
# precision, scale, null_ok, column_flags, charset)
599
600
# Remove converters for things the JSON parser already converted
601
http_converters = dict(self._connection.decoders)
602
http_converters.pop(4, None)
603
http_converters.pop(5, None)
604
http_converters.pop(6, None)
605
http_converters.pop(15, None)
606
http_converters.pop(245, None)
607
http_converters.pop(247, None)
608
http_converters.pop(249, None)
609
http_converters.pop(250, None)
610
http_converters.pop(251, None)
611
http_converters.pop(252, None)
612
http_converters.pop(253, None)
613
http_converters.pop(254, None)
614
615
# Merge passed in converters
616
if self._connection._conv:
617
for k, v in self._connection._conv.items():
618
if isinstance(k, int):
619
http_converters[k] = v
620
621
# Make JSON a string for Arrow
622
if 'arrow' in self._results_type:
623
def json_to_str(x: Any) -> Optional[str]:
624
if x is None:
625
return None
626
return json.dumps(x)
627
http_converters[245] = json_to_str
628
629
# Don't convert date/times in polars
630
elif 'polars' in self._results_type:
631
http_converters.pop(7, None)
632
http_converters.pop(10, None)
633
http_converters.pop(12, None)
634
635
results = out['results']
636
637
# Convert data to Python types
638
if results and results[0]:
639
self._row_idx = 0
640
self._result_idx = 0
641
642
for result in results:
643
644
pymy_res = PyMyResult()
645
convs = []
646
647
description: List[Description] = []
648
for i, col in enumerate(result.get('columns', [])):
649
charset = 0
650
flags = 0
651
data_type = col['dataType'].split('(')[0]
652
type_code = types.ColumnType.get_code(data_type)
653
prec, scale = get_precision_scale(col['dataType'])
654
converter = http_converters.get(type_code, None)
655
656
if 'UNSIGNED' in data_type:
657
flags = 32
658
659
if data_type.endswith('BLOB') or data_type.endswith('BINARY'):
660
converter = functools.partial(
661
b64decode_converter, converter, # type: ignore
662
)
663
charset = 63 # BINARY
664
665
if type_code == 0: # DECIMAL
666
type_code = types.ColumnType.get_code('NEWDECIMAL')
667
elif type_code == 15: # VARCHAR / VARBINARY
668
type_code = types.ColumnType.get_code('VARSTRING')
669
670
if converter is not None:
671
convs.append((i, None, converter))
672
673
description.append(
674
Description(
675
str(col['name']), type_code,
676
None, None, prec, scale,
677
col.get('nullable', False),
678
flags, charset,
679
),
680
)
681
pymy_res.append(PyMyField(col['name'], flags, charset))
682
683
self._descriptions.append(description)
684
self._schemas.append(get_schema(self._results_type, description))
685
686
rows = convert_rows(result.get('rows', []), convs)
687
688
self._results.append(rows)
689
self._pymy_results.append(pymy_res)
690
691
# For compatibility with PyMySQL/MySQLdb
692
if is_callproc:
693
self._results.append([])
694
695
self.rowcount = len(self._results[0])
696
697
else:
698
# For compatibility with PyMySQL/MySQLdb
699
if is_callproc:
700
self._results.append([])
701
702
self.rowcount = out['rowsAffected']
703
704
return self.rowcount
705
706
def executemany(
707
self, query: str,
708
args: Optional[Sequence[Union[Sequence[Any], Dict[str, Any]]]] = None,
709
) -> int:
710
"""
711
Execute SQL code against multiple sets of parameters.
712
713
Parameters
714
----------
715
query : str
716
The SQL statement to execute
717
args : iterable of iterables or dicts, optional
718
Sets of parameters to substitute into the SQL code
719
720
"""
721
if self._connection is None:
722
raise ProgrammingError(errno=2048, msg='Connection is closed.')
723
724
results = []
725
rowcount = 0
726
if args is not None and len(args) > 0:
727
description = []
728
schema = {}
729
# Detect dataframes
730
if hasattr(args, 'itertuples'):
731
argiter = args.itertuples(index=False) # type: ignore
732
else:
733
argiter = iter(args)
734
for params in argiter:
735
self.execute(query, params)
736
if self._descriptions:
737
description = self._descriptions[-1]
738
if self._schemas:
739
schema = self._schemas[-1]
740
if self._rows is not None:
741
results.append(self._rows)
742
rowcount += self.rowcount
743
self._results = results
744
self._descriptions = [description for _ in range(len(results))]
745
self._schemas = [schema for _ in range(len(results))]
746
else:
747
self.execute(query)
748
rowcount += self.rowcount
749
750
self.rowcount = rowcount
751
752
return self.rowcount
753
754
@property
755
def _has_row(self) -> bool:
756
"""Determine if a row is available."""
757
if self._result_idx < 0 or self._result_idx >= len(self._results):
758
return False
759
if self._row_idx < 0 or self._row_idx >= len(self._results[self._result_idx]):
760
return False
761
return True
762
763
@property
764
def _rows(self) -> List[Tuple[Any, ...]]:
765
"""Return current set of rows."""
766
if not self._has_row:
767
return []
768
return self._results[self._result_idx]
769
770
def fetchone(self) -> Optional[Result]:
771
"""
772
Fetch a single row from the result set.
773
774
Returns
775
-------
776
tuple
777
Values of the returned row if there are rows remaining
778
None
779
If there are no rows left to return
780
781
"""
782
if self._connection is None:
783
raise ProgrammingError(errno=2048, msg='Connection is closed')
784
if not self._expect_results:
785
raise self._connection.ProgrammingError(msg='No query has been submitted')
786
if not self._has_row:
787
return None
788
out = self._rows[self._row_idx]
789
self._row_idx += 1
790
return format_results(
791
self._results_type,
792
self.description or [],
793
out, single=True,
794
schema=self._schema,
795
)
796
797
def fetchmany(
798
self,
799
size: Optional[int] = None,
800
) -> Result:
801
"""
802
Fetch `size` rows from the result.
803
804
If `size` is not specified, the `arraysize` attribute is used.
805
806
Returns
807
-------
808
list of tuples
809
Values of the returned rows if there are rows remaining
810
811
"""
812
if self._connection is None:
813
raise ProgrammingError(errno=2048, msg='Connection is closed')
814
if not self._expect_results:
815
raise self._connection.ProgrammingError(msg='No query has been submitted')
816
if not self._has_row:
817
if 'dict' in self._results_type:
818
return {}
819
return tuple()
820
if not size:
821
size = max(int(self.arraysize), 1)
822
else:
823
size = max(int(size), 1)
824
out = self._rows[self._row_idx:self._row_idx+size]
825
self._row_idx += len(out)
826
return format_results(
827
self._results_type, self.description or [],
828
out, schema=self._schema,
829
)
830
831
def fetchall(self) -> Result:
832
"""
833
Fetch all rows in the result set.
834
835
Returns
836
-------
837
list of tuples
838
Values of the returned rows if there are rows remaining
839
840
"""
841
if self._connection is None:
842
raise ProgrammingError(errno=2048, msg='Connection is closed')
843
if not self._expect_results:
844
raise self._connection.ProgrammingError(msg='No query has been submitted')
845
if not self._has_row:
846
if 'dict' in self._results_type:
847
return {}
848
return tuple()
849
out = list(self._rows[self._row_idx:])
850
self._row_idx = len(out)
851
return format_results(
852
self._results_type, self.description or [],
853
out, schema=self._schema,
854
)
855
856
def nextset(self) -> Optional[bool]:
857
"""Skip to the next available result set."""
858
if self._connection is None:
859
raise ProgrammingError(errno=2048, msg='Connection is closed')
860
861
if self._result_idx < 0:
862
self._row_idx = -1
863
return None
864
865
self._result_idx += 1
866
self._row_idx = 0
867
868
if self._result_idx >= len(self._results):
869
self._result_idx = -1
870
self._row_idx = -1
871
return None
872
873
self.rowcount = len(self._results[self._result_idx])
874
875
return True
876
877
def setinputsizes(self, sizes: Sequence[int]) -> None:
878
"""Predefine memory areas for parameters."""
879
pass
880
881
def setoutputsize(self, size: int, column: Optional[str] = None) -> None:
882
"""Set a column buffer size for fetches of large columns."""
883
pass
884
885
@property
886
def rownumber(self) -> Optional[int]:
887
"""
888
Return the zero-based index of the cursor in the result set.
889
890
Returns
891
-------
892
int
893
894
"""
895
if self._row_idx < 0:
896
return None
897
return self._row_idx
898
899
def scroll(self, value: int, mode: str = 'relative') -> None:
900
"""
901
Scroll the cursor to the position in the result set.
902
903
Parameters
904
----------
905
value : int
906
Value of the positional move
907
mode : str
908
Type of move that should be made: 'relative' or 'absolute'
909
910
"""
911
if self._connection is None:
912
raise ProgrammingError(errno=2048, msg='Connection is closed')
913
if mode == 'relative':
914
self._row_idx += value
915
elif mode == 'absolute':
916
self._row_idx = value
917
else:
918
raise ValueError(
919
f'{mode} is not a valid mode, '
920
'expecting "relative" or "absolute"',
921
)
922
923
def next(self) -> Optional[Result]:
924
"""
925
Return the next row from the result set for use in iterators.
926
927
Returns
928
-------
929
tuple
930
Values from the next result row
931
None
932
If no more rows exist
933
934
"""
935
if self._connection is None:
936
raise InterfaceError(errno=2048, msg='Connection is closed')
937
out = self.fetchone()
938
if out is None:
939
raise StopIteration
940
return out
941
942
__next__ = next
943
944
def __iter__(self) -> Iterable[Tuple[Any, ...]]:
945
"""Return result iterator."""
946
return iter(self._rows[self._row_idx:])
947
948
def __enter__(self) -> 'Cursor':
949
"""Enter a context."""
950
return self
951
952
def __exit__(
953
self, exc_type: Optional[object],
954
exc_value: Optional[Exception], exc_traceback: Optional[str],
955
) -> None:
956
"""Exit a context."""
957
self.close()
958
959
@property
960
def open(self) -> bool:
961
"""Check if the cursor is still connected."""
962
if self._connection is None:
963
return False
964
return self._connection.is_connected()
965
966
def is_connected(self) -> bool:
967
"""
968
Check if the cursor is still connected.
969
970
Returns
971
-------
972
bool
973
974
"""
975
return self.open
976
977
978
class Connection(connection.Connection):
979
"""
980
SingleStoreDB HTTP database connection.
981
982
Instances of this object are typically created through the
983
`connection` function rather than creating them directly.
984
985
See Also
986
--------
987
`connect`
988
989
"""
990
driver = 'https'
991
paramstyle = 'qmark'
992
993
def __init__(self, **kwargs: Any):
994
from .. import __version__ as client_version
995
996
if 'SINGLESTOREDB_WORKLOAD_TYPE' in os.environ:
997
client_version += '+' + os.environ['SINGLESTOREDB_WORKLOAD_TYPE']
998
999
connection.Connection.__init__(self, **kwargs)
1000
1001
host = kwargs.get('host', get_option('host'))
1002
port = kwargs.get('port', get_option('http_port'))
1003
1004
self._sess: Optional[requests.Session] = requests.Session()
1005
1006
user = kwargs.get('user', get_option('user'))
1007
password = kwargs.get('password', get_option('password'))
1008
if user is not None and password is not None:
1009
self._sess.auth = (user, password)
1010
elif user is not None:
1011
self._sess.auth = (user, '')
1012
self._sess.headers.update({
1013
'Content-Type': 'application/json',
1014
'Accept': 'application/json',
1015
'Accept-Encoding': 'compress,identity',
1016
'User-Agent': f'SingleStoreDB-Python/{client_version}',
1017
})
1018
1019
if kwargs.get('ssl_disabled', get_option('ssl_disabled')):
1020
self._sess.verify = False
1021
else:
1022
ssl_key = kwargs.get('ssl_key', get_option('ssl_key'))
1023
ssl_cert = kwargs.get('ssl_cert', get_option('ssl_cert'))
1024
if ssl_key and ssl_cert:
1025
self._sess.cert = (ssl_key, ssl_cert)
1026
elif ssl_cert:
1027
self._sess.cert = ssl_cert
1028
1029
ssl_ca = kwargs.get('ssl_ca', get_option('ssl_ca'))
1030
if ssl_ca:
1031
self._sess.verify = ssl_ca
1032
1033
ssl_verify_cert = kwargs.get('ssl_verify_cert', True)
1034
if not ssl_verify_cert:
1035
self._sess.verify = False
1036
1037
if kwargs.get('multi_statements', False):
1038
raise self.InterfaceError(
1039
0, 'The Data API does not allow multiple '
1040
'statements within a query',
1041
)
1042
1043
self._version = kwargs.get('version', 'v2')
1044
self.driver = kwargs.get('driver', 'https')
1045
1046
self.encoders = {k: v for (k, v) in converters.items() if type(k) is not int}
1047
self.decoders = {k: v for (k, v) in converters.items() if type(k) is int}
1048
1049
self._database = kwargs.get('database', get_option('database'))
1050
self._url = f'{self.driver}://{host}:{port}/api/{self._version}/'
1051
self._host = host
1052
self._messages: List[Tuple[int, str]] = []
1053
self._autocommit: bool = True
1054
self._conv = kwargs.get('conv', None)
1055
self._in_sync: bool = False
1056
self._track_env: bool = kwargs.get('track_env', False) \
1057
or host == 'singlestore.com'
1058
1059
@property
1060
def messages(self) -> List[Tuple[int, str]]:
1061
return self._messages
1062
1063
def connect(self) -> 'Connection':
1064
"""Connect to the server."""
1065
return self
1066
1067
def _sync_connection(self, kwargs: Dict[str, Any]) -> None:
1068
"""Synchronize connection with env variable."""
1069
if self._sess is None:
1070
raise InterfaceError(errno=2048, msg='Connection is closed.')
1071
1072
if self._in_sync:
1073
return
1074
1075
if not self._track_env:
1076
return
1077
1078
url = os.environ.get('SINGLESTOREDB_URL')
1079
if not url:
1080
if self._host == 'singlestore.com':
1081
raise InterfaceError(0, 'Connection URL has not been established')
1082
return
1083
1084
out = {}
1085
urlp = connection._parse_url(url)
1086
out.update(urlp)
1087
out = connection._cast_params(out)
1088
1089
# Set default port based on driver.
1090
if 'port' not in out or not out['port']:
1091
if out.get('driver', 'https') == 'http':
1092
out['port'] = int(get_option('port') or 80)
1093
else:
1094
out['port'] = int(get_option('port') or 443)
1095
1096
# If there is no user and the password is empty, remove the password key.
1097
if 'user' not in out and not out.get('password', None):
1098
out.pop('password', None)
1099
1100
if out['host'] == 'singlestore.com':
1101
raise InterfaceError(0, 'Connection URL has not been established')
1102
1103
# Get current connection attributes
1104
curr_url = urlparse(self._url, scheme='singlestoredb', allow_fragments=True)
1105
if self._sess.auth is not None:
1106
auth = tuple(self._sess.auth) # type: ignore
1107
else:
1108
auth = (None, None) # type: ignore
1109
1110
# If it's just a password change, we don't need to reconnect
1111
if (curr_url.hostname, curr_url.port, auth[0], self._database) == \
1112
(out['host'], out['port'], out['user'], out.get('database')):
1113
return
1114
1115
try:
1116
self._in_sync = True
1117
sess = requests.Session()
1118
sess.auth = (out['user'], out['password'])
1119
sess.headers.update(self._sess.headers)
1120
sess.verify = self._sess.verify
1121
sess.cert = self._sess.cert
1122
self._database = out.get('database')
1123
self._host = out['host']
1124
self._url = f'{out.get("driver", "https")}://{out["host"]}:{out["port"]}' \
1125
f'/api/{self._version}/'
1126
self._sess = sess
1127
if self._database:
1128
kwargs['json']['database'] = self._database
1129
finally:
1130
self._in_sync = False
1131
1132
def _post(self, path: str, *args: Any, **kwargs: Any) -> requests.Response:
1133
"""
1134
Invoke a POST request on the HTTP connection.
1135
1136
Parameters
1137
----------
1138
path : str
1139
The path of the resource
1140
*args : positional parameters, optional
1141
Extra parameters to the POST request
1142
**kwargs : keyword parameters, optional
1143
Extra keyword parameters to the POST request
1144
1145
Returns
1146
-------
1147
requests.Response
1148
1149
"""
1150
if self._sess is None:
1151
raise InterfaceError(errno=2048, msg='Connection is closed.')
1152
1153
self._sync_connection(kwargs)
1154
1155
return self._sess.post(urljoin(self._url, path), *args, **kwargs)
1156
1157
def close(self) -> None:
1158
"""Close the connection."""
1159
if self._host == 'singlestore.com':
1160
return
1161
if self._sess is None:
1162
raise Error(errno=2048, msg='Connection is closed')
1163
self._sess = None
1164
1165
def autocommit(self, value: bool = True) -> None:
1166
"""Set autocommit mode."""
1167
if self._host == 'singlestore.com':
1168
return
1169
if self._sess is None:
1170
raise InterfaceError(errno=2048, msg='Connection is closed')
1171
self._autocommit = value
1172
1173
def commit(self) -> None:
1174
"""Commit the pending transaction."""
1175
if self._host == 'singlestore.com':
1176
return
1177
if self._sess is None:
1178
raise InterfaceError(errno=2048, msg='Connection is closed')
1179
if self._autocommit:
1180
return
1181
raise NotSupportedError(msg='operation not supported')
1182
1183
def rollback(self) -> None:
1184
"""Rollback the pending transaction."""
1185
if self._host == 'singlestore.com':
1186
return
1187
if self._sess is None:
1188
raise InterfaceError(errno=2048, msg='Connection is closed')
1189
if self._autocommit:
1190
return
1191
raise NotSupportedError(msg='operation not supported')
1192
1193
def cursor(self) -> Cursor:
1194
"""
1195
Create a new cursor object.
1196
1197
Returns
1198
-------
1199
Cursor
1200
1201
"""
1202
return Cursor(self)
1203
1204
def __enter__(self) -> 'Connection':
1205
"""Enter a context."""
1206
return self
1207
1208
def __exit__(
1209
self, exc_type: Optional[object],
1210
exc_value: Optional[Exception], exc_traceback: Optional[str],
1211
) -> None:
1212
"""Exit a context."""
1213
self.close()
1214
1215
@property
1216
def open(self) -> bool:
1217
"""Check if the database is still connected."""
1218
if self._sess is None:
1219
return False
1220
url = '/'.join(self._url.split('/')[:3]) + '/ping'
1221
res = self._sess.get(url)
1222
if res.status_code <= 400 and res.text == 'pong':
1223
return True
1224
return False
1225
1226
def is_connected(self) -> bool:
1227
"""
1228
Check if the database is still connected.
1229
1230
Returns
1231
-------
1232
bool
1233
1234
"""
1235
return self.open
1236
1237
1238
def connect(
1239
host: Optional[str] = None,
1240
user: Optional[str] = None,
1241
password: Optional[str] = None,
1242
port: Optional[int] = None,
1243
database: Optional[str] = None,
1244
driver: Optional[str] = None,
1245
pure_python: Optional[bool] = None,
1246
local_infile: Optional[bool] = None,
1247
charset: Optional[str] = None,
1248
ssl_key: Optional[str] = None,
1249
ssl_cert: Optional[str] = None,
1250
ssl_ca: Optional[str] = None,
1251
ssl_disabled: Optional[bool] = None,
1252
ssl_cipher: Optional[str] = None,
1253
ssl_verify_cert: Optional[bool] = None,
1254
ssl_verify_identity: Optional[bool] = None,
1255
conv: Optional[Dict[int, Callable[..., Any]]] = None,
1256
credential_type: Optional[str] = None,
1257
autocommit: Optional[bool] = None,
1258
results_type: Optional[str] = None,
1259
buffered: Optional[bool] = None,
1260
results_format: Optional[str] = None,
1261
program_name: Optional[str] = None,
1262
conn_attrs: Optional[Dict[str, str]] = None,
1263
multi_statements: Optional[bool] = None,
1264
connect_timeout: Optional[int] = None,
1265
nan_as_null: Optional[bool] = None,
1266
inf_as_null: Optional[bool] = None,
1267
encoding_errors: Optional[str] = None,
1268
track_env: Optional[bool] = None,
1269
enable_extended_data_types: Optional[bool] = None,
1270
vector_data_format: Optional[str] = None,
1271
) -> Connection:
1272
return Connection(**dict(locals()))
1273
1274