Path: blob/main/py-polars/tests/unit/io/test_collect_batches.py
8479 views
from __future__ import annotations12import subprocess3import sys4from typing import TYPE_CHECKING56import pytest78import polars as pl9from polars.testing import assert_frame_equal1011if TYPE_CHECKING:12from polars._typing import EngineType131415@pytest.mark.parametrize("engine", ["in-memory", "streaming"])16def test_sink_batches(engine: EngineType) -> None:17df = pl.DataFrame({"a": range(100)})18frames: list[pl.DataFrame] = []1920df.lazy().sink_batches(lambda df: frames.append(df), engine=engine) # type: ignore[call-overload]2122assert_frame_equal(pl.concat(frames), df)232425@pytest.mark.parametrize("engine", ["in-memory", "streaming"])26def test_sink_batches_early_stop(engine: EngineType) -> None:27df = pl.DataFrame({"a": range(1000)})28stopped = False2930def cb(_: pl.DataFrame) -> bool | None:31nonlocal stopped32assert not stopped33stopped = True34return True3536df.lazy().sink_batches(cb, chunk_size=100, engine=engine) # type: ignore[call-overload]37assert stopped383940def test_collect_batches() -> None:41df = pl.DataFrame({"a": range(100)})42frames = []4344for f in df.lazy().collect_batches():45frames += [f]4647assert_frame_equal(pl.concat(frames), df)484950def test_chunk_size() -> None:51df = pl.DataFrame({"a": range(113)})5253for f in df.lazy().collect_batches(chunk_size=17):54expected = df.head(17)55df = df.slice(17)5657assert_frame_equal(f, expected)5859df = pl.DataFrame({"a": range(10)})6061for f in df.lazy().collect_batches(chunk_size=10):62assert not f.is_empty()6364expected = df.head(10)65df = df.slice(10)6667assert_frame_equal(f, expected)686970@pytest.mark.slow71def test_collect_batches_releases_gil_26031() -> None:72out = subprocess.check_output(73[74sys.executable,75"-c",76"""\77import polars as pl78from polars.testing import assert_frame_equal7980def reentrant_add(x: int):81return next(82pl.DataFrame({"": x})83.lazy()84.select(pl.first().map_elements(lambda x: x + 1, return_dtype=pl.UInt32))85.collect_batches(engine="streaming")86).item()8788assert_frame_equal(89pl.concat(90pl.LazyFrame({"a": range(10)})91.with_columns(92out=pl.col("a").map_elements(reentrant_add, return_dtype=pl.UInt32)93)94.collect_batches(engine="streaming")95),96pl.DataFrame(97[98pl.Series("a", [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=pl.Int64),99pl.Series("out", [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=pl.UInt32),100]101),102)103104print("OK", end="")105""",106],107timeout=5,108)109110assert out == b"OK"111112113