Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/io/database/test_write.py
6939 views
1
from __future__ import annotations
2
3
import sys
4
from typing import TYPE_CHECKING, Any
5
6
import pytest
7
from sqlalchemy import create_engine
8
from sqlalchemy.orm import Session
9
from sqlalchemy.pool import NullPool
10
11
import polars as pl
12
from polars.io.database._utils import _open_adbc_connection
13
from polars.testing import assert_frame_equal
14
15
if TYPE_CHECKING:
16
from pathlib import Path
17
18
from polars._typing import DbWriteEngine
19
20
21
@pytest.mark.write_disk
22
@pytest.mark.parametrize(
23
("engine", "uri_connection"),
24
[
25
("sqlalchemy", True),
26
("sqlalchemy", False),
27
pytest.param(
28
"adbc",
29
True,
30
marks=pytest.mark.skipif(
31
sys.platform == "win32",
32
reason="adbc not available on Windows",
33
),
34
),
35
pytest.param(
36
"adbc",
37
False,
38
marks=pytest.mark.skipif(
39
sys.platform == "win32",
40
reason="adbc not available on Windows",
41
),
42
),
43
],
44
)
45
class TestWriteDatabase:
46
"""Database write tests that share common pytest/parametrize options."""
47
48
@staticmethod
49
def _get_connection(uri: str, engine: DbWriteEngine, uri_connection: bool) -> Any:
50
if uri_connection:
51
return uri
52
elif engine == "sqlalchemy":
53
return create_engine(uri)
54
else:
55
return _open_adbc_connection(uri)
56
57
def test_write_database_create(
58
self, engine: DbWriteEngine, uri_connection: bool, tmp_path: Path
59
) -> None:
60
"""Test basic database table creation."""
61
df = pl.DataFrame(
62
{
63
"id": [1234, 5678],
64
"name": ["misc", "other"],
65
"value": [1000.0, -9999.0],
66
}
67
)
68
tmp_path.mkdir(exist_ok=True)
69
test_db_uri = f"sqlite:///{tmp_path}/test_create_{int(uri_connection)}.db"
70
71
table_name = "test_create"
72
conn = self._get_connection(test_db_uri, engine, uri_connection)
73
74
assert (
75
df.write_database(
76
table_name=table_name,
77
connection=conn,
78
engine=engine,
79
)
80
== 2
81
)
82
result = pl.read_database(
83
query=f"SELECT * FROM {table_name}",
84
connection=create_engine(test_db_uri),
85
)
86
assert_frame_equal(result, df)
87
88
if hasattr(conn, "close"):
89
conn.close()
90
91
def test_write_database_append_replace(
92
self, engine: DbWriteEngine, uri_connection: bool, tmp_path: Path
93
) -> None:
94
"""Test append/replace ops against existing database table."""
95
df = pl.DataFrame(
96
{
97
"key": ["xx", "yy", "zz"],
98
"value": [123, None, 789],
99
"other": [5.5, 7.0, None],
100
}
101
)
102
tmp_path.mkdir(exist_ok=True)
103
test_db_uri = f"sqlite:///{tmp_path}/test_append_{int(uri_connection)}.db"
104
105
table_name = "test_append"
106
conn = self._get_connection(test_db_uri, engine, uri_connection)
107
108
assert (
109
df.write_database(
110
table_name=table_name,
111
connection=conn,
112
engine=engine,
113
)
114
== 3
115
)
116
with pytest.raises(Exception): # noqa: B017
117
df.write_database(
118
table_name=table_name,
119
connection=conn,
120
if_table_exists="fail",
121
engine=engine,
122
)
123
124
assert (
125
df.write_database(
126
table_name=table_name,
127
connection=conn,
128
if_table_exists="replace",
129
engine=engine,
130
)
131
== 3
132
)
133
result = pl.read_database(
134
query=f"SELECT * FROM {table_name}",
135
connection=create_engine(test_db_uri),
136
)
137
assert_frame_equal(result, df)
138
139
assert (
140
df[:2].write_database(
141
table_name=table_name,
142
connection=conn,
143
if_table_exists="append",
144
engine=engine,
145
)
146
== 2
147
)
148
result = pl.read_database(
149
query=f"SELECT * FROM {table_name}",
150
connection=create_engine(test_db_uri),
151
)
152
assert_frame_equal(result, pl.concat([df, df[:2]]))
153
154
if engine == "adbc" and not uri_connection:
155
assert conn._closed is False
156
157
if hasattr(conn, "close"):
158
conn.close()
159
160
def test_write_database_create_quoted_tablename(
161
self, engine: DbWriteEngine, uri_connection: bool, tmp_path: Path
162
) -> None:
163
"""Test parsing/handling of quoted database table names."""
164
df = pl.DataFrame(
165
{
166
"col x": [100, 200, 300],
167
"col y": ["a", "b", "c"],
168
}
169
)
170
tmp_path.mkdir(exist_ok=True)
171
test_db_uri = f"sqlite:///{tmp_path}/test_create_quoted.db"
172
173
# table name has some special chars, so requires quoting, and
174
# is explicitly qualified with the sqlite 'main' schema
175
qualified_table_name = f'main."test-append-{engine}-{int(uri_connection)}"'
176
conn = self._get_connection(test_db_uri, engine, uri_connection)
177
178
assert (
179
df.write_database(
180
table_name=qualified_table_name,
181
connection=conn,
182
engine=engine,
183
)
184
== 3
185
)
186
assert (
187
df.write_database(
188
table_name=qualified_table_name,
189
connection=conn,
190
if_table_exists="replace",
191
engine=engine,
192
)
193
== 3
194
)
195
result = pl.read_database(
196
query=f"SELECT * FROM {qualified_table_name}",
197
connection=create_engine(test_db_uri),
198
)
199
assert_frame_equal(result, df)
200
201
if engine == "adbc" and not uri_connection:
202
assert conn._closed is False
203
204
if hasattr(conn, "close"):
205
conn.close()
206
207
def test_write_database_errors(
208
self, engine: DbWriteEngine, uri_connection: bool, tmp_path: Path
209
) -> None:
210
"""Confirm that expected errors are raised."""
211
df = pl.DataFrame({"colx": [1, 2, 3]})
212
213
with pytest.raises(
214
ValueError, match="`table_name` appears to be invalid: 'w.x.y.z'"
215
):
216
df.write_database(
217
connection="sqlite:///:memory:",
218
table_name="w.x.y.z",
219
engine=engine,
220
)
221
222
with pytest.raises(
223
ValueError,
224
match="`if_table_exists` must be one of .* got 'do_something'",
225
):
226
df.write_database(
227
connection="sqlite:///:memory:",
228
table_name="main.test_errs",
229
if_table_exists="do_something", # type: ignore[arg-type]
230
engine=engine,
231
)
232
233
with pytest.raises(
234
TypeError,
235
match="unrecognised connection type.*",
236
):
237
df.write_database(connection=True, table_name="misc", engine=engine) # type: ignore[arg-type]
238
239
def test_write_database_adbc_missing_driver_error(
240
self, engine: DbWriteEngine, uri_connection: bool, tmp_path: Path
241
) -> None:
242
# Skip for sqlalchemy
243
if engine == "sqlalchemy":
244
return
245
df = pl.DataFrame({"colx": [1, 2, 3]})
246
with pytest.raises(
247
ModuleNotFoundError, match="ADBC 'adbc_driver_mysql' driver not detected."
248
):
249
df.write_database(
250
table_name="my_schema.my_table",
251
connection="mysql:///:memory:",
252
engine=engine,
253
)
254
255
256
@pytest.mark.write_disk
257
def test_write_database_using_sa_session(tmp_path: str) -> None:
258
df = pl.DataFrame(
259
{
260
"key": ["xx", "yy", "zz"],
261
"value": [123, None, 789],
262
"other": [5.5, 7.0, None],
263
}
264
)
265
table_name = "test_sa_session"
266
test_db_uri = f"sqlite:///{tmp_path}/test_sa_session.db"
267
engine = create_engine(test_db_uri, poolclass=NullPool)
268
with Session(engine) as session:
269
df.write_database(table_name, session)
270
session.commit()
271
272
with Session(engine) as session:
273
result = pl.read_database(
274
query=f"select * from {table_name}", connection=session
275
)
276
277
assert_frame_equal(result, df)
278
279
280
@pytest.mark.write_disk
281
@pytest.mark.parametrize("pass_connection", [True, False])
282
def test_write_database_sa_rollback(tmp_path: str, pass_connection: bool) -> None:
283
df = pl.DataFrame(
284
{
285
"key": ["xx", "yy", "zz"],
286
"value": [123, None, 789],
287
"other": [5.5, 7.0, None],
288
}
289
)
290
table_name = "test_sa_rollback"
291
test_db_uri = f"sqlite:///{tmp_path}/test_sa_rollback.db"
292
engine = create_engine(test_db_uri, poolclass=NullPool)
293
with Session(engine) as session:
294
if pass_connection:
295
conn = session.connection()
296
df.write_database(table_name, conn)
297
else:
298
df.write_database(table_name, session)
299
session.rollback()
300
301
with Session(engine) as session:
302
count = pl.read_database(
303
query=f"select count(*) from {table_name}", connection=session
304
).item(0, 0)
305
306
assert isinstance(count, int)
307
assert count == 0
308
309
310
@pytest.mark.write_disk
311
@pytest.mark.parametrize("pass_connection", [True, False])
312
def test_write_database_sa_commit(tmp_path: str, pass_connection: bool) -> None:
313
df = pl.DataFrame(
314
{
315
"key": ["xx", "yy", "zz"],
316
"value": [123, None, 789],
317
"other": [5.5, 7.0, None],
318
}
319
)
320
table_name = "test_sa_commit"
321
test_db_uri = f"sqlite:///{tmp_path}/test_sa_commit.db"
322
engine = create_engine(test_db_uri, poolclass=NullPool)
323
with Session(engine) as session:
324
if pass_connection:
325
conn = session.connection()
326
df.write_database(table_name, conn)
327
else:
328
df.write_database(table_name, session)
329
session.commit()
330
331
with Session(engine) as session:
332
result = pl.read_database(
333
query=f"select * from {table_name}", connection=session
334
)
335
336
assert_frame_equal(result, df)
337
338
339
@pytest.mark.skipif(sys.platform == "win32", reason="adbc not available on Windows")
340
def test_write_database_adbc_temporary_table() -> None:
341
"""Confirm that execution_options are passed along to create temporary tables."""
342
df = pl.DataFrame({"colx": [1, 2, 3]})
343
temp_tbl_name = "should_be_temptable"
344
expected_temp_table_create_sql = (
345
"""CREATE TABLE "should_be_temptable" ("colx" INTEGER)"""
346
)
347
348
# test with sqlite in memory
349
conn = _open_adbc_connection("sqlite:///:memory:")
350
assert (
351
df.write_database(
352
temp_tbl_name,
353
connection=conn,
354
if_table_exists="fail",
355
engine_options={"temporary": True},
356
)
357
== 3
358
)
359
temp_tbl_sql_df = pl.read_database(
360
"select sql from sqlite_temp_master where type='table' and tbl_name = ?",
361
connection=conn,
362
execute_options={"parameters": [temp_tbl_name]},
363
)
364
assert temp_tbl_sql_df.shape[0] == 1, "no temp table created"
365
actual_temp_table_create_sql = temp_tbl_sql_df["sql"][0]
366
assert expected_temp_table_create_sql == actual_temp_table_create_sql
367
368
if hasattr(conn, "close"):
369
conn.close()
370
371