Path: blob/main/py-polars/tests/unit/lazyframe/test_optimizations.py
8415 views
import datetime as dt1import io2import itertools3import typing45import pytest67import polars as pl8from polars.testing import assert_frame_equal91011def test_is_null_followed_by_all() -> None:12lf = pl.LazyFrame({"group": [0, 0, 0, 1], "val": [6, 0, None, None]})1314expected_df = pl.DataFrame({"group": [0, 1], "val": [False, True]})15result_lf = lf.group_by("group", maintain_order=True).agg(16pl.col("val").is_null().all()17)1819assert r'[[(col("val").len()) == (col("val").null_count())]]' in result_lf.explain()20assert "is_null" not in result_lf21assert_frame_equal(expected_df, result_lf.collect())2223# verify we don't optimize on chained expressions when last one is not col24non_optimized_result_plan = (25lf.group_by("group", maintain_order=True)26.agg(pl.col("val").abs().is_null().all())27.explain()28)29assert "null_count" not in non_optimized_result_plan30assert "is_null" in non_optimized_result_plan3132# edge case of empty series33lf = pl.LazyFrame({"val": []}, schema={"val": pl.Int32})3435expected_df = pl.DataFrame({"val": [True]})36result_df = lf.select(pl.col("val").is_null().all()).collect()37assert_frame_equal(expected_df, result_df)383940def test_is_null_followed_by_any() -> None:41lf = pl.LazyFrame({"group": [0, 0, 0, 1, 2], "val": [6, 0, None, None, 5]})4243expected_df = pl.DataFrame({"group": [0, 1, 2], "val": [True, True, False]})44result_lf = lf.group_by("group", maintain_order=True).agg(45pl.col("val").is_null().any()46)47assert_frame_equal(expected_df, result_lf.collect())4849# edge case of empty series50lf = pl.LazyFrame({"val": []}, schema={"val": pl.Int32})5152expected_df = pl.DataFrame({"val": [False]})53result_df = lf.select(pl.col("val").is_null().any()).collect()54assert_frame_equal(expected_df, result_df)555657def test_is_not_null_followed_by_all() -> None:58lf = pl.LazyFrame({"group": [0, 0, 0, 1], "val": [6, 0, 5, None]})5960expected_df = pl.DataFrame({"group": [0, 1], "val": [True, False]})61result_df = (62lf.group_by("group", maintain_order=True)63.agg(pl.col("val").is_not_null().all())64.collect()65)6667assert_frame_equal(expected_df, result_df)6869# edge case of empty series70lf = pl.LazyFrame({"val": []}, schema={"val": pl.Int32})7172expected_df = pl.DataFrame({"val": [True]})73result_df = lf.select(pl.col("val").is_not_null().all()).collect()74assert_frame_equal(expected_df, result_df)757677def test_is_not_null_followed_by_any() -> None:78lf = pl.LazyFrame({"group": [0, 0, 0, 1, 2], "val": [6, 0, None, None, 5]})7980expected_df = pl.DataFrame({"group": [0, 1, 2], "val": [True, False, True]})81result_lf = lf.group_by("group", maintain_order=True).agg(82pl.col("val").is_not_null().any()83)8485assert r'[[(col("val").null_count()) < (col("val").len())]]' in result_lf.explain()86assert "is_not_null" not in result_lf.explain()87assert_frame_equal(expected_df, result_lf.collect())8889# verify we don't optimize on chained expressions when last one is not col90non_optimized_result_plan = (91lf.group_by("group", maintain_order=True)92.agg(pl.col("val").abs().is_not_null().any())93.explain()94)95assert "null_count" not in non_optimized_result_plan96assert "is_not_null" in non_optimized_result_plan9798# edge case of empty series99lf = pl.LazyFrame({"val": []}, schema={"val": pl.Int32})100101expected_df = pl.DataFrame({"val": [False]})102result_df = lf.select(pl.col("val").is_not_null().any()).collect()103assert_frame_equal(expected_df, result_df)104105106def test_is_null_followed_by_sum() -> None:107lf = pl.LazyFrame({"group": [0, 0, 0, 1, 2], "val": [6, 0, None, None, 5]})108109expected_df = pl.DataFrame(110{"group": [0, 1, 2], "val": [1, 1, 0]},111schema_overrides={"val": pl.get_index_type()},112)113result_lf = lf.group_by("group", maintain_order=True).agg(114pl.col("val").is_null().sum()115)116117assert r'[col("val").null_count()]' in result_lf.explain()118assert "is_null" not in result_lf.explain()119assert_frame_equal(expected_df, result_lf.collect())120121# edge case of empty series122lf = pl.LazyFrame({"val": []}, schema={"val": pl.Int32})123124expected_df = pl.DataFrame({"val": [0]}, schema={"val": pl.get_index_type()})125result_df = lf.select(pl.col("val").is_null().sum()).collect()126assert_frame_equal(expected_df, result_df)127128129def test_is_not_null_followed_by_sum() -> None:130lf = pl.LazyFrame({"group": [0, 0, 0, 1, 2], "val": [6, 0, None, None, 5]})131132expected_df = pl.DataFrame(133{"group": [0, 1, 2], "val": [2, 0, 1]},134schema_overrides={"val": pl.get_index_type()},135)136result_lf = lf.group_by("group", maintain_order=True).agg(137pl.col("val").is_not_null().sum()138)139140assert r'[[(col("val").len()) - (col("val").null_count())]]' in result_lf.explain()141assert "is_not_null" not in result_lf.explain()142assert_frame_equal(expected_df, result_lf.collect())143144# verify we don't optimize on chained expressions when last one is not col145non_optimized_result_lf = lf.group_by("group", maintain_order=True).agg(146pl.col("val").abs().is_not_null().sum()147)148assert "null_count" not in non_optimized_result_lf.explain()149assert "is_not_null" in non_optimized_result_lf.explain()150151# edge case of empty series152lf = pl.LazyFrame({"val": []}, schema={"val": pl.Int32})153154expected_df = pl.DataFrame({"val": [0]}, schema={"val": pl.get_index_type()})155result_df = lf.select(pl.col("val").is_not_null().sum()).collect()156assert_frame_equal(expected_df, result_df)157158159def test_drop_nulls_followed_by_len() -> None:160lf = pl.LazyFrame({"group": [0, 0, 0, 1, 2], "val": [6, 0, None, None, 5]})161162expected_df = pl.DataFrame(163{"group": [0, 1, 2], "val": [2, 0, 1]},164schema_overrides={"val": pl.get_index_type()},165)166result_lf = lf.group_by("group", maintain_order=True).agg(167pl.col("val").drop_nulls().len()168)169170assert r'[[(col("val").len()) - (col("val").null_count())]]' in result_lf.explain()171assert "drop_nulls" not in result_lf.explain()172assert_frame_equal(expected_df, result_lf.collect())173174# verify we don't optimize on chained expressions when last one is not col175non_optimized_result_plan = (176lf.group_by("group", maintain_order=True)177.agg(pl.col("val").abs().drop_nulls().len())178.explain()179)180assert "null_count" not in non_optimized_result_plan181assert "drop_nulls" in non_optimized_result_plan182183184def test_drop_nulls_followed_by_count() -> None:185lf = pl.LazyFrame({"group": [0, 0, 0, 1, 2], "val": [6, 0, None, None, 5]})186187expected_df = pl.DataFrame(188{"group": [0, 1, 2], "val": [2, 0, 1]},189schema_overrides={"val": pl.get_index_type()},190)191result_lf = lf.group_by("group", maintain_order=True).agg(192pl.col("val").drop_nulls().count()193)194195assert r'[[(col("val").len()) - (col("val").null_count())]]' in result_lf.explain()196assert "drop_nulls" not in result_lf.explain()197assert_frame_equal(expected_df, result_lf.collect())198199# verify we don't optimize on chained expressions when last one is not col200non_optimized_result_plan = (201lf.group_by("group", maintain_order=True)202.agg(pl.col("val").abs().drop_nulls().count())203.explain()204)205assert "null_count" not in non_optimized_result_plan206assert "drop_nulls" in non_optimized_result_plan207208209def test_collapse_joins() -> None:210a = pl.LazyFrame({"a": [1, 2, 3], "b": [2, 2, 2]})211b = pl.LazyFrame({"x": [7, 1, 2]})212213cross = a.join(b, how="cross")214215inner_join = cross.filter(pl.col.a == pl.col.x)216e = inner_join.explain()217assert "INNER JOIN" in e218assert "FILTER" not in e219assert_frame_equal(220inner_join.collect(optimizations=pl.QueryOptFlags.none()),221inner_join.collect(),222check_row_order=False,223)224225inner_join = cross.filter(pl.col.x == pl.col.a)226e = inner_join.explain()227assert "INNER JOIN" in e228assert "FILTER" not in e229assert_frame_equal(230inner_join.collect(optimizations=pl.QueryOptFlags.none()),231inner_join.collect(),232check_row_order=False,233)234235double_inner_join = cross.filter(pl.col.x == pl.col.a).filter(pl.col.x == pl.col.b)236e = double_inner_join.explain()237assert "INNER JOIN" in e238assert "FILTER" not in e239assert_frame_equal(240double_inner_join.collect(optimizations=pl.QueryOptFlags.none()),241double_inner_join.collect(),242check_row_order=False,243)244245dont_mix = cross.filter(pl.col.x + pl.col.a != 0)246e = dont_mix.explain()247assert "NESTED LOOP JOIN" in e248assert "FILTER" not in e249assert_frame_equal(250dont_mix.collect(optimizations=pl.QueryOptFlags.none()),251dont_mix.collect(),252check_row_order=False,253)254255iejoin = cross.filter(pl.col.x >= pl.col.a)256e = iejoin.explain()257assert "IEJOIN" in e258assert "NESTED LOOP JOIN" not in e259assert "CROSS JOIN" not in e260assert "FILTER" not in e261assert_frame_equal(262iejoin.collect(optimizations=pl.QueryOptFlags.none()),263iejoin.collect(),264check_row_order=False,265)266267iejoin = cross.filter(pl.col.x >= pl.col.a).filter(pl.col.x <= pl.col.b)268e = iejoin.explain()269assert "IEJOIN" in e270assert "CROSS JOIN" not in e271assert "NESTED LOOP JOIN" not in e272assert "FILTER" not in e273assert_frame_equal(274iejoin.collect(optimizations=pl.QueryOptFlags.none()),275iejoin.collect(),276check_row_order=False,277)278279280@pytest.mark.slow281def test_collapse_joins_combinations() -> None:282# This just tests all possible combinations for expressions on a cross join.283284a = pl.LazyFrame({"a": [1, 2, 3], "x": [7, 2, 1]})285b = pl.LazyFrame({"b": [2, 2, 2], "x": [7, 1, 3]})286287cross = a.join(b, how="cross")288289exprs = []290291for lhs in [pl.col.a, pl.col.b, pl.col.x, pl.lit(1), pl.col.a + pl.col.b]:292for rhs in [pl.col.a, pl.col.b, pl.col.x, pl.lit(1), pl.col.a * pl.col.x]:293for cmp in ["__eq__", "__ge__", "__lt__"]:294e = (getattr(lhs, cmp))(rhs)295exprs.append(e)296297for amount in range(3):298for merge in itertools.product(["__and__", "__or__"] * (amount - 1)):299for es in itertools.product(*([exprs] * amount)):300e = es[0]301for i in range(amount - 1):302e = (getattr(e, merge[i]))(es[i + 1])303304# NOTE: We need to sort because the order of the cross-join &305# IE-join is unspecified. Therefore, this might not necessarily306# create the exact same dataframe.307optimized = cross.filter(e).sort(pl.all()).collect()308unoptimized = cross.filter(e).collect(309optimizations=pl.QueryOptFlags.none()310)311312try:313assert_frame_equal(optimized, unoptimized, check_row_order=False)314except:315print(e)316print()317print("Optimized")318print(cross.filter(e).explain())319print(optimized)320print()321print("Unoptimized")322print(323cross.filter(e).explain(optimizations=pl.QueryOptFlags.none())324)325print(unoptimized)326print()327328raise329330331def test_order_observe_sort_before_unique_22485() -> None:332lf = pl.LazyFrame(333{334"order": [3, 2, 1],335"id": ["A", "A", "B"],336}337)338339expect = pl.DataFrame({"order": [1, 3], "id": ["B", "A"]})340341q = lf.sort("order").unique(["id"], keep="last").sort("order")342343plan = q.explain()344assert "SORT BY" in plan[plan.index("UNIQUE") :]345346assert_frame_equal(q.collect(), expect)347348q = lf.sort("order").unique(["id"], keep="last", maintain_order=True)349350plan = q.explain()351assert "SORT BY" in plan[plan.index("UNIQUE") :]352353assert_frame_equal(q.collect(), expect)354355356def test_order_observe_group_by() -> None:357q = (358pl.LazyFrame({"a": range(5)})359.group_by("a", maintain_order=True)360.agg(b=1)361.sort("b")362)363364plan = q.explain()365assert "AGGREGATE[maintain_order: false]" in plan366367q = (368pl.LazyFrame({"a": range(5)})369.group_by("a", maintain_order=True)370.agg(b=1)371.sort("b", maintain_order=True)372)373374plan = q.explain()375assert "AGGREGATE[maintain_order: true]" in plan376377378def test_fused_correct_name() -> None:379df = pl.DataFrame({"x": [1, 2, 3]})380381lf = df.lazy().select(382(pl.col.x.alias("a") * pl.col.x.alias("b")) + pl.col.x.alias("c")383)384385no_opts = lf.collect(optimizations=pl.QueryOptFlags.none())386opts = lf.collect()387assert_frame_equal(388no_opts,389opts,390)391assert_frame_equal(opts, pl.DataFrame({"a": [2, 6, 12]}))392393394def test_slice_pushdown_within_concat_24734() -> None:395q = pl.concat(396[397pl.LazyFrame({"x": [0, 1, 2, 3, 4]}).head(2),398pl.LazyFrame(schema={"x": pl.Int64}),399]400)401402plan = q.explain()403assert "SLICE" not in plan404405assert_frame_equal(q, pl.LazyFrame({"x": [0, 1]}))406407q = pl.concat(408[409pl.LazyFrame({"x": [0, 1, 2, 3, 4]}).select(pl.col("x").reverse()),410pl.LazyFrame(schema={"x": pl.Int64}),411]412).slice(1, 2)413414plan = q.explain()415assert plan.index("SLICE[offset: 0, len: 3]") > plan.index("PLAN 0:")416417assert_frame_equal(q, pl.LazyFrame({"x": [3, 2]}))418419420def test_is_between_pushdown_25499() -> None:421f = io.BytesIO()422pl.LazyFrame(423{"a": [0, 1, 2, 3, 4]}, schema_overrides={"a": pl.UInt32}424).sink_parquet(f)425parquet = f.getvalue()426427expr = pl.lit(3, dtype=pl.UInt32).is_between(428pl.lit(1, dtype=pl.UInt32), pl.col("a")429)430431df1 = pl.scan_parquet(parquet).filter(expr).collect()432df2 = pl.scan_parquet(parquet).collect().filter(expr)433assert_frame_equal(df1, df2)434435436def test_slice_pushdown_expr_25473() -> None:437lf = pl.LazyFrame({"a": [0, 1, 2, 3, 4]})438439assert_frame_equal(440lf.select((pl.col("a") + 1).slice(-4, 2)).collect(), pl.DataFrame({"a": [2, 3]})441)442443assert_frame_equal(444lf.select(445a=(446pl.when(pl.col("a") == 1).then(pl.lit("one")).otherwise(pl.lit("other"))447).slice(-4, 2)448).collect(),449pl.DataFrame({"a": ["one", "other"]}),450)451452assert_frame_equal(453lf.select(a=pl.col("a").is_in(pl.Series([1]).implode()).slice(-4, 2)).collect(),454pl.DataFrame({"a": [True, False]}),455)456457q = pl.LazyFrame().select(458pl.lit(pl.Series([0, 1, 2, 3, 4])).is_in(pl.Series([[3], [1]])).slice(-2, 1)459)460461with pytest.raises(pl.exceptions.ShapeError, match=r"lengths.*5 != 2"):462q.collect()463464465def test_lazy_groupby_maintain_order_after_asof_join_25973() -> None:466# Small target times: 00:00, 00:10, 00:20, 00:30467targettime = (468pl.DataFrame(469{470"targettime": pl.time_range(471dt.time(0, 0),472dt.time(0, 30),473interval="10m",474closed="both",475eager=True,476)477}478)479.with_columns(480targettime=pl.lit(dt.date(2026, 1, 1)).dt.combine(pl.col("targettime")),481grp=pl.lit(1),482)483.lazy()484)485486# Small input times: every second from 00:00 to 00:30487df = (488pl.DataFrame(489{490"time": pl.time_range(491dt.time(0, 0),492dt.time(0, 30),493interval="1s",494closed="both",495eager=True,496)497}498)499.with_row_index("value")500.with_columns(501time=pl.lit(dt.date(2026, 1, 1)).dt.combine(pl.col("time")),502grp=pl.lit(1),503)504.lazy()505)506507# This used to produce out-of-order results.508# The optimizer previously cleared maintain_order.509q = (510df.join_asof(511targettime,512left_on="time",513right_on="targettime",514strategy="forward",515)516.drop_nulls("targettime")517.group_by("targettime", maintain_order=True)518.agg(pl.col("value").last())519)520521# Verify optimizer preserves maintain_order on UNIQUE522plan = q.explain()523assert "AGGREGATE[maintain_order: true" in plan524525result = q.collect()526527idx_dtype = pl.get_index_type()528529expected = pl.DataFrame(530{531"targettime": [532dt.datetime(2026, 1, 1, 0, 0),533dt.datetime(2026, 1, 1, 0, 10),534dt.datetime(2026, 1, 1, 0, 20),535dt.datetime(2026, 1, 1, 0, 30),536],537"value": pl.Series("value", [0, 600, 1200, 1800], dtype=idx_dtype),538}539)540541assert_frame_equal(result, expected)542543544def test_fast_count_alias_18581() -> None:545f = io.BytesIO()546f.write(b"a,b,c\n1,2,3\n4,5,6")547f.flush()548f.seek(0)549550df = pl.scan_csv(f).select(pl.len().alias("weird_name")).collect()551552# Just check the value, let assert_frame_equal handle dtype matching553expected = pl.DataFrame(554{"weird_name": [2]}, schema={"weird_name": pl.get_index_type()}555)556assert_frame_equal(expected, df)557558559def test_flatten_alias() -> None:560assert (561"""len().alias("bar")"""562in pl.LazyFrame({"a": [1, 2]})563.select(pl.len().alias("foo").alias("bar"))564.explain()565)566567568def test_concat_str_sortedness_26466() -> None:569df = pl.DataFrame({"x": ["", "a", "b"], "y": [1, 2, 3]})570lf = df.lazy().set_sorted("x")571572dot = (573lf.with_columns(x=pl.concat_str("x"))574.group_by("x")575.agg(pl.col.y.sum())576.show_graph(engine="streaming", plan_stage="physical", raw_output=True)577)578579assert "sorted-group-by" in typing.cast("str", dot)580581for e in [pl.concat_str("x", pl.lit("c")), pl.concat_str("x", ignore_nulls=True)]:582dot = (583lf.with_columns(x=e)584.group_by("x")585.agg(pl.col.y.sum())586.show_graph(engine="streaming", plan_stage="physical", raw_output=True)587)588589assert "sorted-group-by" not in typing.cast("str", dot)590591592