Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/io/test_collect_batches.py
8479 views
1
from __future__ import annotations
2
3
import subprocess
4
import sys
5
from typing import TYPE_CHECKING
6
7
import pytest
8
9
import polars as pl
10
from polars.testing import assert_frame_equal
11
12
if TYPE_CHECKING:
13
from polars._typing import EngineType
14
15
16
@pytest.mark.parametrize("engine", ["in-memory", "streaming"])
17
def test_sink_batches(engine: EngineType) -> None:
18
df = pl.DataFrame({"a": range(100)})
19
frames: list[pl.DataFrame] = []
20
21
df.lazy().sink_batches(lambda df: frames.append(df), engine=engine) # type: ignore[call-overload]
22
23
assert_frame_equal(pl.concat(frames), df)
24
25
26
@pytest.mark.parametrize("engine", ["in-memory", "streaming"])
27
def test_sink_batches_early_stop(engine: EngineType) -> None:
28
df = pl.DataFrame({"a": range(1000)})
29
stopped = False
30
31
def cb(_: pl.DataFrame) -> bool | None:
32
nonlocal stopped
33
assert not stopped
34
stopped = True
35
return True
36
37
df.lazy().sink_batches(cb, chunk_size=100, engine=engine) # type: ignore[call-overload]
38
assert stopped
39
40
41
def test_collect_batches() -> None:
42
df = pl.DataFrame({"a": range(100)})
43
frames = []
44
45
for f in df.lazy().collect_batches():
46
frames += [f]
47
48
assert_frame_equal(pl.concat(frames), df)
49
50
51
def test_chunk_size() -> None:
52
df = pl.DataFrame({"a": range(113)})
53
54
for f in df.lazy().collect_batches(chunk_size=17):
55
expected = df.head(17)
56
df = df.slice(17)
57
58
assert_frame_equal(f, expected)
59
60
df = pl.DataFrame({"a": range(10)})
61
62
for f in df.lazy().collect_batches(chunk_size=10):
63
assert not f.is_empty()
64
65
expected = df.head(10)
66
df = df.slice(10)
67
68
assert_frame_equal(f, expected)
69
70
71
@pytest.mark.slow
72
def test_collect_batches_releases_gil_26031() -> None:
73
out = subprocess.check_output(
74
[
75
sys.executable,
76
"-c",
77
"""\
78
import polars as pl
79
from polars.testing import assert_frame_equal
80
81
def reentrant_add(x: int):
82
return next(
83
pl.DataFrame({"": x})
84
.lazy()
85
.select(pl.first().map_elements(lambda x: x + 1, return_dtype=pl.UInt32))
86
.collect_batches(engine="streaming")
87
).item()
88
89
assert_frame_equal(
90
pl.concat(
91
pl.LazyFrame({"a": range(10)})
92
.with_columns(
93
out=pl.col("a").map_elements(reentrant_add, return_dtype=pl.UInt32)
94
)
95
.collect_batches(engine="streaming")
96
),
97
pl.DataFrame(
98
[
99
pl.Series("a", [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=pl.Int64),
100
pl.Series("out", [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=pl.UInt32),
101
]
102
),
103
)
104
105
print("OK", end="")
106
""",
107
],
108
timeout=5,
109
)
110
111
assert out == b"OK"
112
113