Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/io/database/test_read.py
8420 views
1
from __future__ import annotations
2
3
import os
4
import sqlite3
5
import sys
6
from contextlib import suppress
7
from datetime import date
8
from pathlib import Path
9
from types import GeneratorType
10
from typing import TYPE_CHECKING, Any, Literal, NamedTuple, cast
11
from unittest.mock import Mock, patch
12
13
with suppress(ModuleNotFoundError): # not available on windows
14
import adbc_driver_sqlite.dbapi
15
import pyarrow as pa
16
import pytest
17
import sqlalchemy
18
from sqlalchemy import Integer, MetaData, Table, create_engine, func, select, text
19
from sqlalchemy.orm import sessionmaker
20
from sqlalchemy.sql.expression import cast as alchemy_cast
21
22
import polars as pl
23
from polars._utils.various import parse_version
24
from polars.exceptions import DuplicateError, UnsuitableSQLError
25
from polars.io.database._arrow_registry import ARROW_DRIVER_REGISTRY
26
from polars.testing import assert_frame_equal, assert_series_equal
27
28
if TYPE_CHECKING:
29
from polars._typing import (
30
ConnectionOrCursor,
31
DbReadEngine,
32
SchemaDefinition,
33
SchemaDict,
34
)
35
36
37
def adbc_sqlite_connect(*args: Any, **kwargs: Any) -> Any:
38
args = tuple(str(a) if isinstance(a, Path) else a for a in args)
39
return adbc_driver_sqlite.dbapi.connect(*args, **kwargs)
40
41
42
class MockConnection:
43
"""Mock connection class for databases we can't test in CI."""
44
45
def __init__(
46
self,
47
driver: str,
48
batch_size: int | None,
49
exact_batch_size: bool,
50
test_data: pa.Table,
51
repeat_batch_calls: bool,
52
) -> None:
53
self.__class__.__module__ = driver
54
self._cursor = MockCursor(
55
repeat_batch_calls=repeat_batch_calls,
56
exact_batch_size=exact_batch_size,
57
batched=(batch_size is not None),
58
test_data=test_data,
59
)
60
61
def close(self) -> None:
62
pass
63
64
def cursor(self) -> Any:
65
return self._cursor
66
67
68
class MockCursor:
69
"""Mock cursor class for databases we can't test in CI."""
70
71
def __init__(
72
self,
73
batched: bool,
74
exact_batch_size: bool,
75
test_data: pa.Table,
76
repeat_batch_calls: bool,
77
) -> None:
78
self.resultset = MockResultSet(
79
test_data=test_data,
80
batched=batched,
81
exact_batch_size=exact_batch_size,
82
repeat_batch_calls=repeat_batch_calls,
83
)
84
self.exact_batch_size = exact_batch_size
85
self.called: list[str] = []
86
self.batched = batched
87
self.n_calls = 1
88
89
def __getattr__(self, name: str) -> Any:
90
if "fetch" in name:
91
self.called.append(name)
92
return self.resultset
93
super().__getattr__(name) # type: ignore[misc]
94
95
def close(self) -> Any:
96
pass
97
98
def execute(self, query: str) -> Any:
99
return self
100
101
102
class MockResultSet:
103
"""Mock resultset class for databases we can't test in CI."""
104
105
def __init__(
106
self,
107
test_data: pa.Table,
108
batched: bool,
109
exact_batch_size: bool,
110
repeat_batch_calls: bool = False,
111
) -> None:
112
self.test_data = test_data
113
self.repeat_batched_calls = repeat_batch_calls
114
self.exact_batch_size = exact_batch_size
115
self.batched = batched
116
self.n_calls = 1
117
118
def __call__(self, *args: Any, **kwargs: Any) -> Any:
119
if not self.exact_batch_size:
120
assert len(args) == 0
121
if self.repeat_batched_calls:
122
res = self.test_data[: None if self.n_calls else 0]
123
self.n_calls -= 1
124
else:
125
res = iter((self.test_data,))
126
return res
127
128
129
class DatabaseReadTestParams(NamedTuple):
130
"""Clarify read test params."""
131
132
read_method: Literal["read_database", "read_database_uri"]
133
connect_using: Any
134
expected_dtypes: SchemaDefinition
135
expected_dates: list[date | str]
136
schema_overrides: SchemaDict | None = None
137
batch_size: int | None = None
138
139
140
class ExceptionTestParams(NamedTuple):
141
"""Clarify exception test params."""
142
143
read_method: str
144
query: str | list[str]
145
protocol: Any
146
errclass: type[Exception]
147
errmsg: str
148
engine: str | None = None
149
execute_options: dict[str, Any] | None = None
150
pre_execution_query: str | list[str] | None = None
151
kwargs: dict[str, Any] | None = None
152
153
154
@pytest.mark.write_disk
155
@pytest.mark.parametrize(
156
(
157
"read_method",
158
"connect_using",
159
"expected_dtypes",
160
"expected_dates",
161
"schema_overrides",
162
"batch_size",
163
),
164
[
165
pytest.param(
166
*DatabaseReadTestParams(
167
read_method="read_database_uri",
168
connect_using="connectorx",
169
expected_dtypes={
170
"id": pl.UInt8,
171
"name": pl.String,
172
"value": pl.Float64,
173
"date": pl.Date,
174
},
175
expected_dates=[date(2020, 1, 1), date(2021, 12, 31)],
176
schema_overrides={"id": pl.UInt8},
177
),
178
id="uri: connectorx",
179
),
180
pytest.param(
181
*DatabaseReadTestParams(
182
read_method="read_database_uri",
183
connect_using="adbc",
184
expected_dtypes={
185
"id": pl.UInt8,
186
"name": pl.String,
187
"value": pl.Float64,
188
"date": pl.String,
189
},
190
expected_dates=["2020-01-01", "2021-12-31"],
191
schema_overrides={"id": pl.UInt8},
192
),
193
marks=pytest.mark.skipif(
194
sys.platform == "win32",
195
reason="adbc_driver_sqlite not available on Windows",
196
),
197
id="uri: adbc",
198
),
199
pytest.param(
200
*DatabaseReadTestParams(
201
read_method="read_database",
202
connect_using=lambda path: sqlite3.connect(path, detect_types=True),
203
expected_dtypes={
204
"id": pl.UInt8,
205
"name": pl.String,
206
"value": pl.Float32,
207
"date": pl.Date,
208
},
209
expected_dates=[date(2020, 1, 1), date(2021, 12, 31)],
210
schema_overrides={"id": pl.UInt8, "value": pl.Float32},
211
),
212
id="conn: sqlite3",
213
),
214
pytest.param(
215
*DatabaseReadTestParams(
216
read_method="read_database",
217
connect_using=lambda path: sqlite3.connect(path, detect_types=True),
218
expected_dtypes={
219
"id": pl.Int32,
220
"name": pl.String,
221
"value": pl.Float32,
222
"date": pl.Date,
223
},
224
expected_dates=[date(2020, 1, 1), date(2021, 12, 31)],
225
schema_overrides={"id": pl.Int32, "value": pl.Float32},
226
batch_size=1,
227
),
228
id="conn: sqlite3 (batched)",
229
),
230
pytest.param(
231
*DatabaseReadTestParams(
232
read_method="read_database",
233
connect_using=lambda path: create_engine(
234
f"sqlite:///{path}",
235
connect_args={"detect_types": sqlite3.PARSE_DECLTYPES},
236
).connect(),
237
expected_dtypes={
238
"id": pl.Int64,
239
"name": pl.String,
240
"value": pl.Float64,
241
"date": pl.Date,
242
},
243
expected_dates=[date(2020, 1, 1), date(2021, 12, 31)],
244
),
245
id="conn: sqlalchemy",
246
),
247
pytest.param(
248
*DatabaseReadTestParams(
249
read_method="read_database",
250
connect_using=adbc_sqlite_connect,
251
expected_dtypes={
252
"id": pl.Int64,
253
"name": pl.String,
254
"value": pl.Float64,
255
"date": pl.String,
256
},
257
expected_dates=["2020-01-01", "2021-12-31"],
258
),
259
marks=pytest.mark.skipif(
260
sys.platform == "win32",
261
reason="adbc_driver_sqlite not available on Windows",
262
),
263
id="conn: adbc (fetchall)",
264
),
265
pytest.param(
266
*DatabaseReadTestParams(
267
read_method="read_database",
268
connect_using=adbc_sqlite_connect,
269
expected_dtypes={
270
"id": pl.Int64,
271
"name": pl.String,
272
"value": pl.Float64,
273
"date": pl.String,
274
},
275
expected_dates=["2020-01-01", "2021-12-31"],
276
batch_size=1,
277
),
278
marks=pytest.mark.skipif(
279
sys.platform == "win32",
280
reason="adbc_driver_sqlite not available on Windows",
281
),
282
id="conn: adbc (batched)",
283
),
284
],
285
)
286
def test_read_database(
287
read_method: Literal["read_database", "read_database_uri"],
288
connect_using: Any,
289
expected_dtypes: dict[str, pl.DataType],
290
expected_dates: list[date | str],
291
schema_overrides: SchemaDict | None,
292
batch_size: int | None,
293
tmp_sqlite_db: Path,
294
) -> None:
295
if read_method == "read_database_uri":
296
connect_using = cast("DbReadEngine", connect_using)
297
# instantiate the connection ourselves, using connectorx/adbc
298
df = pl.read_database_uri(
299
uri=f"sqlite:///{tmp_sqlite_db}",
300
query="SELECT * FROM test_data",
301
engine=connect_using,
302
schema_overrides=schema_overrides,
303
)
304
df_empty = pl.read_database_uri(
305
uri=f"sqlite:///{tmp_sqlite_db}",
306
query="SELECT * FROM test_data WHERE name LIKE '%polars%'",
307
engine=connect_using,
308
schema_overrides=schema_overrides,
309
)
310
elif "adbc" in os.environ["PYTEST_CURRENT_TEST"]:
311
# externally instantiated adbc connections
312
with connect_using(tmp_sqlite_db) as conn:
313
df = pl.read_database(
314
connection=conn,
315
query="SELECT * FROM test_data",
316
schema_overrides=schema_overrides,
317
batch_size=batch_size,
318
)
319
df_empty = pl.read_database(
320
connection=conn,
321
query="SELECT * FROM test_data WHERE name LIKE '%polars%'",
322
schema_overrides=schema_overrides,
323
batch_size=batch_size,
324
)
325
else:
326
# other user-supplied connections
327
df = pl.read_database(
328
connection=connect_using(tmp_sqlite_db),
329
query="SELECT * FROM test_data WHERE name NOT LIKE '%polars%'",
330
schema_overrides=schema_overrides,
331
batch_size=batch_size,
332
)
333
df_empty = pl.read_database(
334
connection=connect_using(tmp_sqlite_db),
335
query="SELECT * FROM test_data WHERE name LIKE '%polars%'",
336
schema_overrides=schema_overrides,
337
batch_size=batch_size,
338
)
339
340
# validate the expected query return (data and schema)
341
assert df.schema == expected_dtypes
342
assert df.shape == (2, 4)
343
assert df["date"].to_list() == expected_dates
344
345
# note: 'cursor.description' is not reliable when no query
346
# data is returned, so no point comparing expected dtypes
347
assert df_empty.columns == ["id", "name", "value", "date"]
348
assert df_empty.shape == (0, 4)
349
assert df_empty["date"].to_list() == []
350
351
352
@pytest.mark.write_disk
353
@pytest.mark.parametrize(
354
(
355
"read_method",
356
"connect_using",
357
"expected_dtypes",
358
"expected_dates",
359
"schema_overrides",
360
"batch_size",
361
),
362
[
363
pytest.param(
364
*DatabaseReadTestParams(
365
read_method="read_database",
366
connect_using=lambda path: sqlite3.connect(path, detect_types=True),
367
expected_dtypes={
368
"id": pl.Int32,
369
"name": pl.String,
370
"value": pl.Float32,
371
"date": pl.Date,
372
},
373
expected_dates=[date(2020, 1, 1), date(2021, 12, 31)],
374
schema_overrides={"id": pl.Int32, "value": pl.Float32},
375
batch_size=1,
376
),
377
id="conn: sqlite3",
378
),
379
pytest.param(
380
*DatabaseReadTestParams(
381
read_method="read_database",
382
connect_using=lambda path: create_engine(
383
f"sqlite:///{path}",
384
connect_args={"detect_types": sqlite3.PARSE_DECLTYPES},
385
).connect(),
386
expected_dtypes={
387
"id": pl.Int64,
388
"name": pl.String,
389
"value": pl.Float64,
390
"date": pl.Date,
391
},
392
expected_dates=[date(2020, 1, 1), date(2021, 12, 31)],
393
batch_size=1,
394
),
395
id="conn: sqlalchemy",
396
),
397
pytest.param(
398
*DatabaseReadTestParams(
399
read_method="read_database",
400
connect_using=adbc_sqlite_connect,
401
expected_dtypes={
402
"id": pl.Int64,
403
"name": pl.String,
404
"value": pl.Float64,
405
"date": pl.String,
406
},
407
expected_dates=["2020-01-01", "2021-12-31"],
408
),
409
marks=pytest.mark.skipif(
410
sys.platform == "win32",
411
reason="adbc_driver_sqlite not available on Windows",
412
),
413
id="conn: adbc",
414
),
415
pytest.param(
416
*DatabaseReadTestParams(
417
read_method="read_database",
418
connect_using=adbc_sqlite_connect,
419
expected_dtypes={
420
"id": pl.Int64,
421
"name": pl.String,
422
"value": pl.Float64,
423
"date": pl.String,
424
},
425
expected_dates=["2020-01-01", "2021-12-31"],
426
batch_size=1,
427
),
428
marks=pytest.mark.skipif(
429
sys.platform == "win32",
430
reason="adbc_driver_sqlite not available on Windows",
431
),
432
id="conn: adbc (ignore batch_size)",
433
),
434
],
435
)
436
def test_read_database_iter_batches(
437
read_method: Literal["read_database"],
438
connect_using: Any,
439
expected_dtypes: dict[str, pl.DataType],
440
expected_dates: list[date | str],
441
schema_overrides: SchemaDict | None,
442
batch_size: int | None,
443
tmp_sqlite_db: Path,
444
) -> None:
445
if "adbc" in os.environ["PYTEST_CURRENT_TEST"]:
446
# externally instantiated adbc connections
447
with connect_using(tmp_sqlite_db) as conn:
448
dfs = pl.read_database(
449
connection=conn,
450
query="SELECT * FROM test_data",
451
schema_overrides=schema_overrides,
452
iter_batches=True,
453
batch_size=batch_size,
454
)
455
empty_dfs = pl.read_database(
456
connection=conn,
457
query="SELECT * FROM test_data WHERE name LIKE '%polars%'",
458
schema_overrides=schema_overrides,
459
iter_batches=True,
460
batch_size=batch_size,
461
)
462
# must consume the iterators while the connection is open
463
dfs = iter(list(dfs))
464
empty_dfs = iter(list(empty_dfs))
465
else:
466
# other user-supplied connections
467
dfs = pl.read_database(
468
connection=connect_using(tmp_sqlite_db),
469
query="SELECT * FROM test_data WHERE name NOT LIKE '%polars%'",
470
schema_overrides=schema_overrides,
471
iter_batches=True,
472
batch_size=batch_size,
473
)
474
empty_dfs = pl.read_database(
475
connection=connect_using(tmp_sqlite_db),
476
query="SELECT * FROM test_data WHERE name LIKE '%polars%'",
477
schema_overrides=schema_overrides,
478
iter_batches=True,
479
batch_size=batch_size,
480
)
481
482
df: pl.DataFrame = pl.concat(dfs)
483
# validate the expected query return (data and schema)
484
assert df.schema == expected_dtypes
485
assert df.shape == (2, 4)
486
assert df["date"].to_list() == expected_dates
487
488
# some drivers return an empty iterator when there is no result
489
try:
490
df_empty: pl.DataFrame = pl.concat(empty_dfs)
491
except ValueError:
492
return
493
# # note: 'cursor.description' is not reliable when no query
494
# # data is returned, so no point comparing expected dtypes
495
assert df_empty.columns == ["id", "name", "value", "date"]
496
assert df_empty.shape == (0, 4)
497
assert df_empty["date"].to_list() == []
498
499
500
def test_read_database_alchemy_selectable(tmp_sqlite_db: Path) -> None:
501
# various flavours of alchemy connection
502
alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}")
503
alchemy_session: ConnectionOrCursor = sessionmaker(bind=alchemy_engine)()
504
alchemy_conn: ConnectionOrCursor = alchemy_engine.connect()
505
506
t = Table("test_data", MetaData(), autoload_with=alchemy_engine)
507
508
# establish sqlalchemy "selectable" and validate usage
509
selectable_query = select(
510
alchemy_cast(func.strftime("%Y", t.c.date), Integer).label("year"),
511
t.c.name,
512
t.c.value,
513
).where(t.c.value < 0)
514
515
expected = pl.DataFrame({"year": [2021], "name": ["other"], "value": [-99.5]})
516
517
for conn in (alchemy_session, alchemy_engine, alchemy_conn):
518
assert_frame_equal(
519
pl.read_database(selectable_query, connection=conn),
520
expected,
521
)
522
523
batches = list(
524
pl.read_database(
525
selectable_query,
526
connection=conn,
527
iter_batches=True,
528
batch_size=1,
529
)
530
)
531
assert len(batches) == 1
532
assert_frame_equal(batches[0], expected)
533
534
535
def test_read_database_alchemy_textclause(tmp_sqlite_db: Path) -> None:
536
# various flavours of alchemy connection
537
alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}")
538
alchemy_session: ConnectionOrCursor = sessionmaker(bind=alchemy_engine)()
539
alchemy_conn: ConnectionOrCursor = alchemy_engine.connect()
540
541
# establish sqlalchemy "textclause" and validate usage
542
textclause_query = text(
543
"""
544
SELECT CAST(STRFTIME('%Y',"date") AS INT) as "year", name, value
545
FROM test_data
546
WHERE value < 0
547
"""
548
)
549
550
expected = pl.DataFrame({"year": [2021], "name": ["other"], "value": [-99.5]})
551
552
for conn in (alchemy_session, alchemy_engine, alchemy_conn):
553
assert_frame_equal(
554
pl.read_database(textclause_query, connection=conn),
555
expected,
556
)
557
558
batches = list(
559
pl.read_database(
560
textclause_query,
561
connection=conn,
562
iter_batches=True,
563
batch_size=1,
564
)
565
)
566
assert len(batches) == 1
567
assert_frame_equal(batches[0], expected)
568
569
570
@pytest.mark.parametrize(
571
("param", "param_value"),
572
[
573
(":n", {"n": 0}),
574
("?", (0,)),
575
("?", [0]),
576
],
577
)
578
def test_read_database_parameterised(
579
param: str, param_value: Any, tmp_sqlite_db: Path
580
) -> None:
581
# raw cursor "execute" only takes positional params, alchemy cursor takes kwargs
582
alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}")
583
alchemy_conn: ConnectionOrCursor = alchemy_engine.connect()
584
alchemy_session: ConnectionOrCursor = sessionmaker(bind=alchemy_engine)()
585
raw_conn: ConnectionOrCursor = sqlite3.connect(tmp_sqlite_db)
586
587
# establish parameterised queries and validate usage
588
query = """
589
SELECT CAST(STRFTIME('%Y',"date") AS INT) as "year", name, value
590
FROM test_data
591
WHERE value < {n}
592
"""
593
expected_frame = pl.DataFrame({"year": [2021], "name": ["other"], "value": [-99.5]})
594
595
for conn in (alchemy_session, alchemy_engine, alchemy_conn, raw_conn):
596
if conn is alchemy_session and param == "?":
597
continue # alchemy session.execute() doesn't support positional params
598
if parse_version(sqlalchemy.__version__) < (2, 0) and param == ":n":
599
continue # skip for older sqlalchemy versions
600
601
assert_frame_equal(
602
expected_frame,
603
pl.read_database(
604
query.format(n=param),
605
connection=conn,
606
execute_options={"parameters": param_value},
607
),
608
)
609
610
611
@pytest.mark.parametrize(
612
("param", "param_value"),
613
[
614
pytest.param(
615
":n",
616
pa.Table.from_pydict({"n": [0]}),
617
marks=pytest.mark.skip(
618
reason="Named binding not currently supported. See https://github.com/apache/arrow-adbc/issues/3262"
619
),
620
),
621
pytest.param(
622
":n",
623
{"n": 0},
624
marks=pytest.mark.skip(
625
reason="Named binding not currently supported. See https://github.com/apache/arrow-adbc/issues/3262",
626
),
627
),
628
("?", pa.Table.from_pydict({"data": [0]})),
629
("?", pl.DataFrame({"data": [0]})),
630
("?", pl.Series([{"data": 0}])),
631
("?", (0,)),
632
("?", [0]),
633
],
634
)
635
@pytest.mark.skipif(
636
sys.platform == "win32", reason="adbc_driver_sqlite not available on Windows"
637
)
638
def test_read_database_parameterised_adbc(
639
param: str, param_value: Any, tmp_sqlite_db: Path
640
) -> None:
641
# establish parameterised queries and validate usage
642
query = """
643
SELECT CAST(STRFTIME('%Y',"date") AS INT) as "year", name, value
644
FROM test_data
645
WHERE value < {n}
646
"""
647
expected_frame = pl.DataFrame({"year": [2021], "name": ["other"], "value": [-99.5]})
648
649
# ADBC will complain in pytest if the connection isn't closed
650
with adbc_driver_sqlite.dbapi.connect(str(tmp_sqlite_db)) as conn:
651
assert_frame_equal(
652
expected_frame,
653
pl.read_database(
654
query.format(n=param),
655
connection=conn,
656
execute_options={"parameters": param_value},
657
),
658
)
659
660
661
@pytest.mark.parametrize(
662
("params", "param_value"),
663
[
664
([":lo", ":hi"], {"lo": 90, "hi": 100}),
665
(["?", "?"], (90, 100)),
666
(["?", "?"], [90, 100]),
667
],
668
)
669
def test_read_database_parameterised_multiple(
670
params: list[str], param_value: Any, tmp_sqlite_db: Path
671
) -> None:
672
param_1, param_2 = params
673
# establish parameterised queries and validate usage
674
query = """
675
SELECT CAST(STRFTIME('%Y',"date") AS INT) as "year", name, value
676
FROM test_data
677
WHERE value BETWEEN {param_1} AND {param_2}
678
"""
679
expected_frame = pl.DataFrame({"year": [2020], "name": ["misc"], "value": [100.0]})
680
681
# raw cursor "execute" only takes positional params, alchemy cursor takes kwargs
682
alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}")
683
alchemy_conn: ConnectionOrCursor = alchemy_engine.connect()
684
alchemy_session: ConnectionOrCursor = sessionmaker(bind=alchemy_engine)()
685
raw_conn: ConnectionOrCursor = sqlite3.connect(tmp_sqlite_db)
686
for conn in (alchemy_session, alchemy_engine, alchemy_conn, raw_conn):
687
if alchemy_session is conn and param_1 == "?":
688
continue # alchemy session.execute() doesn't support positional params
689
if parse_version(sqlalchemy.__version__) < (2, 0) and isinstance(
690
param_value, dict
691
):
692
continue # skip for older sqlalchemy versions
693
694
assert_frame_equal(
695
expected_frame,
696
pl.read_database(
697
query.format(param_1=param_1, param_2=param_2),
698
connection=conn,
699
execute_options={"parameters": param_value},
700
),
701
)
702
703
704
@pytest.mark.parametrize(
705
("params", "param_value"),
706
[
707
pytest.param(
708
[":lo", ":hi"],
709
{"lo": 90, "hi": 100},
710
marks=pytest.mark.skip(
711
reason="Named binding not currently supported. See https://github.com/apache/arrow-adbc/issues/3262"
712
),
713
),
714
(["?", "?"], pa.Table.from_pydict({"data_1": [90], "data_2": [100]})),
715
(["?", "?"], pl.DataFrame({"data_1": [90], "data_2": [100]})),
716
(["?", "?"], pl.Series([{"data_1": 90, "data_2": 100}])),
717
(["?", "?"], (90, 100)),
718
(["?", "?"], [90, 100]),
719
],
720
)
721
@pytest.mark.skipif(
722
sys.platform == "win32", reason="adbc_driver_sqlite not available on Windows"
723
)
724
def test_read_database_parameterised_multiple_adbc(
725
params: list[str], param_value: Any, tmp_sqlite_db: Path
726
) -> None:
727
param_1, param_2 = params
728
# establish parameterised queries and validate usage
729
query = """
730
SELECT CAST(STRFTIME('%Y',"date") AS INT) as "year", name, value
731
FROM test_data
732
WHERE value BETWEEN {param_1} AND {param_2}
733
"""
734
expected_frame = pl.DataFrame({"year": [2020], "name": ["misc"], "value": [100.0]})
735
736
# ADBC will complain in pytest if the connection isn't closed
737
with adbc_driver_sqlite.dbapi.connect(str(tmp_sqlite_db)) as conn:
738
assert_frame_equal(
739
expected_frame,
740
pl.read_database(
741
query.format(param_1=param_1, param_2=param_2),
742
connection=conn,
743
execute_options={"parameters": param_value},
744
),
745
)
746
747
748
@pytest.mark.parametrize(
749
("param", "param_value"),
750
[
751
pytest.param(
752
":n",
753
pa.Table.from_pydict({"n": [0]}),
754
marks=pytest.mark.skip(
755
reason="Named binding not currently supported. See https://github.com/apache/arrow-adbc/issues/3262"
756
),
757
),
758
pytest.param(
759
":n",
760
{"n": 0},
761
marks=pytest.mark.skip(
762
reason="Named binding not currently supported. See https://github.com/apache/arrow-adbc/issues/3262",
763
),
764
),
765
("?", pa.Table.from_pydict({"data": [0]})),
766
("?", pl.DataFrame({"data": [0]})),
767
("?", pl.Series([{"data": 0}])),
768
("?", (0,)),
769
("?", [0]),
770
],
771
)
772
@pytest.mark.skipif(
773
sys.platform == "win32", reason="adbc_driver_sqlite not available on Windows"
774
)
775
def test_read_database_uri_parameterised(
776
param: str, param_value: Any, tmp_sqlite_db: Path
777
) -> None:
778
alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}")
779
uri = alchemy_engine.url.render_as_string(hide_password=False)
780
query = """
781
SELECT CAST(STRFTIME('%Y',"date") AS INT) as "year", name, value
782
FROM test_data
783
WHERE value < {n}
784
"""
785
expected_frame = pl.DataFrame({"year": [2021], "name": ["other"], "value": [-99.5]})
786
787
# test URI read method (adbc only)
788
assert_frame_equal(
789
expected_frame,
790
pl.read_database_uri(
791
query.format(n=param),
792
uri=uri,
793
engine="adbc",
794
execute_options={"parameters": param_value},
795
),
796
)
797
798
# no connectorx support for execute_options
799
with pytest.raises(
800
ValueError,
801
match=r"connectorx.*does not support.*execute_options",
802
):
803
pl.read_database_uri(
804
query.format(n=":n"),
805
uri=uri,
806
engine="connectorx",
807
execute_options={"parameters": (":n", {"n": 0})},
808
)
809
810
811
@pytest.mark.parametrize(
812
("params", "param_value"),
813
[
814
pytest.param(
815
[":lo", ":hi"],
816
{"lo": 90, "hi": 100},
817
marks=pytest.mark.xfail(
818
reason="Named binding not supported. See https://github.com/apache/arrow-adbc/issues/3262",
819
strict=True,
820
),
821
),
822
(["?", "?"], pa.Table.from_pydict({"data_1": [90], "data_2": [100]})),
823
(["?", "?"], pl.DataFrame({"data_1": [90], "data_2": [100]})),
824
(["?", "?"], pl.Series([{"data_1": 90, "data_2": 100}])),
825
(["?", "?"], (90, 100)),
826
(["?", "?"], [90, 100]),
827
],
828
)
829
@pytest.mark.skipif(
830
sys.platform == "win32", reason="adbc_driver_sqlite not available on Windows"
831
)
832
def test_read_database_uri_parameterised_multiple(
833
params: list[str], param_value: Any, tmp_sqlite_db: Path
834
) -> None:
835
param_1, param_2 = params
836
alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}")
837
uri = alchemy_engine.url.render_as_string(hide_password=False)
838
query = """
839
SELECT CAST(STRFTIME('%Y',"date") AS INT) as "year", name, value
840
FROM test_data
841
WHERE value BETWEEN {param_1} AND {param_2}
842
"""
843
expected_frame = pl.DataFrame({"year": [2020], "name": ["misc"], "value": [100.0]})
844
845
# test URI read method (ADBC only)
846
assert_frame_equal(
847
expected_frame,
848
pl.read_database_uri(
849
query.format(param_1=param_1, param_2=param_2),
850
uri=uri,
851
engine="adbc",
852
execute_options={"parameters": param_value},
853
),
854
)
855
856
# no connectorx support for execute_options
857
with pytest.raises(
858
ValueError,
859
match=r"connectorx.*does not support.*execute_options",
860
):
861
pl.read_database_uri(
862
query.format(param_1="?", param_2="?"),
863
uri=uri,
864
engine="connectorx",
865
execute_options={"parameters": (90, 100)},
866
)
867
868
869
@pytest.mark.parametrize(
870
("driver", "batch_size", "iter_batches", "expected_call"),
871
[
872
("snowflake", None, False, "fetch_arrow_all"),
873
("snowflake", 10_000, False, "fetch_arrow_all"),
874
("snowflake", 10_000, True, "fetch_arrow_batches"),
875
("databricks", None, False, "fetchall_arrow"),
876
("databricks", 25_000, False, "fetchall_arrow"),
877
("databricks", 25_000, True, "fetchmany_arrow"),
878
("turbodbc", None, False, "fetchallarrow"),
879
("turbodbc", 50_000, False, "fetchallarrow"),
880
("turbodbc", 50_000, True, "fetcharrowbatches"),
881
pytest.param(
882
"adbc_driver_postgresql",
883
None,
884
False,
885
"fetch_arrow",
886
marks=pytest.mark.skipif(
887
sys.platform == "win32",
888
reason="adbc_driver_postgresql not available on Windows",
889
),
890
),
891
pytest.param(
892
"adbc_driver_postgresql",
893
75_000,
894
False,
895
"fetch_arrow",
896
marks=pytest.mark.skipif(
897
sys.platform == "win32",
898
reason="adbc_driver_postgresql not available on Windows",
899
),
900
),
901
pytest.param(
902
"adbc_driver_postgresql",
903
75_000,
904
True,
905
"fetch_record_batch",
906
marks=pytest.mark.skipif(
907
sys.platform == "win32",
908
reason="adbc_driver_postgresql not available on Windows",
909
),
910
),
911
],
912
)
913
def test_read_database_mocked(
914
driver: str, batch_size: int | None, iter_batches: bool, expected_call: str
915
) -> None:
916
# since we don't have access to snowflake/databricks/etc from CI we
917
# mock them so we can check that we're calling the expected methods
918
arrow = pl.DataFrame({"x": [1, 2, 3], "y": ["aa", "bb", "cc"]}).to_arrow()
919
920
reg = ARROW_DRIVER_REGISTRY.get(driver, [{}])[0] # type: ignore[var-annotated]
921
exact_batch_size = reg.get("exact_batch_size", False)
922
repeat_batch_calls = reg.get("repeat_batch_calls", False)
923
924
mc = MockConnection(
925
driver,
926
batch_size,
927
test_data=arrow,
928
repeat_batch_calls=repeat_batch_calls,
929
exact_batch_size=exact_batch_size, # type: ignore[arg-type]
930
)
931
res = pl.read_database(
932
query="SELECT * FROM test_data",
933
connection=mc,
934
iter_batches=iter_batches,
935
batch_size=batch_size,
936
)
937
if iter_batches:
938
assert isinstance(res, GeneratorType)
939
res = pl.concat(res)
940
941
res = cast("pl.DataFrame", res)
942
assert expected_call in mc.cursor().called
943
assert res.rows() == [(1, "aa"), (2, "bb"), (3, "cc")]
944
945
946
@pytest.mark.parametrize(
947
(
948
"read_method",
949
"query",
950
"protocol",
951
"errclass",
952
"errmsg",
953
"engine",
954
"execute_options",
955
"pre_execution_query",
956
"kwargs",
957
),
958
[
959
pytest.param(
960
*ExceptionTestParams(
961
read_method="read_database_uri",
962
query="SELECT * FROM test_data",
963
protocol="sqlite",
964
errclass=ValueError,
965
errmsg="engine must be one of {'connectorx', 'adbc'}, got 'not_an_engine'",
966
engine="not_an_engine",
967
),
968
id="Not an available sql engine",
969
),
970
pytest.param(
971
*ExceptionTestParams(
972
read_method="read_database_uri",
973
query=["SELECT * FROM test_data", "SELECT * FROM test_data"],
974
protocol="sqlite",
975
errclass=ValueError,
976
errmsg="only a single SQL query string is accepted for adbc, got a 'list' type",
977
engine="adbc",
978
),
979
id="Unavailable list of queries for adbc",
980
),
981
pytest.param(
982
*ExceptionTestParams(
983
read_method="read_database_uri",
984
query="SELECT * FROM test_data",
985
protocol="mysql",
986
errclass=ModuleNotFoundError,
987
errmsg="ADBC 'adbc_driver_mysql' driver not detected.",
988
engine="adbc",
989
),
990
id="Unavailable adbc driver",
991
),
992
pytest.param(
993
*ExceptionTestParams(
994
read_method="read_database_uri",
995
query="SELECT * FROM test_data",
996
protocol=sqlite3.connect(":memory:"),
997
errclass=TypeError,
998
errmsg="expected connection to be a URI string",
999
engine="adbc",
1000
),
1001
id="Invalid connection URI",
1002
),
1003
pytest.param(
1004
*ExceptionTestParams(
1005
read_method="read_database",
1006
query="SELECT * FROM imaginary_table",
1007
protocol=sqlite3.connect(":memory:"),
1008
errclass=sqlite3.OperationalError,
1009
errmsg="no such table: imaginary_table",
1010
),
1011
id="Invalid query (unrecognised table name)",
1012
),
1013
pytest.param(
1014
*ExceptionTestParams(
1015
read_method="read_database",
1016
query="SELECT * FROM imaginary_table",
1017
protocol=sys.getsizeof, # not a connection
1018
errclass=TypeError,
1019
errmsg="Unrecognised connection .* no 'execute' or 'cursor' method",
1020
),
1021
id="Invalid read DB kwargs",
1022
),
1023
pytest.param(
1024
*ExceptionTestParams(
1025
read_method="read_database",
1026
query="/* tag: misc */ INSERT INTO xyz VALUES ('polars')",
1027
protocol=sqlite3.connect(":memory:"),
1028
errclass=UnsuitableSQLError,
1029
errmsg="INSERT statements are not valid 'read' queries",
1030
),
1031
id="Invalid statement type",
1032
),
1033
pytest.param(
1034
*ExceptionTestParams(
1035
read_method="read_database",
1036
query="DELETE FROM xyz WHERE id = 'polars'",
1037
protocol=sqlite3.connect(":memory:"),
1038
errclass=UnsuitableSQLError,
1039
errmsg="DELETE statements are not valid 'read' queries",
1040
),
1041
id="Invalid statement type",
1042
),
1043
pytest.param(
1044
*ExceptionTestParams(
1045
read_method="read_database",
1046
query="SELECT * FROM sqlite_master",
1047
protocol=sqlite3.connect(":memory:"),
1048
errclass=ValueError,
1049
kwargs={"iter_batches": True},
1050
errmsg="Cannot set `iter_batches` without also setting a non-zero `batch_size`",
1051
),
1052
id="Invalid batch_size",
1053
),
1054
pytest.param(
1055
*ExceptionTestParams(
1056
read_method="read_database",
1057
engine="adbc",
1058
query="SELECT * FROM test_data",
1059
protocol=sqlite3.connect(":memory:"),
1060
errclass=TypeError,
1061
errmsg=r"unexpected keyword argument 'partition_on'",
1062
kwargs={"partition_on": "id"},
1063
),
1064
id="Invalid kwargs",
1065
),
1066
pytest.param(
1067
*ExceptionTestParams(
1068
read_method="read_database",
1069
engine="adbc",
1070
query="SELECT * FROM test_data",
1071
protocol="{not:a, valid:odbc_string}",
1072
errclass=ValueError,
1073
errmsg=r"unable to identify string connection as valid ODBC",
1074
),
1075
id="Invalid ODBC string",
1076
),
1077
pytest.param(
1078
*ExceptionTestParams(
1079
read_method="read_database_uri",
1080
query="SELECT * FROM test_data",
1081
protocol="sqlite",
1082
errclass=ValueError,
1083
errmsg="the 'adbc' engine does not support use of `pre_execution_query`",
1084
engine="adbc",
1085
pre_execution_query="SET statement_timeout = 2151",
1086
),
1087
id="Unavailable `pre_execution_query` for adbc",
1088
),
1089
],
1090
)
1091
def test_read_database_exceptions(
1092
read_method: str,
1093
query: str,
1094
protocol: Any,
1095
errclass: type[Exception],
1096
errmsg: str,
1097
engine: DbReadEngine | None,
1098
execute_options: dict[str, Any] | None,
1099
pre_execution_query: str | list[str] | None,
1100
kwargs: dict[str, Any] | None,
1101
) -> None:
1102
if read_method == "read_database_uri":
1103
conn = f"{protocol}://test" if isinstance(protocol, str) else protocol
1104
params = {
1105
"uri": conn,
1106
"query": query,
1107
"engine": engine,
1108
"pre_execution_query": pre_execution_query,
1109
}
1110
else:
1111
params = {"connection": protocol, "query": query}
1112
if execute_options:
1113
params["execute_options"] = execute_options
1114
if kwargs is not None:
1115
params.update(kwargs)
1116
1117
read_database = getattr(pl, read_method)
1118
with pytest.raises(errclass, match=errmsg):
1119
read_database(**params)
1120
1121
1122
@pytest.mark.parametrize(
1123
"query",
1124
[
1125
"SELECT 1, 1 FROM test_data",
1126
'SELECT 1 AS "n", 2 AS "n" FROM test_data',
1127
'SELECT name, value AS "name" FROM test_data',
1128
],
1129
)
1130
def test_read_database_duplicate_column_error(tmp_sqlite_db: Path, query: str) -> None:
1131
alchemy_conn = create_engine(f"sqlite:///{tmp_sqlite_db}").connect()
1132
with pytest.raises(
1133
DuplicateError,
1134
match=r"column .+ appears more than once in the query/result cursor",
1135
):
1136
pl.read_database(query, connection=alchemy_conn)
1137
1138
1139
@pytest.mark.parametrize(
1140
"uri",
1141
[
1142
"fakedb://123:456@account/database/schema?warehouse=warehouse&role=role",
1143
"fakedb://my#%us3r:p433w0rd@not_a_real_host:9999/database",
1144
],
1145
)
1146
def test_read_database_cx_credentials(uri: str) -> None:
1147
with pytest.raises(RuntimeError, match=r"Source.*not supported"):
1148
pl.read_database_uri("SELECT * FROM data", uri=uri, engine="connectorx")
1149
1150
1151
def test_sqlalchemy_row_init(tmp_sqlite_db: Path) -> None:
1152
expected_frame = pl.DataFrame(
1153
{
1154
"id": [1, 2],
1155
"name": ["misc", "other"],
1156
"value": [100.0, -99.5],
1157
"date": ["2020-01-01", "2021-12-31"],
1158
}
1159
)
1160
alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}")
1161
query = text("SELECT * FROM test_data ORDER BY name")
1162
1163
with alchemy_engine.connect() as conn:
1164
# note: sqlalchemy `Row` is a NamedTuple-like object; it additionally has
1165
# a `_mapping` attribute that returns a `RowMapping` dict-like object. we
1166
# validate frame/series init from each flavour of query result.
1167
query_result = list(conn.execute(query))
1168
for df in (
1169
pl.DataFrame(query_result),
1170
pl.DataFrame([row._mapping for row in query_result]),
1171
pl.from_records([row._mapping for row in query_result]),
1172
):
1173
assert_frame_equal(expected_frame, df)
1174
1175
expected_series = expected_frame.to_struct()
1176
for s in (
1177
pl.Series(query_result),
1178
pl.Series([row._mapping for row in query_result]),
1179
):
1180
assert_series_equal(expected_series, s)
1181
1182
1183
@patch("polars.io.database._utils.from_arrow")
1184
@patch("polars.io.database._utils.import_optional")
1185
def test_read_database_uri_pre_execution_query_success(
1186
import_mock: Mock, from_arrow_mock: Mock
1187
) -> None:
1188
cx_mock = Mock()
1189
cx_mock.__version__ = "0.4.2"
1190
1191
import_mock.return_value = cx_mock
1192
1193
pre_execution_query = "SET statement_timeout = 2151"
1194
1195
pl.read_database_uri(
1196
query="SELECT 1",
1197
uri="mysql://test",
1198
engine="connectorx",
1199
pre_execution_query=pre_execution_query,
1200
)
1201
1202
assert (
1203
cx_mock.read_sql.call_args.kwargs["pre_execution_query"] == pre_execution_query
1204
)
1205
1206
1207
@patch("polars.io.database._utils.import_optional")
1208
def test_read_database_uri_pre_execution_not_supported_exception(
1209
import_mock: Mock,
1210
) -> None:
1211
cx_mock = Mock()
1212
cx_mock.__version__ = "0.4.0"
1213
1214
import_mock.return_value = cx_mock
1215
1216
with (
1217
pytest.raises(
1218
ValueError,
1219
match=r"'pre_execution_query' is only supported in connectorx version 0\.4\.2 or later",
1220
),
1221
):
1222
pl.read_database_uri(
1223
query="SELECT 1",
1224
uri="mysql://test",
1225
engine="connectorx",
1226
pre_execution_query="SET statement_timeout = 2151",
1227
)
1228
1229
1230
@patch("polars.io.database._utils.from_arrow")
1231
@patch("polars.io.database._utils.import_optional")
1232
def test_read_database_uri_pre_execution_query_not_supported_success(
1233
import_mock: Mock, from_arrow_mock: Mock
1234
) -> None:
1235
cx_mock = Mock()
1236
cx_mock.__version__ = "0.4.0"
1237
1238
import_mock.return_value = cx_mock
1239
1240
pl.read_database_uri(
1241
query="SELECT 1",
1242
uri="mysql://test",
1243
engine="connectorx",
1244
)
1245
1246
assert cx_mock.read_sql.call_args.kwargs.get("pre_execution_query") is None
1247
1248