Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/lazyframe/test_order_observability.py
6939 views
1
import pytest
2
3
import polars as pl
4
from polars.testing import assert_frame_equal
5
6
7
def test_order_observability() -> None:
8
q = pl.LazyFrame({"a": [1, 2, 3], "b": [1, 2, 3]}).sort("a")
9
10
opts = pl.QueryOptFlags(check_order_observe=True)
11
12
assert "SORT" not in q.group_by("a").sum().explain(optimizations=opts)
13
assert "SORT" not in q.group_by("a").min().explain(optimizations=opts)
14
assert "SORT" not in q.group_by("a").max().explain(optimizations=opts)
15
assert "SORT" in q.group_by("a").last().explain(optimizations=opts)
16
assert "SORT" in q.group_by("a").first().explain(optimizations=opts)
17
18
# (sort on column: keys) -- missed optimization opportunity for now
19
# assert "SORT" not in q.group_by("a").agg(pl.col("b")).explain(optimizations=opts)
20
21
# (sort on columns: agg) -- sort cannot be dropped
22
assert "SORT" in q.group_by("b").agg(pl.col("a")).explain(optimizations=opts)
23
24
25
def test_order_observability_group_by_dynamic() -> None:
26
assert (
27
pl.LazyFrame(
28
{"REGIONID": [1, 23, 4], "INTERVAL_END": [32, 43, 12], "POWER": [12, 3, 1]}
29
)
30
.sort("REGIONID", "INTERVAL_END")
31
.group_by_dynamic(index_column="INTERVAL_END", every="1i", group_by="REGIONID")
32
.agg(pl.col("POWER").sum())
33
.sort("POWER")
34
.head()
35
.explain()
36
).count("SORT") == 2
37
38
39
def test_remove_double_sort() -> None:
40
assert (
41
pl.LazyFrame({"a": [1, 2, 3, 3]}).sort("a").sort("a").explain().count("SORT")
42
== 1
43
)
44
45
46
def test_double_sort_maintain_order_18558() -> None:
47
df = pl.DataFrame(
48
{
49
"col1": [1, 2, 2, 4, 5, 6],
50
"col2": [2, 2, 0, 0, 2, None],
51
}
52
)
53
54
lf = df.lazy().sort("col2").sort("col1", maintain_order=True)
55
56
expect = pl.DataFrame(
57
[
58
pl.Series("col1", [1, 2, 2, 4, 5, 6], dtype=pl.Int64),
59
pl.Series("col2", [2, 0, 2, 0, 2, None], dtype=pl.Int64),
60
]
61
)
62
63
assert_frame_equal(lf.collect(), expect)
64
65
66
def test_sort_on_agg_maintain_order() -> None:
67
lf = pl.DataFrame(
68
{
69
"grp": [10, 10, 10, 30, 30, 30, 20, 20, 20],
70
"val": [1, 33, 2, 7, 99, 8, 4, 66, 5],
71
}
72
).lazy()
73
opts = pl.QueryOptFlags(check_order_observe=True)
74
75
out = lf.sort(pl.col("val")).group_by("grp").agg(pl.col("val"))
76
assert "SORT" in out.explain(optimizations=opts)
77
78
expected = pl.DataFrame(
79
{
80
"grp": [10, 20, 30],
81
"val": [[1, 2, 33], [4, 5, 66], [7, 8, 99]],
82
}
83
)
84
assert_frame_equal(out.collect(optimizations=opts), expected, check_row_order=False)
85
86
87
@pytest.mark.parametrize(
88
("func", "result"),
89
[
90
(pl.col("val").cum_sum(), 16), # (3 + (3+10)) after sort
91
(pl.col("val").cum_prod(), 33), # (3 + (3*10)) after sort
92
(pl.col("val").cum_min(), 6), # (3 + 3) after sort
93
(pl.col("val").cum_max(), 13), # (3 + 10) after sort
94
],
95
)
96
def test_sort_agg_with_nested_windowing_22918(func: pl.Expr, result: int) -> None:
97
# target pattern: df.sort().group_by().agg(_fooexpr()._barexpr())
98
# where _fooexpr is order dependent (e.g., cum_sum)
99
# and _barexpr is not order dependent (e.g., sum)
100
101
lf = pl.DataFrame(
102
data=[
103
{"val": 10, "id": 1, "grp": 0},
104
{"val": 3, "id": 0, "grp": 0},
105
]
106
).lazy()
107
108
out = lf.sort("id").group_by("grp").agg(func.sum())
109
expected = pl.DataFrame({"grp": 0, "val": result}) # (3 + (3+10)) after sort
110
111
assert_frame_equal(out.collect(), expected)
112
assert "SORT" in out.explain()
113
114
115
def test_remove_sorts_on_unordered() -> None:
116
lf = pl.LazyFrame({"a": [1, 2, 3]}).sort("a").sort("a").sort("a")
117
explain = lf.explain()
118
assert explain.count("SORT") == 1
119
120
lf = (
121
pl.LazyFrame({"a": [1, 2, 3]})
122
.sort("a")
123
.group_by("a")
124
.agg([])
125
.sort("a")
126
.group_by("a")
127
.agg([])
128
.sort("a")
129
.group_by("a")
130
.agg([])
131
)
132
explain = lf.explain()
133
assert explain.count("SORT") == 0
134
135
lf = (
136
pl.LazyFrame({"a": [1, 2, 3]})
137
.sort("a")
138
.join(pl.LazyFrame({"b": [1, 2, 3]}), on=pl.lit(1))
139
)
140
explain = lf.explain()
141
assert explain.count("SORT") == 0
142
143
lf = pl.LazyFrame({"a": [1, 2, 3]}).sort("a").unique()
144
explain = lf.explain()
145
assert explain.count("SORT") == 0
146
147
148
def test_merge_sorted_to_union() -> None:
149
lf1 = pl.LazyFrame({"a": [1, 2, 3]})
150
lf2 = pl.LazyFrame({"a": [2, 3, 4]})
151
152
lf = lf1.merge_sorted(lf2, "a").unique()
153
154
explain = lf.explain(optimizations=pl.QueryOptFlags(check_order_observe=False))
155
assert "MERGE_SORTED" in explain
156
assert "UNION" not in explain
157
158
explain = lf.explain()
159
assert "MERGE_SORTED" not in explain
160
assert "UNION" in explain
161
162
163
@pytest.mark.parametrize(
164
"order_sensitive_expr",
165
[
166
pl.arange(0, pl.len()),
167
pl.int_range(pl.len()),
168
pl.row_index().cast(pl.Int64),
169
pl.lit([0, 1, 2, 3, 4], dtype=pl.List(pl.Int64)).explode(),
170
pl.lit(pl.Series([0, 1, 2, 3, 4])),
171
pl.lit(pl.Series([[0], [1], [2], [3], [4]])).explode(),
172
pl.col("y").sort(),
173
pl.col("y").sort_by(pl.col("y"), maintain_order=True),
174
pl.col("y").sort_by(pl.col("y"), maintain_order=False),
175
pl.col("x").gather(pl.col("x")),
176
],
177
)
178
def test_order_sensitive_exprs_24335(order_sensitive_expr: pl.Expr) -> None:
179
expect = pl.DataFrame(
180
{
181
"x": [0, 1, 2, 3, 4],
182
"y": [3, 4, 0, 1, 2],
183
"out": [0, 1, 2, 3, 4],
184
}
185
)
186
187
q = (
188
pl.LazyFrame({"x": [0, 1, 2, 3, 4], "y": [3, 4, 0, 1, 2]})
189
.unique(maintain_order=True)
190
.with_columns(order_sensitive_expr.alias("out"))
191
.unique()
192
)
193
194
plan = q.explain()
195
196
assert plan.index("UNIQUE[maintain_order: true") > plan.index("WITH_COLUMNS")
197
198
assert_frame_equal(q.collect().sort(pl.all()), expect)
199
200