Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/io/database/test_async.py
6939 views
1
from __future__ import annotations
2
3
import asyncio
4
from math import ceil
5
from types import ModuleType
6
from typing import TYPE_CHECKING, Any, overload
7
8
import pytest
9
import sqlalchemy
10
from sqlalchemy.ext.asyncio import create_async_engine
11
12
import polars as pl
13
from polars._utils.various import parse_version
14
from polars.io.database._utils import _run_async
15
from polars.testing import assert_frame_equal
16
from tests.unit.conftest import mock_module_import
17
18
if TYPE_CHECKING:
19
from collections.abc import Iterable
20
from pathlib import Path
21
22
SURREAL_MOCK_DATA: list[dict[str, Any]] = [
23
{
24
"id": "item:8xj31jfpdkf9gvmxdxpi",
25
"name": "abc",
26
"tags": ["polars"],
27
"checked": False,
28
},
29
{
30
"id": "item:l59k19swv2adsv4q04cj",
31
"name": "mno",
32
"tags": ["async"],
33
"checked": None,
34
},
35
{
36
"id": "item:w831f1oyqnwztv5q03em",
37
"name": "xyz",
38
"tags": ["stroop", "wafel"],
39
"checked": True,
40
},
41
]
42
43
44
class MockSurrealConnection:
45
"""Mock SurrealDB connection/client object."""
46
47
__module__ = "surrealdb"
48
49
def __init__(self, url: str, mock_data: list[dict[str, Any]]) -> None:
50
self._mock_data = mock_data.copy()
51
self.url = url
52
53
async def __aenter__(self) -> Any:
54
await self.connect()
55
return self
56
57
async def __aexit__(self, *args: object, **kwargs: Any) -> None:
58
await self.close()
59
60
async def close(self) -> None:
61
pass
62
63
async def connect(self) -> None:
64
pass
65
66
async def use(self, namespace: str, database: str) -> None:
67
pass
68
69
async def query(
70
self, query: str, variables: dict[str, Any] | None = None
71
) -> list[dict[str, Any]]:
72
return [{"result": self._mock_data, "status": "OK", "time": "32.083µs"}]
73
74
75
class MockedSurrealModule(ModuleType):
76
"""Mock SurrealDB module; enables internal `isinstance` check for AsyncSurrealDB."""
77
78
AsyncSurrealDB = MockSurrealConnection
79
80
81
@pytest.mark.skipif(
82
parse_version(sqlalchemy.__version__) < (2, 0),
83
reason="SQLAlchemy 2.0+ required for async tests",
84
)
85
def test_read_async(tmp_sqlite_db: Path) -> None:
86
# confirm that we can load frame data from the core sqlalchemy async
87
# primitives: AsyncConnection, AsyncEngine, and async_sessionmaker
88
from sqlalchemy.ext.asyncio import async_sessionmaker
89
90
async_engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_sqlite_db}")
91
async_connection = async_engine.connect()
92
async_session = async_sessionmaker(async_engine)
93
async_session_inst = async_session()
94
95
expected_frame = pl.DataFrame(
96
{"id": [2, 1], "name": ["other", "misc"], "value": [-99.5, 100.0]}
97
)
98
async_conn: Any
99
for async_conn in (
100
async_engine,
101
async_connection,
102
async_session,
103
async_session_inst,
104
):
105
if async_conn in (async_session, async_session_inst):
106
constraint, execute_opts = "", {}
107
else:
108
constraint = "WHERE value > :n"
109
execute_opts = {"parameters": {"n": -1000}}
110
111
df = pl.read_database(
112
query=f"""
113
SELECT id, name, value
114
FROM test_data {constraint}
115
ORDER BY id DESC
116
""",
117
connection=async_conn,
118
execute_options=execute_opts,
119
)
120
assert_frame_equal(expected_frame, df)
121
122
123
async def _nested_async_test(tmp_sqlite_db: Path) -> pl.DataFrame:
124
async_engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_sqlite_db}")
125
return pl.read_database(
126
query="SELECT id, name FROM test_data ORDER BY id",
127
connection=async_engine.connect(),
128
)
129
130
131
@pytest.mark.skipif(
132
parse_version(sqlalchemy.__version__) < (2, 0),
133
reason="SQLAlchemy 2.0+ required for async tests",
134
)
135
def test_read_async_nested(tmp_sqlite_db: Path) -> None:
136
# This tests validates that we can handle nested async calls
137
expected_frame = pl.DataFrame({"id": [1, 2], "name": ["misc", "other"]})
138
df = asyncio.run(_nested_async_test(tmp_sqlite_db))
139
assert_frame_equal(expected_frame, df)
140
141
142
@overload
143
async def _surreal_query_as_frame(
144
url: str, query: str, batch_size: None
145
) -> pl.DataFrame: ...
146
147
148
@overload
149
async def _surreal_query_as_frame(
150
url: str, query: str, batch_size: int
151
) -> Iterable[pl.DataFrame]: ...
152
153
154
async def _surreal_query_as_frame(
155
url: str, query: str, batch_size: int | None
156
) -> pl.DataFrame | Iterable[pl.DataFrame]:
157
batch_params = (
158
{"iter_batches": True, "batch_size": batch_size} if batch_size else {}
159
)
160
async with MockSurrealConnection(url=url, mock_data=SURREAL_MOCK_DATA) as client:
161
await client.use(namespace="test", database="test")
162
return pl.read_database( # type: ignore[no-any-return,call-overload]
163
query=query,
164
connection=client,
165
**batch_params,
166
)
167
168
169
@pytest.mark.parametrize("batch_size", [None, 1, 2, 3, 4])
170
def test_surrealdb_fetchall(batch_size: int | None) -> None:
171
with mock_module_import("surrealdb", MockedSurrealModule("surrealdb")):
172
df_expected = pl.DataFrame(SURREAL_MOCK_DATA)
173
res = asyncio.run(
174
_surreal_query_as_frame(
175
url="ws://localhost:8000/rpc",
176
query="SELECT * FROM item",
177
batch_size=batch_size,
178
)
179
)
180
if batch_size:
181
frames = list(res) # type: ignore[call-overload]
182
n_mock_rows = len(SURREAL_MOCK_DATA)
183
assert len(frames) == ceil(n_mock_rows / batch_size)
184
assert_frame_equal(df_expected[:batch_size], frames[0])
185
else:
186
assert_frame_equal(df_expected, res) # type: ignore[arg-type]
187
188
189
def test_async_nested_captured_loop_21263() -> None:
190
# Tests awaiting a future that has "captured" the original event loop from
191
# within a `_run_async` context.
192
async def test_impl() -> None:
193
loop = asyncio.get_running_loop()
194
task = loop.create_task(asyncio.sleep(0))
195
196
_run_async(await_task(task))
197
198
async def await_task(task: Any) -> None:
199
await task
200
201
asyncio.run(test_impl())
202
203