Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/functions/test_when_then.py
6939 views
1
from __future__ import annotations
2
3
import itertools
4
import random
5
from datetime import datetime
6
from typing import Any
7
8
import pytest
9
10
import polars as pl
11
import polars.selectors as cs
12
from polars.exceptions import InvalidOperationError, ShapeError
13
from polars.testing import assert_frame_equal, assert_series_equal
14
15
16
def test_when_then() -> None:
17
df = pl.DataFrame({"a": [1, 2, 3, 4, 5]})
18
19
expr = pl.when(pl.col("a") < 3).then(pl.lit("x"))
20
21
result = df.select(
22
expr.otherwise(pl.lit("y")).alias("a"),
23
expr.alias("b"),
24
)
25
26
expected = pl.DataFrame(
27
{
28
"a": ["x", "x", "y", "y", "y"],
29
"b": ["x", "x", None, None, None],
30
}
31
)
32
assert_frame_equal(result, expected)
33
34
35
def test_when_then_chained() -> None:
36
df = pl.DataFrame({"a": [1, 2, 3, 4, 5]})
37
38
expr = (
39
pl.when(pl.col("a") < 3)
40
.then(pl.lit("x"))
41
.when(pl.col("a") > 4)
42
.then(pl.lit("z"))
43
)
44
45
result = df.select(
46
expr.otherwise(pl.lit("y")).alias("a"),
47
expr.alias("b"),
48
)
49
50
expected = pl.DataFrame(
51
{
52
"a": ["x", "x", "y", "y", "z"],
53
"b": ["x", "x", None, None, "z"],
54
}
55
)
56
assert_frame_equal(result, expected)
57
58
59
def test_when_then_invalid_chains() -> None:
60
with pytest.raises(AttributeError):
61
pl.when("a").when("b") # type: ignore[attr-defined]
62
with pytest.raises(AttributeError):
63
pl.when("a").otherwise(2) # type: ignore[attr-defined]
64
with pytest.raises(AttributeError):
65
pl.when("a").then(1).then(2) # type: ignore[attr-defined]
66
with pytest.raises(AttributeError):
67
pl.when("a").then(1).otherwise(2).otherwise(3) # type: ignore[attr-defined]
68
with pytest.raises(AttributeError):
69
pl.when("a").then(1).when("b").when("c") # type: ignore[attr-defined]
70
with pytest.raises(AttributeError):
71
pl.when("a").then(1).when("b").otherwise("2") # type: ignore[attr-defined]
72
with pytest.raises(AttributeError):
73
pl.when("a").then(1).when("b").then(2).when("c").when("d") # type: ignore[attr-defined]
74
75
76
def test_when_then_implicit_none() -> None:
77
df = pl.DataFrame(
78
{
79
"team": ["A", "A", "A", "B", "B", "C"],
80
"points": [11, 8, 10, 6, 6, 5],
81
}
82
)
83
84
result = df.select(
85
pl.when(pl.col("points") > 7).then(pl.lit("Foo")),
86
pl.when(pl.col("points") > 7).then(pl.lit("Foo")).alias("bar"),
87
)
88
89
expected = pl.DataFrame(
90
{
91
"literal": ["Foo", "Foo", "Foo", None, None, None],
92
"bar": ["Foo", "Foo", "Foo", None, None, None],
93
}
94
)
95
assert_frame_equal(result, expected)
96
97
98
def test_when_then_empty_list_5547() -> None:
99
out = pl.DataFrame({"a": []}).select([pl.when(pl.col("a") > 1).then([1])])
100
assert out.shape == (0, 1)
101
assert out.dtypes == [pl.List(pl.Int64)]
102
103
104
def test_nested_when_then_and_wildcard_expansion_6284() -> None:
105
df = pl.DataFrame(
106
{
107
"1": ["a", "b"],
108
"2": ["c", "d"],
109
}
110
)
111
112
out0 = df.with_columns(
113
pl.when(pl.any_horizontal(pl.all() == "a"))
114
.then(pl.lit("a"))
115
.otherwise(
116
pl.when(pl.any_horizontal(pl.all() == "d"))
117
.then(pl.lit("d"))
118
.otherwise(None)
119
)
120
.alias("result")
121
)
122
123
out1 = df.with_columns(
124
pl.when(pl.any_horizontal(pl.all() == "a"))
125
.then(pl.lit("a"))
126
.when(pl.any_horizontal(pl.all() == "d"))
127
.then(pl.lit("d"))
128
.otherwise(None)
129
.alias("result")
130
)
131
132
assert_frame_equal(out0, out1)
133
assert out0.to_dict(as_series=False) == {
134
"1": ["a", "b"],
135
"2": ["c", "d"],
136
"result": ["a", "d"],
137
}
138
139
140
def test_list_zip_with_logical_type() -> None:
141
df = pl.DataFrame(
142
{
143
"start": [datetime(2023, 1, 1, 1, 1, 1), datetime(2023, 1, 1, 1, 1, 1)],
144
"stop": [datetime(2023, 1, 1, 1, 3, 1), datetime(2023, 1, 1, 1, 4, 1)],
145
"use": [1, 0],
146
}
147
)
148
149
df = df.with_columns(
150
pl.datetime_ranges(
151
pl.col("start"), pl.col("stop"), interval="1h", eager=False, closed="left"
152
).alias("interval_1"),
153
pl.datetime_ranges(
154
pl.col("start"), pl.col("stop"), interval="1h", eager=False, closed="left"
155
).alias("interval_2"),
156
)
157
158
out = df.select(
159
pl.when(pl.col("use") == 1)
160
.then(pl.col("interval_2"))
161
.otherwise(pl.col("interval_1"))
162
.alias("interval_new")
163
)
164
assert out.dtypes == [pl.List(pl.Datetime(time_unit="us", time_zone=None))]
165
166
167
def test_type_coercion_when_then_otherwise_2806() -> None:
168
out = (
169
pl.DataFrame({"names": ["foo", "spam", "spam"], "nrs": [1, 2, 3]})
170
.select(
171
pl.when(pl.col("names") == "spam")
172
.then(pl.col("nrs") * 2)
173
.otherwise(pl.lit("other"))
174
.alias("new_col"),
175
)
176
.to_series()
177
)
178
expected = pl.Series("new_col", ["other", "4", "6"])
179
assert out.to_list() == expected.to_list()
180
181
# test it remains float32
182
assert (
183
pl.Series("a", [1.0, 2.0, 3.0], dtype=pl.Float32)
184
.to_frame()
185
.select(pl.when(pl.col("a") > 2.0).then(pl.col("a")).otherwise(0.0))
186
).to_series().dtype == pl.Float32
187
188
189
def test_when_then_edge_cases_3994() -> None:
190
df = pl.DataFrame(data={"id": [1, 1], "type": [2, 2]})
191
192
# this tests if lazy correctly assigns the list schema to the column aggregation
193
assert (
194
df.lazy()
195
.group_by(["id"])
196
.agg(pl.col("type"))
197
.with_columns(
198
pl.when(pl.col("type").list.len() == 0)
199
.then(pl.lit(None))
200
.otherwise(pl.col("type"))
201
.name.keep()
202
)
203
.collect()
204
).to_dict(as_series=False) == {"id": [1], "type": [[2, 2]]}
205
206
# this tests ternary with an empty argument
207
assert (
208
df.filter(pl.col("id") == 42)
209
.group_by(["id"])
210
.agg(pl.col("type"))
211
.with_columns(
212
pl.when(pl.col("type").list.len() == 0)
213
.then(pl.lit(None))
214
.otherwise(pl.col("type"))
215
.name.keep()
216
)
217
).to_dict(as_series=False) == {"id": [], "type": []}
218
219
220
@pytest.mark.may_fail_cloud # reason: object
221
def test_object_when_then_4702() -> None:
222
# please don't ever do this
223
x = pl.DataFrame({"Row": [1, 2], "Type": [pl.Date, pl.UInt8]})
224
225
assert x.with_columns(
226
pl.when(pl.col("Row") == 1)
227
.then(pl.lit(pl.UInt16, allow_object=True))
228
.otherwise(pl.lit(pl.UInt8, allow_object=True))
229
.alias("New_Type")
230
).to_dict(as_series=False) == {
231
"Row": [1, 2],
232
"Type": [pl.Date, pl.UInt8],
233
"New_Type": [pl.UInt16, pl.UInt8],
234
}
235
236
237
def test_comp_categorical_lit_dtype() -> None:
238
df = pl.DataFrame(
239
data={"column": ["a", "b", "e"], "values": [1, 5, 9]},
240
schema=[("column", pl.Categorical), ("more", pl.Int32)],
241
)
242
243
assert df.with_columns(
244
pl.when(pl.col("column") == "e")
245
.then(pl.lit("d"))
246
.otherwise(pl.col("column"))
247
.alias("column")
248
).dtypes == [pl.Categorical, pl.Int32]
249
250
251
def test_comp_incompatible_enum_dtype() -> None:
252
df = pl.DataFrame({"a": pl.Series(["a", "b"], dtype=pl.Enum(["a", "b"]))})
253
254
with pytest.raises(
255
InvalidOperationError,
256
match="conversion from `str` to `enum` failed in column 'scalar'",
257
):
258
df.with_columns(
259
pl.when(pl.col("a") == "a").then(pl.col("a")).otherwise(pl.lit("c"))
260
)
261
262
263
def test_predicate_broadcast() -> None:
264
df = pl.DataFrame(
265
{
266
"key": ["a", "a", "b", "b", "c", "c"],
267
"val": [1, 2, 3, 4, 5, 6],
268
}
269
)
270
out = df.group_by("key", maintain_order=True).agg(
271
agg=pl.when(pl.col("val").min() >= 3).then(pl.col("val")),
272
)
273
assert out.to_dict(as_series=False) == {
274
"key": ["a", "b", "c"],
275
"agg": [[None, None], [3, 4], [5, 6]],
276
}
277
278
279
@pytest.mark.parametrize(
280
"mask_expr",
281
[
282
pl.lit(True),
283
pl.first("true"),
284
pl.lit(False),
285
pl.first("false"),
286
pl.lit(None, dtype=pl.Boolean),
287
pl.col("null_bool"),
288
pl.col("true"),
289
pl.col("false"),
290
],
291
)
292
@pytest.mark.parametrize(
293
"truthy_expr",
294
[
295
pl.lit(1),
296
pl.first("x"),
297
pl.col("x"),
298
],
299
)
300
@pytest.mark.parametrize(
301
"falsy_expr",
302
[
303
pl.lit(1),
304
pl.first("x"),
305
pl.col("x"),
306
],
307
)
308
@pytest.mark.parametrize("maintain_order", [False, True])
309
def test_single_element_broadcast(
310
mask_expr: pl.Expr,
311
truthy_expr: pl.Expr,
312
falsy_expr: pl.Expr,
313
maintain_order: bool,
314
) -> None:
315
df = (
316
pl.Series("x", 5 * [1], dtype=pl.Int32)
317
.to_frame()
318
.with_columns(true=True, false=False, null_bool=pl.lit(None, dtype=pl.Boolean))
319
)
320
321
# Given that the lengths of the mask, truthy and falsy are all either:
322
# - Length 1
323
# - Equal length to the maximum length of the 3.
324
# This test checks that all length-1 exprs are broadcast to the max length.
325
result = df.select(
326
pl.when(mask_expr).then(truthy_expr.alias("x")).otherwise(falsy_expr)
327
)
328
expected = df.select("x").head(
329
df.select(
330
pl.max_horizontal(mask_expr.len(), truthy_expr.len(), falsy_expr.len())
331
).item()
332
)
333
assert_frame_equal(result, expected)
334
335
result = (
336
df.group_by(pl.lit(True).alias("key"), maintain_order=maintain_order)
337
.agg(pl.when(mask_expr).then(truthy_expr.alias("x")).otherwise(falsy_expr))
338
.drop("key")
339
)
340
if expected.height > 1:
341
result = result.explode(cs.all())
342
assert_frame_equal(result, expected, check_row_order=maintain_order)
343
344
345
@pytest.mark.parametrize(
346
"df",
347
[pl.DataFrame({"x": range(5)}), pl.DataFrame({"x": 5 * [[*range(5)]]})],
348
)
349
@pytest.mark.parametrize(
350
"ternary_expr",
351
[
352
pl.when(True).then(pl.col("x").head(2)).otherwise(pl.col("x")),
353
pl.when(False).then(pl.col("x").head(2)).otherwise(pl.col("x")),
354
],
355
)
356
def test_mismatched_height_should_raise(
357
df: pl.DataFrame, ternary_expr: pl.Expr
358
) -> None:
359
with pytest.raises(ShapeError):
360
df.select(ternary_expr)
361
362
with pytest.raises(ShapeError):
363
df.group_by(pl.lit(True).alias("key")).agg(ternary_expr)
364
365
366
@pytest.mark.parametrize("maintain_order", [False, True])
367
def test_when_then_output_name_12380(maintain_order: bool) -> None:
368
df = pl.DataFrame(
369
{"x": range(5), "y": range(5, 10)}, schema={"x": pl.Int8, "y": pl.Int64}
370
).with_columns(true=True, false=False, null_bool=pl.lit(None, dtype=pl.Boolean))
371
372
expect = df.select(pl.col("x").cast(pl.Int64))
373
for true_expr in (pl.first("true"), pl.col("true"), pl.lit(True)):
374
ternary_expr = pl.when(true_expr).then(pl.col("x")).otherwise(pl.col("y"))
375
376
actual = df.select(ternary_expr)
377
assert_frame_equal(
378
expect,
379
actual,
380
)
381
actual = (
382
df.group_by(pl.lit(True).alias("key"), maintain_order=maintain_order)
383
.agg(ternary_expr)
384
.drop("key")
385
.explode(cs.all())
386
)
387
assert_frame_equal(expect, actual, check_row_order=maintain_order)
388
389
expect = df.select(pl.col("y").alias("x"))
390
for false_expr in (
391
pl.first("false"),
392
pl.col("false"),
393
pl.lit(False),
394
pl.first("null_bool"),
395
pl.col("null_bool"),
396
pl.lit(None, dtype=pl.Boolean),
397
):
398
ternary_expr = pl.when(false_expr).then(pl.col("x")).otherwise(pl.col("y"))
399
400
actual = df.select(ternary_expr)
401
assert_frame_equal(
402
expect,
403
actual,
404
)
405
actual = (
406
df.group_by(pl.lit(True).alias("key"))
407
.agg(ternary_expr)
408
.drop("key")
409
.explode(cs.all())
410
)
411
assert_frame_equal(
412
expect,
413
actual,
414
)
415
416
417
def test_when_then_nested_non_unit_literal_predicate_agg_broadcast_12242() -> None:
418
df = pl.DataFrame(
419
{
420
"array_name": ["A", "A", "A", "B", "B"],
421
"array_idx": [5, 0, 3, 7, 2],
422
"array_val": [1, 2, 3, 4, 5],
423
}
424
)
425
426
int_range = pl.int_range(pl.min("array_idx"), pl.max("array_idx") + 1)
427
428
is_valid_idx = int_range.is_in("array_idx")
429
430
idxs = is_valid_idx.cum_sum() - 1
431
432
ternary_expr = pl.when(is_valid_idx).then(pl.col("array_val").gather(idxs))
433
434
expect = pl.DataFrame(
435
[
436
pl.Series("array_name", ["A", "B"], dtype=pl.String),
437
pl.Series(
438
"array_val",
439
[[1, None, None, 2, None, 3], [4, None, None, None, None, 5]],
440
dtype=pl.List(pl.Int64),
441
),
442
]
443
)
444
445
assert_frame_equal(
446
expect, df.group_by("array_name").agg(ternary_expr).sort("array_name")
447
)
448
449
450
def test_when_then_non_unit_literal_predicate_agg_broadcast_12382() -> None:
451
df = pl.DataFrame({"id": [1, 1], "value": [0, 3]})
452
453
expect = pl.DataFrame({"id": [1], "literal": [["yes", None, None, "yes", None]]})
454
actual = df.group_by("id").agg(
455
pl.when(pl.int_range(0, 5).is_in("value")).then(pl.lit("yes"))
456
)
457
458
assert_frame_equal(expect, actual)
459
460
461
def test_when_then_binary_op_predicate_agg_12526() -> None:
462
df = pl.DataFrame(
463
{
464
"a": [1, 1, 1],
465
"b": [1, 2, 5],
466
}
467
)
468
469
expect = pl.DataFrame(
470
{"a": [1], "col": [None]}, schema={"a": pl.Int64, "col": pl.String}
471
)
472
473
actual = df.group_by("a").agg(
474
col=(
475
pl.when(
476
pl.col("a").shift(1) > 2,
477
pl.col("b").is_not_null(),
478
)
479
.then(pl.lit("abc"))
480
.when(
481
pl.col("a").shift(1) > 1,
482
pl.col("b").is_not_null(),
483
)
484
.then(pl.lit("def"))
485
.otherwise(pl.lit(None))
486
.first()
487
)
488
)
489
490
assert_frame_equal(expect, actual)
491
492
493
def test_when_predicates_kwargs() -> None:
494
df = pl.DataFrame(
495
{
496
"x": [10, 20, 30, 40],
497
"y": [15, -20, None, 1],
498
"z": ["a", "b", "c", "d"],
499
}
500
)
501
assert_frame_equal( # kwargs only
502
df.select(matched=pl.when(x=30, z="c").then(True).otherwise(False)),
503
pl.DataFrame({"matched": [False, False, True, False]}),
504
)
505
assert_frame_equal( # mixed predicates & kwargs
506
df.select(matched=pl.when(pl.col("x") < 30, z="b").then(True).otherwise(False)),
507
pl.DataFrame({"matched": [False, True, False, False]}),
508
)
509
assert_frame_equal( # chained when/then with mixed predicates/kwargs
510
df.select(
511
misc=pl.when(pl.col("x") > 50)
512
.then(pl.lit("x>50"))
513
.when(y=1)
514
.then(pl.lit("y=1"))
515
.when(pl.col("z").is_in(["a", "b"]), pl.col("y") < 0)
516
.then(pl.lit("z in (a|b), y<0"))
517
.otherwise(pl.lit("?"))
518
),
519
pl.DataFrame({"misc": ["?", "z in (a|b), y<0", "?", "y=1"]}),
520
)
521
522
523
def test_when_then_null_broadcast() -> None:
524
assert (
525
pl.select(
526
pl.when(pl.repeat(True, 2, dtype=pl.Boolean)).then(
527
pl.repeat(None, 1, dtype=pl.Null)
528
)
529
).height
530
== 2
531
)
532
533
534
@pytest.mark.slow
535
@pytest.mark.parametrize("len", [1, 10, 100, 500])
536
@pytest.mark.parametrize(
537
("dtype", "vals"),
538
[
539
pytest.param(pl.Boolean, [False, True], id="Boolean"),
540
pytest.param(pl.UInt8, [0, 1], id="UInt8"),
541
pytest.param(pl.UInt16, [0, 1], id="UInt16"),
542
pytest.param(pl.UInt32, [0, 1], id="UInt32"),
543
pytest.param(pl.UInt64, [0, 1], id="UInt64"),
544
pytest.param(pl.Float32, [0.0, 1.0], id="Float32"),
545
pytest.param(pl.Float64, [0.0, 1.0], id="Float64"),
546
pytest.param(pl.String, ["0", "12"], id="String"),
547
pytest.param(pl.Array(pl.String, 2), [["0", "1"], ["3", "4"]], id="StrArray"),
548
pytest.param(pl.Array(pl.Int64, 2), [[0, 1], [3, 4]], id="IntArray"),
549
pytest.param(pl.List(pl.String), [["0"], ["1", "2"]], id="List"),
550
pytest.param(
551
pl.Struct({"foo": pl.Int32, "bar": pl.String}),
552
[{"foo": 0, "bar": "1"}, {"foo": 1, "bar": "2"}],
553
id="Struct",
554
),
555
pytest.param(
556
pl.Object,
557
["x", "y"],
558
id="Object",
559
marks=pytest.mark.may_fail_cloud,
560
# reason: objects are not allowed in cloud
561
),
562
],
563
)
564
@pytest.mark.parametrize("broadcast", list(itertools.product([False, True], repeat=3)))
565
def test_when_then_parametric(
566
len: int, dtype: pl.DataType, vals: list[Any], broadcast: list[bool]
567
) -> None:
568
# Makes no sense to broadcast all columns.
569
if all(broadcast):
570
return
571
572
rng = random.Random(42)
573
574
for _ in range(10):
575
mask = rng.choices([False, True, None], k=len)
576
if_true = rng.choices(vals + [None], k=len)
577
if_false = rng.choices(vals + [None], k=len)
578
579
py_mask, py_true, py_false = (
580
[c[0]] * len if b else c
581
for b, c in zip(broadcast, [mask, if_true, if_false])
582
)
583
pl_mask, pl_true, pl_false = (
584
c.first() if b else c
585
for b, c in zip(broadcast, [pl.col.mask, pl.col.if_true, pl.col.if_false])
586
)
587
588
ref = pl.DataFrame(
589
{"if_true": [t if m else f for m, t, f in zip(py_mask, py_true, py_false)]},
590
schema={"if_true": dtype},
591
)
592
df = pl.DataFrame(
593
{
594
"mask": mask,
595
"if_true": if_true,
596
"if_false": if_false,
597
},
598
schema={"mask": pl.Boolean, "if_true": dtype, "if_false": dtype},
599
)
600
601
ans = df.select(pl.when(pl_mask).then(pl_true).otherwise(pl_false))
602
if dtype != pl.Object:
603
assert_frame_equal(ref, ans)
604
else:
605
assert ref["if_true"].to_list() == ans["if_true"].to_list()
606
607
608
def test_when_then_else_struct_18961() -> None:
609
v1 = [None, {"foo": 0, "bar": "1"}]
610
v2 = [{"foo": 0, "bar": "1"}, {"foo": 0, "bar": "1"}]
611
612
df = pl.DataFrame({"left": v1, "right": v2, "mask": [False, True]})
613
614
expected = [{"foo": 0, "bar": "1"}, {"foo": 0, "bar": "1"}]
615
ans = (
616
df.select(
617
pl.when(pl.col.mask).then(pl.col.left).otherwise(pl.col.right.first())
618
)
619
.get_column("left")
620
.to_list()
621
)
622
assert expected == ans
623
624
df = pl.DataFrame({"left": v2, "right": v1, "mask": [True, False]})
625
626
expected = [{"foo": 0, "bar": "1"}, {"foo": 0, "bar": "1"}]
627
ans = (
628
df.select(
629
pl.when(pl.col.mask).then(pl.col.left.first()).otherwise(pl.col.right)
630
)
631
.get_column("left")
632
.to_list()
633
)
634
assert expected == ans
635
636
df = pl.DataFrame({"left": v1, "right": v2, "mask": [True, False]})
637
638
expected2 = [None, {"foo": 0, "bar": "1"}]
639
ans = (
640
df.select(
641
pl.when(pl.col.mask)
642
.then(pl.col.left.first())
643
.otherwise(pl.col.right.first())
644
)
645
.get_column("left")
646
.to_list()
647
)
648
assert expected2 == ans
649
650
651
def test_when_then_supertype_15975() -> None:
652
df = pl.DataFrame({"a": [1, 2, 3]})
653
654
assert df.with_columns(
655
pl.when(True).then(1 ** pl.col("a") + 1.0 * pl.col("a"))
656
).to_dict(as_series=False) == {"a": [1, 2, 3], "literal": [2.0, 3.0, 4.0]}
657
658
659
def test_when_then_supertype_15975_comment() -> None:
660
df = pl.LazyFrame({"foo": [1, 3, 4], "bar": [3, 4, 0]})
661
662
q = df.with_columns(
663
pl.when(pl.col("foo") == 1)
664
.then(1)
665
.when(pl.col("foo") == 2)
666
.then(4)
667
.when(pl.col("foo") == 3)
668
.then(1.5)
669
.when(pl.col("foo") == 4)
670
.then(16)
671
.otherwise(0)
672
.alias("val")
673
)
674
675
assert q.collect()["val"].to_list() == [1.0, 1.5, 16.0]
676
677
678
def test_chained_when_no_subclass_17142() -> None:
679
# https://github.com/pola-rs/polars/pull/17142
680
when = pl.when(True).then(1).when(True)
681
682
assert not isinstance(when, pl.Expr)
683
assert "<polars.expr.whenthen.ChainedWhen object at" in str(when)
684
685
686
def test_when_then_chunked_structs_18673() -> None:
687
df = pl.DataFrame(
688
[
689
pl.Series("x", [{"a": 1}]),
690
pl.Series("b", [False]),
691
]
692
)
693
694
df = df.vstack(df)
695
696
# This used to panic
697
assert_frame_equal(
698
df.select(pl.when(pl.col.b).then(pl.first("x")).otherwise(pl.first("x"))),
699
pl.DataFrame({"x": [{"a": 1}, {"a": 1}]}),
700
)
701
702
703
some_scalar = pl.Series("a", [{"x": 2}], pl.Struct)
704
none_scalar = pl.Series("a", [None], pl.Struct({"x": pl.Int64}))
705
column = pl.Series("a", [{"x": 2}, {"x": 2}], pl.Struct)
706
707
708
@pytest.mark.parametrize(
709
"values",
710
[
711
(some_scalar, some_scalar),
712
(some_scalar, pl.col.a),
713
(some_scalar, none_scalar),
714
(some_scalar, column),
715
(none_scalar, pl.col.a),
716
(none_scalar, none_scalar),
717
(none_scalar, column),
718
(pl.col.a, pl.col.a),
719
(pl.col.a, column),
720
(column, column),
721
],
722
)
723
def test_struct_when_then_broadcasting_combinations_19122(
724
values: tuple[Any, Any],
725
) -> None:
726
lv, rv = values
727
728
df = pl.Series("a", [{"x": 1}, {"x": 1}], pl.Struct).to_frame()
729
730
assert_frame_equal(
731
df.select(
732
pl.when(pl.col.a.struct.field("x") == 0).then(lv).otherwise(rv).alias("a")
733
),
734
df.select(
735
pl.when(pl.col.a.struct.field("x") == 0).then(None).otherwise(rv).alias("a")
736
),
737
)
738
739
assert_frame_equal(
740
df.select(
741
pl.when(pl.col.a.struct.field("x") != 0).then(rv).otherwise(lv).alias("a")
742
),
743
df.select(
744
pl.when(pl.col.a.struct.field("x") != 0).then(rv).otherwise(None).alias("a")
745
),
746
)
747
748
749
@pytest.mark.may_fail_cloud # reason str.to_decimal is an eager construct
750
def test_when_then_to_decimal_18375() -> None:
751
df = pl.DataFrame({"a": ["1.23", "4.56"]})
752
753
result = df.with_columns(
754
b=pl.when(False).then(None).otherwise(pl.col("a").str.to_decimal(scale=2)),
755
c=pl.when(True).then(pl.col("a").str.to_decimal(scale=2)),
756
)
757
expected = pl.DataFrame(
758
{
759
"a": ["1.23", "4.56"],
760
"b": ["1.23", "4.56"],
761
"c": ["1.23", "4.56"],
762
},
763
schema={"a": pl.String, "b": pl.Decimal, "c": pl.Decimal},
764
)
765
assert_frame_equal(result, expected)
766
767
768
def test_when_then_chunked_fill_null_22794() -> None:
769
df = pl.DataFrame(
770
{
771
"node": [{"x": "a", "y": "a"}, {"x": "b", "y": "b"}, {"x": "c", "y": "c"}],
772
"level": [0, 1, 2],
773
}
774
)
775
776
out = pl.concat([df.slice(0, 1), df.slice(1, 1), df.slice(2, 1)]).with_columns(
777
pl.when(level=1).then("node").forward_fill()
778
)
779
expected = pl.DataFrame(
780
{
781
"node": [None, {"x": "b", "y": "b"}, {"x": "b", "y": "b"}],
782
"level": [0, 1, 2],
783
}
784
)
785
786
assert_frame_equal(out, expected)
787
788
789
def test_when_then_complex_conditional_22959() -> None:
790
df = pl.DataFrame(
791
{"B": [None, "T1", "T2"], "C": [None, None, [1.0]], "E": [None, 2.0, None]}
792
)
793
794
res = df.with_columns(
795
Result=(
796
pl.when(B="T1")
797
.then(pl.struct(X="C", Y="C"))
798
.when(B="T2")
799
.then(pl.struct(X=pl.concat_list([3.0, "E"])))
800
)
801
)
802
803
assert_series_equal(
804
res["Result"],
805
pl.Series(
806
"Result",
807
[None, {"X": None, "Y": None}, {"X": [3.0, None], "Y": None}],
808
pl.Struct({"X": pl.List(pl.Float64), "Y": pl.List(pl.Float64)}),
809
),
810
)
811
812
813
def test_when_then_simplification() -> None:
814
lf = pl.LazyFrame({"a": [12]})
815
assert (
816
"""[col("a")]"""
817
in (
818
lf.select(pl.when(True).then(pl.col("a")).otherwise(pl.col("a") * 2))
819
).explain()
820
)
821
assert (
822
"""(col("a")) * (2)"""
823
in (
824
lf.select(pl.when(False).then(pl.col("a")).otherwise(pl.col("a") * 2))
825
).explain()
826
)
827
828
829
def test_when_then_in_group_by_aggregated_22922() -> None:
830
df = pl.DataFrame({"group": ["x", "y", "x", "y"], "value": [1, 2, 3, 4]})
831
out = df.group_by("group", maintain_order=True).agg(
832
expr=pl.when(group="x").then(pl.col.value.max()).first()
833
)
834
expected = pl.DataFrame({"group": ["x", "y"], "expr": [3, None]})
835
assert_frame_equal(out, expected)
836
837