Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/test_predicates.py
6939 views
1
from __future__ import annotations
2
3
import re
4
from datetime import date, datetime, timedelta
5
from typing import TYPE_CHECKING, Any
6
7
import numpy as np
8
import pytest
9
10
import polars as pl
11
from polars.exceptions import ComputeError, InvalidOperationError
12
from polars.io.plugins import register_io_source
13
from polars.testing import assert_frame_equal
14
from polars.testing.asserts.series import assert_series_equal
15
16
if TYPE_CHECKING:
17
from collections.abc import Iterator
18
19
20
def test_predicate_4906() -> None:
21
one_day = timedelta(days=1)
22
23
ldf = pl.DataFrame(
24
{
25
"dt": [
26
date(2022, 9, 1),
27
date(2022, 9, 10),
28
date(2022, 9, 20),
29
]
30
}
31
).lazy()
32
33
assert ldf.filter(
34
pl.min_horizontal((pl.col("dt") + one_day), date(2022, 9, 30))
35
> date(2022, 9, 10)
36
).collect().to_dict(as_series=False) == {
37
"dt": [date(2022, 9, 10), date(2022, 9, 20)]
38
}
39
40
41
def test_predicate_null_block_asof_join() -> None:
42
left = (
43
pl.DataFrame(
44
{
45
"id": [1, 2, 3, 4],
46
"timestamp": [
47
datetime(2022, 1, 1, 10, 0),
48
datetime(2022, 1, 1, 10, 1),
49
datetime(2022, 1, 1, 10, 2),
50
datetime(2022, 1, 1, 10, 3),
51
],
52
}
53
)
54
.lazy()
55
.set_sorted("timestamp")
56
)
57
58
right = (
59
pl.DataFrame(
60
{
61
"id": [1, 2, 3] * 2,
62
"timestamp": [
63
datetime(2022, 1, 1, 9, 59, 50),
64
datetime(2022, 1, 1, 10, 0, 50),
65
datetime(2022, 1, 1, 10, 1, 50),
66
datetime(2022, 1, 1, 8, 0, 0),
67
datetime(2022, 1, 1, 8, 0, 0),
68
datetime(2022, 1, 1, 8, 0, 0),
69
],
70
"value": ["a", "b", "c"] * 2,
71
}
72
)
73
.lazy()
74
.set_sorted("timestamp")
75
)
76
77
assert_frame_equal(
78
left.join_asof(right, by="id", on="timestamp")
79
.filter(pl.col("value").is_not_null())
80
.collect(),
81
pl.DataFrame(
82
{
83
"id": [1, 2, 3],
84
"timestamp": [
85
datetime(2022, 1, 1, 10, 0),
86
datetime(2022, 1, 1, 10, 1),
87
datetime(2022, 1, 1, 10, 2),
88
],
89
"value": ["a", "b", "c"],
90
}
91
),
92
check_row_order=False,
93
)
94
95
96
def test_predicate_strptime_6558() -> None:
97
assert (
98
pl.DataFrame({"date": ["2022-01-03", "2020-01-04", "2021-02-03", "2019-01-04"]})
99
.lazy()
100
.select(pl.col("date").str.strptime(pl.Date, format="%F"))
101
.filter((pl.col("date").dt.year() == 2022) & (pl.col("date").dt.month() == 1))
102
.collect()
103
).to_dict(as_series=False) == {"date": [date(2022, 1, 3)]}
104
105
106
def test_predicate_arr_first_6573() -> None:
107
df = pl.DataFrame(
108
{
109
"a": [1, 2, 3, 4, 5, 6],
110
"b": [6, 5, 4, 3, 2, 1],
111
}
112
)
113
114
assert (
115
df.lazy()
116
.with_columns(pl.col("a").implode())
117
.with_columns(pl.col("a").list.first())
118
.filter(pl.col("a") == pl.col("b"))
119
.collect()
120
).to_dict(as_series=False) == {"a": [1], "b": [1]}
121
122
123
def test_fast_path_comparisons() -> None:
124
s = pl.Series(np.sort(np.random.randint(0, 50, 100)))
125
126
assert_series_equal(s > 25, s.set_sorted() > 25)
127
assert_series_equal(s >= 25, s.set_sorted() >= 25)
128
assert_series_equal(s < 25, s.set_sorted() < 25)
129
assert_series_equal(s <= 25, s.set_sorted() <= 25)
130
131
132
def test_predicate_pushdown_block_8661() -> None:
133
df = pl.DataFrame(
134
{
135
"g": [1, 1, 1, 1, 2, 2, 2, 2],
136
"t": [1, 2, 3, 4, 4, 3, 2, 1],
137
"x": [10, 20, 30, 40, 10, 20, 30, 40],
138
}
139
)
140
assert df.lazy().sort(["g", "t"]).filter(
141
(pl.col("x").shift() > 20).over("g")
142
).collect().to_dict(as_series=False) == {
143
"g": [1, 2, 2],
144
"t": [4, 2, 3],
145
"x": [40, 30, 20],
146
}
147
148
149
def test_predicate_pushdown_cumsum_9566() -> None:
150
df = pl.DataFrame({"A": range(10), "B": ["b"] * 5 + ["a"] * 5})
151
152
q = df.lazy().sort(["B", "A"]).filter(pl.col("A").is_in([8, 2]).cum_sum() == 1)
153
154
assert q.collect()["A"].to_list() == [8, 9, 0, 1]
155
156
157
def test_predicate_pushdown_join_fill_null_10058() -> None:
158
ids = pl.LazyFrame({"id": [0, 1, 2]})
159
filters = pl.LazyFrame({"id": [0, 1], "filter": [True, False]})
160
161
assert sorted(
162
ids.join(filters, how="left", on="id")
163
.filter(pl.col("filter").fill_null(True))
164
.collect()
165
.to_dict(as_series=False)["id"]
166
) == [0, 2]
167
168
169
def test_is_in_join_blocked() -> None:
170
lf1 = pl.LazyFrame(
171
{"Groups": ["A", "B", "C", "D", "E", "F"], "values0": [1, 2, 3, 4, 5, 6]}
172
)
173
lf2 = pl.LazyFrame(
174
{"values_22": [1, 2, None, 4, 5, 6], "values_20": [1, 2, 3, 4, 5, 6]}
175
)
176
lf_all = lf2.join(
177
lf1,
178
left_on="values_20",
179
right_on="values0",
180
how="left",
181
maintain_order="right_left",
182
)
183
184
for result in (
185
lf_all.filter(~pl.col("Groups").is_in(["A", "B", "F"])),
186
lf_all.remove(pl.col("Groups").is_in(["A", "B", "F"])),
187
):
188
expected = pl.LazyFrame(
189
{
190
"values_22": [None, 4, 5],
191
"values_20": [3, 4, 5],
192
"Groups": ["C", "D", "E"],
193
}
194
)
195
assert_frame_equal(result, expected)
196
197
198
def test_predicate_pushdown_group_by_keys() -> None:
199
df = pl.LazyFrame(
200
{"str": ["A", "B", "A", "B", "C"], "group": [1, 1, 2, 1, 2]}
201
).lazy()
202
assert (
203
"SELECTION: None"
204
not in df.group_by("group")
205
.agg([pl.len().alias("str_list")])
206
.filter(pl.col("group") == 1)
207
.explain()
208
)
209
210
211
def test_no_predicate_push_down_with_cast_and_alias_11883() -> None:
212
df = pl.DataFrame({"a": [1, 2, 3]})
213
out = (
214
df.lazy()
215
.select(pl.col("a").cast(pl.Int64).alias("b"))
216
.filter(pl.col("b") == 1)
217
.filter((pl.col("b") >= 1) & (pl.col("b") < 1))
218
)
219
assert (
220
re.search(
221
r"FILTER.*FROM\n\s*DF",
222
out.explain(optimizations=pl.QueryOptFlags(predicate_pushdown=True)),
223
)
224
is None
225
)
226
227
228
@pytest.mark.parametrize(
229
"predicate",
230
[
231
0,
232
"x",
233
[2, 3],
234
{"x": 1},
235
pl.Series([1, 2, 3]),
236
None,
237
],
238
)
239
def test_invalid_filter_predicates(predicate: Any) -> None:
240
df = pl.DataFrame({"colx": ["aa", "bb", "cc", "dd"]})
241
with pytest.raises(TypeError, match="invalid predicate"):
242
df.filter(predicate)
243
244
245
def test_fast_path_boolean_filter_predicates() -> None:
246
df = pl.DataFrame({"colx": ["aa", "bb", "cc", "dd"]})
247
df_empty = df.clear()
248
249
assert_frame_equal(df.filter(False), df_empty)
250
assert_frame_equal(df.filter(True), df)
251
252
assert_frame_equal(df.remove(True), df_empty)
253
assert_frame_equal(df.remove(False), df)
254
255
256
def test_predicate_pushdown_boundary_12102() -> None:
257
df = pl.DataFrame({"x": [1, 2, 4], "y": [1, 2, 4]})
258
259
lf = (
260
df.lazy()
261
.filter(pl.col("y") > 1)
262
.filter(pl.col("x") == pl.min("x"))
263
.filter(pl.col("y") > 2)
264
)
265
266
result = lf.collect()
267
result_no_ppd = lf.collect(optimizations=pl.QueryOptFlags(predicate_pushdown=False))
268
assert_frame_equal(result, result_no_ppd)
269
270
271
def test_take_can_block_predicate_pushdown() -> None:
272
df = pl.DataFrame({"x": [1, 2, 4], "y": [False, True, True]})
273
lf = (
274
df.lazy()
275
.filter(pl.col("y"))
276
.filter(pl.col("x") == pl.col("x").gather(0))
277
.filter(pl.col("y"))
278
)
279
result = lf.collect(optimizations=pl.QueryOptFlags(predicate_pushdown=True))
280
assert result.to_dict(as_series=False) == {"x": [2], "y": [True]}
281
282
283
def test_literal_series_expr_predicate_pushdown() -> None:
284
# No pushdown should occur in this case, because otherwise the filter will
285
# attempt to filter 3 rows with a boolean mask of 2 rows.
286
lf = pl.LazyFrame({"x": [0, 1, 2]})
287
288
for res in (
289
lf.filter(pl.col("x") > 0).filter(pl.Series([True, True])),
290
lf.remove(pl.col("x") <= 0).remove(pl.Series([False, False])),
291
):
292
assert res.collect().to_series().to_list() == [1, 2]
293
294
# Pushdown should occur here; series is being used as part of an `is_in`.
295
for res in (
296
lf.filter(pl.col("x") > 0).filter(pl.col("x").is_in([0, 1])),
297
lf.remove(pl.col("x") <= 0).remove(~pl.col("x").is_in([0, 1])),
298
):
299
assert re.search(r"FILTER .*\nFROM\n\s*DF", res.explain(), re.DOTALL)
300
assert res.collect().to_series().to_list() == [1]
301
302
303
def test_multi_alias_pushdown() -> None:
304
lf = pl.LazyFrame({"a": [1], "b": [1]})
305
306
actual = lf.with_columns(m="a", n="b").filter((pl.col("m") + pl.col("n")) < 2)
307
plan = actual.explain()
308
309
assert plan.count("FILTER") == 1
310
assert re.search(r"FILTER.*FROM\n\s*DF", plan, re.DOTALL) is not None
311
312
with pytest.warns(UserWarning, match="Comparisons with None always result in null"):
313
# confirm we aren't using `eq_missing` in the query plan (denoted as " ==v ")
314
assert " ==v " not in lf.select(pl.col("a").filter(a=None)).explain()
315
316
317
def test_predicate_pushdown_with_window_projections_12637() -> None:
318
lf = pl.LazyFrame(
319
{
320
"key": [1],
321
"key_2": [1],
322
"key_3": [1],
323
"value": [1],
324
"value_2": [1],
325
"value_3": [1],
326
}
327
)
328
329
actual = lf.with_columns(
330
(pl.col("value") * 2).over("key").alias("value_2"),
331
(pl.col("value") * 2).over("key").alias("value_3"),
332
).filter(pl.col("key") == 5)
333
334
plan = actual.explain()
335
336
assert (
337
re.search(
338
r'FILTER \[\(col\("key"\)\) == \(5\)\]\s*FROM\n\s*DF', plan, re.DOTALL
339
)
340
is not None
341
)
342
assert plan.count("FILTER") == 1
343
344
actual = (
345
lf.with_columns(
346
(pl.col("value") * 2).over("key", "key_2").alias("value_2"),
347
(pl.col("value") * 2).over("key", "key_2").alias("value_3"),
348
)
349
.filter(pl.col("key") == 5)
350
.filter(pl.col("key_2") == 5)
351
)
352
353
plan = actual.explain()
354
assert plan.count("FILTER") == 1
355
assert re.search(r"FILTER.*FROM\n\s*DF", plan, re.DOTALL) is not None
356
actual = (
357
lf.with_columns(
358
(pl.col("value") * 2).over("key", "key_2").alias("value_2"),
359
(pl.col("value") * 2).over("key", "key_3").alias("value_3"),
360
)
361
.filter(pl.col("key") == 5)
362
.filter(pl.col("key_2") == 5)
363
)
364
365
plan = actual.explain()
366
assert plan.count("FILTER") == 2
367
assert (
368
re.search(
369
r'FILTER \[\(col\("key"\)\) == \(5\)\]\s*FROM\n\s*DF', plan, re.DOTALL
370
)
371
is not None
372
)
373
374
actual = (
375
lf.with_columns(
376
(pl.col("value") * 2).over("key", pl.col("key_2") + 1).alias("value_2"),
377
(pl.col("value") * 2).over("key", "key_2").alias("value_3"),
378
)
379
.filter(pl.col("key") == 5)
380
.filter(pl.col("key_2") == 5)
381
)
382
plan = actual.explain()
383
assert plan.count("FILTER") == 2
384
assert (
385
re.search(
386
r'FILTER \[\(col\("key"\)\) == \(5\)\]\s*FROM\n\s*DF', plan, re.DOTALL
387
)
388
is not None
389
)
390
391
# Should block when .over() contains groups-sensitive expr
392
actual = (
393
lf.with_columns(
394
(pl.col("value") * 2).over("key", pl.sum("key_2")).alias("value_2"),
395
(pl.col("value") * 2).over("key", "key_2").alias("value_3"),
396
)
397
.filter(pl.col("key") == 5)
398
.filter(pl.col("key_2") == 5)
399
)
400
401
plan = actual.explain()
402
assert plan.count("FILTER") == 1
403
assert "FILTER" in plan
404
assert re.search(r"FILTER.*FROM\n\s*DF", plan, re.DOTALL) is None
405
# Ensure the implementation doesn't accidentally push a window expression
406
# that only refers to the common window keys.
407
actual = lf.with_columns(
408
(pl.col("value") * 2).over("key").alias("value_2"),
409
).filter(pl.len().over("key") == 1)
410
411
plan = actual.explain()
412
assert re.search(r"FILTER.*FROM\n\s*DF", plan, re.DOTALL) is None
413
assert plan.count("FILTER") == 1
414
415
# Test window in filter
416
actual = lf.filter(pl.len().over("key") == 1).filter(pl.col("key") == 1)
417
plan = actual.explain()
418
assert plan.count("FILTER") == 2
419
assert (
420
re.search(
421
r'FILTER \[\(len\(\).over\(\[col\("key"\)\]\)\) == \(1\)\]\s*FROM\n\s*FILTER',
422
plan,
423
)
424
is not None
425
)
426
assert (
427
re.search(
428
r'FILTER \[\(col\("key"\)\) == \(1\)\]\s*FROM\n\s*DF', plan, re.DOTALL
429
)
430
is not None
431
)
432
433
434
def test_predicate_reduction() -> None:
435
# ensure we get clean reduction without casts
436
lf = pl.LazyFrame({"a": [1], "b": [2]})
437
for filter_frame in (lf.filter, lf.remove):
438
assert (
439
"cast"
440
not in filter_frame(
441
pl.col("a") > 1,
442
pl.col("b") > 1,
443
).explain()
444
)
445
446
447
def test_all_any_cleanup_at_single_predicate_case() -> None:
448
plan = pl.LazyFrame({"a": [1], "b": [2]}).select(["a"]).drop_nulls().explain()
449
assert "horizontal" not in plan
450
assert "all" not in plan
451
452
453
def test_hconcat_predicate() -> None:
454
# Predicates shouldn't be pushed down past an hconcat as we can't filter
455
# across the different inputs
456
lf1 = pl.LazyFrame(
457
{
458
"a1": [0, 1, 2, 3, 4],
459
"a2": [5, 6, 7, 8, 9],
460
}
461
)
462
lf2 = pl.LazyFrame(
463
{
464
"b1": [0, 1, 2, 3, 4],
465
"b2": [5, 6, 7, 8, 9],
466
}
467
)
468
469
query = pl.concat(
470
[
471
lf1.filter(pl.col("a1") < 4),
472
lf2.filter(pl.col("b1") > 0),
473
],
474
how="horizontal",
475
).filter(pl.col("b2") < 9)
476
477
expected = pl.DataFrame(
478
{
479
"a1": [0, 1, 2],
480
"a2": [5, 6, 7],
481
"b1": [1, 2, 3],
482
"b2": [6, 7, 8],
483
}
484
)
485
result = query.collect(optimizations=pl.QueryOptFlags(predicate_pushdown=True))
486
assert_frame_equal(result, expected)
487
488
489
def test_predicate_pd_join_13300() -> None:
490
# https://github.com/pola-rs/polars/issues/13300
491
492
lf = pl.LazyFrame({"col3": range(10, 14), "new_col": range(11, 15)})
493
lf_other = pl.LazyFrame({"col4": [0, 11, 2, 13]})
494
495
lf = lf.join(lf_other, left_on="new_col", right_on="col4", how="left")
496
for res in (
497
lf.filter(pl.col("new_col") < 12),
498
lf.remove(pl.col("new_col") >= 12),
499
):
500
assert res.collect().to_dict(as_series=False) == {"col3": [10], "new_col": [11]}
501
502
503
def test_filter_eq_missing_13861() -> None:
504
lf = pl.LazyFrame({"a": [1, None, 3], "b": ["xx", "yy", None]})
505
lf_empty = lf.clear()
506
507
with pytest.warns(UserWarning, match="Comparisons with None always result in null"):
508
assert_frame_equal(lf.collect().filter(a=None), lf_empty.collect())
509
510
with pytest.warns(UserWarning, match="Comparisons with None always result in null"):
511
assert_frame_equal(lf.collect().remove(a=None), lf.collect())
512
513
with pytest.warns(UserWarning, match="Comparisons with None always result in null"):
514
lff = lf.filter(a=None)
515
assert lff.collect().rows() == []
516
assert " ==v " not in lff.explain() # check no `eq_missing` op
517
518
with pytest.warns(UserWarning, match="Comparisons with None always result in null"):
519
assert_frame_equal(lf.collect().filter(a=None), lf_empty.collect())
520
521
with pytest.warns(UserWarning, match="Comparisons with None always result in null"):
522
assert_frame_equal(lf.collect().remove(a=None), lf.collect())
523
524
for filter_expr in (
525
pl.col("a").eq_missing(None),
526
pl.col("a").is_null(),
527
):
528
assert lf.collect().filter(filter_expr).rows() == [(None, "yy")]
529
530
531
@pytest.mark.parametrize("how", ["left", "inner"])
532
def test_predicate_pushdown_block_join(how: Any) -> None:
533
q = (
534
pl.LazyFrame({"a": [1]})
535
.join(
536
pl.LazyFrame({"a": [2], "b": [1]}),
537
left_on=["a"],
538
right_on=["b"],
539
how=how,
540
)
541
.filter(pl.col("a") == 1)
542
)
543
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), q.collect())
544
545
546
def test_predicate_push_down_with_alias_15442() -> None:
547
df = pl.DataFrame({"a": [1]})
548
output = (
549
df.lazy()
550
.filter(pl.col("a").alias("x").drop_nulls() > 0)
551
.collect(optimizations=pl.QueryOptFlags(predicate_pushdown=True))
552
)
553
assert output.to_dict(as_series=False) == {"a": [1]}
554
555
556
def test_predicate_slice_pushdown_list_gather_17492(
557
monkeypatch: pytest.MonkeyPatch,
558
) -> None:
559
lf = pl.LazyFrame({"val": [[1], [1, 1]], "len": [1, 2]})
560
561
assert_frame_equal(
562
lf.filter(pl.col("len") == 2).filter(pl.col("val").list.get(1) == 1),
563
lf.slice(1, 1),
564
)
565
566
# null_on_oob=True can pass
567
568
plan = (
569
lf.filter(pl.col("len") == 2)
570
.filter(pl.col("val").list.get(1, null_on_oob=True) == 1)
571
.explain()
572
)
573
574
assert re.search(r"FILTER.*FROM\n\s*DF", plan, re.DOTALL) is not None
575
576
# Also check slice pushdown
577
q = lf.with_columns(pl.col("val").list.get(1).alias("b")).slice(1, 1)
578
579
assert_frame_equal(
580
q.collect(),
581
pl.DataFrame(
582
{
583
"val": [[1, 1]],
584
"len": pl.Series([2], dtype=pl.Int64),
585
"b": pl.Series([1], dtype=pl.Int64),
586
}
587
),
588
)
589
590
591
def test_predicate_pushdown_struct_unnest_19632() -> None:
592
lf = pl.LazyFrame({"a": [{"a": 1, "b": 2}]}).unnest("a")
593
594
q = lf.filter(pl.col("a") == 1)
595
plan = q.explain()
596
597
assert "FILTER" in plan
598
assert plan.index("FILTER") < plan.index("UNNEST")
599
600
assert_frame_equal(
601
q.collect(),
602
pl.DataFrame({"a": 1, "b": 2}),
603
)
604
605
# With `pl.struct()`
606
lf = pl.LazyFrame({"a": 1, "b": 2}).select(pl.struct(pl.all())).unnest("a")
607
608
q = lf.filter(pl.col("a") == 1)
609
plan = q.explain()
610
611
assert "FILTER" in plan
612
assert plan.index("FILTER") < plan.index("UNNEST")
613
614
assert_frame_equal(
615
q.collect(),
616
pl.DataFrame({"a": 1, "b": 2}),
617
)
618
619
# With `value_counts()`
620
lf = pl.LazyFrame({"a": [1]}).select(pl.col("a").value_counts()).unnest("a")
621
622
q = lf.filter(pl.col("a") == 1)
623
plan = q.explain()
624
625
assert plan.index("FILTER") < plan.index("UNNEST")
626
627
assert_frame_equal(
628
q.collect(),
629
pl.DataFrame({"a": 1, "count": 1}, schema={"a": pl.Int64, "count": pl.UInt32}),
630
)
631
632
633
@pytest.mark.parametrize(
634
"predicate",
635
[
636
pl.col("v") == 7,
637
pl.col("v") != 99,
638
pl.col("v") > 0,
639
pl.col("v") < 999,
640
pl.col("v").is_in([7]),
641
pl.col("v").cast(pl.Boolean),
642
pl.col("b"),
643
],
644
)
645
@pytest.mark.parametrize("alias", [True, False])
646
@pytest.mark.parametrize("join_type", ["left", "right"])
647
def test_predicate_pushdown_join_19772(
648
predicate: pl.Expr, join_type: str, alias: bool
649
) -> None:
650
left = pl.LazyFrame({"k": [1, 2]})
651
right = pl.LazyFrame({"k": [1], "v": [7], "b": True})
652
653
if join_type == "right":
654
[left, right] = [right, left]
655
656
if alias:
657
predicate = predicate.alias(":V")
658
659
q = left.join(right, on="k", how=join_type).filter(predicate) # type: ignore[arg-type]
660
661
expect = pl.DataFrame({"k": 1, "v": 7, "b": True})
662
663
if join_type == "right":
664
expect = expect.select("v", "b", "k")
665
666
assert_frame_equal(q.collect(optimizations=pl.QueryOptFlags.none()), expect)
667
assert_frame_equal(q.collect(), expect)
668
669
670
def test_predicate_pushdown_scalar_20489() -> None:
671
df = pl.DataFrame({"a": [1]})
672
mask = pl.Series([False])
673
674
assert_frame_equal(
675
df.lazy().with_columns(b=pl.Series([2])).filter(mask).collect(),
676
pl.DataFrame(schema={"a": pl.Int64, "b": pl.Int64}),
677
)
678
679
680
def test_predicates_not_split_when_pushdown_disabled_20475() -> None:
681
# This is important for the eager `DataFrame.filter()`, as that runs without
682
# predicate pushdown enabled. Splitting the predicates in that case can
683
# severely degrade performance.
684
q = pl.LazyFrame({"a": 1, "b": 1, "c": 1}).filter(
685
pl.col("a") > 0, pl.col("b") > 0, pl.col("c") > 0
686
)
687
assert (
688
q.explain(optimizations=pl.QueryOptFlags(predicate_pushdown=False)).count(
689
"FILTER"
690
)
691
== 1
692
)
693
694
695
def test_predicate_filtering_against_nulls() -> None:
696
df = pl.DataFrame({"num": [1, 2, None, 4]})
697
698
for res in (
699
df.filter(pl.col("num") > 2),
700
df.filter(pl.col("num").is_in([3, 4, 5])),
701
):
702
assert res["num"].to_list() == [4]
703
704
for res in (
705
df.remove(pl.col("num") <= 2),
706
df.remove(pl.col("num").is_in([1, 2, 3])),
707
):
708
assert res["num"].to_list() == [None, 4]
709
710
for res in (
711
df.filter(pl.col("num").ne_missing(None)),
712
df.remove(pl.col("num").eq_missing(None)),
713
):
714
assert res["num"].to_list() == [1, 2, 4]
715
716
717
@pytest.mark.parametrize(
718
("query", "expected"),
719
[
720
(
721
(
722
pl.LazyFrame({"a": [1], "b": [2], "c": [3]})
723
.rename({"a": "A", "b": "a"})
724
.select("A", "c")
725
.filter(pl.col("A") == 1)
726
),
727
pl.DataFrame({"A": 1, "c": 3}),
728
),
729
(
730
(
731
pl.LazyFrame({"a": [1], "b": [2], "c": [3]})
732
.rename({"b": "a", "a": "A"})
733
.select("A", "c")
734
.filter(pl.col("A") == 1)
735
),
736
pl.DataFrame({"A": 1, "c": 3}),
737
),
738
(
739
(
740
pl.LazyFrame({"a": [1], "b": [2], "c": [3]})
741
.rename({"a": "b", "b": "a"})
742
.select("a", "b", "c")
743
.filter(pl.col("b") == 1)
744
),
745
pl.DataFrame({"a": 2, "b": 1, "c": 3}),
746
),
747
(
748
(
749
pl.LazyFrame({"a": [1], "b": [2], "c": [3]})
750
.rename({"a": "b", "b": "a"})
751
.select("b", "c")
752
.filter(pl.col("b") == 1)
753
),
754
pl.DataFrame({"b": 1, "c": 3}),
755
),
756
(
757
(
758
pl.LazyFrame({"a": [1], "b": [2], "c": [3]})
759
.rename({"b": "a", "a": "b"})
760
.select("a", "b", "c")
761
.filter(pl.col("b") == 1)
762
),
763
pl.DataFrame({"a": 2, "b": 1, "c": 3}),
764
),
765
],
766
)
767
def test_predicate_pushdown_lazy_rename_22373(
768
query: pl.LazyFrame,
769
expected: pl.DataFrame,
770
) -> None:
771
assert_frame_equal(
772
query.collect(),
773
expected,
774
)
775
776
# Ensure filter is pushed past rename
777
plan = query.explain()
778
assert plan.index("FILTER") > plan.index("SELECT")
779
780
781
@pytest.mark.parametrize(
782
"base_query",
783
[
784
( # Fallible expr in earlier `with_columns()`
785
pl.LazyFrame({"a": [[1]]})
786
.with_columns(MARKER=1)
787
.with_columns(b=pl.col("a").list.get(1, null_on_oob=False))
788
),
789
( # Fallible expr in earlier `filter()`
790
pl.LazyFrame({"a": [[1]]})
791
.with_columns(MARKER=1)
792
.filter(
793
pl.col("a")
794
.list.get(1, null_on_oob=False)
795
.cast(pl.Boolean, strict=False)
796
)
797
),
798
( # Fallible expr in earlier `select()`
799
pl.LazyFrame({"a": [[1]]})
800
.with_columns(MARKER=1)
801
.select("a", "MARKER", b=pl.col("a").list.get(1, null_on_oob=False))
802
),
803
],
804
)
805
def test_predicate_pushdown_pushes_past_fallible(
806
base_query: pl.LazyFrame, monkeypatch: pytest.MonkeyPatch
807
) -> None:
808
# Ensure baseline fails
809
with pytest.raises(ComputeError, match="index is out of bounds"):
810
base_query.collect()
811
812
q = base_query.filter(pl.col("a").list.len() > 1)
813
814
plan = q.explain()
815
816
assert plan.index("list.len") > plan.index("MARKER")
817
818
assert_frame_equal(q.collect(), pl.DataFrame(schema=q.collect_schema()))
819
820
monkeypatch.setenv("POLARS_PUSHDOWN_OPT_MAINTAIN_ERRORS", "1")
821
822
with pytest.raises(ComputeError, match="index is out of bounds"):
823
q.collect()
824
825
826
def test_predicate_pushdown_fallible_exprs_22284(
827
monkeypatch: pytest.MonkeyPatch,
828
) -> None:
829
q = (
830
pl.LazyFrame({"a": ["xyz", "123", "456", "789"]})
831
.with_columns(MARKER=1)
832
.filter(pl.col.a.str.contains(r"^\d{3}$"))
833
.filter(pl.col.a.cast(pl.Int64) >= 123)
834
)
835
836
plan = q.explain()
837
838
assert (
839
plan.index('FILTER [(col("a").strict_cast(Int64)) >= (123)]')
840
< plan.index("MARKER")
841
< plan.index(r'FILTER col("a").str.contains(["^\d{3}$"])')
842
)
843
844
assert_frame_equal(
845
q.collect(),
846
pl.DataFrame(
847
{
848
"a": ["123", "456", "789"],
849
"MARKER": 1,
850
}
851
),
852
)
853
854
lf = pl.LazyFrame(
855
{
856
"str_date": ["2025-01-01", "20250101"],
857
"data_source": ["system_1", "system_2"],
858
}
859
)
860
861
q = lf.filter(pl.col("data_source") == "system_1").filter(
862
pl.col("str_date").str.to_datetime("%Y-%m-%d", strict=True)
863
== datetime(2025, 1, 1)
864
)
865
866
assert_frame_equal(
867
q.collect(),
868
pl.DataFrame(
869
{
870
"str_date": ["2025-01-01"],
871
"data_source": ["system_1"],
872
}
873
),
874
)
875
876
q = lf.with_columns(
877
pl.col("str_date").str.to_datetime("%Y-%m-%d", strict=True)
878
).filter(pl.col("data_source") == "system_1")
879
880
assert_frame_equal(
881
q.collect(),
882
pl.DataFrame(
883
{
884
"str_date": [datetime(2025, 1, 1)],
885
"data_source": ["system_1"],
886
}
887
),
888
)
889
890
monkeypatch.setenv("POLARS_PUSHDOWN_OPT_MAINTAIN_ERRORS", "1")
891
892
with pytest.raises(
893
InvalidOperationError, match=r"`str` to `datetime\[μs\]` failed"
894
):
895
q.collect()
896
897
898
def test_predicate_pushdown_single_fallible() -> None:
899
lf = pl.LazyFrame({"a": [0, 1]}).with_columns(MARKER=pl.lit(1, dtype=pl.Int64))
900
901
q = lf.filter(pl.col("a").cast(pl.Boolean))
902
903
plan = q.explain()
904
905
assert plan.index('FILTER col("a").strict_cast(Boolean)') > plan.index("MARKER")
906
907
assert_frame_equal(q.collect(), pl.DataFrame({"a": 1, "MARKER": 1}))
908
909
910
def test_predicate_pushdown_split_pushable(
911
monkeypatch: pytest.MonkeyPatch,
912
) -> None:
913
lf = pl.LazyFrame({"a": [1, 999]}).with_columns(MARKER=pl.lit(1, dtype=pl.Int64))
914
915
q = lf.filter(
916
pl.col("a") == 1, # pushable
917
pl.col("a").cast(pl.Int8) == 1, # fallible
918
)
919
920
plan = q.explain()
921
922
assert (
923
plan.index('FILTER [(col("a").strict_cast(Int8)) == (1)]')
924
< plan.index("MARKER")
925
< plan.index('FILTER [(col("a")) == (1)]')
926
)
927
928
assert_frame_equal(q.collect(), pl.DataFrame({"a": 1, "MARKER": 1}))
929
930
with monkeypatch.context() as cx:
931
cx.setenv("POLARS_PUSHDOWN_OPT_MAINTAIN_ERRORS", "1")
932
933
with pytest.raises(
934
InvalidOperationError, match="conversion from `i64` to `i8` failed"
935
):
936
q.collect()
937
938
q = lf.filter(
939
pl.col("a").cast(pl.UInt16) == 1,
940
pl.col("a").sort() == 1,
941
)
942
943
plan = q.explain()
944
945
assert plan.index(
946
'FILTER [([(col("a").strict_cast(UInt16)) == (1)]) & ([(col("a").sort(asc)) == (1)])]'
947
) < plan.index("MARKER")
948
949
assert_frame_equal(q.collect(), pl.DataFrame({"a": 1, "MARKER": 1}))
950
951
with monkeypatch.context() as cx:
952
cx.setenv("POLARS_PUSHDOWN_OPT_MAINTAIN_ERRORS", "1")
953
assert_frame_equal(q.collect(), pl.DataFrame({"a": 1, "MARKER": 1}))
954
955
# Ensure it is not pushed past a join
956
957
# Baseline
958
q = lf.join(
959
lf.drop("MARKER").collect().lazy(),
960
on="a",
961
how="inner",
962
coalesce=False,
963
maintain_order="left_right",
964
).filter(pl.col("a_right") == 1)
965
966
plan = q.explain()
967
968
assert not plan.startswith("FILTER")
969
970
assert_frame_equal(
971
q.collect(),
972
pl.DataFrame(
973
{
974
"a": 1,
975
"MARKER": 1,
976
"a_right": 1,
977
}
978
),
979
)
980
981
q = lf.join(
982
lf.drop("MARKER").collect().lazy(),
983
on="a",
984
how="inner",
985
coalesce=False,
986
maintain_order="left_right",
987
).filter(pl.col("a_right").cast(pl.Int16) == 1)
988
989
plan = q.explain()
990
991
assert plan.startswith("FILTER")
992
993
assert_frame_equal(
994
q.collect(),
995
pl.DataFrame(
996
{
997
"a": 1,
998
"MARKER": 1,
999
"a_right": 1,
1000
}
1001
),
1002
)
1003
1004
# With a select node in between
1005
1006
q = (
1007
lf.join(
1008
lf.drop("MARKER").collect().lazy(),
1009
on="a",
1010
how="inner",
1011
coalesce=False,
1012
maintain_order="left_right",
1013
)
1014
.select(
1015
"a",
1016
"a_right",
1017
"MARKER",
1018
)
1019
.filter(pl.col("a_right").cast(pl.Int16) == 1)
1020
)
1021
1022
plan = q.explain()
1023
1024
assert plan.startswith("FILTER")
1025
1026
assert_frame_equal(
1027
q.collect(),
1028
pl.DataFrame(
1029
{
1030
"a": 1,
1031
"a_right": 1,
1032
"MARKER": 1,
1033
}
1034
),
1035
)
1036
1037
1038
def test_predicate_pushdown_fallible_literal_in_filter_expr() -> None:
1039
# Fallible operations on literals inside of the predicate expr should not
1040
# block pushdown.
1041
lf = pl.LazyFrame(
1042
{"column": "2025-01-01", "column_date": datetime(2025, 1, 1), "integer": 1}
1043
)
1044
1045
q = lf.with_columns(
1046
MARKER=1,
1047
).filter(
1048
pl.col("column_date")
1049
== pl.lit("2025-01-01").str.to_datetime("%Y-%m-%d", strict=True)
1050
)
1051
1052
plan = q.explain()
1053
1054
assert plan.index("FILTER") > plan.index("MARKER")
1055
1056
assert q.collect().height == 1
1057
1058
q = lf.with_columns(
1059
MARKER=1,
1060
).filter(pl.col("integer") == pl.lit("1").cast(pl.Int64, strict=True))
1061
1062
plan = q.explain()
1063
1064
assert plan.index("FILTER") > plan.index("MARKER")
1065
1066
assert q.collect().height == 1
1067
1068
1069
def test_predicate_does_not_split_barrier_expr() -> None:
1070
q = (
1071
pl.LazyFrame({"a": [1, 2, 3]})
1072
.with_row_index()
1073
.filter(pl.col("a") > 1, pl.col("a").sort() == 3)
1074
)
1075
1076
plan = q.explain()
1077
1078
assert plan.startswith(
1079
'FILTER [([(col("a")) > (1)]) & ([(col("a").sort(asc)) == (3)])]'
1080
)
1081
1082
assert_frame_equal(
1083
q.collect(),
1084
pl.DataFrame({"a": 3}).with_row_index(offset=2),
1085
)
1086
1087
1088
def test_predicate_passes_set_sorted_22397() -> None:
1089
plan = (
1090
pl.LazyFrame({"a": [1, 2, 3]})
1091
.with_columns(MARKER=1, b=pl.lit(1))
1092
.set_sorted("a")
1093
.filter(pl.col("a") <= 1)
1094
.explain()
1095
)
1096
assert plan.index("FILTER") > plan.index("MARKER")
1097
1098
1099
@pytest.mark.filterwarnings("ignore")
1100
def test_predicate_pass() -> None:
1101
plan = (
1102
pl.LazyFrame({"a": [1, 2, 3]})
1103
.with_columns(MARKER=pl.col("a"))
1104
.filter(pl.col("a").map_elements(lambda x: x > 2, return_dtype=pl.Boolean))
1105
.explain()
1106
)
1107
assert plan.index("FILTER") > plan.index("MARKER")
1108
1109
1110
def test_predicate_pushdown_auto_disable_strict() -> None:
1111
# Test that type-coercion automatically switches strict cast to
1112
# non-strict/overflowing for compatible types, allowing the predicate to be
1113
# pushed.
1114
lf = pl.LazyFrame(
1115
{"column": "2025-01-01", "column_date": datetime(2025, 1, 1), "integer": 1},
1116
schema={
1117
"column": pl.String,
1118
"column_date": pl.Datetime("ns"),
1119
"integer": pl.Int64,
1120
},
1121
)
1122
1123
q = lf.with_columns(
1124
MARKER=1,
1125
).filter(
1126
pl.col("column_date").cast(pl.Datetime("us")) == pl.lit(datetime(2025, 1, 1)),
1127
pl.col("integer") == 1,
1128
)
1129
1130
plan = q.explain()
1131
assert plan.index("FILTER") > plan.index("MARKER")
1132
1133
q = lf.with_columns(
1134
MARKER=1,
1135
).filter(
1136
pl.col("column_date").cast(pl.Datetime("us"), strict=False)
1137
== pl.lit(datetime(2025, 1, 1)),
1138
pl.col("integer").cast(pl.Int128, strict=True) == 1,
1139
)
1140
1141
plan = q.explain()
1142
assert plan.index("FILTER") > plan.index("MARKER")
1143
1144
1145
@pytest.mark.may_fail_auto_streaming # IO plugin validate=False schema mismatch
1146
def test_predicate_pushdown_map_elements_io_plugin_22860() -> None:
1147
def generator(
1148
with_columns: list[str] | None,
1149
predicate: pl.Expr | None,
1150
n_rows: int | None,
1151
batch_size: int | None,
1152
) -> Iterator[pl.DataFrame]:
1153
df = pl.DataFrame({"row_nr": [1, 2, 3, 4, 5], "y": [0, 1, 0, 1, 1]})
1154
assert predicate is not None
1155
yield df.filter(predicate)
1156
1157
q = register_io_source(
1158
io_source=generator, schema={"x": pl.Int64, "y": pl.Int64}
1159
).filter(pl.col("y").map_elements(bool, return_dtype=pl.Boolean))
1160
1161
plan = q.explain()
1162
assert plan.index("SELECTION") > plan.index("PYTHON SCAN")
1163
1164
assert_frame_equal(q.collect(), pl.DataFrame({"row_nr": [2, 4, 5], "y": [1, 1, 1]}))
1165
1166
1167
def test_duplicate_filter_removal_23243() -> None:
1168
lf = pl.LazyFrame({"x": [1, 2, 3]})
1169
1170
q = lf.filter(pl.col("x") == 2, pl.col("x") == 2)
1171
1172
expect = pl.DataFrame({"x": [2]})
1173
1174
plan = q.explain()
1175
1176
assert plan.split("\n", 1)[0] == 'FILTER [(col("x")) == (2)]'
1177
1178
assert_frame_equal(q.collect(), expect)
1179
1180