Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/test_async.py
6939 views
1
from __future__ import annotations
2
3
import asyncio
4
import sys
5
import time
6
from functools import partial
7
from typing import Any, Callable
8
9
import pytest
10
11
import polars as pl
12
from polars.dependencies import gevent
13
from polars.exceptions import ColumnNotFoundError
14
15
pytestmark = pytest.mark.slow()
16
17
18
async def _aio_collect_async(raises: bool = False) -> pl.DataFrame:
19
lf = (
20
pl.LazyFrame(
21
{
22
"a": ["a", "b", "a", "b", "b", "c"],
23
"b": [1, 2, 3, 4, 5, 6],
24
"c": [6, 5, 4, 3, 2, 1],
25
}
26
)
27
.group_by("a", maintain_order=True)
28
.agg(pl.all().sum())
29
)
30
if raises:
31
lf = lf.select(pl.col("foo_bar"))
32
return await lf.collect_async()
33
34
35
async def _aio_collect_all_async(raises: bool = False) -> list[pl.DataFrame]:
36
lf = (
37
pl.LazyFrame(
38
{
39
"a": ["a", "b", "a", "b", "b", "c"],
40
"b": [1, 2, 3, 4, 5, 6],
41
"c": [6, 5, 4, 3, 2, 1],
42
}
43
)
44
.group_by("a", maintain_order=True)
45
.agg(pl.all().sum())
46
)
47
if raises:
48
lf = lf.select(pl.col("foo_bar"))
49
50
lf2 = pl.LazyFrame({"a": [1, 2], "b": [1, 2]}).group_by("a").sum()
51
52
return await pl.collect_all_async([lf, lf2])
53
54
55
_aio_collect = pytest.mark.parametrize(
56
("collect", "raises"),
57
[
58
(_aio_collect_async, None),
59
(_aio_collect_all_async, None),
60
(partial(_aio_collect_async, True), ColumnNotFoundError),
61
(partial(_aio_collect_all_async, True), ColumnNotFoundError),
62
],
63
)
64
65
66
def _aio_run(coroutine: Any, raises: Exception | None = None) -> None:
67
if raises is not None:
68
with pytest.raises(raises): # type: ignore[call-overload]
69
asyncio.run(coroutine)
70
else:
71
assert len(asyncio.run(coroutine)) > 0
72
73
74
@_aio_collect
75
def test_collect_async_switch(
76
collect: Callable[[], Any],
77
raises: Exception | None,
78
) -> None:
79
async def main() -> Any:
80
df = collect()
81
await asyncio.sleep(0.3)
82
return await df
83
84
_aio_run(main(), raises)
85
86
87
@_aio_collect
88
def test_collect_async_task(
89
collect: Callable[[], Any], raises: Exception | None
90
) -> None:
91
async def main() -> Any:
92
df = asyncio.create_task(collect())
93
await asyncio.sleep(0.3)
94
return await df
95
96
_aio_run(main(), raises)
97
98
99
def _gevent_collect_async(raises: bool = False) -> Any:
100
lf = (
101
pl.LazyFrame(
102
{
103
"a": ["a", "b", "a", "b", "b", "c"],
104
"b": [1, 2, 3, 4, 5, 6],
105
"c": [6, 5, 4, 3, 2, 1],
106
}
107
)
108
.group_by("a", maintain_order=True)
109
.agg(pl.all().sum())
110
)
111
if raises:
112
lf = lf.select(pl.col("foo_bar"))
113
return lf.collect_async(gevent=True)
114
115
116
def _gevent_collect_all_async(raises: bool = False) -> Any:
117
lf = (
118
pl.LazyFrame(
119
{
120
"a": ["a", "b", "a", "b", "b", "c"],
121
"b": [1, 2, 3, 4, 5, 6],
122
"c": [6, 5, 4, 3, 2, 1],
123
}
124
)
125
.group_by("a", maintain_order=True)
126
.agg(pl.all().sum())
127
)
128
if raises:
129
lf = lf.select(pl.col("foo_bar"))
130
return pl.collect_all_async([lf], gevent=True)
131
132
133
_gevent_collect = pytest.mark.parametrize(
134
("get_result", "raises"),
135
[
136
(_gevent_collect_async, None),
137
(_gevent_collect_all_async, None),
138
(partial(_gevent_collect_async, True), ColumnNotFoundError),
139
(partial(_gevent_collect_all_async, True), ColumnNotFoundError),
140
],
141
)
142
143
144
def _gevent_run(callback: Callable[[], Any], raises: Exception | None = None) -> None:
145
if raises is not None:
146
with pytest.raises(raises): # type: ignore[call-overload]
147
callback()
148
else:
149
assert len(callback()) > 0
150
151
152
@_gevent_collect
153
def test_gevent_collect_async_without_hub(
154
get_result: Callable[[], Any], raises: Exception | None
155
) -> None:
156
def main() -> Any:
157
return get_result().get()
158
159
_gevent_run(main, raises)
160
161
162
@_gevent_collect
163
def test_gevent_collect_async_with_hub(
164
get_result: Callable[[], Any], raises: Exception | None
165
) -> None:
166
_hub = gevent.get_hub()
167
168
def main() -> Any:
169
return get_result().get()
170
171
_gevent_run(main, raises)
172
173
174
@pytest.mark.skipif(sys.platform == "win32", reason="May time out on Windows")
175
@_gevent_collect
176
def test_gevent_collect_async_switch(
177
get_result: Callable[[], Any], raises: Exception | None
178
) -> None:
179
def main() -> Any:
180
result = get_result()
181
gevent.sleep(0.1)
182
return result.get(block=False, timeout=3)
183
184
_gevent_run(main, raises)
185
186
187
@_gevent_collect
188
def test_gevent_collect_async_no_switch(
189
get_result: Callable[[], Any], raises: Exception | None
190
) -> None:
191
def main() -> Any:
192
result = get_result()
193
time.sleep(1)
194
return result.get(block=False, timeout=None)
195
196
_gevent_run(main, raises)
197
198
199
@_gevent_collect
200
def test_gevent_collect_async_spawn(
201
get_result: Callable[[], Any], raises: Exception | None
202
) -> None:
203
def main() -> Any:
204
result_greenlet = gevent.spawn(get_result)
205
gevent.spawn(gevent.sleep, 0.1)
206
return result_greenlet.get().get()
207
208
_gevent_run(main, raises)
209
210