Path: blob/main/py-polars/tests/unit/functions/test_when_then.py
8420 views
from __future__ import annotations12import itertools3import random4from datetime import datetime5from typing import Any67import pytest89import polars as pl10import polars.selectors as cs11from polars.exceptions import InvalidOperationError, ShapeError12from polars.testing import assert_frame_equal, assert_series_equal131415def test_when_then() -> None:16df = pl.DataFrame({"a": [1, 2, 3, 4, 5]})1718expr = pl.when(pl.col("a") < 3).then(pl.lit("x"))1920result = df.select(21expr.otherwise(pl.lit("y")).alias("a"),22expr.alias("b"),23)2425expected = pl.DataFrame(26{27"a": ["x", "x", "y", "y", "y"],28"b": ["x", "x", None, None, None],29}30)31assert_frame_equal(result, expected)323334def test_when_then_chained() -> None:35df = pl.DataFrame({"a": [1, 2, 3, 4, 5]})3637expr = (38pl.when(pl.col("a") < 3)39.then(pl.lit("x"))40.when(pl.col("a") > 4)41.then(pl.lit("z"))42)4344result = df.select(45expr.otherwise(pl.lit("y")).alias("a"),46expr.alias("b"),47)4849expected = pl.DataFrame(50{51"a": ["x", "x", "y", "y", "z"],52"b": ["x", "x", None, None, "z"],53}54)55assert_frame_equal(result, expected)565758def test_when_then_invalid_chains() -> None:59with pytest.raises(AttributeError):60pl.when("a").when("b") # type: ignore[attr-defined]61with pytest.raises(AttributeError):62pl.when("a").otherwise(2) # type: ignore[attr-defined]63with pytest.raises(AttributeError):64pl.when("a").then(1).then(2) # type: ignore[attr-defined]65with pytest.raises(AttributeError):66pl.when("a").then(1).otherwise(2).otherwise(3) # type: ignore[attr-defined]67with pytest.raises(AttributeError):68pl.when("a").then(1).when("b").when("c") # type: ignore[attr-defined]69with pytest.raises(AttributeError):70pl.when("a").then(1).when("b").otherwise("2") # type: ignore[attr-defined]71with pytest.raises(AttributeError):72pl.when("a").then(1).when("b").then(2).when("c").when("d") # type: ignore[attr-defined]737475def test_when_then_implicit_none() -> None:76df = pl.DataFrame(77{78"team": ["A", "A", "A", "B", "B", "C"],79"points": [11, 8, 10, 6, 6, 5],80}81)8283result = df.select(84pl.when(pl.col("points") > 7).then(pl.lit("Foo")),85pl.when(pl.col("points") > 7).then(pl.lit("Foo")).alias("bar"),86)8788expected = pl.DataFrame(89{90"literal": ["Foo", "Foo", "Foo", None, None, None],91"bar": ["Foo", "Foo", "Foo", None, None, None],92}93)94assert_frame_equal(result, expected)959697def test_when_then_empty_list_5547() -> None:98out = pl.DataFrame({"a": []}).select([pl.when(pl.col("a") > 1).then([1])])99assert out.shape == (0, 1)100assert out.dtypes == [pl.List(pl.Int64)]101102103def test_nested_when_then_and_wildcard_expansion_6284() -> None:104df = pl.DataFrame(105{106"1": ["a", "b"],107"2": ["c", "d"],108}109)110111out0 = df.with_columns(112pl.when(pl.any_horizontal(pl.all() == "a"))113.then(pl.lit("a"))114.otherwise(115pl.when(pl.any_horizontal(pl.all() == "d"))116.then(pl.lit("d"))117.otherwise(None)118)119.alias("result")120)121122out1 = df.with_columns(123pl.when(pl.any_horizontal(pl.all() == "a"))124.then(pl.lit("a"))125.when(pl.any_horizontal(pl.all() == "d"))126.then(pl.lit("d"))127.otherwise(None)128.alias("result")129)130131assert_frame_equal(out0, out1)132assert out0.to_dict(as_series=False) == {133"1": ["a", "b"],134"2": ["c", "d"],135"result": ["a", "d"],136}137138139def test_list_zip_with_logical_type() -> None:140df = pl.DataFrame(141{142"start": [datetime(2023, 1, 1, 1, 1, 1), datetime(2023, 1, 1, 1, 1, 1)],143"stop": [datetime(2023, 1, 1, 1, 3, 1), datetime(2023, 1, 1, 1, 4, 1)],144"use": [1, 0],145}146)147148df = df.with_columns(149pl.datetime_ranges(150pl.col("start"), pl.col("stop"), interval="1h", eager=False, closed="left"151).alias("interval_1"),152pl.datetime_ranges(153pl.col("start"), pl.col("stop"), interval="1h", eager=False, closed="left"154).alias("interval_2"),155)156157out = df.select(158pl.when(pl.col("use") == 1)159.then(pl.col("interval_2"))160.otherwise(pl.col("interval_1"))161.alias("interval_new")162)163assert out.dtypes == [pl.List(pl.Datetime(time_unit="us", time_zone=None))]164165166def test_type_coercion_when_then_otherwise_2806() -> None:167out = (168pl.DataFrame({"names": ["foo", "spam", "spam"], "nrs": [1, 2, 3]})169.select(170pl.when(pl.col("names") == "spam")171.then(pl.col("nrs") * 2)172.otherwise(pl.lit("other"))173.alias("new_col"),174)175.to_series()176)177expected = pl.Series("new_col", ["other", "4", "6"])178assert out.to_list() == expected.to_list()179180# test it remains float32181assert (182pl.Series("a", [1.0, 2.0, 3.0], dtype=pl.Float32)183.to_frame()184.select(pl.when(pl.col("a") > 2.0).then(pl.col("a")).otherwise(0.0))185).to_series().dtype == pl.Float32186187188def test_when_then_edge_cases_3994() -> None:189df = pl.DataFrame(data={"id": [1, 1], "type": [2, 2]})190191# this tests if lazy correctly assigns the list schema to the column aggregation192assert (193df.lazy()194.group_by(["id"])195.agg(pl.col("type"))196.with_columns(197pl.when(pl.col("type").list.len() == 0)198.then(pl.lit(None))199.otherwise(pl.col("type"))200.name.keep()201)202.collect()203).to_dict(as_series=False) == {"id": [1], "type": [[2, 2]]}204205# this tests ternary with an empty argument206assert (207df.filter(pl.col("id") == 42)208.group_by(["id"])209.agg(pl.col("type"))210.with_columns(211pl.when(pl.col("type").list.len() == 0)212.then(pl.lit(None))213.otherwise(pl.col("type"))214.name.keep()215)216).to_dict(as_series=False) == {"id": [], "type": []}217218219@pytest.mark.may_fail_cloud # reason: object220def test_object_when_then_4702() -> None:221# please don't ever do this222x = pl.DataFrame({"Row": [1, 2], "Type": [pl.Date, pl.UInt8]})223224assert x.with_columns(225pl.when(pl.col("Row") == 1)226.then(pl.lit(pl.UInt16, allow_object=True))227.otherwise(pl.lit(pl.UInt8, allow_object=True))228.alias("New_Type")229).to_dict(as_series=False) == {230"Row": [1, 2],231"Type": [pl.Date, pl.UInt8],232"New_Type": [pl.UInt16, pl.UInt8],233}234235236def test_comp_categorical_lit_dtype() -> None:237df = pl.DataFrame(238data={"column": ["a", "b", "e"], "values": [1, 5, 9]},239schema=[("column", pl.Categorical), ("more", pl.Int32)],240)241242assert df.with_columns(243pl.when(pl.col("column") == "e")244.then(pl.lit("d"))245.otherwise(pl.col("column"))246.alias("column")247).dtypes == [pl.Categorical, pl.Int32]248249250def test_comp_incompatible_enum_dtype() -> None:251df = pl.DataFrame({"a": pl.Series(["a", "b"], dtype=pl.Enum(["a", "b"]))})252253with pytest.raises(254InvalidOperationError,255match="conversion from `str` to `enum` failed in column 'scalar'",256):257df.with_columns(258pl.when(pl.col("a") == "a").then(pl.col("a")).otherwise(pl.lit("c"))259)260261262def test_predicate_broadcast() -> None:263df = pl.DataFrame(264{265"key": ["a", "a", "b", "b", "c", "c"],266"val": [1, 2, 3, 4, 5, 6],267}268)269out = df.group_by("key", maintain_order=True).agg(270agg=pl.when(pl.col("val").min() >= 3).then(pl.col("val")),271)272assert out.to_dict(as_series=False) == {273"key": ["a", "b", "c"],274"agg": [[None, None], [3, 4], [5, 6]],275}276277278@pytest.mark.parametrize(279"mask_expr",280[281pl.lit(True),282pl.first("true"),283pl.lit(False),284pl.first("false"),285pl.lit(None, dtype=pl.Boolean),286pl.col("null_bool"),287pl.col("true"),288pl.col("false"),289],290)291@pytest.mark.parametrize(292"truthy_expr",293[294pl.lit(1),295pl.first("x"),296pl.col("x"),297],298)299@pytest.mark.parametrize(300"falsy_expr",301[302pl.lit(1),303pl.first("x"),304pl.col("x"),305],306)307@pytest.mark.parametrize("maintain_order", [False, True])308def test_single_element_broadcast(309mask_expr: pl.Expr,310truthy_expr: pl.Expr,311falsy_expr: pl.Expr,312maintain_order: bool,313) -> None:314df = (315pl.Series("x", 5 * [1], dtype=pl.Int32)316.to_frame()317.with_columns(true=True, false=False, null_bool=pl.lit(None, dtype=pl.Boolean))318)319320# Given that the lengths of the mask, truthy and falsy are all either:321# - Length 1322# - Equal length to the maximum length of the 3.323# This test checks that all length-1 exprs are broadcast to the max length.324result = df.select(325pl.when(mask_expr).then(truthy_expr.alias("x")).otherwise(falsy_expr)326)327expected = df.select("x").head(328df.select(329pl.max_horizontal(mask_expr.len(), truthy_expr.len(), falsy_expr.len())330).item()331)332assert_frame_equal(result, expected)333334result = (335df.group_by(pl.lit(True).alias("key"), maintain_order=maintain_order)336.agg(pl.when(mask_expr).then(truthy_expr.alias("x")).otherwise(falsy_expr))337.drop("key")338)339if expected.height > 1:340result = result.explode(cs.all())341assert_frame_equal(result, expected, check_row_order=maintain_order)342343344@pytest.mark.parametrize(345"df",346[pl.DataFrame({"x": range(5)}), pl.DataFrame({"x": 5 * [[*range(5)]]})],347)348@pytest.mark.parametrize(349"ternary_expr",350[351pl.when(True).then(pl.col("x").head(2)).otherwise(pl.col("x")),352pl.when(False).then(pl.col("x").head(2)).otherwise(pl.col("x")),353],354)355def test_mismatched_height_should_raise(356df: pl.DataFrame, ternary_expr: pl.Expr357) -> None:358with pytest.raises(ShapeError):359df.select(ternary_expr)360361with pytest.raises(ShapeError):362df.group_by(pl.lit(True).alias("key")).agg(ternary_expr)363364365@pytest.mark.parametrize("maintain_order", [False, True])366def test_when_then_output_name_12380(maintain_order: bool) -> None:367df = pl.DataFrame(368{"x": range(5), "y": range(5, 10)}, schema={"x": pl.Int8, "y": pl.Int64}369).with_columns(true=True, false=False, null_bool=pl.lit(None, dtype=pl.Boolean))370371expect = df.select(pl.col("x").cast(pl.Int64))372for true_expr in (pl.first("true"), pl.col("true"), pl.lit(True)):373ternary_expr = pl.when(true_expr).then(pl.col("x")).otherwise(pl.col("y"))374375actual = df.select(ternary_expr)376assert_frame_equal(377expect,378actual,379)380actual = (381df.group_by(pl.lit(True).alias("key"), maintain_order=maintain_order)382.agg(ternary_expr)383.drop("key")384.explode(cs.all())385)386assert_frame_equal(expect, actual, check_row_order=maintain_order)387388expect = df.select(pl.col("y").alias("x"))389for false_expr in (390pl.first("false"),391pl.col("false"),392pl.lit(False),393pl.first("null_bool"),394pl.col("null_bool"),395pl.lit(None, dtype=pl.Boolean),396):397ternary_expr = pl.when(false_expr).then(pl.col("x")).otherwise(pl.col("y"))398399actual = df.select(ternary_expr)400assert_frame_equal(401expect,402actual,403)404actual = (405df.group_by(pl.lit(True).alias("key"))406.agg(ternary_expr)407.drop("key")408.explode(cs.all())409)410assert_frame_equal(411expect,412actual,413)414415416def test_when_then_nested_non_unit_literal_predicate_agg_broadcast_12242() -> None:417df = pl.DataFrame(418{419"array_name": ["A", "A", "A", "B", "B"],420"array_idx": [5, 0, 3, 7, 2],421"array_val": [1, 2, 3, 4, 5],422}423)424425int_range = pl.int_range(pl.min("array_idx"), pl.max("array_idx") + 1)426427is_valid_idx = int_range.is_in("array_idx")428429idxs = is_valid_idx.cum_sum() - 1430431ternary_expr = pl.when(is_valid_idx).then(pl.col("array_val").gather(idxs))432433expect = pl.DataFrame(434[435pl.Series("array_name", ["A", "B"], dtype=pl.String),436pl.Series(437"array_val",438[[1, None, None, 2, None, 3], [4, None, None, None, None, 5]],439dtype=pl.List(pl.Int64),440),441]442)443444assert_frame_equal(445expect, df.group_by("array_name").agg(ternary_expr).sort("array_name")446)447448449def test_when_then_non_unit_literal_predicate_agg_broadcast_12382() -> None:450df = pl.DataFrame({"id": [1, 1], "value": [0, 3]})451452expect = pl.DataFrame({"id": [1], "literal": [["yes", None, None, "yes", None]]})453actual = df.group_by("id").agg(454pl.when(pl.int_range(0, 5).is_in("value")).then(pl.lit("yes"))455)456457assert_frame_equal(expect, actual)458459460def test_when_then_binary_op_predicate_agg_12526() -> None:461df = pl.DataFrame(462{463"a": [1, 1, 1],464"b": [1, 2, 5],465}466)467468expect = pl.DataFrame(469{"a": [1], "col": [None]}, schema={"a": pl.Int64, "col": pl.String}470)471472actual = df.group_by("a").agg(473col=(474pl.when(475pl.col("a").shift(1) > 2,476pl.col("b").is_not_null(),477)478.then(pl.lit("abc"))479.when(480pl.col("a").shift(1) > 1,481pl.col("b").is_not_null(),482)483.then(pl.lit("def"))484.otherwise(pl.lit(None))485.first()486)487)488489assert_frame_equal(expect, actual)490491492def test_when_predicates_kwargs() -> None:493df = pl.DataFrame(494{495"x": [10, 20, 30, 40],496"y": [15, -20, None, 1],497"z": ["a", "b", "c", "d"],498}499)500assert_frame_equal( # kwargs only501df.select(matched=pl.when(x=30, z="c").then(True).otherwise(False)),502pl.DataFrame({"matched": [False, False, True, False]}),503)504assert_frame_equal( # mixed predicates & kwargs505df.select(matched=pl.when(pl.col("x") < 30, z="b").then(True).otherwise(False)),506pl.DataFrame({"matched": [False, True, False, False]}),507)508assert_frame_equal( # chained when/then with mixed predicates/kwargs509df.select(510misc=pl.when(pl.col("x") > 50)511.then(pl.lit("x>50"))512.when(y=1)513.then(pl.lit("y=1"))514.when(pl.col("z").is_in(["a", "b"]), pl.col("y") < 0)515.then(pl.lit("z in (a|b), y<0"))516.otherwise(pl.lit("?"))517),518pl.DataFrame({"misc": ["?", "z in (a|b), y<0", "?", "y=1"]}),519)520521522def test_when_then_null_broadcast() -> None:523assert (524pl.select(525pl.when(pl.repeat(True, 2, dtype=pl.Boolean)).then(526pl.repeat(None, 1, dtype=pl.Null)527)528).height529== 2530)531532533@pytest.mark.slow534@pytest.mark.parametrize("len", [1, 10, 100, 500])535@pytest.mark.parametrize(536("dtype", "vals"),537[538pytest.param(pl.Boolean, [False, True], id="Boolean"),539pytest.param(pl.UInt8, [0, 1], id="UInt8"),540pytest.param(pl.UInt16, [0, 1], id="UInt16"),541pytest.param(pl.UInt32, [0, 1], id="UInt32"),542pytest.param(pl.UInt64, [0, 1], id="UInt64"),543pytest.param(pl.Float32, [0.0, 1.0], id="Float32"),544pytest.param(pl.Float64, [0.0, 1.0], id="Float64"),545pytest.param(pl.String, ["0", "12"], id="String"),546pytest.param(pl.Array(pl.String, 2), [["0", "1"], ["3", "4"]], id="StrArray"),547pytest.param(pl.Array(pl.Int64, 2), [[0, 1], [3, 4]], id="IntArray"),548pytest.param(pl.List(pl.String), [["0"], ["1", "2"]], id="List"),549pytest.param(550pl.Struct({"foo": pl.Int32, "bar": pl.String}),551[{"foo": 0, "bar": "1"}, {"foo": 1, "bar": "2"}],552id="Struct",553),554pytest.param(555pl.Object,556["x", "y"],557id="Object",558marks=pytest.mark.may_fail_cloud,559# reason: objects are not allowed in cloud560),561],562)563@pytest.mark.parametrize("broadcast", list(itertools.product([False, True], repeat=3)))564def test_when_then_parametric(565len: int, dtype: pl.DataType, vals: list[Any], broadcast: list[bool]566) -> None:567# Makes no sense to broadcast all columns.568if all(broadcast):569return570571rng = random.Random(42)572573for _ in range(10):574mask = rng.choices([False, True, None], k=len)575if_true = rng.choices(vals + [None], k=len)576if_false = rng.choices(vals + [None], k=len)577578py_mask, py_true, py_false = (579[c[0]] * len if b else c580for b, c in zip(broadcast, [mask, if_true, if_false], strict=True)581)582pl_mask, pl_true, pl_false = (583c.first() if b else c584for b, c in zip(585broadcast, [pl.col.mask, pl.col.if_true, pl.col.if_false], strict=True586)587)588589ref = pl.DataFrame(590{591"if_true": [592t if m else f593for m, t, f in zip(py_mask, py_true, py_false, strict=True)594]595},596schema={"if_true": dtype},597)598df = pl.DataFrame(599{600"mask": mask,601"if_true": if_true,602"if_false": if_false,603},604schema={"mask": pl.Boolean, "if_true": dtype, "if_false": dtype},605)606607ans = df.select(pl.when(pl_mask).then(pl_true).otherwise(pl_false))608if dtype != pl.Object:609assert_frame_equal(ref, ans)610else:611assert ref["if_true"].to_list() == ans["if_true"].to_list()612613614def test_when_then_else_struct_18961() -> None:615v1 = [None, {"foo": 0, "bar": "1"}]616v2 = [{"foo": 0, "bar": "1"}, {"foo": 0, "bar": "1"}]617618df = pl.DataFrame({"left": v1, "right": v2, "mask": [False, True]})619620expected = [{"foo": 0, "bar": "1"}, {"foo": 0, "bar": "1"}]621ans = (622df.select(623pl.when(pl.col.mask).then(pl.col.left).otherwise(pl.col.right.first())624)625.get_column("left")626.to_list()627)628assert expected == ans629630df = pl.DataFrame({"left": v2, "right": v1, "mask": [True, False]})631632expected = [{"foo": 0, "bar": "1"}, {"foo": 0, "bar": "1"}]633ans = (634df.select(635pl.when(pl.col.mask).then(pl.col.left.first()).otherwise(pl.col.right)636)637.get_column("left")638.to_list()639)640assert expected == ans641642df = pl.DataFrame({"left": v1, "right": v2, "mask": [True, False]})643644expected2 = [None, {"foo": 0, "bar": "1"}]645ans = (646df.select(647pl.when(pl.col.mask)648.then(pl.col.left.first())649.otherwise(pl.col.right.first())650)651.get_column("left")652.to_list()653)654assert expected2 == ans655656657def test_when_then_supertype_15975() -> None:658df = pl.DataFrame({"a": [1, 2, 3]})659660assert df.with_columns(661pl.when(True).then(1 ** pl.col("a") + 1.0 * pl.col("a"))662).to_dict(as_series=False) == {"a": [1, 2, 3], "literal": [2.0, 3.0, 4.0]}663664665def test_when_then_supertype_15975_comment() -> None:666df = pl.LazyFrame({"foo": [1, 3, 4], "bar": [3, 4, 0]})667668q = df.with_columns(669pl.when(pl.col("foo") == 1)670.then(1)671.when(pl.col("foo") == 2)672.then(4)673.when(pl.col("foo") == 3)674.then(1.5)675.when(pl.col("foo") == 4)676.then(16)677.otherwise(0)678.alias("val")679)680681assert q.collect()["val"].to_list() == [1.0, 1.5, 16.0]682683684def test_chained_when_no_subclass_17142() -> None:685# https://github.com/pola-rs/polars/pull/17142686when = pl.when(True).then(1).when(True)687688assert not isinstance(when, pl.Expr)689assert "<polars.expr.whenthen.ChainedWhen object at" in str(when)690691692def test_when_then_chunked_structs_18673() -> None:693df = pl.DataFrame(694[695pl.Series("x", [{"a": 1}]),696pl.Series("b", [False]),697]698)699700df = df.vstack(df)701702# This used to panic703assert_frame_equal(704df.select(pl.when(pl.col.b).then(pl.first("x")).otherwise(pl.first("x"))),705pl.DataFrame({"x": [{"a": 1}, {"a": 1}]}),706)707708709some_scalar = pl.Series("a", [{"x": 2}], pl.Struct)710none_scalar = pl.Series("a", [None], pl.Struct({"x": pl.Int64}))711column = pl.Series("a", [{"x": 2}, {"x": 2}], pl.Struct)712713714@pytest.mark.parametrize(715"values",716[717(some_scalar, some_scalar),718(some_scalar, pl.col.a),719(some_scalar, none_scalar),720(some_scalar, column),721(none_scalar, pl.col.a),722(none_scalar, none_scalar),723(none_scalar, column),724(pl.col.a, pl.col.a),725(pl.col.a, column),726(column, column),727],728)729def test_struct_when_then_broadcasting_combinations_19122(730values: tuple[Any, Any],731) -> None:732lv, rv = values733734df = pl.Series("a", [{"x": 1}, {"x": 1}], pl.Struct).to_frame()735736assert_frame_equal(737df.select(738pl.when(pl.col.a.struct.field("x") == 0).then(lv).otherwise(rv).alias("a")739),740df.select(741pl.when(pl.col.a.struct.field("x") == 0).then(None).otherwise(rv).alias("a")742),743)744745assert_frame_equal(746df.select(747pl.when(pl.col.a.struct.field("x") != 0).then(rv).otherwise(lv).alias("a")748),749df.select(750pl.when(pl.col.a.struct.field("x") != 0).then(rv).otherwise(None).alias("a")751),752)753754755@pytest.mark.may_fail_cloud # reason str.to_decimal is an eager construct756def test_when_then_to_decimal_18375() -> None:757df = pl.DataFrame({"a": ["1.23", "4.56"]})758759result = df.with_columns(760b=pl.when(False).then(None).otherwise(pl.col("a").str.to_decimal(scale=2)),761c=pl.when(True).then(pl.col("a").str.to_decimal(scale=2)),762)763expected = pl.DataFrame(764{765"a": ["1.23", "4.56"],766"b": ["1.23", "4.56"],767"c": ["1.23", "4.56"],768},769schema={"a": pl.String, "b": pl.Decimal(scale=2), "c": pl.Decimal(scale=2)},770)771assert_frame_equal(result, expected)772773774def test_when_then_chunked_fill_null_22794() -> None:775df = pl.DataFrame(776{777"node": [{"x": "a", "y": "a"}, {"x": "b", "y": "b"}, {"x": "c", "y": "c"}],778"level": [0, 1, 2],779}780)781782out = pl.concat([df.slice(0, 1), df.slice(1, 1), df.slice(2, 1)]).with_columns(783pl.when(level=1).then("node").forward_fill()784)785expected = pl.DataFrame(786{787"node": [None, {"x": "b", "y": "b"}, {"x": "b", "y": "b"}],788"level": [0, 1, 2],789}790)791792assert_frame_equal(out, expected)793794795def test_when_then_complex_conditional_22959() -> None:796df = pl.DataFrame(797{"B": [None, "T1", "T2"], "C": [None, None, [1.0]], "E": [None, 2.0, None]}798)799800res = df.with_columns(801Result=(802pl.when(B="T1")803.then(pl.struct(X="C", Y="C"))804.when(B="T2")805.then(pl.struct(X=pl.concat_list([3.0, "E"])))806)807)808809assert_series_equal(810res["Result"],811pl.Series(812"Result",813[None, {"X": None, "Y": None}, {"X": [3.0, None], "Y": None}],814pl.Struct({"X": pl.List(pl.Float64), "Y": pl.List(pl.Float64)}),815),816)817818819def test_when_then_simplification() -> None:820lf = pl.LazyFrame({"a": [12]})821assert (822"""[col("a")]"""823in (824lf.select(pl.when(True).then(pl.col("a")).otherwise(pl.col("a") * 2))825).explain()826)827assert (828"""(col("a")) * (2)"""829in (830lf.select(pl.when(False).then(pl.col("a")).otherwise(pl.col("a") * 2))831).explain()832)833834835def test_when_then_in_group_by_aggregated_22922() -> None:836df = pl.DataFrame({"group": ["x", "y", "x", "y"], "value": [1, 2, 3, 4]})837out = df.group_by("group", maintain_order=True).agg(838expr=pl.when(group="x").then(pl.col.value.max()).first()839)840expected = pl.DataFrame({"group": ["x", "y"], "expr": [3, None]})841assert_frame_equal(out, expected)842843844