Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/lazyframe/test_cse.py
8430 views
1
from __future__ import annotations
2
3
import re
4
from datetime import date, datetime, timedelta
5
from pathlib import Path
6
from tempfile import NamedTemporaryFile, TemporaryDirectory
7
from typing import TYPE_CHECKING, Any, TypeVar
8
from unittest.mock import Mock
9
10
import numpy as np
11
import pytest
12
13
import polars as pl
14
from polars.io.plugins import register_io_source
15
from polars.testing import assert_frame_equal
16
17
if TYPE_CHECKING:
18
from collections.abc import Iterator
19
20
from tests.conftest import PlMonkeyPatch
21
22
23
def num_cse_occurrences(explanation: str) -> int:
24
"""The number of unique CSE columns in an explain string."""
25
return len(set(re.findall(r'__POLARS_CSER_0x[^"]+"', explanation)))
26
27
28
def create_dataframe_source(
29
source_df: pl.DataFrame,
30
is_pure: bool,
31
validate_schame: bool = False,
32
) -> pl.LazyFrame:
33
"""Generates a custom io source based on the provided pl.DataFrame."""
34
35
def dataframe_source(
36
with_columns: list[str] | None,
37
predicate: pl.Expr | None,
38
_n_rows: int | None,
39
_batch_size: int | None,
40
) -> Iterator[pl.DataFrame]:
41
df = source_df.clone()
42
if predicate is not None:
43
df = df.filter(predicate)
44
if with_columns is not None:
45
df = df.select(with_columns)
46
yield df
47
48
return register_io_source(
49
dataframe_source,
50
schema=source_df.schema,
51
validate_schema=validate_schame,
52
is_pure=is_pure,
53
)
54
55
56
@pytest.mark.parametrize("use_custom_io_source", [True, False])
57
def test_cse_rename_cross_join_5405(use_custom_io_source: bool) -> None:
58
# https://github.com/pola-rs/polars/issues/5405
59
60
right = pl.DataFrame({"A": [1, 2], "B": [3, 4], "D": [5, 6]}).lazy()
61
if use_custom_io_source:
62
right = create_dataframe_source(right.collect(), is_pure=True)
63
left = pl.DataFrame({"C": [3, 4]}).lazy().join(right.select("A"), how="cross")
64
65
result = left.join(right.rename({"B": "C"}), on=["A", "C"], how="left").collect(
66
optimizations=pl.QueryOptFlags(comm_subplan_elim=True)
67
)
68
69
expected = pl.DataFrame(
70
{
71
"C": [3, 3, 4, 4],
72
"A": [1, 2, 1, 2],
73
"D": [5, None, None, 6],
74
}
75
)
76
assert_frame_equal(result, expected, check_row_order=False)
77
78
79
def test_union_duplicates() -> None:
80
n_dfs = 10
81
df_lazy = pl.DataFrame({}).lazy()
82
lazy_dfs = [df_lazy for _ in range(n_dfs)]
83
84
matches = re.findall(r"CACHE\[id: (.*)]", pl.concat(lazy_dfs).explain())
85
86
assert len(matches) == 10
87
assert len(set(matches)) == 1
88
89
90
def test_cse_with_struct_expr_11116() -> None:
91
# https://github.com/pola-rs/polars/issues/11116
92
93
df = pl.DataFrame([{"s": {"a": 1, "b": 4}, "c": 3}]).lazy()
94
95
result = df.with_columns(
96
pl.col("s").struct.field("a").alias("s_a"),
97
pl.col("s").struct.field("b").alias("s_b"),
98
(
99
(pl.col("s").struct.field("a") <= pl.col("c"))
100
& (pl.col("s").struct.field("b") > pl.col("c"))
101
).alias("c_between_a_and_b"),
102
).collect(optimizations=pl.QueryOptFlags(comm_subexpr_elim=True))
103
104
expected = pl.DataFrame(
105
{
106
"s": [{"a": 1, "b": 4}],
107
"c": [3],
108
"s_a": [1],
109
"s_b": [4],
110
"c_between_a_and_b": [True],
111
}
112
)
113
assert_frame_equal(result, expected)
114
115
116
def test_cse_schema_6081() -> None:
117
# https://github.com/pola-rs/polars/issues/6081
118
119
df = pl.DataFrame(
120
data=[
121
[date(2022, 12, 12), 1, 1],
122
[date(2022, 12, 12), 1, 2],
123
[date(2022, 12, 13), 5, 2],
124
],
125
schema=["date", "id", "value"],
126
orient="row",
127
).lazy()
128
129
min_value_by_group = df.group_by(["date", "id"]).agg(
130
pl.col("value").min().alias("min_value")
131
)
132
133
result = df.join(min_value_by_group, on=["date", "id"], how="left").collect(
134
optimizations=pl.QueryOptFlags(comm_subplan_elim=True, projection_pushdown=True)
135
)
136
expected = pl.DataFrame(
137
{
138
"date": [date(2022, 12, 12), date(2022, 12, 12), date(2022, 12, 13)],
139
"id": [1, 1, 5],
140
"value": [1, 2, 2],
141
"min_value": [1, 1, 2],
142
}
143
)
144
assert_frame_equal(result, expected, check_row_order=False)
145
146
147
def test_cse_9630() -> None:
148
lf1 = pl.LazyFrame({"key": [1], "x": [1]})
149
lf2 = pl.LazyFrame({"key": [1], "y": [2]})
150
151
joined_lf2 = lf1.join(lf2, on="key")
152
153
all_subsections = (
154
pl.concat(
155
[
156
lf1.select("key", pl.col("x").alias("value")),
157
joined_lf2.select("key", pl.col("y").alias("value")),
158
]
159
)
160
.group_by("key")
161
.agg(pl.col("value"))
162
)
163
164
intersected_df1 = all_subsections.join(lf1, on="key")
165
intersected_df2 = all_subsections.join(lf2, on="key")
166
167
result = intersected_df1.join(intersected_df2, on=["key"], how="left").collect(
168
optimizations=pl.QueryOptFlags(comm_subplan_elim=True)
169
)
170
171
expected = pl.DataFrame(
172
{
173
"key": [1],
174
"value": [[1, 2]],
175
"x": [1],
176
"value_right": [[1, 2]],
177
"y": [2],
178
}
179
)
180
assert_frame_equal(result, expected)
181
182
183
@pytest.mark.write_disk
184
@pytest.mark.parametrize("maintain_order", [False, True])
185
def test_schema_row_index_cse(maintain_order: bool) -> None:
186
with NamedTemporaryFile() as csv_a:
187
csv_a.write(b"A,B\nGr1,A\nGr1,B")
188
csv_a.seek(0)
189
190
df_a = pl.scan_csv(csv_a.name).with_row_index("Idx")
191
192
result = (
193
df_a.join(df_a, on="B", maintain_order="left" if maintain_order else "none")
194
.group_by("A", maintain_order=maintain_order)
195
.all()
196
.collect(optimizations=pl.QueryOptFlags(comm_subexpr_elim=True))
197
)
198
199
expected = pl.DataFrame(
200
{
201
"A": ["Gr1"],
202
"Idx": [[0, 1]],
203
"B": [["A", "B"]],
204
"Idx_right": [[0, 1]],
205
"A_right": [["Gr1", "Gr1"]],
206
},
207
schema_overrides={"Idx": pl.List(pl.UInt32), "Idx_right": pl.List(pl.UInt32)},
208
)
209
assert_frame_equal(result, expected, check_row_order=maintain_order)
210
211
212
@pytest.mark.debug
213
def test_cse_expr_selection_context() -> None:
214
q = pl.LazyFrame(
215
{
216
"a": [1, 2, 3, 4],
217
"b": [1, 2, 3, 4],
218
"c": [1, 2, 3, 4],
219
}
220
)
221
222
derived = (pl.col("a") * pl.col("b")).sum()
223
derived2 = derived * derived
224
225
exprs = [
226
derived.alias("d1"),
227
(derived * pl.col("c").sum() - 1).alias("foo"),
228
derived2.alias("d2"),
229
(derived2 * 10).alias("d3"),
230
]
231
232
result = q.select(exprs).collect(
233
optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)
234
)
235
assert (
236
num_cse_occurrences(
237
q.select(exprs).explain(
238
optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)
239
)
240
)
241
== 2
242
)
243
expected = pl.DataFrame(
244
{
245
"d1": [30],
246
"foo": [299],
247
"d2": [900],
248
"d3": [9000],
249
}
250
)
251
assert_frame_equal(result, expected)
252
253
result = q.with_columns(exprs).collect(
254
optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)
255
)
256
assert (
257
num_cse_occurrences(
258
q.with_columns(exprs).explain(
259
optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)
260
)
261
)
262
== 2
263
)
264
expected = pl.DataFrame(
265
{
266
"a": [1, 2, 3, 4],
267
"b": [1, 2, 3, 4],
268
"c": [1, 2, 3, 4],
269
"d1": [30, 30, 30, 30],
270
"foo": [299, 299, 299, 299],
271
"d2": [900, 900, 900, 900],
272
"d3": [9000, 9000, 9000, 9000],
273
}
274
)
275
assert_frame_equal(result, expected)
276
277
278
def test_windows_cse_excluded() -> None:
279
lf = pl.LazyFrame(
280
data=[
281
("a", "aaa", 1),
282
("a", "bbb", 3),
283
("a", "ccc", 1),
284
("c", "xxx", 2),
285
("c", "yyy", 3),
286
("c", "zzz", 4),
287
("b", "qqq", 0),
288
],
289
schema=["a", "b", "c"],
290
orient="row",
291
)
292
293
result = lf.select(
294
c_diff=pl.col("c").diff(1),
295
c_diff_by_a=pl.col("c").diff(1).over("a"),
296
).collect(optimizations=pl.QueryOptFlags(comm_subexpr_elim=True))
297
298
expected = pl.DataFrame(
299
{
300
"c_diff": [None, 2, -2, 1, 1, 1, -4],
301
"c_diff_by_a": [None, 2, -2, None, 1, 1, None],
302
}
303
)
304
assert_frame_equal(result, expected)
305
306
307
def test_cse_group_by_10215() -> None:
308
lf = pl.LazyFrame({"a": [1], "b": [1]})
309
310
result = lf.group_by("b").agg(
311
(pl.col("a").sum() * pl.col("a").sum()).alias("x"),
312
(pl.col("b").sum() * pl.col("b").sum()).alias("y"),
313
(pl.col("a").sum() * pl.col("a").sum()).alias("x2"),
314
((pl.col("a") + 2).sum() * pl.col("a").sum()).alias("x3"),
315
((pl.col("a") + 2).sum() * pl.col("b").sum()).alias("x4"),
316
((pl.col("a") + 2).sum() * pl.col("b").sum()),
317
)
318
319
assert "__POLARS_CSER" in result.explain(
320
optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)
321
)
322
expected = pl.DataFrame(
323
{
324
"b": [1],
325
"x": [1],
326
"y": [1],
327
"x2": [1],
328
"x3": [3],
329
"x4": [3],
330
"a": [3],
331
}
332
)
333
assert_frame_equal(
334
result.collect(optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)), expected
335
)
336
337
338
def test_cse_mixed_window_functions() -> None:
339
# checks if the window caches are cleared
340
# there are windows in the cse's and the default expressions
341
lf = pl.LazyFrame({"a": [1], "b": [1], "c": [1]})
342
343
result = lf.select(
344
pl.col("a"),
345
pl.col("b"),
346
pl.col("c"),
347
pl.col("b").rank().alias("rank"),
348
pl.col("b").rank().alias("d_rank"),
349
pl.col("b").first().over([pl.col("a")]).alias("b_first"),
350
pl.col("b").last().over([pl.col("a")]).alias("b_last"),
351
pl.col("b").item().over([pl.col("a")]).alias("b_item"),
352
pl.col("b").shift().alias("b_lag_1"),
353
pl.col("b").shift().alias("b_lead_1"),
354
pl.col("c").cum_sum().alias("c_cumsum"),
355
pl.col("c").cum_sum().over([pl.col("a")]).alias("c_cumsum_by_a"),
356
pl.col("c").diff().alias("c_diff"),
357
pl.col("c").diff().over([pl.col("a")]).alias("c_diff_by_a"),
358
).collect(optimizations=pl.QueryOptFlags(comm_subexpr_elim=True))
359
360
expected = pl.DataFrame(
361
{
362
"a": [1],
363
"b": [1],
364
"c": [1],
365
"rank": [1.0],
366
"d_rank": [1.0],
367
"b_first": [1],
368
"b_last": [1],
369
"b_item": [1],
370
"b_lag_1": [None],
371
"b_lead_1": [None],
372
"c_cumsum": [1],
373
"c_cumsum_by_a": [1],
374
"c_diff": [None],
375
"c_diff_by_a": [None],
376
},
377
).with_columns(pl.col(pl.Null).cast(pl.Int64))
378
assert_frame_equal(result, expected)
379
380
381
def test_cse_10401() -> None:
382
df = pl.LazyFrame({"clicks": [1.0, float("nan"), None]})
383
384
q = df.with_columns(pl.all().fill_null(0).fill_nan(0))
385
386
assert r"""col("clicks").fill_null([0.0]).alias("__POLARS_CSER""" in q.explain()
387
388
expected = pl.DataFrame({"clicks": [1.0, 0.0, 0.0]})
389
assert_frame_equal(
390
q.collect(optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)), expected
391
)
392
393
394
def test_cse_10441() -> None:
395
lf = pl.LazyFrame({"a": [1, 2, 3], "b": [3, 2, 1]})
396
397
result = lf.select(
398
pl.col("a").sum() + pl.col("a").sum() + pl.col("b").sum()
399
).collect(optimizations=pl.QueryOptFlags(comm_subexpr_elim=True))
400
401
expected = pl.DataFrame({"a": [18]})
402
assert_frame_equal(result, expected)
403
404
405
def test_cse_10452() -> None:
406
lf = pl.LazyFrame({"a": [1, 2, 3], "b": [3, 2, 1]})
407
q = lf.select(
408
pl.col("b").sum() + pl.col("a").sum().over(pl.col("b")) + pl.col("b").sum()
409
)
410
411
assert "__POLARS_CSE" in q.explain(
412
optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)
413
)
414
415
expected = pl.DataFrame({"b": [13, 14, 15]})
416
assert_frame_equal(
417
q.collect(optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)), expected
418
)
419
420
421
def test_cse_group_by_ternary_10490() -> None:
422
lf = pl.LazyFrame(
423
{
424
"a": [1, 1, 2, 2],
425
"b": [1, 2, 3, 4],
426
"c": [2, 3, 4, 5],
427
}
428
)
429
430
result = (
431
lf.group_by("a")
432
.agg(
433
[
434
pl.when(pl.col(col).is_null().all()).then(None).otherwise(1).alias(col)
435
for col in ["b", "c"]
436
]
437
+ [
438
(pl.col("a").sum() * pl.col("a").sum()).alias("x"),
439
(pl.col("b").sum() * pl.col("b").sum()).alias("y"),
440
(pl.col("a").sum() * pl.col("a").sum()).alias("x2"),
441
((pl.col("a") + 2).sum() * pl.col("a").sum()).alias("x3"),
442
((pl.col("a") + 2).sum() * pl.col("b").sum()).alias("x4"),
443
]
444
)
445
.collect(optimizations=pl.QueryOptFlags(comm_subexpr_elim=True))
446
.sort("a")
447
)
448
449
expected = pl.DataFrame(
450
{
451
"a": [1, 2],
452
"b": [1, 1],
453
"c": [1, 1],
454
"x": [4, 16],
455
"y": [9, 49],
456
"x2": [4, 16],
457
"x3": [12, 32],
458
"x4": [18, 56],
459
},
460
schema_overrides={"b": pl.Int32, "c": pl.Int32},
461
)
462
assert_frame_equal(result, expected)
463
464
465
def test_cse_quantile_10815() -> None:
466
np.random.seed(1)
467
a = np.random.random(10)
468
b = np.random.random(10)
469
df = pl.DataFrame({"a": a, "b": b})
470
cols = ["a", "b"]
471
q = df.lazy().select(
472
*(
473
pl.col(c).quantile(0.75, interpolation="midpoint").name.suffix("_3")
474
for c in cols
475
),
476
*(
477
pl.col(c).quantile(0.25, interpolation="midpoint").name.suffix("_1")
478
for c in cols
479
),
480
)
481
assert "__POLARS_CSE" not in q.explain()
482
assert q.collect().to_dict(as_series=False) == {
483
"a_3": [0.40689473946662197],
484
"b_3": [0.6145786693120769],
485
"a_1": [0.16650805109739197],
486
"b_1": [0.2012768694081981],
487
}
488
489
490
def test_cse_nan_10824() -> None:
491
v = pl.col("a") / pl.col("b")
492
magic = pl.when(v > 0).then(pl.lit(float("nan"))).otherwise(v)
493
assert (
494
str(
495
(
496
pl.DataFrame(
497
{
498
"a": [1.0],
499
"b": [1.0],
500
}
501
)
502
.lazy()
503
.select(magic)
504
.collect(optimizations=pl.QueryOptFlags(comm_subexpr_elim=True))
505
).to_dict(as_series=False)
506
)
507
== "{'literal': [nan]}"
508
)
509
510
511
def test_cse_10901() -> None:
512
df = pl.DataFrame(data=range(6), schema={"a": pl.Int64})
513
a = pl.col("a").rolling_sum(window_size=2)
514
b = pl.col("a").rolling_sum(window_size=3)
515
exprs = {
516
"ax1": a,
517
"ax2": a * 2,
518
"bx1": b,
519
"bx2": b * 2,
520
}
521
522
expected = pl.DataFrame(
523
{
524
"a": [0, 1, 2, 3, 4, 5],
525
"ax1": [None, 1, 3, 5, 7, 9],
526
"ax2": [None, 2, 6, 10, 14, 18],
527
"bx1": [None, None, 3, 6, 9, 12],
528
"bx2": [None, None, 6, 12, 18, 24],
529
}
530
)
531
532
assert_frame_equal(df.lazy().with_columns(**exprs).collect(), expected)
533
534
535
def test_cse_count_in_group_by() -> None:
536
q = (
537
pl.LazyFrame({"a": [1, 1, 2], "b": [1, 2, 3], "c": [40, 51, 12]})
538
.group_by("a")
539
.agg(pl.all().slice(0, pl.len() - 1))
540
)
541
542
assert "POLARS_CSER" not in q.explain()
543
assert q.collect().sort("a").to_dict(as_series=False) == {
544
"a": [1, 2],
545
"b": [[1], []],
546
"c": [[40], []],
547
}
548
549
550
def test_cse_slice_11594() -> None:
551
df = pl.LazyFrame({"a": [1, 2, 1, 2, 1, 2]})
552
553
q = df.select(
554
pl.col("a").slice(offset=1, length=pl.len() - 1).alias("1"),
555
pl.col("a").slice(offset=1, length=pl.len() - 1).alias("2"),
556
)
557
558
assert "__POLARS_CSE" in q.explain(
559
optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)
560
)
561
562
assert q.collect(optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)).to_dict(
563
as_series=False
564
) == {
565
"1": [2, 1, 2, 1, 2],
566
"2": [2, 1, 2, 1, 2],
567
}
568
569
q = df.select(
570
pl.col("a").slice(offset=1, length=pl.len() - 1).alias("1"),
571
pl.col("a").slice(offset=0, length=pl.len() - 1).alias("2"),
572
)
573
574
assert "__POLARS_CSE" in q.explain(
575
optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)
576
)
577
578
assert q.collect(optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)).to_dict(
579
as_series=False
580
) == {
581
"1": [2, 1, 2, 1, 2],
582
"2": [1, 2, 1, 2, 1],
583
}
584
585
586
def test_cse_is_in_11489() -> None:
587
df = pl.DataFrame(
588
{"cond": [1, 2, 3, 2, 1], "x": [1.0, 0.20, 3.0, 4.0, 0.50]}
589
).lazy()
590
any_cond = (
591
pl.when(pl.col("cond").is_in([2, 3]))
592
.then(True)
593
.when(pl.col("cond").is_in([1]))
594
.then(False)
595
.otherwise(None)
596
.alias("any_cond")
597
)
598
val = (
599
pl.when(any_cond)
600
.then(1.0)
601
.when(~any_cond)
602
.then(0.0)
603
.otherwise(None)
604
.alias("val")
605
)
606
assert df.select("cond", any_cond, val).collect().to_dict(as_series=False) == {
607
"cond": [1, 2, 3, 2, 1],
608
"any_cond": [False, True, True, True, False],
609
"val": [0.0, 1.0, 1.0, 1.0, 0.0],
610
}
611
612
613
def test_cse_11958() -> None:
614
df = pl.LazyFrame({"a": [1, 2, 3, 4, 5]})
615
vector_losses = []
616
for lag in range(1, 5):
617
difference = pl.col("a") - pl.col("a").shift(lag)
618
component_loss = pl.when(difference >= 0).then(difference * 10)
619
vector_losses.append(component_loss.alias(f"diff{lag}"))
620
621
q = df.select(vector_losses)
622
assert "__POLARS_CSE" in q.explain(
623
optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)
624
)
625
assert q.collect(optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)).to_dict(
626
as_series=False
627
) == {
628
"diff1": [None, 10, 10, 10, 10],
629
"diff2": [None, None, 20, 20, 20],
630
"diff3": [None, None, None, 30, 30],
631
"diff4": [None, None, None, None, 40],
632
}
633
634
635
def test_cse_14047() -> None:
636
ldf = pl.LazyFrame(
637
{
638
"timestamp": pl.datetime_range(
639
datetime(2024, 1, 12),
640
datetime(2024, 1, 12, 0, 0, 0, 150_000),
641
"10ms",
642
eager=True,
643
closed="left",
644
),
645
"price": list(range(15)),
646
}
647
)
648
649
def count_diff(
650
price: pl.Expr, upper_bound: float = 0.1, lower_bound: float = 0.001
651
) -> pl.Expr:
652
span_end_to_curr = (
653
price.count()
654
.cast(int)
655
.rolling("timestamp", period=timedelta(seconds=lower_bound))
656
)
657
span_start_to_curr = (
658
price.count()
659
.cast(int)
660
.rolling("timestamp", period=timedelta(seconds=upper_bound))
661
)
662
return (span_start_to_curr - span_end_to_curr).alias(
663
f"count_diff_{upper_bound}_{lower_bound}"
664
)
665
666
def s_per_count(count_diff: pl.Expr, span: tuple[float, float]) -> pl.Expr:
667
return (span[1] * 1000 - span[0] * 1000) / count_diff
668
669
spans = [(0.001, 0.1), (1, 10)]
670
count_diff_exprs = [count_diff(pl.col("price"), span[0], span[1]) for span in spans]
671
s_per_count_exprs = [
672
s_per_count(count_diff, span).alias(f"zz_{span}")
673
for count_diff, span in zip(count_diff_exprs, spans, strict=True)
674
]
675
676
exprs = count_diff_exprs + s_per_count_exprs
677
ldf = ldf.with_columns(*exprs)
678
assert_frame_equal(
679
ldf.collect(optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)),
680
ldf.collect(optimizations=pl.QueryOptFlags(comm_subexpr_elim=False)),
681
)
682
683
684
def test_cse_15536() -> None:
685
source = pl.DataFrame({"a": range(10)})
686
687
data = source.lazy().filter(pl.col("a") >= 5)
688
689
assert pl.concat(
690
[
691
data.filter(pl.lit(True) & (pl.col("a") == 6) | (pl.col("a") == 9)),
692
data.filter(pl.lit(True) & (pl.col("a") == 7) | (pl.col("a") == 8)),
693
]
694
).collect()["a"].to_list() == [6, 9, 7, 8]
695
696
697
def test_cse_15548() -> None:
698
ldf = pl.LazyFrame({"a": [1, 2, 3]})
699
ldf2 = ldf.filter(pl.col("a") == 1).cache()
700
ldf3 = pl.concat([ldf, ldf2])
701
702
assert (
703
len(ldf3.collect(optimizations=pl.QueryOptFlags(comm_subplan_elim=False))) == 4
704
)
705
assert (
706
len(ldf3.collect(optimizations=pl.QueryOptFlags(comm_subplan_elim=True))) == 4
707
)
708
709
710
@pytest.mark.debug
711
def test_cse_and_schema_update_projection_pd() -> None:
712
df = pl.LazyFrame({"a": [1, 2], "b": [99, 99]})
713
714
q = (
715
df.lazy()
716
.with_row_index()
717
.select(
718
pl.when(pl.col("b") < 10)
719
.then(0.1 * pl.col("b"))
720
.when(pl.col("b") < 100)
721
.then(0.2 * pl.col("b"))
722
)
723
)
724
assert q.collect(optimizations=pl.QueryOptFlags(comm_subplan_elim=False)).to_dict(
725
as_series=False
726
) == {"literal": [19.8, 19.8]}
727
assert (
728
num_cse_occurrences(
729
q.explain(optimizations=pl.QueryOptFlags(comm_subexpr_elim=True))
730
)
731
== 1
732
)
733
734
735
@pytest.mark.debug
736
@pytest.mark.may_fail_auto_streaming
737
@pytest.mark.parametrize("use_custom_io_source", [True, False])
738
def test_cse_predicate_self_join(
739
capfd: Any, plmonkeypatch: PlMonkeyPatch, use_custom_io_source: bool
740
) -> None:
741
plmonkeypatch.setenv("POLARS_VERBOSE", "1")
742
y = pl.LazyFrame({"a": [1], "b": [2], "y": [3]})
743
if use_custom_io_source:
744
y = create_dataframe_source(y.collect(), is_pure=True)
745
746
xf = y.filter(pl.col("y") == 2).select(["a", "b"])
747
y_xf = y.join(xf, on=["a", "b"], how="left")
748
749
y_xf_c = y_xf.select("a", "b")
750
assert y_xf_c.collect().to_dict(as_series=False) == {"a": [1], "b": [2]}
751
captured = capfd.readouterr().err
752
assert "CACHE HIT" in captured
753
754
755
def test_cse_manual_cache_15688() -> None:
756
df = pl.LazyFrame(
757
{"a": [1, 2, 3, 1, 2, 3], "b": [1, 1, 1, 1, 1, 1], "id": [1, 1, 1, 2, 2, 2]}
758
)
759
760
df1 = df.filter(id=1).join(df.filter(id=2), on=["a", "b"], how="semi")
761
df2 = df.filter(id=1).join(df1, on=["a", "b"], how="semi")
762
df2 = df2.cache()
763
res = df2.group_by("b").agg(pl.all().sum())
764
765
assert res.cache().with_columns(foo=1).collect().to_dict(as_series=False) == {
766
"b": [1],
767
"a": [6],
768
"id": [3],
769
"foo": [1],
770
}
771
772
773
def test_cse_drop_nulls_15795() -> None:
774
A = pl.LazyFrame({"X": 1})
775
B = pl.LazyFrame({"X": 1, "Y": 0}).filter(pl.col("Y").is_not_null())
776
C = A.join(B, on="X").select("X")
777
D = B.select("X")
778
assert C.join(D, on="X").collect().shape == (1, 1)
779
780
781
def test_cse_no_projection_15980() -> None:
782
df = pl.LazyFrame({"x": "a", "y": 1})
783
df = pl.concat(df.with_columns(pl.col("y").add(n)) for n in range(2))
784
785
assert df.filter(pl.col("x").eq("a")).select("x").collect().to_dict(
786
as_series=False
787
) == {"x": ["a", "a"]}
788
789
790
@pytest.mark.debug
791
def test_cse_series_collision_16138() -> None:
792
holdings = pl.DataFrame(
793
{
794
"fund_currency": ["CLP", "CLP"],
795
"asset_currency": ["EUR", "USA"],
796
}
797
)
798
799
usd = ["USD"]
800
eur = ["EUR"]
801
clp = ["CLP"]
802
803
currency_factor_query_dict = [
804
pl.col("asset_currency").is_in(eur) & pl.col("fund_currency").is_in(clp),
805
pl.col("asset_currency").is_in(eur) & pl.col("fund_currency").is_in(usd),
806
pl.col("asset_currency").is_in(clp) & pl.col("fund_currency").is_in(clp),
807
pl.col("asset_currency").is_in(usd) & pl.col("fund_currency").is_in(usd),
808
]
809
810
factor_holdings = holdings.lazy().with_columns(
811
pl.coalesce(currency_factor_query_dict).alias("currency_factor"),
812
)
813
814
assert factor_holdings.collect(
815
optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)
816
).to_dict(as_series=False) == {
817
"fund_currency": ["CLP", "CLP"],
818
"asset_currency": ["EUR", "USA"],
819
"currency_factor": [True, False],
820
}
821
assert (
822
num_cse_occurrences(
823
factor_holdings.explain(
824
optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)
825
)
826
)
827
== 3
828
)
829
830
831
def test_nested_cache_no_panic_16553() -> None:
832
assert pl.LazyFrame().select(a=[[[1]]]).collect(
833
optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)
834
).to_dict(as_series=False) == {"a": [[[[1]]]]}
835
836
837
def test_hash_empty_series_16577() -> None:
838
s = pl.Series(values=None)
839
out = pl.LazyFrame().select(s).collect()
840
assert out.equals(s.to_frame())
841
842
843
def test_cse_non_scalar_length_mismatch_17732() -> None:
844
df = pl.LazyFrame({"a": pl.Series(range(30), dtype=pl.Int32)})
845
got = (
846
df.lazy()
847
.with_columns(
848
pl.col("a").head(5).min().alias("b"),
849
pl.col("a").head(5).max().alias("c"),
850
)
851
.collect(optimizations=pl.QueryOptFlags(comm_subexpr_elim=True))
852
)
853
expect = pl.DataFrame(
854
{
855
"a": pl.Series(range(30), dtype=pl.Int32),
856
"b": pl.Series([0] * 30, dtype=pl.Int32),
857
"c": pl.Series([4] * 30, dtype=pl.Int32),
858
}
859
)
860
861
assert_frame_equal(expect, got)
862
863
864
def test_cse_chunks_18124() -> None:
865
df = pl.DataFrame(
866
{
867
"ts_diff": [timedelta(seconds=60)] * 2,
868
"ts_diff_after": [timedelta(seconds=120)] * 2,
869
}
870
)
871
df = pl.concat([df, df], rechunk=False)
872
assert (
873
df.lazy()
874
.with_columns(
875
ts_diff_sign=pl.col("ts_diff") > pl.duration(seconds=0),
876
ts_diff_after_sign=pl.col("ts_diff_after") > pl.duration(seconds=0),
877
)
878
.filter(pl.col("ts_diff") > 1)
879
).collect().shape == (4, 4)
880
881
882
@pytest.mark.may_fail_auto_streaming
883
def test_eager_cse_during_struct_expansion_18411() -> None:
884
df = pl.DataFrame({"foo": [0, 0, 0, 1, 1]})
885
vc = pl.col("foo").value_counts()
886
classes = vc.struct[0]
887
counts = vc.struct[1]
888
# Check if output is stable
889
assert (
890
df.select(pl.col("foo").replace(classes, counts))
891
== df.select(pl.col("foo").replace(classes, counts))
892
)["foo"].all()
893
894
895
def test_cse_as_struct_19253() -> None:
896
df = pl.LazyFrame({"x": [1, 2], "y": [4, 5]})
897
898
assert (
899
df.with_columns(
900
q1=pl.struct(pl.col.x - pl.col.y.mean()),
901
q2=pl.struct(pl.col.x - pl.col.y.mean().over("y")),
902
).collect()
903
).to_dict(as_series=False) == {
904
"x": [1, 2],
905
"y": [4, 5],
906
"q1": [{"x": -3.5}, {"x": -2.5}],
907
"q2": [{"x": -3.0}, {"x": -3.0}],
908
}
909
910
911
@pytest.mark.may_fail_auto_streaming
912
def test_cse_as_struct_value_counts_20927() -> None:
913
assert pl.DataFrame({"x": [i for i in range(1, 6) for _ in range(i)]}).select(
914
pl.struct("x").value_counts().struct.unnest()
915
).sort("count").to_dict(as_series=False) == {
916
"x": [{"x": 1}, {"x": 2}, {"x": 3}, {"x": 4}, {"x": 5}],
917
"count": [1, 2, 3, 4, 5],
918
}
919
920
921
def test_cse_union_19227() -> None:
922
lf = pl.LazyFrame({"A": [1], "B": [2]})
923
lf_1 = lf.select(C="A", B="B")
924
lf_2 = lf.select(C="A", A="B")
925
926
direct = lf_2.join(lf, on=["A"]).select("C", "A", "B")
927
928
indirect = lf_1.join(direct, on=["C", "B"]).select("C", "A", "B")
929
930
out = pl.concat([direct, indirect])
931
assert out.collect().schema == pl.Schema(
932
[("C", pl.Int64), ("A", pl.Int64), ("B", pl.Int64)]
933
)
934
935
936
def test_cse_21115() -> None:
937
lf = pl.LazyFrame({"x": 1, "y": 5})
938
939
assert lf.with_columns(
940
pl.all().exp() + pl.min_horizontal(pl.all().exp())
941
).collect().to_dict(as_series=False) == {
942
"x": [5.43656365691809],
943
"y": [151.13144093103566],
944
}
945
946
947
@pytest.mark.parametrize("use_custom_io_source", [True, False])
948
def test_cse_cache_leakage_22339(use_custom_io_source: bool) -> None:
949
lf1 = pl.LazyFrame({"x": [True] * 2})
950
lf2 = pl.LazyFrame({"x": [True] * 3})
951
if use_custom_io_source:
952
lf1 = create_dataframe_source(lf1.collect(), is_pure=True)
953
lf2 = create_dataframe_source(lf2.collect(), is_pure=True)
954
955
a = lf1
956
b = lf1.filter(pl.col("x").not_().over(1))
957
c = lf2.filter(pl.col("x").not_().over(1))
958
959
ab = a.join(b, on="x")
960
bc = b.join(c, on="x")
961
ac = a.join(c, on="x")
962
963
assert pl.concat([ab, bc, ac]).collect().to_dict(as_series=False) == {"x": []}
964
965
966
@pytest.mark.write_disk
967
def test_multiplex_predicate_pushdown() -> None:
968
ldf = pl.LazyFrame({"a": [1, 1, 2, 2], "b": [1, 2, 3, 4]})
969
with TemporaryDirectory() as f:
970
tmppath = Path(f)
971
ldf.sink_parquet(
972
pl.PartitionBy(tmppath, key="a", include_key=True),
973
sync_on_close="all",
974
mkdir=True,
975
)
976
ldf = pl.scan_parquet(tmppath, hive_partitioning=True)
977
ldf = ldf.filter(pl.col("a").eq(1)).select("b")
978
assert 'SELECTION: [(col("a")) == (1)]' in pl.explain_all([ldf, ldf])
979
980
981
def test_cse_custom_io_source_same_object() -> None:
982
df = pl.DataFrame({"a": [1, 2, 3, 4, 5]})
983
984
io_source = Mock(wraps=lambda *_: iter([df]))
985
986
lf = register_io_source(
987
io_source,
988
schema=df.schema,
989
validate_schema=True,
990
is_pure=True,
991
)
992
993
lfs = [lf, lf]
994
995
plan = pl.explain_all(lfs)
996
caches: list[str] = [
997
x for x in map(str.strip, plan.splitlines()) if x.startswith("CACHE[")
998
]
999
assert len(caches) == 2
1000
assert len(set(caches)) == 1
1001
1002
assert io_source.call_count == 0
1003
1004
assert_frame_equal(
1005
pl.concat(pl.collect_all(lfs)),
1006
pl.DataFrame({"a": [1, 2, 3, 4, 5, 1, 2, 3, 4, 5]}),
1007
)
1008
1009
assert io_source.call_count == 1
1010
1011
io_source = Mock(wraps=lambda *_: iter([df]))
1012
1013
# Without explicit is_pure parameter should default to False
1014
lf = register_io_source(
1015
io_source,
1016
schema=df.schema,
1017
validate_schema=True,
1018
)
1019
1020
lfs = [lf, lf]
1021
1022
plan = pl.explain_all(lfs)
1023
1024
caches = [x for x in map(str.strip, plan.splitlines()) if x.startswith("CACHE[")]
1025
assert len(caches) == 0
1026
1027
assert io_source.call_count == 0
1028
1029
assert_frame_equal(
1030
pl.concat(pl.collect_all(lfs)),
1031
pl.DataFrame({"a": [1, 2, 3, 4, 5, 1, 2, 3, 4, 5]}),
1032
)
1033
1034
assert io_source.call_count == 2
1035
1036
io_source = Mock(wraps=lambda *_: iter([df]))
1037
1038
# LazyFrames constructed from separate calls do not CSE even if the
1039
# io_source function is the same.
1040
#
1041
# Note: This behavior is achieved by having `register_io_source` wrap
1042
# the user-provided io plugin with a locally constructed wrapper before
1043
# passing to the Rust-side.
1044
lfs = [
1045
register_io_source(
1046
io_source,
1047
schema=df.schema,
1048
validate_schema=True,
1049
is_pure=True,
1050
),
1051
register_io_source(
1052
io_source,
1053
schema=df.schema,
1054
validate_schema=True,
1055
is_pure=True,
1056
),
1057
]
1058
1059
caches = [x for x in map(str.strip, plan.splitlines()) if x.startswith("CACHE[")]
1060
assert len(caches) == 0
1061
1062
assert io_source.call_count == 0
1063
1064
assert_frame_equal(
1065
pl.concat(pl.collect_all(lfs)),
1066
pl.DataFrame({"a": [1, 2, 3, 4, 5, 1, 2, 3, 4, 5]}),
1067
)
1068
1069
assert io_source.call_count == 2
1070
1071
1072
@pytest.mark.write_disk
1073
def test_cse_preferred_over_slice() -> None:
1074
# This test asserts that even if we slice disjoint sections of a lazyframe, caching
1075
# is preferred, and slicing is not pushed down
1076
df = pl.DataFrame({"a": list(range(1, 21))})
1077
with NamedTemporaryFile() as f:
1078
val = df.write_csv()
1079
f.write(val.encode())
1080
f.seek(0)
1081
ldf = pl.scan_csv(f.name)
1082
left = ldf.slice(0, 5)
1083
right = ldf.slice(6, 5)
1084
q = left.join(right, on="a", how="inner")
1085
assert "CACHE[id:" in q.explain(
1086
optimizations=pl.QueryOptFlags(comm_subplan_elim=True)
1087
)
1088
1089
1090
def test_cse_preferred_over_slice_custom_io_source() -> None:
1091
# This test asserts that even if we slice disjoint sections of a custom io source,
1092
# caching is preferred, and slicing is not pushed down
1093
df = pl.DataFrame({"a": list(range(1, 21))})
1094
lf = create_dataframe_source(df, is_pure=True)
1095
left = lf.slice(0, 5)
1096
right = lf.slice(6, 5)
1097
q = left.join(right, on="a", how="inner")
1098
assert "CACHE[id:" in q.explain(
1099
optimizations=pl.QueryOptFlags(comm_subplan_elim=True)
1100
)
1101
1102
lf = create_dataframe_source(df, is_pure=False)
1103
left = lf.slice(0, 5)
1104
right = lf.slice(6, 5)
1105
q = left.join(right, on="a", how="inner")
1106
assert "CACHE[id:" not in q.explain(
1107
optimizations=pl.QueryOptFlags(comm_subplan_elim=True)
1108
)
1109
1110
1111
def test_cse_custom_io_source_diff_columns() -> None:
1112
df = pl.DataFrame({"a": [1, 2, 3, 4, 5], "b": [10, 11, 12, 13, 14]})
1113
lf = create_dataframe_source(df, is_pure=True)
1114
collection = [lf.select("a"), lf.select("b")]
1115
assert "CACHE[id:" in pl.explain_all(collection)
1116
collected = pl.collect_all(
1117
collection, optimizations=pl.QueryOptFlags(comm_subplan_elim=True)
1118
)
1119
assert_frame_equal(df.select("a"), collected[0])
1120
assert_frame_equal(df.select("b"), collected[1])
1121
1122
1123
def test_cse_custom_io_source_diff_filters() -> None:
1124
df = pl.DataFrame({"a": [1, 2, 3, 4, 5], "b": [10, 11, 12, 13, 14]})
1125
lf = create_dataframe_source(df, is_pure=True)
1126
1127
# We use this so that the true type of the input is passed through
1128
# to the output
1129
PolarsFrame = TypeVar("PolarsFrame", pl.DataFrame, pl.LazyFrame)
1130
1131
def left_pipe(df_or_lf: PolarsFrame) -> PolarsFrame:
1132
return df_or_lf.select("a").filter(pl.col("a").is_between(2, 6))
1133
1134
def right_pipe(df_or_lf: PolarsFrame) -> PolarsFrame:
1135
return df_or_lf.select("b").filter(pl.col("b").is_between(10, 13))
1136
1137
collection = [lf.pipe(left_pipe), lf.pipe(right_pipe)]
1138
explanation = pl.explain_all(collection)
1139
# we prefer predicate pushdown over CSE
1140
assert "CACHE[id:" not in explanation
1141
assert 'SELECTION: col("a").is_between([2, 6])' in explanation
1142
assert 'SELECTION: col("b").is_between([10, 13])' in explanation
1143
1144
res = pl.collect_all(collection)
1145
expected = [df.pipe(left_pipe), df.pipe(right_pipe)]
1146
assert_frame_equal(expected[0], res[0])
1147
assert_frame_equal(expected[1], res[1])
1148
1149
1150
@pytest.mark.skip
1151
def test_cspe_recursive_24744() -> None:
1152
df_a = pl.DataFrame([pl.Series("x", [0, 1, 2, 3], dtype=pl.UInt32)])
1153
1154
def convoluted_inner_join(
1155
lf_left: pl.LazyFrame,
1156
lf_right: pl.LazyFrame,
1157
) -> pl.LazyFrame:
1158
lf_left = lf_left.with_columns(pl.col("x").alias("index"))
1159
1160
lf_joined = lf_left.join(
1161
lf_right,
1162
how="inner",
1163
on=["x"],
1164
)
1165
1166
lf_joined_final = lf_left.join(
1167
lf_joined,
1168
how="inner",
1169
on=["index", "x"],
1170
).drop("index")
1171
return lf_joined_final
1172
1173
lf_a = df_a.lazy()
1174
lf_j1 = convoluted_inner_join(lf_left=lf_a, lf_right=lf_a)
1175
lf_j2 = convoluted_inner_join(lf_left=lf_j1, lf_right=lf_a)
1176
lf_j3 = convoluted_inner_join(lf_left=lf_j2, lf_right=lf_a).sort("x")
1177
1178
assert lf_j3.explain().count("CACHE") == 14
1179
assert_frame_equal(
1180
lf_j3.collect(),
1181
lf_j3.collect(optimizations=pl.QueryOptFlags(comm_subplan_elim=False)),
1182
)
1183
assert (
1184
lf_j3.show_graph( # type: ignore[union-attr]
1185
engine="streaming", plan_stage="physical", raw_output=True
1186
).count("multiplexer")
1187
== 3
1188
)
1189
assert (
1190
lf_j3.show_graph( # type: ignore[union-attr]
1191
engine="in-memory", plan_stage="physical", raw_output=True
1192
).count("CACHE")
1193
== 3
1194
)
1195
1196
1197
def test_cpse_predicates_25030() -> None:
1198
df = pl.LazyFrame({"key": [1, 2, 2], "x": [6, 2, 3], "y": [0, 1, 4]})
1199
1200
q1 = df.group_by("key").len().filter(pl.col("len") > 1)
1201
q2 = df.filter(pl.col.x > pl.col.y)
1202
1203
q3 = q1.join(q2, on="key")
1204
1205
q4 = q3.group_by("key").len().join(q3, on="key")
1206
1207
got = q4.collect()
1208
expected = q4.collect(optimizations=pl.QueryOptFlags(comm_subplan_elim=False))
1209
1210
assert_frame_equal(got, expected)
1211
assert q4.explain().count("CACHE") == 2
1212
1213
1214
def test_asof_join_25699() -> None:
1215
df = pl.LazyFrame({"a": [10], "b": [10]})
1216
1217
df = df.with_columns(pl.col("a"))
1218
df = df.with_columns(pl.col("b"))
1219
1220
assert_frame_equal(
1221
df.join_asof(df, on="b").collect(),
1222
pl.DataFrame({"a": [10], "b": [10], "a_right": [10]}),
1223
)
1224
1225
1226
def test_csee_python_function() -> None:
1227
# Make sure to use the same expression
1228
# This only works for functions on the same address
1229
expr = pl.col("a").map_elements(lambda x: hash(x))
1230
q = pl.LazyFrame({"a": [10], "b": [10]}).with_columns(
1231
a=expr * 10,
1232
b=expr * 100,
1233
)
1234
1235
assert "__POLARS_CSER" in q.explain()
1236
assert_frame_equal(
1237
q.collect(), q.collect(optimizations=pl.QueryOptFlags(comm_subexpr_elim=False))
1238
)
1239
1240
1241
def test_csee_streaming() -> None:
1242
lf = pl.LazyFrame({"a": [10], "b": [10]})
1243
1244
# elementwise is allowed
1245
expr = pl.col("a") * pl.col("b")
1246
q = lf.with_columns(
1247
a=expr * 10,
1248
b=expr * 100,
1249
)
1250
assert "__POLARS_CSER" in q.explain(engine="streaming")
1251
1252
# non-elementwise not
1253
expr = pl.col("a").sum()
1254
q = lf.with_columns(
1255
a=expr * 10,
1256
b=expr * 100,
1257
)
1258
assert "__POLARS_CSER" not in q.explain(engine="streaming")
1259
1260