Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
singlestore-labs
GitHub Repository: singlestore-labs/singlestoredb-python
Path: blob/main/singlestoredb/tests/test_basics.py
469 views
1
#!/usr/bin/env python
2
# type: ignore
3
"""Basic SingleStoreDB connection testing."""
4
import datetime
5
import decimal
6
import math
7
import os
8
import unittest
9
from typing import Optional
10
11
from requests.exceptions import InvalidJSONError
12
13
try:
14
import numpy as np
15
has_numpy = True
16
except ImportError:
17
has_numpy = False
18
19
try:
20
import shapely.wkt
21
has_shapely = True
22
except ImportError:
23
has_shapely = False
24
25
try:
26
import pygeos
27
from pygeos.testing import assert_geometries_equal
28
has_pygeos = True
29
except ImportError:
30
has_pygeos = False
31
32
try:
33
import pydantic
34
has_pydantic = True
35
except ImportError:
36
has_pydantic = False
37
38
import singlestoredb as s2
39
from . import utils
40
# import traceback
41
42
43
class TestBasics(unittest.TestCase):
44
45
dbname: str = ''
46
dbexisted: bool = False
47
48
@classmethod
49
def setUpClass(cls):
50
sql_file = os.path.join(os.path.dirname(__file__), 'test.sql')
51
cls.dbname, cls.dbexisted = utils.load_sql(sql_file)
52
53
@classmethod
54
def tearDownClass(cls):
55
if not cls.dbexisted:
56
utils.drop_database(cls.dbname)
57
58
def setUp(self):
59
self.conn = s2.connect(database=type(self).dbname)
60
self.cur = self.conn.cursor()
61
62
def tearDown(self):
63
try:
64
if self.cur is not None:
65
self.cur.close()
66
except Exception:
67
# traceback.print_exc()
68
pass
69
70
try:
71
if self.conn is not None:
72
self.conn.close()
73
except Exception:
74
# traceback.print_exc()
75
pass
76
77
def test_connection(self):
78
self.cur.execute('show databases')
79
dbs = set([x[0] for x in self.cur.fetchall()])
80
assert type(self).dbname in dbs, dbs
81
82
def test_fetchall(self):
83
self.cur.execute('select * from data')
84
85
out = self.cur.fetchall()
86
87
desc = self.cur.description
88
rowcount = self.cur.rowcount
89
rownumber = self.cur.rownumber
90
lastrowid = self.cur.lastrowid
91
92
assert sorted(out) == sorted([
93
('a', 'antelopes', 2),
94
('b', 'bears', 2),
95
('c', 'cats', 5),
96
('d', 'dogs', 4),
97
('e', 'elephants', 0),
98
]), out
99
100
assert rowcount in (5, -1), rowcount
101
assert rownumber == 5, rownumber
102
assert lastrowid is None, lastrowid
103
assert len(desc) == 3, desc
104
assert desc[0].name == 'id', desc[0].name
105
assert desc[0].type_code in [253, 15], desc[0].type_code
106
assert desc[1].name == 'name', desc[1].name
107
assert desc[1].type_code in [253, 15], desc[1].type_code
108
assert desc[2].name == 'value', desc[2].name
109
assert desc[2].type_code == 8, desc[2].type_code
110
111
def test_fetchone(self):
112
self.cur.execute('select * from data')
113
114
out = []
115
while True:
116
row = self.cur.fetchone()
117
if row is None:
118
break
119
out.append(row)
120
assert self.cur.rownumber == len(out), self.cur.rownumber
121
122
desc = self.cur.description
123
rowcount = self.cur.rowcount
124
rownumber = self.cur.rownumber
125
lastrowid = self.cur.lastrowid
126
127
assert sorted(out) == sorted([
128
('a', 'antelopes', 2),
129
('b', 'bears', 2),
130
('c', 'cats', 5),
131
('d', 'dogs', 4),
132
('e', 'elephants', 0),
133
]), out
134
135
assert rowcount in (5, -1), rowcount
136
assert rownumber == 5, rownumber
137
assert lastrowid is None, lastrowid
138
assert len(desc) == 3, desc
139
assert desc[0].name == 'id', desc[0].name
140
assert desc[0].type_code in [253, 15], desc[0].type_code
141
assert desc[1].name == 'name', desc[1].name
142
assert desc[1].type_code in [253, 15], desc[1].type_code
143
assert desc[2].name == 'value', desc[2].name
144
assert desc[2].type_code == 8, desc[2].type_code
145
146
def test_fetchmany(self):
147
self.cur.execute('select * from data')
148
149
out = []
150
while True:
151
rows = self.cur.fetchmany(size=3)
152
assert len(rows) <= 3, rows
153
if not rows:
154
break
155
out.extend(rows)
156
assert self.cur.rownumber == len(out), self.cur.rownumber
157
158
desc = self.cur.description
159
rowcount = self.cur.rowcount
160
rownumber = self.cur.rownumber
161
lastrowid = self.cur.lastrowid
162
163
assert sorted(out) == sorted([
164
('a', 'antelopes', 2),
165
('b', 'bears', 2),
166
('c', 'cats', 5),
167
('d', 'dogs', 4),
168
('e', 'elephants', 0),
169
]), out
170
171
assert rowcount in (5, -1), rowcount
172
assert rownumber == 5, rownumber
173
assert lastrowid is None, lastrowid
174
assert len(desc) == 3, desc
175
assert desc[0].name == 'id'
176
assert desc[0].type_code in [253, 15]
177
assert desc[1].name == 'name'
178
assert desc[1].type_code in [253, 15]
179
assert desc[2].name == 'value'
180
assert desc[2].type_code == 8
181
182
def test_arraysize(self):
183
self.cur.execute('select * from data')
184
185
self.cur.arraysize = 3
186
assert self.cur.arraysize == 3
187
188
rows = self.cur.fetchmany()
189
assert len(rows) == 3, rows
190
assert self.cur.rownumber == 3, self.cur.rownumber
191
192
self.cur.arraysize = 1
193
assert self.cur.arraysize == 1
194
195
rows = self.cur.fetchmany()
196
assert len(rows) == 1, rows
197
assert self.cur.rownumber == 4, self.cur.rownumber
198
199
rows = self.cur.fetchmany()
200
assert len(rows) == 1, rows
201
assert self.cur.rownumber == 5, self.cur.rownumber
202
203
rows = self.cur.fetchall()
204
assert len(rows) == 0, rows
205
assert self.cur.rownumber == 5, self.cur.rownumber
206
207
def test_execute_with_dict_params(self):
208
self.cur.execute('select * from data where id < %(name)s', dict(name='d'))
209
out = self.cur.fetchall()
210
211
desc = self.cur.description
212
rowcount = self.cur.rowcount
213
lastrowid = self.cur.lastrowid
214
215
assert sorted(out) == sorted([
216
('a', 'antelopes', 2),
217
('b', 'bears', 2),
218
('c', 'cats', 5),
219
]), out
220
221
assert rowcount in (3, -1), rowcount
222
assert lastrowid is None, lastrowid
223
assert len(desc) == 3, desc
224
assert desc[0].name == 'id', desc[0].name
225
assert desc[0].type_code in [253, 15], desc[0].type_code
226
assert desc[1].name == 'name', desc[1].name
227
assert desc[1].type_code in [253, 15], desc[1].type_code
228
assert desc[2].name == 'value', desc[2].name
229
assert desc[2].type_code == 8, desc[2].type_code
230
231
with self.assertRaises(KeyError):
232
self.cur.execute('select * from data where id < %(name)s', dict(foo='d'))
233
234
def test_execute_with_positional_params(self):
235
self.cur.execute('select * from data where id < %s', ['d'])
236
out = self.cur.fetchall()
237
238
desc = self.cur.description
239
rowcount = self.cur.rowcount
240
lastrowid = self.cur.lastrowid
241
242
assert sorted(out) == sorted([
243
('a', 'antelopes', 2),
244
('b', 'bears', 2),
245
('c', 'cats', 5),
246
]), out
247
248
assert rowcount in (3, -1), rowcount
249
assert lastrowid is None, lastrowid
250
assert len(desc) == 3, desc
251
assert desc[0].name == 'id', desc[0].name
252
assert desc[0].type_code in [253, 15], desc[0].type_code
253
assert desc[1].name == 'name', desc[1].name
254
assert desc[1].type_code in [253, 15], desc[1].type_code
255
assert desc[2].name == 'value', desc[2].name
256
assert desc[2].type_code == 8, desc[2].type_code
257
258
with self.assertRaises(TypeError):
259
self.cur.execute(
260
'select * from data where id < %s and id > %s', ['d', 'e', 'f'],
261
)
262
263
with self.assertRaises(TypeError):
264
self.cur.execute('select * from data where id < %s and id > %s', ['d'])
265
266
def test_execute_with_escaped_positional_substitutions(self):
267
self.cur.execute(
268
'select `id`, `time` from alltypes where `time` = %s', ['00:07:00'],
269
)
270
out = self.cur.fetchall()
271
assert out[0] == (0, datetime.timedelta(seconds=420)), out[0]
272
273
self.cur.execute('select `id`, `time` from alltypes where `time` = "00:07:00"')
274
out = self.cur.fetchall()
275
assert out[0] == (0, datetime.timedelta(seconds=420)), out[0]
276
277
# with self.assertRaises(IndexError):
278
# self.cur.execute(
279
# 'select `id`, `time` from alltypes where `id` = %1s '
280
# 'or `time` = "00:07:00"', [0],
281
# )
282
283
self.cur.execute(
284
'select `id`, `time` from alltypes where `id` = %s '
285
'or `time` = "00:07:00"', [0],
286
)
287
out = self.cur.fetchall()
288
assert out[0] == (0, datetime.timedelta(seconds=420)), out[0]
289
290
def test_execute_with_escaped_substitutions(self):
291
self.cur.execute(
292
'select `id`, `time` from alltypes where `time` = %(time)s',
293
dict(time='00:07:00'),
294
)
295
out = self.cur.fetchall()
296
assert out[0] == (0, datetime.timedelta(seconds=420)), out[0]
297
298
self.cur.execute(
299
'select `id`, `time` from alltypes where `time` = %(time)s',
300
dict(time='00:07:00'),
301
)
302
out = self.cur.fetchall()
303
assert len(out) == 1, out
304
305
with self.assertRaises(KeyError):
306
self.cur.execute(
307
'select `id`, `time`, `char_100` from alltypes '
308
'where `time` = %(time)s or `char_100` like "foo:bar"',
309
dict(x='00:07:00'),
310
)
311
312
self.cur.execute(
313
'select `id`, `time`, `char_100` from alltypes '
314
'where `time` = %(time)s or `char_100` like "foo::bar"',
315
dict(time='00:07:00'),
316
)
317
out = self.cur.fetchall()
318
assert out[0][:2] == (0, datetime.timedelta(seconds=420)), out[0]
319
320
def test_is_connected(self):
321
assert self.conn.is_connected()
322
assert self.cur.is_connected()
323
self.cur.close()
324
assert not self.cur.is_connected()
325
assert self.conn.is_connected()
326
self.conn.close()
327
assert not self.cur.is_connected()
328
assert not self.conn.is_connected()
329
330
def test_connection_attr(self):
331
# Use context manager to get to underlying object (self.conn is a weakref.proxy)
332
with self.conn as conn:
333
assert conn is self.conn
334
335
def test_executemany(self):
336
# NOTE: Doesn't actually do anything since no rows match
337
self.cur.executemany(
338
'delete from data where id > %(name)s',
339
[dict(name='z'), dict(name='y')],
340
)
341
342
def test_executemany_no_args(self):
343
self.cur.executemany('select * from data where id > "z"')
344
345
def test_context_managers(self):
346
with s2.connect() as conn:
347
with conn.cursor() as cur:
348
assert cur.is_connected()
349
assert conn.is_connected()
350
assert not cur.is_connected()
351
assert not conn.is_connected()
352
353
def test_iterator(self):
354
self.cur.execute('select * from data')
355
356
out = []
357
for row in self.cur:
358
out.append(row)
359
360
assert sorted(out) == sorted([
361
('a', 'antelopes', 2),
362
('b', 'bears', 2),
363
('c', 'cats', 5),
364
('d', 'dogs', 4),
365
('e', 'elephants', 0),
366
]), out
367
368
def test_urls(self):
369
from singlestoredb.connection import build_params
370
from singlestoredb.config import get_option
371
372
# Full URL (without scheme)
373
url = 'me:[email protected]:3307/mydb'
374
out = build_params(host=url)
375
assert out['driver'] == get_option('driver'), out['driver']
376
assert out['host'] == 's2host.com', out['host']
377
assert out['port'] == 3307, out['port']
378
assert out['database'] == 'mydb', out['database']
379
assert out['user'] == 'me', out['user']
380
assert out['password'] == 'p455w0rd', out['password']
381
382
# Full URL (with scheme)
383
url = 'http://me:[email protected]:3307/mydb'
384
out = build_params(host=url)
385
assert out['driver'] == 'http', out['driver']
386
assert out['host'] == 's2host.com', out['host']
387
assert out['port'] == 3307, out['port']
388
assert out['database'] == 'mydb', out['database']
389
assert out['user'] == 'me', out['user']
390
assert out['password'] == 'p455w0rd', out['password']
391
392
# No port
393
url = 'me:[email protected]/mydb'
394
out = build_params(host=url)
395
assert out['driver'] == get_option('driver'), out['driver']
396
assert out['host'] == 's2host.com', out['host']
397
if out['driver'] in ['http', 'https']:
398
assert out['port'] in [get_option('http_port'), 80, 443], out['port']
399
else:
400
assert out['port'] in [get_option('port'), 3306], out['port']
401
assert out['database'] == 'mydb', out['database']
402
assert out['user'] == 'me', out['user']
403
assert out['password'] == 'p455w0rd', out['password']
404
405
# No http port
406
url = 'http://me:[email protected]/mydb'
407
out = build_params(host=url)
408
assert out['driver'] == 'http', out['driver']
409
assert out['host'] == 's2host.com', out['host']
410
assert out['port'] in [get_option('http_port'), 80], out['port']
411
assert out['database'] == 'mydb', out['database']
412
assert out['user'] == 'me', out['user']
413
assert out['password'] == 'p455w0rd', out['password']
414
415
# No https port
416
url = 'https://me:[email protected]/mydb'
417
out = build_params(host=url)
418
assert out['driver'] == 'https', out['driver']
419
assert out['host'] == 's2host.com', out['host']
420
assert out['port'] in [get_option('http_port'), 443], out['port']
421
assert out['database'] == 'mydb', out['database']
422
assert out['user'] == 'me', out['user']
423
assert out['password'] == 'p455w0rd', out['password']
424
425
# Invalid port
426
url = 'https://me:[email protected]:foo/mydb'
427
with self.assertRaises(ValueError):
428
build_params(host=url)
429
430
# Empty password
431
url = 'me:@s2host.com/mydb'
432
out = build_params(host=url)
433
assert out['driver'] == get_option('driver'), out['driver']
434
assert out['host'] == 's2host.com', out['host']
435
if out['driver'] in ['http', 'https']:
436
assert out['port'] in [get_option('http_port'), 80, 443], out['port']
437
else:
438
assert out['port'] in [get_option('port'), 3306], out['port']
439
assert out['database'] == 'mydb', out['database']
440
assert out['user'] == 'me', out['user']
441
assert out['password'] == '', out['password']
442
443
# No user/password
444
url = 's2host.com/mydb'
445
out = build_params(host=url)
446
assert out['driver'] == get_option('driver'), out['driver']
447
assert out['host'] == 's2host.com', out['host']
448
if out['driver'] in ['http', 'https']:
449
assert out['port'] in [get_option('http_port'), 80, 443], out['port']
450
else:
451
assert out['port'] in [get_option('port'), 3306], out['port']
452
assert out['database'] == 'mydb', out['database']
453
assert 'user' not in out or out['user'] == get_option('user'), out['user']
454
assert 'password' not in out or out['password'] == get_option(
455
'password',
456
), out['password']
457
458
# Just hostname
459
url = 's2host.com'
460
out = build_params(host=url)
461
assert out['driver'] == get_option('driver'), out['driver']
462
assert out['host'] == 's2host.com', out['host']
463
if out['driver'] in ['http', 'https']:
464
assert out['port'] in [get_option('http_port'), 80, 443], out['port']
465
else:
466
assert out['port'] in [get_option('port'), 3306], out['port']
467
assert 'database' not in out
468
assert 'user' not in out or out['user'] == get_option('user'), out['user']
469
assert 'password' not in out or out['password'] == get_option(
470
'password',
471
), out['password']
472
473
# Just hostname and port
474
url = 's2host.com:1000'
475
out = build_params(host=url)
476
assert out['driver'] == get_option('driver'), out['driver']
477
assert out['host'] == 's2host.com', out['host']
478
assert out['port'] == 1000, out['port']
479
assert 'database' not in out
480
assert 'user' not in out or out['user'] == get_option('user'), out['user']
481
assert 'password' not in out or out['password'] == get_option(
482
'password',
483
), out['password']
484
485
# Query options
486
url = 's2host.com:1000?local_infile=1&charset=utf8'
487
out = build_params(host=url)
488
assert out['driver'] == get_option('driver'), out['driver']
489
assert out['host'] == 's2host.com', out['host']
490
assert out['port'] == 1000, out['port']
491
assert 'database' not in out
492
assert 'user' not in out or out['user'] == get_option('user'), out['user']
493
assert 'password' not in out or out['password'] == get_option(
494
'password',
495
), out['password']
496
assert out['local_infile'] is True, out['local_infile']
497
assert out['charset'] == 'utf8', out['charset']
498
499
# Bad query option
500
url = 's2host.com:1000?bad_param=10'
501
with self.assertRaises(ValueError):
502
build_params(host=url)
503
504
def test_wrap_exc(self):
505
with self.assertRaises(s2.ProgrammingError) as cm:
506
self.cur.execute('garbage syntax')
507
508
exc = cm.exception
509
assert exc.errno == 1064, exc.errno
510
assert 'You have an error in your SQL syntax' in exc.errmsg, exc.errmsg
511
512
def test_extended_types(self):
513
if not has_numpy or not has_pygeos or not has_shapely:
514
self.skipTest('Test requires numpy, pygeos, and shapely')
515
516
import uuid
517
518
key = str(uuid.uuid4())
519
520
# shapely data
521
data = [
522
(
523
1, 'POLYGON((1 1, 2 1, 2 2, 1 2, 1 1))', 'POINT(1.5 1.5)',
524
[0.5, 0.6], datetime.datetime(1950, 1, 2, 12, 13, 14),
525
datetime.date(1950, 1, 2), datetime.time(12, 13, 14),
526
datetime.timedelta(seconds=123456), key,
527
),
528
(
529
2, 'POLYGON((5 1, 6 1, 6 2, 5 2, 5 1))', 'POINT(5.5 1.5)',
530
[1.3, 2.5], datetime.datetime(1960, 3, 4, 15, 16, 17),
531
datetime.date(1960, 3, 4), datetime.time(15, 16, 17),
532
datetime.timedelta(seconds=2), key,
533
),
534
(
535
3, 'POLYGON((5 5, 6 5, 6 6, 5 6, 5 5))', 'POINT(5.5 5.5)',
536
[10.3, 11.1], datetime.datetime(1970, 6, 7, 18, 19, 20),
537
datetime.date(1970, 5, 6), datetime.time(18, 19, 20),
538
datetime.timedelta(seconds=-2), key,
539
),
540
(
541
4, 'POLYGON((1 5, 2 5, 2 6, 1 6, 1 5))', 'POINT(1.5 5.5)',
542
[3.3, 3.4], datetime.datetime(1980, 8, 9, 21, 22, 23),
543
datetime.date(1980, 7, 8), datetime.time(21, 22, 23),
544
datetime.timedelta(seconds=-123456), key,
545
),
546
(
547
5, 'POLYGON((3 3, 4 3, 4 4, 3 4, 3 3))', 'POINT(3.5 3.5)',
548
[2.9, 9.5], datetime.datetime(2010, 10, 11, 1, 2, 3),
549
datetime.date(2010, 8, 9), datetime.time(1, 2, 3),
550
datetime.timedelta(seconds=0), key,
551
),
552
]
553
554
new_data = []
555
for i, row in enumerate(data):
556
row = list(row)
557
row[1] = shapely.wkt.loads(row[1])
558
row[2] = shapely.wkt.loads(row[2])
559
if 'http' in self.conn.driver:
560
row[3] = ''
561
else:
562
row[3] = np.array(row[3], dtype='<f4')
563
new_data.append(row)
564
565
self.cur.executemany(
566
'INSERT INTO extended_types '
567
'(id, geography, geographypoint, vectors, dt, d, t, td, testkey) '
568
'VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)', new_data,
569
)
570
571
self.cur.execute(
572
'SELECT * FROM extended_types WHERE testkey = %s ORDER BY id', [key],
573
)
574
575
for data_row, row in zip(new_data, self.cur):
576
assert data_row[0] == row[0]
577
assert data_row[1].equals_exact(shapely.wkt.loads(row[1]), 1e-4)
578
assert data_row[2].equals_exact(shapely.wkt.loads(row[2]), 1e-4)
579
if 'http' in self.conn.driver:
580
assert row[3] == b''
581
else:
582
assert (data_row[3] == np.frombuffer(row[3], dtype='<f4')).all()
583
584
# pygeos data
585
data = [
586
(
587
6, 'POLYGON((1 1, 2 1, 2 2, 1 2, 1 1))', 'POINT(1.5 1.5)',
588
[0.5, 0.6], datetime.datetime(1950, 1, 2, 12, 13, 14),
589
datetime.date(1950, 1, 2), datetime.time(12, 13, 14),
590
datetime.timedelta(seconds=123456), key,
591
),
592
(
593
7, 'POLYGON((5 1, 6 1, 6 2, 5 2, 5 1))', 'POINT(5.5 1.5)',
594
[1.3, 2.5], datetime.datetime(1960, 3, 4, 15, 16, 17),
595
datetime.date(1960, 3, 4), datetime.time(15, 16, 17),
596
datetime.timedelta(seconds=2), key,
597
),
598
(
599
8, 'POLYGON((5 5, 6 5, 6 6, 5 6, 5 5))', 'POINT(5.5 5.5)',
600
[10.3, 11.1], datetime.datetime(1970, 6, 7, 18, 19, 20),
601
datetime.date(1970, 5, 6), datetime.time(18, 19, 20),
602
datetime.timedelta(seconds=-2), key,
603
),
604
(
605
9, 'POLYGON((1 5, 2 5, 2 6, 1 6, 1 5))', 'POINT(1.5 5.5)',
606
[3.3, 3.4], datetime.datetime(1980, 8, 9, 21, 22, 23),
607
datetime.date(1980, 7, 8), datetime.time(21, 22, 23),
608
datetime.timedelta(seconds=-123456), key,
609
),
610
(
611
10, 'POLYGON((3 3, 4 3, 4 4, 3 4, 3 3))', 'POINT(3.5 3.5)',
612
[2.9, 9.5], datetime.datetime(2010, 10, 11, 1, 2, 3),
613
datetime.date(2010, 8, 9), datetime.time(1, 2, 3),
614
datetime.timedelta(seconds=0), key,
615
),
616
]
617
618
new_data = []
619
for i, row in enumerate(data):
620
row = list(row)
621
row[1] = pygeos.io.from_wkt(row[1])
622
row[2] = pygeos.io.from_wkt(row[2])
623
if 'http' in self.conn.driver:
624
row[3] = ''
625
else:
626
row[3] = np.array(row[3], dtype='<f4')
627
new_data.append(row)
628
629
self.cur.executemany(
630
'INSERT INTO extended_types '
631
'(id, geography, geographypoint, vectors, dt, d, t, td, testkey) '
632
'VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)', new_data,
633
)
634
635
self.cur.execute(
636
'SELECT * FROM extended_types WHERE id >= 6 and testkey = %s ORDER BY id', [
637
key,
638
],
639
)
640
641
for data_row, row in zip(new_data, self.cur):
642
assert data_row[0] == row[0]
643
assert_geometries_equal(data_row[1], pygeos.io.from_wkt(row[1]))
644
assert_geometries_equal(data_row[2], pygeos.io.from_wkt(row[2]))
645
if 'http' in self.conn.driver:
646
assert row[3] == b''
647
else:
648
assert (data_row[3] == np.frombuffer(row[3], dtype='<f4')).all()
649
650
def test_alltypes(self):
651
self.cur.execute('select * from alltypes where id = 0')
652
names = [x[0] for x in self.cur.description]
653
types = [x[1] for x in self.cur.description]
654
out = self.cur.fetchone()
655
row = dict(zip(names, out))
656
typ = dict(zip(names, types))
657
658
bits = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
659
660
def otype(x):
661
return x
662
663
assert row['id'] == 0, row['id']
664
assert typ['id'] == otype(3), typ['id']
665
666
assert row['tinyint'] == 80, row['tinyint']
667
assert typ['tinyint'] == otype(1), typ['tinyint']
668
669
assert row['bool'] == 0, row['bool']
670
assert typ['bool'] == otype(1), typ['bool']
671
672
assert row['boolean'] == 1, row['boolean']
673
assert typ['boolean'] == otype(1), typ['boolean']
674
675
assert row['smallint'] == -27897, row['smallint']
676
assert typ['smallint'] == otype(2), typ['smallint']
677
678
assert row['mediumint'] == 104729, row['mediumint']
679
assert typ['mediumint'] == otype(9), typ['mediumint']
680
681
assert row['int24'] == -200899, row['int24']
682
assert typ['int24'] == otype(9), typ['int24']
683
684
assert row['int'] == -1295369311, row['int']
685
assert typ['int'] == otype(3), typ['int']
686
687
assert row['integer'] == -1741727421, row['integer']
688
assert typ['integer'] == otype(3), typ['integer']
689
690
assert row['bigint'] == -266883847, row['bigint']
691
assert typ['bigint'] == otype(8), typ['bigint']
692
693
assert row['float'] == -146487000.0, row['float']
694
assert typ['float'] == otype(4), typ['float']
695
696
assert row['double'] == -474646154.719356, row['double']
697
assert typ['double'] == otype(5), typ['double']
698
699
assert row['real'] == -901409776.279346, row['real']
700
assert typ['real'] == otype(5), typ['real']
701
702
assert row['decimal'] == decimal.Decimal('28111097.610822'), row['decimal']
703
assert typ['decimal'] == otype(246), typ['decimal']
704
705
assert row['dec'] == decimal.Decimal('389451155.931428'), row['dec']
706
assert typ['dec'] == otype(246), typ['dec']
707
708
assert row['fixed'] == decimal.Decimal('-143773416.044092'), row['fixed']
709
assert typ['fixed'] == otype(246), typ['fixed']
710
711
assert row['numeric'] == decimal.Decimal('866689461.300046'), row['numeric']
712
assert typ['numeric'] == otype(246), typ['numeric']
713
714
assert row['date'] == datetime.date(8524, 11, 10), row['date']
715
assert typ['date'] == 10, typ['date']
716
717
assert row['time'] == datetime.timedelta(minutes=7), row['time']
718
assert typ['time'] == 11, typ['time']
719
720
assert row['time_6'] == datetime.timedelta(
721
hours=1, minutes=10, microseconds=2,
722
), row['time_6']
723
assert typ['time_6'] == 11, typ['time_6']
724
725
assert row['datetime'] == datetime.datetime(
726
9948, 3, 11, 15, 29, 22,
727
), row['datetime']
728
assert typ['datetime'] == 12, typ['datetime']
729
730
assert row['datetime_6'] == datetime.datetime(
731
1756, 10, 29, 2, 2, 42, 8,
732
), row['datetime_6']
733
assert typ['datetime_6'] == 12, typ['datetime_6']
734
735
assert row['timestamp'] == datetime.datetime(
736
1980, 12, 31, 1, 10, 23,
737
), row['timestamp']
738
assert typ['timestamp'] == otype(7), typ['timestamp']
739
740
assert row['timestamp_6'] == datetime.datetime(
741
1991, 1, 2, 22, 15, 10, 6,
742
), row['timestamp_6']
743
assert typ['timestamp_6'] == otype(7), typ['timestamp_6']
744
745
assert row['year'] == 1923, row['year']
746
assert typ['year'] == otype(13), typ['year']
747
748
assert row['char_100'] == \
749
'This is a test of a 100 character column.', row['char_100']
750
assert typ['char_100'] == otype(254), typ['char_100']
751
752
assert row['binary_100'] == bytearray(bits + [0] * 84), row['binary_100']
753
assert typ['binary_100'] == otype(254), typ['binary_100']
754
755
assert row['varchar_200'] == \
756
'This is a test of a variable character column.', row['varchar_200']
757
assert typ['varchar_200'] == otype(253), typ['varchar_200'] # why not 15?
758
759
assert row['varbinary_200'] == bytearray(bits * 2), row['varbinary_200']
760
assert typ['varbinary_200'] == otype(253), typ['varbinary_200'] # why not 15?
761
762
assert row['longtext'] == 'This is a longtext column.', row['longtext']
763
assert typ['longtext'] == otype(251), typ['longtext']
764
765
assert row['mediumtext'] == 'This is a mediumtext column.', row['mediumtext']
766
assert typ['mediumtext'] == otype(250), typ['mediumtext']
767
768
assert row['text'] == 'This is a text column.', row['text']
769
assert typ['text'] == otype(252), typ['text']
770
771
assert row['tinytext'] == 'This is a tinytext column.'
772
assert typ['tinytext'] == otype(249), typ['tinytext']
773
774
assert row['longblob'] == bytearray(bits * 3), row['longblob']
775
assert typ['longblob'] == otype(251), typ['longblob']
776
777
assert row['mediumblob'] == bytearray(bits * 2), row['mediumblob']
778
assert typ['mediumblob'] == otype(250), typ['mediumblob']
779
780
assert row['blob'] == bytearray(bits), row['blob']
781
assert typ['blob'] == otype(252), typ['blob']
782
783
assert row['tinyblob'] == bytearray([10, 11, 12, 13, 14, 15]), row['tinyblob']
784
assert typ['tinyblob'] == otype(249), typ['tinyblob']
785
786
assert row['json'] == {'a': 10, 'b': 2.75, 'c': 'hello world'}, row['json']
787
assert typ['json'] == otype(245), typ['json']
788
789
assert row['enum'] == 'one', row['enum']
790
assert typ['enum'] == otype(253), typ['enum'] # mysql code: 247
791
792
# TODO: HTTP sees this as a varchar, so it doesn't become a set.
793
assert row['set'] in [{'two'}, 'two'], row['set']
794
assert typ['set'] == otype(253), typ['set'] # mysql code: 248
795
796
assert row['bit'] == b'\x00\x00\x00\x00\x00\x00\x00\x80', row['bit']
797
assert typ['bit'] == otype(16), typ['bit']
798
799
def test_alltypes_nulls(self):
800
self.cur.execute('select * from alltypes where id = 1')
801
names = [x[0] for x in self.cur.description]
802
types = [x[1] for x in self.cur.description]
803
out = self.cur.fetchone()
804
row = dict(zip(names, out))
805
typ = dict(zip(names, types))
806
807
def otype(x):
808
return x
809
810
assert row['id'] == 1, row['id']
811
assert typ['id'] == otype(3), typ['id']
812
813
assert row['tinyint'] is None, row['tinyint']
814
assert typ['tinyint'] == otype(1), typ['tinyint']
815
816
assert row['bool'] is None, row['bool']
817
assert typ['bool'] == otype(1), typ['bool']
818
819
assert row['boolean'] is None, row['boolean']
820
assert typ['boolean'] == otype(1), typ['boolean']
821
822
assert row['smallint'] is None, row['smallint']
823
assert typ['smallint'] == otype(2), typ['smallint']
824
825
assert row['mediumint'] is None, row['mediumint']
826
assert typ['mediumint'] == otype(9), typ['mediumint']
827
828
assert row['int24'] is None, row['int24']
829
assert typ['int24'] == otype(9), typ['int24']
830
831
assert row['int'] is None, row['int']
832
assert typ['int'] == otype(3), typ['int']
833
834
assert row['integer'] is None, row['integer']
835
assert typ['integer'] == otype(3), typ['integer']
836
837
assert row['bigint'] is None, row['bigint']
838
assert typ['bigint'] == otype(8), typ['bigint']
839
840
assert row['float'] is None, row['float']
841
assert typ['float'] == otype(4), typ['float']
842
843
assert row['double'] is None, row['double']
844
assert typ['double'] == otype(5), typ['double']
845
846
assert row['real'] is None, row['real']
847
assert typ['real'] == otype(5), typ['real']
848
849
assert row['decimal'] is None, row['decimal']
850
assert typ['decimal'] == otype(246), typ['decimal']
851
852
assert row['dec'] is None, row['dec']
853
assert typ['dec'] == otype(246), typ['dec']
854
855
assert row['fixed'] is None, row['fixed']
856
assert typ['fixed'] == otype(246), typ['fixed']
857
858
assert row['numeric'] is None, row['numeric']
859
assert typ['numeric'] == otype(246), typ['numeric']
860
861
assert row['date'] is None, row['date']
862
assert typ['date'] == 10, typ['date']
863
864
assert row['time'] is None, row['time']
865
assert typ['time'] == 11, typ['time']
866
867
assert row['time'] is None, row['time']
868
assert typ['time_6'] == 11, typ['time_6']
869
870
assert row['datetime'] is None, row['datetime']
871
assert typ['datetime'] == 12, typ['datetime']
872
873
assert row['datetime_6'] is None, row['datetime_6']
874
assert typ['datetime'] == 12, typ['datetime']
875
876
assert row['timestamp'] is None, row['timestamp']
877
assert typ['timestamp'] == otype(7), typ['timestamp']
878
879
assert row['timestamp_6'] is None, row['timestamp_6']
880
assert typ['timestamp_6'] == otype(7), typ['timestamp_6']
881
882
assert row['year'] is None, row['year']
883
assert typ['year'] == otype(13), typ['year']
884
885
assert row['char_100'] is None, row['char_100']
886
assert typ['char_100'] == otype(254), typ['char_100']
887
888
assert row['binary_100'] is None, row['binary_100']
889
assert typ['binary_100'] == otype(254), typ['binary_100']
890
891
assert row['varchar_200'] is None, typ['varchar_200']
892
assert typ['varchar_200'] == otype(253), typ['varchar_200'] # why not 15?
893
894
assert row['varbinary_200'] is None, row['varbinary_200']
895
assert typ['varbinary_200'] == otype(253), typ['varbinary_200'] # why not 15?
896
897
assert row['longtext'] is None, row['longtext']
898
assert typ['longtext'] == otype(251), typ['longtext']
899
900
assert row['mediumtext'] is None, row['mediumtext']
901
assert typ['mediumtext'] == otype(250), typ['mediumtext']
902
903
assert row['text'] is None, row['text']
904
assert typ['text'] == otype(252), typ['text']
905
906
assert row['tinytext'] is None, row['tinytext']
907
assert typ['tinytext'] == otype(249), typ['tinytext']
908
909
assert row['longblob'] is None, row['longblob']
910
assert typ['longblob'] == otype(251), typ['longblob']
911
912
assert row['mediumblob'] is None, row['mediumblob']
913
assert typ['mediumblob'] == otype(250), typ['mediumblob']
914
915
assert row['blob'] is None, row['blob']
916
assert typ['blob'] == otype(252), typ['blob']
917
918
assert row['tinyblob'] is None, row['tinyblob']
919
assert typ['tinyblob'] == otype(249), typ['tinyblob']
920
921
assert row['json'] is None, row['json']
922
assert typ['json'] == otype(245), typ['json']
923
924
assert row['enum'] is None, row['enum']
925
assert typ['enum'] == otype(253), typ['enum'] # mysql code: 247
926
927
assert row['set'] is None, row['set']
928
assert typ['set'] == otype(253), typ['set'] # mysql code: 248
929
930
assert row['bit'] is None, row['bit']
931
assert typ['bit'] == otype(16), typ['bit']
932
933
def test_alltypes_mins(self):
934
self.cur.execute('select * from alltypes where id = 2')
935
names = [x[0] for x in self.cur.description]
936
out = self.cur.fetchone()
937
row = dict(zip(names, out))
938
939
expected = dict(
940
id=2,
941
tinyint=-128,
942
unsigned_tinyint=0,
943
bool=-128,
944
boolean=-128,
945
smallint=-32768,
946
unsigned_smallint=0,
947
mediumint=-8388608,
948
unsigned_mediumint=0,
949
int24=-8388608,
950
unsigned_int24=0,
951
int=-2147483648,
952
unsigned_int=0,
953
integer=-2147483648,
954
unsigned_integer=0,
955
bigint=-9223372036854775808,
956
unsigned_bigint=0,
957
float=0,
958
double=-1.7976931348623158e308,
959
real=-1.7976931348623158e308,
960
decimal=decimal.Decimal('-99999999999999.999999'),
961
dec=-decimal.Decimal('99999999999999.999999'),
962
fixed=decimal.Decimal('-99999999999999.999999'),
963
numeric=decimal.Decimal('-99999999999999.999999'),
964
date=datetime.date(1000, 1, 1),
965
time=-1 * datetime.timedelta(hours=838, minutes=59, seconds=59),
966
time_6=-1 * datetime.timedelta(hours=838, minutes=59, seconds=59),
967
datetime=datetime.datetime(1000, 1, 1, 0, 0, 0),
968
datetime_6=datetime.datetime(1000, 1, 1, 0, 0, 0, 0),
969
timestamp=datetime.datetime(1970, 1, 1, 0, 0, 1),
970
timestamp_6=datetime.datetime(1970, 1, 1, 0, 0, 1, 0),
971
year=1901,
972
char_100='',
973
binary_100=b'\x00' * 100,
974
varchar_200='',
975
varbinary_200=b'',
976
longtext='',
977
mediumtext='',
978
text='',
979
tinytext='',
980
longblob=b'',
981
mediumblob=b'',
982
blob=b'',
983
tinyblob=b'',
984
json={},
985
enum='one',
986
set='two',
987
bit=b'\x00\x00\x00\x00\x00\x00\x00\x00',
988
)
989
990
for k, v in sorted(row.items()):
991
assert v == expected[k], '{} != {} in key {}'.format(v, expected[k], k)
992
993
def test_alltypes_maxs(self):
994
self.cur.execute('select * from alltypes where id = 3')
995
names = [x[0] for x in self.cur.description]
996
out = self.cur.fetchone()
997
row = dict(zip(names, out))
998
999
expected = dict(
1000
id=3,
1001
tinyint=127,
1002
unsigned_tinyint=255,
1003
bool=127,
1004
boolean=127,
1005
smallint=32767,
1006
unsigned_smallint=65535,
1007
mediumint=8388607,
1008
unsigned_mediumint=16777215,
1009
int24=8388607,
1010
unsigned_int24=16777215,
1011
int=2147483647,
1012
unsigned_int=4294967295,
1013
integer=2147483647,
1014
unsigned_integer=4294967295,
1015
bigint=9223372036854775807,
1016
unsigned_bigint=18446744073709551615,
1017
float=0,
1018
double=1.7976931348623158e308,
1019
real=1.7976931348623158e308,
1020
decimal=decimal.Decimal('99999999999999.999999'),
1021
dec=decimal.Decimal('99999999999999.999999'),
1022
fixed=decimal.Decimal('99999999999999.999999'),
1023
numeric=decimal.Decimal('99999999999999.999999'),
1024
date=datetime.date(9999, 12, 31),
1025
time=datetime.timedelta(hours=838, minutes=59, seconds=59),
1026
time_6=datetime.timedelta(hours=838, minutes=59, seconds=59),
1027
datetime=datetime.datetime(9999, 12, 31, 23, 59, 59),
1028
datetime_6=datetime.datetime(9999, 12, 31, 23, 59, 59, 999999),
1029
timestamp=datetime.datetime(2038, 1, 19, 3, 14, 7),
1030
timestamp_6=datetime.datetime(2038, 1, 19, 3, 14, 7, 999999),
1031
year=2155,
1032
char_100='',
1033
binary_100=b'\x00' * 100,
1034
varchar_200='',
1035
varbinary_200=b'',
1036
longtext='',
1037
mediumtext='',
1038
text='',
1039
tinytext='',
1040
longblob=b'',
1041
mediumblob=b'',
1042
blob=b'',
1043
tinyblob=b'',
1044
json={},
1045
enum='one',
1046
set='two',
1047
bit=b'\xff\xff\xff\xff\xff\xff\xff\xff',
1048
)
1049
1050
for k, v in sorted(row.items()):
1051
# TODO: Figure out how to get time zones working
1052
if 'timestamp' in k:
1053
continue
1054
assert v == expected[k], '{} != {} in key {}'.format(v, expected[k], k)
1055
1056
def test_alltypes_zeros(self):
1057
self.cur.execute('select * from alltypes where id = 4')
1058
names = [x[0] for x in self.cur.description]
1059
out = self.cur.fetchone()
1060
row = dict(zip(names, out))
1061
1062
expected = dict(
1063
id=4,
1064
tinyint=0,
1065
unsigned_tinyint=0,
1066
bool=0,
1067
boolean=0,
1068
smallint=0,
1069
unsigned_smallint=0,
1070
mediumint=0,
1071
unsigned_mediumint=0,
1072
int24=0,
1073
unsigned_int24=0,
1074
int=0,
1075
unsigned_int=0,
1076
integer=0,
1077
unsigned_integer=0,
1078
bigint=0,
1079
unsigned_bigint=0,
1080
float=0,
1081
double=0,
1082
real=0,
1083
decimal=decimal.Decimal('0.0'),
1084
dec=decimal.Decimal('0.0'),
1085
fixed=decimal.Decimal('0.0'),
1086
numeric=decimal.Decimal('0.0'),
1087
date=None,
1088
time=datetime.timedelta(hours=0, minutes=0, seconds=0),
1089
time_6=datetime.timedelta(hours=0, minutes=0, seconds=0, microseconds=0),
1090
datetime=None,
1091
datetime_6=None,
1092
timestamp=None,
1093
timestamp_6=None,
1094
year=None,
1095
char_100='',
1096
binary_100=b'\x00' * 100,
1097
varchar_200='',
1098
varbinary_200=b'',
1099
longtext='',
1100
mediumtext='',
1101
text='',
1102
tinytext='',
1103
longblob=b'',
1104
mediumblob=b'',
1105
blob=b'',
1106
tinyblob=b'',
1107
json={},
1108
enum='one',
1109
set='two',
1110
bit=b'\x00\x00\x00\x00\x00\x00\x00\x00',
1111
)
1112
1113
for k, v in sorted(row.items()):
1114
assert v == expected[k], '{} != {} in key {}'.format(v, expected[k], k)
1115
1116
def _test_MySQLdb(self):
1117
try:
1118
import json
1119
import MySQLdb
1120
except (ModuleNotFoundError, ImportError):
1121
self.skipTest('MySQLdb is not installed')
1122
1123
self.cur.execute('select * from alltypes order by id')
1124
s2_out = self.cur.fetchall()
1125
1126
port = self.conn.connection_params['port']
1127
if 'http' in self.conn.driver:
1128
port = 3306
1129
1130
args = dict(
1131
host=self.conn.connection_params['host'],
1132
port=port,
1133
user=self.conn.connection_params['user'],
1134
password=self.conn.connection_params['password'],
1135
database=type(self).dbname,
1136
)
1137
1138
with MySQLdb.connect(**args) as conn:
1139
conn.converter[245] = json.loads
1140
with conn.cursor() as cur:
1141
cur.execute('select * from alltypes order by id')
1142
mydb_out = cur.fetchall()
1143
1144
for a, b in zip(s2_out, mydb_out):
1145
assert a == b, (a, b)
1146
1147
def test_int_string(self):
1148
string = 'a' * 48
1149
self.cur.execute(f"SELECT 1, '{string}'")
1150
self.assertEqual((1, string), self.cur.fetchone())
1151
1152
def test_double_string(self):
1153
string = 'a' * 49
1154
self.cur.execute(f"SELECT 1.2 :> DOUBLE, '{string}'")
1155
self.assertEqual((1.2, string), self.cur.fetchone())
1156
1157
def test_year_string(self):
1158
string = 'a' * 49
1159
self.cur.execute(f"SELECT 1999 :> YEAR, '{string}'")
1160
self.assertEqual((1999, string), self.cur.fetchone())
1161
1162
def test_nan_as_null(self):
1163
with self.assertRaises((s2.ProgrammingError, InvalidJSONError)):
1164
self.cur.execute('SELECT %s :> DOUBLE AS X', [math.nan])
1165
1166
with s2.connect(database=type(self).dbname, nan_as_null=True) as conn:
1167
with conn.cursor() as cur:
1168
cur.execute('SELECT %s :> DOUBLE AS X', [math.nan])
1169
self.assertEqual(None, list(cur)[0][0])
1170
1171
with s2.connect(database=type(self).dbname, nan_as_null=True) as conn:
1172
with conn.cursor() as cur:
1173
cur.execute('SELECT %s :> DOUBLE AS X', [1.234])
1174
self.assertEqual(1.234, list(cur)[0][0])
1175
1176
def test_inf_as_null(self):
1177
with self.assertRaises((s2.ProgrammingError, InvalidJSONError)):
1178
self.cur.execute('SELECT %s :> DOUBLE AS X', [math.inf])
1179
1180
with s2.connect(database=type(self).dbname, inf_as_null=True) as conn:
1181
with conn.cursor() as cur:
1182
cur.execute('SELECT %s :> DOUBLE AS X', [math.inf])
1183
self.assertEqual(None, list(cur)[0][0])
1184
1185
with s2.connect(database=type(self).dbname, inf_as_null=True) as conn:
1186
with conn.cursor() as cur:
1187
cur.execute('SELECT %s :> DOUBLE AS X', [1.234])
1188
self.assertEqual(1.234, list(cur)[0][0])
1189
1190
def test_encoding_errors(self):
1191
with s2.connect(
1192
database=type(self).dbname,
1193
encoding_errors='strict',
1194
) as conn:
1195
with conn.cursor() as cur:
1196
cur.execute('SELECT * FROM badutf8')
1197
list(cur)
1198
1199
with s2.connect(
1200
database=type(self).dbname,
1201
encoding_errors='backslashreplace',
1202
) as conn:
1203
with conn.cursor() as cur:
1204
cur.execute('SELECT * FROM badutf8')
1205
list(cur)
1206
1207
def test_character_lengths(self):
1208
if 'http' in self.conn.driver:
1209
self.skipTest('Character lengths too long for HTTP interface')
1210
1211
tbl_id = str(id(self))
1212
1213
self.cur.execute('DROP TABLE IF EXISTS test_character_lengths')
1214
self.cur.execute(rf'''
1215
CREATE TABLE `test_character_lengths_{tbl_id}` (
1216
`id` text CHARACTER SET utf8 COLLATE utf8_general_ci NOT NULL,
1217
`char_col` longtext CHARACTER SET utf8 COLLATE utf8_general_ci NOT NULL,
1218
`int_col` INT,
1219
PRIMARY KEY (`id`),
1220
SORT KEY `id` (`id`)
1221
) AUTOSTATS_CARDINALITY_MODE=INCREMENTAL
1222
AUTOSTATS_HISTOGRAM_MODE=CREATE
1223
AUTOSTATS_SAMPLING=ON
1224
SQL_MODE='STRICT_ALL_TABLES'
1225
''')
1226
1227
CHAR_STR_SHORT = 'a'
1228
CHAR_STR_LONG = 'a' * (2**8-1)
1229
SHORT_STR_SHORT = 'a' * ((2**8-1) + 1)
1230
SHORT_STR_LONG = 'a' * (2**16-1)
1231
INT24_STR_SHORT = 'a' * ((2**16-1) + 1)
1232
INT24_STR_LONG = 'a' * (2**24-1)
1233
INT64_STR_SHORT = 'a' * ((2**24-1) + 1)
1234
INT64_STR_LONG = 'a' * ((2**24-1) + 100000)
1235
1236
data = [
1237
['CHAR_SHORT', CHAR_STR_SHORT, 123456],
1238
['CHAR_LONG', CHAR_STR_LONG, 123456],
1239
['SHORT_SHORT', SHORT_STR_SHORT, 123456],
1240
['SHORT_LONG', SHORT_STR_LONG, 123456],
1241
['INT24_SHORT', INT24_STR_SHORT, 123456],
1242
['INT24_LONG', INT24_STR_LONG, 123456],
1243
['INT64_SHORT', INT64_STR_SHORT, 123456],
1244
['INT64_LONG', INT64_STR_LONG, 123456],
1245
]
1246
1247
self.cur.executemany(
1248
f'INSERT INTO test_character_lengths_{tbl_id}(id, char_col, int_col) '
1249
'VALUES (%s, %s, %s)', data,
1250
)
1251
1252
for i, row in enumerate(data):
1253
self.cur.execute(
1254
f'SELECT id, char_col, int_col FROM test_character_lengths_{tbl_id} '
1255
'WHERE id = %s',
1256
[row[0]],
1257
)
1258
assert data[i] == list(list(self.cur)[0])
1259
1260
try:
1261
self.cur.execute(f'DROP TABLE test_character_lengths_{tbl_id}')
1262
except Exception:
1263
pass
1264
1265
def test_pydantic(self):
1266
if not has_pydantic:
1267
self.skipTest('Test requires pydantic')
1268
1269
tblname = 'foo_' + str(id(self))
1270
1271
class FooData(pydantic.BaseModel):
1272
x: Optional[int]
1273
y: Optional[float]
1274
z: Optional[str] = None
1275
1276
self.cur.execute(f'''
1277
CREATE TABLE {tblname}(
1278
x INT,
1279
y DOUBLE,
1280
z TEXT
1281
)
1282
''')
1283
1284
self.cur.execute(
1285
f'INSERT INTO {tblname}(x, y) VALUES (%(x)s, %(y)s)',
1286
FooData(x=2, y=3.23),
1287
)
1288
1289
self.cur.execute('SELECT * FROM ' + tblname)
1290
1291
assert list(sorted(self.cur.fetchall())) == \
1292
list(sorted([(2, 3.23, None)]))
1293
1294
self.cur.executemany(
1295
f'INSERT INTO {tblname}(x) VALUES (%(x)s)',
1296
[FooData(x=3, y=3.12), FooData(x=10, y=100.11)],
1297
)
1298
1299
self.cur.execute('SELECT * FROM ' + tblname)
1300
1301
assert list(sorted(self.cur.fetchall())) == \
1302
list(
1303
sorted([
1304
(2, 3.23, None),
1305
(3, None, None),
1306
(10, None, None),
1307
]),
1308
)
1309
1310
def test_charset(self):
1311
self.skipTest('Skip until charset commands are re-implemented')
1312
1313
with s2.connect(database=type(self).dbname) as conn:
1314
with conn.cursor() as cur:
1315
cur.execute('''
1316
select json_extract_string('{"foo":"😀"}', "bar");
1317
''')
1318
1319
if 'http' in self.conn.driver:
1320
self.skipTest('Charset is not use in HTTP interface')
1321
1322
with self.assertRaises(s2.OperationalError):
1323
with s2.connect(database=type(self).dbname, charset='utf8') as conn:
1324
with conn.cursor() as cur:
1325
cur.execute('''
1326
select json_extract_string('{"foo":"😀"}', "bar");
1327
''')
1328
1329
1330
if __name__ == '__main__':
1331
import nose2
1332
nose2.main()
1333
1334