Path: blob/main/py-polars/tests/unit/io/database/test_async.py
8407 views
from __future__ import annotations12import asyncio3from math import ceil4from types import ModuleType5from typing import TYPE_CHECKING, Any, overload67import pytest8import sqlalchemy9from sqlalchemy.ext.asyncio import create_async_engine1011import polars as pl12from polars._utils.various import parse_version13from polars.io.database._utils import _run_async14from polars.testing import assert_frame_equal15from tests.unit.conftest import mock_module_import1617if TYPE_CHECKING:18from collections.abc import Iterable19from pathlib import Path202122SURREAL_MOCK_DATA: list[dict[str, Any]] = [23{24"id": "item:8xj31jfpdkf9gvmxdxpi",25"name": "abc",26"tags": ["polars"],27"checked": False,28},29{30"id": "item:l59k19swv2adsv4q04cj",31"name": "mno",32"tags": ["async"],33"checked": None,34},35{36"id": "item:w831f1oyqnwztv5q03em",37"name": "xyz",38"tags": ["stroop", "wafel"],39"checked": True,40},41]424344class MockSurrealConnection:45"""Mock SurrealDB connection/client object."""4647__module__ = "surrealdb"4849def __init__(self, url: str, mock_data: list[dict[str, Any]]) -> None:50self._mock_data = mock_data.copy()51self.url = url5253async def __aenter__(self) -> Any:54await self.connect()55return self5657async def __aexit__(self, *args: object, **kwargs: Any) -> None:58await self.close()5960async def close(self) -> None:61pass6263async def connect(self) -> None:64pass6566async def use(self, namespace: str, database: str) -> None:67pass6869async def query(70self, query: str, variables: dict[str, Any] | None = None71) -> list[dict[str, Any]]:72return [{"result": self._mock_data, "status": "OK", "time": "32.083µs"}]737475class MockedSurrealModule(ModuleType):76"""Mock SurrealDB module; enables internal `isinstance` check for AsyncSurrealDB."""7778AsyncSurrealDB = MockSurrealConnection798081@pytest.mark.skipif(82parse_version(sqlalchemy.__version__) < (2, 0),83reason="SQLAlchemy 2.0+ required for async tests",84)85def test_read_async(tmp_sqlite_db: Path) -> None:86# confirm that we can load frame data from the core sqlalchemy async87# primitives: AsyncEngine, AsyncConnection, async_sessionmaker, and AsyncSession88from sqlalchemy.ext.asyncio import async_sessionmaker8990async def _test_impl() -> None:91async_engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_sqlite_db}")92async_connection = await async_engine.connect()93try:94async_session = async_sessionmaker(async_engine)95async_session_inst = async_session()9697expected_frame = pl.DataFrame(98{"id": [2, 1], "name": ["other", "misc"], "value": [-99.5, 100.0]}99)100async_conn: Any101for async_conn in (102async_engine,103async_connection,104async_session,105async_session_inst,106):107if async_conn in (async_session, async_session_inst):108constraint, execute_opts = "", {}109else:110constraint = "WHERE value > :n"111execute_opts = {"parameters": {"n": -1000}}112113df = pl.read_database(114query=f"""115SELECT id, name, value116FROM test_data {constraint}117ORDER BY id DESC118""",119connection=async_conn,120execute_options=execute_opts,121)122assert_frame_equal(expected_frame, df)123finally:124await async_session_inst.close()125await async_connection.close()126await async_engine.dispose()127128asyncio.run(_test_impl())129130131@pytest.mark.skipif(132parse_version(sqlalchemy.__version__) < (2, 0),133reason="SQLAlchemy 2.0+ required for async tests",134)135@pytest.mark.parametrize("started", [True, False])136def test_read_async_nested(tmp_sqlite_db: Path, started: bool) -> None:137# validate that we can handle nested async calls; check138# this works with connections that are started/unstarted139async def _test_impl() -> pl.DataFrame:140async_engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_sqlite_db}")141async_connection = async_engine.connect()142if started:143async_connection = await async_connection144try:145return pl.read_database(146query="SELECT id, name FROM test_data ORDER BY id",147connection=async_connection,148)149finally:150await async_connection.close()151await async_engine.dispose()152153expected_frame = pl.DataFrame({"id": [1, 2], "name": ["misc", "other"]})154df = asyncio.run(_test_impl())155assert_frame_equal(expected_frame, df)156157158@overload159async def _surreal_query_as_frame(160url: str, query: str, batch_size: None161) -> pl.DataFrame: ...162163164@overload165async def _surreal_query_as_frame(166url: str, query: str, batch_size: int167) -> Iterable[pl.DataFrame]: ...168169170async def _surreal_query_as_frame(171url: str, query: str, batch_size: int | None172) -> pl.DataFrame | Iterable[pl.DataFrame]:173batch_params = (174{"iter_batches": True, "batch_size": batch_size} if batch_size else {}175)176async with MockSurrealConnection(url=url, mock_data=SURREAL_MOCK_DATA) as client:177await client.use(namespace="test", database="test")178return pl.read_database( # type: ignore[no-any-return,call-overload]179query=query,180connection=client,181**batch_params,182)183184185@pytest.mark.parametrize("batch_size", [None, 1, 2, 3, 4])186def test_surrealdb_fetchall(batch_size: int | None) -> None:187with mock_module_import("surrealdb", MockedSurrealModule("surrealdb")):188df_expected = pl.DataFrame(SURREAL_MOCK_DATA)189res = asyncio.run(190_surreal_query_as_frame(191url="ws://localhost:8000/rpc",192query="SELECT * FROM item",193batch_size=batch_size,194)195)196if batch_size:197frames = list(res) # type: ignore[call-overload]198n_mock_rows = len(SURREAL_MOCK_DATA)199assert len(frames) == ceil(n_mock_rows / batch_size)200assert_frame_equal(df_expected[:batch_size], frames[0])201else:202assert_frame_equal(df_expected, res) # type: ignore[arg-type]203204205def test_async_nested_captured_loop_21263() -> None:206# tests awaiting a future that has "captured" the original event loop from207# within a `_run_async` context.208async def test_impl() -> None:209loop = asyncio.get_running_loop()210task = loop.create_task(asyncio.sleep(0))211212_run_async(await_task(task))213214async def await_task(task: Any) -> None:215await task216217asyncio.run(test_impl())218219220def test_async_index_error_25209(tmp_sqlite_db: Path) -> None:221base_uri = f"sqlite:///{tmp_sqlite_db}"222table_name = "test_25209"223224pl.select(x=1, y=2, z=3).write_database(225table_name,226connection=base_uri,227engine="sqlalchemy",228if_table_exists="replace",229)230231async def run_async_query() -> Any:232async_engine = create_async_engine(f"sqlite+aio{base_uri}")233try:234return pl.read_database(235query=f"SELECT * FROM {table_name}",236connection=async_engine,237)238finally:239await async_engine.dispose()240241async def testing() -> Any:242# return/await multiple queries243return await asyncio.gather(*(run_async_query(), run_async_query()))244245df1, df2 = asyncio.run(testing())246247assert_frame_equal(df1, df2)248assert df1.rows() == [(1, 2, 3)]249250251