Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/lazyframe/test_async.py
8430 views
1
from __future__ import annotations
2
3
import asyncio
4
import sys
5
import time
6
from functools import partial
7
from typing import TYPE_CHECKING, Any
8
9
import pytest
10
11
import polars as pl
12
from polars._dependencies import gevent
13
from polars.exceptions import ColumnNotFoundError
14
15
if TYPE_CHECKING:
16
from collections.abc import Callable
17
18
pytestmark = pytest.mark.slow()
19
20
21
async def _aio_collect_async(raises: bool = False) -> pl.DataFrame:
22
lf = (
23
pl.LazyFrame(
24
{
25
"a": ["a", "b", "a", "b", "b", "c"],
26
"b": [1, 2, 3, 4, 5, 6],
27
"c": [6, 5, 4, 3, 2, 1],
28
}
29
)
30
.group_by("a", maintain_order=True)
31
.agg(pl.all().sum())
32
)
33
if raises:
34
lf = lf.select(pl.col("foo_bar"))
35
return await lf.collect_async()
36
37
38
async def _aio_collect_all_async(raises: bool = False) -> list[pl.DataFrame]:
39
lf = (
40
pl.LazyFrame(
41
{
42
"a": ["a", "b", "a", "b", "b", "c"],
43
"b": [1, 2, 3, 4, 5, 6],
44
"c": [6, 5, 4, 3, 2, 1],
45
}
46
)
47
.group_by("a", maintain_order=True)
48
.agg(pl.all().sum())
49
)
50
if raises:
51
lf = lf.select(pl.col("foo_bar"))
52
53
lf2 = pl.LazyFrame({"a": [1, 2], "b": [1, 2]}).group_by("a").sum()
54
55
return await pl.collect_all_async([lf, lf2])
56
57
58
_aio_collect = pytest.mark.parametrize(
59
("collect", "raises"),
60
[
61
(_aio_collect_async, None),
62
(_aio_collect_all_async, None),
63
(partial(_aio_collect_async, True), ColumnNotFoundError),
64
(partial(_aio_collect_all_async, True), ColumnNotFoundError),
65
],
66
)
67
68
69
def _aio_run(coroutine: Any, raises: Exception | None = None) -> None:
70
if raises is not None:
71
with pytest.raises(raises): # type: ignore[call-overload]
72
asyncio.run(coroutine)
73
else:
74
assert len(asyncio.run(coroutine)) > 0
75
76
77
@_aio_collect
78
def test_collect_async_switch(
79
collect: Callable[[], Any],
80
raises: Exception | None,
81
) -> None:
82
async def main() -> Any:
83
df = collect()
84
await asyncio.sleep(0.3)
85
return await df
86
87
_aio_run(main(), raises)
88
89
90
@_aio_collect
91
def test_collect_async_task(
92
collect: Callable[[], Any], raises: Exception | None
93
) -> None:
94
async def main() -> Any:
95
df = asyncio.create_task(collect())
96
await asyncio.sleep(0.3)
97
return await df
98
99
_aio_run(main(), raises)
100
101
102
def _gevent_collect_async(raises: bool = False) -> Any:
103
lf = (
104
pl.LazyFrame(
105
{
106
"a": ["a", "b", "a", "b", "b", "c"],
107
"b": [1, 2, 3, 4, 5, 6],
108
"c": [6, 5, 4, 3, 2, 1],
109
}
110
)
111
.group_by("a", maintain_order=True)
112
.agg(pl.all().sum())
113
)
114
if raises:
115
lf = lf.select(pl.col("foo_bar"))
116
return lf.collect_async(gevent=True)
117
118
119
def _gevent_collect_all_async(raises: bool = False) -> Any:
120
lf = (
121
pl.LazyFrame(
122
{
123
"a": ["a", "b", "a", "b", "b", "c"],
124
"b": [1, 2, 3, 4, 5, 6],
125
"c": [6, 5, 4, 3, 2, 1],
126
}
127
)
128
.group_by("a", maintain_order=True)
129
.agg(pl.all().sum())
130
)
131
if raises:
132
lf = lf.select(pl.col("foo_bar"))
133
return pl.collect_all_async([lf], gevent=True)
134
135
136
_gevent_collect = pytest.mark.parametrize(
137
("get_result", "raises"),
138
[
139
(_gevent_collect_async, None),
140
(_gevent_collect_all_async, None),
141
(partial(_gevent_collect_async, True), ColumnNotFoundError),
142
(partial(_gevent_collect_all_async, True), ColumnNotFoundError),
143
],
144
)
145
146
147
def _gevent_run(callback: Callable[[], Any], raises: Exception | None = None) -> None:
148
if raises is not None:
149
with pytest.raises(raises): # type: ignore[call-overload]
150
callback()
151
else:
152
assert len(callback()) > 0
153
154
155
@_gevent_collect
156
def test_gevent_collect_async_without_hub(
157
get_result: Callable[[], Any], raises: Exception | None
158
) -> None:
159
def main() -> Any:
160
return get_result().get()
161
162
_gevent_run(main, raises)
163
164
165
@_gevent_collect
166
def test_gevent_collect_async_with_hub(
167
get_result: Callable[[], Any], raises: Exception | None
168
) -> None:
169
_hub = gevent.get_hub()
170
171
def main() -> Any:
172
return get_result().get()
173
174
_gevent_run(main, raises)
175
176
177
@pytest.mark.skipif(sys.platform == "win32", reason="May time out on Windows")
178
@_gevent_collect
179
def test_gevent_collect_async_switch(
180
get_result: Callable[[], Any], raises: Exception | None
181
) -> None:
182
def main() -> Any:
183
result = get_result()
184
gevent.sleep(0.1)
185
return result.get(block=False, timeout=3)
186
187
_gevent_run(main, raises)
188
189
190
@_gevent_collect
191
def test_gevent_collect_async_no_switch(
192
get_result: Callable[[], Any], raises: Exception | None
193
) -> None:
194
def main() -> Any:
195
result = get_result()
196
time.sleep(1)
197
return result.get(block=False, timeout=None)
198
199
_gevent_run(main, raises)
200
201
202
@_gevent_collect
203
def test_gevent_collect_async_spawn(
204
get_result: Callable[[], Any], raises: Exception | None
205
) -> None:
206
def main() -> Any:
207
result_greenlet = gevent.spawn(get_result)
208
gevent.spawn(gevent.sleep, 0.1)
209
return result_greenlet.get().get()
210
211
_gevent_run(main, raises)
212
213