Path: blob/main/py-polars/tests/unit/io/test_pyarrow_dataset.py
8427 views
from __future__ import annotations12from datetime import date, datetime, time, timezone3from typing import TYPE_CHECKING45import pyarrow as pa6import pyarrow.dataset as ds7import pytest89import polars as pl10from polars.testing import assert_frame_equal1112if TYPE_CHECKING:13from collections.abc import Callable14from pathlib import Path1516from tests.conftest import PlMonkeyPatch171819def helper_dataset_test(20file_path: Path,21query: Callable[[pl.LazyFrame], pl.LazyFrame],22batch_size: int | None = None,23n_expected: int | None = None,24check_predicate_pushdown: bool = False,25) -> None:26dset = ds.dataset(file_path, format="ipc")27q = pl.scan_ipc(file_path).pipe(query)2829expected = q.collect()30out = pl.scan_pyarrow_dataset(dset, batch_size=batch_size).pipe(query).collect()31assert_frame_equal(out, expected)32if n_expected is not None:33assert len(out) == n_expected3435if check_predicate_pushdown:36assert "FILTER" not in q.explain()373839# @pytest.mark.write_disk()40def test_pyarrow_dataset_source(df: pl.DataFrame, tmp_path: Path) -> None:41file_path = tmp_path / "small.ipc"42df.write_ipc(file_path)4344helper_dataset_test(45file_path,46lambda lf: lf.filter("bools").select("bools", "floats", "date"),47n_expected=1,48check_predicate_pushdown=True,49)50helper_dataset_test(51file_path,52lambda lf: lf.filter(~pl.col("bools")).select("bools", "floats", "date"),53n_expected=2,54check_predicate_pushdown=True,55)56helper_dataset_test(57file_path,58lambda lf: lf.filter(pl.col("int_nulls").is_null()).select(59"bools", "floats", "date"60),61n_expected=1,62check_predicate_pushdown=True,63)64helper_dataset_test(65file_path,66lambda lf: lf.filter(pl.col("int_nulls").is_not_null()).select(67"bools", "floats", "date"68),69n_expected=2,70check_predicate_pushdown=True,71)72helper_dataset_test(73file_path,74lambda lf: lf.filter(75pl.col("int_nulls").is_not_null() == pl.col("bools")76).select("bools", "floats", "date"),77n_expected=0,78check_predicate_pushdown=True,79)80# this equality on a column with nulls fails as pyarrow has different81# handling kleene logic. We leave it for now and document it in the function.82helper_dataset_test(83file_path,84lambda lf: lf.filter(pl.col("int") == 10).select(85"bools", "floats", "int_nulls"86),87n_expected=0,88check_predicate_pushdown=True,89)90helper_dataset_test(91file_path,92lambda lf: lf.filter(pl.col("int") != 10).select(93"bools", "floats", "int_nulls"94),95n_expected=3,96check_predicate_pushdown=True,97)9899for closed, n_expected in zip(100["both", "left", "right", "none"], [3, 2, 2, 1], strict=True101):102helper_dataset_test(103file_path,104lambda lf, closed=closed: lf.filter( # type: ignore[misc]105pl.col("int").is_between(1, 3, closed=closed)106).select("bools", "floats", "date"),107n_expected=n_expected,108check_predicate_pushdown=True,109)110# this predicate is not supported by pyarrow111# check if we still do it on our side112helper_dataset_test(113file_path,114lambda lf: lf.filter(pl.col("floats").sum().over("date") == 10).select(115"bools", "floats", "date"116),117n_expected=0,118)119# temporal types120helper_dataset_test(121file_path,122lambda lf: lf.filter(pl.col("date") < date(1972, 1, 1)).select(123"bools", "floats", "date"124),125n_expected=1,126check_predicate_pushdown=True,127)128helper_dataset_test(129file_path,130lambda lf: lf.filter(131pl.col("datetime") > datetime(1970, 1, 1, second=13)132).select("bools", "floats", "date"),133n_expected=1,134check_predicate_pushdown=True,135)136# not yet supported in pyarrow137helper_dataset_test(138file_path,139lambda lf: lf.filter(pl.col("time") >= time(microsecond=100)).select(140"bools", "time", "date"141),142n_expected=3,143check_predicate_pushdown=True,144)145# pushdown is_in146helper_dataset_test(147file_path,148lambda lf: lf.filter(pl.col("int").is_in([1, 3, 20])).select(149"bools", "floats", "date"150),151n_expected=2,152check_predicate_pushdown=True,153)154helper_dataset_test(155file_path,156lambda lf: lf.filter(157pl.col("date").is_in([date(1973, 8, 17), date(1973, 5, 19)])158).select("bools", "floats", "date"),159n_expected=2,160check_predicate_pushdown=True,161)162helper_dataset_test(163file_path,164lambda lf: lf.filter(165pl.col("datetime").is_in(166[167datetime(1970, 1, 1, 0, 0, 12, 341234),168datetime(1970, 1, 1, 0, 0, 13, 241324),169]170)171).select("bools", "floats", "date"),172n_expected=2,173check_predicate_pushdown=True,174)175helper_dataset_test(176file_path,177lambda lf: lf.filter(pl.col("int").is_in(list(range(120)))).select(178"bools", "floats", "date"179),180n_expected=3,181check_predicate_pushdown=True,182)183helper_dataset_test(184file_path,185lambda lf: lf.filter(pl.col("cat").is_in([])).select("bools", "floats", "date"),186n_expected=0,187)188helper_dataset_test(189file_path,190lambda lf: lf.select(pl.exclude("enum")),191batch_size=2,192n_expected=3,193)194195# direct filter196helper_dataset_test(197file_path,198lambda lf: lf.filter(pl.Series([True, False, True])).select(199"bools", "floats", "date"200),201n_expected=2,202)203204helper_dataset_test(205file_path,206lambda lf: lf.filter(pl.col("bools") & pl.col("int").is_in([1, 2])).select(207"bools", "floats"208),209n_expected=1,210check_predicate_pushdown=True,211)212213214def test_pyarrow_dataset_partial_predicate_pushdown(215tmp_path: Path,216plmonkeypatch: PlMonkeyPatch,217capfd: pytest.CaptureFixture[str],218) -> None:219plmonkeypatch.setenv("POLARS_VERBOSE_SENSITIVE", "1")220221df = pl.DataFrame({"a": [1, 2, 3], "b": [10.0, 20.0, 30.0]})222file_path = tmp_path / "0"223df.write_parquet(file_path)224dset = ds.dataset(file_path, format="parquet")225226# col("a") > 1 is convertible; col("a") * col("b") > 25 is not (arithmetic227# on two columns cannot be expressed as a pyarrow compute expression).228# The optimizer pushes both terms into the scan's SELECTION, so our229# MintermIter-based partial conversion should push the convertible part.230q = pl.scan_pyarrow_dataset(dset).filter(231(pl.col("a") > 1) & (pl.col("a") * pl.col("b") > 25)232)233234capfd.readouterr()235result = q.collect()236capture = capfd.readouterr().err237238# Verify: partial predicate was pushed to pyarrow239assert "(pa.compute.field('a') > 1)" in capture240assert (241'residual predicate: Some([([(col("a").cast(Float64)) * (col("b"))]) > (25.0)])'242in capture243)244# Verify: correctness245expected = (246df.lazy().filter((pl.col("a") > 1) & (pl.col("a") * pl.col("b") > 25)).collect()247)248assert_frame_equal(result, expected)249250251def test_pyarrow_dataset_comm_subplan_elim(tmp_path: Path) -> None:252df0 = pl.DataFrame({"a": [1, 2, 3]})253254df1 = pl.DataFrame({"a": [1, 2]})255256file_path_0 = tmp_path / "0.parquet"257file_path_1 = tmp_path / "1.parquet"258259df0.write_parquet(file_path_0)260df1.write_parquet(file_path_1)261262ds0 = ds.dataset(file_path_0, format="parquet")263ds1 = ds.dataset(file_path_1, format="parquet")264265lf0 = pl.scan_pyarrow_dataset(ds0)266lf1 = pl.scan_pyarrow_dataset(ds1)267268assert_frame_equal(269lf0.join(lf1, on="a", how="inner").collect(),270pl.DataFrame({"a": [1, 2]}),271check_row_order=False,272)273274275def test_pyarrow_dataset_predicate_verbose_log(276tmp_path: Path,277plmonkeypatch: PlMonkeyPatch,278capfd: pytest.CaptureFixture[str],279) -> None:280plmonkeypatch.setenv("POLARS_VERBOSE_SENSITIVE", "1")281282df = pl.DataFrame({"a": [1, 2, 3]})283file_path_0 = tmp_path / "0"284285df.write_parquet(file_path_0)286dset = ds.dataset(file_path_0, format="parquet")287288q = pl.scan_pyarrow_dataset(dset).filter(pl.col("a") < 3)289290capfd.readouterr()291assert_frame_equal(q.collect(), pl.DataFrame({"a": [1, 2]}))292capture = capfd.readouterr().err293294assert (295"[SENSITIVE]: python_scan_predicate: "296'predicate node: [(col("a")) < (3)], '297"converted pyarrow predicate: (pa.compute.field('a') < 3), "298"residual predicate: None"299) in capture300301q = pl.scan_pyarrow_dataset(dset).filter(pl.col("a").cast(pl.String) < "3")302303capfd.readouterr()304assert_frame_equal(q.collect(), pl.DataFrame({"a": [1, 2]}))305capture = capfd.readouterr().err306307assert (308"[SENSITIVE]: python_scan_predicate: "309'predicate node: [(col("a").strict_cast(String)) < ("3")], '310"converted pyarrow predicate: <conversion failed>, "311'residual predicate: Some([(col("a").strict_cast(String)) < ("3")])'312) in capture313314315@pytest.mark.write_disk316def test_pyarrow_dataset_python_scan(tmp_path: Path) -> None:317df = pl.DataFrame({"x": [0, 1, 2, 3]})318file_path = tmp_path / "0.parquet"319df.write_parquet(file_path)320321dataset = ds.dataset(file_path)322lf = pl.scan_pyarrow_dataset(dataset)323out = lf.collect(engine="streaming")324325assert_frame_equal(df, out)326327328def test_pyarrow_dataset_allow_pyarrow_filter_false() -> None:329df = pl.DataFrame({"item": ["foo", "bar", "baz"], "price": [10.0, 20.0, 30.0]})330dataset = ds.dataset(df.to_arrow(compat_level=pl.CompatLevel.oldest()))331332# basic scan without filter333result = pl.scan_pyarrow_dataset(dataset, allow_pyarrow_filter=False).collect()334assert_frame_equal(result, df)335336# with filter (predicate should be applied by Polars, not PyArrow)337result = (338pl.scan_pyarrow_dataset(dataset, allow_pyarrow_filter=False)339.filter(pl.col("price") > 15)340.collect()341)342343expected = pl.DataFrame({"item": ["bar", "baz"], "price": [20.0, 30.0]})344assert_frame_equal(result, expected)345346# check user-specified `batch_size` doesn't error (ref: #25316)347result = (348pl.scan_pyarrow_dataset(dataset, allow_pyarrow_filter=False, batch_size=1000)349.filter(pl.col("price") > 15)350.collect()351)352assert_frame_equal(result, expected)353354# check `allow_pyarrow_filter=True` still works355result = (356pl.scan_pyarrow_dataset(dataset, allow_pyarrow_filter=True)357.filter(pl.col("price") > 15)358.collect()359)360assert_frame_equal(result, expected)361362363def test_scan_pyarrow_dataset_filter_with_timezone_26029() -> None:364table = pa.table(365{366"valid_from": [367datetime(2025, 8, 26, 10, 0, 0, tzinfo=timezone.utc),368datetime(2025, 8, 26, 11, 0, 0, tzinfo=timezone.utc),369],370"valid_to": [371datetime(2025, 8, 26, 12, 0, 0, tzinfo=timezone.utc),372datetime(2025, 8, 26, 13, 0, 0, tzinfo=timezone.utc),373],374"value": [1, 2],375}376)377dataset = ds.dataset(table)378379lower_bound_time = datetime(2025, 8, 26, 11, 30, 0, tzinfo=timezone.utc)380lf = pl.scan_pyarrow_dataset(dataset).filter(381(pl.col("valid_from") <= lower_bound_time)382& (pl.col("valid_to") > lower_bound_time)383)384385assert_frame_equal(lf.collect(), pl.DataFrame(table))386387388def test_scan_pyarrow_dataset_filter_slice_order() -> None:389table = pa.table(390{391"index": [0, 1, 2],392"year": [2025, 2026, 2026],393"month": [0, 0, 0],394}395)396dataset = ds.dataset(table)397398q = pl.scan_pyarrow_dataset(dataset).head(2).filter(pl.col("year") == 2026)399400assert_frame_equal(401q.collect(),402pl.DataFrame({"index": 1, "year": 2026, "month": 0}),403)404405import polars.io.pyarrow_dataset.anonymous_scan406407assert_frame_equal(408polars.io.pyarrow_dataset.anonymous_scan._scan_pyarrow_dataset_impl(409dataset,410n_rows=2,411predicate="pa.compute.field('year') == 2026",412with_columns=None,413),414pl.DataFrame({"index": 1, "year": 2026, "month": 0}),415)416417assert_frame_equal(418polars.io.pyarrow_dataset.anonymous_scan._scan_pyarrow_dataset_impl(419dataset,420n_rows=0,421predicate="pa.compute.field('year') == 2026",422with_columns=None,423),424pl.DataFrame(schema={"index": pl.Int64, "year": pl.Int64, "month": pl.Int64}),425)426427assert_frame_equal(428pl.concat(429polars.io.pyarrow_dataset.anonymous_scan._scan_pyarrow_dataset_impl(430dataset,431n_rows=1,432predicate=None,433with_columns=None,434allow_pyarrow_filter=False,435)[0]436),437pl.DataFrame({"index": 0, "year": 2025, "month": 0}),438)439440assert not polars.io.pyarrow_dataset.anonymous_scan._scan_pyarrow_dataset_impl(441dataset,442n_rows=0,443predicate="pa.compute.field('year') == 2026",444with_columns=None,445allow_pyarrow_filter=False,446)[1]447448449