Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
singlestore-labs
GitHub Repository: singlestore-labs/singlestoredb-python
Path: blob/main/singlestoredb/connection.py
469 views
1
#!/usr/bin/env python
2
"""SingleStoreDB connections and cursors."""
3
import abc
4
import functools
5
import inspect
6
import io
7
import queue
8
import re
9
import sys
10
import warnings
11
import weakref
12
from collections.abc import Mapping
13
from collections.abc import MutableMapping
14
from typing import Any
15
from typing import Callable
16
from typing import Dict
17
from typing import Iterator
18
from typing import List
19
from typing import Optional
20
from typing import Sequence
21
from typing import Tuple
22
from typing import Union
23
from urllib.parse import parse_qs
24
from urllib.parse import unquote_plus
25
from urllib.parse import urlparse
26
27
import sqlparams
28
try:
29
from pandas import DataFrame
30
except ImportError:
31
class DataFrame(object): # type: ignore
32
def itertuples(self, *args: Any, **kwargs: Any) -> None:
33
pass
34
35
from . import auth
36
from . import exceptions
37
from .config import get_option
38
from .utils.results import Description
39
from .utils.results import Result
40
41
if sys.version_info < (3, 10):
42
InfileQueue = queue.Queue
43
else:
44
InfileQueue = queue.Queue[Union[bytes, str]]
45
46
47
# DB-API settings
48
apilevel = '2.0'
49
threadsafety = 1
50
paramstyle = map_paramstyle = 'pyformat'
51
positional_paramstyle = 'format'
52
53
54
# Type codes for character-based columns
55
CHAR_COLUMNS = set(list(range(247, 256)) + [245])
56
57
58
def under2camel(s: str) -> str:
59
"""Format underscore-delimited strings to camel-case."""
60
61
def upper_mid(m: Any) -> str:
62
"""Uppercase middle group of matches."""
63
return m.group(1) + m.group(2).upper() + m.group(3)
64
65
def upper(m: Any) -> str:
66
"""Uppercase match."""
67
return m.group(1).upper()
68
69
s = re.sub(r'(\b|_)(xml|sql|json)(\b|_)', upper_mid, s, flags=re.I)
70
s = re.sub(r'(?:^|_+)(\w)', upper, s)
71
s = re.sub(r'_+$', r'', s)
72
73
return s
74
75
76
def nested_converter(
77
conv: Callable[[Any], Any],
78
inner: Callable[[Any], Any],
79
) -> Callable[[Any], Any]:
80
"""Create a pipeline of two functions."""
81
def converter(value: Any) -> Any:
82
return conv(inner(value))
83
return converter
84
85
86
def cast_bool_param(val: Any) -> bool:
87
"""Cast value to a bool."""
88
if val is None or val is False:
89
return False
90
91
if val is True:
92
return True
93
94
# Test ints
95
try:
96
ival = int(val)
97
if ival == 1:
98
return True
99
if ival == 0:
100
return False
101
except Exception:
102
pass
103
104
# Lowercase strings
105
if hasattr(val, 'lower'):
106
if val.lower() in ['on', 't', 'true', 'y', 'yes', 'enabled', 'enable']:
107
return True
108
elif val.lower() in ['off', 'f', 'false', 'n', 'no', 'disabled', 'disable']:
109
return False
110
111
raise ValueError('Unrecognized value for bool: {}'.format(val))
112
113
114
def build_params(**kwargs: Any) -> Dict[str, Any]:
115
"""
116
Construct connection parameters from given URL and arbitrary parameters.
117
118
Parameters
119
----------
120
**kwargs : keyword-parameters, optional
121
Arbitrary keyword parameters corresponding to connection parameters
122
123
Returns
124
-------
125
dict
126
127
"""
128
out: Dict[str, Any] = {}
129
130
kwargs = {k: v for k, v in kwargs.items() if v is not None}
131
132
# Set known parameters
133
for name in inspect.getfullargspec(connect).args:
134
if name == 'conv':
135
out[name] = kwargs.get(name, None)
136
elif name == 'results_format': # deprecated
137
if kwargs.get(name, None) is not None:
138
warnings.warn(
139
'The `results_format=` parameter has been '
140
'renamed to `results_type=`.',
141
DeprecationWarning,
142
)
143
out['results_type'] = kwargs.get(name, get_option('results.type'))
144
elif name == 'results_type':
145
out[name] = kwargs.get(name, get_option('results.type'))
146
else:
147
out[name] = kwargs.get(name, get_option(name))
148
149
# See if host actually contains a URL; definitely not a perfect test.
150
host = out['host']
151
if host and (':' in host or '/' in host or '@' in host or '?' in host):
152
urlp = _parse_url(host)
153
if 'driver' not in urlp:
154
urlp['driver'] = get_option('driver')
155
out.update(urlp)
156
157
out = _cast_params(out)
158
159
# Set default port based on driver.
160
if 'port' not in out or not out['port']:
161
if out['driver'] == 'http':
162
out['port'] = int(get_option('http_port') or 80)
163
elif out['driver'] == 'https':
164
out['port'] = int(get_option('http_port') or 443)
165
else:
166
out['port'] = int(get_option('port') or 3306)
167
168
# If there is no user and the password is empty, remove the password key.
169
if 'user' not in out and not out.get('password', None):
170
out.pop('password', None)
171
172
if out.get('ssl_ca', '') and not out.get('ssl_verify_cert', None):
173
out['ssl_verify_cert'] = True
174
175
return out
176
177
178
def _get_param_types(func: Any) -> Dict[str, Any]:
179
"""
180
Retrieve the types for the parameters to the given function.
181
182
Note that if a parameter has multiple possible types, only the
183
first one is returned.
184
185
Parameters
186
----------
187
func : callable
188
Callable object to inspect the parameters of
189
190
Returns
191
-------
192
dict
193
194
"""
195
out = {}
196
args = inspect.getfullargspec(func)
197
for name in args.args:
198
ann = args.annotations[name]
199
if isinstance(ann, str):
200
ann = eval(ann)
201
if hasattr(ann, '__args__'):
202
out[name] = ann.__args__[0]
203
else:
204
out[name] = ann
205
return out
206
207
208
def _cast_params(params: Dict[str, Any]) -> Dict[str, Any]:
209
"""
210
Cast known keys to appropriate values.
211
212
Parameters
213
----------
214
params : dict
215
Dictionary of connection parameters
216
217
Returns
218
-------
219
dict
220
221
"""
222
param_types = _get_param_types(connect)
223
out = {}
224
for key, val in params.items():
225
key = key.lower()
226
if val is None:
227
continue
228
if key not in param_types:
229
raise ValueError('Unrecognized connection parameter: {}'.format(key))
230
dtype = param_types[key]
231
if dtype is bool:
232
val = cast_bool_param(val)
233
elif getattr(dtype, '_name', '') in ['Dict', 'Mapping'] or \
234
str(dtype).startswith('typing.Dict'):
235
val = dict(val)
236
elif getattr(dtype, '_name', '') == 'List':
237
val = list(val)
238
elif getattr(dtype, '_name', '') == 'Tuple':
239
val = tuple(val)
240
else:
241
val = dtype(val)
242
out[key] = val
243
return out
244
245
246
def _parse_url(url: str) -> Dict[str, Any]:
247
"""
248
Parse a connection URL and return only the defined parts.
249
250
Parameters
251
----------
252
url : str
253
The URL passed in can be a full URL or a partial URL. At a minimum,
254
a host name must be specified. All other parts are optional.
255
256
Returns
257
-------
258
dict
259
260
"""
261
out: Dict[str, Any] = {}
262
263
if '//' not in url:
264
url = '//' + url
265
266
if url.startswith('singlestoredb+'):
267
url = re.sub(r'^singlestoredb\+', r'', url)
268
269
parts = urlparse(url, scheme='singlestoredb', allow_fragments=True)
270
271
url_db = parts.path
272
if url_db.startswith('/'):
273
url_db = url_db.split('/')[1].strip()
274
url_db = url_db.split('/')[0].strip() or ''
275
276
# Retrieve basic connection parameters
277
out['host'] = parts.hostname or None
278
out['port'] = parts.port or None
279
out['database'] = url_db or None
280
out['user'] = parts.username or None
281
282
# Allow an empty string for password
283
if out['user'] and parts.password is not None:
284
out['password'] = parts.password
285
286
if parts.scheme != 'singlestoredb':
287
out['driver'] = parts.scheme.lower()
288
289
if out.get('user'):
290
out['user'] = unquote_plus(out['user'])
291
292
if out.get('password'):
293
out['password'] = unquote_plus(out['password'])
294
295
if out.get('database'):
296
out['database'] = unquote_plus(out['database'])
297
298
# Convert query string to parameters
299
out.update({k.lower(): v[-1] for k, v in parse_qs(parts.query).items()})
300
301
return {k: v for k, v in out.items() if v is not None}
302
303
304
def _name_check(name: str) -> str:
305
"""
306
Make sure the given name is a legal variable name.
307
308
Parameters
309
----------
310
name : str
311
Name to check
312
313
Returns
314
-------
315
str
316
317
"""
318
name = name.strip()
319
if not re.match(r'^[A-Za-z_][\w+_]*$', name):
320
raise ValueError('Name contains invalid characters')
321
return name
322
323
324
def quote_identifier(name: str) -> str:
325
"""Escape identifier value."""
326
return f'`{name}`'
327
328
329
class Driver(object):
330
"""Compatibility class for driver name."""
331
332
def __init__(self, name: str):
333
self.name = name
334
335
336
class VariableAccessor(MutableMapping): # type: ignore
337
"""Variable accessor class."""
338
339
def __init__(self, conn: 'Connection', vtype: str):
340
object.__setattr__(self, 'connection', weakref.proxy(conn))
341
object.__setattr__(self, 'vtype', vtype.lower())
342
if self.vtype not in [
343
'global', 'local', '',
344
'cluster', 'cluster global', 'cluster local',
345
]:
346
raise ValueError(
347
'Variable type must be global, local, cluster, '
348
'cluster global, cluster local, or empty',
349
)
350
351
def _cast_value(self, value: Any) -> Any:
352
if isinstance(value, str):
353
if value.lower() in ['on', 'true']:
354
return True
355
if value.lower() in ['off', 'false']:
356
return False
357
return value
358
359
def __getitem__(self, name: str) -> Any:
360
name = _name_check(name)
361
out = self.connection._iquery(
362
'show {} variables like %s;'.format(self.vtype),
363
[name],
364
)
365
if not out:
366
raise KeyError(f"No variable found with the name '{name}'.")
367
if len(out) > 1:
368
raise KeyError(f"Multiple variables found with the name '{name}'.")
369
return self._cast_value(out[0]['Value'])
370
371
def __setitem__(self, name: str, value: Any) -> None:
372
name = _name_check(name)
373
if value is True:
374
value = 'ON'
375
elif value is False:
376
value = 'OFF'
377
if 'local' in self.vtype:
378
self.connection._iquery(
379
'set {} {}=%s;'.format(
380
self.vtype.replace('local', 'session'), name,
381
), [value],
382
)
383
else:
384
self.connection._iquery('set {} {}=%s;'.format(self.vtype, name), [value])
385
386
def __delitem__(self, name: str) -> None:
387
raise TypeError('Variables can not be deleted.')
388
389
def __getattr__(self, name: str) -> Any:
390
return self[name]
391
392
def __setattr__(self, name: str, value: Any) -> None:
393
self[name] = value
394
395
def __delattr__(self, name: str) -> None:
396
del self[name]
397
398
def __len__(self) -> int:
399
out = self.connection._iquery('show {} variables;'.format(self.vtype))
400
return len(list(out))
401
402
def __iter__(self) -> Iterator[str]:
403
out = self.connection._iquery('show {} variables;'.format(self.vtype))
404
return iter(list(x.values())[0] for x in out)
405
406
407
class Cursor(metaclass=abc.ABCMeta):
408
"""
409
Database cursor for submitting commands and queries.
410
411
This object should not be instantiated directly.
412
The ``Connection.cursor`` method should be used.
413
414
"""
415
416
def __init__(self, connection: 'Connection'):
417
"""Call ``Connection.cursor`` instead."""
418
self.errorhandler = connection.errorhandler
419
self._connection: Optional[Connection] = weakref.proxy(connection)
420
421
self._rownumber: Optional[int] = None
422
423
self._description: Optional[List[Description]] = None
424
425
#: Default batch size of ``fetchmany`` calls.
426
self.arraysize = get_option('results.arraysize')
427
428
self._converters: List[
429
Tuple[
430
int, Optional[str],
431
Optional[Callable[..., Any]],
432
]
433
] = []
434
435
#: Number of rows affected by the last query.
436
self.rowcount: int = -1
437
438
self._messages: List[Tuple[int, str]] = []
439
440
#: Row ID of the last modified row.
441
self.lastrowid: Optional[int] = None
442
443
@property
444
def messages(self) -> List[Tuple[int, str]]:
445
"""Messages created by the server."""
446
return self._messages
447
448
@abc.abstractproperty
449
def description(self) -> Optional[List[Description]]:
450
"""The field descriptions of the last query."""
451
return self._description
452
453
@abc.abstractproperty
454
def rownumber(self) -> Optional[int]:
455
"""The last modified row number."""
456
return self._rownumber
457
458
@property
459
def connection(self) -> Optional['Connection']:
460
"""the connection that the cursor belongs to."""
461
return self._connection
462
463
@abc.abstractmethod
464
def callproc(
465
self, name: str,
466
params: Optional[Sequence[Any]] = None,
467
) -> None:
468
"""
469
Call a stored procedure.
470
471
The result sets generated by a store procedure can be retrieved
472
like the results of any other query using :meth:`fetchone`,
473
:meth:`fetchmany`, or :meth:`fetchall`. If the procedure generates
474
multiple result sets, subsequent result sets can be accessed
475
using :meth:`nextset`.
476
477
Examples
478
--------
479
>>> cur.callproc('myprocedure', ['arg1', 'arg2'])
480
>>> print(cur.fetchall())
481
482
Parameters
483
----------
484
name : str
485
Name of the stored procedure
486
params : iterable, optional
487
Parameters to the stored procedure
488
489
"""
490
# NOTE: The `callproc` interface varies quite a bit between drivers
491
# so it is implemented using `execute` here.
492
493
if not self.is_connected():
494
raise exceptions.InterfaceError(2048, 'Cursor is closed.')
495
496
name = _name_check(name)
497
498
if not params:
499
self.execute(f'CALL {name}();')
500
else:
501
keys = ', '.join([f':{i+1}' for i in range(len(params))])
502
self.execute(f'CALL {name}({keys});', params)
503
504
@abc.abstractmethod
505
def is_connected(self) -> bool:
506
"""Is the cursor still connected?"""
507
raise NotImplementedError
508
509
@abc.abstractmethod
510
def close(self) -> None:
511
"""Close the cursor."""
512
raise NotImplementedError
513
514
@abc.abstractmethod
515
def execute(
516
self, query: str,
517
args: Optional[Union[Sequence[Any], Dict[str, Any], Any]] = None,
518
infile_stream: Optional[ # type: ignore
519
Union[
520
io.RawIOBase,
521
io.TextIOBase,
522
Iterator[Union[bytes, str]],
523
InfileQueue,
524
]
525
] = None,
526
) -> int:
527
"""
528
Execute a SQL statement.
529
530
Queries can use the ``format``-style parameters (``%s``) when using a
531
list of paramters or ``pyformat``-style parameters (``%(key)s``)
532
when using a dictionary of parameters.
533
534
Parameters
535
----------
536
query : str
537
The SQL statement to execute
538
args : Sequence or dict, optional
539
Parameters to substitute into the SQL code
540
infile_stream : io.RawIOBase or io.TextIOBase or Iterator[bytes|str], optional
541
Data stream for ``LOCAL INFILE`` statement
542
543
Examples
544
--------
545
Query with no parameters
546
547
>>> cur.execute('select * from mytable')
548
549
Query with positional parameters
550
551
>>> cur.execute('select * from mytable where id < %s', [100])
552
553
Query with named parameters
554
555
>>> cur.execute('select * from mytable where id < %(max)s', dict(max=100))
556
557
Returns
558
-------
559
Number of rows affected
560
561
"""
562
raise NotImplementedError
563
564
def executemany(
565
self, query: str,
566
args: Optional[Sequence[Union[Sequence[Any], Dict[str, Any], Any]]] = None,
567
) -> int:
568
"""
569
Execute SQL code against multiple sets of parameters.
570
571
Queries can use the ``format``-style parameters (``%s``) when using
572
lists of paramters or ``pyformat``-style parameters (``%(key)s``)
573
when using dictionaries of parameters.
574
575
Parameters
576
----------
577
query : str
578
The SQL statement to execute
579
args : iterable of iterables or dicts, optional
580
Sets of parameters to substitute into the SQL code
581
582
Examples
583
--------
584
>>> cur.executemany('select * from mytable where id < %s',
585
... [[100], [200], [300]])
586
587
>>> cur.executemany('select * from mytable where id < %(max)s',
588
... [dict(max=100), dict(max=100), dict(max=300)])
589
590
Returns
591
-------
592
Number of rows affected
593
594
"""
595
# NOTE: Just implement using `execute` to cover driver inconsistencies
596
if not args:
597
self.execute(query)
598
else:
599
for params in args:
600
self.execute(query, params)
601
return self.rowcount
602
603
@abc.abstractmethod
604
def fetchone(self) -> Optional[Result]:
605
"""
606
Fetch a single row from the result set.
607
608
Examples
609
--------
610
>>> while True:
611
... row = cur.fetchone()
612
... if row is None:
613
... break
614
... print(row)
615
616
Returns
617
-------
618
tuple
619
Values of the returned row if there are rows remaining
620
621
"""
622
raise NotImplementedError
623
624
@abc.abstractmethod
625
def fetchmany(self, size: Optional[int] = None) -> Result:
626
"""
627
Fetch `size` rows from the result.
628
629
If `size` is not specified, the `arraysize` attribute is used.
630
631
Examples
632
--------
633
>>> while True:
634
... out = cur.fetchmany(100)
635
... if not len(out):
636
... break
637
... for row in out:
638
... print(row)
639
640
Returns
641
-------
642
list of tuples
643
Values of the returned rows if there are rows remaining
644
645
"""
646
raise NotImplementedError
647
648
@abc.abstractmethod
649
def fetchall(self) -> Result:
650
"""
651
Fetch all rows in the result set.
652
653
Examples
654
--------
655
>>> for row in cur.fetchall():
656
... print(row)
657
658
Returns
659
-------
660
list of tuples
661
Values of the returned rows if there are rows remaining
662
None
663
If there are no rows to return
664
665
"""
666
raise NotImplementedError
667
668
@abc.abstractmethod
669
def nextset(self) -> Optional[bool]:
670
"""
671
Skip to the next available result set.
672
673
This is used when calling a procedure that returns multiple
674
results sets.
675
676
Note
677
----
678
The ``nextset`` method must be called until it returns an empty
679
set (i.e., once more than the number of expected result sets).
680
This is to retain compatibility with PyMySQL and MySOLdb.
681
682
Returns
683
-------
684
``True``
685
If another result set is available
686
``False``
687
If no other result set is available
688
689
"""
690
raise NotImplementedError
691
692
@abc.abstractmethod
693
def setinputsizes(self, sizes: Sequence[int]) -> None:
694
"""Predefine memory areas for parameters."""
695
raise NotImplementedError
696
697
@abc.abstractmethod
698
def setoutputsize(self, size: int, column: Optional[str] = None) -> None:
699
"""Set a column buffer size for fetches of large columns."""
700
raise NotImplementedError
701
702
@abc.abstractmethod
703
def scroll(self, value: int, mode: str = 'relative') -> None:
704
"""
705
Scroll the cursor to the position in the result set.
706
707
Parameters
708
----------
709
value : int
710
Value of the positional move
711
mode : str
712
Where to move the cursor from: 'relative' or 'absolute'
713
714
"""
715
raise NotImplementedError
716
717
def next(self) -> Optional[Result]:
718
"""
719
Return the next row from the result set for use in iterators.
720
721
Raises
722
------
723
StopIteration
724
If no more results exist
725
726
Returns
727
-------
728
tuple of values
729
730
"""
731
if not self.is_connected():
732
raise exceptions.InterfaceError(2048, 'Cursor is closed.')
733
out = self.fetchone()
734
if out is None:
735
raise StopIteration
736
return out
737
738
__next__ = next
739
740
def __iter__(self) -> Any:
741
"""Return result iterator."""
742
return self
743
744
def __enter__(self) -> 'Cursor':
745
"""Enter a context."""
746
return self
747
748
def __exit__(
749
self, exc_type: Optional[object],
750
exc_value: Optional[Exception], exc_traceback: Optional[str],
751
) -> None:
752
"""Exit a context."""
753
self.close()
754
755
756
class ShowResult(Sequence[Any]):
757
"""
758
Simple result object.
759
760
This object is primarily used for displaying results to a
761
terminal or web browser, but it can also be treated like a
762
simple data frame where columns are accessible using either
763
dictionary key-like syntax or attribute syntax.
764
765
Examples
766
--------
767
>>> conn.show.status().Value[10]
768
769
>>> conn.show.status()[10]['Value']
770
771
Parameters
772
----------
773
*args : Any
774
Parameters to send to underlying list constructor
775
**kwargs : Any
776
Keyword parameters to send to underlying list constructor
777
778
See Also
779
--------
780
:attr:`Connection.show`
781
782
"""
783
784
def __init__(self, *args: Any, **kwargs: Any) -> None:
785
self._data: List[Dict[str, Any]] = []
786
item: Any = None
787
for item in list(*args, **kwargs):
788
self._data.append(item)
789
790
def __getitem__(self, item: Union[int, slice]) -> Any:
791
return self._data[item]
792
793
def __getattr__(self, name: str) -> List[Any]:
794
if name.startswith('_ipython'):
795
raise AttributeError(name)
796
out = []
797
for item in self._data:
798
out.append(item[name])
799
return out
800
801
def __len__(self) -> int:
802
return len(self._data)
803
804
def __repr__(self) -> str:
805
if not self._data:
806
return ''
807
return '\n{}\n'.format(self._format_table(self._data))
808
809
@property
810
def columns(self) -> List[str]:
811
"""The columns in the result."""
812
if not self._data:
813
return []
814
return list(self._data[0].keys())
815
816
def _format_table(self, rows: Sequence[Dict[str, Any]]) -> str:
817
if not self._data:
818
return ''
819
820
keys = rows[0].keys()
821
lens = [len(x) for x in keys]
822
823
for row in self._data:
824
align = ['<'] * len(keys)
825
for i, k in enumerate(keys):
826
lens[i] = max(lens[i], len(str(row[k])))
827
align[i] = '<' if isinstance(row[k], (bytes, bytearray, str)) else '>'
828
829
fmt = '| %s |' % '|'.join([' {:%s%d} ' % (x, y) for x, y in zip(align, lens)])
830
831
out = []
832
out.append(fmt.format(*keys))
833
out.append('-' * len(out[0]))
834
for row in rows:
835
out.append(fmt.format(*[str(x) for x in row.values()]))
836
return '\n'.join(out)
837
838
def __str__(self) -> str:
839
return self.__repr__()
840
841
def _repr_html_(self) -> str:
842
if not self._data:
843
return ''
844
cell_style = 'style="text-align: left; vertical-align: top"'
845
out = []
846
out.append('<table border="1" class="dataframe">')
847
out.append('<thead>')
848
out.append('<tr>')
849
for name in self._data[0].keys():
850
out.append(f'<th {cell_style}>{name}</th>')
851
out.append('</tr>')
852
out.append('</thead>')
853
out.append('<tbody>')
854
for row in self._data:
855
out.append('<tr>')
856
for item in row.values():
857
out.append(f'<td {cell_style}>{item}</td>')
858
out.append('</tr>')
859
out.append('</tbody>')
860
out.append('</table>')
861
return ''.join(out)
862
863
864
class ShowAccessor(object):
865
"""
866
Accessor for ``SHOW`` commands.
867
868
See Also
869
--------
870
:attr:`Connection.show`
871
872
"""
873
874
def __init__(self, conn: 'Connection'):
875
self._conn = conn
876
877
def columns(self, table: str, full: bool = False) -> ShowResult:
878
"""Show the column information for the given table."""
879
table = quote_identifier(table)
880
if full:
881
return self._iquery(f'full columns in {table}')
882
return self._iquery(f'columns in {table}')
883
884
def tables(self, extended: bool = False) -> ShowResult:
885
"""Show tables in the current database."""
886
if extended:
887
return self._iquery('tables extended')
888
return self._iquery('tables')
889
890
def warnings(self) -> ShowResult:
891
"""Show warnings."""
892
return self._iquery('warnings')
893
894
def errors(self) -> ShowResult:
895
"""Show errors."""
896
return self._iquery('errors')
897
898
def databases(self, extended: bool = False) -> ShowResult:
899
"""Show all databases in the server."""
900
if extended:
901
return self._iquery('databases extended')
902
return self._iquery('databases')
903
904
def database_status(self) -> ShowResult:
905
"""Show status of the current database."""
906
return self._iquery('database status')
907
908
def global_status(self) -> ShowResult:
909
"""Show global status of the current server."""
910
return self._iquery('global status')
911
912
def indexes(self, table: str) -> ShowResult:
913
"""Show all indexes in the given table."""
914
table = quote_identifier(table)
915
return self._iquery(f'indexes in {table}')
916
917
def functions(self) -> ShowResult:
918
"""Show all functions in the current database."""
919
return self._iquery('functions')
920
921
def partitions(self, extended: bool = False) -> ShowResult:
922
"""Show partitions in the current database."""
923
if extended:
924
return self._iquery('partitions extended')
925
return self._iquery('partitions')
926
927
def pipelines(self) -> ShowResult:
928
"""Show all pipelines in the current database."""
929
return self._iquery('pipelines')
930
931
def plan(self, plan_id: int, json: bool = False) -> ShowResult:
932
"""Show the plan for the given plan ID."""
933
plan_id = int(plan_id)
934
if json:
935
return self._iquery(f'plan json {plan_id}')
936
return self._iquery(f'plan {plan_id}')
937
938
def plancache(self) -> ShowResult:
939
"""Show all query statements compiled and executed."""
940
return self._iquery('plancache')
941
942
def processlist(self) -> ShowResult:
943
"""Show details about currently running threads."""
944
return self._iquery('processlist')
945
946
def reproduction(self, outfile: Optional[str] = None) -> ShowResult:
947
"""Show troubleshooting data for query optimizer and code generation."""
948
if outfile:
949
outfile = outfile.replace('"', r'\"')
950
return self._iquery('reproduction into outfile "{outfile}"')
951
return self._iquery('reproduction')
952
953
def schemas(self) -> ShowResult:
954
"""Show schemas in the server."""
955
return self._iquery('schemas')
956
957
def session_status(self) -> ShowResult:
958
"""Show server status information for a session."""
959
return self._iquery('session status')
960
961
def status(self, extended: bool = False) -> ShowResult:
962
"""Show server status information."""
963
if extended:
964
return self._iquery('status extended')
965
return self._iquery('status')
966
967
def table_status(self) -> ShowResult:
968
"""Show table status information for the current database."""
969
return self._iquery('table status')
970
971
def procedures(self) -> ShowResult:
972
"""Show all procedures in the current database."""
973
return self._iquery('procedures')
974
975
def aggregates(self) -> ShowResult:
976
"""Show all aggregate functions in the current database."""
977
return self._iquery('aggregates')
978
979
def create_aggregate(self, name: str) -> ShowResult:
980
"""Show the function creation code for the given aggregate function."""
981
name = quote_identifier(name)
982
return self._iquery(f'create aggregate {name}')
983
984
def create_function(self, name: str) -> ShowResult:
985
"""Show the function creation code for the given function."""
986
name = quote_identifier(name)
987
return self._iquery(f'create function {name}')
988
989
def create_pipeline(self, name: str, extended: bool = False) -> ShowResult:
990
"""Show the pipeline creation code for the given pipeline."""
991
name = quote_identifier(name)
992
if extended:
993
return self._iquery(f'create pipeline {name} extended')
994
return self._iquery(f'create pipeline {name}')
995
996
def create_table(self, name: str) -> ShowResult:
997
"""Show the table creation code for the given table."""
998
name = quote_identifier(name)
999
return self._iquery(f'create table {name}')
1000
1001
def create_view(self, name: str) -> ShowResult:
1002
"""Show the view creation code for the given view."""
1003
name = quote_identifier(name)
1004
return self._iquery(f'create view {name}')
1005
1006
# def grants(
1007
# self,
1008
# user: Optional[str] = None,
1009
# hostname: Optional[str] = None,
1010
# role: Optional[str] = None
1011
# ) -> ShowResult:
1012
# """Show the privileges for the given user or role."""
1013
# if user:
1014
# if not re.match(r'^[\w+-_]+$', user):
1015
# raise ValueError(f'User name is not valid: {user}')
1016
# if hostname and not re.match(r'^[\w+-_\.]+$', hostname):
1017
# raise ValueError(f'Hostname is not valid: {hostname}')
1018
# if hostname:
1019
# return self._iquery(f"grants for '{user}@{hostname}'")
1020
# return self._iquery(f"grants for '{user}'")
1021
# if role:
1022
# if not re.match(r'^[\w+-_]+$', role):
1023
# raise ValueError(f'Role is not valid: {role}')
1024
# return self._iquery(f"grants for role '{role}'")
1025
# return self._iquery('grants')
1026
1027
def _iquery(self, qtype: str) -> ShowResult:
1028
"""Query the given object type."""
1029
out = self._conn._iquery(f'show {qtype}')
1030
for i, row in enumerate(out):
1031
new_row = {}
1032
for j, (k, v) in enumerate(row.items()):
1033
if j == 0:
1034
k = 'Name'
1035
new_row[under2camel(k)] = v
1036
out[i] = new_row
1037
return ShowResult(out)
1038
1039
1040
class Connection(metaclass=abc.ABCMeta):
1041
"""
1042
SingleStoreDB connection.
1043
1044
Instances of this object are typically created through the
1045
:func:`singlestoredb.connect` function rather than creating them directly.
1046
See the :func:`singlestoredb.connect` function for parameter definitions.
1047
1048
See Also
1049
--------
1050
:func:`singlestoredb.connect`
1051
1052
"""
1053
1054
Warning = exceptions.Warning
1055
Error = exceptions.Error
1056
InterfaceError = exceptions.InterfaceError
1057
DataError = exceptions.DataError
1058
DatabaseError = exceptions.DatabaseError
1059
OperationalError = exceptions.OperationalError
1060
IntegrityError = exceptions.IntegrityError
1061
InternalError = exceptions.InternalError
1062
ProgrammingError = exceptions.ProgrammingError
1063
NotSupportedError = exceptions.NotSupportedError
1064
1065
#: Read-only DB-API parameter style
1066
paramstyle = 'pyformat'
1067
1068
# Must be set by subclass
1069
driver = ''
1070
1071
# Populated when first needed
1072
_map_param_converter: Optional[sqlparams.SQLParams] = None
1073
_positional_param_converter: Optional[sqlparams.SQLParams] = None
1074
1075
def __init__(self, **kwargs: Any):
1076
"""Call :func:`singlestoredb.connect` instead."""
1077
self.connection_params: Dict[str, Any] = kwargs
1078
self.errorhandler = None
1079
self._results_type: str = kwargs.get('results_type', None) or 'tuples'
1080
1081
#: Session encoding
1082
self.encoding = self.connection_params.get('charset', None) or 'utf-8'
1083
self.encoding = self.encoding.replace('mb4', '')
1084
1085
# Handle various authentication types
1086
credential_type = self.connection_params.get('credential_type', None)
1087
if credential_type == auth.BROWSER_SSO:
1088
# TODO: Cache info for token refreshes
1089
info = auth.get_jwt(self.connection_params['user'])
1090
self.connection_params['password'] = str(info)
1091
self.connection_params['credential_type'] = auth.JWT
1092
1093
#: Attribute-like access to global server variables
1094
self.globals = VariableAccessor(self, 'global')
1095
1096
#: Attribute-like access to local / session server variables
1097
self.locals = VariableAccessor(self, 'local')
1098
1099
#: Attribute-like access to cluster global server variables
1100
self.cluster_globals = VariableAccessor(self, 'cluster global')
1101
1102
#: Attribute-like access to cluster local / session server variables
1103
self.cluster_locals = VariableAccessor(self, 'cluster local')
1104
1105
#: Attribute-like access to all server variables
1106
self.vars = VariableAccessor(self, '')
1107
1108
#: Attribute-like access to all cluster server variables
1109
self.cluster_vars = VariableAccessor(self, 'cluster')
1110
1111
# For backwards compatibility with SQLAlchemy package
1112
self._driver = Driver(self.driver)
1113
1114
# Output decoders
1115
self.decoders: Dict[int, Callable[[Any], Any]] = {}
1116
1117
@classmethod
1118
def _convert_params(
1119
cls, oper: str,
1120
params: Optional[Union[Sequence[Any], Dict[str, Any], Any]],
1121
) -> Tuple[Any, ...]:
1122
"""Convert query to correct parameter format."""
1123
if params:
1124
1125
if cls._map_param_converter is None:
1126
cls._map_param_converter = sqlparams.SQLParams(
1127
map_paramstyle, cls.paramstyle, escape_char=True,
1128
)
1129
1130
if cls._positional_param_converter is None:
1131
cls._positional_param_converter = sqlparams.SQLParams(
1132
positional_paramstyle, cls.paramstyle, escape_char=True,
1133
)
1134
1135
is_sequence = isinstance(params, Sequence) \
1136
and not isinstance(params, str) \
1137
and not isinstance(params, bytes)
1138
is_mapping = isinstance(params, Mapping)
1139
1140
param_converter = cls._map_param_converter \
1141
if is_mapping else cls._positional_param_converter
1142
1143
if not is_sequence and not is_mapping:
1144
params = [params]
1145
1146
return param_converter.format(oper, params)
1147
1148
return (oper, None)
1149
1150
def autocommit(self, value: bool = True) -> None:
1151
"""Set autocommit mode."""
1152
self.locals.autocommit = bool(value)
1153
1154
@abc.abstractmethod
1155
def connect(self) -> 'Connection':
1156
"""Connect to the server."""
1157
raise NotImplementedError
1158
1159
def _iquery(
1160
self, oper: str,
1161
params: Optional[Union[Sequence[Any], Dict[str, Any]]] = None,
1162
fix_names: bool = True,
1163
) -> List[Dict[str, Any]]:
1164
"""Return the results of a query as a list of dicts (for internal use)."""
1165
with self.cursor() as cur:
1166
cur.execute(oper, params)
1167
if not re.match(r'^\s*(select|show|call|echo)\s+', oper, flags=re.I):
1168
return []
1169
out = list(cur.fetchall())
1170
if not out:
1171
return []
1172
if isinstance(out, DataFrame):
1173
out = out.to_dict(orient='records')
1174
elif isinstance(out[0], (tuple, list)):
1175
if cur.description:
1176
names = [x[0] for x in cur.description]
1177
if fix_names:
1178
names = [under2camel(str(x).replace(' ', '')) for x in names]
1179
out = [{k: v for k, v in zip(names, row)} for row in out]
1180
return out
1181
1182
@abc.abstractmethod
1183
def close(self) -> None:
1184
"""Close the database connection."""
1185
raise NotImplementedError
1186
1187
@abc.abstractmethod
1188
def commit(self) -> None:
1189
"""Commit the pending transaction."""
1190
raise NotImplementedError
1191
1192
@abc.abstractmethod
1193
def rollback(self) -> None:
1194
"""Rollback the pending transaction."""
1195
raise NotImplementedError
1196
1197
@abc.abstractmethod
1198
def cursor(self) -> Cursor:
1199
"""
1200
Create a new cursor object.
1201
1202
See Also
1203
--------
1204
:class:`Cursor`
1205
1206
Returns
1207
-------
1208
:class:`Cursor`
1209
1210
"""
1211
raise NotImplementedError
1212
1213
@abc.abstractproperty
1214
def messages(self) -> List[Tuple[int, str]]:
1215
"""Messages generated during the connection."""
1216
raise NotImplementedError
1217
1218
def __enter__(self) -> 'Connection':
1219
"""Enter a context."""
1220
return self
1221
1222
def __exit__(
1223
self, exc_type: Optional[object],
1224
exc_value: Optional[Exception], exc_traceback: Optional[str],
1225
) -> None:
1226
"""Exit a context."""
1227
self.close()
1228
1229
@abc.abstractmethod
1230
def is_connected(self) -> bool:
1231
"""
1232
Determine if the database is still connected.
1233
1234
Returns
1235
-------
1236
bool
1237
1238
"""
1239
raise NotImplementedError
1240
1241
def enable_data_api(self, port: Optional[int] = None) -> int:
1242
"""
1243
Enable the data API in the server.
1244
1245
Use of this method requires privileges that allow setting global
1246
variables and starting the HTTP proxy.
1247
1248
Parameters
1249
----------
1250
port : int, optional
1251
The port number that the HTTP server should run on. If this
1252
value is not specified, the current value of the
1253
``http_proxy_port`` is used.
1254
1255
See Also
1256
--------
1257
:meth:`disable_data_api`
1258
1259
Returns
1260
-------
1261
int
1262
port number of the HTTP server
1263
1264
"""
1265
if port is not None:
1266
self.globals.http_proxy_port = int(port)
1267
self.globals.http_api = True
1268
self._iquery('restart proxy')
1269
return int(self.globals.http_proxy_port)
1270
1271
enable_http_api = enable_data_api
1272
1273
def disable_data_api(self) -> None:
1274
"""
1275
Disable the data API.
1276
1277
See Also
1278
--------
1279
:meth:`enable_data_api`
1280
1281
"""
1282
self.globals.http_api = False
1283
self._iquery('restart proxy')
1284
1285
disable_http_api = disable_data_api
1286
1287
@property
1288
def show(self) -> ShowAccessor:
1289
"""Access server properties managed by the SHOW statement."""
1290
return ShowAccessor(self)
1291
1292
@functools.cached_property
1293
def vector_db(self) -> Any:
1294
"""
1295
Get vectorstore API accessor
1296
"""
1297
from vectorstore import VectorDB
1298
return VectorDB(connection=self)
1299
1300
1301
#
1302
# NOTE: When adding parameters to this function, you should always
1303
# make the value optional with a default of None. The options
1304
# processing framework will fill in the default value based
1305
# on environment variables or other configuration sources.
1306
#
1307
def connect(
1308
host: Optional[str] = None, user: Optional[str] = None,
1309
password: Optional[str] = None, port: Optional[int] = None,
1310
database: Optional[str] = None, driver: Optional[str] = None,
1311
pure_python: Optional[bool] = None, local_infile: Optional[bool] = None,
1312
charset: Optional[str] = None,
1313
ssl_key: Optional[str] = None, ssl_cert: Optional[str] = None,
1314
ssl_ca: Optional[str] = None, ssl_disabled: Optional[bool] = None,
1315
ssl_cipher: Optional[str] = None, ssl_verify_cert: Optional[bool] = None,
1316
tls_sni_servername: Optional[str] = None,
1317
ssl_verify_identity: Optional[bool] = None,
1318
conv: Optional[Dict[int, Callable[..., Any]]] = None,
1319
credential_type: Optional[str] = None,
1320
autocommit: Optional[bool] = None,
1321
results_type: Optional[str] = None,
1322
buffered: Optional[bool] = None,
1323
results_format: Optional[str] = None,
1324
program_name: Optional[str] = None,
1325
conn_attrs: Optional[Dict[str, str]] = None,
1326
multi_statements: Optional[bool] = None,
1327
client_found_rows: Optional[bool] = None,
1328
connect_timeout: Optional[int] = None,
1329
nan_as_null: Optional[bool] = None,
1330
inf_as_null: Optional[bool] = None,
1331
encoding_errors: Optional[str] = None,
1332
track_env: Optional[bool] = None,
1333
enable_extended_data_types: Optional[bool] = None,
1334
vector_data_format: Optional[str] = None,
1335
parse_json: Optional[bool] = None,
1336
) -> Connection:
1337
"""
1338
Return a SingleStoreDB connection.
1339
1340
Parameters
1341
----------
1342
host : str, optional
1343
Hostname, IP address, or URL that describes the connection.
1344
The scheme or protocol defines which database connector to use.
1345
By default, the ``mysql`` scheme is used. To connect to the
1346
HTTP API, the scheme can be set to ``http`` or ``https``. The username,
1347
password, host, and port are specified as in a standard URL. The path
1348
indicates the database name. The overall form of the URL is:
1349
``scheme://user:password@host:port/db_name``. The scheme can
1350
typically be left off (unless you are using the HTTP API):
1351
``user:password@host:port/db_name``.
1352
user : str, optional
1353
Database user name
1354
password : str, optional
1355
Database user password
1356
port : int, optional
1357
Database port. This defaults to 3306 for non-HTTP connections, 80
1358
for HTTP connections, and 443 for HTTPS connections.
1359
database : str, optional
1360
Database name
1361
pure_python : bool, optional
1362
Use the connector in pure Python mode
1363
local_infile : bool, optional
1364
Allow local file uploads
1365
charset : str, optional
1366
Character set for string values
1367
ssl_key : str, optional
1368
File containing SSL key
1369
ssl_cert : str, optional
1370
File containing SSL certificate
1371
ssl_ca : str, optional
1372
File containing SSL certificate authority
1373
ssl_cipher : str, optional
1374
Sets the SSL cipher list
1375
ssl_disabled : bool, optional
1376
Disable SSL usage
1377
ssl_verify_cert : bool, optional
1378
Verify the server's certificate. This is automatically enabled if
1379
``ssl_ca`` is also specified.
1380
ssl_verify_identity : bool, optional
1381
Verify the server's identity
1382
conv : dict[int, Callable], optional
1383
Dictionary of data conversion functions
1384
credential_type : str, optional
1385
Type of authentication to use: auth.PASSWORD, auth.JWT, or auth.BROWSER_SSO
1386
autocommit : bool, optional
1387
Enable autocommits
1388
results_type : str, optional
1389
The form of the query results: tuples, namedtuples, dicts,
1390
numpy, polars, pandas, arrow
1391
buffered : bool, optional
1392
Should the entire query result be buffered in memory? This is the default
1393
behavior which allows full cursor control of the result, but does consume
1394
more memory.
1395
results_format : str, optional
1396
Deprecated. This option has been renamed to results_type.
1397
program_name : str, optional
1398
Name of the program
1399
conn_attrs : dict, optional
1400
Additional connection attributes for telemetry. Example:
1401
{'program_version': "1.0.2", "_connector_name": "dbt connector"}
1402
multi_statements: bool, optional
1403
Should multiple statements be allowed within a single query?
1404
connect_timeout : int, optional
1405
The timeout for connecting to the database in seconds.
1406
(default: 10, min: 1, max: 31536000)
1407
nan_as_null : bool, optional
1408
Should NaN values be treated as NULLs when used in parameter
1409
substitutions including uploaded data?
1410
inf_as_null : bool, optional
1411
Should Inf values be treated as NULLs when used in parameter
1412
substitutions including uploaded data?
1413
encoding_errors : str, optional
1414
The error handler name for value decoding errors
1415
track_env : bool, optional
1416
Should the connection track the SINGLESTOREDB_URL environment variable?
1417
enable_extended_data_types : bool, optional
1418
Should extended data types (BSON, vector) be enabled?
1419
vector_data_format : str, optional
1420
Format for vector types: json or binary
1421
1422
Examples
1423
--------
1424
Standard database connection
1425
1426
>>> conn = s2.connect('me:[email protected]/my_db')
1427
1428
Connect to HTTP API on port 8080
1429
1430
>>> conn = s2.connect('http://me:[email protected]:8080/my_db')
1431
1432
Using an environment variable for connection string
1433
1434
>>> os.environ['SINGLESTOREDB_URL'] = 'me:[email protected]/my_db'
1435
>>> conn = s2.connect()
1436
1437
Specifying credentials using environment variables
1438
1439
>>> os.environ['SINGLESTOREDB_USER'] = 'me'
1440
>>> os.environ['SINGLESTOREDB_PASSWORD'] = 'p455w0rd'
1441
>>> conn = s2.connect('s2-host.com/my_db')
1442
1443
Specifying options with keyword parameters
1444
1445
>>> conn = s2.connect('s2-host.com/my_db', user='me', password='p455w0rd',
1446
local_infile=True)
1447
1448
Specifying options with URL parameters
1449
1450
>>> conn = s2.connect('s2-host.com/my_db?local_infile=True&charset=utf8')
1451
1452
Connecting within a context manager
1453
1454
>>> with s2.connect('...') as conn:
1455
... with conn.cursor() as cur:
1456
... cur.execute('...')
1457
1458
Setting session variables, the code below sets the ``autocommit`` option
1459
1460
>>> conn.locals.autocommit = True
1461
1462
Getting session variables
1463
1464
>>> conn.locals.autocommit
1465
True
1466
1467
See Also
1468
--------
1469
:class:`Connection`
1470
1471
Returns
1472
-------
1473
:class:`Connection`
1474
1475
"""
1476
params = build_params(**dict(locals()))
1477
driver = params.get('driver', 'mysql')
1478
1479
if not driver or driver == 'mysql':
1480
from .mysql.connection import Connection # type: ignore
1481
return Connection(**params)
1482
1483
if driver in ['http', 'https']:
1484
from .http.connection import Connection
1485
return Connection(**params)
1486
1487
raise ValueError(f'Unrecognized protocol: {driver}')
1488
1489