Path: blob/main/py-polars/tests/unit/lazyframe/test_predicates.py
8430 views
from __future__ import annotations12import re3from datetime import date, datetime, timedelta4from typing import TYPE_CHECKING, Any56import numpy as np7import pytest89import polars as pl10from polars.exceptions import ComputeError, InvalidOperationError11from polars.io.plugins import register_io_source12from polars.testing import assert_frame_equal13from polars.testing.asserts.series import assert_series_equal1415if TYPE_CHECKING:16from collections.abc import Iterator1718from tests.conftest import PlMonkeyPatch192021def test_predicate_4906() -> None:22one_day = timedelta(days=1)2324ldf = pl.DataFrame(25{26"dt": [27date(2022, 9, 1),28date(2022, 9, 10),29date(2022, 9, 20),30]31}32).lazy()3334assert ldf.filter(35pl.min_horizontal((pl.col("dt") + one_day), date(2022, 9, 30))36> date(2022, 9, 10)37).collect().to_dict(as_series=False) == {38"dt": [date(2022, 9, 10), date(2022, 9, 20)]39}404142def test_predicate_null_block_asof_join() -> None:43left = (44pl.DataFrame(45{46"id": [1, 2, 3, 4],47"timestamp": [48datetime(2022, 1, 1, 10, 0),49datetime(2022, 1, 1, 10, 1),50datetime(2022, 1, 1, 10, 2),51datetime(2022, 1, 1, 10, 3),52],53}54)55.lazy()56.set_sorted("timestamp")57)5859right = (60pl.DataFrame(61{62"id": [1, 2, 3] * 2,63"timestamp": [64datetime(2022, 1, 1, 9, 59, 50),65datetime(2022, 1, 1, 10, 0, 50),66datetime(2022, 1, 1, 10, 1, 50),67datetime(2022, 1, 1, 8, 0, 0),68datetime(2022, 1, 1, 8, 0, 0),69datetime(2022, 1, 1, 8, 0, 0),70],71"value": ["a", "b", "c"] * 2,72}73)74.lazy()75.set_sorted("timestamp")76)7778assert_frame_equal(79left.join_asof(right, by="id", on="timestamp")80.filter(pl.col("value").is_not_null())81.collect(),82pl.DataFrame(83{84"id": [1, 2, 3],85"timestamp": [86datetime(2022, 1, 1, 10, 0),87datetime(2022, 1, 1, 10, 1),88datetime(2022, 1, 1, 10, 2),89],90"value": ["a", "b", "c"],91}92),93check_row_order=False,94)959697def test_predicate_strptime_6558() -> None:98assert (99pl.DataFrame({"date": ["2022-01-03", "2020-01-04", "2021-02-03", "2019-01-04"]})100.lazy()101.select(pl.col("date").str.strptime(pl.Date, format="%F"))102.filter((pl.col("date").dt.year() == 2022) & (pl.col("date").dt.month() == 1))103.collect()104).to_dict(as_series=False) == {"date": [date(2022, 1, 3)]}105106107def test_predicate_arr_first_6573() -> None:108df = pl.DataFrame(109{110"a": [1, 2, 3, 4, 5, 6],111"b": [6, 5, 4, 3, 2, 1],112}113)114115assert (116df.lazy()117.with_columns(pl.col("a").implode())118.with_columns(pl.col("a").list.first())119.filter(pl.col("a") == pl.col("b"))120.collect()121).to_dict(as_series=False) == {"a": [1], "b": [1]}122123124def test_fast_path_comparisons() -> None:125s = pl.Series(np.sort(np.random.randint(0, 50, 100)))126127assert_series_equal(s > 25, s.set_sorted() > 25)128assert_series_equal(s >= 25, s.set_sorted() >= 25)129assert_series_equal(s < 25, s.set_sorted() < 25)130assert_series_equal(s <= 25, s.set_sorted() <= 25)131132133def test_predicate_pushdown_block_8661() -> None:134df = pl.DataFrame(135{136"g": [1, 1, 1, 1, 2, 2, 2, 2],137"t": [1, 2, 3, 4, 4, 3, 2, 1],138"x": [10, 20, 30, 40, 10, 20, 30, 40],139}140)141assert df.lazy().sort(["g", "t"]).filter(142(pl.col("x").shift() > 20).over("g")143).collect().to_dict(as_series=False) == {144"g": [1, 2, 2],145"t": [4, 2, 3],146"x": [40, 30, 20],147}148149150def test_predicate_pushdown_cumsum_9566() -> None:151df = pl.DataFrame({"A": range(10), "B": ["b"] * 5 + ["a"] * 5})152153q = df.lazy().sort(["B", "A"]).filter(pl.col("A").is_in([8, 2]).cum_sum() == 1)154155assert q.collect()["A"].to_list() == [8, 9, 0, 1]156157158def test_predicate_pushdown_join_fill_null_10058() -> None:159ids = pl.LazyFrame({"id": [0, 1, 2]})160filters = pl.LazyFrame({"id": [0, 1], "filter": [True, False]})161162assert sorted(163ids.join(filters, how="left", on="id")164.filter(pl.col("filter").fill_null(True))165.collect()166.to_dict(as_series=False)["id"]167) == [0, 2]168169170def test_is_in_join_blocked() -> None:171lf1 = pl.LazyFrame(172{"Groups": ["A", "B", "C", "D", "E", "F"], "values0": [1, 2, 3, 4, 5, 6]}173)174lf2 = pl.LazyFrame(175{"values_22": [1, 2, None, 4, 5, 6], "values_20": [1, 2, 3, 4, 5, 6]}176)177lf_all = lf2.join(178lf1,179left_on="values_20",180right_on="values0",181how="left",182maintain_order="right_left",183)184185for result in (186lf_all.filter(~pl.col("Groups").is_in(["A", "B", "F"])),187lf_all.remove(pl.col("Groups").is_in(["A", "B", "F"])),188):189expected = pl.LazyFrame(190{191"values_22": [None, 4, 5],192"values_20": [3, 4, 5],193"Groups": ["C", "D", "E"],194}195)196assert_frame_equal(result, expected)197198199def test_predicate_pushdown_group_by_keys() -> None:200df = pl.LazyFrame(201{"str": ["A", "B", "A", "B", "C"], "group": [1, 1, 2, 1, 2]}202).lazy()203q = (204df.group_by("group")205.agg([pl.len().alias("str_list")])206.filter(pl.col("group") == 1)207)208assert not q.explain().startswith("FILTER")209assert q.explain(210optimizations=pl.QueryOptFlags(predicate_pushdown=False)211).startswith("FILTER")212213214def test_no_predicate_push_down_with_cast_and_alias_11883() -> None:215df = pl.DataFrame({"a": [1, 2, 3]})216out = (217df.lazy()218.select(pl.col("a").cast(pl.Int64).alias("b"))219.filter(pl.col("b") == 1)220.filter((pl.col("b") >= 1) & (pl.col("b") < 1))221)222assert (223re.search(224r"FILTER.*FROM\n\s*DF",225out.explain(optimizations=pl.QueryOptFlags(predicate_pushdown=True)),226)227is None228)229230231@pytest.mark.parametrize(232"predicate",233[2340,235"x",236[2, 3],237{"x": 1},238pl.Series([1, 2, 3]),239None,240],241)242def test_invalid_filter_predicates(predicate: Any) -> None:243df = pl.DataFrame({"colx": ["aa", "bb", "cc", "dd"]})244with pytest.raises(TypeError, match="invalid predicate"):245df.filter(predicate)246247248def test_fast_path_boolean_filter_predicates() -> None:249df = pl.DataFrame({"colx": ["aa", "bb", "cc", "dd"]})250df_empty = df.clear()251252assert_frame_equal(df.filter(False), df_empty)253assert_frame_equal(df.filter(True), df)254255assert_frame_equal(df.remove(True), df_empty)256assert_frame_equal(df.remove(False), df)257258259def test_predicate_pushdown_boundary_12102() -> None:260df = pl.DataFrame({"x": [1, 2, 4], "y": [1, 2, 4]})261262lf = (263df.lazy()264.filter(pl.col("y") > 1)265.filter(pl.col("x") == pl.min("x"))266.filter(pl.col("y") > 2)267)268269result = lf.collect()270result_no_ppd = lf.collect(optimizations=pl.QueryOptFlags(predicate_pushdown=False))271assert_frame_equal(result, result_no_ppd)272273274def test_take_can_block_predicate_pushdown() -> None:275df = pl.DataFrame({"x": [1, 2, 4], "y": [False, True, True]})276lf = (277df.lazy()278.filter(pl.col("y"))279.filter(pl.col("x") == pl.col("x").gather(0))280.filter(pl.col("y"))281)282result = lf.collect(optimizations=pl.QueryOptFlags(predicate_pushdown=True))283assert result.to_dict(as_series=False) == {"x": [2], "y": [True]}284285286def test_literal_series_expr_predicate_pushdown() -> None:287# No pushdown should occur in this case, because otherwise the filter will288# attempt to filter 3 rows with a boolean mask of 2 rows.289lf = pl.LazyFrame({"x": [0, 1, 2]})290291for res in (292lf.filter(pl.col("x") > 0).filter(pl.Series([True, True])),293lf.remove(pl.col("x") <= 0).remove(pl.Series([False, False])),294):295assert res.collect().to_series().to_list() == [1, 2]296297# Pushdown should occur here; series is being used as part of an `is_in`.298for res in (299lf.filter(pl.col("x") > 0).filter(pl.col("x").is_in([0, 1])),300lf.remove(pl.col("x") <= 0).remove(~pl.col("x").is_in([0, 1])),301):302assert re.search(r"FILTER .*\nFROM\n\s*DF", res.explain(), re.DOTALL)303assert res.collect().to_series().to_list() == [1]304305306def test_multi_alias_pushdown() -> None:307lf = pl.LazyFrame({"a": [1], "b": [1]})308309actual = lf.with_columns(m="a", n="b").filter((pl.col("m") + pl.col("n")) < 2)310plan = actual.explain()311312assert plan.count("FILTER") == 1313assert re.search(r"FILTER.*FROM\n\s*DF", plan, re.DOTALL) is not None314315with pytest.warns(UserWarning, match="Comparisons with None always result in null"):316# confirm we aren't using `eq_missing` in the query plan (denoted as " ==v ")317assert " ==v " not in lf.select(pl.col("a").filter(a=None)).explain()318319320def test_predicate_pushdown_with_window_projections_12637() -> None:321lf = pl.LazyFrame(322{323"key": [1],324"key_2": [1],325"key_3": [1],326"value": [1],327"value_2": [1],328"value_3": [1],329}330)331332actual = lf.with_columns(333(pl.col("value") * 2).over("key").alias("value_2"),334(pl.col("value") * 2).over("key").alias("value_3"),335).filter(pl.col("key") == 5)336337plan = actual.explain()338339assert (340re.search(341r'FILTER \[\(col\("key"\)\) == \(5\)\]\s*FROM\n\s*DF', plan, re.DOTALL342)343is not None344)345assert plan.count("FILTER") == 1346347actual = (348lf.with_columns(349(pl.col("value") * 2).over("key", "key_2").alias("value_2"),350(pl.col("value") * 2).over("key", "key_2").alias("value_3"),351)352.filter(pl.col("key") == 5)353.filter(pl.col("key_2") == 5)354)355356plan = actual.explain()357assert plan.count("FILTER") == 1358assert re.search(r"FILTER.*FROM\n\s*DF", plan, re.DOTALL) is not None359actual = (360lf.with_columns(361(pl.col("value") * 2).over("key", "key_2").alias("value_2"),362(pl.col("value") * 2).over("key", "key_3").alias("value_3"),363)364.filter(pl.col("key") == 5)365.filter(pl.col("key_2") == 5)366)367368plan = actual.explain()369assert plan.count("FILTER") == 2370assert (371re.search(372r'FILTER \[\(col\("key"\)\) == \(5\)\]\s*FROM\n\s*DF', plan, re.DOTALL373)374is not None375)376377actual = (378lf.with_columns(379(pl.col("value") * 2).over("key", pl.col("key_2") + 1).alias("value_2"),380(pl.col("value") * 2).over("key", "key_2").alias("value_3"),381)382.filter(pl.col("key") == 5)383.filter(pl.col("key_2") == 5)384)385plan = actual.explain()386assert plan.count("FILTER") == 2387assert (388re.search(389r'FILTER \[\(col\("key"\)\) == \(5\)\]\s*FROM\n\s*DF', plan, re.DOTALL390)391is not None392)393394# Should block when .over() contains groups-sensitive expr395actual = (396lf.with_columns(397(pl.col("value") * 2).over("key", pl.sum("key_2")).alias("value_2"),398(pl.col("value") * 2).over("key", "key_2").alias("value_3"),399)400.filter(pl.col("key") == 5)401.filter(pl.col("key_2") == 5)402)403404plan = actual.explain()405assert plan.count("FILTER") == 1406assert "FILTER" in plan407assert re.search(r"FILTER.*FROM\n\s*DF", plan, re.DOTALL) is None408# Ensure the implementation doesn't accidentally push a window expression409# that only refers to the common window keys.410actual = lf.with_columns(411(pl.col("value") * 2).over("key").alias("value_2"),412).filter(pl.len().over("key") == 1)413414plan = actual.explain()415assert re.search(r"FILTER.*FROM\n\s*DF", plan, re.DOTALL) is None416assert plan.count("FILTER") == 1417418# Test window in filter419actual = lf.filter(pl.len().over("key") == 1).filter(pl.col("key") == 1)420plan = actual.explain()421assert plan.count("FILTER") == 2422assert (423re.search(424r'FILTER \[\(len\(\).over\(\[col\("key"\)\]\)\) == \(1\)\]\s*FROM\n\s*FILTER',425plan,426)427is not None428)429assert (430re.search(431r'FILTER \[\(col\("key"\)\) == \(1\)\]\s*FROM\n\s*DF', plan, re.DOTALL432)433is not None434)435436437def test_predicate_reduction() -> None:438# ensure we get clean reduction without casts439lf = pl.LazyFrame({"a": [1], "b": [2]})440for filter_frame in (lf.filter, lf.remove):441assert (442"cast"443not in filter_frame(444pl.col("a") > 1,445pl.col("b") > 1,446).explain()447)448449450def test_all_any_cleanup_at_single_predicate_case() -> None:451plan = pl.LazyFrame({"a": [1], "b": [2]}).select(["a"]).drop_nulls().explain()452assert "horizontal" not in plan453assert "all" not in plan454455456def test_hconcat_predicate() -> None:457# Predicates shouldn't be pushed down past an hconcat as we can't filter458# across the different inputs459lf1 = pl.LazyFrame(460{461"a1": [0, 1, 2, 3, 4],462"a2": [5, 6, 7, 8, 9],463}464)465lf2 = pl.LazyFrame(466{467"b1": [0, 1, 2, 3, 4],468"b2": [5, 6, 7, 8, 9],469}470)471472query = pl.concat(473[474lf1.filter(pl.col("a1") < 4),475lf2.filter(pl.col("b1") > 0),476],477how="horizontal",478).filter(pl.col("b2") < 9)479480expected = pl.DataFrame(481{482"a1": [0, 1, 2],483"a2": [5, 6, 7],484"b1": [1, 2, 3],485"b2": [6, 7, 8],486}487)488result = query.collect(optimizations=pl.QueryOptFlags(predicate_pushdown=True))489assert_frame_equal(result, expected)490491492def test_predicate_pd_join_13300() -> None:493# https://github.com/pola-rs/polars/issues/13300494495lf = pl.LazyFrame({"col3": range(10, 14), "new_col": range(11, 15)})496lf_other = pl.LazyFrame({"col4": [0, 11, 2, 13]})497498lf = lf.join(lf_other, left_on="new_col", right_on="col4", how="left")499for res in (500lf.filter(pl.col("new_col") < 12),501lf.remove(pl.col("new_col") >= 12),502):503assert res.collect().to_dict(as_series=False) == {"col3": [10], "new_col": [11]}504505506def test_filter_eq_missing_13861() -> None:507lf = pl.LazyFrame({"a": [1, None, 3], "b": ["xx", "yy", None]})508lf_empty = lf.clear()509510with pytest.warns(UserWarning, match="Comparisons with None always result in null"):511assert_frame_equal(lf.collect().filter(a=None), lf_empty.collect())512513with pytest.warns(UserWarning, match="Comparisons with None always result in null"):514assert_frame_equal(lf.collect().remove(a=None), lf.collect())515516with pytest.warns(UserWarning, match="Comparisons with None always result in null"):517lff = lf.filter(a=None)518assert lff.collect().rows() == []519assert " ==v " not in lff.explain() # check no `eq_missing` op520521with pytest.warns(UserWarning, match="Comparisons with None always result in null"):522assert_frame_equal(lf.collect().filter(a=None), lf_empty.collect())523524with pytest.warns(UserWarning, match="Comparisons with None always result in null"):525assert_frame_equal(lf.collect().remove(a=None), lf.collect())526527for filter_expr in (528pl.col("a").eq_missing(None),529pl.col("a").is_null(),530):531assert lf.collect().filter(filter_expr).rows() == [(None, "yy")]532533534@pytest.mark.parametrize("how", ["left", "inner"])535def test_predicate_pushdown_block_join(how: Any) -> None:536q = (537pl.LazyFrame({"a": [1]})538.join(539pl.LazyFrame({"a": [2], "b": [1]}),540left_on=["a"],541right_on=["b"],542how=how,543)544.filter(pl.col("a") == 1)545)546assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), q.collect())547548549def test_predicate_push_down_with_alias_15442() -> None:550df = pl.DataFrame({"a": [1]})551output = (552df.lazy()553.filter(pl.col("a").alias("x").drop_nulls() > 0)554.collect(optimizations=pl.QueryOptFlags(predicate_pushdown=True))555)556assert output.to_dict(as_series=False) == {"a": [1]}557558559def test_predicate_slice_pushdown_list_gather_17492(560plmonkeypatch: PlMonkeyPatch,561) -> None:562lf = pl.LazyFrame({"val": [[1], [1, 1]], "len": [1, 2]})563564assert_frame_equal(565lf.filter(pl.col("len") == 2).filter(pl.col("val").list.get(1) == 1),566lf.slice(1, 1),567)568569# null_on_oob=True can pass570571plan = (572lf.filter(pl.col("len") == 2)573.filter(pl.col("val").list.get(1, null_on_oob=True) == 1)574.explain()575)576577assert re.search(r"FILTER.*FROM\n\s*DF", plan, re.DOTALL) is not None578579# Also check slice pushdown580q = lf.with_columns(pl.col("val").list.get(1).alias("b")).slice(1, 1)581582assert_frame_equal(583q.collect(),584pl.DataFrame(585{586"val": [[1, 1]],587"len": pl.Series([2], dtype=pl.Int64),588"b": pl.Series([1], dtype=pl.Int64),589}590),591)592593594def test_predicate_pushdown_struct_unnest_19632() -> None:595lf = pl.LazyFrame({"a": [{"a": 1, "b": 2}]}).unnest("a")596597q = lf.filter(pl.col("a") == 1)598plan = q.explain()599600assert "FILTER" in plan601assert plan.index("FILTER") < plan.index("UNNEST")602603assert_frame_equal(604q.collect(),605pl.DataFrame({"a": 1, "b": 2}),606)607608# With `pl.struct()`609lf = pl.LazyFrame({"a": 1, "b": 2}).select(pl.struct(pl.all())).unnest("a")610611q = lf.filter(pl.col("a") == 1)612plan = q.explain()613614assert "FILTER" in plan615assert plan.index("FILTER") < plan.index("UNNEST")616617assert_frame_equal(618q.collect(),619pl.DataFrame({"a": 1, "b": 2}),620)621622# With `value_counts()`623lf = pl.LazyFrame({"a": [1]}).select(pl.col("a").value_counts()).unnest("a")624625q = lf.filter(pl.col("a") == 1)626plan = q.explain()627628assert plan.index("FILTER") < plan.index("UNNEST")629630assert_frame_equal(631q.collect(),632pl.DataFrame(633{"a": 1, "count": 1}, schema={"a": pl.Int64, "count": pl.get_index_type()}634),635)636637638@pytest.mark.parametrize(639"predicate",640[641pl.col("v") == 7,642pl.col("v") != 99,643pl.col("v") > 0,644pl.col("v") < 999,645pl.col("v").is_in([7]),646pl.col("v").cast(pl.Boolean),647pl.col("b"),648],649)650@pytest.mark.parametrize("alias", [True, False])651@pytest.mark.parametrize("join_type", ["left", "right"])652def test_predicate_pushdown_join_19772(653predicate: pl.Expr, join_type: str, alias: bool654) -> None:655left = pl.LazyFrame({"k": [1, 2]})656right = pl.LazyFrame({"k": [1], "v": [7], "b": True})657658if join_type == "right":659[left, right] = [right, left]660661if alias:662predicate = predicate.alias(":V")663664q = left.join(right, on="k", how=join_type).filter(predicate) # type: ignore[arg-type]665666expect = pl.DataFrame({"k": 1, "v": 7, "b": True})667668if join_type == "right":669expect = expect.select("v", "b", "k")670671assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)672assert_frame_equal(q.collect(), expect)673674675def test_predicate_pushdown_scalar_20489() -> None:676df = pl.DataFrame({"a": [1]})677mask = pl.Series([False])678679assert_frame_equal(680df.lazy().with_columns(b=pl.Series([2])).filter(mask).collect(),681pl.DataFrame(schema={"a": pl.Int64, "b": pl.Int64}),682)683684685def test_predicates_not_split_when_pushdown_disabled_20475() -> None:686# This is important for the eager `DataFrame.filter()`, as that runs without687# predicate pushdown enabled. Splitting the predicates in that case can688# severely degrade performance.689q = pl.LazyFrame({"a": 1, "b": 1, "c": 1}).filter(690pl.col("a") > 0, pl.col("b") > 0, pl.col("c") > 0691)692assert (693q.explain(optimizations=pl.QueryOptFlags(predicate_pushdown=False)).count(694"FILTER"695)696== 1697)698699700def test_predicate_filtering_against_nulls() -> None:701df = pl.DataFrame({"num": [1, 2, None, 4]})702703for res in (704df.filter(pl.col("num") > 2),705df.filter(pl.col("num").is_in([3, 4, 5])),706):707assert res["num"].to_list() == [4]708709for res in (710df.remove(pl.col("num") <= 2),711df.remove(pl.col("num").is_in([1, 2, 3])),712):713assert res["num"].to_list() == [None, 4]714715for res in (716df.filter(pl.col("num").ne_missing(None)),717df.remove(pl.col("num").eq_missing(None)),718):719assert res["num"].to_list() == [1, 2, 4]720721722@pytest.mark.parametrize(723("query", "expected"),724[725(726(727pl.LazyFrame({"a": [1], "b": [2], "c": [3]})728.rename({"a": "A", "b": "a"})729.select("A", "c")730.filter(pl.col("A") == 1)731),732pl.DataFrame({"A": 1, "c": 3}),733),734(735(736pl.LazyFrame({"a": [1], "b": [2], "c": [3]})737.rename({"b": "a", "a": "A"})738.select("A", "c")739.filter(pl.col("A") == 1)740),741pl.DataFrame({"A": 1, "c": 3}),742),743(744(745pl.LazyFrame({"a": [1], "b": [2], "c": [3]})746.rename({"a": "b", "b": "a"})747.select("a", "b", "c")748.filter(pl.col("b") == 1)749),750pl.DataFrame({"a": 2, "b": 1, "c": 3}),751),752(753(754pl.LazyFrame({"a": [1], "b": [2], "c": [3]})755.rename({"a": "b", "b": "a"})756.select("b", "c")757.filter(pl.col("b") == 1)758),759pl.DataFrame({"b": 1, "c": 3}),760),761(762(763pl.LazyFrame({"a": [1], "b": [2], "c": [3]})764.rename({"b": "a", "a": "b"})765.select("a", "b", "c")766.filter(pl.col("b") == 1)767),768pl.DataFrame({"a": 2, "b": 1, "c": 3}),769),770],771)772def test_predicate_pushdown_lazy_rename_22373(773query: pl.LazyFrame,774expected: pl.DataFrame,775) -> None:776assert_frame_equal(777query.collect(),778expected,779)780781# Ensure filter is pushed past rename782plan = query.explain()783assert plan.index("FILTER") > plan.index("SELECT")784785786@pytest.mark.parametrize(787"base_query",788[789( # Fallible expr in earlier `with_columns()`790pl.LazyFrame({"a": [[1]]})791.with_columns(MARKER=1)792.with_columns(b=pl.col("a").list.get(1, null_on_oob=False))793),794( # Fallible expr in earlier `filter()`795pl.LazyFrame({"a": [[1]]})796.with_columns(MARKER=1)797.filter(798pl.col("a")799.list.get(1, null_on_oob=False)800.cast(pl.Boolean, strict=False)801)802),803( # Fallible expr in earlier `select()`804pl.LazyFrame({"a": [[1]]})805.with_columns(MARKER=1)806.select("a", "MARKER", b=pl.col("a").list.get(1, null_on_oob=False))807),808],809)810def test_predicate_pushdown_pushes_past_fallible(811base_query: pl.LazyFrame, plmonkeypatch: PlMonkeyPatch812) -> None:813# Ensure baseline fails814with pytest.raises(ComputeError, match="index is out of bounds"):815base_query.collect()816817q = base_query.filter(pl.col("a").list.len() > 1)818819plan = q.explain()820821assert plan.index("list.len") > plan.index("MARKER")822823assert_frame_equal(q.collect(), pl.DataFrame(schema=q.collect_schema()))824825plmonkeypatch.setenv("POLARS_PUSHDOWN_OPT_MAINTAIN_ERRORS", "1")826827with pytest.raises(ComputeError, match="index is out of bounds"):828q.collect()829830831def test_predicate_pushdown_fallible_exprs_22284(832plmonkeypatch: PlMonkeyPatch,833) -> None:834q = (835pl.LazyFrame({"a": ["xyz", "123", "456", "789"]})836.with_columns(MARKER=1)837.filter(pl.col.a.str.contains(r"^\d{3}$"))838.filter(pl.col.a.cast(pl.Int64) >= 123)839)840841plan = q.explain()842843assert (844plan.index('FILTER [(col("a").strict_cast(Int64)) >= (123)]')845< plan.index("MARKER")846< plan.index(r'FILTER col("a").str.contains(["^\d{3}$"])')847)848849assert_frame_equal(850q.collect(),851pl.DataFrame(852{853"a": ["123", "456", "789"],854"MARKER": 1,855}856),857)858859lf = pl.LazyFrame(860{861"str_date": ["2025-01-01", "20250101"],862"data_source": ["system_1", "system_2"],863}864)865866q = lf.filter(pl.col("data_source") == "system_1").filter(867pl.col("str_date").str.to_datetime("%Y-%m-%d", strict=True)868== datetime(2025, 1, 1)869)870871assert_frame_equal(872q.collect(),873pl.DataFrame(874{875"str_date": ["2025-01-01"],876"data_source": ["system_1"],877}878),879)880881q = lf.with_columns(882pl.col("str_date").str.to_datetime("%Y-%m-%d", strict=True)883).filter(pl.col("data_source") == "system_1")884885assert_frame_equal(886q.collect(),887pl.DataFrame(888{889"str_date": [datetime(2025, 1, 1)],890"data_source": ["system_1"],891}892),893)894895plmonkeypatch.setenv("POLARS_PUSHDOWN_OPT_MAINTAIN_ERRORS", "1")896897with pytest.raises(898InvalidOperationError, match=r"`str` to `datetime\[μs\]` failed"899):900q.collect()901902903def test_predicate_pushdown_single_fallible() -> None:904lf = pl.LazyFrame({"a": [0, 1]}).with_columns(MARKER=pl.lit(1, dtype=pl.Int64))905906q = lf.filter(pl.col("a").cast(pl.Boolean))907908plan = q.explain()909910assert plan.index('FILTER col("a").strict_cast(Boolean)') > plan.index("MARKER")911912assert_frame_equal(q.collect(), pl.DataFrame({"a": 1, "MARKER": 1}))913914915def test_predicate_pushdown_split_pushable(916plmonkeypatch: PlMonkeyPatch,917) -> None:918lf = pl.LazyFrame({"a": [1, 999]}).with_columns(MARKER=pl.lit(1, dtype=pl.Int64))919920q = lf.filter(921pl.col("a") == 1, # pushable922pl.col("a").cast(pl.Int8) == 1, # fallible923)924925plan = q.explain()926927assert (928plan.index('FILTER [(col("a").strict_cast(Int8)) == (1)]')929< plan.index("MARKER")930< plan.index('FILTER [(col("a")) == (1)]')931)932933assert_frame_equal(q.collect(), pl.DataFrame({"a": 1, "MARKER": 1}))934935with plmonkeypatch.context() as cx:936cx.setenv("POLARS_PUSHDOWN_OPT_MAINTAIN_ERRORS", "1")937938with pytest.raises(939InvalidOperationError, match="conversion from `i64` to `i8` failed"940):941q.collect()942943q = lf.filter(944pl.col("a").cast(pl.UInt16) == 1,945pl.col("a").sort() == 1,946)947948plan = q.explain()949950assert plan.index(951'FILTER [([(col("a").strict_cast(UInt16)) == (1)]) & ([(col("a").sort(asc)) == (1)])]'952) < plan.index("MARKER")953954assert_frame_equal(q.collect(), pl.DataFrame({"a": 1, "MARKER": 1}))955956with plmonkeypatch.context() as cx:957cx.setenv("POLARS_PUSHDOWN_OPT_MAINTAIN_ERRORS", "1")958assert_frame_equal(q.collect(), pl.DataFrame({"a": 1, "MARKER": 1}))959960# Ensure it is not pushed past a join961962# Baseline963q = lf.join(964lf.drop("MARKER").collect().lazy(),965on="a",966how="inner",967coalesce=False,968maintain_order="left_right",969).filter(pl.col("a_right") == 1)970971plan = q.explain()972973assert not plan.startswith("FILTER")974975assert_frame_equal(976q.collect(),977pl.DataFrame(978{979"a": 1,980"MARKER": 1,981"a_right": 1,982}983),984)985986q = lf.join(987lf.drop("MARKER").collect().lazy(),988on="a",989how="inner",990coalesce=False,991maintain_order="left_right",992).filter(pl.col("a_right").cast(pl.Int16) == 1)993994plan = q.explain()995996assert plan.startswith("FILTER")997998assert_frame_equal(999q.collect(),1000pl.DataFrame(1001{1002"a": 1,1003"MARKER": 1,1004"a_right": 1,1005}1006),1007)10081009# With a select node in between10101011q = (1012lf.join(1013lf.drop("MARKER").collect().lazy(),1014on="a",1015how="inner",1016coalesce=False,1017maintain_order="left_right",1018)1019.select(1020"a",1021"a_right",1022"MARKER",1023)1024.filter(pl.col("a_right").cast(pl.Int16) == 1)1025)10261027plan = q.explain()10281029assert plan.startswith("FILTER")10301031assert_frame_equal(1032q.collect(),1033pl.DataFrame(1034{1035"a": 1,1036"a_right": 1,1037"MARKER": 1,1038}1039),1040)104110421043def test_predicate_pushdown_fallible_literal_in_filter_expr() -> None:1044# Fallible operations on literals inside of the predicate expr should not1045# block pushdown.10461047# Pushdown will also push any fallible expression if it's the only accumulated1048# predicate, we insert this dummy predicate to ensure the predicate is being1049# pushed solely because it is considered infallible.1050dummy_predicate = pl.lit(1) == pl.lit(1)10511052lf = pl.LazyFrame(1053{"column": "2025-01-01", "column_date": datetime(2025, 1, 1), "integer": 1}1054)10551056q = lf.with_columns(1057MARKER=1,1058).filter(1059pl.col("column_date")1060== pl.lit("2025-01-01").str.to_datetime("%Y-%m-%d", strict=True),1061dummy_predicate,1062)10631064plan = q.explain()10651066assert plan.index("FILTER") > plan.index("MARKER")10671068assert q.collect().height == 110691070q = lf.with_columns(1071MARKER=1,1072).filter(1073pl.col("column_date") == pl.lit("2025-01-01").str.strptime(pl.Datetime),1074dummy_predicate,1075)10761077plan = q.explain()10781079assert plan.index("FILTER") > plan.index("MARKER")10801081assert q.collect().height == 110821083q = lf.with_columns(1084MARKER=1,1085).filter(1086pl.col("integer") == pl.lit("1").cast(pl.Int64, strict=True), dummy_predicate1087)10881089plan = q.explain()10901091assert plan.index("FILTER") > plan.index("MARKER")10921093assert q.collect().height == 1109410951096def test_predicate_does_not_split_barrier_expr() -> None:1097q = (1098pl.LazyFrame({"a": [1, 2, 3]})1099.with_row_index()1100.filter(pl.col("a") > 1, pl.col("a").sort() == 3)1101)11021103plan = q.explain()11041105assert plan.startswith(1106'FILTER [([(col("a")) > (1)]) & ([(col("a").sort(asc)) == (3)])]'1107)11081109assert_frame_equal(1110q.collect(),1111pl.DataFrame({"a": 3}).with_row_index(offset=2),1112)111311141115def test_predicate_passes_set_sorted_22397() -> None:1116plan = (1117pl.LazyFrame({"a": [1, 2, 3]})1118.with_columns(MARKER=1, b=pl.lit(1))1119.set_sorted("a")1120.filter(pl.col("a") <= 1)1121.explain()1122)1123assert plan.index("FILTER") > plan.index("MARKER")112411251126@pytest.mark.filterwarnings("ignore")1127def test_predicate_pass() -> None:1128plan = (1129pl.LazyFrame({"a": [1, 2, 3]})1130.with_columns(MARKER=pl.col("a"))1131.filter(pl.col("a").map_elements(lambda x: x > 2, return_dtype=pl.Boolean))1132.explain()1133)1134assert plan.index("FILTER") > plan.index("MARKER")113511361137def test_predicate_pushdown_auto_disable_strict() -> None:1138# Test that type-coercion automatically switches strict cast to1139# non-strict/overflowing for compatible types, allowing the predicate to be1140# pushed.1141lf = pl.LazyFrame(1142{"column": "2025-01-01", "column_date": datetime(2025, 1, 1), "integer": 1},1143schema={1144"column": pl.String,1145"column_date": pl.Datetime("ns"),1146"integer": pl.Int64,1147},1148)11491150q = lf.with_columns(1151MARKER=1,1152).filter(1153pl.col("column_date").cast(pl.Datetime("us")) == pl.lit(datetime(2025, 1, 1)),1154pl.col("integer") == 1,1155)11561157plan = q.explain()1158assert plan.index("FILTER") > plan.index("MARKER")11591160q = lf.with_columns(1161MARKER=1,1162).filter(1163pl.col("column_date").cast(pl.Datetime("us"), strict=False)1164== pl.lit(datetime(2025, 1, 1)),1165pl.col("integer").cast(pl.Int128, strict=True) == 1,1166)11671168plan = q.explain()1169assert plan.index("FILTER") > plan.index("MARKER")117011711172@pytest.mark.may_fail_auto_streaming # IO plugin validate=False schema mismatch1173def test_predicate_pushdown_map_elements_io_plugin_22860() -> None:1174def generator(1175with_columns: list[str] | None,1176predicate: pl.Expr | None,1177n_rows: int | None,1178batch_size: int | None,1179) -> Iterator[pl.DataFrame]:1180df = pl.DataFrame({"row_nr": [1, 2, 3, 4, 5], "y": [0, 1, 0, 1, 1]})1181assert predicate is not None1182yield df.filter(predicate)11831184q = register_io_source(1185io_source=generator, schema={"x": pl.Int64, "y": pl.Int64}1186).filter(pl.col("y").map_elements(bool, return_dtype=pl.Boolean))11871188plan = q.explain()1189assert plan.index("SELECTION") > plan.index("PYTHON SCAN")11901191assert_frame_equal(q.collect(), pl.DataFrame({"row_nr": [2, 4, 5], "y": [1, 1, 1]}))119211931194def test_duplicate_filter_removal_23243() -> None:1195lf = pl.LazyFrame({"x": [1, 2, 3]})11961197q = lf.filter(pl.col("x") == 2, pl.col("x") == 2)11981199expect = pl.DataFrame({"x": [2]})12001201plan = q.explain()12021203assert plan.split("\n", 1)[0] == 'FILTER [(col("x")) == (2)]'12041205assert_frame_equal(q.collect(), expect)120612071208@pytest.mark.parametrize("maintain_order", [True, False])1209def test_no_predicate_pushdown_on_modified_groupby_keys_21439(1210maintain_order: bool,1211) -> None:1212df = pl.DataFrame({"a": [1, 2, 3]})1213q = (1214df.lazy()1215.group_by(pl.col.a + 1, maintain_order=maintain_order)1216.agg()1217.filter(pl.col.a <= 3)1218)1219expected = pl.DataFrame({"a": [2, 3]})1220assert_frame_equal(q.collect(), expected, check_row_order=maintain_order)12211222df = pl.DataFrame({"a": [1, 2, 3], "b": [1, 2, 3]})1223q = (1224df.lazy()1225.group_by([(pl.col.a + 1).alias("b"), pl.col.b.alias("a")], maintain_order=True)1226.agg()1227.filter(pl.col.b <= 2)1228.select(pl.col.b)1229)1230expected = pl.DataFrame({"b": [2]})1231assert_frame_equal(q.collect(), expected, check_row_order=maintain_order)123212331234def test_no_predicate_pushdown_on_modified_groupby_keys_21439b() -> None:1235df = pl.DataFrame(1236{1237"time": pl.datetime_range(1238datetime(2021, 1, 1),1239datetime(2021, 1, 2),1240timedelta(minutes=15),1241eager=True,1242)1243}1244)1245eager = (1246df.group_by(pl.col("time").dt.hour())1247.agg()1248.filter(pl.col("time").is_between(0, 10))1249)1250lazy = (1251df.lazy()1252.group_by(pl.col("time").dt.hour())1253.agg()1254.filter(pl.col("time").is_between(0, 10))1255.collect()1256)1257assert_frame_equal(eager, lazy, check_row_order=False)125812591260def test_no_predicate_pushdown_unpivot() -> None:1261data = {"a": [5, 2, 8, 2], "b": [99, 33, 77, 44]}12621263for index, pred in [("a", pl.col.a == 2), (["b", "a"], pl.col.b != 33)]:1264lf = pl.LazyFrame(data).unpivot(on="b", index=index).filter(pred)1265plan = lf.explain()1266assert plan.index("FILTER") > plan.index("UNPIVOT")126712681269def test_replace_strict_predicate_merging() -> None:1270df = pl.LazyFrame({"x": [True, True, True, False]})1271out = (1272df.filter(pl.col("x")).filter(pl.col("x").replace_strict(True, True)).collect()1273)1274assert out.height == 3127512761277