Path: blob/main/py-polars/tests/unit/functions/test_when_then.py
6939 views
from __future__ import annotations12import itertools3import random4from datetime import datetime5from typing import Any67import pytest89import polars as pl10import polars.selectors as cs11from polars.exceptions import InvalidOperationError, ShapeError12from polars.testing import assert_frame_equal, assert_series_equal131415def test_when_then() -> None:16df = pl.DataFrame({"a": [1, 2, 3, 4, 5]})1718expr = pl.when(pl.col("a") < 3).then(pl.lit("x"))1920result = df.select(21expr.otherwise(pl.lit("y")).alias("a"),22expr.alias("b"),23)2425expected = pl.DataFrame(26{27"a": ["x", "x", "y", "y", "y"],28"b": ["x", "x", None, None, None],29}30)31assert_frame_equal(result, expected)323334def test_when_then_chained() -> None:35df = pl.DataFrame({"a": [1, 2, 3, 4, 5]})3637expr = (38pl.when(pl.col("a") < 3)39.then(pl.lit("x"))40.when(pl.col("a") > 4)41.then(pl.lit("z"))42)4344result = df.select(45expr.otherwise(pl.lit("y")).alias("a"),46expr.alias("b"),47)4849expected = pl.DataFrame(50{51"a": ["x", "x", "y", "y", "z"],52"b": ["x", "x", None, None, "z"],53}54)55assert_frame_equal(result, expected)565758def test_when_then_invalid_chains() -> None:59with pytest.raises(AttributeError):60pl.when("a").when("b") # type: ignore[attr-defined]61with pytest.raises(AttributeError):62pl.when("a").otherwise(2) # type: ignore[attr-defined]63with pytest.raises(AttributeError):64pl.when("a").then(1).then(2) # type: ignore[attr-defined]65with pytest.raises(AttributeError):66pl.when("a").then(1).otherwise(2).otherwise(3) # type: ignore[attr-defined]67with pytest.raises(AttributeError):68pl.when("a").then(1).when("b").when("c") # type: ignore[attr-defined]69with pytest.raises(AttributeError):70pl.when("a").then(1).when("b").otherwise("2") # type: ignore[attr-defined]71with pytest.raises(AttributeError):72pl.when("a").then(1).when("b").then(2).when("c").when("d") # type: ignore[attr-defined]737475def test_when_then_implicit_none() -> None:76df = pl.DataFrame(77{78"team": ["A", "A", "A", "B", "B", "C"],79"points": [11, 8, 10, 6, 6, 5],80}81)8283result = df.select(84pl.when(pl.col("points") > 7).then(pl.lit("Foo")),85pl.when(pl.col("points") > 7).then(pl.lit("Foo")).alias("bar"),86)8788expected = pl.DataFrame(89{90"literal": ["Foo", "Foo", "Foo", None, None, None],91"bar": ["Foo", "Foo", "Foo", None, None, None],92}93)94assert_frame_equal(result, expected)959697def test_when_then_empty_list_5547() -> None:98out = pl.DataFrame({"a": []}).select([pl.when(pl.col("a") > 1).then([1])])99assert out.shape == (0, 1)100assert out.dtypes == [pl.List(pl.Int64)]101102103def test_nested_when_then_and_wildcard_expansion_6284() -> None:104df = pl.DataFrame(105{106"1": ["a", "b"],107"2": ["c", "d"],108}109)110111out0 = df.with_columns(112pl.when(pl.any_horizontal(pl.all() == "a"))113.then(pl.lit("a"))114.otherwise(115pl.when(pl.any_horizontal(pl.all() == "d"))116.then(pl.lit("d"))117.otherwise(None)118)119.alias("result")120)121122out1 = df.with_columns(123pl.when(pl.any_horizontal(pl.all() == "a"))124.then(pl.lit("a"))125.when(pl.any_horizontal(pl.all() == "d"))126.then(pl.lit("d"))127.otherwise(None)128.alias("result")129)130131assert_frame_equal(out0, out1)132assert out0.to_dict(as_series=False) == {133"1": ["a", "b"],134"2": ["c", "d"],135"result": ["a", "d"],136}137138139def test_list_zip_with_logical_type() -> None:140df = pl.DataFrame(141{142"start": [datetime(2023, 1, 1, 1, 1, 1), datetime(2023, 1, 1, 1, 1, 1)],143"stop": [datetime(2023, 1, 1, 1, 3, 1), datetime(2023, 1, 1, 1, 4, 1)],144"use": [1, 0],145}146)147148df = df.with_columns(149pl.datetime_ranges(150pl.col("start"), pl.col("stop"), interval="1h", eager=False, closed="left"151).alias("interval_1"),152pl.datetime_ranges(153pl.col("start"), pl.col("stop"), interval="1h", eager=False, closed="left"154).alias("interval_2"),155)156157out = df.select(158pl.when(pl.col("use") == 1)159.then(pl.col("interval_2"))160.otherwise(pl.col("interval_1"))161.alias("interval_new")162)163assert out.dtypes == [pl.List(pl.Datetime(time_unit="us", time_zone=None))]164165166def test_type_coercion_when_then_otherwise_2806() -> None:167out = (168pl.DataFrame({"names": ["foo", "spam", "spam"], "nrs": [1, 2, 3]})169.select(170pl.when(pl.col("names") == "spam")171.then(pl.col("nrs") * 2)172.otherwise(pl.lit("other"))173.alias("new_col"),174)175.to_series()176)177expected = pl.Series("new_col", ["other", "4", "6"])178assert out.to_list() == expected.to_list()179180# test it remains float32181assert (182pl.Series("a", [1.0, 2.0, 3.0], dtype=pl.Float32)183.to_frame()184.select(pl.when(pl.col("a") > 2.0).then(pl.col("a")).otherwise(0.0))185).to_series().dtype == pl.Float32186187188def test_when_then_edge_cases_3994() -> None:189df = pl.DataFrame(data={"id": [1, 1], "type": [2, 2]})190191# this tests if lazy correctly assigns the list schema to the column aggregation192assert (193df.lazy()194.group_by(["id"])195.agg(pl.col("type"))196.with_columns(197pl.when(pl.col("type").list.len() == 0)198.then(pl.lit(None))199.otherwise(pl.col("type"))200.name.keep()201)202.collect()203).to_dict(as_series=False) == {"id": [1], "type": [[2, 2]]}204205# this tests ternary with an empty argument206assert (207df.filter(pl.col("id") == 42)208.group_by(["id"])209.agg(pl.col("type"))210.with_columns(211pl.when(pl.col("type").list.len() == 0)212.then(pl.lit(None))213.otherwise(pl.col("type"))214.name.keep()215)216).to_dict(as_series=False) == {"id": [], "type": []}217218219@pytest.mark.may_fail_cloud # reason: object220def test_object_when_then_4702() -> None:221# please don't ever do this222x = pl.DataFrame({"Row": [1, 2], "Type": [pl.Date, pl.UInt8]})223224assert x.with_columns(225pl.when(pl.col("Row") == 1)226.then(pl.lit(pl.UInt16, allow_object=True))227.otherwise(pl.lit(pl.UInt8, allow_object=True))228.alias("New_Type")229).to_dict(as_series=False) == {230"Row": [1, 2],231"Type": [pl.Date, pl.UInt8],232"New_Type": [pl.UInt16, pl.UInt8],233}234235236def test_comp_categorical_lit_dtype() -> None:237df = pl.DataFrame(238data={"column": ["a", "b", "e"], "values": [1, 5, 9]},239schema=[("column", pl.Categorical), ("more", pl.Int32)],240)241242assert df.with_columns(243pl.when(pl.col("column") == "e")244.then(pl.lit("d"))245.otherwise(pl.col("column"))246.alias("column")247).dtypes == [pl.Categorical, pl.Int32]248249250def test_comp_incompatible_enum_dtype() -> None:251df = pl.DataFrame({"a": pl.Series(["a", "b"], dtype=pl.Enum(["a", "b"]))})252253with pytest.raises(254InvalidOperationError,255match="conversion from `str` to `enum` failed in column 'scalar'",256):257df.with_columns(258pl.when(pl.col("a") == "a").then(pl.col("a")).otherwise(pl.lit("c"))259)260261262def test_predicate_broadcast() -> None:263df = pl.DataFrame(264{265"key": ["a", "a", "b", "b", "c", "c"],266"val": [1, 2, 3, 4, 5, 6],267}268)269out = df.group_by("key", maintain_order=True).agg(270agg=pl.when(pl.col("val").min() >= 3).then(pl.col("val")),271)272assert out.to_dict(as_series=False) == {273"key": ["a", "b", "c"],274"agg": [[None, None], [3, 4], [5, 6]],275}276277278@pytest.mark.parametrize(279"mask_expr",280[281pl.lit(True),282pl.first("true"),283pl.lit(False),284pl.first("false"),285pl.lit(None, dtype=pl.Boolean),286pl.col("null_bool"),287pl.col("true"),288pl.col("false"),289],290)291@pytest.mark.parametrize(292"truthy_expr",293[294pl.lit(1),295pl.first("x"),296pl.col("x"),297],298)299@pytest.mark.parametrize(300"falsy_expr",301[302pl.lit(1),303pl.first("x"),304pl.col("x"),305],306)307@pytest.mark.parametrize("maintain_order", [False, True])308def test_single_element_broadcast(309mask_expr: pl.Expr,310truthy_expr: pl.Expr,311falsy_expr: pl.Expr,312maintain_order: bool,313) -> None:314df = (315pl.Series("x", 5 * [1], dtype=pl.Int32)316.to_frame()317.with_columns(true=True, false=False, null_bool=pl.lit(None, dtype=pl.Boolean))318)319320# Given that the lengths of the mask, truthy and falsy are all either:321# - Length 1322# - Equal length to the maximum length of the 3.323# This test checks that all length-1 exprs are broadcast to the max length.324result = df.select(325pl.when(mask_expr).then(truthy_expr.alias("x")).otherwise(falsy_expr)326)327expected = df.select("x").head(328df.select(329pl.max_horizontal(mask_expr.len(), truthy_expr.len(), falsy_expr.len())330).item()331)332assert_frame_equal(result, expected)333334result = (335df.group_by(pl.lit(True).alias("key"), maintain_order=maintain_order)336.agg(pl.when(mask_expr).then(truthy_expr.alias("x")).otherwise(falsy_expr))337.drop("key")338)339if expected.height > 1:340result = result.explode(cs.all())341assert_frame_equal(result, expected, check_row_order=maintain_order)342343344@pytest.mark.parametrize(345"df",346[pl.DataFrame({"x": range(5)}), pl.DataFrame({"x": 5 * [[*range(5)]]})],347)348@pytest.mark.parametrize(349"ternary_expr",350[351pl.when(True).then(pl.col("x").head(2)).otherwise(pl.col("x")),352pl.when(False).then(pl.col("x").head(2)).otherwise(pl.col("x")),353],354)355def test_mismatched_height_should_raise(356df: pl.DataFrame, ternary_expr: pl.Expr357) -> None:358with pytest.raises(ShapeError):359df.select(ternary_expr)360361with pytest.raises(ShapeError):362df.group_by(pl.lit(True).alias("key")).agg(ternary_expr)363364365@pytest.mark.parametrize("maintain_order", [False, True])366def test_when_then_output_name_12380(maintain_order: bool) -> None:367df = pl.DataFrame(368{"x": range(5), "y": range(5, 10)}, schema={"x": pl.Int8, "y": pl.Int64}369).with_columns(true=True, false=False, null_bool=pl.lit(None, dtype=pl.Boolean))370371expect = df.select(pl.col("x").cast(pl.Int64))372for true_expr in (pl.first("true"), pl.col("true"), pl.lit(True)):373ternary_expr = pl.when(true_expr).then(pl.col("x")).otherwise(pl.col("y"))374375actual = df.select(ternary_expr)376assert_frame_equal(377expect,378actual,379)380actual = (381df.group_by(pl.lit(True).alias("key"), maintain_order=maintain_order)382.agg(ternary_expr)383.drop("key")384.explode(cs.all())385)386assert_frame_equal(expect, actual, check_row_order=maintain_order)387388expect = df.select(pl.col("y").alias("x"))389for false_expr in (390pl.first("false"),391pl.col("false"),392pl.lit(False),393pl.first("null_bool"),394pl.col("null_bool"),395pl.lit(None, dtype=pl.Boolean),396):397ternary_expr = pl.when(false_expr).then(pl.col("x")).otherwise(pl.col("y"))398399actual = df.select(ternary_expr)400assert_frame_equal(401expect,402actual,403)404actual = (405df.group_by(pl.lit(True).alias("key"))406.agg(ternary_expr)407.drop("key")408.explode(cs.all())409)410assert_frame_equal(411expect,412actual,413)414415416def test_when_then_nested_non_unit_literal_predicate_agg_broadcast_12242() -> None:417df = pl.DataFrame(418{419"array_name": ["A", "A", "A", "B", "B"],420"array_idx": [5, 0, 3, 7, 2],421"array_val": [1, 2, 3, 4, 5],422}423)424425int_range = pl.int_range(pl.min("array_idx"), pl.max("array_idx") + 1)426427is_valid_idx = int_range.is_in("array_idx")428429idxs = is_valid_idx.cum_sum() - 1430431ternary_expr = pl.when(is_valid_idx).then(pl.col("array_val").gather(idxs))432433expect = pl.DataFrame(434[435pl.Series("array_name", ["A", "B"], dtype=pl.String),436pl.Series(437"array_val",438[[1, None, None, 2, None, 3], [4, None, None, None, None, 5]],439dtype=pl.List(pl.Int64),440),441]442)443444assert_frame_equal(445expect, df.group_by("array_name").agg(ternary_expr).sort("array_name")446)447448449def test_when_then_non_unit_literal_predicate_agg_broadcast_12382() -> None:450df = pl.DataFrame({"id": [1, 1], "value": [0, 3]})451452expect = pl.DataFrame({"id": [1], "literal": [["yes", None, None, "yes", None]]})453actual = df.group_by("id").agg(454pl.when(pl.int_range(0, 5).is_in("value")).then(pl.lit("yes"))455)456457assert_frame_equal(expect, actual)458459460def test_when_then_binary_op_predicate_agg_12526() -> None:461df = pl.DataFrame(462{463"a": [1, 1, 1],464"b": [1, 2, 5],465}466)467468expect = pl.DataFrame(469{"a": [1], "col": [None]}, schema={"a": pl.Int64, "col": pl.String}470)471472actual = df.group_by("a").agg(473col=(474pl.when(475pl.col("a").shift(1) > 2,476pl.col("b").is_not_null(),477)478.then(pl.lit("abc"))479.when(480pl.col("a").shift(1) > 1,481pl.col("b").is_not_null(),482)483.then(pl.lit("def"))484.otherwise(pl.lit(None))485.first()486)487)488489assert_frame_equal(expect, actual)490491492def test_when_predicates_kwargs() -> None:493df = pl.DataFrame(494{495"x": [10, 20, 30, 40],496"y": [15, -20, None, 1],497"z": ["a", "b", "c", "d"],498}499)500assert_frame_equal( # kwargs only501df.select(matched=pl.when(x=30, z="c").then(True).otherwise(False)),502pl.DataFrame({"matched": [False, False, True, False]}),503)504assert_frame_equal( # mixed predicates & kwargs505df.select(matched=pl.when(pl.col("x") < 30, z="b").then(True).otherwise(False)),506pl.DataFrame({"matched": [False, True, False, False]}),507)508assert_frame_equal( # chained when/then with mixed predicates/kwargs509df.select(510misc=pl.when(pl.col("x") > 50)511.then(pl.lit("x>50"))512.when(y=1)513.then(pl.lit("y=1"))514.when(pl.col("z").is_in(["a", "b"]), pl.col("y") < 0)515.then(pl.lit("z in (a|b), y<0"))516.otherwise(pl.lit("?"))517),518pl.DataFrame({"misc": ["?", "z in (a|b), y<0", "?", "y=1"]}),519)520521522def test_when_then_null_broadcast() -> None:523assert (524pl.select(525pl.when(pl.repeat(True, 2, dtype=pl.Boolean)).then(526pl.repeat(None, 1, dtype=pl.Null)527)528).height529== 2530)531532533@pytest.mark.slow534@pytest.mark.parametrize("len", [1, 10, 100, 500])535@pytest.mark.parametrize(536("dtype", "vals"),537[538pytest.param(pl.Boolean, [False, True], id="Boolean"),539pytest.param(pl.UInt8, [0, 1], id="UInt8"),540pytest.param(pl.UInt16, [0, 1], id="UInt16"),541pytest.param(pl.UInt32, [0, 1], id="UInt32"),542pytest.param(pl.UInt64, [0, 1], id="UInt64"),543pytest.param(pl.Float32, [0.0, 1.0], id="Float32"),544pytest.param(pl.Float64, [0.0, 1.0], id="Float64"),545pytest.param(pl.String, ["0", "12"], id="String"),546pytest.param(pl.Array(pl.String, 2), [["0", "1"], ["3", "4"]], id="StrArray"),547pytest.param(pl.Array(pl.Int64, 2), [[0, 1], [3, 4]], id="IntArray"),548pytest.param(pl.List(pl.String), [["0"], ["1", "2"]], id="List"),549pytest.param(550pl.Struct({"foo": pl.Int32, "bar": pl.String}),551[{"foo": 0, "bar": "1"}, {"foo": 1, "bar": "2"}],552id="Struct",553),554pytest.param(555pl.Object,556["x", "y"],557id="Object",558marks=pytest.mark.may_fail_cloud,559# reason: objects are not allowed in cloud560),561],562)563@pytest.mark.parametrize("broadcast", list(itertools.product([False, True], repeat=3)))564def test_when_then_parametric(565len: int, dtype: pl.DataType, vals: list[Any], broadcast: list[bool]566) -> None:567# Makes no sense to broadcast all columns.568if all(broadcast):569return570571rng = random.Random(42)572573for _ in range(10):574mask = rng.choices([False, True, None], k=len)575if_true = rng.choices(vals + [None], k=len)576if_false = rng.choices(vals + [None], k=len)577578py_mask, py_true, py_false = (579[c[0]] * len if b else c580for b, c in zip(broadcast, [mask, if_true, if_false])581)582pl_mask, pl_true, pl_false = (583c.first() if b else c584for b, c in zip(broadcast, [pl.col.mask, pl.col.if_true, pl.col.if_false])585)586587ref = pl.DataFrame(588{"if_true": [t if m else f for m, t, f in zip(py_mask, py_true, py_false)]},589schema={"if_true": dtype},590)591df = pl.DataFrame(592{593"mask": mask,594"if_true": if_true,595"if_false": if_false,596},597schema={"mask": pl.Boolean, "if_true": dtype, "if_false": dtype},598)599600ans = df.select(pl.when(pl_mask).then(pl_true).otherwise(pl_false))601if dtype != pl.Object:602assert_frame_equal(ref, ans)603else:604assert ref["if_true"].to_list() == ans["if_true"].to_list()605606607def test_when_then_else_struct_18961() -> None:608v1 = [None, {"foo": 0, "bar": "1"}]609v2 = [{"foo": 0, "bar": "1"}, {"foo": 0, "bar": "1"}]610611df = pl.DataFrame({"left": v1, "right": v2, "mask": [False, True]})612613expected = [{"foo": 0, "bar": "1"}, {"foo": 0, "bar": "1"}]614ans = (615df.select(616pl.when(pl.col.mask).then(pl.col.left).otherwise(pl.col.right.first())617)618.get_column("left")619.to_list()620)621assert expected == ans622623df = pl.DataFrame({"left": v2, "right": v1, "mask": [True, False]})624625expected = [{"foo": 0, "bar": "1"}, {"foo": 0, "bar": "1"}]626ans = (627df.select(628pl.when(pl.col.mask).then(pl.col.left.first()).otherwise(pl.col.right)629)630.get_column("left")631.to_list()632)633assert expected == ans634635df = pl.DataFrame({"left": v1, "right": v2, "mask": [True, False]})636637expected2 = [None, {"foo": 0, "bar": "1"}]638ans = (639df.select(640pl.when(pl.col.mask)641.then(pl.col.left.first())642.otherwise(pl.col.right.first())643)644.get_column("left")645.to_list()646)647assert expected2 == ans648649650def test_when_then_supertype_15975() -> None:651df = pl.DataFrame({"a": [1, 2, 3]})652653assert df.with_columns(654pl.when(True).then(1 ** pl.col("a") + 1.0 * pl.col("a"))655).to_dict(as_series=False) == {"a": [1, 2, 3], "literal": [2.0, 3.0, 4.0]}656657658def test_when_then_supertype_15975_comment() -> None:659df = pl.LazyFrame({"foo": [1, 3, 4], "bar": [3, 4, 0]})660661q = df.with_columns(662pl.when(pl.col("foo") == 1)663.then(1)664.when(pl.col("foo") == 2)665.then(4)666.when(pl.col("foo") == 3)667.then(1.5)668.when(pl.col("foo") == 4)669.then(16)670.otherwise(0)671.alias("val")672)673674assert q.collect()["val"].to_list() == [1.0, 1.5, 16.0]675676677def test_chained_when_no_subclass_17142() -> None:678# https://github.com/pola-rs/polars/pull/17142679when = pl.when(True).then(1).when(True)680681assert not isinstance(when, pl.Expr)682assert "<polars.expr.whenthen.ChainedWhen object at" in str(when)683684685def test_when_then_chunked_structs_18673() -> None:686df = pl.DataFrame(687[688pl.Series("x", [{"a": 1}]),689pl.Series("b", [False]),690]691)692693df = df.vstack(df)694695# This used to panic696assert_frame_equal(697df.select(pl.when(pl.col.b).then(pl.first("x")).otherwise(pl.first("x"))),698pl.DataFrame({"x": [{"a": 1}, {"a": 1}]}),699)700701702some_scalar = pl.Series("a", [{"x": 2}], pl.Struct)703none_scalar = pl.Series("a", [None], pl.Struct({"x": pl.Int64}))704column = pl.Series("a", [{"x": 2}, {"x": 2}], pl.Struct)705706707@pytest.mark.parametrize(708"values",709[710(some_scalar, some_scalar),711(some_scalar, pl.col.a),712(some_scalar, none_scalar),713(some_scalar, column),714(none_scalar, pl.col.a),715(none_scalar, none_scalar),716(none_scalar, column),717(pl.col.a, pl.col.a),718(pl.col.a, column),719(column, column),720],721)722def test_struct_when_then_broadcasting_combinations_19122(723values: tuple[Any, Any],724) -> None:725lv, rv = values726727df = pl.Series("a", [{"x": 1}, {"x": 1}], pl.Struct).to_frame()728729assert_frame_equal(730df.select(731pl.when(pl.col.a.struct.field("x") == 0).then(lv).otherwise(rv).alias("a")732),733df.select(734pl.when(pl.col.a.struct.field("x") == 0).then(None).otherwise(rv).alias("a")735),736)737738assert_frame_equal(739df.select(740pl.when(pl.col.a.struct.field("x") != 0).then(rv).otherwise(lv).alias("a")741),742df.select(743pl.when(pl.col.a.struct.field("x") != 0).then(rv).otherwise(None).alias("a")744),745)746747748@pytest.mark.may_fail_cloud # reason str.to_decimal is an eager construct749def test_when_then_to_decimal_18375() -> None:750df = pl.DataFrame({"a": ["1.23", "4.56"]})751752result = df.with_columns(753b=pl.when(False).then(None).otherwise(pl.col("a").str.to_decimal(scale=2)),754c=pl.when(True).then(pl.col("a").str.to_decimal(scale=2)),755)756expected = pl.DataFrame(757{758"a": ["1.23", "4.56"],759"b": ["1.23", "4.56"],760"c": ["1.23", "4.56"],761},762schema={"a": pl.String, "b": pl.Decimal, "c": pl.Decimal},763)764assert_frame_equal(result, expected)765766767def test_when_then_chunked_fill_null_22794() -> None:768df = pl.DataFrame(769{770"node": [{"x": "a", "y": "a"}, {"x": "b", "y": "b"}, {"x": "c", "y": "c"}],771"level": [0, 1, 2],772}773)774775out = pl.concat([df.slice(0, 1), df.slice(1, 1), df.slice(2, 1)]).with_columns(776pl.when(level=1).then("node").forward_fill()777)778expected = pl.DataFrame(779{780"node": [None, {"x": "b", "y": "b"}, {"x": "b", "y": "b"}],781"level": [0, 1, 2],782}783)784785assert_frame_equal(out, expected)786787788def test_when_then_complex_conditional_22959() -> None:789df = pl.DataFrame(790{"B": [None, "T1", "T2"], "C": [None, None, [1.0]], "E": [None, 2.0, None]}791)792793res = df.with_columns(794Result=(795pl.when(B="T1")796.then(pl.struct(X="C", Y="C"))797.when(B="T2")798.then(pl.struct(X=pl.concat_list([3.0, "E"])))799)800)801802assert_series_equal(803res["Result"],804pl.Series(805"Result",806[None, {"X": None, "Y": None}, {"X": [3.0, None], "Y": None}],807pl.Struct({"X": pl.List(pl.Float64), "Y": pl.List(pl.Float64)}),808),809)810811812def test_when_then_simplification() -> None:813lf = pl.LazyFrame({"a": [12]})814assert (815"""[col("a")]"""816in (817lf.select(pl.when(True).then(pl.col("a")).otherwise(pl.col("a") * 2))818).explain()819)820assert (821"""(col("a")) * (2)"""822in (823lf.select(pl.when(False).then(pl.col("a")).otherwise(pl.col("a") * 2))824).explain()825)826827828def test_when_then_in_group_by_aggregated_22922() -> None:829df = pl.DataFrame({"group": ["x", "y", "x", "y"], "value": [1, 2, 3, 4]})830out = df.group_by("group", maintain_order=True).agg(831expr=pl.when(group="x").then(pl.col.value.max()).first()832)833expected = pl.DataFrame({"group": ["x", "y"], "expr": [3, None]})834assert_frame_equal(out, expected)835836837