Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
singlestore-labs
GitHub Repository: singlestore-labs/singlestoredb-python
Path: blob/main/singlestoredb/connection.py
801 views
1
#!/usr/bin/env python
2
"""SingleStoreDB connections and cursors."""
3
import abc
4
import functools
5
import inspect
6
import io
7
import queue
8
import re
9
import sys
10
import warnings
11
import weakref
12
from collections.abc import Iterator
13
from collections.abc import Mapping
14
from collections.abc import MutableMapping
15
from collections.abc import Sequence
16
from typing import Any
17
from typing import Callable
18
from typing import Dict
19
from typing import List
20
from typing import Optional
21
from typing import Tuple
22
from typing import Union
23
from urllib.parse import parse_qs
24
from urllib.parse import unquote_plus
25
from urllib.parse import urlparse
26
27
import sqlparams
28
try:
29
from pandas import DataFrame
30
except ImportError:
31
class DataFrame(object): # type: ignore
32
def itertuples(self, *args: Any, **kwargs: Any) -> None:
33
pass
34
35
from . import auth
36
from . import exceptions
37
from .config import get_option
38
from .utils.results import Description
39
from .utils.results import Result
40
41
if sys.version_info < (3, 10):
42
InfileQueue = queue.Queue
43
else:
44
InfileQueue = queue.Queue[Union[bytes, str]]
45
46
47
# DB-API settings
48
apilevel = '2.0'
49
threadsafety = 1
50
paramstyle = map_paramstyle = 'pyformat'
51
positional_paramstyle = 'format'
52
53
54
# Type codes for character-based columns
55
CHAR_COLUMNS = set(list(range(247, 256)) + [245])
56
57
58
def under2camel(s: str) -> str:
59
"""Format underscore-delimited strings to camel-case."""
60
61
def upper_mid(m: Any) -> str:
62
"""Uppercase middle group of matches."""
63
return m.group(1) + m.group(2).upper() + m.group(3)
64
65
def upper(m: Any) -> str:
66
"""Uppercase match."""
67
return m.group(1).upper()
68
69
s = re.sub(r'(\b|_)(xml|sql|json)(\b|_)', upper_mid, s, flags=re.I)
70
s = re.sub(r'(?:^|_+)(\w)', upper, s)
71
s = re.sub(r'_+$', r'', s)
72
73
return s
74
75
76
def nested_converter(
77
conv: Callable[[Any], Any],
78
inner: Callable[[Any], Any],
79
) -> Callable[[Any], Any]:
80
"""Create a pipeline of two functions."""
81
def converter(value: Any) -> Any:
82
return conv(inner(value))
83
return converter
84
85
86
def cast_bool_param(val: Any) -> bool:
87
"""Cast value to a bool."""
88
if val is None or val is False:
89
return False
90
91
if val is True:
92
return True
93
94
# Test ints
95
try:
96
ival = int(val)
97
if ival == 1:
98
return True
99
if ival == 0:
100
return False
101
except Exception:
102
pass
103
104
# Lowercase strings
105
if hasattr(val, 'lower'):
106
if val.lower() in ['on', 't', 'true', 'y', 'yes', 'enabled', 'enable']:
107
return True
108
elif val.lower() in ['off', 'f', 'false', 'n', 'no', 'disabled', 'disable']:
109
return False
110
111
raise ValueError('Unrecognized value for bool: {}'.format(val))
112
113
114
def build_params(**kwargs: Any) -> Dict[str, Any]:
115
"""
116
Construct connection parameters from given URL and arbitrary parameters.
117
118
Parameters
119
----------
120
**kwargs : keyword-parameters, optional
121
Arbitrary keyword parameters corresponding to connection parameters
122
123
Returns
124
-------
125
dict
126
127
"""
128
out: Dict[str, Any] = {}
129
130
kwargs = {k: v for k, v in kwargs.items() if v is not None}
131
132
# Set known parameters
133
for name in inspect.getfullargspec(connect).args:
134
if name == 'conv':
135
out[name] = kwargs.get(name, None)
136
elif name == 'results_format': # deprecated
137
if kwargs.get(name, None) is not None:
138
warnings.warn(
139
'The `results_format=` parameter has been '
140
'renamed to `results_type=`.',
141
DeprecationWarning,
142
)
143
out['results_type'] = kwargs.get(name, get_option('results.type'))
144
elif name == 'results_type':
145
out[name] = kwargs.get(name, get_option('results.type'))
146
else:
147
out[name] = kwargs.get(name, get_option(name))
148
149
# See if host actually contains a URL; definitely not a perfect test.
150
host = out['host']
151
if host and (':' in host or '/' in host or '@' in host or '?' in host):
152
urlp = _parse_url(host)
153
if 'driver' not in urlp:
154
urlp['driver'] = get_option('driver')
155
out.update(urlp)
156
157
out = _cast_params(out)
158
159
# Set default port based on driver.
160
if 'port' not in out or not out['port']:
161
if out['driver'] == 'http':
162
out['port'] = int(get_option('http_port') or 80)
163
elif out['driver'] == 'https':
164
out['port'] = int(get_option('http_port') or 443)
165
else:
166
out['port'] = int(get_option('port') or 3306)
167
168
# If there is no user and the password is empty, remove the password key.
169
if 'user' not in out and not out.get('password', None):
170
out.pop('password', None)
171
172
if out.get('ssl_ca', '') and not out.get('ssl_verify_cert', None):
173
out['ssl_verify_cert'] = True
174
175
return out
176
177
178
def _get_param_types(func: Any) -> Dict[str, Any]:
179
"""
180
Retrieve the types for the parameters to the given function.
181
182
Note that if a parameter has multiple possible types, only the
183
first one is returned.
184
185
Parameters
186
----------
187
func : callable
188
Callable object to inspect the parameters of
189
190
Returns
191
-------
192
dict
193
194
"""
195
out = {}
196
args = inspect.getfullargspec(func)
197
for name in args.args:
198
ann = args.annotations[name]
199
if isinstance(ann, str):
200
ann = eval(ann)
201
if hasattr(ann, '__args__'):
202
out[name] = ann.__args__[0]
203
else:
204
out[name] = ann
205
return out
206
207
208
def _cast_params(params: Dict[str, Any]) -> Dict[str, Any]:
209
"""
210
Cast known keys to appropriate values.
211
212
Parameters
213
----------
214
params : dict
215
Dictionary of connection parameters
216
217
Returns
218
-------
219
dict
220
221
"""
222
param_types = _get_param_types(connect)
223
out = {}
224
for key, val in params.items():
225
key = key.lower()
226
if val is None:
227
continue
228
if key not in param_types:
229
raise ValueError('Unrecognized connection parameter: {}'.format(key))
230
dtype = param_types[key]
231
if dtype is bool:
232
val = cast_bool_param(val)
233
elif getattr(dtype, '_name', '') in ['Dict', 'Mapping'] or \
234
str(dtype).startswith('typing.Dict'):
235
val = dict(val)
236
elif getattr(dtype, '_name', '') == 'List':
237
val = list(val)
238
elif getattr(dtype, '_name', '') == 'Tuple':
239
val = tuple(val)
240
else:
241
val = dtype(val)
242
out[key] = val
243
return out
244
245
246
def _parse_url(url: str) -> Dict[str, Any]:
247
"""
248
Parse a connection URL and return only the defined parts.
249
250
Parameters
251
----------
252
url : str
253
The URL passed in can be a full URL or a partial URL. At a minimum,
254
a host name must be specified. All other parts are optional.
255
256
Returns
257
-------
258
dict
259
260
"""
261
out: Dict[str, Any] = {}
262
263
if '//' not in url:
264
url = '//' + url
265
266
if url.startswith('singlestoredb+'):
267
url = re.sub(r'^singlestoredb\+', r'', url)
268
269
parts = urlparse(url, scheme='singlestoredb', allow_fragments=True)
270
271
url_db = parts.path
272
if url_db.startswith('/'):
273
url_db = url_db.split('/')[1].strip()
274
url_db = url_db.split('/')[0].strip() or ''
275
276
# Retrieve basic connection parameters
277
out['host'] = parts.hostname or None
278
out['port'] = parts.port or None
279
out['database'] = url_db or None
280
out['user'] = parts.username or None
281
282
# Allow an empty string for password
283
if out['user'] and parts.password is not None:
284
out['password'] = parts.password
285
286
if parts.scheme != 'singlestoredb':
287
out['driver'] = parts.scheme.lower()
288
289
if out.get('user'):
290
out['user'] = unquote_plus(out['user'])
291
292
if out.get('password'):
293
out['password'] = unquote_plus(out['password'])
294
295
if out.get('database'):
296
out['database'] = unquote_plus(out['database'])
297
298
# Convert query string to parameters
299
out.update({k.lower(): v[-1] for k, v in parse_qs(parts.query).items()})
300
301
return {k: v for k, v in out.items() if v is not None}
302
303
304
def _name_check(name: str) -> str:
305
"""
306
Make sure the given name is a legal variable name.
307
308
Parameters
309
----------
310
name : str
311
Name to check
312
313
Returns
314
-------
315
str
316
317
"""
318
name = name.strip()
319
if not re.match(r'^[A-Za-z_][\w+_]*$', name):
320
raise ValueError('Name contains invalid characters')
321
return name
322
323
324
def quote_identifier(name: str) -> str:
325
"""Escape identifier value."""
326
return f'`{name}`'
327
328
329
class Driver(object):
330
"""Compatibility class for driver name."""
331
332
def __init__(self, name: str):
333
self.name = name
334
335
336
class VariableAccessor(MutableMapping): # type: ignore
337
"""Variable accessor class."""
338
339
def __init__(self, conn: 'Connection', vtype: str):
340
object.__setattr__(self, 'connection', weakref.proxy(conn))
341
object.__setattr__(self, 'vtype', vtype.lower())
342
if self.vtype not in [
343
'global', 'local', '',
344
'cluster', 'cluster global', 'cluster local',
345
]:
346
raise ValueError(
347
'Variable type must be global, local, cluster, '
348
'cluster global, cluster local, or empty',
349
)
350
351
def _cast_value(self, value: Any) -> Any:
352
if isinstance(value, str):
353
if value.lower() in ['on', 'true']:
354
return True
355
if value.lower() in ['off', 'false']:
356
return False
357
return value
358
359
def __getitem__(self, name: str) -> Any:
360
name = _name_check(name)
361
out = self.connection._iquery(
362
'show {} variables like %s;'.format(self.vtype),
363
[name],
364
)
365
if not out:
366
raise KeyError(f"No variable found with the name '{name}'.")
367
if len(out) > 1:
368
raise KeyError(f"Multiple variables found with the name '{name}'.")
369
return self._cast_value(out[0]['Value'])
370
371
def __setitem__(self, name: str, value: Any) -> None:
372
name = _name_check(name)
373
if value is True:
374
value = 'ON'
375
elif value is False:
376
value = 'OFF'
377
if 'local' in self.vtype:
378
self.connection._iquery(
379
'set {} {}=%s;'.format(
380
self.vtype.replace('local', 'session'), name,
381
), [value],
382
)
383
else:
384
self.connection._iquery('set {} {}=%s;'.format(self.vtype, name), [value])
385
386
def __delitem__(self, name: str) -> None:
387
raise TypeError('Variables can not be deleted.')
388
389
def __getattr__(self, name: str) -> Any:
390
return self[name]
391
392
def __setattr__(self, name: str, value: Any) -> None:
393
self[name] = value
394
395
def __delattr__(self, name: str) -> None:
396
del self[name]
397
398
def __len__(self) -> int:
399
out = self.connection._iquery('show {} variables;'.format(self.vtype))
400
return len(list(out))
401
402
def __iter__(self) -> Iterator[str]:
403
out = self.connection._iquery('show {} variables;'.format(self.vtype))
404
return iter(list(x.values())[0] for x in out)
405
406
407
class Cursor(metaclass=abc.ABCMeta):
408
"""
409
Database cursor for submitting commands and queries.
410
411
This object should not be instantiated directly.
412
The ``Connection.cursor`` method should be used.
413
414
"""
415
416
def __init__(self, connection: 'Connection'):
417
"""Call ``Connection.cursor`` instead."""
418
self.errorhandler = connection.errorhandler
419
self._connection: Optional[Connection] = weakref.proxy(connection)
420
421
self._rownumber: Optional[int] = None
422
423
self._description: Optional[List[Description]] = None
424
425
#: Default batch size of ``fetchmany`` calls.
426
self.arraysize = get_option('results.arraysize')
427
428
self._converters: List[
429
Tuple[
430
int, Optional[str],
431
Optional[Callable[..., Any]],
432
]
433
] = []
434
435
#: Number of rows affected by the last query.
436
self.rowcount: int = -1
437
438
self._messages: List[Tuple[int, str]] = []
439
440
#: Row ID of the last modified row.
441
self.lastrowid: Optional[int] = None
442
443
@property
444
def messages(self) -> List[Tuple[int, str]]:
445
"""Messages created by the server."""
446
return self._messages
447
448
@abc.abstractproperty
449
def description(self) -> Optional[List[Description]]:
450
"""The field descriptions of the last query."""
451
return self._description
452
453
@abc.abstractproperty
454
def rownumber(self) -> Optional[int]:
455
"""The last modified row number."""
456
return self._rownumber
457
458
@property
459
def connection(self) -> Optional['Connection']:
460
"""the connection that the cursor belongs to."""
461
return self._connection
462
463
@abc.abstractmethod
464
def callproc(
465
self, name: str,
466
params: Optional[Sequence[Any]] = None,
467
) -> None:
468
"""
469
Call a stored procedure.
470
471
The result sets generated by a store procedure can be retrieved
472
like the results of any other query using :meth:`fetchone`,
473
:meth:`fetchmany`, or :meth:`fetchall`. If the procedure generates
474
multiple result sets, subsequent result sets can be accessed
475
using :meth:`nextset`.
476
477
Examples
478
--------
479
>>> cur.callproc('myprocedure', ['arg1', 'arg2'])
480
>>> print(cur.fetchall())
481
482
Parameters
483
----------
484
name : str
485
Name of the stored procedure
486
params : iterable, optional
487
Parameters to the stored procedure
488
489
"""
490
# NOTE: The `callproc` interface varies quite a bit between drivers
491
# so it is implemented using `execute` here.
492
493
if not self.is_connected():
494
raise exceptions.InterfaceError(2048, 'Cursor is closed.')
495
496
name = _name_check(name)
497
498
if not params:
499
self.execute(f'CALL {name}();')
500
else:
501
keys = ', '.join([f':{i+1}' for i in range(len(params))])
502
self.execute(f'CALL {name}({keys});', params)
503
504
@abc.abstractmethod
505
def is_connected(self) -> bool:
506
"""Is the cursor still connected?"""
507
raise NotImplementedError
508
509
@abc.abstractmethod
510
def close(self) -> None:
511
"""Close the cursor."""
512
raise NotImplementedError
513
514
@abc.abstractmethod
515
def execute(
516
self, query: str,
517
args: Optional[Union[Sequence[Any], Dict[str, Any], Any]] = None,
518
infile_stream: Optional[ # type: ignore
519
Union[
520
io.RawIOBase,
521
io.TextIOBase,
522
Iterator[Union[bytes, str]],
523
InfileQueue,
524
]
525
] = None,
526
) -> int:
527
"""
528
Execute a SQL statement.
529
530
Queries can use the ``format``-style parameters (``%s``) when using a
531
list of paramters or ``pyformat``-style parameters (``%(key)s``)
532
when using a dictionary of parameters.
533
534
Parameters
535
----------
536
query : str
537
The SQL statement to execute
538
args : Sequence or dict, optional
539
Parameters to substitute into the SQL code
540
infile_stream : io.RawIOBase or io.TextIOBase or Iterator[bytes|str], optional
541
Data stream for ``LOCAL INFILE`` statement
542
543
Examples
544
--------
545
Query with no parameters
546
547
>>> cur.execute('select * from mytable')
548
549
Query with positional parameters
550
551
>>> cur.execute('select * from mytable where id < %s', [100])
552
553
Query with named parameters
554
555
>>> cur.execute('select * from mytable where id < %(max)s', dict(max=100))
556
557
Returns
558
-------
559
Number of rows affected
560
561
"""
562
raise NotImplementedError
563
564
def executemany(
565
self, query: str,
566
args: Optional[Sequence[Union[Sequence[Any], Dict[str, Any], Any]]] = None,
567
) -> int:
568
"""
569
Execute SQL code against multiple sets of parameters.
570
571
Queries can use the ``format``-style parameters (``%s``) when using
572
lists of paramters or ``pyformat``-style parameters (``%(key)s``)
573
when using dictionaries of parameters.
574
575
Parameters
576
----------
577
query : str
578
The SQL statement to execute
579
args : iterable of iterables or dicts, optional
580
Sets of parameters to substitute into the SQL code
581
582
Examples
583
--------
584
>>> cur.executemany('select * from mytable where id < %s',
585
... [[100], [200], [300]])
586
587
>>> cur.executemany('select * from mytable where id < %(max)s',
588
... [dict(max=100), dict(max=100), dict(max=300)])
589
590
Returns
591
-------
592
Number of rows affected
593
594
"""
595
# NOTE: Just implement using `execute` to cover driver inconsistencies
596
if not args:
597
self.execute(query)
598
else:
599
for params in args:
600
self.execute(query, params)
601
return self.rowcount
602
603
@abc.abstractmethod
604
def fetchone(self) -> Optional[Result]:
605
"""
606
Fetch a single row from the result set.
607
608
Examples
609
--------
610
>>> while True:
611
... row = cur.fetchone()
612
... if row is None:
613
... break
614
... print(row)
615
616
Returns
617
-------
618
tuple
619
Values of the returned row if there are rows remaining
620
621
"""
622
raise NotImplementedError
623
624
@abc.abstractmethod
625
def fetchmany(self, size: Optional[int] = None) -> Result:
626
"""
627
Fetch `size` rows from the result.
628
629
If `size` is not specified, the `arraysize` attribute is used.
630
631
Examples
632
--------
633
>>> while True:
634
... out = cur.fetchmany(100)
635
... if not len(out):
636
... break
637
... for row in out:
638
... print(row)
639
640
Returns
641
-------
642
list of tuples
643
Values of the returned rows if there are rows remaining
644
645
"""
646
raise NotImplementedError
647
648
@abc.abstractmethod
649
def fetchall(self) -> Result:
650
"""
651
Fetch all rows in the result set.
652
653
Examples
654
--------
655
>>> for row in cur.fetchall():
656
... print(row)
657
658
Returns
659
-------
660
list of tuples
661
Values of the returned rows if there are rows remaining
662
None
663
If there are no rows to return
664
665
"""
666
raise NotImplementedError
667
668
@abc.abstractmethod
669
def nextset(self) -> Optional[bool]:
670
"""
671
Skip to the next available result set.
672
673
This is used when calling a procedure that returns multiple
674
results sets.
675
676
Note
677
----
678
The ``nextset`` method must be called until it returns an empty
679
set (i.e., once more than the number of expected result sets).
680
This is to retain compatibility with PyMySQL and MySOLdb.
681
682
Returns
683
-------
684
``True``
685
If another result set is available
686
``False``
687
If no other result set is available
688
689
"""
690
raise NotImplementedError
691
692
@abc.abstractmethod
693
def setinputsizes(self, sizes: Sequence[int]) -> None:
694
"""Predefine memory areas for parameters."""
695
raise NotImplementedError
696
697
@abc.abstractmethod
698
def setoutputsize(self, size: int, column: Optional[str] = None) -> None:
699
"""Set a column buffer size for fetches of large columns."""
700
raise NotImplementedError
701
702
@abc.abstractmethod
703
def scroll(self, value: int, mode: str = 'relative') -> None:
704
"""
705
Scroll the cursor to the position in the result set.
706
707
Parameters
708
----------
709
value : int
710
Value of the positional move
711
mode : str
712
Where to move the cursor from: 'relative' or 'absolute'
713
714
"""
715
raise NotImplementedError
716
717
def next(self) -> Optional[Result]:
718
"""
719
Return the next row from the result set for use in iterators.
720
721
Raises
722
------
723
StopIteration
724
If no more results exist
725
726
Returns
727
-------
728
tuple of values
729
730
"""
731
if not self.is_connected():
732
raise exceptions.InterfaceError(2048, 'Cursor is closed.')
733
out = self.fetchone()
734
if out is None:
735
raise StopIteration
736
return out
737
738
__next__ = next
739
740
def __iter__(self) -> Any:
741
"""Return result iterator."""
742
return self
743
744
def __enter__(self) -> 'Cursor':
745
"""Enter a context."""
746
return self
747
748
def __exit__(
749
self, exc_type: Optional[object],
750
exc_value: Optional[Exception], exc_traceback: Optional[str],
751
) -> None:
752
"""Exit a context."""
753
self.close()
754
755
756
class ShowResult(Sequence[Any]):
757
"""
758
Simple result object.
759
760
This object is primarily used for displaying results to a
761
terminal or web browser, but it can also be treated like a
762
simple data frame where columns are accessible using either
763
dictionary key-like syntax or attribute syntax.
764
765
Examples
766
--------
767
>>> conn.show.status().Value[10]
768
769
>>> conn.show.status()[10]['Value']
770
771
Parameters
772
----------
773
*args : Any
774
Parameters to send to underlying list constructor
775
**kwargs : Any
776
Keyword parameters to send to underlying list constructor
777
778
See Also
779
--------
780
:attr:`Connection.show`
781
782
"""
783
784
def __init__(self, *args: Any, **kwargs: Any) -> None:
785
self._data: List[Dict[str, Any]] = []
786
item: Any = None
787
for item in list(*args, **kwargs):
788
self._data.append(item)
789
790
def __getitem__(self, item: Union[int, slice]) -> Any:
791
return self._data[item]
792
793
def __getattr__(self, name: str) -> List[Any]:
794
if name.startswith('_ipython'):
795
raise AttributeError(name)
796
out = []
797
for item in self._data:
798
out.append(item[name])
799
return out
800
801
def __len__(self) -> int:
802
return len(self._data)
803
804
def __repr__(self) -> str:
805
if not self._data:
806
return ''
807
return '\n{}\n'.format(self._format_table(self._data))
808
809
@property
810
def columns(self) -> List[str]:
811
"""The columns in the result."""
812
if not self._data:
813
return []
814
return list(self._data[0].keys())
815
816
def _format_table(self, rows: Sequence[Dict[str, Any]]) -> str:
817
if not self._data:
818
return ''
819
820
keys = rows[0].keys()
821
lens = [len(x) for x in keys]
822
823
for row in self._data:
824
align = ['<'] * len(keys)
825
for i, k in enumerate(keys):
826
lens[i] = max(lens[i], len(str(row[k])))
827
align[i] = '<' if isinstance(row[k], (bytes, bytearray, str)) else '>'
828
829
fmt = '| %s |' % '|'.join([' {:%s%d} ' % (x, y) for x, y in zip(align, lens)])
830
831
out = []
832
out.append(fmt.format(*keys))
833
out.append('-' * len(out[0]))
834
for row in rows:
835
out.append(fmt.format(*[str(x) for x in row.values()]))
836
return '\n'.join(out)
837
838
def __str__(self) -> str:
839
return self.__repr__()
840
841
def _repr_html_(self) -> str:
842
if not self._data:
843
return ''
844
cell_style = 'style="text-align: left; vertical-align: top"'
845
out = []
846
out.append('<table border="1" class="dataframe">')
847
out.append('<thead>')
848
out.append('<tr>')
849
for name in self._data[0].keys():
850
out.append(f'<th {cell_style}>{name}</th>')
851
out.append('</tr>')
852
out.append('</thead>')
853
out.append('<tbody>')
854
for row in self._data:
855
out.append('<tr>')
856
for item in row.values():
857
out.append(f'<td {cell_style}>{item}</td>')
858
out.append('</tr>')
859
out.append('</tbody>')
860
out.append('</table>')
861
return ''.join(out)
862
863
864
class ShowAccessor(object):
865
"""
866
Accessor for ``SHOW`` commands.
867
868
See Also
869
--------
870
:attr:`Connection.show`
871
872
"""
873
874
def __init__(self, conn: 'Connection'):
875
self._conn = conn
876
877
def columns(self, table: str, full: bool = False) -> ShowResult:
878
"""Show the column information for the given table."""
879
table = quote_identifier(table)
880
if full:
881
return self._iquery(f'full columns in {table}')
882
return self._iquery(f'columns in {table}')
883
884
def tables(self, extended: bool = False) -> ShowResult:
885
"""Show tables in the current database."""
886
if extended:
887
return self._iquery('tables extended')
888
return self._iquery('tables')
889
890
def warnings(self) -> ShowResult:
891
"""Show warnings."""
892
return self._iquery('warnings')
893
894
def errors(self) -> ShowResult:
895
"""Show errors."""
896
return self._iquery('errors')
897
898
def databases(self, extended: bool = False) -> ShowResult:
899
"""Show all databases in the server."""
900
if extended:
901
return self._iquery('databases extended')
902
return self._iquery('databases')
903
904
def database_status(self) -> ShowResult:
905
"""Show status of the current database."""
906
return self._iquery('database status')
907
908
def global_status(self) -> ShowResult:
909
"""Show global status of the current server."""
910
return self._iquery('global status')
911
912
def indexes(self, table: str) -> ShowResult:
913
"""Show all indexes in the given table."""
914
table = quote_identifier(table)
915
return self._iquery(f'indexes in {table}')
916
917
def functions(self) -> ShowResult:
918
"""Show all functions in the current database."""
919
return self._iquery('functions')
920
921
def partitions(self, extended: bool = False) -> ShowResult:
922
"""Show partitions in the current database."""
923
if extended:
924
return self._iquery('partitions extended')
925
return self._iquery('partitions')
926
927
def pipelines(self) -> ShowResult:
928
"""Show all pipelines in the current database."""
929
return self._iquery('pipelines')
930
931
def plan(self, plan_id: int, json: bool = False) -> ShowResult:
932
"""Show the plan for the given plan ID."""
933
plan_id = int(plan_id)
934
if json:
935
return self._iquery(f'plan json {plan_id}')
936
return self._iquery(f'plan {plan_id}')
937
938
def plancache(self) -> ShowResult:
939
"""Show all query statements compiled and executed."""
940
return self._iquery('plancache')
941
942
def processlist(self) -> ShowResult:
943
"""Show details about currently running threads."""
944
return self._iquery('processlist')
945
946
def reproduction(self, outfile: Optional[str] = None) -> ShowResult:
947
"""Show troubleshooting data for query optimizer and code generation."""
948
if outfile:
949
outfile = outfile.replace('"', r'\"')
950
return self._iquery('reproduction into outfile "{outfile}"')
951
return self._iquery('reproduction')
952
953
def schemas(self) -> ShowResult:
954
"""Show schemas in the server."""
955
return self._iquery('schemas')
956
957
def session_status(self) -> ShowResult:
958
"""Show server status information for a session."""
959
return self._iquery('session status')
960
961
def status(self, extended: bool = False) -> ShowResult:
962
"""Show server status information."""
963
if extended:
964
return self._iquery('status extended')
965
return self._iquery('status')
966
967
def table_status(self) -> ShowResult:
968
"""Show table status information for the current database."""
969
return self._iquery('table status')
970
971
def procedures(self) -> ShowResult:
972
"""Show all procedures in the current database."""
973
return self._iquery('procedures')
974
975
def aggregates(self) -> ShowResult:
976
"""Show all aggregate functions in the current database."""
977
return self._iquery('aggregates')
978
979
def create_aggregate(self, name: str) -> ShowResult:
980
"""Show the function creation code for the given aggregate function."""
981
name = quote_identifier(name)
982
return self._iquery(f'create aggregate {name}')
983
984
def create_function(self, name: str) -> ShowResult:
985
"""Show the function creation code for the given function."""
986
name = quote_identifier(name)
987
return self._iquery(f'create function {name}')
988
989
def create_pipeline(self, name: str, extended: bool = False) -> ShowResult:
990
"""Show the pipeline creation code for the given pipeline."""
991
name = quote_identifier(name)
992
if extended:
993
return self._iquery(f'create pipeline {name} extended')
994
return self._iquery(f'create pipeline {name}')
995
996
def create_table(self, name: str) -> ShowResult:
997
"""Show the table creation code for the given table."""
998
name = quote_identifier(name)
999
return self._iquery(f'create table {name}')
1000
1001
def create_view(self, name: str) -> ShowResult:
1002
"""Show the view creation code for the given view."""
1003
name = quote_identifier(name)
1004
return self._iquery(f'create view {name}')
1005
1006
# def grants(
1007
# self,
1008
# user: Optional[str] = None,
1009
# hostname: Optional[str] = None,
1010
# role: Optional[str] = None
1011
# ) -> ShowResult:
1012
# """Show the privileges for the given user or role."""
1013
# if user:
1014
# if not re.match(r'^[\w+-_]+$', user):
1015
# raise ValueError(f'User name is not valid: {user}')
1016
# if hostname and not re.match(r'^[\w+-_\.]+$', hostname):
1017
# raise ValueError(f'Hostname is not valid: {hostname}')
1018
# if hostname:
1019
# return self._iquery(f"grants for '{user}@{hostname}'")
1020
# return self._iquery(f"grants for '{user}'")
1021
# if role:
1022
# if not re.match(r'^[\w+-_]+$', role):
1023
# raise ValueError(f'Role is not valid: {role}')
1024
# return self._iquery(f"grants for role '{role}'")
1025
# return self._iquery('grants')
1026
1027
def _iquery(self, qtype: str) -> ShowResult:
1028
"""Query the given object type."""
1029
out = self._conn._iquery(f'show {qtype}')
1030
for i, row in enumerate(out):
1031
new_row = {}
1032
for j, (k, v) in enumerate(row.items()):
1033
if j == 0:
1034
k = 'Name'
1035
new_row[under2camel(k)] = v
1036
out[i] = new_row
1037
return ShowResult(out)
1038
1039
1040
class Connection(metaclass=abc.ABCMeta):
1041
"""
1042
SingleStoreDB connection.
1043
1044
Instances of this object are typically created through the
1045
:func:`singlestoredb.connect` function rather than creating them directly.
1046
See the :func:`singlestoredb.connect` function for parameter definitions.
1047
1048
See Also
1049
--------
1050
:func:`singlestoredb.connect`
1051
1052
"""
1053
1054
Warning = exceptions.Warning
1055
Error = exceptions.Error
1056
InterfaceError = exceptions.InterfaceError
1057
DataError = exceptions.DataError
1058
DatabaseError = exceptions.DatabaseError
1059
OperationalError = exceptions.OperationalError
1060
IntegrityError = exceptions.IntegrityError
1061
InternalError = exceptions.InternalError
1062
ProgrammingError = exceptions.ProgrammingError
1063
NotSupportedError = exceptions.NotSupportedError
1064
1065
#: Read-only DB-API parameter style
1066
paramstyle = 'pyformat'
1067
1068
# Must be set by subclass
1069
driver = ''
1070
1071
# Populated when first needed
1072
_map_param_converter: Optional[sqlparams.SQLParams] = None
1073
_positional_param_converter: Optional[sqlparams.SQLParams] = None
1074
1075
def __init__(self, **kwargs: Any):
1076
"""Call :func:`singlestoredb.connect` instead."""
1077
self.connection_params: Dict[str, Any] = kwargs
1078
self.errorhandler = None
1079
self._results_type: str = kwargs.get('results_type', None) or 'tuples'
1080
1081
#: Session encoding
1082
self.encoding = self.connection_params.get('charset', None) or 'utf-8'
1083
self.encoding = self.encoding.replace('mb4', '')
1084
1085
# Handle various authentication types
1086
credential_type = self.connection_params.get('credential_type', None)
1087
if credential_type == auth.BROWSER_SSO:
1088
# TODO: Cache info for token refreshes
1089
info = auth.get_jwt(self.connection_params['user'])
1090
self.connection_params['password'] = str(info)
1091
self.connection_params['credential_type'] = auth.JWT
1092
1093
#: Attribute-like access to global server variables
1094
self.globals = VariableAccessor(self, 'global')
1095
1096
#: Attribute-like access to local / session server variables
1097
self.locals = VariableAccessor(self, 'local')
1098
1099
#: Attribute-like access to cluster global server variables
1100
self.cluster_globals = VariableAccessor(self, 'cluster global')
1101
1102
#: Attribute-like access to cluster local / session server variables
1103
self.cluster_locals = VariableAccessor(self, 'cluster local')
1104
1105
#: Attribute-like access to all server variables
1106
self.vars = VariableAccessor(self, '')
1107
1108
#: Attribute-like access to all cluster server variables
1109
self.cluster_vars = VariableAccessor(self, 'cluster')
1110
1111
# For backwards compatibility with SQLAlchemy package
1112
self._driver = Driver(self.driver)
1113
1114
# Output decoders
1115
self.decoders: Dict[int, Callable[[Any], Any]] = {}
1116
1117
@classmethod
1118
def _convert_params(
1119
cls, oper: str,
1120
params: Optional[Union[Sequence[Any], Dict[str, Any], Any]],
1121
interpolate_query_with_empty_args: bool = False,
1122
) -> Tuple[Any, ...]:
1123
"""Convert query to correct parameter format."""
1124
if interpolate_query_with_empty_args:
1125
should_convert = params is not None
1126
else:
1127
should_convert = bool(params)
1128
1129
if should_convert:
1130
1131
if cls._map_param_converter is None:
1132
cls._map_param_converter = sqlparams.SQLParams(
1133
map_paramstyle, cls.paramstyle, escape_char=True,
1134
)
1135
1136
if cls._positional_param_converter is None:
1137
cls._positional_param_converter = sqlparams.SQLParams(
1138
positional_paramstyle, cls.paramstyle, escape_char=True,
1139
)
1140
1141
is_sequence = isinstance(params, Sequence) \
1142
and not isinstance(params, str) \
1143
and not isinstance(params, bytes)
1144
is_mapping = isinstance(params, Mapping)
1145
1146
param_converter = cls._map_param_converter \
1147
if is_mapping else cls._positional_param_converter
1148
1149
if not is_sequence and not is_mapping:
1150
params = [params]
1151
1152
return param_converter.format(oper, params)
1153
1154
return (oper, None)
1155
1156
def autocommit(self, value: bool = True) -> None:
1157
"""Set autocommit mode."""
1158
self.locals.autocommit = bool(value)
1159
1160
@abc.abstractmethod
1161
def connect(self) -> 'Connection':
1162
"""Connect to the server."""
1163
raise NotImplementedError
1164
1165
def _iquery(
1166
self, oper: str,
1167
params: Optional[Union[Sequence[Any], Dict[str, Any]]] = None,
1168
fix_names: bool = True,
1169
) -> List[Dict[str, Any]]:
1170
"""Return the results of a query as a list of dicts (for internal use)."""
1171
with self.cursor() as cur:
1172
cur.execute(oper, params)
1173
if not re.match(r'^\s*(select|show|call|echo)\s+', oper, flags=re.I):
1174
return []
1175
out = list(cur.fetchall())
1176
if not out:
1177
return []
1178
if isinstance(out, DataFrame):
1179
out = out.to_dict(orient='records')
1180
elif isinstance(out[0], (tuple, list)):
1181
if cur.description:
1182
names = [x[0] for x in cur.description]
1183
if fix_names:
1184
names = [under2camel(str(x).replace(' ', '')) for x in names]
1185
out = [{k: v for k, v in zip(names, row)} for row in out]
1186
return out
1187
1188
@abc.abstractmethod
1189
def close(self) -> None:
1190
"""Close the database connection."""
1191
raise NotImplementedError
1192
1193
@abc.abstractmethod
1194
def commit(self) -> None:
1195
"""Commit the pending transaction."""
1196
raise NotImplementedError
1197
1198
@abc.abstractmethod
1199
def rollback(self) -> None:
1200
"""Rollback the pending transaction."""
1201
raise NotImplementedError
1202
1203
@abc.abstractmethod
1204
def cursor(self) -> Cursor:
1205
"""
1206
Create a new cursor object.
1207
1208
See Also
1209
--------
1210
:class:`Cursor`
1211
1212
Returns
1213
-------
1214
:class:`Cursor`
1215
1216
"""
1217
raise NotImplementedError
1218
1219
@abc.abstractproperty
1220
def messages(self) -> List[Tuple[int, str]]:
1221
"""Messages generated during the connection."""
1222
raise NotImplementedError
1223
1224
def __enter__(self) -> 'Connection':
1225
"""Enter a context."""
1226
return self
1227
1228
def __exit__(
1229
self, exc_type: Optional[object],
1230
exc_value: Optional[Exception], exc_traceback: Optional[str],
1231
) -> None:
1232
"""Exit a context."""
1233
self.close()
1234
1235
@abc.abstractmethod
1236
def is_connected(self) -> bool:
1237
"""
1238
Determine if the database is still connected.
1239
1240
Returns
1241
-------
1242
bool
1243
1244
"""
1245
raise NotImplementedError
1246
1247
def enable_data_api(self, port: Optional[int] = None) -> int:
1248
"""
1249
Enable the data API in the server.
1250
1251
Use of this method requires privileges that allow setting global
1252
variables and starting the HTTP proxy.
1253
1254
Parameters
1255
----------
1256
port : int, optional
1257
The port number that the HTTP server should run on. If this
1258
value is not specified, the current value of the
1259
``http_proxy_port`` is used.
1260
1261
See Also
1262
--------
1263
:meth:`disable_data_api`
1264
1265
Returns
1266
-------
1267
int
1268
port number of the HTTP server
1269
1270
"""
1271
if port is not None:
1272
self.globals.http_proxy_port = int(port)
1273
self.globals.http_api = True
1274
self._iquery('restart proxy')
1275
return int(self.globals.http_proxy_port)
1276
1277
enable_http_api = enable_data_api
1278
1279
def disable_data_api(self) -> None:
1280
"""
1281
Disable the data API.
1282
1283
See Also
1284
--------
1285
:meth:`enable_data_api`
1286
1287
"""
1288
self.globals.http_api = False
1289
self._iquery('restart proxy')
1290
1291
disable_http_api = disable_data_api
1292
1293
@property
1294
def show(self) -> ShowAccessor:
1295
"""Access server properties managed by the SHOW statement."""
1296
return ShowAccessor(self)
1297
1298
@functools.cached_property
1299
def vector_db(self) -> Any:
1300
"""
1301
Get vectorstore API accessor
1302
"""
1303
from vectorstore import VectorDB
1304
return VectorDB(connection=self)
1305
1306
1307
#
1308
# NOTE: When adding parameters to this function, you should always
1309
# make the value optional with a default of None. The options
1310
# processing framework will fill in the default value based
1311
# on environment variables or other configuration sources.
1312
#
1313
def connect(
1314
host: Optional[str] = None, user: Optional[str] = None,
1315
password: Optional[str] = None, port: Optional[int] = None,
1316
database: Optional[str] = None, driver: Optional[str] = None,
1317
pure_python: Optional[bool] = None, local_infile: Optional[bool] = None,
1318
charset: Optional[str] = None,
1319
ssl_key: Optional[str] = None, ssl_cert: Optional[str] = None,
1320
ssl_ca: Optional[str] = None, ssl_disabled: Optional[bool] = None,
1321
ssl_cipher: Optional[str] = None, ssl_verify_cert: Optional[bool] = None,
1322
tls_sni_servername: Optional[str] = None,
1323
ssl_verify_identity: Optional[bool] = None,
1324
conv: Optional[Dict[int, Callable[..., Any]]] = None,
1325
credential_type: Optional[str] = None,
1326
autocommit: Optional[bool] = None,
1327
results_type: Optional[str] = None,
1328
buffered: Optional[bool] = None,
1329
results_format: Optional[str] = None,
1330
program_name: Optional[str] = None,
1331
conn_attrs: Optional[Dict[str, str]] = None,
1332
multi_statements: Optional[bool] = None,
1333
client_found_rows: Optional[bool] = None,
1334
connect_timeout: Optional[int] = None,
1335
nan_as_null: Optional[bool] = None,
1336
inf_as_null: Optional[bool] = None,
1337
encoding_errors: Optional[str] = None,
1338
track_env: Optional[bool] = None,
1339
enable_extended_data_types: Optional[bool] = None,
1340
vector_data_format: Optional[str] = None,
1341
parse_json: Optional[bool] = None,
1342
interpolate_query_with_empty_args: Optional[bool] = None,
1343
) -> Connection:
1344
"""
1345
Return a SingleStoreDB connection.
1346
1347
Parameters
1348
----------
1349
host : str, optional
1350
Hostname, IP address, or URL that describes the connection.
1351
The scheme or protocol defines which database connector to use.
1352
By default, the ``mysql`` scheme is used. To connect to the
1353
HTTP API, the scheme can be set to ``http`` or ``https``. The username,
1354
password, host, and port are specified as in a standard URL. The path
1355
indicates the database name. The overall form of the URL is:
1356
``scheme://user:password@host:port/db_name``. The scheme can
1357
typically be left off (unless you are using the HTTP API):
1358
``user:password@host:port/db_name``.
1359
user : str, optional
1360
Database user name
1361
password : str, optional
1362
Database user password
1363
port : int, optional
1364
Database port. This defaults to 3306 for non-HTTP connections, 80
1365
for HTTP connections, and 443 for HTTPS connections.
1366
database : str, optional
1367
Database name
1368
pure_python : bool, optional
1369
Use the connector in pure Python mode
1370
local_infile : bool, optional
1371
Allow local file uploads
1372
charset : str, optional
1373
Character set for string values
1374
ssl_key : str, optional
1375
File containing SSL key
1376
ssl_cert : str, optional
1377
File containing SSL certificate
1378
ssl_ca : str, optional
1379
File containing SSL certificate authority
1380
ssl_cipher : str, optional
1381
Sets the SSL cipher list
1382
ssl_disabled : bool, optional
1383
Disable SSL usage
1384
ssl_verify_cert : bool, optional
1385
Verify the server's certificate. This is automatically enabled if
1386
``ssl_ca`` is also specified.
1387
ssl_verify_identity : bool, optional
1388
Verify the server's identity
1389
conv : dict[int, Callable], optional
1390
Dictionary of data conversion functions
1391
credential_type : str, optional
1392
Type of authentication to use: auth.PASSWORD, auth.JWT, or auth.BROWSER_SSO
1393
autocommit : bool, optional
1394
Enable autocommits
1395
results_type : str, optional
1396
The form of the query results: tuples, namedtuples, dicts,
1397
numpy, polars, pandas, arrow
1398
buffered : bool, optional
1399
Should the entire query result be buffered in memory? This is the default
1400
behavior which allows full cursor control of the result, but does consume
1401
more memory.
1402
results_format : str, optional
1403
Deprecated. This option has been renamed to results_type.
1404
program_name : str, optional
1405
Name of the program
1406
conn_attrs : dict, optional
1407
Additional connection attributes for telemetry. Example:
1408
{'program_version': "1.0.2", "_connector_name": "dbt connector"}
1409
multi_statements: bool, optional
1410
Should multiple statements be allowed within a single query?
1411
connect_timeout : int, optional
1412
The timeout for connecting to the database in seconds.
1413
(default: 10, min: 1, max: 31536000)
1414
nan_as_null : bool, optional
1415
Should NaN values be treated as NULLs when used in parameter
1416
substitutions including uploaded data?
1417
inf_as_null : bool, optional
1418
Should Inf values be treated as NULLs when used in parameter
1419
substitutions including uploaded data?
1420
encoding_errors : str, optional
1421
The error handler name for value decoding errors
1422
track_env : bool, optional
1423
Should the connection track the SINGLESTOREDB_URL environment variable?
1424
enable_extended_data_types : bool, optional
1425
Should extended data types (BSON, vector) be enabled?
1426
vector_data_format : str, optional
1427
Format for vector types: json or binary
1428
interpolate_query_with_empty_args : bool, optional
1429
Should the connector apply parameter interpolation even when the
1430
parameters are empty? This corresponds to pymysql/mysqlclient's handling
1431
1432
Examples
1433
--------
1434
Standard database connection
1435
1436
>>> conn = s2.connect('me:[email protected]/my_db')
1437
1438
Connect to HTTP API on port 8080
1439
1440
>>> conn = s2.connect('http://me:[email protected]:8080/my_db')
1441
1442
Using an environment variable for connection string
1443
1444
>>> os.environ['SINGLESTOREDB_URL'] = 'me:[email protected]/my_db'
1445
>>> conn = s2.connect()
1446
1447
Specifying credentials using environment variables
1448
1449
>>> os.environ['SINGLESTOREDB_USER'] = 'me'
1450
>>> os.environ['SINGLESTOREDB_PASSWORD'] = 'p455w0rd'
1451
>>> conn = s2.connect('s2-host.com/my_db')
1452
1453
Specifying options with keyword parameters
1454
1455
>>> conn = s2.connect('s2-host.com/my_db', user='me', password='p455w0rd',
1456
local_infile=True)
1457
1458
Specifying options with URL parameters
1459
1460
>>> conn = s2.connect('s2-host.com/my_db?local_infile=True&charset=utf8')
1461
1462
Connecting within a context manager
1463
1464
>>> with s2.connect('...') as conn:
1465
... with conn.cursor() as cur:
1466
... cur.execute('...')
1467
1468
Setting session variables, the code below sets the ``autocommit`` option
1469
1470
>>> conn.locals.autocommit = True
1471
1472
Getting session variables
1473
1474
>>> conn.locals.autocommit
1475
True
1476
1477
See Also
1478
--------
1479
:class:`Connection`
1480
1481
Returns
1482
-------
1483
:class:`Connection`
1484
1485
"""
1486
params = build_params(**dict(locals()))
1487
driver = params.get('driver', 'mysql')
1488
1489
if not driver or driver == 'mysql':
1490
from .mysql.connection import Connection # type: ignore
1491
return Connection(**params)
1492
1493
if driver in ['http', 'https']:
1494
from .http.connection import Connection
1495
return Connection(**params)
1496
1497
raise ValueError(f'Unrecognized protocol: {driver}')
1498
1499