Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/io/database/test_read.py
6939 views
1
from __future__ import annotations
2
3
import os
4
import sqlite3
5
import sys
6
from contextlib import suppress
7
from datetime import date
8
from pathlib import Path
9
from types import GeneratorType
10
from typing import TYPE_CHECKING, Any, NamedTuple, cast
11
from unittest.mock import Mock, patch
12
13
with suppress(ModuleNotFoundError): # not available on windows
14
import adbc_driver_sqlite.dbapi
15
import pyarrow as pa
16
import pytest
17
import sqlalchemy
18
from sqlalchemy import Integer, MetaData, Table, create_engine, func, select, text
19
from sqlalchemy.orm import sessionmaker
20
from sqlalchemy.sql.expression import cast as alchemy_cast
21
22
import polars as pl
23
from polars._utils.various import parse_version
24
from polars.exceptions import DuplicateError, UnsuitableSQLError
25
from polars.io.database._arrow_registry import ARROW_DRIVER_REGISTRY
26
from polars.testing import assert_frame_equal, assert_series_equal
27
28
if TYPE_CHECKING:
29
from polars._typing import (
30
ConnectionOrCursor,
31
DbReadEngine,
32
SchemaDefinition,
33
SchemaDict,
34
)
35
36
37
def adbc_sqlite_connect(*args: Any, **kwargs: Any) -> Any:
38
args = tuple(str(a) if isinstance(a, Path) else a for a in args)
39
return adbc_driver_sqlite.dbapi.connect(*args, **kwargs)
40
41
42
class MockConnection:
43
"""Mock connection class for databases we can't test in CI."""
44
45
def __init__(
46
self,
47
driver: str,
48
batch_size: int | None,
49
exact_batch_size: bool,
50
test_data: pa.Table,
51
repeat_batch_calls: bool,
52
) -> None:
53
self.__class__.__module__ = driver
54
self._cursor = MockCursor(
55
repeat_batch_calls=repeat_batch_calls,
56
exact_batch_size=exact_batch_size,
57
batched=(batch_size is not None),
58
test_data=test_data,
59
)
60
61
def close(self) -> None:
62
pass
63
64
def cursor(self) -> Any:
65
return self._cursor
66
67
68
class MockCursor:
69
"""Mock cursor class for databases we can't test in CI."""
70
71
def __init__(
72
self,
73
batched: bool,
74
exact_batch_size: bool,
75
test_data: pa.Table,
76
repeat_batch_calls: bool,
77
) -> None:
78
self.resultset = MockResultSet(
79
test_data=test_data,
80
batched=batched,
81
exact_batch_size=exact_batch_size,
82
repeat_batch_calls=repeat_batch_calls,
83
)
84
self.exact_batch_size = exact_batch_size
85
self.called: list[str] = []
86
self.batched = batched
87
self.n_calls = 1
88
89
def __getattr__(self, name: str) -> Any:
90
if "fetch" in name:
91
self.called.append(name)
92
return self.resultset
93
super().__getattr__(name) # type: ignore[misc]
94
95
def close(self) -> Any:
96
pass
97
98
def execute(self, query: str) -> Any:
99
return self
100
101
102
class MockResultSet:
103
"""Mock resultset class for databases we can't test in CI."""
104
105
def __init__(
106
self,
107
test_data: pa.Table,
108
batched: bool,
109
exact_batch_size: bool,
110
repeat_batch_calls: bool = False,
111
) -> None:
112
self.test_data = test_data
113
self.repeat_batched_calls = repeat_batch_calls
114
self.exact_batch_size = exact_batch_size
115
self.batched = batched
116
self.n_calls = 1
117
118
def __call__(self, *args: Any, **kwargs: Any) -> Any:
119
if not self.exact_batch_size:
120
assert len(args) == 0
121
if self.repeat_batched_calls:
122
res = self.test_data[: None if self.n_calls else 0]
123
self.n_calls -= 1
124
else:
125
res = iter((self.test_data,))
126
return res
127
128
129
class DatabaseReadTestParams(NamedTuple):
130
"""Clarify read test params."""
131
132
read_method: str
133
connect_using: Any
134
expected_dtypes: SchemaDefinition
135
expected_dates: list[date | str]
136
schema_overrides: SchemaDict | None = None
137
batch_size: int | None = None
138
139
140
class ExceptionTestParams(NamedTuple):
141
"""Clarify exception test params."""
142
143
read_method: str
144
query: str | list[str]
145
protocol: Any
146
errclass: type[Exception]
147
errmsg: str
148
engine: str | None = None
149
execute_options: dict[str, Any] | None = None
150
pre_execution_query: str | list[str] | None = None
151
kwargs: dict[str, Any] | None = None
152
153
154
@pytest.mark.write_disk
155
@pytest.mark.parametrize(
156
(
157
"read_method",
158
"connect_using",
159
"expected_dtypes",
160
"expected_dates",
161
"schema_overrides",
162
"batch_size",
163
),
164
[
165
pytest.param(
166
*DatabaseReadTestParams(
167
read_method="read_database_uri",
168
connect_using="connectorx",
169
expected_dtypes={
170
"id": pl.UInt8,
171
"name": pl.String,
172
"value": pl.Float64,
173
"date": pl.Date,
174
},
175
expected_dates=[date(2020, 1, 1), date(2021, 12, 31)],
176
schema_overrides={"id": pl.UInt8},
177
),
178
id="uri: connectorx",
179
),
180
pytest.param(
181
*DatabaseReadTestParams(
182
read_method="read_database_uri",
183
connect_using="adbc",
184
expected_dtypes={
185
"id": pl.UInt8,
186
"name": pl.String,
187
"value": pl.Float64,
188
"date": pl.String,
189
},
190
expected_dates=["2020-01-01", "2021-12-31"],
191
schema_overrides={"id": pl.UInt8},
192
),
193
marks=pytest.mark.skipif(
194
sys.platform == "win32",
195
reason="adbc_driver_sqlite not available on Windows",
196
),
197
id="uri: adbc",
198
),
199
pytest.param(
200
*DatabaseReadTestParams(
201
read_method="read_database",
202
connect_using=lambda path: sqlite3.connect(path, detect_types=True),
203
expected_dtypes={
204
"id": pl.UInt8,
205
"name": pl.String,
206
"value": pl.Float32,
207
"date": pl.Date,
208
},
209
expected_dates=[date(2020, 1, 1), date(2021, 12, 31)],
210
schema_overrides={"id": pl.UInt8, "value": pl.Float32},
211
),
212
id="conn: sqlite3",
213
),
214
pytest.param(
215
*DatabaseReadTestParams(
216
read_method="read_database",
217
connect_using=lambda path: sqlite3.connect(path, detect_types=True),
218
expected_dtypes={
219
"id": pl.Int32,
220
"name": pl.String,
221
"value": pl.Float32,
222
"date": pl.Date,
223
},
224
expected_dates=[date(2020, 1, 1), date(2021, 12, 31)],
225
schema_overrides={"id": pl.Int32, "value": pl.Float32},
226
batch_size=1,
227
),
228
id="conn: sqlite3",
229
),
230
pytest.param(
231
*DatabaseReadTestParams(
232
read_method="read_database",
233
connect_using=lambda path: create_engine(
234
f"sqlite:///{path}",
235
connect_args={"detect_types": sqlite3.PARSE_DECLTYPES},
236
).connect(),
237
expected_dtypes={
238
"id": pl.Int64,
239
"name": pl.String,
240
"value": pl.Float64,
241
"date": pl.Date,
242
},
243
expected_dates=[date(2020, 1, 1), date(2021, 12, 31)],
244
),
245
id="conn: sqlalchemy",
246
),
247
pytest.param(
248
*DatabaseReadTestParams(
249
read_method="read_database",
250
connect_using=adbc_sqlite_connect,
251
expected_dtypes={
252
"id": pl.Int64,
253
"name": pl.String,
254
"value": pl.Float64,
255
"date": pl.String,
256
},
257
expected_dates=["2020-01-01", "2021-12-31"],
258
),
259
marks=pytest.mark.skipif(
260
sys.platform == "win32",
261
reason="adbc_driver_sqlite not available on Windows",
262
),
263
id="conn: adbc (fetchall)",
264
),
265
pytest.param(
266
*DatabaseReadTestParams(
267
read_method="read_database",
268
connect_using=adbc_sqlite_connect,
269
expected_dtypes={
270
"id": pl.Int64,
271
"name": pl.String,
272
"value": pl.Float64,
273
"date": pl.String,
274
},
275
expected_dates=["2020-01-01", "2021-12-31"],
276
batch_size=1,
277
),
278
marks=pytest.mark.skipif(
279
sys.platform == "win32",
280
reason="adbc_driver_sqlite not available on Windows",
281
),
282
id="conn: adbc (batched)",
283
),
284
],
285
)
286
def test_read_database(
287
read_method: str,
288
connect_using: Any,
289
expected_dtypes: dict[str, pl.DataType],
290
expected_dates: list[date | str],
291
schema_overrides: SchemaDict | None,
292
batch_size: int | None,
293
tmp_sqlite_db: Path,
294
) -> None:
295
if read_method == "read_database_uri":
296
connect_using = cast("DbReadEngine", connect_using)
297
# instantiate the connection ourselves, using connectorx/adbc
298
df = pl.read_database_uri(
299
uri=f"sqlite:///{tmp_sqlite_db}",
300
query="SELECT * FROM test_data",
301
engine=connect_using,
302
schema_overrides=schema_overrides,
303
)
304
df_empty = pl.read_database_uri(
305
uri=f"sqlite:///{tmp_sqlite_db}",
306
query="SELECT * FROM test_data WHERE name LIKE '%polars%'",
307
engine=connect_using,
308
schema_overrides=schema_overrides,
309
)
310
elif "adbc" in os.environ["PYTEST_CURRENT_TEST"]:
311
# externally instantiated adbc connections
312
with connect_using(tmp_sqlite_db) as conn:
313
df = pl.read_database(
314
connection=conn,
315
query="SELECT * FROM test_data",
316
schema_overrides=schema_overrides,
317
batch_size=batch_size,
318
)
319
df_empty = pl.read_database(
320
connection=conn,
321
query="SELECT * FROM test_data WHERE name LIKE '%polars%'",
322
schema_overrides=schema_overrides,
323
batch_size=batch_size,
324
)
325
else:
326
# other user-supplied connections
327
df = pl.read_database(
328
connection=connect_using(tmp_sqlite_db),
329
query="SELECT * FROM test_data WHERE name NOT LIKE '%polars%'",
330
schema_overrides=schema_overrides,
331
batch_size=batch_size,
332
)
333
df_empty = pl.read_database(
334
connection=connect_using(tmp_sqlite_db),
335
query="SELECT * FROM test_data WHERE name LIKE '%polars%'",
336
schema_overrides=schema_overrides,
337
batch_size=batch_size,
338
)
339
340
# validate the expected query return (data and schema)
341
assert df.schema == expected_dtypes
342
assert df.shape == (2, 4)
343
assert df["date"].to_list() == expected_dates
344
345
# note: 'cursor.description' is not reliable when no query
346
# data is returned, so no point comparing expected dtypes
347
assert df_empty.columns == ["id", "name", "value", "date"]
348
assert df_empty.shape == (0, 4)
349
assert df_empty["date"].to_list() == []
350
351
352
def test_read_database_alchemy_selectable(tmp_sqlite_db: Path) -> None:
353
# various flavours of alchemy connection
354
alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}")
355
alchemy_session: ConnectionOrCursor = sessionmaker(bind=alchemy_engine)()
356
alchemy_conn: ConnectionOrCursor = alchemy_engine.connect()
357
358
t = Table("test_data", MetaData(), autoload_with=alchemy_engine)
359
360
# establish sqlalchemy "selectable" and validate usage
361
selectable_query = select(
362
alchemy_cast(func.strftime("%Y", t.c.date), Integer).label("year"),
363
t.c.name,
364
t.c.value,
365
).where(t.c.value < 0)
366
367
expected = pl.DataFrame({"year": [2021], "name": ["other"], "value": [-99.5]})
368
369
for conn in (alchemy_session, alchemy_engine, alchemy_conn):
370
assert_frame_equal(
371
pl.read_database(selectable_query, connection=conn),
372
expected,
373
)
374
375
batches = list(
376
pl.read_database(
377
selectable_query,
378
connection=conn,
379
iter_batches=True,
380
batch_size=1,
381
)
382
)
383
assert len(batches) == 1
384
assert_frame_equal(batches[0], expected)
385
386
387
def test_read_database_alchemy_textclause(tmp_sqlite_db: Path) -> None:
388
# various flavours of alchemy connection
389
alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}")
390
alchemy_session: ConnectionOrCursor = sessionmaker(bind=alchemy_engine)()
391
alchemy_conn: ConnectionOrCursor = alchemy_engine.connect()
392
393
# establish sqlalchemy "textclause" and validate usage
394
textclause_query = text(
395
"""
396
SELECT CAST(STRFTIME('%Y',"date") AS INT) as "year", name, value
397
FROM test_data
398
WHERE value < 0
399
"""
400
)
401
402
expected = pl.DataFrame({"year": [2021], "name": ["other"], "value": [-99.5]})
403
404
for conn in (alchemy_session, alchemy_engine, alchemy_conn):
405
assert_frame_equal(
406
pl.read_database(textclause_query, connection=conn),
407
expected,
408
)
409
410
batches = list(
411
pl.read_database(
412
textclause_query,
413
connection=conn,
414
iter_batches=True,
415
batch_size=1,
416
)
417
)
418
assert len(batches) == 1
419
assert_frame_equal(batches[0], expected)
420
421
422
@pytest.mark.parametrize(
423
("param", "param_value"),
424
[
425
(":n", {"n": 0}),
426
("?", (0,)),
427
("?", [0]),
428
],
429
)
430
def test_read_database_parameterised(
431
param: str, param_value: Any, tmp_sqlite_db: Path
432
) -> None:
433
# raw cursor "execute" only takes positional params, alchemy cursor takes kwargs
434
alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}")
435
alchemy_conn: ConnectionOrCursor = alchemy_engine.connect()
436
alchemy_session: ConnectionOrCursor = sessionmaker(bind=alchemy_engine)()
437
raw_conn: ConnectionOrCursor = sqlite3.connect(tmp_sqlite_db)
438
439
# establish parameterised queries and validate usage
440
query = """
441
SELECT CAST(STRFTIME('%Y',"date") AS INT) as "year", name, value
442
FROM test_data
443
WHERE value < {n}
444
"""
445
expected_frame = pl.DataFrame({"year": [2021], "name": ["other"], "value": [-99.5]})
446
447
for conn in (alchemy_session, alchemy_engine, alchemy_conn, raw_conn):
448
if conn is alchemy_session and param == "?":
449
continue # alchemy session.execute() doesn't support positional params
450
if parse_version(sqlalchemy.__version__) < (2, 0) and param == ":n":
451
continue # skip for older sqlalchemy versions
452
453
assert_frame_equal(
454
expected_frame,
455
pl.read_database(
456
query.format(n=param),
457
connection=conn,
458
execute_options={"parameters": param_value},
459
),
460
)
461
462
463
@pytest.mark.parametrize(
464
("param", "param_value"),
465
[
466
pytest.param(
467
":n",
468
pa.Table.from_pydict({"n": [0]}),
469
marks=pytest.mark.skip(
470
reason="Named binding not currently supported. See https://github.com/apache/arrow-adbc/issues/3262"
471
),
472
),
473
pytest.param(
474
":n",
475
{"n": 0},
476
marks=pytest.mark.skip(
477
reason="Named binding not currently supported. See https://github.com/apache/arrow-adbc/issues/3262",
478
),
479
),
480
("?", pa.Table.from_pydict({"data": [0]})),
481
("?", pl.DataFrame({"data": [0]})),
482
("?", pl.Series([{"data": 0}])),
483
("?", (0,)),
484
("?", [0]),
485
],
486
)
487
@pytest.mark.skipif(
488
sys.platform == "win32", reason="adbc_driver_sqlite not available on Windows"
489
)
490
def test_read_database_parameterised_adbc(
491
param: str, param_value: Any, tmp_sqlite_db: Path
492
) -> None:
493
# establish parameterised queries and validate usage
494
query = """
495
SELECT CAST(STRFTIME('%Y',"date") AS INT) as "year", name, value
496
FROM test_data
497
WHERE value < {n}
498
"""
499
expected_frame = pl.DataFrame({"year": [2021], "name": ["other"], "value": [-99.5]})
500
501
# ADBC will complain in pytest if the connection isn't closed
502
with adbc_driver_sqlite.dbapi.connect(str(tmp_sqlite_db)) as conn:
503
assert_frame_equal(
504
expected_frame,
505
pl.read_database(
506
query.format(n=param),
507
connection=conn,
508
execute_options={"parameters": param_value},
509
),
510
)
511
512
513
@pytest.mark.parametrize(
514
("params", "param_value"),
515
[
516
([":lo", ":hi"], {"lo": 90, "hi": 100}),
517
(["?", "?"], (90, 100)),
518
(["?", "?"], [90, 100]),
519
],
520
)
521
def test_read_database_parameterised_multiple(
522
params: list[str], param_value: Any, tmp_sqlite_db: Path
523
) -> None:
524
param_1, param_2 = params
525
# establish parameterised queries and validate usage
526
query = """
527
SELECT CAST(STRFTIME('%Y',"date") AS INT) as "year", name, value
528
FROM test_data
529
WHERE value BETWEEN {param_1} AND {param_2}
530
"""
531
expected_frame = pl.DataFrame({"year": [2020], "name": ["misc"], "value": [100.0]})
532
533
# raw cursor "execute" only takes positional params, alchemy cursor takes kwargs
534
alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}")
535
alchemy_conn: ConnectionOrCursor = alchemy_engine.connect()
536
alchemy_session: ConnectionOrCursor = sessionmaker(bind=alchemy_engine)()
537
raw_conn: ConnectionOrCursor = sqlite3.connect(tmp_sqlite_db)
538
for conn in (alchemy_session, alchemy_engine, alchemy_conn, raw_conn):
539
if alchemy_session is conn and param_1 == "?":
540
continue # alchemy session.execute() doesn't support positional params
541
if parse_version(sqlalchemy.__version__) < (2, 0) and isinstance(
542
param_value, dict
543
):
544
continue # skip for older sqlalchemy versions
545
546
assert_frame_equal(
547
expected_frame,
548
pl.read_database(
549
query.format(param_1=param_1, param_2=param_2),
550
connection=conn,
551
execute_options={"parameters": param_value},
552
),
553
)
554
555
556
@pytest.mark.parametrize(
557
("params", "param_value"),
558
[
559
pytest.param(
560
[":lo", ":hi"],
561
{"lo": 90, "hi": 100},
562
marks=pytest.mark.skip(
563
reason="Named binding not currently supported. See https://github.com/apache/arrow-adbc/issues/3262"
564
),
565
),
566
(["?", "?"], pa.Table.from_pydict({"data_1": [90], "data_2": [100]})),
567
(["?", "?"], pl.DataFrame({"data_1": [90], "data_2": [100]})),
568
(["?", "?"], pl.Series([{"data_1": 90, "data_2": 100}])),
569
(["?", "?"], (90, 100)),
570
(["?", "?"], [90, 100]),
571
],
572
)
573
@pytest.mark.skipif(
574
sys.platform == "win32", reason="adbc_driver_sqlite not available on Windows"
575
)
576
def test_read_database_parameterised_multiple_adbc(
577
params: list[str], param_value: Any, tmp_sqlite_db: Path
578
) -> None:
579
param_1, param_2 = params
580
# establish parameterised queries and validate usage
581
query = """
582
SELECT CAST(STRFTIME('%Y',"date") AS INT) as "year", name, value
583
FROM test_data
584
WHERE value BETWEEN {param_1} AND {param_2}
585
"""
586
expected_frame = pl.DataFrame({"year": [2020], "name": ["misc"], "value": [100.0]})
587
588
# ADBC will complain in pytest if the connection isn't closed
589
with adbc_driver_sqlite.dbapi.connect(str(tmp_sqlite_db)) as conn:
590
assert_frame_equal(
591
expected_frame,
592
pl.read_database(
593
query.format(param_1=param_1, param_2=param_2),
594
connection=conn,
595
execute_options={"parameters": param_value},
596
),
597
)
598
599
600
@pytest.mark.parametrize(
601
("param", "param_value"),
602
[
603
pytest.param(
604
":n",
605
pa.Table.from_pydict({"n": [0]}),
606
marks=pytest.mark.skip(
607
reason="Named binding not currently supported. See https://github.com/apache/arrow-adbc/issues/3262"
608
),
609
),
610
pytest.param(
611
":n",
612
{"n": 0},
613
marks=pytest.mark.skip(
614
reason="Named binding not currently supported. See https://github.com/apache/arrow-adbc/issues/3262",
615
),
616
),
617
("?", pa.Table.from_pydict({"data": [0]})),
618
("?", pl.DataFrame({"data": [0]})),
619
("?", pl.Series([{"data": 0}])),
620
("?", (0,)),
621
("?", [0]),
622
],
623
)
624
@pytest.mark.skipif(
625
sys.platform == "win32", reason="adbc_driver_sqlite not available on Windows"
626
)
627
def test_read_database_uri_parameterised(
628
param: str, param_value: Any, tmp_sqlite_db: Path
629
) -> None:
630
alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}")
631
uri = alchemy_engine.url.render_as_string(hide_password=False)
632
query = """
633
SELECT CAST(STRFTIME('%Y',"date") AS INT) as "year", name, value
634
FROM test_data
635
WHERE value < {n}
636
"""
637
expected_frame = pl.DataFrame({"year": [2021], "name": ["other"], "value": [-99.5]})
638
639
# test URI read method (adbc only)
640
assert_frame_equal(
641
expected_frame,
642
pl.read_database_uri(
643
query.format(n=param),
644
uri=uri,
645
engine="adbc",
646
execute_options={"parameters": param_value},
647
),
648
)
649
650
# no connectorx support for execute_options
651
with pytest.raises(
652
ValueError,
653
match="connectorx.*does not support.*execute_options",
654
):
655
pl.read_database_uri(
656
query.format(n=":n"),
657
uri=uri,
658
engine="connectorx",
659
execute_options={"parameters": (":n", {"n": 0})},
660
)
661
662
663
@pytest.mark.parametrize(
664
("params", "param_value"),
665
[
666
pytest.param(
667
[":lo", ":hi"],
668
{"lo": 90, "hi": 100},
669
marks=pytest.mark.xfail(
670
reason="Named binding not supported. See https://github.com/apache/arrow-adbc/issues/3262",
671
strict=True,
672
),
673
),
674
(["?", "?"], pa.Table.from_pydict({"data_1": [90], "data_2": [100]})),
675
(["?", "?"], pl.DataFrame({"data_1": [90], "data_2": [100]})),
676
(["?", "?"], pl.Series([{"data_1": 90, "data_2": 100}])),
677
(["?", "?"], (90, 100)),
678
(["?", "?"], [90, 100]),
679
],
680
)
681
@pytest.mark.skipif(
682
sys.platform == "win32", reason="adbc_driver_sqlite not available on Windows"
683
)
684
def test_read_database_uri_parameterised_multiple(
685
params: list[str], param_value: Any, tmp_sqlite_db: Path
686
) -> None:
687
param_1, param_2 = params
688
alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}")
689
uri = alchemy_engine.url.render_as_string(hide_password=False)
690
query = """
691
SELECT CAST(STRFTIME('%Y',"date") AS INT) as "year", name, value
692
FROM test_data
693
WHERE value BETWEEN {param_1} AND {param_2}
694
"""
695
expected_frame = pl.DataFrame({"year": [2020], "name": ["misc"], "value": [100.0]})
696
697
# test URI read method (ADBC only)
698
assert_frame_equal(
699
expected_frame,
700
pl.read_database_uri(
701
query.format(param_1=param_1, param_2=param_2),
702
uri=uri,
703
engine="adbc",
704
execute_options={"parameters": param_value},
705
),
706
)
707
708
# no connectorx support for execute_options
709
with pytest.raises(
710
ValueError,
711
match="connectorx.*does not support.*execute_options",
712
):
713
pl.read_database_uri(
714
query.format(param_1="?", param_2="?"),
715
uri=uri,
716
engine="connectorx",
717
execute_options={"parameters": (90, 100)},
718
)
719
720
721
@pytest.mark.parametrize(
722
("driver", "batch_size", "iter_batches", "expected_call"),
723
[
724
("snowflake", None, False, "fetch_arrow_all"),
725
("snowflake", 10_000, False, "fetch_arrow_all"),
726
("snowflake", 10_000, True, "fetch_arrow_batches"),
727
("databricks", None, False, "fetchall_arrow"),
728
("databricks", 25_000, False, "fetchall_arrow"),
729
("databricks", 25_000, True, "fetchmany_arrow"),
730
("turbodbc", None, False, "fetchallarrow"),
731
("turbodbc", 50_000, False, "fetchallarrow"),
732
("turbodbc", 50_000, True, "fetcharrowbatches"),
733
("adbc_driver_postgresql", None, False, "fetch_arrow_table"),
734
("adbc_driver_postgresql", 75_000, False, "fetch_arrow_table"),
735
("adbc_driver_postgresql", 75_000, True, "fetch_arrow_table"),
736
],
737
)
738
def test_read_database_mocked(
739
driver: str, batch_size: int | None, iter_batches: bool, expected_call: str
740
) -> None:
741
# since we don't have access to snowflake/databricks/etc from CI we
742
# mock them so we can check that we're calling the expected methods
743
arrow = pl.DataFrame({"x": [1, 2, 3], "y": ["aa", "bb", "cc"]}).to_arrow()
744
745
reg = ARROW_DRIVER_REGISTRY.get(driver, {}) # type: ignore[var-annotated]
746
exact_batch_size = reg.get("exact_batch_size", False)
747
repeat_batch_calls = reg.get("repeat_batch_calls", False)
748
749
mc = MockConnection(
750
driver,
751
batch_size,
752
test_data=arrow,
753
repeat_batch_calls=repeat_batch_calls,
754
exact_batch_size=exact_batch_size, # type: ignore[arg-type]
755
)
756
res = pl.read_database(
757
query="SELECT * FROM test_data",
758
connection=mc,
759
iter_batches=iter_batches,
760
batch_size=batch_size,
761
)
762
if iter_batches:
763
assert isinstance(res, GeneratorType)
764
res = pl.concat(res)
765
766
res = cast(pl.DataFrame, res)
767
assert expected_call in mc.cursor().called
768
assert res.rows() == [(1, "aa"), (2, "bb"), (3, "cc")]
769
770
771
@pytest.mark.parametrize(
772
(
773
"read_method",
774
"query",
775
"protocol",
776
"errclass",
777
"errmsg",
778
"engine",
779
"execute_options",
780
"pre_execution_query",
781
"kwargs",
782
),
783
[
784
pytest.param(
785
*ExceptionTestParams(
786
read_method="read_database_uri",
787
query="SELECT * FROM test_data",
788
protocol="sqlite",
789
errclass=ValueError,
790
errmsg="engine must be one of {'connectorx', 'adbc'}, got 'not_an_engine'",
791
engine="not_an_engine",
792
),
793
id="Not an available sql engine",
794
),
795
pytest.param(
796
*ExceptionTestParams(
797
read_method="read_database_uri",
798
query=["SELECT * FROM test_data", "SELECT * FROM test_data"],
799
protocol="sqlite",
800
errclass=ValueError,
801
errmsg="only a single SQL query string is accepted for adbc",
802
engine="adbc",
803
),
804
id="Unavailable list of queries for adbc",
805
),
806
pytest.param(
807
*ExceptionTestParams(
808
read_method="read_database_uri",
809
query="SELECT * FROM test_data",
810
protocol="mysql",
811
errclass=ModuleNotFoundError,
812
errmsg="ADBC 'adbc_driver_mysql' driver not detected.",
813
engine="adbc",
814
),
815
id="Unavailable adbc driver",
816
),
817
pytest.param(
818
*ExceptionTestParams(
819
read_method="read_database_uri",
820
query="SELECT * FROM test_data",
821
protocol=sqlite3.connect(":memory:"),
822
errclass=TypeError,
823
errmsg="expected connection to be a URI string",
824
engine="adbc",
825
),
826
id="Invalid connection URI",
827
),
828
pytest.param(
829
*ExceptionTestParams(
830
read_method="read_database",
831
query="SELECT * FROM imaginary_table",
832
protocol=sqlite3.connect(":memory:"),
833
errclass=sqlite3.OperationalError,
834
errmsg="no such table: imaginary_table",
835
),
836
id="Invalid query (unrecognised table name)",
837
),
838
pytest.param(
839
*ExceptionTestParams(
840
read_method="read_database",
841
query="SELECT * FROM imaginary_table",
842
protocol=sys.getsizeof, # not a connection
843
errclass=TypeError,
844
errmsg="Unrecognised connection .* no 'execute' or 'cursor' method",
845
),
846
id="Invalid read DB kwargs",
847
),
848
pytest.param(
849
*ExceptionTestParams(
850
read_method="read_database",
851
query="/* tag: misc */ INSERT INTO xyz VALUES ('polars')",
852
protocol=sqlite3.connect(":memory:"),
853
errclass=UnsuitableSQLError,
854
errmsg="INSERT statements are not valid 'read' queries",
855
),
856
id="Invalid statement type",
857
),
858
pytest.param(
859
*ExceptionTestParams(
860
read_method="read_database",
861
query="DELETE FROM xyz WHERE id = 'polars'",
862
protocol=sqlite3.connect(":memory:"),
863
errclass=UnsuitableSQLError,
864
errmsg="DELETE statements are not valid 'read' queries",
865
),
866
id="Invalid statement type",
867
),
868
pytest.param(
869
*ExceptionTestParams(
870
read_method="read_database",
871
query="SELECT * FROM sqlite_master",
872
protocol=sqlite3.connect(":memory:"),
873
errclass=ValueError,
874
kwargs={"iter_batches": True},
875
errmsg="Cannot set `iter_batches` without also setting a non-zero `batch_size`",
876
),
877
id="Invalid batch_size",
878
),
879
pytest.param(
880
*ExceptionTestParams(
881
read_method="read_database",
882
engine="adbc",
883
query="SELECT * FROM test_data",
884
protocol=sqlite3.connect(":memory:"),
885
errclass=TypeError,
886
errmsg=r"unexpected keyword argument 'partition_on'",
887
kwargs={"partition_on": "id"},
888
),
889
id="Invalid kwargs",
890
),
891
pytest.param(
892
*ExceptionTestParams(
893
read_method="read_database",
894
engine="adbc",
895
query="SELECT * FROM test_data",
896
protocol="{not:a, valid:odbc_string}",
897
errclass=ValueError,
898
errmsg=r"unable to identify string connection as valid ODBC",
899
),
900
id="Invalid ODBC string",
901
),
902
pytest.param(
903
*ExceptionTestParams(
904
read_method="read_database_uri",
905
query="SELECT * FROM test_data",
906
protocol="sqlite",
907
errclass=ValueError,
908
errmsg="the 'adbc' engine does not support use of `pre_execution_query`",
909
engine="adbc",
910
pre_execution_query="SET statement_timeout = 2151",
911
),
912
id="Unavailable `pre_execution_query` for adbc",
913
),
914
],
915
)
916
def test_read_database_exceptions(
917
read_method: str,
918
query: str,
919
protocol: Any,
920
errclass: type[Exception],
921
errmsg: str,
922
engine: DbReadEngine | None,
923
execute_options: dict[str, Any] | None,
924
pre_execution_query: str | list[str] | None,
925
kwargs: dict[str, Any] | None,
926
) -> None:
927
if read_method == "read_database_uri":
928
conn = f"{protocol}://test" if isinstance(protocol, str) else protocol
929
params = {
930
"uri": conn,
931
"query": query,
932
"engine": engine,
933
"pre_execution_query": pre_execution_query,
934
}
935
else:
936
params = {"connection": protocol, "query": query}
937
if execute_options:
938
params["execute_options"] = execute_options
939
if kwargs is not None:
940
params.update(kwargs)
941
942
read_database = getattr(pl, read_method)
943
with pytest.raises(errclass, match=errmsg):
944
read_database(**params)
945
946
947
@pytest.mark.parametrize(
948
"query",
949
[
950
"SELECT 1, 1 FROM test_data",
951
'SELECT 1 AS "n", 2 AS "n" FROM test_data',
952
'SELECT name, value AS "name" FROM test_data',
953
],
954
)
955
def test_read_database_duplicate_column_error(tmp_sqlite_db: Path, query: str) -> None:
956
alchemy_conn = create_engine(f"sqlite:///{tmp_sqlite_db}").connect()
957
with pytest.raises(
958
DuplicateError,
959
match="column .+ appears more than once in the query/result cursor",
960
):
961
pl.read_database(query, connection=alchemy_conn)
962
963
964
@pytest.mark.parametrize(
965
"uri",
966
[
967
"fakedb://123:456@account/database/schema?warehouse=warehouse&role=role",
968
"fakedb://my#%us3r:p433w0rd@not_a_real_host:9999/database",
969
],
970
)
971
def test_read_database_cx_credentials(uri: str) -> None:
972
with pytest.raises(RuntimeError, match=r"Source.*not supported"):
973
pl.read_database_uri("SELECT * FROM data", uri=uri, engine="connectorx")
974
975
976
@pytest.mark.skipif(
977
sys.platform == "win32",
978
reason="kuzu segfaults on windows: https://github.com/pola-rs/polars/actions/runs/12502055945/job/34880479875?pr=20462",
979
)
980
@pytest.mark.write_disk
981
def test_read_kuzu_graph_database(tmp_path: Path, io_files_path: Path) -> None:
982
import kuzu
983
984
tmp_path.mkdir(exist_ok=True)
985
if (kuzu_test_db := (tmp_path / "kuzu_test.db")).exists():
986
kuzu_test_db.unlink()
987
988
test_db = str(kuzu_test_db).replace("\\", "/")
989
990
db = kuzu.Database(test_db)
991
conn = kuzu.Connection(db)
992
conn.execute("CREATE NODE TABLE User(name STRING, age UINT64, PRIMARY KEY (name))")
993
conn.execute("CREATE REL TABLE Follows(FROM User TO User, since INT64)")
994
995
users = str(io_files_path / "graph-data" / "user.csv").replace("\\", "/")
996
follows = str(io_files_path / "graph-data" / "follows.csv").replace("\\", "/")
997
998
conn.execute(f'COPY User FROM "{users}"')
999
conn.execute(f'COPY Follows FROM "{follows}"')
1000
1001
# basic: single relation
1002
df1 = pl.read_database(
1003
query="MATCH (u:User) RETURN u.name, u.age",
1004
connection=conn,
1005
)
1006
assert_frame_equal(
1007
df1,
1008
pl.DataFrame(
1009
{
1010
"u.name": ["Adam", "Karissa", "Zhang", "Noura"],
1011
"u.age": [30, 40, 50, 25],
1012
},
1013
schema={"u.name": pl.Utf8, "u.age": pl.UInt64},
1014
),
1015
)
1016
1017
# join: connected edges/relations
1018
df2 = pl.read_database(
1019
query="MATCH (a:User)-[f:Follows]->(b:User) RETURN a.name, f.since, b.name",
1020
connection=conn,
1021
schema_overrides={"f.since": pl.Int16},
1022
)
1023
assert_frame_equal(
1024
df2,
1025
pl.DataFrame(
1026
{
1027
"a.name": ["Adam", "Adam", "Karissa", "Zhang"],
1028
"f.since": [2020, 2020, 2021, 2022],
1029
"b.name": ["Karissa", "Zhang", "Zhang", "Noura"],
1030
},
1031
schema={"a.name": pl.Utf8, "f.since": pl.Int16, "b.name": pl.Utf8},
1032
),
1033
)
1034
1035
# empty: no results for the given query
1036
df3 = pl.read_database(
1037
query="MATCH (a:User)-[f:Follows]->(b:User) WHERE a.name = '🔎️' RETURN a.name, f.since, b.name",
1038
connection=conn,
1039
)
1040
assert_frame_equal(
1041
df3,
1042
pl.DataFrame(
1043
schema={"a.name": pl.Utf8, "f.since": pl.Int64, "b.name": pl.Utf8}
1044
),
1045
)
1046
1047
1048
def test_sqlalchemy_row_init(tmp_sqlite_db: Path) -> None:
1049
expected_frame = pl.DataFrame(
1050
{
1051
"id": [1, 2],
1052
"name": ["misc", "other"],
1053
"value": [100.0, -99.5],
1054
"date": ["2020-01-01", "2021-12-31"],
1055
}
1056
)
1057
alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}")
1058
query = text("SELECT * FROM test_data ORDER BY name")
1059
1060
with alchemy_engine.connect() as conn:
1061
# note: sqlalchemy `Row` is a NamedTuple-like object; it additionally has
1062
# a `_mapping` attribute that returns a `RowMapping` dict-like object. we
1063
# validate frame/series init from each flavour of query result.
1064
query_result = list(conn.execute(query))
1065
for df in (
1066
pl.DataFrame(query_result),
1067
pl.DataFrame([row._mapping for row in query_result]),
1068
pl.from_records([row._mapping for row in query_result]),
1069
):
1070
assert_frame_equal(expected_frame, df)
1071
1072
expected_series = expected_frame.to_struct()
1073
for s in (
1074
pl.Series(query_result),
1075
pl.Series([row._mapping for row in query_result]),
1076
):
1077
assert_series_equal(expected_series, s)
1078
1079
1080
@patch("polars.io.database._utils.from_arrow")
1081
@patch("polars.io.database._utils.import_optional")
1082
def test_read_database_uri_pre_execution_query_success(
1083
import_mock: Mock, from_arrow_mock: Mock
1084
) -> None:
1085
cx_mock = Mock()
1086
cx_mock.__version__ = "0.4.2"
1087
1088
import_mock.return_value = cx_mock
1089
1090
pre_execution_query = "SET statement_timeout = 2151"
1091
1092
pl.read_database_uri(
1093
query="SELECT 1",
1094
uri="mysql://test",
1095
engine="connectorx",
1096
pre_execution_query=pre_execution_query,
1097
)
1098
1099
assert (
1100
cx_mock.read_sql.call_args.kwargs["pre_execution_query"] == pre_execution_query
1101
)
1102
1103
1104
@patch("polars.io.database._utils.import_optional")
1105
def test_read_database_uri_pre_execution_not_supported_exception(
1106
import_mock: Mock,
1107
) -> None:
1108
cx_mock = Mock()
1109
cx_mock.__version__ = "0.4.0"
1110
1111
import_mock.return_value = cx_mock
1112
1113
with (
1114
pytest.raises(
1115
ValueError,
1116
match="'pre_execution_query' is only supported in connectorx version 0.4.2 or later",
1117
),
1118
):
1119
pl.read_database_uri(
1120
query="SELECT 1",
1121
uri="mysql://test",
1122
engine="connectorx",
1123
pre_execution_query="SET statement_timeout = 2151",
1124
)
1125
1126
1127
@patch("polars.io.database._utils.from_arrow")
1128
@patch("polars.io.database._utils.import_optional")
1129
def test_read_database_uri_pre_execution_query_not_supported_success(
1130
import_mock: Mock, from_arrow_mock: Mock
1131
) -> None:
1132
cx_mock = Mock()
1133
cx_mock.__version__ = "0.4.0"
1134
1135
import_mock.return_value = cx_mock
1136
1137
pl.read_database_uri(
1138
query="SELECT 1",
1139
uri="mysql://test",
1140
engine="connectorx",
1141
)
1142
1143
assert cx_mock.read_sql.call_args.kwargs.get("pre_execution_query") is None
1144
1145