Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/lazyframe/test_order_observability.py
8424 views
1
from __future__ import annotations
2
3
from typing import Any
4
5
import pytest
6
7
import polars as pl
8
from polars.testing import assert_frame_equal, assert_series_equal
9
10
11
def test_order_observability() -> None:
12
q = pl.LazyFrame({"a": [1, 2, 3], "b": [1, 2, 3]}).sort("a")
13
14
opts = pl.QueryOptFlags(check_order_observe=True)
15
16
assert "SORT" not in q.group_by("a").sum().explain(optimizations=opts)
17
assert "SORT" not in q.group_by("a").min().explain(optimizations=opts)
18
assert "SORT" not in q.group_by("a").max().explain(optimizations=opts)
19
assert "SORT" in q.group_by("a").last().explain(optimizations=opts)
20
assert "SORT" in q.group_by("a").first().explain(optimizations=opts)
21
22
# (sort on column: keys) -- missed optimization opportunity for now
23
# assert "SORT" not in q.group_by("a").agg(pl.col("b")).explain(optimizations=opts)
24
25
# (sort on columns: agg) -- sort cannot be dropped
26
assert "SORT" in q.group_by("b").agg(pl.col("a")).explain(optimizations=opts)
27
28
29
def test_order_observability_group_by_dynamic() -> None:
30
assert (
31
pl.LazyFrame(
32
{"REGIONID": [1, 23, 4], "INTERVAL_END": [32, 43, 12], "POWER": [12, 3, 1]}
33
)
34
.sort("REGIONID", "INTERVAL_END")
35
.group_by_dynamic(index_column="INTERVAL_END", every="1i", group_by="REGIONID")
36
.agg(pl.col("POWER").sum())
37
.sort("POWER")
38
.head()
39
.explain()
40
).count("SORT") == 2
41
42
43
def test_remove_double_sort() -> None:
44
assert (
45
pl.LazyFrame({"a": [1, 2, 3, 3]}).sort("a").sort("a").explain().count("SORT")
46
== 1
47
)
48
49
50
def test_double_sort_maintain_order_18558() -> None:
51
df = pl.DataFrame(
52
{
53
"col1": [1, 2, 2, 4, 5, 6],
54
"col2": [2, 2, 0, 0, 2, None],
55
}
56
)
57
58
lf = df.lazy().sort("col2").sort("col1", maintain_order=True)
59
60
expect = pl.DataFrame(
61
[
62
pl.Series("col1", [1, 2, 2, 4, 5, 6], dtype=pl.Int64),
63
pl.Series("col2", [2, 0, 2, 0, 2, None], dtype=pl.Int64),
64
]
65
)
66
67
assert_frame_equal(lf.collect(), expect)
68
69
70
def test_sort_on_agg_maintain_order() -> None:
71
lf = pl.DataFrame(
72
{
73
"grp": [10, 10, 10, 30, 30, 30, 20, 20, 20],
74
"val": [1, 33, 2, 7, 99, 8, 4, 66, 5],
75
}
76
).lazy()
77
opts = pl.QueryOptFlags(check_order_observe=True)
78
79
out = lf.sort(pl.col("val")).group_by("grp").agg(pl.col("val"))
80
assert "SORT" in out.explain(optimizations=opts)
81
82
expected = pl.DataFrame(
83
{
84
"grp": [10, 20, 30],
85
"val": [[1, 2, 33], [4, 5, 66], [7, 8, 99]],
86
}
87
)
88
assert_frame_equal(out.collect(optimizations=opts), expected, check_row_order=False)
89
90
91
@pytest.mark.parametrize(
92
("func", "result"),
93
[
94
(pl.col("val").cum_sum(), 16), # (3 + (3+10)) after sort
95
(pl.col("val").cum_prod(), 33), # (3 + (3*10)) after sort
96
(pl.col("val").cum_min(), 6), # (3 + 3) after sort
97
(pl.col("val").cum_max(), 13), # (3 + 10) after sort
98
],
99
)
100
def test_sort_agg_with_nested_windowing_22918(func: pl.Expr, result: int) -> None:
101
# target pattern: df.sort().group_by().agg(_fooexpr()._barexpr())
102
# where _fooexpr is order dependent (e.g., cum_sum)
103
# and _barexpr is not order dependent (e.g., sum)
104
105
lf = pl.DataFrame(
106
data=[
107
{"val": 10, "id": 1, "grp": 0},
108
{"val": 3, "id": 0, "grp": 0},
109
]
110
).lazy()
111
112
out = lf.sort("id").group_by("grp").agg(func.sum())
113
expected = pl.DataFrame({"grp": 0, "val": result}) # (3 + (3+10)) after sort
114
115
assert_frame_equal(out.collect(), expected)
116
assert "SORT" in out.explain()
117
118
119
def test_remove_sorts_on_unordered() -> None:
120
lf = pl.LazyFrame({"a": [1, 2, 3]}).sort("a").sort("a").sort("a")
121
explain = lf.explain()
122
assert explain.count("SORT") == 1
123
124
lf = (
125
pl.LazyFrame({"a": [1, 2, 3]})
126
.sort("a")
127
.group_by("a")
128
.agg([])
129
.sort("a")
130
.group_by("a")
131
.agg([])
132
.sort("a")
133
.group_by("a")
134
.agg([])
135
)
136
explain = lf.explain()
137
assert explain.count("SORT") == 0
138
139
lf = (
140
pl.LazyFrame({"a": [1, 2, 3]})
141
.sort("a")
142
.join(pl.LazyFrame({"b": [1, 2, 3]}), on=pl.lit(1))
143
)
144
explain = lf.explain()
145
assert explain.count("SORT") == 0
146
147
lf = pl.LazyFrame({"a": [1, 2, 3]}).sort("a").unique()
148
explain = lf.explain()
149
assert explain.count("SORT") == 0
150
151
152
def test_merge_sorted_to_union() -> None:
153
lf1 = pl.LazyFrame({"a": [1, 2, 3]})
154
lf2 = pl.LazyFrame({"a": [2, 3, 4]})
155
156
lf = lf1.merge_sorted(lf2, "a").unique()
157
158
explain = lf.explain(optimizations=pl.QueryOptFlags(check_order_observe=False))
159
assert "MERGE_SORTED" in explain
160
assert "UNION" not in explain
161
162
explain = lf.explain()
163
assert "MERGE_SORTED" not in explain
164
assert "UNION" in explain
165
166
167
@pytest.mark.parametrize(
168
"order_sensitive_expr",
169
[
170
pl.arange(0, pl.len()),
171
pl.int_range(pl.len()),
172
pl.row_index().cast(pl.Int64),
173
pl.lit([0, 1, 2, 3, 4], dtype=pl.List(pl.Int64)).explode(),
174
pl.lit(pl.Series([0, 1, 2, 3, 4])),
175
pl.lit(pl.Series([[0], [1], [2], [3], [4]])).explode(),
176
pl.col("y").sort(),
177
pl.col("y").sort_by(pl.col("y"), maintain_order=True),
178
pl.col("y").sort_by(pl.col("y"), maintain_order=False),
179
pl.col("x").gather(pl.col("x")),
180
],
181
)
182
def test_order_sensitive_exprs_24335(order_sensitive_expr: pl.Expr) -> None:
183
expect = pl.DataFrame(
184
{
185
"x": [0, 1, 2, 3, 4],
186
"y": [3, 4, 0, 1, 2],
187
"out": [0, 1, 2, 3, 4],
188
}
189
)
190
191
q = (
192
pl.LazyFrame({"x": [0, 1, 2, 3, 4], "y": [3, 4, 0, 1, 2]})
193
.unique(maintain_order=True)
194
.with_columns(order_sensitive_expr.alias("out"))
195
.unique()
196
)
197
198
plan = q.explain()
199
200
assert plan.index("UNIQUE[maintain_order: true") > plan.index("WITH_COLUMNS")
201
202
assert_frame_equal(q.collect().sort(pl.all()), expect)
203
204
205
def assert_correct_ordering(
206
lf: pl.LazyFrame,
207
expr: pl.Expr,
208
*,
209
expected: pl.Series | None,
210
is_order_observing: bool,
211
pad_exprs: list[pl.Expr] | None = None,
212
) -> None:
213
if pad_exprs is None:
214
pad_exprs = []
215
q = lf.unique(maintain_order=True).select(pad_exprs + [expr]).unique()
216
assert ("UNIQUE[maintain_order: true" in q.explain()) == is_order_observing
217
218
result = q.collect()
219
if expected is not None:
220
unoptimized_result = q.collect(optimizations=pl.QueryOptFlags.none())
221
222
assert_series_equal(
223
result.to_series(len(pad_exprs)), expected, check_order=False
224
)
225
assert_frame_equal(
226
result,
227
unoptimized_result,
228
check_row_order=False,
229
)
230
231
232
c = pl.col.a
233
234
235
@pytest.mark.parametrize(
236
("is_order_observing", "agg", "output", "output_dtype"),
237
[
238
(False, c.min(), 1, pl.Int64()),
239
(False, c.count(), 3, pl.get_index_type()),
240
(False, c.len(), 3, pl.get_index_type()),
241
(False, c.product(), 6, pl.Int64()),
242
(False, c.bitwise_or(), 3, pl.Int64()),
243
(False, (c == 1).any(), True, pl.Boolean()),
244
(False, pl.when(c != 1).then(c).null_count(), 1, pl.get_index_type()),
245
(True, c.first(), 2, pl.Int64()),
246
(True, c.implode(), [2, 1, 3], pl.List(pl.Int64())),
247
(True, c.arg_min(), 1, pl.get_index_type()),
248
],
249
)
250
def test_order_sensitive_aggregations_parametric(
251
is_order_observing: bool, agg: pl.Expr, output: Any, output_dtype: pl.DataType
252
) -> None:
253
assert_correct_ordering(
254
pl.LazyFrame({"a": [2, 1, 3]}),
255
agg.alias("agg"),
256
expected=pl.Series("agg", [output] * 3, output_dtype),
257
is_order_observing=is_order_observing,
258
pad_exprs=[pl.col.a],
259
)
260
261
262
lf1 = pl.LazyFrame({"a": [3, 1, 2]})
263
lf2 = pl.LazyFrame({"a": [2, 1, 3]})
264
lf3 = pl.LazyFrame({"a": [[1, 2], [3]], "b": [[3], [4, 5]]})
265
lf4 = pl.LazyFrame({"a": [2, 1, 3], "b": [4, 6, 5]})
266
lf5 = pl.LazyFrame({"a": [2, None, 3]})
267
lf6 = pl.LazyFrame({"a": [[1], [2]], "b": [[3], [4]]})
268
269
270
@pytest.mark.parametrize(
271
("lf", "expr", "expected", "is_order_observing"),
272
[
273
(lf1, pl.col.a.sort() * pl.col.a, [3, 2, 6], True),
274
(lf1, pl.col.a * pl.col.a, [1, 4, 9], False),
275
(
276
lf2,
277
pl.lit(pl.Series("a", [2, 1, 3, 4])).gather(
278
pl.col.a.filter(pl.col.a > 1) - 1
279
),
280
[1, 3],
281
False,
282
),
283
(lf1, pl.col.a.mode(), [1, 2, 3], False),
284
(lf2, pl.col.a.gather([0, 2]), [2, 3], True),
285
(lf2, pl.col.a, [2, 1, 3], False),
286
(lf2, pl.col.a + 1, [3, 2, 4], False),
287
(lf2, pl.lit(pl.Series("a", [2, 1, 3, 4])).gather([0, 2]), [2, 3], False),
288
(lf2, pl.col.a.filter(pl.col.a != 1), [2, 3], False),
289
(lf3, pl.col.a.explode() * pl.col.b.explode(), [3, 8, 15], True),
290
(lf4, pl.col.a.sort() + pl.col.b, [5, 8], True),
291
(lf4, pl.col.a.sort() + pl.col.b.sort(), [5, 7, 9], False),
292
(lf4, pl.col.a + pl.col.b, pl.Series("a", [6, 7, 8]), False),
293
(lf4, pl.col.a.unique() * pl.col.b.unique(), None, False),
294
(lf5, pl.col.a.drop_nulls(), [2, 3], False),
295
],
296
)
297
def test_order_sensitive_paramateric(
298
lf: pl.LazyFrame,
299
expr: pl.Expr,
300
expected: pl.Series | list[Any] | None,
301
is_order_observing: bool,
302
) -> None:
303
if isinstance(expected, pl.Series):
304
expected = expected.rename("a")
305
elif isinstance(expected, list):
306
expected = pl.Series("a", expected)
307
308
assert_correct_ordering(
309
lf,
310
expr.alias("a"),
311
expected=expected,
312
is_order_observing=is_order_observing,
313
)
314
315
316
def test_with_columns_implicit_columns() -> None:
317
# Test that overwriting all columns in `with_columns` does not require ordering to
318
# be preserved.
319
q = (
320
lf6.select("a")
321
.unique(maintain_order=True)
322
.with_columns(pl.col.a.explode())
323
.unique()
324
)
325
assert "UNIQUE[maintain_order: true" not in q.explain()
326
assert_series_equal(
327
q.collect().to_series(), pl.Series("a", [1, 2]), check_order=False
328
)
329
q = lf6.unique(maintain_order=True).with_columns(pl.col.a.explode()).unique()
330
assert "UNIQUE[maintain_order: true" in q.explain()
331
assert_frame_equal(
332
q.collect(),
333
pl.DataFrame(
334
{
335
"a": [1, 2],
336
"b": [[3], [4]],
337
}
338
),
339
check_row_order=False,
340
)
341
q = lf6.unique(maintain_order=True).with_columns(pl.col.a.alias("c")).unique()
342
assert "UNIQUE[maintain_order: true" not in q.explain()
343
assert_frame_equal(
344
q.collect(),
345
pl.DataFrame(
346
{
347
"a": [[1], [2]],
348
"b": [[3], [4]],
349
"c": [[1], [2]],
350
}
351
),
352
check_row_order=False,
353
)
354
355
356
@pytest.mark.parametrize(
357
("expr", "values", "is_ordered", "is_output_ordered"),
358
[
359
(pl.col.a, [1, 2, 3], False, False),
360
(pl.col.a.map_batches(lambda x: x), [1, 2, 3], True, False),
361
(
362
pl.col.a.map_batches(lambda x: x, is_elementwise=True),
363
[1, 2, 3],
364
False,
365
False,
366
),
367
(
368
pl.col.a.cast(pl.List(pl.Int64))
369
.map_batches(lambda x: x, is_elementwise=True)
370
.explode(),
371
[1, 2, 3],
372
True,
373
False,
374
),
375
(pl.col.a.sort(), [1, 2, 3], True, True),
376
(pl.col.a.sort() + pl.col.a, None, True, True),
377
(pl.col.a.min() + pl.col.a, [2, 3, 4], False, False),
378
(pl.col.a.first() + pl.col.a, None, False, False),
379
],
380
)
381
def test_group_by_key_sensitivity(
382
expr: pl.Expr, values: list[int] | None, is_ordered: bool, is_output_ordered: bool
383
) -> None:
384
lf = pl.LazyFrame({"a": [2, 2, 1, 3], "b": ["A", "B", "C", "D"]}).unique()
385
386
q = lf.group_by(expr.alias("a"), maintain_order=True).agg("b")
387
df = q.collect()
388
assert ("AGGREGATE[maintain_order: true]" in q.explain()) is is_ordered
389
390
expected_values = pl.Series("a", values)
391
392
if values is not None:
393
assert_series_equal(df["a"], expected_values, check_order=is_output_ordered)
394
395
396
@pytest.mark.parametrize(
397
("expr", "is_ordered"),
398
[
399
(pl.col.a, False),
400
(pl.col.a.map_batches(lambda x: x), True),
401
(pl.col.a.map_batches(lambda x: x, is_elementwise=True), False),
402
(
403
pl.col.a.cast(pl.List(pl.Int64))
404
.map_batches(lambda x: x, is_elementwise=True)
405
.explode(),
406
True,
407
),
408
(pl.col.a.cum_prod(), True),
409
(pl.col.a.cum_prod() + pl.col.a, True),
410
(pl.col.a.min() + pl.col.a, False),
411
(pl.col.a.first() + pl.col.a, True),
412
],
413
)
414
def test_sort_key_sensitivity(expr: pl.Expr, is_ordered: bool) -> None:
415
lf = pl.LazyFrame({"a": [2, 2, 1, 3], "b": ["A", "B", "C", "D"]}).sort(pl.all())
416
q = lf.sort(expr)
417
assert (q.explain().count("SORT BY") == 2) is is_ordered
418
assert_frame_equal(q.collect(), lf.sort("a").collect())
419
420
421
@pytest.mark.parametrize(
422
("expr", "is_ordered"),
423
[
424
(pl.col.a, False),
425
(pl.col.a.map_batches(lambda x: x), True),
426
(pl.col.a.map_batches(lambda x: x, is_elementwise=True), False),
427
(
428
pl.col.a.cast(pl.List(pl.Int64))
429
.map_batches(lambda x: x, is_elementwise=True)
430
.explode(),
431
True,
432
),
433
(pl.col.a.cum_prod(), True),
434
(pl.col.a.cum_prod() + pl.col.a, True),
435
(pl.col.a.min() + pl.col.a, False),
436
(pl.col.a.first() + pl.col.a, True),
437
],
438
)
439
def test_filter_sensitivity(expr: pl.Expr, is_ordered: bool) -> None:
440
lf = pl.LazyFrame({"a": [2, 2, 1, 3], "b": ["A", "B", "C", "D"]}).sort(pl.all())
441
q = lf.filter(expr > 0).unique()
442
assert ("SORT BY" in q.explain()) is is_ordered
443
assert_frame_equal(q.collect(), lf.collect(), check_row_order=False)
444
445
446
@pytest.mark.parametrize(
447
("exprs", "is_ordered", "unordered_columns"),
448
[
449
([pl.col.a], True, None),
450
([pl.col.a, pl.col.b], True, None),
451
([pl.col.a.unique()], True, ["a"]),
452
([pl.col.a.min()], True, None),
453
([pl.col.a.product()], True, None),
454
([pl.col.a.unique(), pl.col.b], True, ["a"]),
455
([pl.col.a.unique(), pl.col.b.unique()], False, ["a", "b"]),
456
([pl.col.a.min(), pl.col.b.min()], False, None),
457
([pl.col.a.product(), pl.col.b.null_count()], False, None),
458
([pl.col.b.unique()], True, ["b"]),
459
([pl.col.a.unique(), pl.col.b.unique(), pl.col.a.alias("c")], True, ["a", "b"]),
460
(
461
[pl.col.a.unique(), pl.col.b.unique(), (pl.col.a + 1).unique().alias("c")],
462
False,
463
["a", "b", "c"],
464
),
465
(
466
[pl.col.a.min(), pl.col.b.min(), (pl.col.a + 1).min().alias("c")],
467
False,
468
None,
469
),
470
(
471
[
472
pl.col.a.product(),
473
pl.col.b.null_count(),
474
(pl.col.a + 1).product().alias("c"),
475
],
476
False,
477
None,
478
),
479
],
480
)
481
def test_with_columns_sensitivity(
482
exprs: list[pl.Expr], is_ordered: bool, unordered_columns: list[str] | None
483
) -> None:
484
lf = (
485
pl.LazyFrame({"a": [2, 4, 1, 3], "b": ["A", "C", "B", "D"]})
486
.sort("a")
487
.with_columns(*exprs)
488
.unique(maintain_order=True)
489
)
490
assert ("UNIQUE[maintain_order: true" in lf.explain()) is is_ordered
491
492
df_opt = lf.collect()
493
df_unopt = lf.collect(optimizations=pl.QueryOptFlags(check_order_observe=False))
494
495
if unordered_columns is None:
496
assert_frame_equal(df_opt, df_unopt)
497
else:
498
assert_frame_equal(
499
df_opt.drop(unordered_columns), df_unopt.drop(unordered_columns)
500
)
501
for c in unordered_columns:
502
assert_series_equal(df_opt[c], df_unopt[c], check_order=False)
503
504
505
def test_reverse_non_order_observe() -> None:
506
q = (
507
pl.LazyFrame({"x": [0, 1, 2, 3, 4]})
508
.unique(maintain_order=True)
509
.select(pl.col("x").reverse().sum())
510
)
511
512
plan = q.explain()
513
514
assert "UNIQUE[maintain_order: false" in plan
515
assert q.collect().item() == 10
516
517
# Observing the order of the output of `reverse()` implicitly observes the
518
# input to `reverse()`.
519
q = (
520
pl.LazyFrame({"x": [0, 1, 2, 3, 4]})
521
.unique(maintain_order=True)
522
.select(pl.col("x").reverse().last())
523
)
524
525
plan = q.explain()
526
527
assert "UNIQUE[maintain_order: true" in plan
528
assert q.collect().item() == 0
529
530
# Zipping `reverse()` must also consider the ordering of the input to
531
# `reverse()`.
532
q = (
533
pl.LazyFrame({"x": [0, 1, 2, 3, 4]})
534
.unique(maintain_order=True)
535
.select(x=pl.Series([0, 1, 2, 3, 4]), x_reverse=pl.col("x").reverse())
536
)
537
538
plan = q.explain()
539
assert "UNIQUE[maintain_order: true" in plan
540
assert_frame_equal(
541
q,
542
pl.LazyFrame(
543
{
544
"x": [0, 1, 2, 3, 4],
545
"x_reverse": [4, 3, 2, 1, 0],
546
}
547
),
548
)
549
550
551
def test_order_optimize_cspe_26277() -> None:
552
df = pl.LazyFrame({"x": [1, 2]}).sort("x")
553
554
q1 = pl.concat([df, df])
555
q2 = pl.concat([q1, q1])
556
q3 = q2.sort("x").with_columns("x")
557
558
assert_frame_equal(
559
q3.collect(),
560
pl.DataFrame({"x": [1, 1, 1, 1, 2, 2, 2, 2]}),
561
)
562
563