Path: blob/main/py-polars/tests/unit/lazyframe/test_order_observability.py
6939 views
import pytest12import polars as pl3from polars.testing import assert_frame_equal456def test_order_observability() -> None:7q = pl.LazyFrame({"a": [1, 2, 3], "b": [1, 2, 3]}).sort("a")89opts = pl.QueryOptFlags(check_order_observe=True)1011assert "SORT" not in q.group_by("a").sum().explain(optimizations=opts)12assert "SORT" not in q.group_by("a").min().explain(optimizations=opts)13assert "SORT" not in q.group_by("a").max().explain(optimizations=opts)14assert "SORT" in q.group_by("a").last().explain(optimizations=opts)15assert "SORT" in q.group_by("a").first().explain(optimizations=opts)1617# (sort on column: keys) -- missed optimization opportunity for now18# assert "SORT" not in q.group_by("a").agg(pl.col("b")).explain(optimizations=opts)1920# (sort on columns: agg) -- sort cannot be dropped21assert "SORT" in q.group_by("b").agg(pl.col("a")).explain(optimizations=opts)222324def test_order_observability_group_by_dynamic() -> None:25assert (26pl.LazyFrame(27{"REGIONID": [1, 23, 4], "INTERVAL_END": [32, 43, 12], "POWER": [12, 3, 1]}28)29.sort("REGIONID", "INTERVAL_END")30.group_by_dynamic(index_column="INTERVAL_END", every="1i", group_by="REGIONID")31.agg(pl.col("POWER").sum())32.sort("POWER")33.head()34.explain()35).count("SORT") == 2363738def test_remove_double_sort() -> None:39assert (40pl.LazyFrame({"a": [1, 2, 3, 3]}).sort("a").sort("a").explain().count("SORT")41== 142)434445def test_double_sort_maintain_order_18558() -> None:46df = pl.DataFrame(47{48"col1": [1, 2, 2, 4, 5, 6],49"col2": [2, 2, 0, 0, 2, None],50}51)5253lf = df.lazy().sort("col2").sort("col1", maintain_order=True)5455expect = pl.DataFrame(56[57pl.Series("col1", [1, 2, 2, 4, 5, 6], dtype=pl.Int64),58pl.Series("col2", [2, 0, 2, 0, 2, None], dtype=pl.Int64),59]60)6162assert_frame_equal(lf.collect(), expect)636465def test_sort_on_agg_maintain_order() -> None:66lf = pl.DataFrame(67{68"grp": [10, 10, 10, 30, 30, 30, 20, 20, 20],69"val": [1, 33, 2, 7, 99, 8, 4, 66, 5],70}71).lazy()72opts = pl.QueryOptFlags(check_order_observe=True)7374out = lf.sort(pl.col("val")).group_by("grp").agg(pl.col("val"))75assert "SORT" in out.explain(optimizations=opts)7677expected = pl.DataFrame(78{79"grp": [10, 20, 30],80"val": [[1, 2, 33], [4, 5, 66], [7, 8, 99]],81}82)83assert_frame_equal(out.collect(optimizations=opts), expected, check_row_order=False)848586@pytest.mark.parametrize(87("func", "result"),88[89(pl.col("val").cum_sum(), 16), # (3 + (3+10)) after sort90(pl.col("val").cum_prod(), 33), # (3 + (3*10)) after sort91(pl.col("val").cum_min(), 6), # (3 + 3) after sort92(pl.col("val").cum_max(), 13), # (3 + 10) after sort93],94)95def test_sort_agg_with_nested_windowing_22918(func: pl.Expr, result: int) -> None:96# target pattern: df.sort().group_by().agg(_fooexpr()._barexpr())97# where _fooexpr is order dependent (e.g., cum_sum)98# and _barexpr is not order dependent (e.g., sum)99100lf = pl.DataFrame(101data=[102{"val": 10, "id": 1, "grp": 0},103{"val": 3, "id": 0, "grp": 0},104]105).lazy()106107out = lf.sort("id").group_by("grp").agg(func.sum())108expected = pl.DataFrame({"grp": 0, "val": result}) # (3 + (3+10)) after sort109110assert_frame_equal(out.collect(), expected)111assert "SORT" in out.explain()112113114def test_remove_sorts_on_unordered() -> None:115lf = pl.LazyFrame({"a": [1, 2, 3]}).sort("a").sort("a").sort("a")116explain = lf.explain()117assert explain.count("SORT") == 1118119lf = (120pl.LazyFrame({"a": [1, 2, 3]})121.sort("a")122.group_by("a")123.agg([])124.sort("a")125.group_by("a")126.agg([])127.sort("a")128.group_by("a")129.agg([])130)131explain = lf.explain()132assert explain.count("SORT") == 0133134lf = (135pl.LazyFrame({"a": [1, 2, 3]})136.sort("a")137.join(pl.LazyFrame({"b": [1, 2, 3]}), on=pl.lit(1))138)139explain = lf.explain()140assert explain.count("SORT") == 0141142lf = pl.LazyFrame({"a": [1, 2, 3]}).sort("a").unique()143explain = lf.explain()144assert explain.count("SORT") == 0145146147def test_merge_sorted_to_union() -> None:148lf1 = pl.LazyFrame({"a": [1, 2, 3]})149lf2 = pl.LazyFrame({"a": [2, 3, 4]})150151lf = lf1.merge_sorted(lf2, "a").unique()152153explain = lf.explain(optimizations=pl.QueryOptFlags(check_order_observe=False))154assert "MERGE_SORTED" in explain155assert "UNION" not in explain156157explain = lf.explain()158assert "MERGE_SORTED" not in explain159assert "UNION" in explain160161162@pytest.mark.parametrize(163"order_sensitive_expr",164[165pl.arange(0, pl.len()),166pl.int_range(pl.len()),167pl.row_index().cast(pl.Int64),168pl.lit([0, 1, 2, 3, 4], dtype=pl.List(pl.Int64)).explode(),169pl.lit(pl.Series([0, 1, 2, 3, 4])),170pl.lit(pl.Series([[0], [1], [2], [3], [4]])).explode(),171pl.col("y").sort(),172pl.col("y").sort_by(pl.col("y"), maintain_order=True),173pl.col("y").sort_by(pl.col("y"), maintain_order=False),174pl.col("x").gather(pl.col("x")),175],176)177def test_order_sensitive_exprs_24335(order_sensitive_expr: pl.Expr) -> None:178expect = pl.DataFrame(179{180"x": [0, 1, 2, 3, 4],181"y": [3, 4, 0, 1, 2],182"out": [0, 1, 2, 3, 4],183}184)185186q = (187pl.LazyFrame({"x": [0, 1, 2, 3, 4], "y": [3, 4, 0, 1, 2]})188.unique(maintain_order=True)189.with_columns(order_sensitive_expr.alias("out"))190.unique()191)192193plan = q.explain()194195assert plan.index("UNIQUE[maintain_order: true") > plan.index("WITH_COLUMNS")196197assert_frame_equal(q.collect().sort(pl.all()), expect)198199200