Path: blob/main/py-polars/tests/unit/lazyframe/test_async.py
8430 views
from __future__ import annotations12import asyncio3import sys4import time5from functools import partial6from typing import TYPE_CHECKING, Any78import pytest910import polars as pl11from polars._dependencies import gevent12from polars.exceptions import ColumnNotFoundError1314if TYPE_CHECKING:15from collections.abc import Callable1617pytestmark = pytest.mark.slow()181920async def _aio_collect_async(raises: bool = False) -> pl.DataFrame:21lf = (22pl.LazyFrame(23{24"a": ["a", "b", "a", "b", "b", "c"],25"b": [1, 2, 3, 4, 5, 6],26"c": [6, 5, 4, 3, 2, 1],27}28)29.group_by("a", maintain_order=True)30.agg(pl.all().sum())31)32if raises:33lf = lf.select(pl.col("foo_bar"))34return await lf.collect_async()353637async def _aio_collect_all_async(raises: bool = False) -> list[pl.DataFrame]:38lf = (39pl.LazyFrame(40{41"a": ["a", "b", "a", "b", "b", "c"],42"b": [1, 2, 3, 4, 5, 6],43"c": [6, 5, 4, 3, 2, 1],44}45)46.group_by("a", maintain_order=True)47.agg(pl.all().sum())48)49if raises:50lf = lf.select(pl.col("foo_bar"))5152lf2 = pl.LazyFrame({"a": [1, 2], "b": [1, 2]}).group_by("a").sum()5354return await pl.collect_all_async([lf, lf2])555657_aio_collect = pytest.mark.parametrize(58("collect", "raises"),59[60(_aio_collect_async, None),61(_aio_collect_all_async, None),62(partial(_aio_collect_async, True), ColumnNotFoundError),63(partial(_aio_collect_all_async, True), ColumnNotFoundError),64],65)666768def _aio_run(coroutine: Any, raises: Exception | None = None) -> None:69if raises is not None:70with pytest.raises(raises): # type: ignore[call-overload]71asyncio.run(coroutine)72else:73assert len(asyncio.run(coroutine)) > 0747576@_aio_collect77def test_collect_async_switch(78collect: Callable[[], Any],79raises: Exception | None,80) -> None:81async def main() -> Any:82df = collect()83await asyncio.sleep(0.3)84return await df8586_aio_run(main(), raises)878889@_aio_collect90def test_collect_async_task(91collect: Callable[[], Any], raises: Exception | None92) -> None:93async def main() -> Any:94df = asyncio.create_task(collect())95await asyncio.sleep(0.3)96return await df9798_aio_run(main(), raises)99100101def _gevent_collect_async(raises: bool = False) -> Any:102lf = (103pl.LazyFrame(104{105"a": ["a", "b", "a", "b", "b", "c"],106"b": [1, 2, 3, 4, 5, 6],107"c": [6, 5, 4, 3, 2, 1],108}109)110.group_by("a", maintain_order=True)111.agg(pl.all().sum())112)113if raises:114lf = lf.select(pl.col("foo_bar"))115return lf.collect_async(gevent=True)116117118def _gevent_collect_all_async(raises: bool = False) -> Any:119lf = (120pl.LazyFrame(121{122"a": ["a", "b", "a", "b", "b", "c"],123"b": [1, 2, 3, 4, 5, 6],124"c": [6, 5, 4, 3, 2, 1],125}126)127.group_by("a", maintain_order=True)128.agg(pl.all().sum())129)130if raises:131lf = lf.select(pl.col("foo_bar"))132return pl.collect_all_async([lf], gevent=True)133134135_gevent_collect = pytest.mark.parametrize(136("get_result", "raises"),137[138(_gevent_collect_async, None),139(_gevent_collect_all_async, None),140(partial(_gevent_collect_async, True), ColumnNotFoundError),141(partial(_gevent_collect_all_async, True), ColumnNotFoundError),142],143)144145146def _gevent_run(callback: Callable[[], Any], raises: Exception | None = None) -> None:147if raises is not None:148with pytest.raises(raises): # type: ignore[call-overload]149callback()150else:151assert len(callback()) > 0152153154@_gevent_collect155def test_gevent_collect_async_without_hub(156get_result: Callable[[], Any], raises: Exception | None157) -> None:158def main() -> Any:159return get_result().get()160161_gevent_run(main, raises)162163164@_gevent_collect165def test_gevent_collect_async_with_hub(166get_result: Callable[[], Any], raises: Exception | None167) -> None:168_hub = gevent.get_hub()169170def main() -> Any:171return get_result().get()172173_gevent_run(main, raises)174175176@pytest.mark.skipif(sys.platform == "win32", reason="May time out on Windows")177@_gevent_collect178def test_gevent_collect_async_switch(179get_result: Callable[[], Any], raises: Exception | None180) -> None:181def main() -> Any:182result = get_result()183gevent.sleep(0.1)184return result.get(block=False, timeout=3)185186_gevent_run(main, raises)187188189@_gevent_collect190def test_gevent_collect_async_no_switch(191get_result: Callable[[], Any], raises: Exception | None192) -> None:193def main() -> Any:194result = get_result()195time.sleep(1)196return result.get(block=False, timeout=None)197198_gevent_run(main, raises)199200201@_gevent_collect202def test_gevent_collect_async_spawn(203get_result: Callable[[], Any], raises: Exception | None204) -> None:205def main() -> Any:206result_greenlet = gevent.spawn(get_result)207gevent.spawn(gevent.sleep, 0.1)208return result_greenlet.get().get()209210_gevent_run(main, raises)211212213