Path: blob/main/py-polars/tests/unit/lazyframe/test_optimizations.py
6939 views
import itertools12import pytest34import polars as pl5from polars.testing import assert_frame_equal678def test_is_null_followed_by_all() -> None:9lf = pl.LazyFrame({"group": [0, 0, 0, 1], "val": [6, 0, None, None]})1011expected_df = pl.DataFrame({"group": [0, 1], "val": [False, True]})12result_lf = lf.group_by("group", maintain_order=True).agg(13pl.col("val").is_null().all()14)1516assert r'[[(col("val").len()) == (col("val").null_count())]]' in result_lf.explain()17assert "is_null" not in result_lf18assert_frame_equal(expected_df, result_lf.collect())1920# verify we don't optimize on chained expressions when last one is not col21non_optimized_result_plan = (22lf.group_by("group", maintain_order=True)23.agg(pl.col("val").abs().is_null().all())24.explain()25)26assert "null_count" not in non_optimized_result_plan27assert "is_null" in non_optimized_result_plan2829# edge case of empty series30lf = pl.LazyFrame({"val": []}, schema={"val": pl.Int32})3132expected_df = pl.DataFrame({"val": [True]})33result_df = lf.select(pl.col("val").is_null().all()).collect()34assert_frame_equal(expected_df, result_df)353637def test_is_null_followed_by_any() -> None:38lf = pl.LazyFrame({"group": [0, 0, 0, 1, 2], "val": [6, 0, None, None, 5]})3940expected_df = pl.DataFrame({"group": [0, 1, 2], "val": [True, True, False]})41result_lf = lf.group_by("group", maintain_order=True).agg(42pl.col("val").is_null().any()43)44assert_frame_equal(expected_df, result_lf.collect())4546# edge case of empty series47lf = pl.LazyFrame({"val": []}, schema={"val": pl.Int32})4849expected_df = pl.DataFrame({"val": [False]})50result_df = lf.select(pl.col("val").is_null().any()).collect()51assert_frame_equal(expected_df, result_df)525354def test_is_not_null_followed_by_all() -> None:55lf = pl.LazyFrame({"group": [0, 0, 0, 1], "val": [6, 0, 5, None]})5657expected_df = pl.DataFrame({"group": [0, 1], "val": [True, False]})58result_df = (59lf.group_by("group", maintain_order=True)60.agg(pl.col("val").is_not_null().all())61.collect()62)6364assert_frame_equal(expected_df, result_df)6566# edge case of empty series67lf = pl.LazyFrame({"val": []}, schema={"val": pl.Int32})6869expected_df = pl.DataFrame({"val": [True]})70result_df = lf.select(pl.col("val").is_not_null().all()).collect()71assert_frame_equal(expected_df, result_df)727374def test_is_not_null_followed_by_any() -> None:75lf = pl.LazyFrame({"group": [0, 0, 0, 1, 2], "val": [6, 0, None, None, 5]})7677expected_df = pl.DataFrame({"group": [0, 1, 2], "val": [True, False, True]})78result_lf = lf.group_by("group", maintain_order=True).agg(79pl.col("val").is_not_null().any()80)8182assert r'[[(col("val").null_count()) < (col("val").len())]]' in result_lf.explain()83assert "is_not_null" not in result_lf.explain()84assert_frame_equal(expected_df, result_lf.collect())8586# verify we don't optimize on chained expressions when last one is not col87non_optimized_result_plan = (88lf.group_by("group", maintain_order=True)89.agg(pl.col("val").abs().is_not_null().any())90.explain()91)92assert "null_count" not in non_optimized_result_plan93assert "is_not_null" in non_optimized_result_plan9495# edge case of empty series96lf = pl.LazyFrame({"val": []}, schema={"val": pl.Int32})9798expected_df = pl.DataFrame({"val": [False]})99result_df = lf.select(pl.col("val").is_not_null().any()).collect()100assert_frame_equal(expected_df, result_df)101102103def test_is_null_followed_by_sum() -> None:104lf = pl.LazyFrame({"group": [0, 0, 0, 1, 2], "val": [6, 0, None, None, 5]})105106expected_df = pl.DataFrame(107{"group": [0, 1, 2], "val": [1, 1, 0]}, schema_overrides={"val": pl.UInt32}108)109result_lf = lf.group_by("group", maintain_order=True).agg(110pl.col("val").is_null().sum()111)112113assert r'[col("val").null_count()]' in result_lf.explain()114assert "is_null" not in result_lf.explain()115assert_frame_equal(expected_df, result_lf.collect())116117# edge case of empty series118lf = pl.LazyFrame({"val": []}, schema={"val": pl.Int32})119120expected_df = pl.DataFrame({"val": [0]}, schema={"val": pl.UInt32})121result_df = lf.select(pl.col("val").is_null().sum()).collect()122assert_frame_equal(expected_df, result_df)123124125def test_is_not_null_followed_by_sum() -> None:126lf = pl.LazyFrame({"group": [0, 0, 0, 1, 2], "val": [6, 0, None, None, 5]})127128expected_df = pl.DataFrame(129{"group": [0, 1, 2], "val": [2, 0, 1]}, schema_overrides={"val": pl.UInt32}130)131result_lf = lf.group_by("group", maintain_order=True).agg(132pl.col("val").is_not_null().sum()133)134135assert r'[[(col("val").len()) - (col("val").null_count())]]' in result_lf.explain()136assert "is_not_null" not in result_lf.explain()137assert_frame_equal(expected_df, result_lf.collect())138139# verify we don't optimize on chained expressions when last one is not col140non_optimized_result_lf = lf.group_by("group", maintain_order=True).agg(141pl.col("val").abs().is_not_null().sum()142)143assert "null_count" not in non_optimized_result_lf.explain()144assert "is_not_null" in non_optimized_result_lf.explain()145146# edge case of empty series147lf = pl.LazyFrame({"val": []}, schema={"val": pl.Int32})148149expected_df = pl.DataFrame({"val": [0]}, schema={"val": pl.UInt32})150result_df = lf.select(pl.col("val").is_not_null().sum()).collect()151assert_frame_equal(expected_df, result_df)152153154def test_drop_nulls_followed_by_len() -> None:155lf = pl.LazyFrame({"group": [0, 0, 0, 1, 2], "val": [6, 0, None, None, 5]})156157expected_df = pl.DataFrame(158{"group": [0, 1, 2], "val": [2, 0, 1]}, schema_overrides={"val": pl.UInt32}159)160result_lf = lf.group_by("group", maintain_order=True).agg(161pl.col("val").drop_nulls().len()162)163164assert r'[[(col("val").len()) - (col("val").null_count())]]' in result_lf.explain()165assert "drop_nulls" not in result_lf.explain()166assert_frame_equal(expected_df, result_lf.collect())167168# verify we don't optimize on chained expressions when last one is not col169non_optimized_result_plan = (170lf.group_by("group", maintain_order=True)171.agg(pl.col("val").abs().drop_nulls().len())172.explain()173)174assert "null_count" not in non_optimized_result_plan175assert "drop_nulls" in non_optimized_result_plan176177178def test_drop_nulls_followed_by_count() -> None:179lf = pl.LazyFrame({"group": [0, 0, 0, 1, 2], "val": [6, 0, None, None, 5]})180181expected_df = pl.DataFrame(182{"group": [0, 1, 2], "val": [2, 0, 1]}, schema_overrides={"val": pl.UInt32}183)184result_lf = lf.group_by("group", maintain_order=True).agg(185pl.col("val").drop_nulls().count()186)187188assert r'[[(col("val").len()) - (col("val").null_count())]]' in result_lf.explain()189assert "drop_nulls" not in result_lf.explain()190assert_frame_equal(expected_df, result_lf.collect())191192# verify we don't optimize on chained expressions when last one is not col193non_optimized_result_plan = (194lf.group_by("group", maintain_order=True)195.agg(pl.col("val").abs().drop_nulls().count())196.explain()197)198assert "null_count" not in non_optimized_result_plan199assert "drop_nulls" in non_optimized_result_plan200201202def test_collapse_joins() -> None:203a = pl.LazyFrame({"a": [1, 2, 3], "b": [2, 2, 2]})204b = pl.LazyFrame({"x": [7, 1, 2]})205206cross = a.join(b, how="cross")207208inner_join = cross.filter(pl.col.a == pl.col.x)209e = inner_join.explain()210assert "INNER JOIN" in e211assert "FILTER" not in e212assert_frame_equal(213inner_join.collect(optimizations=pl.QueryOptFlags(collapse_joins=False)),214inner_join.collect(),215check_row_order=False,216)217218inner_join = cross.filter(pl.col.x == pl.col.a)219e = inner_join.explain()220assert "INNER JOIN" in e221assert "FILTER" not in e222assert_frame_equal(223inner_join.collect(optimizations=pl.QueryOptFlags(collapse_joins=False)),224inner_join.collect(),225check_row_order=False,226)227228double_inner_join = cross.filter(pl.col.x == pl.col.a).filter(pl.col.x == pl.col.b)229e = double_inner_join.explain()230assert "INNER JOIN" in e231assert "FILTER" not in e232assert_frame_equal(233double_inner_join.collect(optimizations=pl.QueryOptFlags(collapse_joins=False)),234double_inner_join.collect(),235check_row_order=False,236)237238dont_mix = cross.filter(pl.col.x + pl.col.a != 0)239e = dont_mix.explain()240assert "NESTED LOOP JOIN" in e241assert "FILTER" not in e242assert_frame_equal(243dont_mix.collect(optimizations=pl.QueryOptFlags(collapse_joins=False)),244dont_mix.collect(),245check_row_order=False,246)247248iejoin = cross.filter(pl.col.x >= pl.col.a)249e = iejoin.explain()250assert "IEJOIN" in e251assert "NESTED LOOP JOIN" not in e252assert "CROSS JOIN" not in e253assert "FILTER" not in e254assert_frame_equal(255iejoin.collect(optimizations=pl.QueryOptFlags(collapse_joins=False)),256iejoin.collect(),257check_row_order=False,258)259260iejoin = cross.filter(pl.col.x >= pl.col.a).filter(pl.col.x <= pl.col.b)261e = iejoin.explain()262assert "IEJOIN" in e263assert "CROSS JOIN" not in e264assert "NESTED LOOP JOIN" not in e265assert "FILTER" not in e266assert_frame_equal(267iejoin.collect(optimizations=pl.QueryOptFlags(collapse_joins=False)),268iejoin.collect(),269check_row_order=False,270)271272273@pytest.mark.slow274def test_collapse_joins_combinations() -> None:275# This just tests all possible combinations for expressions on a cross join.276277a = pl.LazyFrame({"a": [1, 2, 3], "x": [7, 2, 1]})278b = pl.LazyFrame({"b": [2, 2, 2], "x": [7, 1, 3]})279280cross = a.join(b, how="cross")281282exprs = []283284for lhs in [pl.col.a, pl.col.b, pl.col.x, pl.lit(1), pl.col.a + pl.col.b]:285for rhs in [pl.col.a, pl.col.b, pl.col.x, pl.lit(1), pl.col.a * pl.col.x]:286for cmp in ["__eq__", "__ge__", "__lt__"]:287e = (getattr(lhs, cmp))(rhs)288exprs.append(e)289290for amount in range(3):291for merge in itertools.product(["__and__", "__or__"] * (amount - 1)):292for es in itertools.product(*([exprs] * amount)):293e = es[0]294for i in range(amount - 1):295e = (getattr(e, merge[i]))(es[i + 1])296297# NOTE: We need to sort because the order of the cross-join &298# IE-join is unspecified. Therefore, this might not necessarily299# create the exact same dataframe.300optimized = cross.filter(e).sort(pl.all()).collect()301unoptimized = cross.filter(e).collect(302optimizations=pl.QueryOptFlags(collapse_joins=False)303)304305try:306assert_frame_equal(optimized, unoptimized, check_row_order=False)307except:308print(e)309print()310print("Optimized")311print(cross.filter(e).explain())312print(optimized)313print()314print("Unoptimized")315print(316cross.filter(e).explain(317optimizations=pl.QueryOptFlags(collapse_joins=False)318)319)320print(unoptimized)321print()322323raise324325326def test_order_observe_sort_before_unique_22485() -> None:327lf = pl.LazyFrame(328{329"order": [3, 2, 1],330"id": ["A", "A", "B"],331}332)333334expect = pl.DataFrame({"order": [1, 3], "id": ["B", "A"]})335336q = lf.sort("order").unique(["id"], keep="last").sort("order")337338plan = q.explain()339assert "SORT BY" in plan[plan.index("UNIQUE") :]340341assert_frame_equal(q.collect(), expect)342343q = lf.sort("order").unique(["id"], keep="last", maintain_order=True)344345plan = q.explain()346assert "SORT BY" in plan[plan.index("UNIQUE") :]347348assert_frame_equal(q.collect(), expect)349350351def test_order_observe_group_by() -> None:352q = (353pl.LazyFrame({"a": range(5)})354.group_by("a", maintain_order=True)355.agg(b=1)356.sort("b")357)358359plan = q.explain()360assert "AGGREGATE[maintain_order: false]" in plan361362q = (363pl.LazyFrame({"a": range(5)})364.group_by("a", maintain_order=True)365.agg(b=1)366.sort("b", maintain_order=True)367)368369plan = q.explain()370assert "AGGREGATE[maintain_order: true]" in plan371372373def test_fused_correct_name() -> None:374df = pl.DataFrame({"x": [1, 2, 3]})375376lf = df.lazy().select(377(pl.col.x.alias("a") * pl.col.x.alias("b")) + pl.col.x.alias("c")378)379380no_opts = lf.collect(optimizations=pl.QueryOptFlags.none())381opts = lf.collect()382assert_frame_equal(383no_opts,384opts,385)386assert_frame_equal(opts, pl.DataFrame({"a": [2, 6, 12]}))387388389