Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/test_cse.py
6939 views
1
from __future__ import annotations
2
3
import re
4
from datetime import date, datetime, timedelta
5
from pathlib import Path
6
from tempfile import NamedTemporaryFile, TemporaryDirectory
7
from typing import TYPE_CHECKING, Any, TypeVar
8
from unittest.mock import Mock
9
10
import numpy as np
11
import pytest
12
13
import polars as pl
14
from polars.io.plugins import register_io_source
15
from polars.testing import assert_frame_equal
16
17
if TYPE_CHECKING:
18
from collections.abc import Iterator
19
20
21
def num_cse_occurrences(explanation: str) -> int:
22
"""The number of unique CSE columns in an explain string."""
23
return len(set(re.findall('__POLARS_CSER_0x[^"]+"', explanation)))
24
25
26
def create_dataframe_source(
27
source_df: pl.DataFrame,
28
is_pure: bool,
29
validate_schame: bool = False,
30
) -> pl.LazyFrame:
31
"""Generates a custom io source based on the provided pl.DataFrame."""
32
33
def dataframe_source(
34
with_columns: list[str] | None,
35
predicate: pl.Expr | None,
36
_n_rows: int | None,
37
_batch_size: int | None,
38
) -> Iterator[pl.DataFrame]:
39
df = source_df.clone()
40
if predicate is not None:
41
df = df.filter(predicate)
42
if with_columns is not None:
43
df = df.select(with_columns)
44
yield df
45
46
return register_io_source(
47
dataframe_source,
48
schema=source_df.schema,
49
validate_schema=validate_schame,
50
is_pure=is_pure,
51
)
52
53
54
@pytest.mark.parametrize("use_custom_io_source", [True, False])
55
def test_cse_rename_cross_join_5405(use_custom_io_source: bool) -> None:
56
# https://github.com/pola-rs/polars/issues/5405
57
58
right = pl.DataFrame({"A": [1, 2], "B": [3, 4], "D": [5, 6]}).lazy()
59
if use_custom_io_source:
60
right = create_dataframe_source(right.collect(), is_pure=True)
61
left = pl.DataFrame({"C": [3, 4]}).lazy().join(right.select("A"), how="cross")
62
63
result = left.join(right.rename({"B": "C"}), on=["A", "C"], how="left").collect(
64
optimizations=pl.QueryOptFlags(comm_subplan_elim=True)
65
)
66
67
expected = pl.DataFrame(
68
{
69
"C": [3, 3, 4, 4],
70
"A": [1, 2, 1, 2],
71
"D": [5, None, None, 6],
72
}
73
)
74
assert_frame_equal(result, expected, check_row_order=False)
75
76
77
def test_union_duplicates() -> None:
78
n_dfs = 10
79
df_lazy = pl.DataFrame({}).lazy()
80
lazy_dfs = [df_lazy for _ in range(n_dfs)]
81
82
matches = re.findall(r"CACHE\[id: (.*)]", pl.concat(lazy_dfs).explain())
83
84
assert len(matches) == 10
85
assert len(set(matches)) == 1
86
87
88
def test_cse_with_struct_expr_11116() -> None:
89
# https://github.com/pola-rs/polars/issues/11116
90
91
df = pl.DataFrame([{"s": {"a": 1, "b": 4}, "c": 3}]).lazy()
92
93
result = df.with_columns(
94
pl.col("s").struct.field("a").alias("s_a"),
95
pl.col("s").struct.field("b").alias("s_b"),
96
(
97
(pl.col("s").struct.field("a") <= pl.col("c"))
98
& (pl.col("s").struct.field("b") > pl.col("c"))
99
).alias("c_between_a_and_b"),
100
).collect(optimizations=pl.QueryOptFlags(comm_subexpr_elim=True))
101
102
expected = pl.DataFrame(
103
{
104
"s": [{"a": 1, "b": 4}],
105
"c": [3],
106
"s_a": [1],
107
"s_b": [4],
108
"c_between_a_and_b": [True],
109
}
110
)
111
assert_frame_equal(result, expected)
112
113
114
def test_cse_schema_6081() -> None:
115
# https://github.com/pola-rs/polars/issues/6081
116
117
df = pl.DataFrame(
118
data=[
119
[date(2022, 12, 12), 1, 1],
120
[date(2022, 12, 12), 1, 2],
121
[date(2022, 12, 13), 5, 2],
122
],
123
schema=["date", "id", "value"],
124
orient="row",
125
).lazy()
126
127
min_value_by_group = df.group_by(["date", "id"]).agg(
128
pl.col("value").min().alias("min_value")
129
)
130
131
result = df.join(min_value_by_group, on=["date", "id"], how="left").collect(
132
optimizations=pl.QueryOptFlags(comm_subplan_elim=True, projection_pushdown=True)
133
)
134
expected = pl.DataFrame(
135
{
136
"date": [date(2022, 12, 12), date(2022, 12, 12), date(2022, 12, 13)],
137
"id": [1, 1, 5],
138
"value": [1, 2, 2],
139
"min_value": [1, 1, 2],
140
}
141
)
142
assert_frame_equal(result, expected, check_row_order=False)
143
144
145
def test_cse_9630() -> None:
146
lf1 = pl.LazyFrame({"key": [1], "x": [1]})
147
lf2 = pl.LazyFrame({"key": [1], "y": [2]})
148
149
joined_lf2 = lf1.join(lf2, on="key")
150
151
all_subsections = (
152
pl.concat(
153
[
154
lf1.select("key", pl.col("x").alias("value")),
155
joined_lf2.select("key", pl.col("y").alias("value")),
156
]
157
)
158
.group_by("key")
159
.agg(pl.col("value"))
160
)
161
162
intersected_df1 = all_subsections.join(lf1, on="key")
163
intersected_df2 = all_subsections.join(lf2, on="key")
164
165
result = intersected_df1.join(intersected_df2, on=["key"], how="left").collect(
166
optimizations=pl.QueryOptFlags(comm_subplan_elim=True)
167
)
168
169
expected = pl.DataFrame(
170
{
171
"key": [1],
172
"value": [[1, 2]],
173
"x": [1],
174
"value_right": [[1, 2]],
175
"y": [2],
176
}
177
)
178
assert_frame_equal(result, expected)
179
180
181
@pytest.mark.write_disk
182
@pytest.mark.parametrize("maintain_order", [False, True])
183
def test_schema_row_index_cse(maintain_order: bool) -> None:
184
with NamedTemporaryFile() as csv_a:
185
csv_a.write(b"A,B\nGr1,A\nGr1,B")
186
csv_a.seek(0)
187
188
df_a = pl.scan_csv(csv_a.name).with_row_index("Idx")
189
190
result = (
191
df_a.join(df_a, on="B", maintain_order="left" if maintain_order else "none")
192
.group_by("A", maintain_order=maintain_order)
193
.all()
194
.collect(optimizations=pl.QueryOptFlags(comm_subexpr_elim=True))
195
)
196
197
expected = pl.DataFrame(
198
{
199
"A": ["Gr1"],
200
"Idx": [[0, 1]],
201
"B": [["A", "B"]],
202
"Idx_right": [[0, 1]],
203
"A_right": [["Gr1", "Gr1"]],
204
},
205
schema_overrides={"Idx": pl.List(pl.UInt32), "Idx_right": pl.List(pl.UInt32)},
206
)
207
assert_frame_equal(result, expected, check_row_order=maintain_order)
208
209
210
@pytest.mark.debug
211
def test_cse_expr_selection_context() -> None:
212
q = pl.LazyFrame(
213
{
214
"a": [1, 2, 3, 4],
215
"b": [1, 2, 3, 4],
216
"c": [1, 2, 3, 4],
217
}
218
)
219
220
derived = (pl.col("a") * pl.col("b")).sum()
221
derived2 = derived * derived
222
223
exprs = [
224
derived.alias("d1"),
225
(derived * pl.col("c").sum() - 1).alias("foo"),
226
derived2.alias("d2"),
227
(derived2 * 10).alias("d3"),
228
]
229
230
result = q.select(exprs).collect(
231
optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)
232
)
233
assert (
234
num_cse_occurrences(
235
q.select(exprs).explain(
236
optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)
237
)
238
)
239
== 2
240
)
241
expected = pl.DataFrame(
242
{
243
"d1": [30],
244
"foo": [299],
245
"d2": [900],
246
"d3": [9000],
247
}
248
)
249
assert_frame_equal(result, expected)
250
251
result = q.with_columns(exprs).collect(
252
optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)
253
)
254
assert (
255
num_cse_occurrences(
256
q.with_columns(exprs).explain(
257
optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)
258
)
259
)
260
== 2
261
)
262
expected = pl.DataFrame(
263
{
264
"a": [1, 2, 3, 4],
265
"b": [1, 2, 3, 4],
266
"c": [1, 2, 3, 4],
267
"d1": [30, 30, 30, 30],
268
"foo": [299, 299, 299, 299],
269
"d2": [900, 900, 900, 900],
270
"d3": [9000, 9000, 9000, 9000],
271
}
272
)
273
assert_frame_equal(result, expected)
274
275
276
def test_windows_cse_excluded() -> None:
277
lf = pl.LazyFrame(
278
data=[
279
("a", "aaa", 1),
280
("a", "bbb", 3),
281
("a", "ccc", 1),
282
("c", "xxx", 2),
283
("c", "yyy", 3),
284
("c", "zzz", 4),
285
("b", "qqq", 0),
286
],
287
schema=["a", "b", "c"],
288
orient="row",
289
)
290
291
result = lf.select(
292
c_diff=pl.col("c").diff(1),
293
c_diff_by_a=pl.col("c").diff(1).over("a"),
294
).collect(optimizations=pl.QueryOptFlags(comm_subexpr_elim=True))
295
296
expected = pl.DataFrame(
297
{
298
"c_diff": [None, 2, -2, 1, 1, 1, -4],
299
"c_diff_by_a": [None, 2, -2, None, 1, 1, None],
300
}
301
)
302
assert_frame_equal(result, expected)
303
304
305
def test_cse_group_by_10215() -> None:
306
lf = pl.LazyFrame({"a": [1], "b": [1]})
307
308
result = lf.group_by("b").agg(
309
(pl.col("a").sum() * pl.col("a").sum()).alias("x"),
310
(pl.col("b").sum() * pl.col("b").sum()).alias("y"),
311
(pl.col("a").sum() * pl.col("a").sum()).alias("x2"),
312
((pl.col("a") + 2).sum() * pl.col("a").sum()).alias("x3"),
313
((pl.col("a") + 2).sum() * pl.col("b").sum()).alias("x4"),
314
((pl.col("a") + 2).sum() * pl.col("b").sum()),
315
)
316
317
assert "__POLARS_CSER" in result.explain(
318
optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)
319
)
320
expected = pl.DataFrame(
321
{
322
"b": [1],
323
"x": [1],
324
"y": [1],
325
"x2": [1],
326
"x3": [3],
327
"x4": [3],
328
"a": [3],
329
}
330
)
331
assert_frame_equal(
332
result.collect(optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)), expected
333
)
334
335
336
def test_cse_mixed_window_functions() -> None:
337
# checks if the window caches are cleared
338
# there are windows in the cse's and the default expressions
339
lf = pl.LazyFrame({"a": [1], "b": [1], "c": [1]})
340
341
result = lf.select(
342
pl.col("a"),
343
pl.col("b"),
344
pl.col("c"),
345
pl.col("b").rank().alias("rank"),
346
pl.col("b").rank().alias("d_rank"),
347
pl.col("b").first().over([pl.col("a")]).alias("b_first"),
348
pl.col("b").last().over([pl.col("a")]).alias("b_last"),
349
pl.col("b").shift().alias("b_lag_1"),
350
pl.col("b").shift().alias("b_lead_1"),
351
pl.col("c").cum_sum().alias("c_cumsum"),
352
pl.col("c").cum_sum().over([pl.col("a")]).alias("c_cumsum_by_a"),
353
pl.col("c").diff().alias("c_diff"),
354
pl.col("c").diff().over([pl.col("a")]).alias("c_diff_by_a"),
355
).collect(optimizations=pl.QueryOptFlags(comm_subexpr_elim=True))
356
357
expected = pl.DataFrame(
358
{
359
"a": [1],
360
"b": [1],
361
"c": [1],
362
"rank": [1.0],
363
"d_rank": [1.0],
364
"b_first": [1],
365
"b_last": [1],
366
"b_lag_1": [None],
367
"b_lead_1": [None],
368
"c_cumsum": [1],
369
"c_cumsum_by_a": [1],
370
"c_diff": [None],
371
"c_diff_by_a": [None],
372
},
373
).with_columns(pl.col(pl.Null).cast(pl.Int64))
374
assert_frame_equal(result, expected)
375
376
377
def test_cse_10401() -> None:
378
df = pl.LazyFrame({"clicks": [1.0, float("nan"), None]})
379
380
q = df.with_columns(pl.all().fill_null(0).fill_nan(0))
381
382
assert r"""col("clicks").fill_null([0.0]).alias("__POLARS_CSER""" in q.explain()
383
384
expected = pl.DataFrame({"clicks": [1.0, 0.0, 0.0]})
385
assert_frame_equal(
386
q.collect(optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)), expected
387
)
388
389
390
def test_cse_10441() -> None:
391
lf = pl.LazyFrame({"a": [1, 2, 3], "b": [3, 2, 1]})
392
393
result = lf.select(
394
pl.col("a").sum() + pl.col("a").sum() + pl.col("b").sum()
395
).collect(optimizations=pl.QueryOptFlags(comm_subexpr_elim=True))
396
397
expected = pl.DataFrame({"a": [18]})
398
assert_frame_equal(result, expected)
399
400
401
def test_cse_10452() -> None:
402
lf = pl.LazyFrame({"a": [1, 2, 3], "b": [3, 2, 1]})
403
q = lf.select(
404
pl.col("b").sum() + pl.col("a").sum().over(pl.col("b")) + pl.col("b").sum()
405
)
406
407
assert "__POLARS_CSE" in q.explain(
408
optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)
409
)
410
411
expected = pl.DataFrame({"b": [13, 14, 15]})
412
assert_frame_equal(
413
q.collect(optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)), expected
414
)
415
416
417
def test_cse_group_by_ternary_10490() -> None:
418
lf = pl.LazyFrame(
419
{
420
"a": [1, 1, 2, 2],
421
"b": [1, 2, 3, 4],
422
"c": [2, 3, 4, 5],
423
}
424
)
425
426
result = (
427
lf.group_by("a")
428
.agg(
429
[
430
pl.when(pl.col(col).is_null().all()).then(None).otherwise(1).alias(col)
431
for col in ["b", "c"]
432
]
433
+ [
434
(pl.col("a").sum() * pl.col("a").sum()).alias("x"),
435
(pl.col("b").sum() * pl.col("b").sum()).alias("y"),
436
(pl.col("a").sum() * pl.col("a").sum()).alias("x2"),
437
((pl.col("a") + 2).sum() * pl.col("a").sum()).alias("x3"),
438
((pl.col("a") + 2).sum() * pl.col("b").sum()).alias("x4"),
439
]
440
)
441
.collect(optimizations=pl.QueryOptFlags(comm_subexpr_elim=True))
442
.sort("a")
443
)
444
445
expected = pl.DataFrame(
446
{
447
"a": [1, 2],
448
"b": [1, 1],
449
"c": [1, 1],
450
"x": [4, 16],
451
"y": [9, 49],
452
"x2": [4, 16],
453
"x3": [12, 32],
454
"x4": [18, 56],
455
},
456
schema_overrides={"b": pl.Int32, "c": pl.Int32},
457
)
458
assert_frame_equal(result, expected)
459
460
461
def test_cse_quantile_10815() -> None:
462
np.random.seed(1)
463
a = np.random.random(10)
464
b = np.random.random(10)
465
df = pl.DataFrame({"a": a, "b": b})
466
cols = ["a", "b"]
467
q = df.lazy().select(
468
*(
469
pl.col(c).quantile(0.75, interpolation="midpoint").name.suffix("_3")
470
for c in cols
471
),
472
*(
473
pl.col(c).quantile(0.25, interpolation="midpoint").name.suffix("_1")
474
for c in cols
475
),
476
)
477
assert "__POLARS_CSE" not in q.explain()
478
assert q.collect().to_dict(as_series=False) == {
479
"a_3": [0.40689473946662197],
480
"b_3": [0.6145786693120769],
481
"a_1": [0.16650805109739197],
482
"b_1": [0.2012768694081981],
483
}
484
485
486
def test_cse_nan_10824() -> None:
487
v = pl.col("a") / pl.col("b")
488
magic = pl.when(v > 0).then(pl.lit(float("nan"))).otherwise(v)
489
assert (
490
str(
491
(
492
pl.DataFrame(
493
{
494
"a": [1.0],
495
"b": [1.0],
496
}
497
)
498
.lazy()
499
.select(magic)
500
.collect(optimizations=pl.QueryOptFlags(comm_subexpr_elim=True))
501
).to_dict(as_series=False)
502
)
503
== "{'literal': [nan]}"
504
)
505
506
507
def test_cse_10901() -> None:
508
df = pl.DataFrame(data=range(6), schema={"a": pl.Int64})
509
a = pl.col("a").rolling_sum(window_size=2)
510
b = pl.col("a").rolling_sum(window_size=3)
511
exprs = {
512
"ax1": a,
513
"ax2": a * 2,
514
"bx1": b,
515
"bx2": b * 2,
516
}
517
518
expected = pl.DataFrame(
519
{
520
"a": [0, 1, 2, 3, 4, 5],
521
"ax1": [None, 1, 3, 5, 7, 9],
522
"ax2": [None, 2, 6, 10, 14, 18],
523
"bx1": [None, None, 3, 6, 9, 12],
524
"bx2": [None, None, 6, 12, 18, 24],
525
}
526
)
527
528
assert_frame_equal(df.lazy().with_columns(**exprs).collect(), expected)
529
530
531
def test_cse_count_in_group_by() -> None:
532
q = (
533
pl.LazyFrame({"a": [1, 1, 2], "b": [1, 2, 3], "c": [40, 51, 12]})
534
.group_by("a")
535
.agg(pl.all().slice(0, pl.len() - 1))
536
)
537
538
assert "POLARS_CSER" not in q.explain()
539
assert q.collect().sort("a").to_dict(as_series=False) == {
540
"a": [1, 2],
541
"b": [[1], []],
542
"c": [[40], []],
543
}
544
545
546
def test_cse_slice_11594() -> None:
547
df = pl.LazyFrame({"a": [1, 2, 1, 2, 1, 2]})
548
549
q = df.select(
550
pl.col("a").slice(offset=1, length=pl.len() - 1).alias("1"),
551
pl.col("a").slice(offset=1, length=pl.len() - 1).alias("2"),
552
)
553
554
assert "__POLARS_CSE" in q.explain(
555
optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)
556
)
557
558
assert q.collect(optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)).to_dict(
559
as_series=False
560
) == {
561
"1": [2, 1, 2, 1, 2],
562
"2": [2, 1, 2, 1, 2],
563
}
564
565
q = df.select(
566
pl.col("a").slice(offset=1, length=pl.len() - 1).alias("1"),
567
pl.col("a").slice(offset=0, length=pl.len() - 1).alias("2"),
568
)
569
570
assert "__POLARS_CSE" in q.explain(
571
optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)
572
)
573
574
assert q.collect(optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)).to_dict(
575
as_series=False
576
) == {
577
"1": [2, 1, 2, 1, 2],
578
"2": [1, 2, 1, 2, 1],
579
}
580
581
582
def test_cse_is_in_11489() -> None:
583
df = pl.DataFrame(
584
{"cond": [1, 2, 3, 2, 1], "x": [1.0, 0.20, 3.0, 4.0, 0.50]}
585
).lazy()
586
any_cond = (
587
pl.when(pl.col("cond").is_in([2, 3]))
588
.then(True)
589
.when(pl.col("cond").is_in([1]))
590
.then(False)
591
.otherwise(None)
592
.alias("any_cond")
593
)
594
val = (
595
pl.when(any_cond)
596
.then(1.0)
597
.when(~any_cond)
598
.then(0.0)
599
.otherwise(None)
600
.alias("val")
601
)
602
assert df.select("cond", any_cond, val).collect().to_dict(as_series=False) == {
603
"cond": [1, 2, 3, 2, 1],
604
"any_cond": [False, True, True, True, False],
605
"val": [0.0, 1.0, 1.0, 1.0, 0.0],
606
}
607
608
609
def test_cse_11958() -> None:
610
df = pl.LazyFrame({"a": [1, 2, 3, 4, 5]})
611
vector_losses = []
612
for lag in range(1, 5):
613
difference = pl.col("a") - pl.col("a").shift(lag)
614
component_loss = pl.when(difference >= 0).then(difference * 10)
615
vector_losses.append(component_loss.alias(f"diff{lag}"))
616
617
q = df.select(vector_losses)
618
assert "__POLARS_CSE" in q.explain(
619
optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)
620
)
621
assert q.collect(optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)).to_dict(
622
as_series=False
623
) == {
624
"diff1": [None, 10, 10, 10, 10],
625
"diff2": [None, None, 20, 20, 20],
626
"diff3": [None, None, None, 30, 30],
627
"diff4": [None, None, None, None, 40],
628
}
629
630
631
def test_cse_14047() -> None:
632
ldf = pl.LazyFrame(
633
{
634
"timestamp": pl.datetime_range(
635
datetime(2024, 1, 12),
636
datetime(2024, 1, 12, 0, 0, 0, 150_000),
637
"10ms",
638
eager=True,
639
closed="left",
640
),
641
"price": list(range(15)),
642
}
643
)
644
645
def count_diff(
646
price: pl.Expr, upper_bound: float = 0.1, lower_bound: float = 0.001
647
) -> pl.Expr:
648
span_end_to_curr = (
649
price.count()
650
.cast(int)
651
.rolling("timestamp", period=timedelta(seconds=lower_bound))
652
)
653
span_start_to_curr = (
654
price.count()
655
.cast(int)
656
.rolling("timestamp", period=timedelta(seconds=upper_bound))
657
)
658
return (span_start_to_curr - span_end_to_curr).alias(
659
f"count_diff_{upper_bound}_{lower_bound}"
660
)
661
662
def s_per_count(count_diff: pl.Expr, span: tuple[float, float]) -> pl.Expr:
663
return (span[1] * 1000 - span[0] * 1000) / count_diff
664
665
spans = [(0.001, 0.1), (1, 10)]
666
count_diff_exprs = [count_diff(pl.col("price"), span[0], span[1]) for span in spans]
667
s_per_count_exprs = [
668
s_per_count(count_diff, span).alias(f"zz_{span}")
669
for count_diff, span in zip(count_diff_exprs, spans)
670
]
671
672
exprs = count_diff_exprs + s_per_count_exprs
673
ldf = ldf.with_columns(*exprs)
674
assert_frame_equal(
675
ldf.collect(optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)),
676
ldf.collect(optimizations=pl.QueryOptFlags(comm_subexpr_elim=False)),
677
)
678
679
680
def test_cse_15536() -> None:
681
source = pl.DataFrame({"a": range(10)})
682
683
data = source.lazy().filter(pl.col("a") >= 5)
684
685
assert pl.concat(
686
[
687
data.filter(pl.lit(True) & (pl.col("a") == 6) | (pl.col("a") == 9)),
688
data.filter(pl.lit(True) & (pl.col("a") == 7) | (pl.col("a") == 8)),
689
]
690
).collect()["a"].to_list() == [6, 9, 7, 8]
691
692
693
def test_cse_15548() -> None:
694
ldf = pl.LazyFrame({"a": [1, 2, 3]})
695
ldf2 = ldf.filter(pl.col("a") == 1).cache()
696
ldf3 = pl.concat([ldf, ldf2])
697
698
assert (
699
len(ldf3.collect(optimizations=pl.QueryOptFlags(comm_subplan_elim=False))) == 4
700
)
701
assert (
702
len(ldf3.collect(optimizations=pl.QueryOptFlags(comm_subplan_elim=True))) == 4
703
)
704
705
706
@pytest.mark.debug
707
def test_cse_and_schema_update_projection_pd() -> None:
708
df = pl.LazyFrame({"a": [1, 2], "b": [99, 99]})
709
710
q = (
711
df.lazy()
712
.with_row_index()
713
.select(
714
pl.when(pl.col("b") < 10)
715
.then(0.1 * pl.col("b"))
716
.when(pl.col("b") < 100)
717
.then(0.2 * pl.col("b"))
718
)
719
)
720
assert q.collect(optimizations=pl.QueryOptFlags(comm_subplan_elim=False)).to_dict(
721
as_series=False
722
) == {"literal": [19.8, 19.8]}
723
assert (
724
num_cse_occurrences(
725
q.explain(optimizations=pl.QueryOptFlags(comm_subexpr_elim=True))
726
)
727
== 1
728
)
729
730
731
@pytest.mark.debug
732
@pytest.mark.may_fail_auto_streaming
733
@pytest.mark.parametrize("use_custom_io_source", [True, False])
734
def test_cse_predicate_self_join(
735
capfd: Any, monkeypatch: Any, use_custom_io_source: bool
736
) -> None:
737
monkeypatch.setenv("POLARS_VERBOSE", "1")
738
y = pl.LazyFrame({"a": [1], "b": [2], "y": [3]})
739
if use_custom_io_source:
740
y = create_dataframe_source(y.collect(), is_pure=True)
741
742
xf = y.filter(pl.col("y") == 2).select(["a", "b"])
743
y_xf = y.join(xf, on=["a", "b"], how="left")
744
745
y_xf_c = y_xf.select("a", "b")
746
assert y_xf_c.collect().to_dict(as_series=False) == {"a": [1], "b": [2]}
747
captured = capfd.readouterr().err
748
assert "CACHE HIT" in captured
749
750
751
def test_cse_manual_cache_15688() -> None:
752
df = pl.LazyFrame(
753
{"a": [1, 2, 3, 1, 2, 3], "b": [1, 1, 1, 1, 1, 1], "id": [1, 1, 1, 2, 2, 2]}
754
)
755
756
df1 = df.filter(id=1).join(df.filter(id=2), on=["a", "b"], how="semi")
757
df2 = df.filter(id=1).join(df1, on=["a", "b"], how="semi")
758
df2 = df2.cache()
759
res = df2.group_by("b").agg(pl.all().sum())
760
761
assert res.cache().with_columns(foo=1).collect().to_dict(as_series=False) == {
762
"b": [1],
763
"a": [6],
764
"id": [3],
765
"foo": [1],
766
}
767
768
769
def test_cse_drop_nulls_15795() -> None:
770
A = pl.LazyFrame({"X": 1})
771
B = pl.LazyFrame({"X": 1, "Y": 0}).filter(pl.col("Y").is_not_null())
772
C = A.join(B, on="X").select("X")
773
D = B.select("X")
774
assert C.join(D, on="X").collect().shape == (1, 1)
775
776
777
def test_cse_no_projection_15980() -> None:
778
df = pl.LazyFrame({"x": "a", "y": 1})
779
df = pl.concat(df.with_columns(pl.col("y").add(n)) for n in range(2))
780
781
assert df.filter(pl.col("x").eq("a")).select("x").collect().to_dict(
782
as_series=False
783
) == {"x": ["a", "a"]}
784
785
786
@pytest.mark.debug
787
def test_cse_series_collision_16138() -> None:
788
holdings = pl.DataFrame(
789
{
790
"fund_currency": ["CLP", "CLP"],
791
"asset_currency": ["EUR", "USA"],
792
}
793
)
794
795
usd = ["USD"]
796
eur = ["EUR"]
797
clp = ["CLP"]
798
799
currency_factor_query_dict = [
800
pl.col("asset_currency").is_in(eur) & pl.col("fund_currency").is_in(clp),
801
pl.col("asset_currency").is_in(eur) & pl.col("fund_currency").is_in(usd),
802
pl.col("asset_currency").is_in(clp) & pl.col("fund_currency").is_in(clp),
803
pl.col("asset_currency").is_in(usd) & pl.col("fund_currency").is_in(usd),
804
]
805
806
factor_holdings = holdings.lazy().with_columns(
807
pl.coalesce(currency_factor_query_dict).alias("currency_factor"),
808
)
809
810
assert factor_holdings.collect(
811
optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)
812
).to_dict(as_series=False) == {
813
"fund_currency": ["CLP", "CLP"],
814
"asset_currency": ["EUR", "USA"],
815
"currency_factor": [True, False],
816
}
817
assert (
818
num_cse_occurrences(
819
factor_holdings.explain(
820
optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)
821
)
822
)
823
== 3
824
)
825
826
827
def test_nested_cache_no_panic_16553() -> None:
828
assert pl.LazyFrame().select(a=[[[1]]]).collect(
829
optimizations=pl.QueryOptFlags(comm_subexpr_elim=True)
830
).to_dict(as_series=False) == {"a": [[[[1]]]]}
831
832
833
def test_hash_empty_series_16577() -> None:
834
s = pl.Series(values=None)
835
out = pl.LazyFrame().select(s).collect()
836
assert out.equals(s.to_frame())
837
838
839
def test_cse_non_scalar_length_mismatch_17732() -> None:
840
df = pl.LazyFrame({"a": pl.Series(range(30), dtype=pl.Int32)})
841
got = (
842
df.lazy()
843
.with_columns(
844
pl.col("a").head(5).min().alias("b"),
845
pl.col("a").head(5).max().alias("c"),
846
)
847
.collect(optimizations=pl.QueryOptFlags(comm_subexpr_elim=True))
848
)
849
expect = pl.DataFrame(
850
{
851
"a": pl.Series(range(30), dtype=pl.Int32),
852
"b": pl.Series([0] * 30, dtype=pl.Int32),
853
"c": pl.Series([4] * 30, dtype=pl.Int32),
854
}
855
)
856
857
assert_frame_equal(expect, got)
858
859
860
def test_cse_chunks_18124() -> None:
861
df = pl.DataFrame(
862
{
863
"ts_diff": [timedelta(seconds=60)] * 2,
864
"ts_diff_after": [timedelta(seconds=120)] * 2,
865
}
866
)
867
df = pl.concat([df, df], rechunk=False)
868
assert (
869
df.lazy()
870
.with_columns(
871
ts_diff_sign=pl.col("ts_diff") > pl.duration(seconds=0),
872
ts_diff_after_sign=pl.col("ts_diff_after") > pl.duration(seconds=0),
873
)
874
.filter(pl.col("ts_diff") > 1)
875
).collect().shape == (4, 4)
876
877
878
@pytest.mark.may_fail_auto_streaming
879
def test_eager_cse_during_struct_expansion_18411() -> None:
880
df = pl.DataFrame({"foo": [0, 0, 0, 1, 1]})
881
vc = pl.col("foo").value_counts()
882
classes = vc.struct[0]
883
counts = vc.struct[1]
884
# Check if output is stable
885
assert (
886
df.select(pl.col("foo").replace(classes, counts))
887
== df.select(pl.col("foo").replace(classes, counts))
888
)["foo"].all()
889
890
891
def test_cse_as_struct_19253() -> None:
892
df = pl.LazyFrame({"x": [1, 2], "y": [4, 5]})
893
894
assert (
895
df.with_columns(
896
q1=pl.struct(pl.col.x - pl.col.y.mean()),
897
q2=pl.struct(pl.col.x - pl.col.y.mean().over("y")),
898
).collect()
899
).to_dict(as_series=False) == {
900
"x": [1, 2],
901
"y": [4, 5],
902
"q1": [{"x": -3.5}, {"x": -2.5}],
903
"q2": [{"x": -3.0}, {"x": -3.0}],
904
}
905
906
907
@pytest.mark.may_fail_auto_streaming
908
def test_cse_as_struct_value_counts_20927() -> None:
909
assert pl.DataFrame({"x": [i for i in range(1, 6) for _ in range(i)]}).select(
910
pl.struct("x").value_counts().struct.unnest()
911
).sort("count").to_dict(as_series=False) == {
912
"x": [{"x": 1}, {"x": 2}, {"x": 3}, {"x": 4}, {"x": 5}],
913
"count": [1, 2, 3, 4, 5],
914
}
915
916
917
def test_cse_union_19227() -> None:
918
lf = pl.LazyFrame({"A": [1], "B": [2]})
919
lf_1 = lf.select(C="A", B="B")
920
lf_2 = lf.select(C="A", A="B")
921
922
direct = lf_2.join(lf, on=["A"]).select("C", "A", "B")
923
924
indirect = lf_1.join(direct, on=["C", "B"]).select("C", "A", "B")
925
926
out = pl.concat([direct, indirect])
927
assert out.collect().schema == pl.Schema(
928
[("C", pl.Int64), ("A", pl.Int64), ("B", pl.Int64)]
929
)
930
931
932
def test_cse_21115() -> None:
933
lf = pl.LazyFrame({"x": 1, "y": 5})
934
935
assert lf.with_columns(
936
pl.all().exp() + pl.min_horizontal(pl.all().exp())
937
).collect().to_dict(as_series=False) == {
938
"x": [5.43656365691809],
939
"y": [151.13144093103566],
940
}
941
942
943
@pytest.mark.parametrize("use_custom_io_source", [True, False])
944
def test_cse_cache_leakage_22339(use_custom_io_source: bool) -> None:
945
lf1 = pl.LazyFrame({"x": [True] * 2})
946
lf2 = pl.LazyFrame({"x": [True] * 3})
947
if use_custom_io_source:
948
lf1 = create_dataframe_source(lf1.collect(), is_pure=True)
949
lf2 = create_dataframe_source(lf2.collect(), is_pure=True)
950
951
a = lf1
952
b = lf1.filter(pl.col("x").not_().over(1))
953
c = lf2.filter(pl.col("x").not_().over(1))
954
955
ab = a.join(b, on="x")
956
bc = b.join(c, on="x")
957
ac = a.join(c, on="x")
958
959
assert pl.concat([ab, bc, ac]).collect().to_dict(as_series=False) == {"x": []}
960
961
962
@pytest.mark.write_disk
963
def test_multiplex_predicate_pushdown() -> None:
964
ldf = pl.LazyFrame({"a": [1, 1, 2, 2], "b": [1, 2, 3, 4]})
965
with TemporaryDirectory() as f:
966
tmppath = Path(f)
967
ldf.sink_parquet(
968
pl.PartitionByKey(tmppath, by="a", include_key=True),
969
sync_on_close="all",
970
mkdir=True,
971
)
972
ldf = pl.scan_parquet(tmppath, hive_partitioning=True)
973
ldf = ldf.filter(pl.col("a").eq(1)).select("b")
974
assert 'SELECTION: [(col("a")) == (1)]' in pl.explain_all([ldf, ldf])
975
976
977
def test_cse_custom_io_source_same_object() -> None:
978
df = pl.DataFrame({"a": [1, 2, 3, 4, 5]})
979
980
io_source = Mock(wraps=lambda *_: iter([df]))
981
982
lf = register_io_source(
983
io_source,
984
schema=df.schema,
985
validate_schema=True,
986
is_pure=True,
987
)
988
989
plan = pl.explain_all([lf, lf])
990
caches: list[str] = [
991
x for x in map(str.strip, plan.splitlines()) if x.startswith("CACHE[")
992
]
993
assert len(caches) == 2
994
assert len(set(caches)) == 1
995
996
assert io_source.call_count == 0
997
998
assert_frame_equal(
999
pl.concat(pl.collect_all([lf, lf])),
1000
pl.DataFrame({"a": [1, 2, 3, 4, 5, 1, 2, 3, 4, 5]}),
1001
)
1002
1003
assert io_source.call_count == 1
1004
1005
io_source = Mock(wraps=lambda *_: iter([df]))
1006
1007
# Without explicit is_pure parameter should default to False
1008
lf = register_io_source(
1009
io_source,
1010
schema=df.schema,
1011
validate_schema=True,
1012
)
1013
1014
plan = pl.explain_all([lf, lf])
1015
1016
caches = [x for x in map(str.strip, plan.splitlines()) if x.startswith("CACHE[")]
1017
assert len(caches) == 0
1018
1019
assert io_source.call_count == 0
1020
1021
assert_frame_equal(
1022
pl.concat(pl.collect_all([lf, lf])),
1023
pl.DataFrame({"a": [1, 2, 3, 4, 5, 1, 2, 3, 4, 5]}),
1024
)
1025
1026
assert io_source.call_count == 2
1027
1028
1029
@pytest.mark.write_disk
1030
def test_cse_preferred_over_slice() -> None:
1031
# This test asserts that even if we slice disjoint sections of a lazyframe, caching
1032
# is preferred, and slicing is not pushed down
1033
df = pl.DataFrame({"a": list(range(1, 21))})
1034
with NamedTemporaryFile() as f:
1035
val = df.write_csv()
1036
f.write(val.encode())
1037
f.seek(0)
1038
ldf = pl.scan_csv(f.name)
1039
left = ldf.slice(0, 5)
1040
right = ldf.slice(6, 5)
1041
q = left.join(right, on="a", how="inner")
1042
assert "CACHE[id:" in q.explain(
1043
optimizations=pl.QueryOptFlags(comm_subplan_elim=True)
1044
)
1045
1046
1047
def test_cse_preferred_over_slice_custom_io_source() -> None:
1048
# This test asserts that even if we slice disjoint sections of a custom io source,
1049
# caching is preferred, and slicing is not pushed down
1050
df = pl.DataFrame({"a": list(range(1, 21))})
1051
lf = create_dataframe_source(df, is_pure=True)
1052
left = lf.slice(0, 5)
1053
right = lf.slice(6, 5)
1054
q = left.join(right, on="a", how="inner")
1055
assert "CACHE[id:" in q.explain(
1056
optimizations=pl.QueryOptFlags(comm_subplan_elim=True)
1057
)
1058
1059
lf = create_dataframe_source(df, is_pure=False)
1060
left = lf.slice(0, 5)
1061
right = lf.slice(6, 5)
1062
q = left.join(right, on="a", how="inner")
1063
assert "CACHE[id:" not in q.explain(
1064
optimizations=pl.QueryOptFlags(comm_subplan_elim=True)
1065
)
1066
1067
1068
def test_cse_custom_io_source_diff_columns() -> None:
1069
df = pl.DataFrame({"a": [1, 2, 3, 4, 5], "b": [10, 11, 12, 13, 14]})
1070
lf = create_dataframe_source(df, is_pure=True)
1071
collection = [lf.select("a"), lf.select("b")]
1072
assert "CACHE[id:" in pl.explain_all(collection)
1073
collected = pl.collect_all(
1074
collection, optimizations=pl.QueryOptFlags(comm_subplan_elim=True)
1075
)
1076
assert_frame_equal(df.select("a"), collected[0])
1077
assert_frame_equal(df.select("b"), collected[1])
1078
1079
1080
def test_cse_custom_io_source_diff_filters() -> None:
1081
df = pl.DataFrame({"a": [1, 2, 3, 4, 5], "b": [10, 11, 12, 13, 14]})
1082
lf = create_dataframe_source(df, is_pure=True)
1083
1084
# We use this so that the true type of the input is passed through
1085
# to the output
1086
PolarsFrame = TypeVar("PolarsFrame", pl.DataFrame, pl.LazyFrame)
1087
1088
def left_pipe(df_or_lf: PolarsFrame) -> PolarsFrame:
1089
return df_or_lf.select("a").filter(pl.col("a").is_between(2, 6))
1090
1091
def right_pipe(df_or_lf: PolarsFrame) -> PolarsFrame:
1092
return df_or_lf.select("b").filter(pl.col("b").is_between(10, 13))
1093
1094
collection = [lf.pipe(left_pipe), lf.pipe(right_pipe)]
1095
explanation = pl.explain_all(collection)
1096
# we prefer predicate pushdown over CSE
1097
assert "CACHE[id:" not in explanation
1098
assert 'SELECTION: col("a").is_between([2, 6])' in explanation
1099
assert 'SELECTION: col("b").is_between([10, 13])' in explanation
1100
1101
res = pl.collect_all(collection)
1102
expected = [df.pipe(left_pipe), df.pipe(right_pipe)]
1103
assert_frame_equal(expected[0], res[0])
1104
assert_frame_equal(expected[1], res[1])
1105
1106