Path: blob/main/py-polars/tests/unit/lazyframe/test_order_observability.py
8424 views
from __future__ import annotations12from typing import Any34import pytest56import polars as pl7from polars.testing import assert_frame_equal, assert_series_equal8910def test_order_observability() -> None:11q = pl.LazyFrame({"a": [1, 2, 3], "b": [1, 2, 3]}).sort("a")1213opts = pl.QueryOptFlags(check_order_observe=True)1415assert "SORT" not in q.group_by("a").sum().explain(optimizations=opts)16assert "SORT" not in q.group_by("a").min().explain(optimizations=opts)17assert "SORT" not in q.group_by("a").max().explain(optimizations=opts)18assert "SORT" in q.group_by("a").last().explain(optimizations=opts)19assert "SORT" in q.group_by("a").first().explain(optimizations=opts)2021# (sort on column: keys) -- missed optimization opportunity for now22# assert "SORT" not in q.group_by("a").agg(pl.col("b")).explain(optimizations=opts)2324# (sort on columns: agg) -- sort cannot be dropped25assert "SORT" in q.group_by("b").agg(pl.col("a")).explain(optimizations=opts)262728def test_order_observability_group_by_dynamic() -> None:29assert (30pl.LazyFrame(31{"REGIONID": [1, 23, 4], "INTERVAL_END": [32, 43, 12], "POWER": [12, 3, 1]}32)33.sort("REGIONID", "INTERVAL_END")34.group_by_dynamic(index_column="INTERVAL_END", every="1i", group_by="REGIONID")35.agg(pl.col("POWER").sum())36.sort("POWER")37.head()38.explain()39).count("SORT") == 2404142def test_remove_double_sort() -> None:43assert (44pl.LazyFrame({"a": [1, 2, 3, 3]}).sort("a").sort("a").explain().count("SORT")45== 146)474849def test_double_sort_maintain_order_18558() -> None:50df = pl.DataFrame(51{52"col1": [1, 2, 2, 4, 5, 6],53"col2": [2, 2, 0, 0, 2, None],54}55)5657lf = df.lazy().sort("col2").sort("col1", maintain_order=True)5859expect = pl.DataFrame(60[61pl.Series("col1", [1, 2, 2, 4, 5, 6], dtype=pl.Int64),62pl.Series("col2", [2, 0, 2, 0, 2, None], dtype=pl.Int64),63]64)6566assert_frame_equal(lf.collect(), expect)676869def test_sort_on_agg_maintain_order() -> None:70lf = pl.DataFrame(71{72"grp": [10, 10, 10, 30, 30, 30, 20, 20, 20],73"val": [1, 33, 2, 7, 99, 8, 4, 66, 5],74}75).lazy()76opts = pl.QueryOptFlags(check_order_observe=True)7778out = lf.sort(pl.col("val")).group_by("grp").agg(pl.col("val"))79assert "SORT" in out.explain(optimizations=opts)8081expected = pl.DataFrame(82{83"grp": [10, 20, 30],84"val": [[1, 2, 33], [4, 5, 66], [7, 8, 99]],85}86)87assert_frame_equal(out.collect(optimizations=opts), expected, check_row_order=False)888990@pytest.mark.parametrize(91("func", "result"),92[93(pl.col("val").cum_sum(), 16), # (3 + (3+10)) after sort94(pl.col("val").cum_prod(), 33), # (3 + (3*10)) after sort95(pl.col("val").cum_min(), 6), # (3 + 3) after sort96(pl.col("val").cum_max(), 13), # (3 + 10) after sort97],98)99def test_sort_agg_with_nested_windowing_22918(func: pl.Expr, result: int) -> None:100# target pattern: df.sort().group_by().agg(_fooexpr()._barexpr())101# where _fooexpr is order dependent (e.g., cum_sum)102# and _barexpr is not order dependent (e.g., sum)103104lf = pl.DataFrame(105data=[106{"val": 10, "id": 1, "grp": 0},107{"val": 3, "id": 0, "grp": 0},108]109).lazy()110111out = lf.sort("id").group_by("grp").agg(func.sum())112expected = pl.DataFrame({"grp": 0, "val": result}) # (3 + (3+10)) after sort113114assert_frame_equal(out.collect(), expected)115assert "SORT" in out.explain()116117118def test_remove_sorts_on_unordered() -> None:119lf = pl.LazyFrame({"a": [1, 2, 3]}).sort("a").sort("a").sort("a")120explain = lf.explain()121assert explain.count("SORT") == 1122123lf = (124pl.LazyFrame({"a": [1, 2, 3]})125.sort("a")126.group_by("a")127.agg([])128.sort("a")129.group_by("a")130.agg([])131.sort("a")132.group_by("a")133.agg([])134)135explain = lf.explain()136assert explain.count("SORT") == 0137138lf = (139pl.LazyFrame({"a": [1, 2, 3]})140.sort("a")141.join(pl.LazyFrame({"b": [1, 2, 3]}), on=pl.lit(1))142)143explain = lf.explain()144assert explain.count("SORT") == 0145146lf = pl.LazyFrame({"a": [1, 2, 3]}).sort("a").unique()147explain = lf.explain()148assert explain.count("SORT") == 0149150151def test_merge_sorted_to_union() -> None:152lf1 = pl.LazyFrame({"a": [1, 2, 3]})153lf2 = pl.LazyFrame({"a": [2, 3, 4]})154155lf = lf1.merge_sorted(lf2, "a").unique()156157explain = lf.explain(optimizations=pl.QueryOptFlags(check_order_observe=False))158assert "MERGE_SORTED" in explain159assert "UNION" not in explain160161explain = lf.explain()162assert "MERGE_SORTED" not in explain163assert "UNION" in explain164165166@pytest.mark.parametrize(167"order_sensitive_expr",168[169pl.arange(0, pl.len()),170pl.int_range(pl.len()),171pl.row_index().cast(pl.Int64),172pl.lit([0, 1, 2, 3, 4], dtype=pl.List(pl.Int64)).explode(),173pl.lit(pl.Series([0, 1, 2, 3, 4])),174pl.lit(pl.Series([[0], [1], [2], [3], [4]])).explode(),175pl.col("y").sort(),176pl.col("y").sort_by(pl.col("y"), maintain_order=True),177pl.col("y").sort_by(pl.col("y"), maintain_order=False),178pl.col("x").gather(pl.col("x")),179],180)181def test_order_sensitive_exprs_24335(order_sensitive_expr: pl.Expr) -> None:182expect = pl.DataFrame(183{184"x": [0, 1, 2, 3, 4],185"y": [3, 4, 0, 1, 2],186"out": [0, 1, 2, 3, 4],187}188)189190q = (191pl.LazyFrame({"x": [0, 1, 2, 3, 4], "y": [3, 4, 0, 1, 2]})192.unique(maintain_order=True)193.with_columns(order_sensitive_expr.alias("out"))194.unique()195)196197plan = q.explain()198199assert plan.index("UNIQUE[maintain_order: true") > plan.index("WITH_COLUMNS")200201assert_frame_equal(q.collect().sort(pl.all()), expect)202203204def assert_correct_ordering(205lf: pl.LazyFrame,206expr: pl.Expr,207*,208expected: pl.Series | None,209is_order_observing: bool,210pad_exprs: list[pl.Expr] | None = None,211) -> None:212if pad_exprs is None:213pad_exprs = []214q = lf.unique(maintain_order=True).select(pad_exprs + [expr]).unique()215assert ("UNIQUE[maintain_order: true" in q.explain()) == is_order_observing216217result = q.collect()218if expected is not None:219unoptimized_result = q.collect(optimizations=pl.QueryOptFlags.none())220221assert_series_equal(222result.to_series(len(pad_exprs)), expected, check_order=False223)224assert_frame_equal(225result,226unoptimized_result,227check_row_order=False,228)229230231c = pl.col.a232233234@pytest.mark.parametrize(235("is_order_observing", "agg", "output", "output_dtype"),236[237(False, c.min(), 1, pl.Int64()),238(False, c.count(), 3, pl.get_index_type()),239(False, c.len(), 3, pl.get_index_type()),240(False, c.product(), 6, pl.Int64()),241(False, c.bitwise_or(), 3, pl.Int64()),242(False, (c == 1).any(), True, pl.Boolean()),243(False, pl.when(c != 1).then(c).null_count(), 1, pl.get_index_type()),244(True, c.first(), 2, pl.Int64()),245(True, c.implode(), [2, 1, 3], pl.List(pl.Int64())),246(True, c.arg_min(), 1, pl.get_index_type()),247],248)249def test_order_sensitive_aggregations_parametric(250is_order_observing: bool, agg: pl.Expr, output: Any, output_dtype: pl.DataType251) -> None:252assert_correct_ordering(253pl.LazyFrame({"a": [2, 1, 3]}),254agg.alias("agg"),255expected=pl.Series("agg", [output] * 3, output_dtype),256is_order_observing=is_order_observing,257pad_exprs=[pl.col.a],258)259260261lf1 = pl.LazyFrame({"a": [3, 1, 2]})262lf2 = pl.LazyFrame({"a": [2, 1, 3]})263lf3 = pl.LazyFrame({"a": [[1, 2], [3]], "b": [[3], [4, 5]]})264lf4 = pl.LazyFrame({"a": [2, 1, 3], "b": [4, 6, 5]})265lf5 = pl.LazyFrame({"a": [2, None, 3]})266lf6 = pl.LazyFrame({"a": [[1], [2]], "b": [[3], [4]]})267268269@pytest.mark.parametrize(270("lf", "expr", "expected", "is_order_observing"),271[272(lf1, pl.col.a.sort() * pl.col.a, [3, 2, 6], True),273(lf1, pl.col.a * pl.col.a, [1, 4, 9], False),274(275lf2,276pl.lit(pl.Series("a", [2, 1, 3, 4])).gather(277pl.col.a.filter(pl.col.a > 1) - 1278),279[1, 3],280False,281),282(lf1, pl.col.a.mode(), [1, 2, 3], False),283(lf2, pl.col.a.gather([0, 2]), [2, 3], True),284(lf2, pl.col.a, [2, 1, 3], False),285(lf2, pl.col.a + 1, [3, 2, 4], False),286(lf2, pl.lit(pl.Series("a", [2, 1, 3, 4])).gather([0, 2]), [2, 3], False),287(lf2, pl.col.a.filter(pl.col.a != 1), [2, 3], False),288(lf3, pl.col.a.explode() * pl.col.b.explode(), [3, 8, 15], True),289(lf4, pl.col.a.sort() + pl.col.b, [5, 8], True),290(lf4, pl.col.a.sort() + pl.col.b.sort(), [5, 7, 9], False),291(lf4, pl.col.a + pl.col.b, pl.Series("a", [6, 7, 8]), False),292(lf4, pl.col.a.unique() * pl.col.b.unique(), None, False),293(lf5, pl.col.a.drop_nulls(), [2, 3], False),294],295)296def test_order_sensitive_paramateric(297lf: pl.LazyFrame,298expr: pl.Expr,299expected: pl.Series | list[Any] | None,300is_order_observing: bool,301) -> None:302if isinstance(expected, pl.Series):303expected = expected.rename("a")304elif isinstance(expected, list):305expected = pl.Series("a", expected)306307assert_correct_ordering(308lf,309expr.alias("a"),310expected=expected,311is_order_observing=is_order_observing,312)313314315def test_with_columns_implicit_columns() -> None:316# Test that overwriting all columns in `with_columns` does not require ordering to317# be preserved.318q = (319lf6.select("a")320.unique(maintain_order=True)321.with_columns(pl.col.a.explode())322.unique()323)324assert "UNIQUE[maintain_order: true" not in q.explain()325assert_series_equal(326q.collect().to_series(), pl.Series("a", [1, 2]), check_order=False327)328q = lf6.unique(maintain_order=True).with_columns(pl.col.a.explode()).unique()329assert "UNIQUE[maintain_order: true" in q.explain()330assert_frame_equal(331q.collect(),332pl.DataFrame(333{334"a": [1, 2],335"b": [[3], [4]],336}337),338check_row_order=False,339)340q = lf6.unique(maintain_order=True).with_columns(pl.col.a.alias("c")).unique()341assert "UNIQUE[maintain_order: true" not in q.explain()342assert_frame_equal(343q.collect(),344pl.DataFrame(345{346"a": [[1], [2]],347"b": [[3], [4]],348"c": [[1], [2]],349}350),351check_row_order=False,352)353354355@pytest.mark.parametrize(356("expr", "values", "is_ordered", "is_output_ordered"),357[358(pl.col.a, [1, 2, 3], False, False),359(pl.col.a.map_batches(lambda x: x), [1, 2, 3], True, False),360(361pl.col.a.map_batches(lambda x: x, is_elementwise=True),362[1, 2, 3],363False,364False,365),366(367pl.col.a.cast(pl.List(pl.Int64))368.map_batches(lambda x: x, is_elementwise=True)369.explode(),370[1, 2, 3],371True,372False,373),374(pl.col.a.sort(), [1, 2, 3], True, True),375(pl.col.a.sort() + pl.col.a, None, True, True),376(pl.col.a.min() + pl.col.a, [2, 3, 4], False, False),377(pl.col.a.first() + pl.col.a, None, False, False),378],379)380def test_group_by_key_sensitivity(381expr: pl.Expr, values: list[int] | None, is_ordered: bool, is_output_ordered: bool382) -> None:383lf = pl.LazyFrame({"a": [2, 2, 1, 3], "b": ["A", "B", "C", "D"]}).unique()384385q = lf.group_by(expr.alias("a"), maintain_order=True).agg("b")386df = q.collect()387assert ("AGGREGATE[maintain_order: true]" in q.explain()) is is_ordered388389expected_values = pl.Series("a", values)390391if values is not None:392assert_series_equal(df["a"], expected_values, check_order=is_output_ordered)393394395@pytest.mark.parametrize(396("expr", "is_ordered"),397[398(pl.col.a, False),399(pl.col.a.map_batches(lambda x: x), True),400(pl.col.a.map_batches(lambda x: x, is_elementwise=True), False),401(402pl.col.a.cast(pl.List(pl.Int64))403.map_batches(lambda x: x, is_elementwise=True)404.explode(),405True,406),407(pl.col.a.cum_prod(), True),408(pl.col.a.cum_prod() + pl.col.a, True),409(pl.col.a.min() + pl.col.a, False),410(pl.col.a.first() + pl.col.a, True),411],412)413def test_sort_key_sensitivity(expr: pl.Expr, is_ordered: bool) -> None:414lf = pl.LazyFrame({"a": [2, 2, 1, 3], "b": ["A", "B", "C", "D"]}).sort(pl.all())415q = lf.sort(expr)416assert (q.explain().count("SORT BY") == 2) is is_ordered417assert_frame_equal(q.collect(), lf.sort("a").collect())418419420@pytest.mark.parametrize(421("expr", "is_ordered"),422[423(pl.col.a, False),424(pl.col.a.map_batches(lambda x: x), True),425(pl.col.a.map_batches(lambda x: x, is_elementwise=True), False),426(427pl.col.a.cast(pl.List(pl.Int64))428.map_batches(lambda x: x, is_elementwise=True)429.explode(),430True,431),432(pl.col.a.cum_prod(), True),433(pl.col.a.cum_prod() + pl.col.a, True),434(pl.col.a.min() + pl.col.a, False),435(pl.col.a.first() + pl.col.a, True),436],437)438def test_filter_sensitivity(expr: pl.Expr, is_ordered: bool) -> None:439lf = pl.LazyFrame({"a": [2, 2, 1, 3], "b": ["A", "B", "C", "D"]}).sort(pl.all())440q = lf.filter(expr > 0).unique()441assert ("SORT BY" in q.explain()) is is_ordered442assert_frame_equal(q.collect(), lf.collect(), check_row_order=False)443444445@pytest.mark.parametrize(446("exprs", "is_ordered", "unordered_columns"),447[448([pl.col.a], True, None),449([pl.col.a, pl.col.b], True, None),450([pl.col.a.unique()], True, ["a"]),451([pl.col.a.min()], True, None),452([pl.col.a.product()], True, None),453([pl.col.a.unique(), pl.col.b], True, ["a"]),454([pl.col.a.unique(), pl.col.b.unique()], False, ["a", "b"]),455([pl.col.a.min(), pl.col.b.min()], False, None),456([pl.col.a.product(), pl.col.b.null_count()], False, None),457([pl.col.b.unique()], True, ["b"]),458([pl.col.a.unique(), pl.col.b.unique(), pl.col.a.alias("c")], True, ["a", "b"]),459(460[pl.col.a.unique(), pl.col.b.unique(), (pl.col.a + 1).unique().alias("c")],461False,462["a", "b", "c"],463),464(465[pl.col.a.min(), pl.col.b.min(), (pl.col.a + 1).min().alias("c")],466False,467None,468),469(470[471pl.col.a.product(),472pl.col.b.null_count(),473(pl.col.a + 1).product().alias("c"),474],475False,476None,477),478],479)480def test_with_columns_sensitivity(481exprs: list[pl.Expr], is_ordered: bool, unordered_columns: list[str] | None482) -> None:483lf = (484pl.LazyFrame({"a": [2, 4, 1, 3], "b": ["A", "C", "B", "D"]})485.sort("a")486.with_columns(*exprs)487.unique(maintain_order=True)488)489assert ("UNIQUE[maintain_order: true" in lf.explain()) is is_ordered490491df_opt = lf.collect()492df_unopt = lf.collect(optimizations=pl.QueryOptFlags(check_order_observe=False))493494if unordered_columns is None:495assert_frame_equal(df_opt, df_unopt)496else:497assert_frame_equal(498df_opt.drop(unordered_columns), df_unopt.drop(unordered_columns)499)500for c in unordered_columns:501assert_series_equal(df_opt[c], df_unopt[c], check_order=False)502503504def test_reverse_non_order_observe() -> None:505q = (506pl.LazyFrame({"x": [0, 1, 2, 3, 4]})507.unique(maintain_order=True)508.select(pl.col("x").reverse().sum())509)510511plan = q.explain()512513assert "UNIQUE[maintain_order: false" in plan514assert q.collect().item() == 10515516# Observing the order of the output of `reverse()` implicitly observes the517# input to `reverse()`.518q = (519pl.LazyFrame({"x": [0, 1, 2, 3, 4]})520.unique(maintain_order=True)521.select(pl.col("x").reverse().last())522)523524plan = q.explain()525526assert "UNIQUE[maintain_order: true" in plan527assert q.collect().item() == 0528529# Zipping `reverse()` must also consider the ordering of the input to530# `reverse()`.531q = (532pl.LazyFrame({"x": [0, 1, 2, 3, 4]})533.unique(maintain_order=True)534.select(x=pl.Series([0, 1, 2, 3, 4]), x_reverse=pl.col("x").reverse())535)536537plan = q.explain()538assert "UNIQUE[maintain_order: true" in plan539assert_frame_equal(540q,541pl.LazyFrame(542{543"x": [0, 1, 2, 3, 4],544"x_reverse": [4, 3, 2, 1, 0],545}546),547)548549550def test_order_optimize_cspe_26277() -> None:551df = pl.LazyFrame({"x": [1, 2]}).sort("x")552553q1 = pl.concat([df, df])554q2 = pl.concat([q1, q1])555q3 = q2.sort("x").with_columns("x")556557assert_frame_equal(558q3.collect(),559pl.DataFrame({"x": [1, 1, 1, 1, 2, 2, 2, 2]}),560)561562563