Path: blob/main/py-polars/tests/unit/lazyframe/test_cse.py
8430 views
from __future__ import annotations12import re3from datetime import date, datetime, timedelta4from pathlib import Path5from tempfile import NamedTemporaryFile, TemporaryDirectory6from typing import TYPE_CHECKING, Any, TypeVar7from unittest.mock import Mock89import numpy as np10import pytest1112import polars as pl13from polars.io.plugins import register_io_source14from polars.testing import assert_frame_equal1516if TYPE_CHECKING:17from collections.abc import Iterator1819from tests.conftest import PlMonkeyPatch202122def num_cse_occurrences(explanation: str) -> int:23"""The number of unique CSE columns in an explain string."""24return len(set(re.findall(r'__POLARS_CSER_0x[^"]+"', explanation)))252627def create_dataframe_source(28source_df: pl.DataFrame,29is_pure: bool,30validate_schame: bool = False,31) -> pl.LazyFrame:32"""Generates a custom io source based on the provided pl.DataFrame."""3334def dataframe_source(35with_columns: list[str] | None,36predicate: pl.Expr | None,37_n_rows: int | None,38_batch_size: int | None,39) -> Iterator[pl.DataFrame]:40df = source_df.clone()41if predicate is not None:42df = df.filter(predicate)43if with_columns is not None:44df = df.select(with_columns)45yield df4647return register_io_source(48dataframe_source,49schema=source_df.schema,50validate_schema=validate_schame,51is_pure=is_pure,52)535455@pytest.mark.parametrize("use_custom_io_source", [True, False])56def test_cse_rename_cross_join_5405(use_custom_io_source: bool) -> None:57# https://github.com/pola-rs/polars/issues/54055859right = pl.DataFrame({"A": [1, 2], "B": [3, 4], "D": [5, 6]}).lazy()60if use_custom_io_source:61right = create_dataframe_source(right.collect(), is_pure=True)62left = pl.DataFrame({"C": [3, 4]}).lazy().join(right.select("A"), how="cross")6364result = left.join(right.rename({"B": "C"}), on=["A", "C"], how="left").collect(65optimizations=pl.QueryOptFlags(comm_subplan_elim=True)66)6768expected = pl.DataFrame(69{70"C": [3, 3, 4, 4],71"A": [1, 2, 1, 2],72"D": [5, None, None, 6],73}74)75assert_frame_equal(result, expected, check_row_order=False)767778def test_union_duplicates() -> None:79n_dfs = 1080df_lazy = pl.DataFrame({}).lazy()81lazy_dfs = [df_lazy for _ in range(n_dfs)]8283matches = re.findall(r"CACHE\[id: (.*)]", pl.concat(lazy_dfs).explain())8485assert len(matches) == 1086assert len(set(matches)) == 1878889def test_cse_with_struct_expr_11116() -> None:90# https://github.com/pola-rs/polars/issues/111169192df = pl.DataFrame([{"s": {"a": 1, "b": 4}, "c": 3}]).lazy()9394result = df.with_columns(95pl.col("s").struct.field("a").alias("s_a"),96pl.col("s").struct.field("b").alias("s_b"),97(98(pl.col("s").struct.field("a") <= pl.col("c"))99& (pl.col("s").struct.field("b") > pl.col("c"))100).alias("c_between_a_and_b"),101).collect(optimizations=pl.QueryOptFlags(comm_subexpr_elim=True))102103expected = pl.DataFrame(104{105"s": [{"a": 1, "b": 4}],106"c": [3],107"s_a": [1],108"s_b": [4],109"c_between_a_and_b": [True],110}111)112assert_frame_equal(result, expected)113114115def test_cse_schema_6081() -> None:116# https://github.com/pola-rs/polars/issues/6081117118df = pl.DataFrame(119data=[120[date(2022, 12, 12), 1, 1],121[date(2022, 12, 12), 1, 2],122[date(2022, 12, 13), 5, 2],123],124schema=["date", "id", "value"],125orient="row",126).lazy()127128min_value_by_group = df.group_by(["date", "id"]).agg(129pl.col("value").min().alias("min_value")130)131132result = df.join(min_value_by_group, on=["date", "id"], how="left").collect(133optimizations=pl.QueryOptFlags(comm_subplan_elim=True, projection_pushdown=True)134)135expected = pl.DataFrame(136{137"date": [date(2022, 12, 12), date(2022, 12, 12), date(2022, 12, 13)],138"id": [1, 1, 5],139"value": [1, 2, 2],140"min_value": [1, 1, 2],141}142)143assert_frame_equal(result, expected, check_row_order=False)144145146def test_cse_9630() -> None:147lf1 = pl.LazyFrame({"key": [1], "x": [1]})148lf2 = pl.LazyFrame({"key": [1], "y": [2]})149150joined_lf2 = lf1.join(lf2, on="key")151152all_subsections = (153pl.concat(154[155lf1.select("key", pl.col("x").alias("value")),156joined_lf2.select("key", pl.col("y").alias("value")),157]158)159.group_by("key")160.agg(pl.col("value"))161)162163intersected_df1 = all_subsections.join(lf1, on="key")164intersected_df2 = all_subsections.join(lf2, on="key")165166result = intersected_df1.join(intersected_df2, on=["key"], how="left").collect(167optimizations=pl.QueryOptFlags(comm_subplan_elim=True)168)169170expected = pl.DataFrame(171{172"key": [1],173"value": [[1, 2]],174"x": [1],175"value_right": [[1, 2]],176"y": [2],177}178)179assert_frame_equal(result, expected)180181182@pytest.mark.write_disk183@pytest.mark.parametrize("maintain_order", [False, True])184def test_schema_row_index_cse(maintain_order: bool) -> None:185with NamedTemporaryFile() as csv_a:186csv_a.write(b"A,B\nGr1,A\nGr1,B")187csv_a.seek(0)188189df_a = pl.scan_csv(csv_a.name).with_row_index("Idx")190191result = (192df_a.join(df_a, on="B", maintain_order="left" if maintain_order else "none")193.group_by("A", maintain_order=maintain_order)194.all()195.collect(optimizations=pl.QueryOptFlags(comm_subexpr_elim=True))196)197198expected = pl.DataFrame(199{200"A": ["Gr1"],201"Idx": [[0, 1]],202"B": [["A", "B"]],203"Idx_right": [[0, 1]],204"A_right": [["Gr1", "Gr1"]],205},206schema_overrides={"Idx": pl.List(pl.UInt32), "Idx_right": pl.List(pl.UInt32)},207)208assert_frame_equal(result, expected, check_row_order=maintain_order)209210211@pytest.mark.debug212def test_cse_expr_selection_context() -> None:213q = pl.LazyFrame(214{215"a": [1, 2, 3, 4],216"b": [1, 2, 3, 4],217"c": [1, 2, 3, 4],218}219)220221derived = (pl.col("a") * pl.col("b")).sum()222derived2 = derived * derived223224exprs = [225derived.alias("d1"),226(derived * pl.col("c").sum() - 1).alias("foo"),227derived2.alias("d2"),228(derived2 * 10).alias("d3"),229]230231result = q.select(exprs).collect(232optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)233)234assert (235num_cse_occurrences(236q.select(exprs).explain(237optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)238)239)240== 2241)242expected = pl.DataFrame(243{244"d1": [30],245"foo": [299],246"d2": [900],247"d3": [9000],248}249)250assert_frame_equal(result, expected)251252result = q.with_columns(exprs).collect(253optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)254)255assert (256num_cse_occurrences(257q.with_columns(exprs).explain(258optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)259)260)261== 2262)263expected = pl.DataFrame(264{265"a": [1, 2, 3, 4],266"b": [1, 2, 3, 4],267"c": [1, 2, 3, 4],268"d1": [30, 30, 30, 30],269"foo": [299, 299, 299, 299],270"d2": [900, 900, 900, 900],271"d3": [9000, 9000, 9000, 9000],272}273)274assert_frame_equal(result, expected)275276277def test_windows_cse_excluded() -> None:278lf = pl.LazyFrame(279data=[280("a", "aaa", 1),281("a", "bbb", 3),282("a", "ccc", 1),283("c", "xxx", 2),284("c", "yyy", 3),285("c", "zzz", 4),286("b", "qqq", 0),287],288schema=["a", "b", "c"],289orient="row",290)291292result = lf.select(293c_diff=pl.col("c").diff(1),294c_diff_by_a=pl.col("c").diff(1).over("a"),295).collect(optimizations=pl.QueryOptFlags(comm_subexpr_elim=True))296297expected = pl.DataFrame(298{299"c_diff": [None, 2, -2, 1, 1, 1, -4],300"c_diff_by_a": [None, 2, -2, None, 1, 1, None],301}302)303assert_frame_equal(result, expected)304305306def test_cse_group_by_10215() -> None:307lf = pl.LazyFrame({"a": [1], "b": [1]})308309result = lf.group_by("b").agg(310(pl.col("a").sum() * pl.col("a").sum()).alias("x"),311(pl.col("b").sum() * pl.col("b").sum()).alias("y"),312(pl.col("a").sum() * pl.col("a").sum()).alias("x2"),313((pl.col("a") + 2).sum() * pl.col("a").sum()).alias("x3"),314((pl.col("a") + 2).sum() * pl.col("b").sum()).alias("x4"),315((pl.col("a") + 2).sum() * pl.col("b").sum()),316)317318assert "__POLARS_CSER" in result.explain(319optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)320)321expected = pl.DataFrame(322{323"b": [1],324"x": [1],325"y": [1],326"x2": [1],327"x3": [3],328"x4": [3],329"a": [3],330}331)332assert_frame_equal(333result.collect(optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)), expected334)335336337def test_cse_mixed_window_functions() -> None:338# checks if the window caches are cleared339# there are windows in the cse's and the default expressions340lf = pl.LazyFrame({"a": [1], "b": [1], "c": [1]})341342result = lf.select(343pl.col("a"),344pl.col("b"),345pl.col("c"),346pl.col("b").rank().alias("rank"),347pl.col("b").rank().alias("d_rank"),348pl.col("b").first().over([pl.col("a")]).alias("b_first"),349pl.col("b").last().over([pl.col("a")]).alias("b_last"),350pl.col("b").item().over([pl.col("a")]).alias("b_item"),351pl.col("b").shift().alias("b_lag_1"),352pl.col("b").shift().alias("b_lead_1"),353pl.col("c").cum_sum().alias("c_cumsum"),354pl.col("c").cum_sum().over([pl.col("a")]).alias("c_cumsum_by_a"),355pl.col("c").diff().alias("c_diff"),356pl.col("c").diff().over([pl.col("a")]).alias("c_diff_by_a"),357).collect(optimizations=pl.QueryOptFlags(comm_subexpr_elim=True))358359expected = pl.DataFrame(360{361"a": [1],362"b": [1],363"c": [1],364"rank": [1.0],365"d_rank": [1.0],366"b_first": [1],367"b_last": [1],368"b_item": [1],369"b_lag_1": [None],370"b_lead_1": [None],371"c_cumsum": [1],372"c_cumsum_by_a": [1],373"c_diff": [None],374"c_diff_by_a": [None],375},376).with_columns(pl.col(pl.Null).cast(pl.Int64))377assert_frame_equal(result, expected)378379380def test_cse_10401() -> None:381df = pl.LazyFrame({"clicks": [1.0, float("nan"), None]})382383q = df.with_columns(pl.all().fill_null(0).fill_nan(0))384385assert r"""col("clicks").fill_null([0.0]).alias("__POLARS_CSER""" in q.explain()386387expected = pl.DataFrame({"clicks": [1.0, 0.0, 0.0]})388assert_frame_equal(389q.collect(optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)), expected390)391392393def test_cse_10441() -> None:394lf = pl.LazyFrame({"a": [1, 2, 3], "b": [3, 2, 1]})395396result = lf.select(397pl.col("a").sum() + pl.col("a").sum() + pl.col("b").sum()398).collect(optimizations=pl.QueryOptFlags(comm_subexpr_elim=True))399400expected = pl.DataFrame({"a": [18]})401assert_frame_equal(result, expected)402403404def test_cse_10452() -> None:405lf = pl.LazyFrame({"a": [1, 2, 3], "b": [3, 2, 1]})406q = lf.select(407pl.col("b").sum() + pl.col("a").sum().over(pl.col("b")) + pl.col("b").sum()408)409410assert "__POLARS_CSE" in q.explain(411optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)412)413414expected = pl.DataFrame({"b": [13, 14, 15]})415assert_frame_equal(416q.collect(optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)), expected417)418419420def test_cse_group_by_ternary_10490() -> None:421lf = pl.LazyFrame(422{423"a": [1, 1, 2, 2],424"b": [1, 2, 3, 4],425"c": [2, 3, 4, 5],426}427)428429result = (430lf.group_by("a")431.agg(432[433pl.when(pl.col(col).is_null().all()).then(None).otherwise(1).alias(col)434for col in ["b", "c"]435]436+ [437(pl.col("a").sum() * pl.col("a").sum()).alias("x"),438(pl.col("b").sum() * pl.col("b").sum()).alias("y"),439(pl.col("a").sum() * pl.col("a").sum()).alias("x2"),440((pl.col("a") + 2).sum() * pl.col("a").sum()).alias("x3"),441((pl.col("a") + 2).sum() * pl.col("b").sum()).alias("x4"),442]443)444.collect(optimizations=pl.QueryOptFlags(comm_subexpr_elim=True))445.sort("a")446)447448expected = pl.DataFrame(449{450"a": [1, 2],451"b": [1, 1],452"c": [1, 1],453"x": [4, 16],454"y": [9, 49],455"x2": [4, 16],456"x3": [12, 32],457"x4": [18, 56],458},459schema_overrides={"b": pl.Int32, "c": pl.Int32},460)461assert_frame_equal(result, expected)462463464def test_cse_quantile_10815() -> None:465np.random.seed(1)466a = np.random.random(10)467b = np.random.random(10)468df = pl.DataFrame({"a": a, "b": b})469cols = ["a", "b"]470q = df.lazy().select(471*(472pl.col(c).quantile(0.75, interpolation="midpoint").name.suffix("_3")473for c in cols474),475*(476pl.col(c).quantile(0.25, interpolation="midpoint").name.suffix("_1")477for c in cols478),479)480assert "__POLARS_CSE" not in q.explain()481assert q.collect().to_dict(as_series=False) == {482"a_3": [0.40689473946662197],483"b_3": [0.6145786693120769],484"a_1": [0.16650805109739197],485"b_1": [0.2012768694081981],486}487488489def test_cse_nan_10824() -> None:490v = pl.col("a") / pl.col("b")491magic = pl.when(v > 0).then(pl.lit(float("nan"))).otherwise(v)492assert (493str(494(495pl.DataFrame(496{497"a": [1.0],498"b": [1.0],499}500)501.lazy()502.select(magic)503.collect(optimizations=pl.QueryOptFlags(comm_subexpr_elim=True))504).to_dict(as_series=False)505)506== "{'literal': [nan]}"507)508509510def test_cse_10901() -> None:511df = pl.DataFrame(data=range(6), schema={"a": pl.Int64})512a = pl.col("a").rolling_sum(window_size=2)513b = pl.col("a").rolling_sum(window_size=3)514exprs = {515"ax1": a,516"ax2": a * 2,517"bx1": b,518"bx2": b * 2,519}520521expected = pl.DataFrame(522{523"a": [0, 1, 2, 3, 4, 5],524"ax1": [None, 1, 3, 5, 7, 9],525"ax2": [None, 2, 6, 10, 14, 18],526"bx1": [None, None, 3, 6, 9, 12],527"bx2": [None, None, 6, 12, 18, 24],528}529)530531assert_frame_equal(df.lazy().with_columns(**exprs).collect(), expected)532533534def test_cse_count_in_group_by() -> None:535q = (536pl.LazyFrame({"a": [1, 1, 2], "b": [1, 2, 3], "c": [40, 51, 12]})537.group_by("a")538.agg(pl.all().slice(0, pl.len() - 1))539)540541assert "POLARS_CSER" not in q.explain()542assert q.collect().sort("a").to_dict(as_series=False) == {543"a": [1, 2],544"b": [[1], []],545"c": [[40], []],546}547548549def test_cse_slice_11594() -> None:550df = pl.LazyFrame({"a": [1, 2, 1, 2, 1, 2]})551552q = df.select(553pl.col("a").slice(offset=1, length=pl.len() - 1).alias("1"),554pl.col("a").slice(offset=1, length=pl.len() - 1).alias("2"),555)556557assert "__POLARS_CSE" in q.explain(558optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)559)560561assert q.collect(optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)).to_dict(562as_series=False563) == {564"1": [2, 1, 2, 1, 2],565"2": [2, 1, 2, 1, 2],566}567568q = df.select(569pl.col("a").slice(offset=1, length=pl.len() - 1).alias("1"),570pl.col("a").slice(offset=0, length=pl.len() - 1).alias("2"),571)572573assert "__POLARS_CSE" in q.explain(574optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)575)576577assert q.collect(optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)).to_dict(578as_series=False579) == {580"1": [2, 1, 2, 1, 2],581"2": [1, 2, 1, 2, 1],582}583584585def test_cse_is_in_11489() -> None:586df = pl.DataFrame(587{"cond": [1, 2, 3, 2, 1], "x": [1.0, 0.20, 3.0, 4.0, 0.50]}588).lazy()589any_cond = (590pl.when(pl.col("cond").is_in([2, 3]))591.then(True)592.when(pl.col("cond").is_in([1]))593.then(False)594.otherwise(None)595.alias("any_cond")596)597val = (598pl.when(any_cond)599.then(1.0)600.when(~any_cond)601.then(0.0)602.otherwise(None)603.alias("val")604)605assert df.select("cond", any_cond, val).collect().to_dict(as_series=False) == {606"cond": [1, 2, 3, 2, 1],607"any_cond": [False, True, True, True, False],608"val": [0.0, 1.0, 1.0, 1.0, 0.0],609}610611612def test_cse_11958() -> None:613df = pl.LazyFrame({"a": [1, 2, 3, 4, 5]})614vector_losses = []615for lag in range(1, 5):616difference = pl.col("a") - pl.col("a").shift(lag)617component_loss = pl.when(difference >= 0).then(difference * 10)618vector_losses.append(component_loss.alias(f"diff{lag}"))619620q = df.select(vector_losses)621assert "__POLARS_CSE" in q.explain(622optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)623)624assert q.collect(optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)).to_dict(625as_series=False626) == {627"diff1": [None, 10, 10, 10, 10],628"diff2": [None, None, 20, 20, 20],629"diff3": [None, None, None, 30, 30],630"diff4": [None, None, None, None, 40],631}632633634def test_cse_14047() -> None:635ldf = pl.LazyFrame(636{637"timestamp": pl.datetime_range(638datetime(2024, 1, 12),639datetime(2024, 1, 12, 0, 0, 0, 150_000),640"10ms",641eager=True,642closed="left",643),644"price": list(range(15)),645}646)647648def count_diff(649price: pl.Expr, upper_bound: float = 0.1, lower_bound: float = 0.001650) -> pl.Expr:651span_end_to_curr = (652price.count()653.cast(int)654.rolling("timestamp", period=timedelta(seconds=lower_bound))655)656span_start_to_curr = (657price.count()658.cast(int)659.rolling("timestamp", period=timedelta(seconds=upper_bound))660)661return (span_start_to_curr - span_end_to_curr).alias(662f"count_diff_{upper_bound}_{lower_bound}"663)664665def s_per_count(count_diff: pl.Expr, span: tuple[float, float]) -> pl.Expr:666return (span[1] * 1000 - span[0] * 1000) / count_diff667668spans = [(0.001, 0.1), (1, 10)]669count_diff_exprs = [count_diff(pl.col("price"), span[0], span[1]) for span in spans]670s_per_count_exprs = [671s_per_count(count_diff, span).alias(f"zz_{span}")672for count_diff, span in zip(count_diff_exprs, spans, strict=True)673]674675exprs = count_diff_exprs + s_per_count_exprs676ldf = ldf.with_columns(*exprs)677assert_frame_equal(678ldf.collect(optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)),679ldf.collect(optimizations=pl.QueryOptFlags(comm_subexpr_elim=False)),680)681682683def test_cse_15536() -> None:684source = pl.DataFrame({"a": range(10)})685686data = source.lazy().filter(pl.col("a") >= 5)687688assert pl.concat(689[690data.filter(pl.lit(True) & (pl.col("a") == 6) | (pl.col("a") == 9)),691data.filter(pl.lit(True) & (pl.col("a") == 7) | (pl.col("a") == 8)),692]693).collect()["a"].to_list() == [6, 9, 7, 8]694695696def test_cse_15548() -> None:697ldf = pl.LazyFrame({"a": [1, 2, 3]})698ldf2 = ldf.filter(pl.col("a") == 1).cache()699ldf3 = pl.concat([ldf, ldf2])700701assert (702len(ldf3.collect(optimizations=pl.QueryOptFlags(comm_subplan_elim=False))) == 4703)704assert (705len(ldf3.collect(optimizations=pl.QueryOptFlags(comm_subplan_elim=True))) == 4706)707708709@pytest.mark.debug710def test_cse_and_schema_update_projection_pd() -> None:711df = pl.LazyFrame({"a": [1, 2], "b": [99, 99]})712713q = (714df.lazy()715.with_row_index()716.select(717pl.when(pl.col("b") < 10)718.then(0.1 * pl.col("b"))719.when(pl.col("b") < 100)720.then(0.2 * pl.col("b"))721)722)723assert q.collect(optimizations=pl.QueryOptFlags(comm_subplan_elim=False)).to_dict(724as_series=False725) == {"literal": [19.8, 19.8]}726assert (727num_cse_occurrences(728q.explain(optimizations=pl.QueryOptFlags(comm_subexpr_elim=True))729)730== 1731)732733734@pytest.mark.debug735@pytest.mark.may_fail_auto_streaming736@pytest.mark.parametrize("use_custom_io_source", [True, False])737def test_cse_predicate_self_join(738capfd: Any, plmonkeypatch: PlMonkeyPatch, use_custom_io_source: bool739) -> None:740plmonkeypatch.setenv("POLARS_VERBOSE", "1")741y = pl.LazyFrame({"a": [1], "b": [2], "y": [3]})742if use_custom_io_source:743y = create_dataframe_source(y.collect(), is_pure=True)744745xf = y.filter(pl.col("y") == 2).select(["a", "b"])746y_xf = y.join(xf, on=["a", "b"], how="left")747748y_xf_c = y_xf.select("a", "b")749assert y_xf_c.collect().to_dict(as_series=False) == {"a": [1], "b": [2]}750captured = capfd.readouterr().err751assert "CACHE HIT" in captured752753754def test_cse_manual_cache_15688() -> None:755df = pl.LazyFrame(756{"a": [1, 2, 3, 1, 2, 3], "b": [1, 1, 1, 1, 1, 1], "id": [1, 1, 1, 2, 2, 2]}757)758759df1 = df.filter(id=1).join(df.filter(id=2), on=["a", "b"], how="semi")760df2 = df.filter(id=1).join(df1, on=["a", "b"], how="semi")761df2 = df2.cache()762res = df2.group_by("b").agg(pl.all().sum())763764assert res.cache().with_columns(foo=1).collect().to_dict(as_series=False) == {765"b": [1],766"a": [6],767"id": [3],768"foo": [1],769}770771772def test_cse_drop_nulls_15795() -> None:773A = pl.LazyFrame({"X": 1})774B = pl.LazyFrame({"X": 1, "Y": 0}).filter(pl.col("Y").is_not_null())775C = A.join(B, on="X").select("X")776D = B.select("X")777assert C.join(D, on="X").collect().shape == (1, 1)778779780def test_cse_no_projection_15980() -> None:781df = pl.LazyFrame({"x": "a", "y": 1})782df = pl.concat(df.with_columns(pl.col("y").add(n)) for n in range(2))783784assert df.filter(pl.col("x").eq("a")).select("x").collect().to_dict(785as_series=False786) == {"x": ["a", "a"]}787788789@pytest.mark.debug790def test_cse_series_collision_16138() -> None:791holdings = pl.DataFrame(792{793"fund_currency": ["CLP", "CLP"],794"asset_currency": ["EUR", "USA"],795}796)797798usd = ["USD"]799eur = ["EUR"]800clp = ["CLP"]801802currency_factor_query_dict = [803pl.col("asset_currency").is_in(eur) & pl.col("fund_currency").is_in(clp),804pl.col("asset_currency").is_in(eur) & pl.col("fund_currency").is_in(usd),805pl.col("asset_currency").is_in(clp) & pl.col("fund_currency").is_in(clp),806pl.col("asset_currency").is_in(usd) & pl.col("fund_currency").is_in(usd),807]808809factor_holdings = holdings.lazy().with_columns(810pl.coalesce(currency_factor_query_dict).alias("currency_factor"),811)812813assert factor_holdings.collect(814optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)815).to_dict(as_series=False) == {816"fund_currency": ["CLP", "CLP"],817"asset_currency": ["EUR", "USA"],818"currency_factor": [True, False],819}820assert (821num_cse_occurrences(822factor_holdings.explain(823optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)824)825)826== 3827)828829830def test_nested_cache_no_panic_16553() -> None:831assert pl.LazyFrame().select(a=[[[1]]]).collect(832optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)833).to_dict(as_series=False) == {"a": [[[[1]]]]}834835836def test_hash_empty_series_16577() -> None:837s = pl.Series(values=None)838out = pl.LazyFrame().select(s).collect()839assert out.equals(s.to_frame())840841842def test_cse_non_scalar_length_mismatch_17732() -> None:843df = pl.LazyFrame({"a": pl.Series(range(30), dtype=pl.Int32)})844got = (845df.lazy()846.with_columns(847pl.col("a").head(5).min().alias("b"),848pl.col("a").head(5).max().alias("c"),849)850.collect(optimizations=pl.QueryOptFlags(comm_subexpr_elim=True))851)852expect = pl.DataFrame(853{854"a": pl.Series(range(30), dtype=pl.Int32),855"b": pl.Series([0] * 30, dtype=pl.Int32),856"c": pl.Series([4] * 30, dtype=pl.Int32),857}858)859860assert_frame_equal(expect, got)861862863def test_cse_chunks_18124() -> None:864df = pl.DataFrame(865{866"ts_diff": [timedelta(seconds=60)] * 2,867"ts_diff_after": [timedelta(seconds=120)] * 2,868}869)870df = pl.concat([df, df], rechunk=False)871assert (872df.lazy()873.with_columns(874ts_diff_sign=pl.col("ts_diff") > pl.duration(seconds=0),875ts_diff_after_sign=pl.col("ts_diff_after") > pl.duration(seconds=0),876)877.filter(pl.col("ts_diff") > 1)878).collect().shape == (4, 4)879880881@pytest.mark.may_fail_auto_streaming882def test_eager_cse_during_struct_expansion_18411() -> None:883df = pl.DataFrame({"foo": [0, 0, 0, 1, 1]})884vc = pl.col("foo").value_counts()885classes = vc.struct[0]886counts = vc.struct[1]887# Check if output is stable888assert (889df.select(pl.col("foo").replace(classes, counts))890== df.select(pl.col("foo").replace(classes, counts))891)["foo"].all()892893894def test_cse_as_struct_19253() -> None:895df = pl.LazyFrame({"x": [1, 2], "y": [4, 5]})896897assert (898df.with_columns(899q1=pl.struct(pl.col.x - pl.col.y.mean()),900q2=pl.struct(pl.col.x - pl.col.y.mean().over("y")),901).collect()902).to_dict(as_series=False) == {903"x": [1, 2],904"y": [4, 5],905"q1": [{"x": -3.5}, {"x": -2.5}],906"q2": [{"x": -3.0}, {"x": -3.0}],907}908909910@pytest.mark.may_fail_auto_streaming911def test_cse_as_struct_value_counts_20927() -> None:912assert pl.DataFrame({"x": [i for i in range(1, 6) for _ in range(i)]}).select(913pl.struct("x").value_counts().struct.unnest()914).sort("count").to_dict(as_series=False) == {915"x": [{"x": 1}, {"x": 2}, {"x": 3}, {"x": 4}, {"x": 5}],916"count": [1, 2, 3, 4, 5],917}918919920def test_cse_union_19227() -> None:921lf = pl.LazyFrame({"A": [1], "B": [2]})922lf_1 = lf.select(C="A", B="B")923lf_2 = lf.select(C="A", A="B")924925direct = lf_2.join(lf, on=["A"]).select("C", "A", "B")926927indirect = lf_1.join(direct, on=["C", "B"]).select("C", "A", "B")928929out = pl.concat([direct, indirect])930assert out.collect().schema == pl.Schema(931[("C", pl.Int64), ("A", pl.Int64), ("B", pl.Int64)]932)933934935def test_cse_21115() -> None:936lf = pl.LazyFrame({"x": 1, "y": 5})937938assert lf.with_columns(939pl.all().exp() + pl.min_horizontal(pl.all().exp())940).collect().to_dict(as_series=False) == {941"x": [5.43656365691809],942"y": [151.13144093103566],943}944945946@pytest.mark.parametrize("use_custom_io_source", [True, False])947def test_cse_cache_leakage_22339(use_custom_io_source: bool) -> None:948lf1 = pl.LazyFrame({"x": [True] * 2})949lf2 = pl.LazyFrame({"x": [True] * 3})950if use_custom_io_source:951lf1 = create_dataframe_source(lf1.collect(), is_pure=True)952lf2 = create_dataframe_source(lf2.collect(), is_pure=True)953954a = lf1955b = lf1.filter(pl.col("x").not_().over(1))956c = lf2.filter(pl.col("x").not_().over(1))957958ab = a.join(b, on="x")959bc = b.join(c, on="x")960ac = a.join(c, on="x")961962assert pl.concat([ab, bc, ac]).collect().to_dict(as_series=False) == {"x": []}963964965@pytest.mark.write_disk966def test_multiplex_predicate_pushdown() -> None:967ldf = pl.LazyFrame({"a": [1, 1, 2, 2], "b": [1, 2, 3, 4]})968with TemporaryDirectory() as f:969tmppath = Path(f)970ldf.sink_parquet(971pl.PartitionBy(tmppath, key="a", include_key=True),972sync_on_close="all",973mkdir=True,974)975ldf = pl.scan_parquet(tmppath, hive_partitioning=True)976ldf = ldf.filter(pl.col("a").eq(1)).select("b")977assert 'SELECTION: [(col("a")) == (1)]' in pl.explain_all([ldf, ldf])978979980def test_cse_custom_io_source_same_object() -> None:981df = pl.DataFrame({"a": [1, 2, 3, 4, 5]})982983io_source = Mock(wraps=lambda *_: iter([df]))984985lf = register_io_source(986io_source,987schema=df.schema,988validate_schema=True,989is_pure=True,990)991992lfs = [lf, lf]993994plan = pl.explain_all(lfs)995caches: list[str] = [996x for x in map(str.strip, plan.splitlines()) if x.startswith("CACHE[")997]998assert len(caches) == 2999assert len(set(caches)) == 110001001assert io_source.call_count == 010021003assert_frame_equal(1004pl.concat(pl.collect_all(lfs)),1005pl.DataFrame({"a": [1, 2, 3, 4, 5, 1, 2, 3, 4, 5]}),1006)10071008assert io_source.call_count == 110091010io_source = Mock(wraps=lambda *_: iter([df]))10111012# Without explicit is_pure parameter should default to False1013lf = register_io_source(1014io_source,1015schema=df.schema,1016validate_schema=True,1017)10181019lfs = [lf, lf]10201021plan = pl.explain_all(lfs)10221023caches = [x for x in map(str.strip, plan.splitlines()) if x.startswith("CACHE[")]1024assert len(caches) == 010251026assert io_source.call_count == 010271028assert_frame_equal(1029pl.concat(pl.collect_all(lfs)),1030pl.DataFrame({"a": [1, 2, 3, 4, 5, 1, 2, 3, 4, 5]}),1031)10321033assert io_source.call_count == 210341035io_source = Mock(wraps=lambda *_: iter([df]))10361037# LazyFrames constructed from separate calls do not CSE even if the1038# io_source function is the same.1039#1040# Note: This behavior is achieved by having `register_io_source` wrap1041# the user-provided io plugin with a locally constructed wrapper before1042# passing to the Rust-side.1043lfs = [1044register_io_source(1045io_source,1046schema=df.schema,1047validate_schema=True,1048is_pure=True,1049),1050register_io_source(1051io_source,1052schema=df.schema,1053validate_schema=True,1054is_pure=True,1055),1056]10571058caches = [x for x in map(str.strip, plan.splitlines()) if x.startswith("CACHE[")]1059assert len(caches) == 010601061assert io_source.call_count == 010621063assert_frame_equal(1064pl.concat(pl.collect_all(lfs)),1065pl.DataFrame({"a": [1, 2, 3, 4, 5, 1, 2, 3, 4, 5]}),1066)10671068assert io_source.call_count == 2106910701071@pytest.mark.write_disk1072def test_cse_preferred_over_slice() -> None:1073# This test asserts that even if we slice disjoint sections of a lazyframe, caching1074# is preferred, and slicing is not pushed down1075df = pl.DataFrame({"a": list(range(1, 21))})1076with NamedTemporaryFile() as f:1077val = df.write_csv()1078f.write(val.encode())1079f.seek(0)1080ldf = pl.scan_csv(f.name)1081left = ldf.slice(0, 5)1082right = ldf.slice(6, 5)1083q = left.join(right, on="a", how="inner")1084assert "CACHE[id:" in q.explain(1085optimizations=pl.QueryOptFlags(comm_subplan_elim=True)1086)108710881089def test_cse_preferred_over_slice_custom_io_source() -> None:1090# This test asserts that even if we slice disjoint sections of a custom io source,1091# caching is preferred, and slicing is not pushed down1092df = pl.DataFrame({"a": list(range(1, 21))})1093lf = create_dataframe_source(df, is_pure=True)1094left = lf.slice(0, 5)1095right = lf.slice(6, 5)1096q = left.join(right, on="a", how="inner")1097assert "CACHE[id:" in q.explain(1098optimizations=pl.QueryOptFlags(comm_subplan_elim=True)1099)11001101lf = create_dataframe_source(df, is_pure=False)1102left = lf.slice(0, 5)1103right = lf.slice(6, 5)1104q = left.join(right, on="a", how="inner")1105assert "CACHE[id:" not in q.explain(1106optimizations=pl.QueryOptFlags(comm_subplan_elim=True)1107)110811091110def test_cse_custom_io_source_diff_columns() -> None:1111df = pl.DataFrame({"a": [1, 2, 3, 4, 5], "b": [10, 11, 12, 13, 14]})1112lf = create_dataframe_source(df, is_pure=True)1113collection = [lf.select("a"), lf.select("b")]1114assert "CACHE[id:" in pl.explain_all(collection)1115collected = pl.collect_all(1116collection, optimizations=pl.QueryOptFlags(comm_subplan_elim=True)1117)1118assert_frame_equal(df.select("a"), collected[0])1119assert_frame_equal(df.select("b"), collected[1])112011211122def test_cse_custom_io_source_diff_filters() -> None:1123df = pl.DataFrame({"a": [1, 2, 3, 4, 5], "b": [10, 11, 12, 13, 14]})1124lf = create_dataframe_source(df, is_pure=True)11251126# We use this so that the true type of the input is passed through1127# to the output1128PolarsFrame = TypeVar("PolarsFrame", pl.DataFrame, pl.LazyFrame)11291130def left_pipe(df_or_lf: PolarsFrame) -> PolarsFrame:1131return df_or_lf.select("a").filter(pl.col("a").is_between(2, 6))11321133def right_pipe(df_or_lf: PolarsFrame) -> PolarsFrame:1134return df_or_lf.select("b").filter(pl.col("b").is_between(10, 13))11351136collection = [lf.pipe(left_pipe), lf.pipe(right_pipe)]1137explanation = pl.explain_all(collection)1138# we prefer predicate pushdown over CSE1139assert "CACHE[id:" not in explanation1140assert 'SELECTION: col("a").is_between([2, 6])' in explanation1141assert 'SELECTION: col("b").is_between([10, 13])' in explanation11421143res = pl.collect_all(collection)1144expected = [df.pipe(left_pipe), df.pipe(right_pipe)]1145assert_frame_equal(expected[0], res[0])1146assert_frame_equal(expected[1], res[1])114711481149@pytest.mark.skip1150def test_cspe_recursive_24744() -> None:1151df_a = pl.DataFrame([pl.Series("x", [0, 1, 2, 3], dtype=pl.UInt32)])11521153def convoluted_inner_join(1154lf_left: pl.LazyFrame,1155lf_right: pl.LazyFrame,1156) -> pl.LazyFrame:1157lf_left = lf_left.with_columns(pl.col("x").alias("index"))11581159lf_joined = lf_left.join(1160lf_right,1161how="inner",1162on=["x"],1163)11641165lf_joined_final = lf_left.join(1166lf_joined,1167how="inner",1168on=["index", "x"],1169).drop("index")1170return lf_joined_final11711172lf_a = df_a.lazy()1173lf_j1 = convoluted_inner_join(lf_left=lf_a, lf_right=lf_a)1174lf_j2 = convoluted_inner_join(lf_left=lf_j1, lf_right=lf_a)1175lf_j3 = convoluted_inner_join(lf_left=lf_j2, lf_right=lf_a).sort("x")11761177assert lf_j3.explain().count("CACHE") == 141178assert_frame_equal(1179lf_j3.collect(),1180lf_j3.collect(optimizations=pl.QueryOptFlags(comm_subplan_elim=False)),1181)1182assert (1183lf_j3.show_graph( # type: ignore[union-attr]1184engine="streaming", plan_stage="physical", raw_output=True1185).count("multiplexer")1186== 31187)1188assert (1189lf_j3.show_graph( # type: ignore[union-attr]1190engine="in-memory", plan_stage="physical", raw_output=True1191).count("CACHE")1192== 31193)119411951196def test_cpse_predicates_25030() -> None:1197df = pl.LazyFrame({"key": [1, 2, 2], "x": [6, 2, 3], "y": [0, 1, 4]})11981199q1 = df.group_by("key").len().filter(pl.col("len") > 1)1200q2 = df.filter(pl.col.x > pl.col.y)12011202q3 = q1.join(q2, on="key")12031204q4 = q3.group_by("key").len().join(q3, on="key")12051206got = q4.collect()1207expected = q4.collect(optimizations=pl.QueryOptFlags(comm_subplan_elim=False))12081209assert_frame_equal(got, expected)1210assert q4.explain().count("CACHE") == 2121112121213def test_asof_join_25699() -> None:1214df = pl.LazyFrame({"a": [10], "b": [10]})12151216df = df.with_columns(pl.col("a"))1217df = df.with_columns(pl.col("b"))12181219assert_frame_equal(1220df.join_asof(df, on="b").collect(),1221pl.DataFrame({"a": [10], "b": [10], "a_right": [10]}),1222)122312241225def test_csee_python_function() -> None:1226# Make sure to use the same expression1227# This only works for functions on the same address1228expr = pl.col("a").map_elements(lambda x: hash(x))1229q = pl.LazyFrame({"a": [10], "b": [10]}).with_columns(1230a=expr * 10,1231b=expr * 100,1232)12331234assert "__POLARS_CSER" in q.explain()1235assert_frame_equal(1236q.collect(), q.collect(optimizations=pl.QueryOptFlags(comm_subexpr_elim=False))1237)123812391240def test_csee_streaming() -> None:1241lf = pl.LazyFrame({"a": [10], "b": [10]})12421243# elementwise is allowed1244expr = pl.col("a") * pl.col("b")1245q = lf.with_columns(1246a=expr * 10,1247b=expr * 100,1248)1249assert "__POLARS_CSER" in q.explain(engine="streaming")12501251# non-elementwise not1252expr = pl.col("a").sum()1253q = lf.with_columns(1254a=expr * 10,1255b=expr * 100,1256)1257assert "__POLARS_CSER" not in q.explain(engine="streaming")125812591260