Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/io/database/test_async.py
8407 views
1
from __future__ import annotations
2
3
import asyncio
4
from math import ceil
5
from types import ModuleType
6
from typing import TYPE_CHECKING, Any, overload
7
8
import pytest
9
import sqlalchemy
10
from sqlalchemy.ext.asyncio import create_async_engine
11
12
import polars as pl
13
from polars._utils.various import parse_version
14
from polars.io.database._utils import _run_async
15
from polars.testing import assert_frame_equal
16
from tests.unit.conftest import mock_module_import
17
18
if TYPE_CHECKING:
19
from collections.abc import Iterable
20
from pathlib import Path
21
22
23
SURREAL_MOCK_DATA: list[dict[str, Any]] = [
24
{
25
"id": "item:8xj31jfpdkf9gvmxdxpi",
26
"name": "abc",
27
"tags": ["polars"],
28
"checked": False,
29
},
30
{
31
"id": "item:l59k19swv2adsv4q04cj",
32
"name": "mno",
33
"tags": ["async"],
34
"checked": None,
35
},
36
{
37
"id": "item:w831f1oyqnwztv5q03em",
38
"name": "xyz",
39
"tags": ["stroop", "wafel"],
40
"checked": True,
41
},
42
]
43
44
45
class MockSurrealConnection:
46
"""Mock SurrealDB connection/client object."""
47
48
__module__ = "surrealdb"
49
50
def __init__(self, url: str, mock_data: list[dict[str, Any]]) -> None:
51
self._mock_data = mock_data.copy()
52
self.url = url
53
54
async def __aenter__(self) -> Any:
55
await self.connect()
56
return self
57
58
async def __aexit__(self, *args: object, **kwargs: Any) -> None:
59
await self.close()
60
61
async def close(self) -> None:
62
pass
63
64
async def connect(self) -> None:
65
pass
66
67
async def use(self, namespace: str, database: str) -> None:
68
pass
69
70
async def query(
71
self, query: str, variables: dict[str, Any] | None = None
72
) -> list[dict[str, Any]]:
73
return [{"result": self._mock_data, "status": "OK", "time": "32.083µs"}]
74
75
76
class MockedSurrealModule(ModuleType):
77
"""Mock SurrealDB module; enables internal `isinstance` check for AsyncSurrealDB."""
78
79
AsyncSurrealDB = MockSurrealConnection
80
81
82
@pytest.mark.skipif(
83
parse_version(sqlalchemy.__version__) < (2, 0),
84
reason="SQLAlchemy 2.0+ required for async tests",
85
)
86
def test_read_async(tmp_sqlite_db: Path) -> None:
87
# confirm that we can load frame data from the core sqlalchemy async
88
# primitives: AsyncEngine, AsyncConnection, async_sessionmaker, and AsyncSession
89
from sqlalchemy.ext.asyncio import async_sessionmaker
90
91
async def _test_impl() -> None:
92
async_engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_sqlite_db}")
93
async_connection = await async_engine.connect()
94
try:
95
async_session = async_sessionmaker(async_engine)
96
async_session_inst = async_session()
97
98
expected_frame = pl.DataFrame(
99
{"id": [2, 1], "name": ["other", "misc"], "value": [-99.5, 100.0]}
100
)
101
async_conn: Any
102
for async_conn in (
103
async_engine,
104
async_connection,
105
async_session,
106
async_session_inst,
107
):
108
if async_conn in (async_session, async_session_inst):
109
constraint, execute_opts = "", {}
110
else:
111
constraint = "WHERE value > :n"
112
execute_opts = {"parameters": {"n": -1000}}
113
114
df = pl.read_database(
115
query=f"""
116
SELECT id, name, value
117
FROM test_data {constraint}
118
ORDER BY id DESC
119
""",
120
connection=async_conn,
121
execute_options=execute_opts,
122
)
123
assert_frame_equal(expected_frame, df)
124
finally:
125
await async_session_inst.close()
126
await async_connection.close()
127
await async_engine.dispose()
128
129
asyncio.run(_test_impl())
130
131
132
@pytest.mark.skipif(
133
parse_version(sqlalchemy.__version__) < (2, 0),
134
reason="SQLAlchemy 2.0+ required for async tests",
135
)
136
@pytest.mark.parametrize("started", [True, False])
137
def test_read_async_nested(tmp_sqlite_db: Path, started: bool) -> None:
138
# validate that we can handle nested async calls; check
139
# this works with connections that are started/unstarted
140
async def _test_impl() -> pl.DataFrame:
141
async_engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_sqlite_db}")
142
async_connection = async_engine.connect()
143
if started:
144
async_connection = await async_connection
145
try:
146
return pl.read_database(
147
query="SELECT id, name FROM test_data ORDER BY id",
148
connection=async_connection,
149
)
150
finally:
151
await async_connection.close()
152
await async_engine.dispose()
153
154
expected_frame = pl.DataFrame({"id": [1, 2], "name": ["misc", "other"]})
155
df = asyncio.run(_test_impl())
156
assert_frame_equal(expected_frame, df)
157
158
159
@overload
160
async def _surreal_query_as_frame(
161
url: str, query: str, batch_size: None
162
) -> pl.DataFrame: ...
163
164
165
@overload
166
async def _surreal_query_as_frame(
167
url: str, query: str, batch_size: int
168
) -> Iterable[pl.DataFrame]: ...
169
170
171
async def _surreal_query_as_frame(
172
url: str, query: str, batch_size: int | None
173
) -> pl.DataFrame | Iterable[pl.DataFrame]:
174
batch_params = (
175
{"iter_batches": True, "batch_size": batch_size} if batch_size else {}
176
)
177
async with MockSurrealConnection(url=url, mock_data=SURREAL_MOCK_DATA) as client:
178
await client.use(namespace="test", database="test")
179
return pl.read_database( # type: ignore[no-any-return,call-overload]
180
query=query,
181
connection=client,
182
**batch_params,
183
)
184
185
186
@pytest.mark.parametrize("batch_size", [None, 1, 2, 3, 4])
187
def test_surrealdb_fetchall(batch_size: int | None) -> None:
188
with mock_module_import("surrealdb", MockedSurrealModule("surrealdb")):
189
df_expected = pl.DataFrame(SURREAL_MOCK_DATA)
190
res = asyncio.run(
191
_surreal_query_as_frame(
192
url="ws://localhost:8000/rpc",
193
query="SELECT * FROM item",
194
batch_size=batch_size,
195
)
196
)
197
if batch_size:
198
frames = list(res) # type: ignore[call-overload]
199
n_mock_rows = len(SURREAL_MOCK_DATA)
200
assert len(frames) == ceil(n_mock_rows / batch_size)
201
assert_frame_equal(df_expected[:batch_size], frames[0])
202
else:
203
assert_frame_equal(df_expected, res) # type: ignore[arg-type]
204
205
206
def test_async_nested_captured_loop_21263() -> None:
207
# tests awaiting a future that has "captured" the original event loop from
208
# within a `_run_async` context.
209
async def test_impl() -> None:
210
loop = asyncio.get_running_loop()
211
task = loop.create_task(asyncio.sleep(0))
212
213
_run_async(await_task(task))
214
215
async def await_task(task: Any) -> None:
216
await task
217
218
asyncio.run(test_impl())
219
220
221
def test_async_index_error_25209(tmp_sqlite_db: Path) -> None:
222
base_uri = f"sqlite:///{tmp_sqlite_db}"
223
table_name = "test_25209"
224
225
pl.select(x=1, y=2, z=3).write_database(
226
table_name,
227
connection=base_uri,
228
engine="sqlalchemy",
229
if_table_exists="replace",
230
)
231
232
async def run_async_query() -> Any:
233
async_engine = create_async_engine(f"sqlite+aio{base_uri}")
234
try:
235
return pl.read_database(
236
query=f"SELECT * FROM {table_name}",
237
connection=async_engine,
238
)
239
finally:
240
await async_engine.dispose()
241
242
async def testing() -> Any:
243
# return/await multiple queries
244
return await asyncio.gather(*(run_async_query(), run_async_query()))
245
246
df1, df2 = asyncio.run(testing())
247
248
assert_frame_equal(df1, df2)
249
assert df1.rows() == [(1, 2, 3)]
250
251