Path: blob/main/py-polars/tests/unit/io/database/test_async.py
6939 views
from __future__ import annotations12import asyncio3from math import ceil4from types import ModuleType5from typing import TYPE_CHECKING, Any, overload67import pytest8import sqlalchemy9from sqlalchemy.ext.asyncio import create_async_engine1011import polars as pl12from polars._utils.various import parse_version13from polars.io.database._utils import _run_async14from polars.testing import assert_frame_equal15from tests.unit.conftest import mock_module_import1617if TYPE_CHECKING:18from collections.abc import Iterable19from pathlib import Path2021SURREAL_MOCK_DATA: list[dict[str, Any]] = [22{23"id": "item:8xj31jfpdkf9gvmxdxpi",24"name": "abc",25"tags": ["polars"],26"checked": False,27},28{29"id": "item:l59k19swv2adsv4q04cj",30"name": "mno",31"tags": ["async"],32"checked": None,33},34{35"id": "item:w831f1oyqnwztv5q03em",36"name": "xyz",37"tags": ["stroop", "wafel"],38"checked": True,39},40]414243class MockSurrealConnection:44"""Mock SurrealDB connection/client object."""4546__module__ = "surrealdb"4748def __init__(self, url: str, mock_data: list[dict[str, Any]]) -> None:49self._mock_data = mock_data.copy()50self.url = url5152async def __aenter__(self) -> Any:53await self.connect()54return self5556async def __aexit__(self, *args: object, **kwargs: Any) -> None:57await self.close()5859async def close(self) -> None:60pass6162async def connect(self) -> None:63pass6465async def use(self, namespace: str, database: str) -> None:66pass6768async def query(69self, query: str, variables: dict[str, Any] | None = None70) -> list[dict[str, Any]]:71return [{"result": self._mock_data, "status": "OK", "time": "32.083µs"}]727374class MockedSurrealModule(ModuleType):75"""Mock SurrealDB module; enables internal `isinstance` check for AsyncSurrealDB."""7677AsyncSurrealDB = MockSurrealConnection787980@pytest.mark.skipif(81parse_version(sqlalchemy.__version__) < (2, 0),82reason="SQLAlchemy 2.0+ required for async tests",83)84def test_read_async(tmp_sqlite_db: Path) -> None:85# confirm that we can load frame data from the core sqlalchemy async86# primitives: AsyncConnection, AsyncEngine, and async_sessionmaker87from sqlalchemy.ext.asyncio import async_sessionmaker8889async_engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_sqlite_db}")90async_connection = async_engine.connect()91async_session = async_sessionmaker(async_engine)92async_session_inst = async_session()9394expected_frame = pl.DataFrame(95{"id": [2, 1], "name": ["other", "misc"], "value": [-99.5, 100.0]}96)97async_conn: Any98for async_conn in (99async_engine,100async_connection,101async_session,102async_session_inst,103):104if async_conn in (async_session, async_session_inst):105constraint, execute_opts = "", {}106else:107constraint = "WHERE value > :n"108execute_opts = {"parameters": {"n": -1000}}109110df = pl.read_database(111query=f"""112SELECT id, name, value113FROM test_data {constraint}114ORDER BY id DESC115""",116connection=async_conn,117execute_options=execute_opts,118)119assert_frame_equal(expected_frame, df)120121122async def _nested_async_test(tmp_sqlite_db: Path) -> pl.DataFrame:123async_engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_sqlite_db}")124return pl.read_database(125query="SELECT id, name FROM test_data ORDER BY id",126connection=async_engine.connect(),127)128129130@pytest.mark.skipif(131parse_version(sqlalchemy.__version__) < (2, 0),132reason="SQLAlchemy 2.0+ required for async tests",133)134def test_read_async_nested(tmp_sqlite_db: Path) -> None:135# This tests validates that we can handle nested async calls136expected_frame = pl.DataFrame({"id": [1, 2], "name": ["misc", "other"]})137df = asyncio.run(_nested_async_test(tmp_sqlite_db))138assert_frame_equal(expected_frame, df)139140141@overload142async def _surreal_query_as_frame(143url: str, query: str, batch_size: None144) -> pl.DataFrame: ...145146147@overload148async def _surreal_query_as_frame(149url: str, query: str, batch_size: int150) -> Iterable[pl.DataFrame]: ...151152153async def _surreal_query_as_frame(154url: str, query: str, batch_size: int | None155) -> pl.DataFrame | Iterable[pl.DataFrame]:156batch_params = (157{"iter_batches": True, "batch_size": batch_size} if batch_size else {}158)159async with MockSurrealConnection(url=url, mock_data=SURREAL_MOCK_DATA) as client:160await client.use(namespace="test", database="test")161return pl.read_database( # type: ignore[no-any-return,call-overload]162query=query,163connection=client,164**batch_params,165)166167168@pytest.mark.parametrize("batch_size", [None, 1, 2, 3, 4])169def test_surrealdb_fetchall(batch_size: int | None) -> None:170with mock_module_import("surrealdb", MockedSurrealModule("surrealdb")):171df_expected = pl.DataFrame(SURREAL_MOCK_DATA)172res = asyncio.run(173_surreal_query_as_frame(174url="ws://localhost:8000/rpc",175query="SELECT * FROM item",176batch_size=batch_size,177)178)179if batch_size:180frames = list(res) # type: ignore[call-overload]181n_mock_rows = len(SURREAL_MOCK_DATA)182assert len(frames) == ceil(n_mock_rows / batch_size)183assert_frame_equal(df_expected[:batch_size], frames[0])184else:185assert_frame_equal(df_expected, res) # type: ignore[arg-type]186187188def test_async_nested_captured_loop_21263() -> None:189# Tests awaiting a future that has "captured" the original event loop from190# within a `_run_async` context.191async def test_impl() -> None:192loop = asyncio.get_running_loop()193task = loop.create_task(asyncio.sleep(0))194195_run_async(await_task(task))196197async def await_task(task: Any) -> None:198await task199200asyncio.run(test_impl())201202203