Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/lazyframe/test_predicates.py
8430 views
1
from __future__ import annotations
2
3
import re
4
from datetime import date, datetime, timedelta
5
from typing import TYPE_CHECKING, Any
6
7
import numpy as np
8
import pytest
9
10
import polars as pl
11
from polars.exceptions import ComputeError, InvalidOperationError
12
from polars.io.plugins import register_io_source
13
from polars.testing import assert_frame_equal
14
from polars.testing.asserts.series import assert_series_equal
15
16
if TYPE_CHECKING:
17
from collections.abc import Iterator
18
19
from tests.conftest import PlMonkeyPatch
20
21
22
def test_predicate_4906() -> None:
23
one_day = timedelta(days=1)
24
25
ldf = pl.DataFrame(
26
{
27
"dt": [
28
date(2022, 9, 1),
29
date(2022, 9, 10),
30
date(2022, 9, 20),
31
]
32
}
33
).lazy()
34
35
assert ldf.filter(
36
pl.min_horizontal((pl.col("dt") + one_day), date(2022, 9, 30))
37
> date(2022, 9, 10)
38
).collect().to_dict(as_series=False) == {
39
"dt": [date(2022, 9, 10), date(2022, 9, 20)]
40
}
41
42
43
def test_predicate_null_block_asof_join() -> None:
44
left = (
45
pl.DataFrame(
46
{
47
"id": [1, 2, 3, 4],
48
"timestamp": [
49
datetime(2022, 1, 1, 10, 0),
50
datetime(2022, 1, 1, 10, 1),
51
datetime(2022, 1, 1, 10, 2),
52
datetime(2022, 1, 1, 10, 3),
53
],
54
}
55
)
56
.lazy()
57
.set_sorted("timestamp")
58
)
59
60
right = (
61
pl.DataFrame(
62
{
63
"id": [1, 2, 3] * 2,
64
"timestamp": [
65
datetime(2022, 1, 1, 9, 59, 50),
66
datetime(2022, 1, 1, 10, 0, 50),
67
datetime(2022, 1, 1, 10, 1, 50),
68
datetime(2022, 1, 1, 8, 0, 0),
69
datetime(2022, 1, 1, 8, 0, 0),
70
datetime(2022, 1, 1, 8, 0, 0),
71
],
72
"value": ["a", "b", "c"] * 2,
73
}
74
)
75
.lazy()
76
.set_sorted("timestamp")
77
)
78
79
assert_frame_equal(
80
left.join_asof(right, by="id", on="timestamp")
81
.filter(pl.col("value").is_not_null())
82
.collect(),
83
pl.DataFrame(
84
{
85
"id": [1, 2, 3],
86
"timestamp": [
87
datetime(2022, 1, 1, 10, 0),
88
datetime(2022, 1, 1, 10, 1),
89
datetime(2022, 1, 1, 10, 2),
90
],
91
"value": ["a", "b", "c"],
92
}
93
),
94
check_row_order=False,
95
)
96
97
98
def test_predicate_strptime_6558() -> None:
99
assert (
100
pl.DataFrame({"date": ["2022-01-03", "2020-01-04", "2021-02-03", "2019-01-04"]})
101
.lazy()
102
.select(pl.col("date").str.strptime(pl.Date, format="%F"))
103
.filter((pl.col("date").dt.year() == 2022) & (pl.col("date").dt.month() == 1))
104
.collect()
105
).to_dict(as_series=False) == {"date": [date(2022, 1, 3)]}
106
107
108
def test_predicate_arr_first_6573() -> None:
109
df = pl.DataFrame(
110
{
111
"a": [1, 2, 3, 4, 5, 6],
112
"b": [6, 5, 4, 3, 2, 1],
113
}
114
)
115
116
assert (
117
df.lazy()
118
.with_columns(pl.col("a").implode())
119
.with_columns(pl.col("a").list.first())
120
.filter(pl.col("a") == pl.col("b"))
121
.collect()
122
).to_dict(as_series=False) == {"a": [1], "b": [1]}
123
124
125
def test_fast_path_comparisons() -> None:
126
s = pl.Series(np.sort(np.random.randint(0, 50, 100)))
127
128
assert_series_equal(s > 25, s.set_sorted() > 25)
129
assert_series_equal(s >= 25, s.set_sorted() >= 25)
130
assert_series_equal(s < 25, s.set_sorted() < 25)
131
assert_series_equal(s <= 25, s.set_sorted() <= 25)
132
133
134
def test_predicate_pushdown_block_8661() -> None:
135
df = pl.DataFrame(
136
{
137
"g": [1, 1, 1, 1, 2, 2, 2, 2],
138
"t": [1, 2, 3, 4, 4, 3, 2, 1],
139
"x": [10, 20, 30, 40, 10, 20, 30, 40],
140
}
141
)
142
assert df.lazy().sort(["g", "t"]).filter(
143
(pl.col("x").shift() > 20).over("g")
144
).collect().to_dict(as_series=False) == {
145
"g": [1, 2, 2],
146
"t": [4, 2, 3],
147
"x": [40, 30, 20],
148
}
149
150
151
def test_predicate_pushdown_cumsum_9566() -> None:
152
df = pl.DataFrame({"A": range(10), "B": ["b"] * 5 + ["a"] * 5})
153
154
q = df.lazy().sort(["B", "A"]).filter(pl.col("A").is_in([8, 2]).cum_sum() == 1)
155
156
assert q.collect()["A"].to_list() == [8, 9, 0, 1]
157
158
159
def test_predicate_pushdown_join_fill_null_10058() -> None:
160
ids = pl.LazyFrame({"id": [0, 1, 2]})
161
filters = pl.LazyFrame({"id": [0, 1], "filter": [True, False]})
162
163
assert sorted(
164
ids.join(filters, how="left", on="id")
165
.filter(pl.col("filter").fill_null(True))
166
.collect()
167
.to_dict(as_series=False)["id"]
168
) == [0, 2]
169
170
171
def test_is_in_join_blocked() -> None:
172
lf1 = pl.LazyFrame(
173
{"Groups": ["A", "B", "C", "D", "E", "F"], "values0": [1, 2, 3, 4, 5, 6]}
174
)
175
lf2 = pl.LazyFrame(
176
{"values_22": [1, 2, None, 4, 5, 6], "values_20": [1, 2, 3, 4, 5, 6]}
177
)
178
lf_all = lf2.join(
179
lf1,
180
left_on="values_20",
181
right_on="values0",
182
how="left",
183
maintain_order="right_left",
184
)
185
186
for result in (
187
lf_all.filter(~pl.col("Groups").is_in(["A", "B", "F"])),
188
lf_all.remove(pl.col("Groups").is_in(["A", "B", "F"])),
189
):
190
expected = pl.LazyFrame(
191
{
192
"values_22": [None, 4, 5],
193
"values_20": [3, 4, 5],
194
"Groups": ["C", "D", "E"],
195
}
196
)
197
assert_frame_equal(result, expected)
198
199
200
def test_predicate_pushdown_group_by_keys() -> None:
201
df = pl.LazyFrame(
202
{"str": ["A", "B", "A", "B", "C"], "group": [1, 1, 2, 1, 2]}
203
).lazy()
204
q = (
205
df.group_by("group")
206
.agg([pl.len().alias("str_list")])
207
.filter(pl.col("group") == 1)
208
)
209
assert not q.explain().startswith("FILTER")
210
assert q.explain(
211
optimizations=pl.QueryOptFlags(predicate_pushdown=False)
212
).startswith("FILTER")
213
214
215
def test_no_predicate_push_down_with_cast_and_alias_11883() -> None:
216
df = pl.DataFrame({"a": [1, 2, 3]})
217
out = (
218
df.lazy()
219
.select(pl.col("a").cast(pl.Int64).alias("b"))
220
.filter(pl.col("b") == 1)
221
.filter((pl.col("b") >= 1) & (pl.col("b") < 1))
222
)
223
assert (
224
re.search(
225
r"FILTER.*FROM\n\s*DF",
226
out.explain(optimizations=pl.QueryOptFlags(predicate_pushdown=True)),
227
)
228
is None
229
)
230
231
232
@pytest.mark.parametrize(
233
"predicate",
234
[
235
0,
236
"x",
237
[2, 3],
238
{"x": 1},
239
pl.Series([1, 2, 3]),
240
None,
241
],
242
)
243
def test_invalid_filter_predicates(predicate: Any) -> None:
244
df = pl.DataFrame({"colx": ["aa", "bb", "cc", "dd"]})
245
with pytest.raises(TypeError, match="invalid predicate"):
246
df.filter(predicate)
247
248
249
def test_fast_path_boolean_filter_predicates() -> None:
250
df = pl.DataFrame({"colx": ["aa", "bb", "cc", "dd"]})
251
df_empty = df.clear()
252
253
assert_frame_equal(df.filter(False), df_empty)
254
assert_frame_equal(df.filter(True), df)
255
256
assert_frame_equal(df.remove(True), df_empty)
257
assert_frame_equal(df.remove(False), df)
258
259
260
def test_predicate_pushdown_boundary_12102() -> None:
261
df = pl.DataFrame({"x": [1, 2, 4], "y": [1, 2, 4]})
262
263
lf = (
264
df.lazy()
265
.filter(pl.col("y") > 1)
266
.filter(pl.col("x") == pl.min("x"))
267
.filter(pl.col("y") > 2)
268
)
269
270
result = lf.collect()
271
result_no_ppd = lf.collect(optimizations=pl.QueryOptFlags(predicate_pushdown=False))
272
assert_frame_equal(result, result_no_ppd)
273
274
275
def test_take_can_block_predicate_pushdown() -> None:
276
df = pl.DataFrame({"x": [1, 2, 4], "y": [False, True, True]})
277
lf = (
278
df.lazy()
279
.filter(pl.col("y"))
280
.filter(pl.col("x") == pl.col("x").gather(0))
281
.filter(pl.col("y"))
282
)
283
result = lf.collect(optimizations=pl.QueryOptFlags(predicate_pushdown=True))
284
assert result.to_dict(as_series=False) == {"x": [2], "y": [True]}
285
286
287
def test_literal_series_expr_predicate_pushdown() -> None:
288
# No pushdown should occur in this case, because otherwise the filter will
289
# attempt to filter 3 rows with a boolean mask of 2 rows.
290
lf = pl.LazyFrame({"x": [0, 1, 2]})
291
292
for res in (
293
lf.filter(pl.col("x") > 0).filter(pl.Series([True, True])),
294
lf.remove(pl.col("x") <= 0).remove(pl.Series([False, False])),
295
):
296
assert res.collect().to_series().to_list() == [1, 2]
297
298
# Pushdown should occur here; series is being used as part of an `is_in`.
299
for res in (
300
lf.filter(pl.col("x") > 0).filter(pl.col("x").is_in([0, 1])),
301
lf.remove(pl.col("x") <= 0).remove(~pl.col("x").is_in([0, 1])),
302
):
303
assert re.search(r"FILTER .*\nFROM\n\s*DF", res.explain(), re.DOTALL)
304
assert res.collect().to_series().to_list() == [1]
305
306
307
def test_multi_alias_pushdown() -> None:
308
lf = pl.LazyFrame({"a": [1], "b": [1]})
309
310
actual = lf.with_columns(m="a", n="b").filter((pl.col("m") + pl.col("n")) < 2)
311
plan = actual.explain()
312
313
assert plan.count("FILTER") == 1
314
assert re.search(r"FILTER.*FROM\n\s*DF", plan, re.DOTALL) is not None
315
316
with pytest.warns(UserWarning, match="Comparisons with None always result in null"):
317
# confirm we aren't using `eq_missing` in the query plan (denoted as " ==v ")
318
assert " ==v " not in lf.select(pl.col("a").filter(a=None)).explain()
319
320
321
def test_predicate_pushdown_with_window_projections_12637() -> None:
322
lf = pl.LazyFrame(
323
{
324
"key": [1],
325
"key_2": [1],
326
"key_3": [1],
327
"value": [1],
328
"value_2": [1],
329
"value_3": [1],
330
}
331
)
332
333
actual = lf.with_columns(
334
(pl.col("value") * 2).over("key").alias("value_2"),
335
(pl.col("value") * 2).over("key").alias("value_3"),
336
).filter(pl.col("key") == 5)
337
338
plan = actual.explain()
339
340
assert (
341
re.search(
342
r'FILTER \[\(col\("key"\)\) == \(5\)\]\s*FROM\n\s*DF', plan, re.DOTALL
343
)
344
is not None
345
)
346
assert plan.count("FILTER") == 1
347
348
actual = (
349
lf.with_columns(
350
(pl.col("value") * 2).over("key", "key_2").alias("value_2"),
351
(pl.col("value") * 2).over("key", "key_2").alias("value_3"),
352
)
353
.filter(pl.col("key") == 5)
354
.filter(pl.col("key_2") == 5)
355
)
356
357
plan = actual.explain()
358
assert plan.count("FILTER") == 1
359
assert re.search(r"FILTER.*FROM\n\s*DF", plan, re.DOTALL) is not None
360
actual = (
361
lf.with_columns(
362
(pl.col("value") * 2).over("key", "key_2").alias("value_2"),
363
(pl.col("value") * 2).over("key", "key_3").alias("value_3"),
364
)
365
.filter(pl.col("key") == 5)
366
.filter(pl.col("key_2") == 5)
367
)
368
369
plan = actual.explain()
370
assert plan.count("FILTER") == 2
371
assert (
372
re.search(
373
r'FILTER \[\(col\("key"\)\) == \(5\)\]\s*FROM\n\s*DF', plan, re.DOTALL
374
)
375
is not None
376
)
377
378
actual = (
379
lf.with_columns(
380
(pl.col("value") * 2).over("key", pl.col("key_2") + 1).alias("value_2"),
381
(pl.col("value") * 2).over("key", "key_2").alias("value_3"),
382
)
383
.filter(pl.col("key") == 5)
384
.filter(pl.col("key_2") == 5)
385
)
386
plan = actual.explain()
387
assert plan.count("FILTER") == 2
388
assert (
389
re.search(
390
r'FILTER \[\(col\("key"\)\) == \(5\)\]\s*FROM\n\s*DF', plan, re.DOTALL
391
)
392
is not None
393
)
394
395
# Should block when .over() contains groups-sensitive expr
396
actual = (
397
lf.with_columns(
398
(pl.col("value") * 2).over("key", pl.sum("key_2")).alias("value_2"),
399
(pl.col("value") * 2).over("key", "key_2").alias("value_3"),
400
)
401
.filter(pl.col("key") == 5)
402
.filter(pl.col("key_2") == 5)
403
)
404
405
plan = actual.explain()
406
assert plan.count("FILTER") == 1
407
assert "FILTER" in plan
408
assert re.search(r"FILTER.*FROM\n\s*DF", plan, re.DOTALL) is None
409
# Ensure the implementation doesn't accidentally push a window expression
410
# that only refers to the common window keys.
411
actual = lf.with_columns(
412
(pl.col("value") * 2).over("key").alias("value_2"),
413
).filter(pl.len().over("key") == 1)
414
415
plan = actual.explain()
416
assert re.search(r"FILTER.*FROM\n\s*DF", plan, re.DOTALL) is None
417
assert plan.count("FILTER") == 1
418
419
# Test window in filter
420
actual = lf.filter(pl.len().over("key") == 1).filter(pl.col("key") == 1)
421
plan = actual.explain()
422
assert plan.count("FILTER") == 2
423
assert (
424
re.search(
425
r'FILTER \[\(len\(\).over\(\[col\("key"\)\]\)\) == \(1\)\]\s*FROM\n\s*FILTER',
426
plan,
427
)
428
is not None
429
)
430
assert (
431
re.search(
432
r'FILTER \[\(col\("key"\)\) == \(1\)\]\s*FROM\n\s*DF', plan, re.DOTALL
433
)
434
is not None
435
)
436
437
438
def test_predicate_reduction() -> None:
439
# ensure we get clean reduction without casts
440
lf = pl.LazyFrame({"a": [1], "b": [2]})
441
for filter_frame in (lf.filter, lf.remove):
442
assert (
443
"cast"
444
not in filter_frame(
445
pl.col("a") > 1,
446
pl.col("b") > 1,
447
).explain()
448
)
449
450
451
def test_all_any_cleanup_at_single_predicate_case() -> None:
452
plan = pl.LazyFrame({"a": [1], "b": [2]}).select(["a"]).drop_nulls().explain()
453
assert "horizontal" not in plan
454
assert "all" not in plan
455
456
457
def test_hconcat_predicate() -> None:
458
# Predicates shouldn't be pushed down past an hconcat as we can't filter
459
# across the different inputs
460
lf1 = pl.LazyFrame(
461
{
462
"a1": [0, 1, 2, 3, 4],
463
"a2": [5, 6, 7, 8, 9],
464
}
465
)
466
lf2 = pl.LazyFrame(
467
{
468
"b1": [0, 1, 2, 3, 4],
469
"b2": [5, 6, 7, 8, 9],
470
}
471
)
472
473
query = pl.concat(
474
[
475
lf1.filter(pl.col("a1") < 4),
476
lf2.filter(pl.col("b1") > 0),
477
],
478
how="horizontal",
479
).filter(pl.col("b2") < 9)
480
481
expected = pl.DataFrame(
482
{
483
"a1": [0, 1, 2],
484
"a2": [5, 6, 7],
485
"b1": [1, 2, 3],
486
"b2": [6, 7, 8],
487
}
488
)
489
result = query.collect(optimizations=pl.QueryOptFlags(predicate_pushdown=True))
490
assert_frame_equal(result, expected)
491
492
493
def test_predicate_pd_join_13300() -> None:
494
# https://github.com/pola-rs/polars/issues/13300
495
496
lf = pl.LazyFrame({"col3": range(10, 14), "new_col": range(11, 15)})
497
lf_other = pl.LazyFrame({"col4": [0, 11, 2, 13]})
498
499
lf = lf.join(lf_other, left_on="new_col", right_on="col4", how="left")
500
for res in (
501
lf.filter(pl.col("new_col") < 12),
502
lf.remove(pl.col("new_col") >= 12),
503
):
504
assert res.collect().to_dict(as_series=False) == {"col3": [10], "new_col": [11]}
505
506
507
def test_filter_eq_missing_13861() -> None:
508
lf = pl.LazyFrame({"a": [1, None, 3], "b": ["xx", "yy", None]})
509
lf_empty = lf.clear()
510
511
with pytest.warns(UserWarning, match="Comparisons with None always result in null"):
512
assert_frame_equal(lf.collect().filter(a=None), lf_empty.collect())
513
514
with pytest.warns(UserWarning, match="Comparisons with None always result in null"):
515
assert_frame_equal(lf.collect().remove(a=None), lf.collect())
516
517
with pytest.warns(UserWarning, match="Comparisons with None always result in null"):
518
lff = lf.filter(a=None)
519
assert lff.collect().rows() == []
520
assert " ==v " not in lff.explain() # check no `eq_missing` op
521
522
with pytest.warns(UserWarning, match="Comparisons with None always result in null"):
523
assert_frame_equal(lf.collect().filter(a=None), lf_empty.collect())
524
525
with pytest.warns(UserWarning, match="Comparisons with None always result in null"):
526
assert_frame_equal(lf.collect().remove(a=None), lf.collect())
527
528
for filter_expr in (
529
pl.col("a").eq_missing(None),
530
pl.col("a").is_null(),
531
):
532
assert lf.collect().filter(filter_expr).rows() == [(None, "yy")]
533
534
535
@pytest.mark.parametrize("how", ["left", "inner"])
536
def test_predicate_pushdown_block_join(how: Any) -> None:
537
q = (
538
pl.LazyFrame({"a": [1]})
539
.join(
540
pl.LazyFrame({"a": [2], "b": [1]}),
541
left_on=["a"],
542
right_on=["b"],
543
how=how,
544
)
545
.filter(pl.col("a") == 1)
546
)
547
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), q.collect())
548
549
550
def test_predicate_push_down_with_alias_15442() -> None:
551
df = pl.DataFrame({"a": [1]})
552
output = (
553
df.lazy()
554
.filter(pl.col("a").alias("x").drop_nulls() > 0)
555
.collect(optimizations=pl.QueryOptFlags(predicate_pushdown=True))
556
)
557
assert output.to_dict(as_series=False) == {"a": [1]}
558
559
560
def test_predicate_slice_pushdown_list_gather_17492(
561
plmonkeypatch: PlMonkeyPatch,
562
) -> None:
563
lf = pl.LazyFrame({"val": [[1], [1, 1]], "len": [1, 2]})
564
565
assert_frame_equal(
566
lf.filter(pl.col("len") == 2).filter(pl.col("val").list.get(1) == 1),
567
lf.slice(1, 1),
568
)
569
570
# null_on_oob=True can pass
571
572
plan = (
573
lf.filter(pl.col("len") == 2)
574
.filter(pl.col("val").list.get(1, null_on_oob=True) == 1)
575
.explain()
576
)
577
578
assert re.search(r"FILTER.*FROM\n\s*DF", plan, re.DOTALL) is not None
579
580
# Also check slice pushdown
581
q = lf.with_columns(pl.col("val").list.get(1).alias("b")).slice(1, 1)
582
583
assert_frame_equal(
584
q.collect(),
585
pl.DataFrame(
586
{
587
"val": [[1, 1]],
588
"len": pl.Series([2], dtype=pl.Int64),
589
"b": pl.Series([1], dtype=pl.Int64),
590
}
591
),
592
)
593
594
595
def test_predicate_pushdown_struct_unnest_19632() -> None:
596
lf = pl.LazyFrame({"a": [{"a": 1, "b": 2}]}).unnest("a")
597
598
q = lf.filter(pl.col("a") == 1)
599
plan = q.explain()
600
601
assert "FILTER" in plan
602
assert plan.index("FILTER") < plan.index("UNNEST")
603
604
assert_frame_equal(
605
q.collect(),
606
pl.DataFrame({"a": 1, "b": 2}),
607
)
608
609
# With `pl.struct()`
610
lf = pl.LazyFrame({"a": 1, "b": 2}).select(pl.struct(pl.all())).unnest("a")
611
612
q = lf.filter(pl.col("a") == 1)
613
plan = q.explain()
614
615
assert "FILTER" in plan
616
assert plan.index("FILTER") < plan.index("UNNEST")
617
618
assert_frame_equal(
619
q.collect(),
620
pl.DataFrame({"a": 1, "b": 2}),
621
)
622
623
# With `value_counts()`
624
lf = pl.LazyFrame({"a": [1]}).select(pl.col("a").value_counts()).unnest("a")
625
626
q = lf.filter(pl.col("a") == 1)
627
plan = q.explain()
628
629
assert plan.index("FILTER") < plan.index("UNNEST")
630
631
assert_frame_equal(
632
q.collect(),
633
pl.DataFrame(
634
{"a": 1, "count": 1}, schema={"a": pl.Int64, "count": pl.get_index_type()}
635
),
636
)
637
638
639
@pytest.mark.parametrize(
640
"predicate",
641
[
642
pl.col("v") == 7,
643
pl.col("v") != 99,
644
pl.col("v") > 0,
645
pl.col("v") < 999,
646
pl.col("v").is_in([7]),
647
pl.col("v").cast(pl.Boolean),
648
pl.col("b"),
649
],
650
)
651
@pytest.mark.parametrize("alias", [True, False])
652
@pytest.mark.parametrize("join_type", ["left", "right"])
653
def test_predicate_pushdown_join_19772(
654
predicate: pl.Expr, join_type: str, alias: bool
655
) -> None:
656
left = pl.LazyFrame({"k": [1, 2]})
657
right = pl.LazyFrame({"k": [1], "v": [7], "b": True})
658
659
if join_type == "right":
660
[left, right] = [right, left]
661
662
if alias:
663
predicate = predicate.alias(":V")
664
665
q = left.join(right, on="k", how=join_type).filter(predicate) # type: ignore[arg-type]
666
667
expect = pl.DataFrame({"k": 1, "v": 7, "b": True})
668
669
if join_type == "right":
670
expect = expect.select("v", "b", "k")
671
672
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
673
assert_frame_equal(q.collect(), expect)
674
675
676
def test_predicate_pushdown_scalar_20489() -> None:
677
df = pl.DataFrame({"a": [1]})
678
mask = pl.Series([False])
679
680
assert_frame_equal(
681
df.lazy().with_columns(b=pl.Series([2])).filter(mask).collect(),
682
pl.DataFrame(schema={"a": pl.Int64, "b": pl.Int64}),
683
)
684
685
686
def test_predicates_not_split_when_pushdown_disabled_20475() -> None:
687
# This is important for the eager `DataFrame.filter()`, as that runs without
688
# predicate pushdown enabled. Splitting the predicates in that case can
689
# severely degrade performance.
690
q = pl.LazyFrame({"a": 1, "b": 1, "c": 1}).filter(
691
pl.col("a") > 0, pl.col("b") > 0, pl.col("c") > 0
692
)
693
assert (
694
q.explain(optimizations=pl.QueryOptFlags(predicate_pushdown=False)).count(
695
"FILTER"
696
)
697
== 1
698
)
699
700
701
def test_predicate_filtering_against_nulls() -> None:
702
df = pl.DataFrame({"num": [1, 2, None, 4]})
703
704
for res in (
705
df.filter(pl.col("num") > 2),
706
df.filter(pl.col("num").is_in([3, 4, 5])),
707
):
708
assert res["num"].to_list() == [4]
709
710
for res in (
711
df.remove(pl.col("num") <= 2),
712
df.remove(pl.col("num").is_in([1, 2, 3])),
713
):
714
assert res["num"].to_list() == [None, 4]
715
716
for res in (
717
df.filter(pl.col("num").ne_missing(None)),
718
df.remove(pl.col("num").eq_missing(None)),
719
):
720
assert res["num"].to_list() == [1, 2, 4]
721
722
723
@pytest.mark.parametrize(
724
("query", "expected"),
725
[
726
(
727
(
728
pl.LazyFrame({"a": [1], "b": [2], "c": [3]})
729
.rename({"a": "A", "b": "a"})
730
.select("A", "c")
731
.filter(pl.col("A") == 1)
732
),
733
pl.DataFrame({"A": 1, "c": 3}),
734
),
735
(
736
(
737
pl.LazyFrame({"a": [1], "b": [2], "c": [3]})
738
.rename({"b": "a", "a": "A"})
739
.select("A", "c")
740
.filter(pl.col("A") == 1)
741
),
742
pl.DataFrame({"A": 1, "c": 3}),
743
),
744
(
745
(
746
pl.LazyFrame({"a": [1], "b": [2], "c": [3]})
747
.rename({"a": "b", "b": "a"})
748
.select("a", "b", "c")
749
.filter(pl.col("b") == 1)
750
),
751
pl.DataFrame({"a": 2, "b": 1, "c": 3}),
752
),
753
(
754
(
755
pl.LazyFrame({"a": [1], "b": [2], "c": [3]})
756
.rename({"a": "b", "b": "a"})
757
.select("b", "c")
758
.filter(pl.col("b") == 1)
759
),
760
pl.DataFrame({"b": 1, "c": 3}),
761
),
762
(
763
(
764
pl.LazyFrame({"a": [1], "b": [2], "c": [3]})
765
.rename({"b": "a", "a": "b"})
766
.select("a", "b", "c")
767
.filter(pl.col("b") == 1)
768
),
769
pl.DataFrame({"a": 2, "b": 1, "c": 3}),
770
),
771
],
772
)
773
def test_predicate_pushdown_lazy_rename_22373(
774
query: pl.LazyFrame,
775
expected: pl.DataFrame,
776
) -> None:
777
assert_frame_equal(
778
query.collect(),
779
expected,
780
)
781
782
# Ensure filter is pushed past rename
783
plan = query.explain()
784
assert plan.index("FILTER") > plan.index("SELECT")
785
786
787
@pytest.mark.parametrize(
788
"base_query",
789
[
790
( # Fallible expr in earlier `with_columns()`
791
pl.LazyFrame({"a": [[1]]})
792
.with_columns(MARKER=1)
793
.with_columns(b=pl.col("a").list.get(1, null_on_oob=False))
794
),
795
( # Fallible expr in earlier `filter()`
796
pl.LazyFrame({"a": [[1]]})
797
.with_columns(MARKER=1)
798
.filter(
799
pl.col("a")
800
.list.get(1, null_on_oob=False)
801
.cast(pl.Boolean, strict=False)
802
)
803
),
804
( # Fallible expr in earlier `select()`
805
pl.LazyFrame({"a": [[1]]})
806
.with_columns(MARKER=1)
807
.select("a", "MARKER", b=pl.col("a").list.get(1, null_on_oob=False))
808
),
809
],
810
)
811
def test_predicate_pushdown_pushes_past_fallible(
812
base_query: pl.LazyFrame, plmonkeypatch: PlMonkeyPatch
813
) -> None:
814
# Ensure baseline fails
815
with pytest.raises(ComputeError, match="index is out of bounds"):
816
base_query.collect()
817
818
q = base_query.filter(pl.col("a").list.len() > 1)
819
820
plan = q.explain()
821
822
assert plan.index("list.len") > plan.index("MARKER")
823
824
assert_frame_equal(q.collect(), pl.DataFrame(schema=q.collect_schema()))
825
826
plmonkeypatch.setenv("POLARS_PUSHDOWN_OPT_MAINTAIN_ERRORS", "1")
827
828
with pytest.raises(ComputeError, match="index is out of bounds"):
829
q.collect()
830
831
832
def test_predicate_pushdown_fallible_exprs_22284(
833
plmonkeypatch: PlMonkeyPatch,
834
) -> None:
835
q = (
836
pl.LazyFrame({"a": ["xyz", "123", "456", "789"]})
837
.with_columns(MARKER=1)
838
.filter(pl.col.a.str.contains(r"^\d{3}$"))
839
.filter(pl.col.a.cast(pl.Int64) >= 123)
840
)
841
842
plan = q.explain()
843
844
assert (
845
plan.index('FILTER [(col("a").strict_cast(Int64)) >= (123)]')
846
< plan.index("MARKER")
847
< plan.index(r'FILTER col("a").str.contains(["^\d{3}$"])')
848
)
849
850
assert_frame_equal(
851
q.collect(),
852
pl.DataFrame(
853
{
854
"a": ["123", "456", "789"],
855
"MARKER": 1,
856
}
857
),
858
)
859
860
lf = pl.LazyFrame(
861
{
862
"str_date": ["2025-01-01", "20250101"],
863
"data_source": ["system_1", "system_2"],
864
}
865
)
866
867
q = lf.filter(pl.col("data_source") == "system_1").filter(
868
pl.col("str_date").str.to_datetime("%Y-%m-%d", strict=True)
869
== datetime(2025, 1, 1)
870
)
871
872
assert_frame_equal(
873
q.collect(),
874
pl.DataFrame(
875
{
876
"str_date": ["2025-01-01"],
877
"data_source": ["system_1"],
878
}
879
),
880
)
881
882
q = lf.with_columns(
883
pl.col("str_date").str.to_datetime("%Y-%m-%d", strict=True)
884
).filter(pl.col("data_source") == "system_1")
885
886
assert_frame_equal(
887
q.collect(),
888
pl.DataFrame(
889
{
890
"str_date": [datetime(2025, 1, 1)],
891
"data_source": ["system_1"],
892
}
893
),
894
)
895
896
plmonkeypatch.setenv("POLARS_PUSHDOWN_OPT_MAINTAIN_ERRORS", "1")
897
898
with pytest.raises(
899
InvalidOperationError, match=r"`str` to `datetime\[μs\]` failed"
900
):
901
q.collect()
902
903
904
def test_predicate_pushdown_single_fallible() -> None:
905
lf = pl.LazyFrame({"a": [0, 1]}).with_columns(MARKER=pl.lit(1, dtype=pl.Int64))
906
907
q = lf.filter(pl.col("a").cast(pl.Boolean))
908
909
plan = q.explain()
910
911
assert plan.index('FILTER col("a").strict_cast(Boolean)') > plan.index("MARKER")
912
913
assert_frame_equal(q.collect(), pl.DataFrame({"a": 1, "MARKER": 1}))
914
915
916
def test_predicate_pushdown_split_pushable(
917
plmonkeypatch: PlMonkeyPatch,
918
) -> None:
919
lf = pl.LazyFrame({"a": [1, 999]}).with_columns(MARKER=pl.lit(1, dtype=pl.Int64))
920
921
q = lf.filter(
922
pl.col("a") == 1, # pushable
923
pl.col("a").cast(pl.Int8) == 1, # fallible
924
)
925
926
plan = q.explain()
927
928
assert (
929
plan.index('FILTER [(col("a").strict_cast(Int8)) == (1)]')
930
< plan.index("MARKER")
931
< plan.index('FILTER [(col("a")) == (1)]')
932
)
933
934
assert_frame_equal(q.collect(), pl.DataFrame({"a": 1, "MARKER": 1}))
935
936
with plmonkeypatch.context() as cx:
937
cx.setenv("POLARS_PUSHDOWN_OPT_MAINTAIN_ERRORS", "1")
938
939
with pytest.raises(
940
InvalidOperationError, match="conversion from `i64` to `i8` failed"
941
):
942
q.collect()
943
944
q = lf.filter(
945
pl.col("a").cast(pl.UInt16) == 1,
946
pl.col("a").sort() == 1,
947
)
948
949
plan = q.explain()
950
951
assert plan.index(
952
'FILTER [([(col("a").strict_cast(UInt16)) == (1)]) & ([(col("a").sort(asc)) == (1)])]'
953
) < plan.index("MARKER")
954
955
assert_frame_equal(q.collect(), pl.DataFrame({"a": 1, "MARKER": 1}))
956
957
with plmonkeypatch.context() as cx:
958
cx.setenv("POLARS_PUSHDOWN_OPT_MAINTAIN_ERRORS", "1")
959
assert_frame_equal(q.collect(), pl.DataFrame({"a": 1, "MARKER": 1}))
960
961
# Ensure it is not pushed past a join
962
963
# Baseline
964
q = lf.join(
965
lf.drop("MARKER").collect().lazy(),
966
on="a",
967
how="inner",
968
coalesce=False,
969
maintain_order="left_right",
970
).filter(pl.col("a_right") == 1)
971
972
plan = q.explain()
973
974
assert not plan.startswith("FILTER")
975
976
assert_frame_equal(
977
q.collect(),
978
pl.DataFrame(
979
{
980
"a": 1,
981
"MARKER": 1,
982
"a_right": 1,
983
}
984
),
985
)
986
987
q = lf.join(
988
lf.drop("MARKER").collect().lazy(),
989
on="a",
990
how="inner",
991
coalesce=False,
992
maintain_order="left_right",
993
).filter(pl.col("a_right").cast(pl.Int16) == 1)
994
995
plan = q.explain()
996
997
assert plan.startswith("FILTER")
998
999
assert_frame_equal(
1000
q.collect(),
1001
pl.DataFrame(
1002
{
1003
"a": 1,
1004
"MARKER": 1,
1005
"a_right": 1,
1006
}
1007
),
1008
)
1009
1010
# With a select node in between
1011
1012
q = (
1013
lf.join(
1014
lf.drop("MARKER").collect().lazy(),
1015
on="a",
1016
how="inner",
1017
coalesce=False,
1018
maintain_order="left_right",
1019
)
1020
.select(
1021
"a",
1022
"a_right",
1023
"MARKER",
1024
)
1025
.filter(pl.col("a_right").cast(pl.Int16) == 1)
1026
)
1027
1028
plan = q.explain()
1029
1030
assert plan.startswith("FILTER")
1031
1032
assert_frame_equal(
1033
q.collect(),
1034
pl.DataFrame(
1035
{
1036
"a": 1,
1037
"a_right": 1,
1038
"MARKER": 1,
1039
}
1040
),
1041
)
1042
1043
1044
def test_predicate_pushdown_fallible_literal_in_filter_expr() -> None:
1045
# Fallible operations on literals inside of the predicate expr should not
1046
# block pushdown.
1047
1048
# Pushdown will also push any fallible expression if it's the only accumulated
1049
# predicate, we insert this dummy predicate to ensure the predicate is being
1050
# pushed solely because it is considered infallible.
1051
dummy_predicate = pl.lit(1) == pl.lit(1)
1052
1053
lf = pl.LazyFrame(
1054
{"column": "2025-01-01", "column_date": datetime(2025, 1, 1), "integer": 1}
1055
)
1056
1057
q = lf.with_columns(
1058
MARKER=1,
1059
).filter(
1060
pl.col("column_date")
1061
== pl.lit("2025-01-01").str.to_datetime("%Y-%m-%d", strict=True),
1062
dummy_predicate,
1063
)
1064
1065
plan = q.explain()
1066
1067
assert plan.index("FILTER") > plan.index("MARKER")
1068
1069
assert q.collect().height == 1
1070
1071
q = lf.with_columns(
1072
MARKER=1,
1073
).filter(
1074
pl.col("column_date") == pl.lit("2025-01-01").str.strptime(pl.Datetime),
1075
dummy_predicate,
1076
)
1077
1078
plan = q.explain()
1079
1080
assert plan.index("FILTER") > plan.index("MARKER")
1081
1082
assert q.collect().height == 1
1083
1084
q = lf.with_columns(
1085
MARKER=1,
1086
).filter(
1087
pl.col("integer") == pl.lit("1").cast(pl.Int64, strict=True), dummy_predicate
1088
)
1089
1090
plan = q.explain()
1091
1092
assert plan.index("FILTER") > plan.index("MARKER")
1093
1094
assert q.collect().height == 1
1095
1096
1097
def test_predicate_does_not_split_barrier_expr() -> None:
1098
q = (
1099
pl.LazyFrame({"a": [1, 2, 3]})
1100
.with_row_index()
1101
.filter(pl.col("a") > 1, pl.col("a").sort() == 3)
1102
)
1103
1104
plan = q.explain()
1105
1106
assert plan.startswith(
1107
'FILTER [([(col("a")) > (1)]) & ([(col("a").sort(asc)) == (3)])]'
1108
)
1109
1110
assert_frame_equal(
1111
q.collect(),
1112
pl.DataFrame({"a": 3}).with_row_index(offset=2),
1113
)
1114
1115
1116
def test_predicate_passes_set_sorted_22397() -> None:
1117
plan = (
1118
pl.LazyFrame({"a": [1, 2, 3]})
1119
.with_columns(MARKER=1, b=pl.lit(1))
1120
.set_sorted("a")
1121
.filter(pl.col("a") <= 1)
1122
.explain()
1123
)
1124
assert plan.index("FILTER") > plan.index("MARKER")
1125
1126
1127
@pytest.mark.filterwarnings("ignore")
1128
def test_predicate_pass() -> None:
1129
plan = (
1130
pl.LazyFrame({"a": [1, 2, 3]})
1131
.with_columns(MARKER=pl.col("a"))
1132
.filter(pl.col("a").map_elements(lambda x: x > 2, return_dtype=pl.Boolean))
1133
.explain()
1134
)
1135
assert plan.index("FILTER") > plan.index("MARKER")
1136
1137
1138
def test_predicate_pushdown_auto_disable_strict() -> None:
1139
# Test that type-coercion automatically switches strict cast to
1140
# non-strict/overflowing for compatible types, allowing the predicate to be
1141
# pushed.
1142
lf = pl.LazyFrame(
1143
{"column": "2025-01-01", "column_date": datetime(2025, 1, 1), "integer": 1},
1144
schema={
1145
"column": pl.String,
1146
"column_date": pl.Datetime("ns"),
1147
"integer": pl.Int64,
1148
},
1149
)
1150
1151
q = lf.with_columns(
1152
MARKER=1,
1153
).filter(
1154
pl.col("column_date").cast(pl.Datetime("us")) == pl.lit(datetime(2025, 1, 1)),
1155
pl.col("integer") == 1,
1156
)
1157
1158
plan = q.explain()
1159
assert plan.index("FILTER") > plan.index("MARKER")
1160
1161
q = lf.with_columns(
1162
MARKER=1,
1163
).filter(
1164
pl.col("column_date").cast(pl.Datetime("us"), strict=False)
1165
== pl.lit(datetime(2025, 1, 1)),
1166
pl.col("integer").cast(pl.Int128, strict=True) == 1,
1167
)
1168
1169
plan = q.explain()
1170
assert plan.index("FILTER") > plan.index("MARKER")
1171
1172
1173
@pytest.mark.may_fail_auto_streaming # IO plugin validate=False schema mismatch
1174
def test_predicate_pushdown_map_elements_io_plugin_22860() -> None:
1175
def generator(
1176
with_columns: list[str] | None,
1177
predicate: pl.Expr | None,
1178
n_rows: int | None,
1179
batch_size: int | None,
1180
) -> Iterator[pl.DataFrame]:
1181
df = pl.DataFrame({"row_nr": [1, 2, 3, 4, 5], "y": [0, 1, 0, 1, 1]})
1182
assert predicate is not None
1183
yield df.filter(predicate)
1184
1185
q = register_io_source(
1186
io_source=generator, schema={"x": pl.Int64, "y": pl.Int64}
1187
).filter(pl.col("y").map_elements(bool, return_dtype=pl.Boolean))
1188
1189
plan = q.explain()
1190
assert plan.index("SELECTION") > plan.index("PYTHON SCAN")
1191
1192
assert_frame_equal(q.collect(), pl.DataFrame({"row_nr": [2, 4, 5], "y": [1, 1, 1]}))
1193
1194
1195
def test_duplicate_filter_removal_23243() -> None:
1196
lf = pl.LazyFrame({"x": [1, 2, 3]})
1197
1198
q = lf.filter(pl.col("x") == 2, pl.col("x") == 2)
1199
1200
expect = pl.DataFrame({"x": [2]})
1201
1202
plan = q.explain()
1203
1204
assert plan.split("\n", 1)[0] == 'FILTER [(col("x")) == (2)]'
1205
1206
assert_frame_equal(q.collect(), expect)
1207
1208
1209
@pytest.mark.parametrize("maintain_order", [True, False])
1210
def test_no_predicate_pushdown_on_modified_groupby_keys_21439(
1211
maintain_order: bool,
1212
) -> None:
1213
df = pl.DataFrame({"a": [1, 2, 3]})
1214
q = (
1215
df.lazy()
1216
.group_by(pl.col.a + 1, maintain_order=maintain_order)
1217
.agg()
1218
.filter(pl.col.a <= 3)
1219
)
1220
expected = pl.DataFrame({"a": [2, 3]})
1221
assert_frame_equal(q.collect(), expected, check_row_order=maintain_order)
1222
1223
df = pl.DataFrame({"a": [1, 2, 3], "b": [1, 2, 3]})
1224
q = (
1225
df.lazy()
1226
.group_by([(pl.col.a + 1).alias("b"), pl.col.b.alias("a")], maintain_order=True)
1227
.agg()
1228
.filter(pl.col.b <= 2)
1229
.select(pl.col.b)
1230
)
1231
expected = pl.DataFrame({"b": [2]})
1232
assert_frame_equal(q.collect(), expected, check_row_order=maintain_order)
1233
1234
1235
def test_no_predicate_pushdown_on_modified_groupby_keys_21439b() -> None:
1236
df = pl.DataFrame(
1237
{
1238
"time": pl.datetime_range(
1239
datetime(2021, 1, 1),
1240
datetime(2021, 1, 2),
1241
timedelta(minutes=15),
1242
eager=True,
1243
)
1244
}
1245
)
1246
eager = (
1247
df.group_by(pl.col("time").dt.hour())
1248
.agg()
1249
.filter(pl.col("time").is_between(0, 10))
1250
)
1251
lazy = (
1252
df.lazy()
1253
.group_by(pl.col("time").dt.hour())
1254
.agg()
1255
.filter(pl.col("time").is_between(0, 10))
1256
.collect()
1257
)
1258
assert_frame_equal(eager, lazy, check_row_order=False)
1259
1260
1261
def test_no_predicate_pushdown_unpivot() -> None:
1262
data = {"a": [5, 2, 8, 2], "b": [99, 33, 77, 44]}
1263
1264
for index, pred in [("a", pl.col.a == 2), (["b", "a"], pl.col.b != 33)]:
1265
lf = pl.LazyFrame(data).unpivot(on="b", index=index).filter(pred)
1266
plan = lf.explain()
1267
assert plan.index("FILTER") > plan.index("UNPIVOT")
1268
1269
1270
def test_replace_strict_predicate_merging() -> None:
1271
df = pl.LazyFrame({"x": [True, True, True, False]})
1272
out = (
1273
df.filter(pl.col("x")).filter(pl.col("x").replace_strict(True, True)).collect()
1274
)
1275
assert out.height == 3
1276
1277