Path: blob/main/py-polars/tests/unit/io/test_pyarrow_dataset.py
6939 views
from __future__ import annotations12from datetime import date, datetime, time3from typing import TYPE_CHECKING, Callable45import pyarrow.dataset as ds67import polars as pl8from polars.testing import assert_frame_equal910if TYPE_CHECKING:11from pathlib import Path1213import pytest141516def helper_dataset_test(17file_path: Path,18query: Callable[[pl.LazyFrame], pl.LazyFrame],19batch_size: int | None = None,20n_expected: int | None = None,21check_predicate_pushdown: bool = False,22) -> None:23dset = ds.dataset(file_path, format="ipc")24q = pl.scan_ipc(file_path).pipe(query)2526expected = q.collect()27out = pl.scan_pyarrow_dataset(dset, batch_size=batch_size).pipe(query).collect()28assert_frame_equal(out, expected)29if n_expected is not None:30assert len(out) == n_expected3132if check_predicate_pushdown:33assert "FILTER" not in q.explain()343536# @pytest.mark.write_disk()37def test_pyarrow_dataset_source(df: pl.DataFrame, tmp_path: Path) -> None:38file_path = tmp_path / "small.ipc"39df.write_ipc(file_path)4041helper_dataset_test(42file_path,43lambda lf: lf.filter("bools").select("bools", "floats", "date"),44n_expected=1,45check_predicate_pushdown=True,46)47helper_dataset_test(48file_path,49lambda lf: lf.filter(~pl.col("bools")).select("bools", "floats", "date"),50n_expected=2,51check_predicate_pushdown=True,52)53helper_dataset_test(54file_path,55lambda lf: lf.filter(pl.col("int_nulls").is_null()).select(56"bools", "floats", "date"57),58n_expected=1,59check_predicate_pushdown=True,60)61helper_dataset_test(62file_path,63lambda lf: lf.filter(pl.col("int_nulls").is_not_null()).select(64"bools", "floats", "date"65),66n_expected=2,67check_predicate_pushdown=True,68)69helper_dataset_test(70file_path,71lambda lf: lf.filter(72pl.col("int_nulls").is_not_null() == pl.col("bools")73).select("bools", "floats", "date"),74n_expected=0,75check_predicate_pushdown=True,76)77# this equality on a column with nulls fails as pyarrow has different78# handling kleene logic. We leave it for now and document it in the function.79helper_dataset_test(80file_path,81lambda lf: lf.filter(pl.col("int") == 10).select(82"bools", "floats", "int_nulls"83),84n_expected=0,85check_predicate_pushdown=True,86)87helper_dataset_test(88file_path,89lambda lf: lf.filter(pl.col("int") != 10).select(90"bools", "floats", "int_nulls"91),92n_expected=3,93check_predicate_pushdown=True,94)9596for closed, n_expected in zip(["both", "left", "right", "none"], [3, 2, 2, 1]):97helper_dataset_test(98file_path,99lambda lf, closed=closed: lf.filter( # type: ignore[misc]100pl.col("int").is_between(1, 3, closed=closed)101).select("bools", "floats", "date"),102n_expected=n_expected,103check_predicate_pushdown=True,104)105# this predicate is not supported by pyarrow106# check if we still do it on our side107helper_dataset_test(108file_path,109lambda lf: lf.filter(pl.col("floats").sum().over("date") == 10).select(110"bools", "floats", "date"111),112n_expected=0,113)114# temporal types115helper_dataset_test(116file_path,117lambda lf: lf.filter(pl.col("date") < date(1972, 1, 1)).select(118"bools", "floats", "date"119),120n_expected=1,121check_predicate_pushdown=True,122)123helper_dataset_test(124file_path,125lambda lf: lf.filter(126pl.col("datetime") > datetime(1970, 1, 1, second=13)127).select("bools", "floats", "date"),128n_expected=1,129check_predicate_pushdown=True,130)131# not yet supported in pyarrow132helper_dataset_test(133file_path,134lambda lf: lf.filter(pl.col("time") >= time(microsecond=100)).select(135"bools", "time", "date"136),137n_expected=3,138check_predicate_pushdown=True,139)140# pushdown is_in141helper_dataset_test(142file_path,143lambda lf: lf.filter(pl.col("int").is_in([1, 3, 20])).select(144"bools", "floats", "date"145),146n_expected=2,147check_predicate_pushdown=True,148)149helper_dataset_test(150file_path,151lambda lf: lf.filter(152pl.col("date").is_in([date(1973, 8, 17), date(1973, 5, 19)])153).select("bools", "floats", "date"),154n_expected=2,155check_predicate_pushdown=True,156)157helper_dataset_test(158file_path,159lambda lf: lf.filter(160pl.col("datetime").is_in(161[162datetime(1970, 1, 1, 0, 0, 12, 341234),163datetime(1970, 1, 1, 0, 0, 13, 241324),164]165)166).select("bools", "floats", "date"),167n_expected=2,168check_predicate_pushdown=True,169)170helper_dataset_test(171file_path,172lambda lf: lf.filter(pl.col("int").is_in(list(range(120)))).select(173"bools", "floats", "date"174),175n_expected=3,176check_predicate_pushdown=True,177)178helper_dataset_test(179file_path,180lambda lf: lf.filter(pl.col("cat").is_in([])).select("bools", "floats", "date"),181n_expected=0,182)183helper_dataset_test(184file_path,185lambda lf: lf.select(pl.exclude("enum")),186batch_size=2,187n_expected=3,188)189190# direct filter191helper_dataset_test(192file_path,193lambda lf: lf.filter(pl.Series([True, False, True])).select(194"bools", "floats", "date"195),196n_expected=2,197)198199helper_dataset_test(200file_path,201lambda lf: lf.filter(pl.col("bools") & pl.col("int").is_in([1, 2])).select(202"bools", "floats"203),204n_expected=1,205check_predicate_pushdown=True,206)207208209def test_pyarrow_dataset_comm_subplan_elim(tmp_path: Path) -> None:210df0 = pl.DataFrame({"a": [1, 2, 3]})211212df1 = pl.DataFrame({"a": [1, 2]})213214file_path_0 = tmp_path / "0.parquet"215file_path_1 = tmp_path / "1.parquet"216217df0.write_parquet(file_path_0)218df1.write_parquet(file_path_1)219220ds0 = ds.dataset(file_path_0, format="parquet")221ds1 = ds.dataset(file_path_1, format="parquet")222223lf0 = pl.scan_pyarrow_dataset(ds0)224lf1 = pl.scan_pyarrow_dataset(ds1)225226assert lf0.join(lf1, on="a", how="inner").collect().to_dict(as_series=False) == {227"a": [1, 2]228}229230231def test_pyarrow_dataset_predicate_verbose_log(232tmp_path: Path,233monkeypatch: pytest.MonkeyPatch,234capfd: pytest.CaptureFixture[str],235) -> None:236monkeypatch.setenv("POLARS_VERBOSE_SENSITIVE", "1")237238df = pl.DataFrame({"a": [1, 2, 3]})239file_path_0 = tmp_path / "0"240241df.write_parquet(file_path_0)242dset = ds.dataset(file_path_0, format="parquet")243244q = pl.scan_pyarrow_dataset(dset).filter(pl.col("a") < 3)245246capfd.readouterr()247assert_frame_equal(q.collect(), pl.DataFrame({"a": [1, 2]}))248capture = capfd.readouterr().err249250assert (251"[SENSITIVE]: python_scan_predicate: "252'predicate node: [(col("a")) < (3)], '253"converted pyarrow predicate: (pa.compute.field('a') < 3)"254) in capture255256q = pl.scan_pyarrow_dataset(dset).filter(pl.col("a").cast(pl.String) < "3")257258capfd.readouterr()259assert_frame_equal(q.collect(), pl.DataFrame({"a": [1, 2]}))260capture = capfd.readouterr().err261262assert (263"[SENSITIVE]: python_scan_predicate: "264'predicate node: [(col("a").strict_cast(String)) < ("3")], '265"converted pyarrow predicate: <conversion failed>\n"266) in capture267268269