Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/lazyframe/test_cwc.py
8424 views
1
# Tests for the optimization pass cluster WITH_COLUMNS
2
3
import pytest
4
5
import polars as pl
6
from polars.exceptions import ColumnNotFoundError
7
from polars.testing import assert_frame_equal
8
9
10
def test_basic_cwc() -> None:
11
df = (
12
pl.LazyFrame({"a": [1, 2]})
13
.with_columns(pl.col("a").alias("b") * 2)
14
.with_columns(pl.col("a").alias("c") * 3)
15
.with_columns(pl.col("a").alias("d") * 4)
16
)
17
18
assert (
19
"""[[(col("a")) * (2)].alias("b"), [(col("a")) * (3)].alias("c"), [(col("a")) * (4)].alias("d")]"""
20
in df.explain()
21
)
22
23
24
def test_disable_cwc() -> None:
25
df = (
26
pl.LazyFrame({"a": [1, 2]})
27
.with_columns(pl.col("a").alias("b") * 2)
28
.with_columns(pl.col("a").alias("c") * 3)
29
.with_columns(pl.col("a").alias("d") * 4)
30
)
31
32
explain = df.explain(optimizations=pl.QueryOptFlags(cluster_with_columns=False))
33
34
assert """[[(col("a")) * (2)].alias("b")]""" in explain
35
assert """[[(col("a")) * (3)].alias("c")]""" in explain
36
assert """[[(col("a")) * (4)].alias("d")]""" in explain
37
38
39
def test_refuse_with_deps() -> None:
40
df = (
41
pl.LazyFrame({"a": [1, 2]})
42
.with_columns(pl.col("a").alias("b") * 2)
43
.with_columns(pl.col("b").alias("c") * 3)
44
.with_columns(pl.col("c").alias("d") * 4)
45
)
46
47
explain = df.explain()
48
49
assert """[[(col("a")) * (2)].alias("b")]""" in explain
50
assert """[[(col("b")) * (3)].alias("c")]""" in explain
51
assert """[[(col("c")) * (4)].alias("d")]""" in explain
52
53
54
def test_partial_deps() -> None:
55
df = (
56
pl.LazyFrame({"a": [1, 2]})
57
.with_columns(pl.col("a").alias("b") * 2)
58
.with_columns(
59
pl.col("a").alias("c") * 3,
60
pl.col("b").alias("d") * 4,
61
pl.col("a").alias("e") * 5,
62
)
63
.with_columns(pl.col("b").alias("f") * 6)
64
)
65
66
explain = df.explain()
67
68
assert (
69
"""[[(col("b")) * (4)].alias("d"), [(col("b")) * (6)].alias("f")]""" in explain
70
)
71
assert (
72
"""[[(col("a")) * (2)].alias("b"), [(col("a")) * (3)].alias("c"), [(col("a")) * (5)].alias("e")]"""
73
in explain
74
)
75
76
77
def test_swap_remove() -> None:
78
df = (
79
pl.LazyFrame({"a": [1, 2]})
80
.with_columns(pl.col("a").alias("b") * 2)
81
.with_columns(
82
pl.col("b").alias("f") * 6,
83
pl.col("a").alias("c") * 3,
84
pl.col("b").alias("d") * 4,
85
pl.col("b").alias("e") * 5,
86
)
87
)
88
89
explain = df.explain()
90
assert df.collect().equals(
91
pl.DataFrame(
92
{
93
"a": [1, 2],
94
"b": [2, 4],
95
"f": [12, 24],
96
"c": [3, 6],
97
"d": [8, 16],
98
"e": [10, 20],
99
}
100
)
101
)
102
103
assert (
104
"""[[(col("b")) * (6)].alias("f"), [(col("b")) * (4)].alias("d"), [(col("b")) * (5)].alias("e")]"""
105
in explain
106
)
107
assert (
108
"""[[(col("a")) * (2)].alias("b"), [(col("a")) * (3)].alias("c")]""" in explain
109
)
110
assert """simple π""" in explain
111
112
113
def test_try_remove_simple_project() -> None:
114
q = (
115
pl.LazyFrame({"a": [1, 2]})
116
.with_columns(pl.col("a").alias("b") * 2)
117
.with_columns(pl.col("a").alias("d") * 4, pl.col("b").alias("c") * 3)
118
)
119
120
assert_frame_equal(
121
q.collect(),
122
pl.DataFrame(
123
[
124
pl.Series("a", [1, 2], dtype=pl.Int64),
125
pl.Series("b", [2, 4], dtype=pl.Int64),
126
pl.Series("d", [4, 8], dtype=pl.Int64),
127
pl.Series("c", [6, 12], dtype=pl.Int64),
128
]
129
),
130
)
131
132
plan = q.explain()
133
134
assert """[[(col("a")) * (2)].alias("b"), [(col("a")) * (4)].alias("d")]""" in plan
135
assert """[[(col("b")) * (3)].alias("c")]""" in plan
136
assert """simple π""" not in plan
137
138
q = (
139
pl.LazyFrame({"a": [1, 2]})
140
.with_columns(pl.col("a").alias("b") * 2)
141
.with_columns(pl.col("b").alias("c") * 3, pl.col("a").alias("d") * 4)
142
)
143
144
assert_frame_equal(
145
q.collect(),
146
pl.DataFrame(
147
[
148
pl.Series("a", [1, 2], dtype=pl.Int64),
149
pl.Series("b", [2, 4], dtype=pl.Int64),
150
pl.Series("c", [6, 12], dtype=pl.Int64),
151
pl.Series("d", [4, 8], dtype=pl.Int64),
152
]
153
),
154
)
155
156
plan = q.explain()
157
158
assert """[[(col("a")) * (2)].alias("b"), [(col("a")) * (4)].alias("d")]""" in plan
159
assert """[[(col("b")) * (3)].alias("c")]""" in plan
160
assert """simple π""" in plan
161
162
163
def test_cwc_with_internal_aliases() -> None:
164
df = (
165
pl.LazyFrame({"a": [1, 2], "b": [3, 4]})
166
.with_columns(pl.any_horizontal((pl.col("a") == 2).alias("b")).alias("c"))
167
.with_columns(pl.col("b").alias("d") * 3)
168
)
169
170
explain = df.explain()
171
172
assert (
173
"""[[(col("a")) == (2)].alias("c"), [(col("b")) * (3)].alias("d")]""" in explain
174
)
175
176
177
def test_read_of_pushed_column_16436() -> None:
178
df = pl.DataFrame(
179
{
180
"x": [1.12, 2.21, 4.2, 3.21],
181
"y": [2.11, 3.32, 2.1, 6.12],
182
}
183
)
184
185
df = (
186
df.lazy()
187
.with_columns((pl.col("y") / pl.col("x")).alias("z"))
188
.with_columns(
189
pl.when(pl.col("z").is_infinite()).then(0).otherwise(pl.col("z")).alias("z")
190
)
191
.fill_nan(0)
192
.collect()
193
)
194
195
196
def test_multiple_simple_projections_16435() -> None:
197
df = pl.DataFrame({"a": [1]}).lazy()
198
199
df = (
200
df.with_columns(b=pl.col("a"))
201
.with_columns(c=pl.col("b"))
202
.with_columns(l2a=pl.lit(2))
203
.with_columns(l2b=pl.col("l2a"))
204
.with_columns(m=pl.lit(3))
205
)
206
207
df.collect()
208
209
210
def test_reverse_order() -> None:
211
df = pl.LazyFrame({"a": [1], "b": [2]})
212
213
df = (
214
df.with_columns(a=pl.col("a"), b=pl.col("b"), c=pl.col("a") * pl.col("b"))
215
.with_columns(x=pl.col("a"), y=pl.col("b"))
216
.with_columns(b=pl.col("a"), a=pl.col("b"))
217
)
218
219
df.collect()
220
221
222
def test_realias_of_unread_column_16530() -> None:
223
df = (
224
pl.LazyFrame({"x": [True]})
225
.with_columns(x=pl.lit(False))
226
.with_columns(y=~pl.col("x"))
227
.with_columns(y=pl.lit(False))
228
)
229
230
plan = df.explain()
231
232
assert plan.count("WITH_COLUMNS") == 1
233
assert df.collect().equals(pl.DataFrame({"x": [False], "y": [False]}))
234
235
236
def test_realias_with_dependencies() -> None:
237
df = (
238
pl.LazyFrame({"x": [True]})
239
.with_columns(x=pl.lit(False))
240
.with_columns(y=~pl.col("x"))
241
.with_columns(y=pl.lit(False), z=pl.col("y") | True)
242
)
243
244
explain = df.explain()
245
246
assert explain.count("WITH_COLUMNS") == 3
247
assert df.collect().equals(pl.DataFrame({"x": [False], "y": [False], "z": [True]}))
248
249
250
def test_refuse_pushdown_with_aliases() -> None:
251
df = (
252
pl.LazyFrame({"x": [True]})
253
.with_columns(x=pl.lit(False))
254
.with_columns(y=pl.lit(True))
255
.with_columns(y=pl.lit(False), z=pl.col("y") | True)
256
)
257
258
explain = df.explain()
259
260
assert explain.count("WITH_COLUMNS") == 2
261
assert df.collect().equals(pl.DataFrame({"x": [False], "y": [False], "z": [True]}))
262
263
264
def test_neighbour_live_expr() -> None:
265
df = (
266
pl.LazyFrame({"x": [True]})
267
.with_columns(y=pl.lit(False))
268
.with_columns(x=pl.lit(False), z=pl.col("x") | False)
269
)
270
271
explain = df.explain()
272
273
assert explain.count("WITH_COLUMNS") == 1
274
assert df.collect().equals(pl.DataFrame({"x": [False], "y": [False], "z": [True]}))
275
276
277
def test_cluster_with_columns_collect_all_panic_26092() -> None:
278
lf = pl.LazyFrame()
279
lf = lf.with_columns(pl.lit(1.0).cast(pl.Float64()).alias("numbers1"))
280
lf = lf.with_columns(pl.lit(2.0).cast(pl.Float64()).alias("numbers2"))
281
282
a, b = pl.collect_all([lf, lf])
283
284
assert_frame_equal(a, pl.DataFrame({"numbers1": 1.0, "numbers2": 2.0}))
285
assert_frame_equal(b, pl.DataFrame({"numbers1": 1.0, "numbers2": 2.0}))
286
287
288
def test_cluster_with_columns_schema_update_26417() -> None:
289
lf = pl.LazyFrame({"x": [[0.0, 1.0]], "y": [[2.0]]})
290
291
q = (
292
lf.with_columns(pl.col("x").cast(pl.Array(pl.Float64, shape=2)))
293
.with_columns(pl.col("y").cast(pl.Array(pl.Float64, shape=1)))
294
.with_columns(pl.col("y").arr.get(0))
295
)
296
297
assert_frame_equal(
298
q.collect(),
299
pl.DataFrame(
300
[
301
pl.Series("x", [[0.0, 1.0]], dtype=pl.Array(pl.Float64, shape=(2,))),
302
pl.Series("y", [2.0], dtype=pl.Float64),
303
]
304
),
305
)
306
307
308
def test_cluster_with_columns_use_existing_names_26456() -> None:
309
q = (
310
pl.LazyFrame({"a": [1, 2, 3]})
311
.with_columns(pl.lit(1).alias("b"))
312
.with_columns(pl.col("a") + 1, pl.col("b") + pl.col("a"))
313
)
314
315
assert_frame_equal(
316
q.collect(),
317
pl.DataFrame(
318
[
319
pl.Series("a", [2, 3, 4], dtype=pl.Int64),
320
pl.Series("b", [2, 3, 4], dtype=pl.Int64),
321
]
322
),
323
)
324
325
326
def test_cluster_with_columns_prune_col() -> None:
327
q = (
328
pl.LazyFrame({"foo": [0.5, 1.7, 3.2], "bar": [4.1, 1.5, 9.2]})
329
.with_columns(pl.col("foo").alias("buzz"))
330
.with_columns(pl.col("buzz"), pl.col("foo") * 2.0)
331
)
332
333
plan = q.explain()
334
335
assert plan.count("WITH_COLUMNS") == 1
336
337
assert_frame_equal(
338
q.collect(),
339
pl.DataFrame(
340
[
341
pl.Series("foo", [1.0, 3.4, 6.4], dtype=pl.Float64),
342
pl.Series("bar", [4.1, 1.5, 9.2], dtype=pl.Float64),
343
pl.Series("buzz", [0.5, 1.7, 3.2], dtype=pl.Float64),
344
]
345
),
346
)
347
348
q = pl.LazyFrame({"a": 1}).with_columns(pl.col("a")).with_columns(pl.col("b"))
349
350
with pytest.raises(ColumnNotFoundError, match='unable to find column "b"'):
351
q.collect()
352
353