Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/lazyframe/test_optimizations.py
8415 views
1
import datetime as dt
2
import io
3
import itertools
4
import typing
5
6
import pytest
7
8
import polars as pl
9
from polars.testing import assert_frame_equal
10
11
12
def test_is_null_followed_by_all() -> None:
13
lf = pl.LazyFrame({"group": [0, 0, 0, 1], "val": [6, 0, None, None]})
14
15
expected_df = pl.DataFrame({"group": [0, 1], "val": [False, True]})
16
result_lf = lf.group_by("group", maintain_order=True).agg(
17
pl.col("val").is_null().all()
18
)
19
20
assert r'[[(col("val").len()) == (col("val").null_count())]]' in result_lf.explain()
21
assert "is_null" not in result_lf
22
assert_frame_equal(expected_df, result_lf.collect())
23
24
# verify we don't optimize on chained expressions when last one is not col
25
non_optimized_result_plan = (
26
lf.group_by("group", maintain_order=True)
27
.agg(pl.col("val").abs().is_null().all())
28
.explain()
29
)
30
assert "null_count" not in non_optimized_result_plan
31
assert "is_null" in non_optimized_result_plan
32
33
# edge case of empty series
34
lf = pl.LazyFrame({"val": []}, schema={"val": pl.Int32})
35
36
expected_df = pl.DataFrame({"val": [True]})
37
result_df = lf.select(pl.col("val").is_null().all()).collect()
38
assert_frame_equal(expected_df, result_df)
39
40
41
def test_is_null_followed_by_any() -> None:
42
lf = pl.LazyFrame({"group": [0, 0, 0, 1, 2], "val": [6, 0, None, None, 5]})
43
44
expected_df = pl.DataFrame({"group": [0, 1, 2], "val": [True, True, False]})
45
result_lf = lf.group_by("group", maintain_order=True).agg(
46
pl.col("val").is_null().any()
47
)
48
assert_frame_equal(expected_df, result_lf.collect())
49
50
# edge case of empty series
51
lf = pl.LazyFrame({"val": []}, schema={"val": pl.Int32})
52
53
expected_df = pl.DataFrame({"val": [False]})
54
result_df = lf.select(pl.col("val").is_null().any()).collect()
55
assert_frame_equal(expected_df, result_df)
56
57
58
def test_is_not_null_followed_by_all() -> None:
59
lf = pl.LazyFrame({"group": [0, 0, 0, 1], "val": [6, 0, 5, None]})
60
61
expected_df = pl.DataFrame({"group": [0, 1], "val": [True, False]})
62
result_df = (
63
lf.group_by("group", maintain_order=True)
64
.agg(pl.col("val").is_not_null().all())
65
.collect()
66
)
67
68
assert_frame_equal(expected_df, result_df)
69
70
# edge case of empty series
71
lf = pl.LazyFrame({"val": []}, schema={"val": pl.Int32})
72
73
expected_df = pl.DataFrame({"val": [True]})
74
result_df = lf.select(pl.col("val").is_not_null().all()).collect()
75
assert_frame_equal(expected_df, result_df)
76
77
78
def test_is_not_null_followed_by_any() -> None:
79
lf = pl.LazyFrame({"group": [0, 0, 0, 1, 2], "val": [6, 0, None, None, 5]})
80
81
expected_df = pl.DataFrame({"group": [0, 1, 2], "val": [True, False, True]})
82
result_lf = lf.group_by("group", maintain_order=True).agg(
83
pl.col("val").is_not_null().any()
84
)
85
86
assert r'[[(col("val").null_count()) < (col("val").len())]]' in result_lf.explain()
87
assert "is_not_null" not in result_lf.explain()
88
assert_frame_equal(expected_df, result_lf.collect())
89
90
# verify we don't optimize on chained expressions when last one is not col
91
non_optimized_result_plan = (
92
lf.group_by("group", maintain_order=True)
93
.agg(pl.col("val").abs().is_not_null().any())
94
.explain()
95
)
96
assert "null_count" not in non_optimized_result_plan
97
assert "is_not_null" in non_optimized_result_plan
98
99
# edge case of empty series
100
lf = pl.LazyFrame({"val": []}, schema={"val": pl.Int32})
101
102
expected_df = pl.DataFrame({"val": [False]})
103
result_df = lf.select(pl.col("val").is_not_null().any()).collect()
104
assert_frame_equal(expected_df, result_df)
105
106
107
def test_is_null_followed_by_sum() -> None:
108
lf = pl.LazyFrame({"group": [0, 0, 0, 1, 2], "val": [6, 0, None, None, 5]})
109
110
expected_df = pl.DataFrame(
111
{"group": [0, 1, 2], "val": [1, 1, 0]},
112
schema_overrides={"val": pl.get_index_type()},
113
)
114
result_lf = lf.group_by("group", maintain_order=True).agg(
115
pl.col("val").is_null().sum()
116
)
117
118
assert r'[col("val").null_count()]' in result_lf.explain()
119
assert "is_null" not in result_lf.explain()
120
assert_frame_equal(expected_df, result_lf.collect())
121
122
# edge case of empty series
123
lf = pl.LazyFrame({"val": []}, schema={"val": pl.Int32})
124
125
expected_df = pl.DataFrame({"val": [0]}, schema={"val": pl.get_index_type()})
126
result_df = lf.select(pl.col("val").is_null().sum()).collect()
127
assert_frame_equal(expected_df, result_df)
128
129
130
def test_is_not_null_followed_by_sum() -> None:
131
lf = pl.LazyFrame({"group": [0, 0, 0, 1, 2], "val": [6, 0, None, None, 5]})
132
133
expected_df = pl.DataFrame(
134
{"group": [0, 1, 2], "val": [2, 0, 1]},
135
schema_overrides={"val": pl.get_index_type()},
136
)
137
result_lf = lf.group_by("group", maintain_order=True).agg(
138
pl.col("val").is_not_null().sum()
139
)
140
141
assert r'[[(col("val").len()) - (col("val").null_count())]]' in result_lf.explain()
142
assert "is_not_null" not in result_lf.explain()
143
assert_frame_equal(expected_df, result_lf.collect())
144
145
# verify we don't optimize on chained expressions when last one is not col
146
non_optimized_result_lf = lf.group_by("group", maintain_order=True).agg(
147
pl.col("val").abs().is_not_null().sum()
148
)
149
assert "null_count" not in non_optimized_result_lf.explain()
150
assert "is_not_null" in non_optimized_result_lf.explain()
151
152
# edge case of empty series
153
lf = pl.LazyFrame({"val": []}, schema={"val": pl.Int32})
154
155
expected_df = pl.DataFrame({"val": [0]}, schema={"val": pl.get_index_type()})
156
result_df = lf.select(pl.col("val").is_not_null().sum()).collect()
157
assert_frame_equal(expected_df, result_df)
158
159
160
def test_drop_nulls_followed_by_len() -> None:
161
lf = pl.LazyFrame({"group": [0, 0, 0, 1, 2], "val": [6, 0, None, None, 5]})
162
163
expected_df = pl.DataFrame(
164
{"group": [0, 1, 2], "val": [2, 0, 1]},
165
schema_overrides={"val": pl.get_index_type()},
166
)
167
result_lf = lf.group_by("group", maintain_order=True).agg(
168
pl.col("val").drop_nulls().len()
169
)
170
171
assert r'[[(col("val").len()) - (col("val").null_count())]]' in result_lf.explain()
172
assert "drop_nulls" not in result_lf.explain()
173
assert_frame_equal(expected_df, result_lf.collect())
174
175
# verify we don't optimize on chained expressions when last one is not col
176
non_optimized_result_plan = (
177
lf.group_by("group", maintain_order=True)
178
.agg(pl.col("val").abs().drop_nulls().len())
179
.explain()
180
)
181
assert "null_count" not in non_optimized_result_plan
182
assert "drop_nulls" in non_optimized_result_plan
183
184
185
def test_drop_nulls_followed_by_count() -> None:
186
lf = pl.LazyFrame({"group": [0, 0, 0, 1, 2], "val": [6, 0, None, None, 5]})
187
188
expected_df = pl.DataFrame(
189
{"group": [0, 1, 2], "val": [2, 0, 1]},
190
schema_overrides={"val": pl.get_index_type()},
191
)
192
result_lf = lf.group_by("group", maintain_order=True).agg(
193
pl.col("val").drop_nulls().count()
194
)
195
196
assert r'[[(col("val").len()) - (col("val").null_count())]]' in result_lf.explain()
197
assert "drop_nulls" not in result_lf.explain()
198
assert_frame_equal(expected_df, result_lf.collect())
199
200
# verify we don't optimize on chained expressions when last one is not col
201
non_optimized_result_plan = (
202
lf.group_by("group", maintain_order=True)
203
.agg(pl.col("val").abs().drop_nulls().count())
204
.explain()
205
)
206
assert "null_count" not in non_optimized_result_plan
207
assert "drop_nulls" in non_optimized_result_plan
208
209
210
def test_collapse_joins() -> None:
211
a = pl.LazyFrame({"a": [1, 2, 3], "b": [2, 2, 2]})
212
b = pl.LazyFrame({"x": [7, 1, 2]})
213
214
cross = a.join(b, how="cross")
215
216
inner_join = cross.filter(pl.col.a == pl.col.x)
217
e = inner_join.explain()
218
assert "INNER JOIN" in e
219
assert "FILTER" not in e
220
assert_frame_equal(
221
inner_join.collect(optimizations=pl.QueryOptFlags.none()),
222
inner_join.collect(),
223
check_row_order=False,
224
)
225
226
inner_join = cross.filter(pl.col.x == pl.col.a)
227
e = inner_join.explain()
228
assert "INNER JOIN" in e
229
assert "FILTER" not in e
230
assert_frame_equal(
231
inner_join.collect(optimizations=pl.QueryOptFlags.none()),
232
inner_join.collect(),
233
check_row_order=False,
234
)
235
236
double_inner_join = cross.filter(pl.col.x == pl.col.a).filter(pl.col.x == pl.col.b)
237
e = double_inner_join.explain()
238
assert "INNER JOIN" in e
239
assert "FILTER" not in e
240
assert_frame_equal(
241
double_inner_join.collect(optimizations=pl.QueryOptFlags.none()),
242
double_inner_join.collect(),
243
check_row_order=False,
244
)
245
246
dont_mix = cross.filter(pl.col.x + pl.col.a != 0)
247
e = dont_mix.explain()
248
assert "NESTED LOOP JOIN" in e
249
assert "FILTER" not in e
250
assert_frame_equal(
251
dont_mix.collect(optimizations=pl.QueryOptFlags.none()),
252
dont_mix.collect(),
253
check_row_order=False,
254
)
255
256
iejoin = cross.filter(pl.col.x >= pl.col.a)
257
e = iejoin.explain()
258
assert "IEJOIN" in e
259
assert "NESTED LOOP JOIN" not in e
260
assert "CROSS JOIN" not in e
261
assert "FILTER" not in e
262
assert_frame_equal(
263
iejoin.collect(optimizations=pl.QueryOptFlags.none()),
264
iejoin.collect(),
265
check_row_order=False,
266
)
267
268
iejoin = cross.filter(pl.col.x >= pl.col.a).filter(pl.col.x <= pl.col.b)
269
e = iejoin.explain()
270
assert "IEJOIN" in e
271
assert "CROSS JOIN" not in e
272
assert "NESTED LOOP JOIN" not in e
273
assert "FILTER" not in e
274
assert_frame_equal(
275
iejoin.collect(optimizations=pl.QueryOptFlags.none()),
276
iejoin.collect(),
277
check_row_order=False,
278
)
279
280
281
@pytest.mark.slow
282
def test_collapse_joins_combinations() -> None:
283
# This just tests all possible combinations for expressions on a cross join.
284
285
a = pl.LazyFrame({"a": [1, 2, 3], "x": [7, 2, 1]})
286
b = pl.LazyFrame({"b": [2, 2, 2], "x": [7, 1, 3]})
287
288
cross = a.join(b, how="cross")
289
290
exprs = []
291
292
for lhs in [pl.col.a, pl.col.b, pl.col.x, pl.lit(1), pl.col.a + pl.col.b]:
293
for rhs in [pl.col.a, pl.col.b, pl.col.x, pl.lit(1), pl.col.a * pl.col.x]:
294
for cmp in ["__eq__", "__ge__", "__lt__"]:
295
e = (getattr(lhs, cmp))(rhs)
296
exprs.append(e)
297
298
for amount in range(3):
299
for merge in itertools.product(["__and__", "__or__"] * (amount - 1)):
300
for es in itertools.product(*([exprs] * amount)):
301
e = es[0]
302
for i in range(amount - 1):
303
e = (getattr(e, merge[i]))(es[i + 1])
304
305
# NOTE: We need to sort because the order of the cross-join &
306
# IE-join is unspecified. Therefore, this might not necessarily
307
# create the exact same dataframe.
308
optimized = cross.filter(e).sort(pl.all()).collect()
309
unoptimized = cross.filter(e).collect(
310
optimizations=pl.QueryOptFlags.none()
311
)
312
313
try:
314
assert_frame_equal(optimized, unoptimized, check_row_order=False)
315
except:
316
print(e)
317
print()
318
print("Optimized")
319
print(cross.filter(e).explain())
320
print(optimized)
321
print()
322
print("Unoptimized")
323
print(
324
cross.filter(e).explain(optimizations=pl.QueryOptFlags.none())
325
)
326
print(unoptimized)
327
print()
328
329
raise
330
331
332
def test_order_observe_sort_before_unique_22485() -> None:
333
lf = pl.LazyFrame(
334
{
335
"order": [3, 2, 1],
336
"id": ["A", "A", "B"],
337
}
338
)
339
340
expect = pl.DataFrame({"order": [1, 3], "id": ["B", "A"]})
341
342
q = lf.sort("order").unique(["id"], keep="last").sort("order")
343
344
plan = q.explain()
345
assert "SORT BY" in plan[plan.index("UNIQUE") :]
346
347
assert_frame_equal(q.collect(), expect)
348
349
q = lf.sort("order").unique(["id"], keep="last", maintain_order=True)
350
351
plan = q.explain()
352
assert "SORT BY" in plan[plan.index("UNIQUE") :]
353
354
assert_frame_equal(q.collect(), expect)
355
356
357
def test_order_observe_group_by() -> None:
358
q = (
359
pl.LazyFrame({"a": range(5)})
360
.group_by("a", maintain_order=True)
361
.agg(b=1)
362
.sort("b")
363
)
364
365
plan = q.explain()
366
assert "AGGREGATE[maintain_order: false]" in plan
367
368
q = (
369
pl.LazyFrame({"a": range(5)})
370
.group_by("a", maintain_order=True)
371
.agg(b=1)
372
.sort("b", maintain_order=True)
373
)
374
375
plan = q.explain()
376
assert "AGGREGATE[maintain_order: true]" in plan
377
378
379
def test_fused_correct_name() -> None:
380
df = pl.DataFrame({"x": [1, 2, 3]})
381
382
lf = df.lazy().select(
383
(pl.col.x.alias("a") * pl.col.x.alias("b")) + pl.col.x.alias("c")
384
)
385
386
no_opts = lf.collect(optimizations=pl.QueryOptFlags.none())
387
opts = lf.collect()
388
assert_frame_equal(
389
no_opts,
390
opts,
391
)
392
assert_frame_equal(opts, pl.DataFrame({"a": [2, 6, 12]}))
393
394
395
def test_slice_pushdown_within_concat_24734() -> None:
396
q = pl.concat(
397
[
398
pl.LazyFrame({"x": [0, 1, 2, 3, 4]}).head(2),
399
pl.LazyFrame(schema={"x": pl.Int64}),
400
]
401
)
402
403
plan = q.explain()
404
assert "SLICE" not in plan
405
406
assert_frame_equal(q, pl.LazyFrame({"x": [0, 1]}))
407
408
q = pl.concat(
409
[
410
pl.LazyFrame({"x": [0, 1, 2, 3, 4]}).select(pl.col("x").reverse()),
411
pl.LazyFrame(schema={"x": pl.Int64}),
412
]
413
).slice(1, 2)
414
415
plan = q.explain()
416
assert plan.index("SLICE[offset: 0, len: 3]") > plan.index("PLAN 0:")
417
418
assert_frame_equal(q, pl.LazyFrame({"x": [3, 2]}))
419
420
421
def test_is_between_pushdown_25499() -> None:
422
f = io.BytesIO()
423
pl.LazyFrame(
424
{"a": [0, 1, 2, 3, 4]}, schema_overrides={"a": pl.UInt32}
425
).sink_parquet(f)
426
parquet = f.getvalue()
427
428
expr = pl.lit(3, dtype=pl.UInt32).is_between(
429
pl.lit(1, dtype=pl.UInt32), pl.col("a")
430
)
431
432
df1 = pl.scan_parquet(parquet).filter(expr).collect()
433
df2 = pl.scan_parquet(parquet).collect().filter(expr)
434
assert_frame_equal(df1, df2)
435
436
437
def test_slice_pushdown_expr_25473() -> None:
438
lf = pl.LazyFrame({"a": [0, 1, 2, 3, 4]})
439
440
assert_frame_equal(
441
lf.select((pl.col("a") + 1).slice(-4, 2)).collect(), pl.DataFrame({"a": [2, 3]})
442
)
443
444
assert_frame_equal(
445
lf.select(
446
a=(
447
pl.when(pl.col("a") == 1).then(pl.lit("one")).otherwise(pl.lit("other"))
448
).slice(-4, 2)
449
).collect(),
450
pl.DataFrame({"a": ["one", "other"]}),
451
)
452
453
assert_frame_equal(
454
lf.select(a=pl.col("a").is_in(pl.Series([1]).implode()).slice(-4, 2)).collect(),
455
pl.DataFrame({"a": [True, False]}),
456
)
457
458
q = pl.LazyFrame().select(
459
pl.lit(pl.Series([0, 1, 2, 3, 4])).is_in(pl.Series([[3], [1]])).slice(-2, 1)
460
)
461
462
with pytest.raises(pl.exceptions.ShapeError, match=r"lengths.*5 != 2"):
463
q.collect()
464
465
466
def test_lazy_groupby_maintain_order_after_asof_join_25973() -> None:
467
# Small target times: 00:00, 00:10, 00:20, 00:30
468
targettime = (
469
pl.DataFrame(
470
{
471
"targettime": pl.time_range(
472
dt.time(0, 0),
473
dt.time(0, 30),
474
interval="10m",
475
closed="both",
476
eager=True,
477
)
478
}
479
)
480
.with_columns(
481
targettime=pl.lit(dt.date(2026, 1, 1)).dt.combine(pl.col("targettime")),
482
grp=pl.lit(1),
483
)
484
.lazy()
485
)
486
487
# Small input times: every second from 00:00 to 00:30
488
df = (
489
pl.DataFrame(
490
{
491
"time": pl.time_range(
492
dt.time(0, 0),
493
dt.time(0, 30),
494
interval="1s",
495
closed="both",
496
eager=True,
497
)
498
}
499
)
500
.with_row_index("value")
501
.with_columns(
502
time=pl.lit(dt.date(2026, 1, 1)).dt.combine(pl.col("time")),
503
grp=pl.lit(1),
504
)
505
.lazy()
506
)
507
508
# This used to produce out-of-order results.
509
# The optimizer previously cleared maintain_order.
510
q = (
511
df.join_asof(
512
targettime,
513
left_on="time",
514
right_on="targettime",
515
strategy="forward",
516
)
517
.drop_nulls("targettime")
518
.group_by("targettime", maintain_order=True)
519
.agg(pl.col("value").last())
520
)
521
522
# Verify optimizer preserves maintain_order on UNIQUE
523
plan = q.explain()
524
assert "AGGREGATE[maintain_order: true" in plan
525
526
result = q.collect()
527
528
idx_dtype = pl.get_index_type()
529
530
expected = pl.DataFrame(
531
{
532
"targettime": [
533
dt.datetime(2026, 1, 1, 0, 0),
534
dt.datetime(2026, 1, 1, 0, 10),
535
dt.datetime(2026, 1, 1, 0, 20),
536
dt.datetime(2026, 1, 1, 0, 30),
537
],
538
"value": pl.Series("value", [0, 600, 1200, 1800], dtype=idx_dtype),
539
}
540
)
541
542
assert_frame_equal(result, expected)
543
544
545
def test_fast_count_alias_18581() -> None:
546
f = io.BytesIO()
547
f.write(b"a,b,c\n1,2,3\n4,5,6")
548
f.flush()
549
f.seek(0)
550
551
df = pl.scan_csv(f).select(pl.len().alias("weird_name")).collect()
552
553
# Just check the value, let assert_frame_equal handle dtype matching
554
expected = pl.DataFrame(
555
{"weird_name": [2]}, schema={"weird_name": pl.get_index_type()}
556
)
557
assert_frame_equal(expected, df)
558
559
560
def test_flatten_alias() -> None:
561
assert (
562
"""len().alias("bar")"""
563
in pl.LazyFrame({"a": [1, 2]})
564
.select(pl.len().alias("foo").alias("bar"))
565
.explain()
566
)
567
568
569
def test_concat_str_sortedness_26466() -> None:
570
df = pl.DataFrame({"x": ["", "a", "b"], "y": [1, 2, 3]})
571
lf = df.lazy().set_sorted("x")
572
573
dot = (
574
lf.with_columns(x=pl.concat_str("x"))
575
.group_by("x")
576
.agg(pl.col.y.sum())
577
.show_graph(engine="streaming", plan_stage="physical", raw_output=True)
578
)
579
580
assert "sorted-group-by" in typing.cast("str", dot)
581
582
for e in [pl.concat_str("x", pl.lit("c")), pl.concat_str("x", ignore_nulls=True)]:
583
dot = (
584
lf.with_columns(x=e)
585
.group_by("x")
586
.agg(pl.col.y.sum())
587
.show_graph(engine="streaming", plan_stage="physical", raw_output=True)
588
)
589
590
assert "sorted-group-by" not in typing.cast("str", dot)
591
592