Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/lazyframe/test_optimizations.py
6939 views
1
import itertools
2
3
import pytest
4
5
import polars as pl
6
from polars.testing import assert_frame_equal
7
8
9
def test_is_null_followed_by_all() -> None:
10
lf = pl.LazyFrame({"group": [0, 0, 0, 1], "val": [6, 0, None, None]})
11
12
expected_df = pl.DataFrame({"group": [0, 1], "val": [False, True]})
13
result_lf = lf.group_by("group", maintain_order=True).agg(
14
pl.col("val").is_null().all()
15
)
16
17
assert r'[[(col("val").len()) == (col("val").null_count())]]' in result_lf.explain()
18
assert "is_null" not in result_lf
19
assert_frame_equal(expected_df, result_lf.collect())
20
21
# verify we don't optimize on chained expressions when last one is not col
22
non_optimized_result_plan = (
23
lf.group_by("group", maintain_order=True)
24
.agg(pl.col("val").abs().is_null().all())
25
.explain()
26
)
27
assert "null_count" not in non_optimized_result_plan
28
assert "is_null" in non_optimized_result_plan
29
30
# edge case of empty series
31
lf = pl.LazyFrame({"val": []}, schema={"val": pl.Int32})
32
33
expected_df = pl.DataFrame({"val": [True]})
34
result_df = lf.select(pl.col("val").is_null().all()).collect()
35
assert_frame_equal(expected_df, result_df)
36
37
38
def test_is_null_followed_by_any() -> None:
39
lf = pl.LazyFrame({"group": [0, 0, 0, 1, 2], "val": [6, 0, None, None, 5]})
40
41
expected_df = pl.DataFrame({"group": [0, 1, 2], "val": [True, True, False]})
42
result_lf = lf.group_by("group", maintain_order=True).agg(
43
pl.col("val").is_null().any()
44
)
45
assert_frame_equal(expected_df, result_lf.collect())
46
47
# edge case of empty series
48
lf = pl.LazyFrame({"val": []}, schema={"val": pl.Int32})
49
50
expected_df = pl.DataFrame({"val": [False]})
51
result_df = lf.select(pl.col("val").is_null().any()).collect()
52
assert_frame_equal(expected_df, result_df)
53
54
55
def test_is_not_null_followed_by_all() -> None:
56
lf = pl.LazyFrame({"group": [0, 0, 0, 1], "val": [6, 0, 5, None]})
57
58
expected_df = pl.DataFrame({"group": [0, 1], "val": [True, False]})
59
result_df = (
60
lf.group_by("group", maintain_order=True)
61
.agg(pl.col("val").is_not_null().all())
62
.collect()
63
)
64
65
assert_frame_equal(expected_df, result_df)
66
67
# edge case of empty series
68
lf = pl.LazyFrame({"val": []}, schema={"val": pl.Int32})
69
70
expected_df = pl.DataFrame({"val": [True]})
71
result_df = lf.select(pl.col("val").is_not_null().all()).collect()
72
assert_frame_equal(expected_df, result_df)
73
74
75
def test_is_not_null_followed_by_any() -> None:
76
lf = pl.LazyFrame({"group": [0, 0, 0, 1, 2], "val": [6, 0, None, None, 5]})
77
78
expected_df = pl.DataFrame({"group": [0, 1, 2], "val": [True, False, True]})
79
result_lf = lf.group_by("group", maintain_order=True).agg(
80
pl.col("val").is_not_null().any()
81
)
82
83
assert r'[[(col("val").null_count()) < (col("val").len())]]' in result_lf.explain()
84
assert "is_not_null" not in result_lf.explain()
85
assert_frame_equal(expected_df, result_lf.collect())
86
87
# verify we don't optimize on chained expressions when last one is not col
88
non_optimized_result_plan = (
89
lf.group_by("group", maintain_order=True)
90
.agg(pl.col("val").abs().is_not_null().any())
91
.explain()
92
)
93
assert "null_count" not in non_optimized_result_plan
94
assert "is_not_null" in non_optimized_result_plan
95
96
# edge case of empty series
97
lf = pl.LazyFrame({"val": []}, schema={"val": pl.Int32})
98
99
expected_df = pl.DataFrame({"val": [False]})
100
result_df = lf.select(pl.col("val").is_not_null().any()).collect()
101
assert_frame_equal(expected_df, result_df)
102
103
104
def test_is_null_followed_by_sum() -> None:
105
lf = pl.LazyFrame({"group": [0, 0, 0, 1, 2], "val": [6, 0, None, None, 5]})
106
107
expected_df = pl.DataFrame(
108
{"group": [0, 1, 2], "val": [1, 1, 0]}, schema_overrides={"val": pl.UInt32}
109
)
110
result_lf = lf.group_by("group", maintain_order=True).agg(
111
pl.col("val").is_null().sum()
112
)
113
114
assert r'[col("val").null_count()]' in result_lf.explain()
115
assert "is_null" not in result_lf.explain()
116
assert_frame_equal(expected_df, result_lf.collect())
117
118
# edge case of empty series
119
lf = pl.LazyFrame({"val": []}, schema={"val": pl.Int32})
120
121
expected_df = pl.DataFrame({"val": [0]}, schema={"val": pl.UInt32})
122
result_df = lf.select(pl.col("val").is_null().sum()).collect()
123
assert_frame_equal(expected_df, result_df)
124
125
126
def test_is_not_null_followed_by_sum() -> None:
127
lf = pl.LazyFrame({"group": [0, 0, 0, 1, 2], "val": [6, 0, None, None, 5]})
128
129
expected_df = pl.DataFrame(
130
{"group": [0, 1, 2], "val": [2, 0, 1]}, schema_overrides={"val": pl.UInt32}
131
)
132
result_lf = lf.group_by("group", maintain_order=True).agg(
133
pl.col("val").is_not_null().sum()
134
)
135
136
assert r'[[(col("val").len()) - (col("val").null_count())]]' in result_lf.explain()
137
assert "is_not_null" not in result_lf.explain()
138
assert_frame_equal(expected_df, result_lf.collect())
139
140
# verify we don't optimize on chained expressions when last one is not col
141
non_optimized_result_lf = lf.group_by("group", maintain_order=True).agg(
142
pl.col("val").abs().is_not_null().sum()
143
)
144
assert "null_count" not in non_optimized_result_lf.explain()
145
assert "is_not_null" in non_optimized_result_lf.explain()
146
147
# edge case of empty series
148
lf = pl.LazyFrame({"val": []}, schema={"val": pl.Int32})
149
150
expected_df = pl.DataFrame({"val": [0]}, schema={"val": pl.UInt32})
151
result_df = lf.select(pl.col("val").is_not_null().sum()).collect()
152
assert_frame_equal(expected_df, result_df)
153
154
155
def test_drop_nulls_followed_by_len() -> None:
156
lf = pl.LazyFrame({"group": [0, 0, 0, 1, 2], "val": [6, 0, None, None, 5]})
157
158
expected_df = pl.DataFrame(
159
{"group": [0, 1, 2], "val": [2, 0, 1]}, schema_overrides={"val": pl.UInt32}
160
)
161
result_lf = lf.group_by("group", maintain_order=True).agg(
162
pl.col("val").drop_nulls().len()
163
)
164
165
assert r'[[(col("val").len()) - (col("val").null_count())]]' in result_lf.explain()
166
assert "drop_nulls" not in result_lf.explain()
167
assert_frame_equal(expected_df, result_lf.collect())
168
169
# verify we don't optimize on chained expressions when last one is not col
170
non_optimized_result_plan = (
171
lf.group_by("group", maintain_order=True)
172
.agg(pl.col("val").abs().drop_nulls().len())
173
.explain()
174
)
175
assert "null_count" not in non_optimized_result_plan
176
assert "drop_nulls" in non_optimized_result_plan
177
178
179
def test_drop_nulls_followed_by_count() -> None:
180
lf = pl.LazyFrame({"group": [0, 0, 0, 1, 2], "val": [6, 0, None, None, 5]})
181
182
expected_df = pl.DataFrame(
183
{"group": [0, 1, 2], "val": [2, 0, 1]}, schema_overrides={"val": pl.UInt32}
184
)
185
result_lf = lf.group_by("group", maintain_order=True).agg(
186
pl.col("val").drop_nulls().count()
187
)
188
189
assert r'[[(col("val").len()) - (col("val").null_count())]]' in result_lf.explain()
190
assert "drop_nulls" not in result_lf.explain()
191
assert_frame_equal(expected_df, result_lf.collect())
192
193
# verify we don't optimize on chained expressions when last one is not col
194
non_optimized_result_plan = (
195
lf.group_by("group", maintain_order=True)
196
.agg(pl.col("val").abs().drop_nulls().count())
197
.explain()
198
)
199
assert "null_count" not in non_optimized_result_plan
200
assert "drop_nulls" in non_optimized_result_plan
201
202
203
def test_collapse_joins() -> None:
204
a = pl.LazyFrame({"a": [1, 2, 3], "b": [2, 2, 2]})
205
b = pl.LazyFrame({"x": [7, 1, 2]})
206
207
cross = a.join(b, how="cross")
208
209
inner_join = cross.filter(pl.col.a == pl.col.x)
210
e = inner_join.explain()
211
assert "INNER JOIN" in e
212
assert "FILTER" not in e
213
assert_frame_equal(
214
inner_join.collect(optimizations=pl.QueryOptFlags(collapse_joins=False)),
215
inner_join.collect(),
216
check_row_order=False,
217
)
218
219
inner_join = cross.filter(pl.col.x == pl.col.a)
220
e = inner_join.explain()
221
assert "INNER JOIN" in e
222
assert "FILTER" not in e
223
assert_frame_equal(
224
inner_join.collect(optimizations=pl.QueryOptFlags(collapse_joins=False)),
225
inner_join.collect(),
226
check_row_order=False,
227
)
228
229
double_inner_join = cross.filter(pl.col.x == pl.col.a).filter(pl.col.x == pl.col.b)
230
e = double_inner_join.explain()
231
assert "INNER JOIN" in e
232
assert "FILTER" not in e
233
assert_frame_equal(
234
double_inner_join.collect(optimizations=pl.QueryOptFlags(collapse_joins=False)),
235
double_inner_join.collect(),
236
check_row_order=False,
237
)
238
239
dont_mix = cross.filter(pl.col.x + pl.col.a != 0)
240
e = dont_mix.explain()
241
assert "NESTED LOOP JOIN" in e
242
assert "FILTER" not in e
243
assert_frame_equal(
244
dont_mix.collect(optimizations=pl.QueryOptFlags(collapse_joins=False)),
245
dont_mix.collect(),
246
check_row_order=False,
247
)
248
249
iejoin = cross.filter(pl.col.x >= pl.col.a)
250
e = iejoin.explain()
251
assert "IEJOIN" in e
252
assert "NESTED LOOP JOIN" not in e
253
assert "CROSS JOIN" not in e
254
assert "FILTER" not in e
255
assert_frame_equal(
256
iejoin.collect(optimizations=pl.QueryOptFlags(collapse_joins=False)),
257
iejoin.collect(),
258
check_row_order=False,
259
)
260
261
iejoin = cross.filter(pl.col.x >= pl.col.a).filter(pl.col.x <= pl.col.b)
262
e = iejoin.explain()
263
assert "IEJOIN" in e
264
assert "CROSS JOIN" not in e
265
assert "NESTED LOOP JOIN" not in e
266
assert "FILTER" not in e
267
assert_frame_equal(
268
iejoin.collect(optimizations=pl.QueryOptFlags(collapse_joins=False)),
269
iejoin.collect(),
270
check_row_order=False,
271
)
272
273
274
@pytest.mark.slow
275
def test_collapse_joins_combinations() -> None:
276
# This just tests all possible combinations for expressions on a cross join.
277
278
a = pl.LazyFrame({"a": [1, 2, 3], "x": [7, 2, 1]})
279
b = pl.LazyFrame({"b": [2, 2, 2], "x": [7, 1, 3]})
280
281
cross = a.join(b, how="cross")
282
283
exprs = []
284
285
for lhs in [pl.col.a, pl.col.b, pl.col.x, pl.lit(1), pl.col.a + pl.col.b]:
286
for rhs in [pl.col.a, pl.col.b, pl.col.x, pl.lit(1), pl.col.a * pl.col.x]:
287
for cmp in ["__eq__", "__ge__", "__lt__"]:
288
e = (getattr(lhs, cmp))(rhs)
289
exprs.append(e)
290
291
for amount in range(3):
292
for merge in itertools.product(["__and__", "__or__"] * (amount - 1)):
293
for es in itertools.product(*([exprs] * amount)):
294
e = es[0]
295
for i in range(amount - 1):
296
e = (getattr(e, merge[i]))(es[i + 1])
297
298
# NOTE: We need to sort because the order of the cross-join &
299
# IE-join is unspecified. Therefore, this might not necessarily
300
# create the exact same dataframe.
301
optimized = cross.filter(e).sort(pl.all()).collect()
302
unoptimized = cross.filter(e).collect(
303
optimizations=pl.QueryOptFlags(collapse_joins=False)
304
)
305
306
try:
307
assert_frame_equal(optimized, unoptimized, check_row_order=False)
308
except:
309
print(e)
310
print()
311
print("Optimized")
312
print(cross.filter(e).explain())
313
print(optimized)
314
print()
315
print("Unoptimized")
316
print(
317
cross.filter(e).explain(
318
optimizations=pl.QueryOptFlags(collapse_joins=False)
319
)
320
)
321
print(unoptimized)
322
print()
323
324
raise
325
326
327
def test_order_observe_sort_before_unique_22485() -> None:
328
lf = pl.LazyFrame(
329
{
330
"order": [3, 2, 1],
331
"id": ["A", "A", "B"],
332
}
333
)
334
335
expect = pl.DataFrame({"order": [1, 3], "id": ["B", "A"]})
336
337
q = lf.sort("order").unique(["id"], keep="last").sort("order")
338
339
plan = q.explain()
340
assert "SORT BY" in plan[plan.index("UNIQUE") :]
341
342
assert_frame_equal(q.collect(), expect)
343
344
q = lf.sort("order").unique(["id"], keep="last", maintain_order=True)
345
346
plan = q.explain()
347
assert "SORT BY" in plan[plan.index("UNIQUE") :]
348
349
assert_frame_equal(q.collect(), expect)
350
351
352
def test_order_observe_group_by() -> None:
353
q = (
354
pl.LazyFrame({"a": range(5)})
355
.group_by("a", maintain_order=True)
356
.agg(b=1)
357
.sort("b")
358
)
359
360
plan = q.explain()
361
assert "AGGREGATE[maintain_order: false]" in plan
362
363
q = (
364
pl.LazyFrame({"a": range(5)})
365
.group_by("a", maintain_order=True)
366
.agg(b=1)
367
.sort("b", maintain_order=True)
368
)
369
370
plan = q.explain()
371
assert "AGGREGATE[maintain_order: true]" in plan
372
373
374
def test_fused_correct_name() -> None:
375
df = pl.DataFrame({"x": [1, 2, 3]})
376
377
lf = df.lazy().select(
378
(pl.col.x.alias("a") * pl.col.x.alias("b")) + pl.col.x.alias("c")
379
)
380
381
no_opts = lf.collect(optimizations=pl.QueryOptFlags.none())
382
opts = lf.collect()
383
assert_frame_equal(
384
no_opts,
385
opts,
386
)
387
assert_frame_equal(opts, pl.DataFrame({"a": [2, 6, 12]}))
388
389