Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/functions/test_when_then.py
8420 views
1
from __future__ import annotations
2
3
import itertools
4
import random
5
from datetime import datetime
6
from typing import Any
7
8
import pytest
9
10
import polars as pl
11
import polars.selectors as cs
12
from polars.exceptions import InvalidOperationError, ShapeError
13
from polars.testing import assert_frame_equal, assert_series_equal
14
15
16
def test_when_then() -> None:
17
df = pl.DataFrame({"a": [1, 2, 3, 4, 5]})
18
19
expr = pl.when(pl.col("a") < 3).then(pl.lit("x"))
20
21
result = df.select(
22
expr.otherwise(pl.lit("y")).alias("a"),
23
expr.alias("b"),
24
)
25
26
expected = pl.DataFrame(
27
{
28
"a": ["x", "x", "y", "y", "y"],
29
"b": ["x", "x", None, None, None],
30
}
31
)
32
assert_frame_equal(result, expected)
33
34
35
def test_when_then_chained() -> None:
36
df = pl.DataFrame({"a": [1, 2, 3, 4, 5]})
37
38
expr = (
39
pl.when(pl.col("a") < 3)
40
.then(pl.lit("x"))
41
.when(pl.col("a") > 4)
42
.then(pl.lit("z"))
43
)
44
45
result = df.select(
46
expr.otherwise(pl.lit("y")).alias("a"),
47
expr.alias("b"),
48
)
49
50
expected = pl.DataFrame(
51
{
52
"a": ["x", "x", "y", "y", "z"],
53
"b": ["x", "x", None, None, "z"],
54
}
55
)
56
assert_frame_equal(result, expected)
57
58
59
def test_when_then_invalid_chains() -> None:
60
with pytest.raises(AttributeError):
61
pl.when("a").when("b") # type: ignore[attr-defined]
62
with pytest.raises(AttributeError):
63
pl.when("a").otherwise(2) # type: ignore[attr-defined]
64
with pytest.raises(AttributeError):
65
pl.when("a").then(1).then(2) # type: ignore[attr-defined]
66
with pytest.raises(AttributeError):
67
pl.when("a").then(1).otherwise(2).otherwise(3) # type: ignore[attr-defined]
68
with pytest.raises(AttributeError):
69
pl.when("a").then(1).when("b").when("c") # type: ignore[attr-defined]
70
with pytest.raises(AttributeError):
71
pl.when("a").then(1).when("b").otherwise("2") # type: ignore[attr-defined]
72
with pytest.raises(AttributeError):
73
pl.when("a").then(1).when("b").then(2).when("c").when("d") # type: ignore[attr-defined]
74
75
76
def test_when_then_implicit_none() -> None:
77
df = pl.DataFrame(
78
{
79
"team": ["A", "A", "A", "B", "B", "C"],
80
"points": [11, 8, 10, 6, 6, 5],
81
}
82
)
83
84
result = df.select(
85
pl.when(pl.col("points") > 7).then(pl.lit("Foo")),
86
pl.when(pl.col("points") > 7).then(pl.lit("Foo")).alias("bar"),
87
)
88
89
expected = pl.DataFrame(
90
{
91
"literal": ["Foo", "Foo", "Foo", None, None, None],
92
"bar": ["Foo", "Foo", "Foo", None, None, None],
93
}
94
)
95
assert_frame_equal(result, expected)
96
97
98
def test_when_then_empty_list_5547() -> None:
99
out = pl.DataFrame({"a": []}).select([pl.when(pl.col("a") > 1).then([1])])
100
assert out.shape == (0, 1)
101
assert out.dtypes == [pl.List(pl.Int64)]
102
103
104
def test_nested_when_then_and_wildcard_expansion_6284() -> None:
105
df = pl.DataFrame(
106
{
107
"1": ["a", "b"],
108
"2": ["c", "d"],
109
}
110
)
111
112
out0 = df.with_columns(
113
pl.when(pl.any_horizontal(pl.all() == "a"))
114
.then(pl.lit("a"))
115
.otherwise(
116
pl.when(pl.any_horizontal(pl.all() == "d"))
117
.then(pl.lit("d"))
118
.otherwise(None)
119
)
120
.alias("result")
121
)
122
123
out1 = df.with_columns(
124
pl.when(pl.any_horizontal(pl.all() == "a"))
125
.then(pl.lit("a"))
126
.when(pl.any_horizontal(pl.all() == "d"))
127
.then(pl.lit("d"))
128
.otherwise(None)
129
.alias("result")
130
)
131
132
assert_frame_equal(out0, out1)
133
assert out0.to_dict(as_series=False) == {
134
"1": ["a", "b"],
135
"2": ["c", "d"],
136
"result": ["a", "d"],
137
}
138
139
140
def test_list_zip_with_logical_type() -> None:
141
df = pl.DataFrame(
142
{
143
"start": [datetime(2023, 1, 1, 1, 1, 1), datetime(2023, 1, 1, 1, 1, 1)],
144
"stop": [datetime(2023, 1, 1, 1, 3, 1), datetime(2023, 1, 1, 1, 4, 1)],
145
"use": [1, 0],
146
}
147
)
148
149
df = df.with_columns(
150
pl.datetime_ranges(
151
pl.col("start"), pl.col("stop"), interval="1h", eager=False, closed="left"
152
).alias("interval_1"),
153
pl.datetime_ranges(
154
pl.col("start"), pl.col("stop"), interval="1h", eager=False, closed="left"
155
).alias("interval_2"),
156
)
157
158
out = df.select(
159
pl.when(pl.col("use") == 1)
160
.then(pl.col("interval_2"))
161
.otherwise(pl.col("interval_1"))
162
.alias("interval_new")
163
)
164
assert out.dtypes == [pl.List(pl.Datetime(time_unit="us", time_zone=None))]
165
166
167
def test_type_coercion_when_then_otherwise_2806() -> None:
168
out = (
169
pl.DataFrame({"names": ["foo", "spam", "spam"], "nrs": [1, 2, 3]})
170
.select(
171
pl.when(pl.col("names") == "spam")
172
.then(pl.col("nrs") * 2)
173
.otherwise(pl.lit("other"))
174
.alias("new_col"),
175
)
176
.to_series()
177
)
178
expected = pl.Series("new_col", ["other", "4", "6"])
179
assert out.to_list() == expected.to_list()
180
181
# test it remains float32
182
assert (
183
pl.Series("a", [1.0, 2.0, 3.0], dtype=pl.Float32)
184
.to_frame()
185
.select(pl.when(pl.col("a") > 2.0).then(pl.col("a")).otherwise(0.0))
186
).to_series().dtype == pl.Float32
187
188
189
def test_when_then_edge_cases_3994() -> None:
190
df = pl.DataFrame(data={"id": [1, 1], "type": [2, 2]})
191
192
# this tests if lazy correctly assigns the list schema to the column aggregation
193
assert (
194
df.lazy()
195
.group_by(["id"])
196
.agg(pl.col("type"))
197
.with_columns(
198
pl.when(pl.col("type").list.len() == 0)
199
.then(pl.lit(None))
200
.otherwise(pl.col("type"))
201
.name.keep()
202
)
203
.collect()
204
).to_dict(as_series=False) == {"id": [1], "type": [[2, 2]]}
205
206
# this tests ternary with an empty argument
207
assert (
208
df.filter(pl.col("id") == 42)
209
.group_by(["id"])
210
.agg(pl.col("type"))
211
.with_columns(
212
pl.when(pl.col("type").list.len() == 0)
213
.then(pl.lit(None))
214
.otherwise(pl.col("type"))
215
.name.keep()
216
)
217
).to_dict(as_series=False) == {"id": [], "type": []}
218
219
220
@pytest.mark.may_fail_cloud # reason: object
221
def test_object_when_then_4702() -> None:
222
# please don't ever do this
223
x = pl.DataFrame({"Row": [1, 2], "Type": [pl.Date, pl.UInt8]})
224
225
assert x.with_columns(
226
pl.when(pl.col("Row") == 1)
227
.then(pl.lit(pl.UInt16, allow_object=True))
228
.otherwise(pl.lit(pl.UInt8, allow_object=True))
229
.alias("New_Type")
230
).to_dict(as_series=False) == {
231
"Row": [1, 2],
232
"Type": [pl.Date, pl.UInt8],
233
"New_Type": [pl.UInt16, pl.UInt8],
234
}
235
236
237
def test_comp_categorical_lit_dtype() -> None:
238
df = pl.DataFrame(
239
data={"column": ["a", "b", "e"], "values": [1, 5, 9]},
240
schema=[("column", pl.Categorical), ("more", pl.Int32)],
241
)
242
243
assert df.with_columns(
244
pl.when(pl.col("column") == "e")
245
.then(pl.lit("d"))
246
.otherwise(pl.col("column"))
247
.alias("column")
248
).dtypes == [pl.Categorical, pl.Int32]
249
250
251
def test_comp_incompatible_enum_dtype() -> None:
252
df = pl.DataFrame({"a": pl.Series(["a", "b"], dtype=pl.Enum(["a", "b"]))})
253
254
with pytest.raises(
255
InvalidOperationError,
256
match="conversion from `str` to `enum` failed in column 'scalar'",
257
):
258
df.with_columns(
259
pl.when(pl.col("a") == "a").then(pl.col("a")).otherwise(pl.lit("c"))
260
)
261
262
263
def test_predicate_broadcast() -> None:
264
df = pl.DataFrame(
265
{
266
"key": ["a", "a", "b", "b", "c", "c"],
267
"val": [1, 2, 3, 4, 5, 6],
268
}
269
)
270
out = df.group_by("key", maintain_order=True).agg(
271
agg=pl.when(pl.col("val").min() >= 3).then(pl.col("val")),
272
)
273
assert out.to_dict(as_series=False) == {
274
"key": ["a", "b", "c"],
275
"agg": [[None, None], [3, 4], [5, 6]],
276
}
277
278
279
@pytest.mark.parametrize(
280
"mask_expr",
281
[
282
pl.lit(True),
283
pl.first("true"),
284
pl.lit(False),
285
pl.first("false"),
286
pl.lit(None, dtype=pl.Boolean),
287
pl.col("null_bool"),
288
pl.col("true"),
289
pl.col("false"),
290
],
291
)
292
@pytest.mark.parametrize(
293
"truthy_expr",
294
[
295
pl.lit(1),
296
pl.first("x"),
297
pl.col("x"),
298
],
299
)
300
@pytest.mark.parametrize(
301
"falsy_expr",
302
[
303
pl.lit(1),
304
pl.first("x"),
305
pl.col("x"),
306
],
307
)
308
@pytest.mark.parametrize("maintain_order", [False, True])
309
def test_single_element_broadcast(
310
mask_expr: pl.Expr,
311
truthy_expr: pl.Expr,
312
falsy_expr: pl.Expr,
313
maintain_order: bool,
314
) -> None:
315
df = (
316
pl.Series("x", 5 * [1], dtype=pl.Int32)
317
.to_frame()
318
.with_columns(true=True, false=False, null_bool=pl.lit(None, dtype=pl.Boolean))
319
)
320
321
# Given that the lengths of the mask, truthy and falsy are all either:
322
# - Length 1
323
# - Equal length to the maximum length of the 3.
324
# This test checks that all length-1 exprs are broadcast to the max length.
325
result = df.select(
326
pl.when(mask_expr).then(truthy_expr.alias("x")).otherwise(falsy_expr)
327
)
328
expected = df.select("x").head(
329
df.select(
330
pl.max_horizontal(mask_expr.len(), truthy_expr.len(), falsy_expr.len())
331
).item()
332
)
333
assert_frame_equal(result, expected)
334
335
result = (
336
df.group_by(pl.lit(True).alias("key"), maintain_order=maintain_order)
337
.agg(pl.when(mask_expr).then(truthy_expr.alias("x")).otherwise(falsy_expr))
338
.drop("key")
339
)
340
if expected.height > 1:
341
result = result.explode(cs.all())
342
assert_frame_equal(result, expected, check_row_order=maintain_order)
343
344
345
@pytest.mark.parametrize(
346
"df",
347
[pl.DataFrame({"x": range(5)}), pl.DataFrame({"x": 5 * [[*range(5)]]})],
348
)
349
@pytest.mark.parametrize(
350
"ternary_expr",
351
[
352
pl.when(True).then(pl.col("x").head(2)).otherwise(pl.col("x")),
353
pl.when(False).then(pl.col("x").head(2)).otherwise(pl.col("x")),
354
],
355
)
356
def test_mismatched_height_should_raise(
357
df: pl.DataFrame, ternary_expr: pl.Expr
358
) -> None:
359
with pytest.raises(ShapeError):
360
df.select(ternary_expr)
361
362
with pytest.raises(ShapeError):
363
df.group_by(pl.lit(True).alias("key")).agg(ternary_expr)
364
365
366
@pytest.mark.parametrize("maintain_order", [False, True])
367
def test_when_then_output_name_12380(maintain_order: bool) -> None:
368
df = pl.DataFrame(
369
{"x": range(5), "y": range(5, 10)}, schema={"x": pl.Int8, "y": pl.Int64}
370
).with_columns(true=True, false=False, null_bool=pl.lit(None, dtype=pl.Boolean))
371
372
expect = df.select(pl.col("x").cast(pl.Int64))
373
for true_expr in (pl.first("true"), pl.col("true"), pl.lit(True)):
374
ternary_expr = pl.when(true_expr).then(pl.col("x")).otherwise(pl.col("y"))
375
376
actual = df.select(ternary_expr)
377
assert_frame_equal(
378
expect,
379
actual,
380
)
381
actual = (
382
df.group_by(pl.lit(True).alias("key"), maintain_order=maintain_order)
383
.agg(ternary_expr)
384
.drop("key")
385
.explode(cs.all())
386
)
387
assert_frame_equal(expect, actual, check_row_order=maintain_order)
388
389
expect = df.select(pl.col("y").alias("x"))
390
for false_expr in (
391
pl.first("false"),
392
pl.col("false"),
393
pl.lit(False),
394
pl.first("null_bool"),
395
pl.col("null_bool"),
396
pl.lit(None, dtype=pl.Boolean),
397
):
398
ternary_expr = pl.when(false_expr).then(pl.col("x")).otherwise(pl.col("y"))
399
400
actual = df.select(ternary_expr)
401
assert_frame_equal(
402
expect,
403
actual,
404
)
405
actual = (
406
df.group_by(pl.lit(True).alias("key"))
407
.agg(ternary_expr)
408
.drop("key")
409
.explode(cs.all())
410
)
411
assert_frame_equal(
412
expect,
413
actual,
414
)
415
416
417
def test_when_then_nested_non_unit_literal_predicate_agg_broadcast_12242() -> None:
418
df = pl.DataFrame(
419
{
420
"array_name": ["A", "A", "A", "B", "B"],
421
"array_idx": [5, 0, 3, 7, 2],
422
"array_val": [1, 2, 3, 4, 5],
423
}
424
)
425
426
int_range = pl.int_range(pl.min("array_idx"), pl.max("array_idx") + 1)
427
428
is_valid_idx = int_range.is_in("array_idx")
429
430
idxs = is_valid_idx.cum_sum() - 1
431
432
ternary_expr = pl.when(is_valid_idx).then(pl.col("array_val").gather(idxs))
433
434
expect = pl.DataFrame(
435
[
436
pl.Series("array_name", ["A", "B"], dtype=pl.String),
437
pl.Series(
438
"array_val",
439
[[1, None, None, 2, None, 3], [4, None, None, None, None, 5]],
440
dtype=pl.List(pl.Int64),
441
),
442
]
443
)
444
445
assert_frame_equal(
446
expect, df.group_by("array_name").agg(ternary_expr).sort("array_name")
447
)
448
449
450
def test_when_then_non_unit_literal_predicate_agg_broadcast_12382() -> None:
451
df = pl.DataFrame({"id": [1, 1], "value": [0, 3]})
452
453
expect = pl.DataFrame({"id": [1], "literal": [["yes", None, None, "yes", None]]})
454
actual = df.group_by("id").agg(
455
pl.when(pl.int_range(0, 5).is_in("value")).then(pl.lit("yes"))
456
)
457
458
assert_frame_equal(expect, actual)
459
460
461
def test_when_then_binary_op_predicate_agg_12526() -> None:
462
df = pl.DataFrame(
463
{
464
"a": [1, 1, 1],
465
"b": [1, 2, 5],
466
}
467
)
468
469
expect = pl.DataFrame(
470
{"a": [1], "col": [None]}, schema={"a": pl.Int64, "col": pl.String}
471
)
472
473
actual = df.group_by("a").agg(
474
col=(
475
pl.when(
476
pl.col("a").shift(1) > 2,
477
pl.col("b").is_not_null(),
478
)
479
.then(pl.lit("abc"))
480
.when(
481
pl.col("a").shift(1) > 1,
482
pl.col("b").is_not_null(),
483
)
484
.then(pl.lit("def"))
485
.otherwise(pl.lit(None))
486
.first()
487
)
488
)
489
490
assert_frame_equal(expect, actual)
491
492
493
def test_when_predicates_kwargs() -> None:
494
df = pl.DataFrame(
495
{
496
"x": [10, 20, 30, 40],
497
"y": [15, -20, None, 1],
498
"z": ["a", "b", "c", "d"],
499
}
500
)
501
assert_frame_equal( # kwargs only
502
df.select(matched=pl.when(x=30, z="c").then(True).otherwise(False)),
503
pl.DataFrame({"matched": [False, False, True, False]}),
504
)
505
assert_frame_equal( # mixed predicates & kwargs
506
df.select(matched=pl.when(pl.col("x") < 30, z="b").then(True).otherwise(False)),
507
pl.DataFrame({"matched": [False, True, False, False]}),
508
)
509
assert_frame_equal( # chained when/then with mixed predicates/kwargs
510
df.select(
511
misc=pl.when(pl.col("x") > 50)
512
.then(pl.lit("x>50"))
513
.when(y=1)
514
.then(pl.lit("y=1"))
515
.when(pl.col("z").is_in(["a", "b"]), pl.col("y") < 0)
516
.then(pl.lit("z in (a|b), y<0"))
517
.otherwise(pl.lit("?"))
518
),
519
pl.DataFrame({"misc": ["?", "z in (a|b), y<0", "?", "y=1"]}),
520
)
521
522
523
def test_when_then_null_broadcast() -> None:
524
assert (
525
pl.select(
526
pl.when(pl.repeat(True, 2, dtype=pl.Boolean)).then(
527
pl.repeat(None, 1, dtype=pl.Null)
528
)
529
).height
530
== 2
531
)
532
533
534
@pytest.mark.slow
535
@pytest.mark.parametrize("len", [1, 10, 100, 500])
536
@pytest.mark.parametrize(
537
("dtype", "vals"),
538
[
539
pytest.param(pl.Boolean, [False, True], id="Boolean"),
540
pytest.param(pl.UInt8, [0, 1], id="UInt8"),
541
pytest.param(pl.UInt16, [0, 1], id="UInt16"),
542
pytest.param(pl.UInt32, [0, 1], id="UInt32"),
543
pytest.param(pl.UInt64, [0, 1], id="UInt64"),
544
pytest.param(pl.Float32, [0.0, 1.0], id="Float32"),
545
pytest.param(pl.Float64, [0.0, 1.0], id="Float64"),
546
pytest.param(pl.String, ["0", "12"], id="String"),
547
pytest.param(pl.Array(pl.String, 2), [["0", "1"], ["3", "4"]], id="StrArray"),
548
pytest.param(pl.Array(pl.Int64, 2), [[0, 1], [3, 4]], id="IntArray"),
549
pytest.param(pl.List(pl.String), [["0"], ["1", "2"]], id="List"),
550
pytest.param(
551
pl.Struct({"foo": pl.Int32, "bar": pl.String}),
552
[{"foo": 0, "bar": "1"}, {"foo": 1, "bar": "2"}],
553
id="Struct",
554
),
555
pytest.param(
556
pl.Object,
557
["x", "y"],
558
id="Object",
559
marks=pytest.mark.may_fail_cloud,
560
# reason: objects are not allowed in cloud
561
),
562
],
563
)
564
@pytest.mark.parametrize("broadcast", list(itertools.product([False, True], repeat=3)))
565
def test_when_then_parametric(
566
len: int, dtype: pl.DataType, vals: list[Any], broadcast: list[bool]
567
) -> None:
568
# Makes no sense to broadcast all columns.
569
if all(broadcast):
570
return
571
572
rng = random.Random(42)
573
574
for _ in range(10):
575
mask = rng.choices([False, True, None], k=len)
576
if_true = rng.choices(vals + [None], k=len)
577
if_false = rng.choices(vals + [None], k=len)
578
579
py_mask, py_true, py_false = (
580
[c[0]] * len if b else c
581
for b, c in zip(broadcast, [mask, if_true, if_false], strict=True)
582
)
583
pl_mask, pl_true, pl_false = (
584
c.first() if b else c
585
for b, c in zip(
586
broadcast, [pl.col.mask, pl.col.if_true, pl.col.if_false], strict=True
587
)
588
)
589
590
ref = pl.DataFrame(
591
{
592
"if_true": [
593
t if m else f
594
for m, t, f in zip(py_mask, py_true, py_false, strict=True)
595
]
596
},
597
schema={"if_true": dtype},
598
)
599
df = pl.DataFrame(
600
{
601
"mask": mask,
602
"if_true": if_true,
603
"if_false": if_false,
604
},
605
schema={"mask": pl.Boolean, "if_true": dtype, "if_false": dtype},
606
)
607
608
ans = df.select(pl.when(pl_mask).then(pl_true).otherwise(pl_false))
609
if dtype != pl.Object:
610
assert_frame_equal(ref, ans)
611
else:
612
assert ref["if_true"].to_list() == ans["if_true"].to_list()
613
614
615
def test_when_then_else_struct_18961() -> None:
616
v1 = [None, {"foo": 0, "bar": "1"}]
617
v2 = [{"foo": 0, "bar": "1"}, {"foo": 0, "bar": "1"}]
618
619
df = pl.DataFrame({"left": v1, "right": v2, "mask": [False, True]})
620
621
expected = [{"foo": 0, "bar": "1"}, {"foo": 0, "bar": "1"}]
622
ans = (
623
df.select(
624
pl.when(pl.col.mask).then(pl.col.left).otherwise(pl.col.right.first())
625
)
626
.get_column("left")
627
.to_list()
628
)
629
assert expected == ans
630
631
df = pl.DataFrame({"left": v2, "right": v1, "mask": [True, False]})
632
633
expected = [{"foo": 0, "bar": "1"}, {"foo": 0, "bar": "1"}]
634
ans = (
635
df.select(
636
pl.when(pl.col.mask).then(pl.col.left.first()).otherwise(pl.col.right)
637
)
638
.get_column("left")
639
.to_list()
640
)
641
assert expected == ans
642
643
df = pl.DataFrame({"left": v1, "right": v2, "mask": [True, False]})
644
645
expected2 = [None, {"foo": 0, "bar": "1"}]
646
ans = (
647
df.select(
648
pl.when(pl.col.mask)
649
.then(pl.col.left.first())
650
.otherwise(pl.col.right.first())
651
)
652
.get_column("left")
653
.to_list()
654
)
655
assert expected2 == ans
656
657
658
def test_when_then_supertype_15975() -> None:
659
df = pl.DataFrame({"a": [1, 2, 3]})
660
661
assert df.with_columns(
662
pl.when(True).then(1 ** pl.col("a") + 1.0 * pl.col("a"))
663
).to_dict(as_series=False) == {"a": [1, 2, 3], "literal": [2.0, 3.0, 4.0]}
664
665
666
def test_when_then_supertype_15975_comment() -> None:
667
df = pl.LazyFrame({"foo": [1, 3, 4], "bar": [3, 4, 0]})
668
669
q = df.with_columns(
670
pl.when(pl.col("foo") == 1)
671
.then(1)
672
.when(pl.col("foo") == 2)
673
.then(4)
674
.when(pl.col("foo") == 3)
675
.then(1.5)
676
.when(pl.col("foo") == 4)
677
.then(16)
678
.otherwise(0)
679
.alias("val")
680
)
681
682
assert q.collect()["val"].to_list() == [1.0, 1.5, 16.0]
683
684
685
def test_chained_when_no_subclass_17142() -> None:
686
# https://github.com/pola-rs/polars/pull/17142
687
when = pl.when(True).then(1).when(True)
688
689
assert not isinstance(when, pl.Expr)
690
assert "<polars.expr.whenthen.ChainedWhen object at" in str(when)
691
692
693
def test_when_then_chunked_structs_18673() -> None:
694
df = pl.DataFrame(
695
[
696
pl.Series("x", [{"a": 1}]),
697
pl.Series("b", [False]),
698
]
699
)
700
701
df = df.vstack(df)
702
703
# This used to panic
704
assert_frame_equal(
705
df.select(pl.when(pl.col.b).then(pl.first("x")).otherwise(pl.first("x"))),
706
pl.DataFrame({"x": [{"a": 1}, {"a": 1}]}),
707
)
708
709
710
some_scalar = pl.Series("a", [{"x": 2}], pl.Struct)
711
none_scalar = pl.Series("a", [None], pl.Struct({"x": pl.Int64}))
712
column = pl.Series("a", [{"x": 2}, {"x": 2}], pl.Struct)
713
714
715
@pytest.mark.parametrize(
716
"values",
717
[
718
(some_scalar, some_scalar),
719
(some_scalar, pl.col.a),
720
(some_scalar, none_scalar),
721
(some_scalar, column),
722
(none_scalar, pl.col.a),
723
(none_scalar, none_scalar),
724
(none_scalar, column),
725
(pl.col.a, pl.col.a),
726
(pl.col.a, column),
727
(column, column),
728
],
729
)
730
def test_struct_when_then_broadcasting_combinations_19122(
731
values: tuple[Any, Any],
732
) -> None:
733
lv, rv = values
734
735
df = pl.Series("a", [{"x": 1}, {"x": 1}], pl.Struct).to_frame()
736
737
assert_frame_equal(
738
df.select(
739
pl.when(pl.col.a.struct.field("x") == 0).then(lv).otherwise(rv).alias("a")
740
),
741
df.select(
742
pl.when(pl.col.a.struct.field("x") == 0).then(None).otherwise(rv).alias("a")
743
),
744
)
745
746
assert_frame_equal(
747
df.select(
748
pl.when(pl.col.a.struct.field("x") != 0).then(rv).otherwise(lv).alias("a")
749
),
750
df.select(
751
pl.when(pl.col.a.struct.field("x") != 0).then(rv).otherwise(None).alias("a")
752
),
753
)
754
755
756
@pytest.mark.may_fail_cloud # reason str.to_decimal is an eager construct
757
def test_when_then_to_decimal_18375() -> None:
758
df = pl.DataFrame({"a": ["1.23", "4.56"]})
759
760
result = df.with_columns(
761
b=pl.when(False).then(None).otherwise(pl.col("a").str.to_decimal(scale=2)),
762
c=pl.when(True).then(pl.col("a").str.to_decimal(scale=2)),
763
)
764
expected = pl.DataFrame(
765
{
766
"a": ["1.23", "4.56"],
767
"b": ["1.23", "4.56"],
768
"c": ["1.23", "4.56"],
769
},
770
schema={"a": pl.String, "b": pl.Decimal(scale=2), "c": pl.Decimal(scale=2)},
771
)
772
assert_frame_equal(result, expected)
773
774
775
def test_when_then_chunked_fill_null_22794() -> None:
776
df = pl.DataFrame(
777
{
778
"node": [{"x": "a", "y": "a"}, {"x": "b", "y": "b"}, {"x": "c", "y": "c"}],
779
"level": [0, 1, 2],
780
}
781
)
782
783
out = pl.concat([df.slice(0, 1), df.slice(1, 1), df.slice(2, 1)]).with_columns(
784
pl.when(level=1).then("node").forward_fill()
785
)
786
expected = pl.DataFrame(
787
{
788
"node": [None, {"x": "b", "y": "b"}, {"x": "b", "y": "b"}],
789
"level": [0, 1, 2],
790
}
791
)
792
793
assert_frame_equal(out, expected)
794
795
796
def test_when_then_complex_conditional_22959() -> None:
797
df = pl.DataFrame(
798
{"B": [None, "T1", "T2"], "C": [None, None, [1.0]], "E": [None, 2.0, None]}
799
)
800
801
res = df.with_columns(
802
Result=(
803
pl.when(B="T1")
804
.then(pl.struct(X="C", Y="C"))
805
.when(B="T2")
806
.then(pl.struct(X=pl.concat_list([3.0, "E"])))
807
)
808
)
809
810
assert_series_equal(
811
res["Result"],
812
pl.Series(
813
"Result",
814
[None, {"X": None, "Y": None}, {"X": [3.0, None], "Y": None}],
815
pl.Struct({"X": pl.List(pl.Float64), "Y": pl.List(pl.Float64)}),
816
),
817
)
818
819
820
def test_when_then_simplification() -> None:
821
lf = pl.LazyFrame({"a": [12]})
822
assert (
823
"""[col("a")]"""
824
in (
825
lf.select(pl.when(True).then(pl.col("a")).otherwise(pl.col("a") * 2))
826
).explain()
827
)
828
assert (
829
"""(col("a")) * (2)"""
830
in (
831
lf.select(pl.when(False).then(pl.col("a")).otherwise(pl.col("a") * 2))
832
).explain()
833
)
834
835
836
def test_when_then_in_group_by_aggregated_22922() -> None:
837
df = pl.DataFrame({"group": ["x", "y", "x", "y"], "value": [1, 2, 3, 4]})
838
out = df.group_by("group", maintain_order=True).agg(
839
expr=pl.when(group="x").then(pl.col.value.max()).first()
840
)
841
expected = pl.DataFrame({"group": ["x", "y"], "expr": [3, None]})
842
assert_frame_equal(out, expected)
843
844