Path: blob/main/py-polars/tests/unit/io/test_sink_batches.py
7884 views
from __future__ import annotations12from typing import TYPE_CHECKING34import pytest56import polars as pl7from polars.testing import assert_frame_equal89if TYPE_CHECKING:10from polars._typing import EngineType111213@pytest.mark.parametrize("engine", ["in-memory", "streaming"])14def test_sink_batches(engine: EngineType) -> None:15df = pl.DataFrame({"a": range(100)})16frames: list[pl.DataFrame] = []1718df.lazy().sink_batches(lambda df: frames.append(df), engine=engine) # type: ignore[call-overload]1920assert_frame_equal(pl.concat(frames), df)212223@pytest.mark.parametrize("engine", ["in-memory", "streaming"])24def test_sink_batches_early_stop(engine: EngineType) -> None:25df = pl.DataFrame({"a": range(1000)})26stopped = False2728def cb(_: pl.DataFrame) -> bool | None:29nonlocal stopped30assert not stopped31stopped = True32return True3334df.lazy().sink_batches(cb, chunk_size=100, engine=engine) # type: ignore[call-overload]35assert stopped363738def test_collect_batches() -> None:39df = pl.DataFrame({"a": range(100)})40frames = []4142for f in df.lazy().collect_batches():43frames += [f]4445assert_frame_equal(pl.concat(frames), df)464748def test_chunk_size() -> None:49df = pl.DataFrame({"a": range(113)})5051for f in df.lazy().collect_batches(chunk_size=17):52expected = df.head(17)53df = df.slice(17)5455assert_frame_equal(f, expected)5657df = pl.DataFrame({"a": range(10)})5859for f in df.lazy().collect_batches(chunk_size=10):60assert not f.is_empty()6162expected = df.head(10)63df = df.slice(10)6465assert_frame_equal(f, expected)666768