Path: blob/main/py-polars/tests/unit/operations/namespaces/array/test_eval.py
8354 views
from __future__ import annotations12from typing import TYPE_CHECKING34import pytest56import polars as pl7from polars.testing import assert_series_equal89if TYPE_CHECKING:10from collections.abc import Callable111213def set_nulls(s: pl.Series, mask: list[bool]) -> pl.Series:14return pl.select(pl.when(pl.Series(mask)).then(s).alias(s.name)).to_series()151617@pytest.mark.parametrize("as_list", [False, True])18@pytest.mark.parametrize(19"nulls",20[21[True] * 3,22[False, True, True],23[True, False, True],24[True, True, False],25[False, False, True],26[True, False, False],27[False] * 3,28],29)30def test_eval_basic(as_list: bool, nulls: list[bool]) -> None:31if as_list:3233def rtdt(dt: pl.DataType) -> pl.DataType:34return pl.List(dt)35else:3637def rtdt(dt: pl.DataType) -> pl.DataType:38return pl.Array(dt, 2)3940s = set_nulls(41pl.Series("a", [[1, 4], [8, 5], [3, 2]], pl.Array(pl.Int64(), 2)), nulls42)4344assert_series_equal(45s.arr.eval(pl.element().rank(), as_list=as_list),46set_nulls(47pl.Series("a", [[1.0, 2.0], [2.0, 1.0], [2.0, 1.0]], rtdt(pl.Float64())),48nulls,49),50)51assert_series_equal(52s.arr.eval(pl.element() + 1, as_list=as_list),53set_nulls(pl.Series("a", [[2, 5], [9, 6], [4, 3]], rtdt(pl.Int64())), nulls),54)55assert_series_equal(56s.arr.eval(pl.element().cast(pl.String()), as_list=as_list),57s.cast(rtdt(pl.Int64())).cast(rtdt(pl.String())),58)5960if as_list:61assert_series_equal(62s.arr.eval(pl.element().unique(maintain_order=True), as_list=True),63s.cast(rtdt(pl.Int64())),64)656667def test_eval_raises_for_non_length_preserving() -> None:68s = pl.Series(69"a", [["A", "B", "C"], ["C", "C", "D"], ["D", "E", "E"]], pl.Array(pl.String, 3)70)7172with pytest.raises(pl.exceptions.InvalidOperationError, match="as_list"):73s.arr.eval(pl.element().unique(maintain_order=True))747576@pytest.mark.parametrize(77"nulls",78[79[True] * 3,80[False, True, True],81[True, False, True],82[True, True, False],83[False, False, True],84[True, False, False],85[False] * 3,86],87)88def test_eval_changing_length(nulls: list[bool]) -> None:89s = set_nulls(90pl.Series(91"a",92[["A", "B", "C"], ["C", "C", "D"], ["D", "E", "E"]],93pl.Array(pl.String, 3),94),95nulls,96)9798assert_series_equal(99s.arr.eval(pl.element().unique(maintain_order=True), as_list=True),100set_nulls(101pl.Series(102"a", [["A", "B", "C"], ["C", "D"], ["D", "E"]], pl.List(pl.String)103),104nulls,105),106)107108109def set_validity(s: pl.Series, validity: list[bool]) -> pl.Series:110return s.zip_with(pl.Series(validity), pl.Series([None], dtype=s.dtype))111112113@pytest.mark.parametrize(114"sum_expr",115[pl.element().sum(), pl.element().unique().sum(), pl.element().fill_null(1).sum()],116)117def test_arr_agg_sum(sum_expr: pl.Expr) -> None:118assert_series_equal(119pl.Series("a", [], pl.Array(pl.Int64, 2)).arr.agg(sum_expr),120pl.Series("a", [], pl.Int64),121)122123assert_series_equal(124pl.Series("a", [[0, 1, 2], [1, 3, 5]], pl.Array(pl.Int64, 3)).arr.agg(sum_expr),125pl.Series("a", [3, 9]),126)127128assert_series_equal(129pl.Series("a", [[], []], pl.Array(pl.Int64, 0)).arr.agg(sum_expr),130pl.Series("a", [0, 0]),131)132133assert_series_equal(134pl.Series("a", [None, [1, 3, 5]], pl.Array(pl.Int64, 3)).arr.agg(sum_expr),135pl.Series("a", [None, 9]),136)137138assert_series_equal(139set_validity(140pl.Series("a", [[1, 2, 3], [3, 4, 5], [1, 3, 5]], pl.Array(pl.Int64, 3)),141[True, False, True],142).arr.agg(sum_expr),143pl.Series("a", [6, None, 9]),144)145146147@pytest.mark.parametrize(148("expr", "is_scalar"),149[150(pl.Expr.null_count, True),151(lambda e: e.rank().null_count(), True),152(pl.Expr.rank, False),153(lambda e: e + pl.lit(1), False),154(lambda e: e.filter(e != 0), False),155(pl.Expr.drop_nulls, False),156(pl.Expr.n_unique, True),157],158)159def test_arr_agg_parametric(160expr: Callable[[pl.Expr], pl.Expr], is_scalar: bool161) -> None:162def test_case(s: pl.Series) -> None:163out = s.arr.agg(expr(pl.element()))164165for i, v in enumerate(s):166if v is None:167assert out[i] is None168continue169170assert isinstance(v, pl.Series)171172v = v.rename("")173v = v.to_frame().select(expr(pl.col(""))).to_series()174175if not is_scalar:176v = v.implode()177178assert_series_equal(out.rename("").slice(i, 1), v)179180test_case(pl.Series("a", [], pl.Array(pl.Int64, 2)))181test_case(pl.Series("a", [[]], pl.Array(pl.Int64, 0)))182test_case(pl.Series("a", [[7], [0]], pl.Array(pl.Int64, 1)))183test_case(pl.Series("a", [[8], [0], None], pl.Array(pl.Int64, 1)))184test_case(pl.Series("a", [None, [0], None], pl.Array(pl.Int64, 1)))185test_case(pl.Series("a", [[1, 2, 3], [4, 5, 6]], pl.Array(pl.Int64, 3)))186187188@pytest.mark.parametrize("insert_none", [False, True])189@pytest.mark.parametrize("keys", [pl.lit(42), pl.col.g])190@pytest.mark.parametrize("filter", [None, pl.lit(True), pl.col.b])191@pytest.mark.parametrize(192("expr", "as_list", "result"),193[194(195pl.element(),196False,197pl.Series("a", [[0, 1, 2], [5, 3, 4], [7, 7, 8]], pl.Array(pl.Int64, 3)),198),199(200pl.element() + pl.element(),201False,202pl.Series(203"a", [[0, 2, 4], [10, 6, 8], [14, 14, 16]], pl.Array(pl.Int64, 3)204),205),206(207pl.element().rank(),208False,209pl.Series(210"a",211[[1.0, 2.0, 3.0], [3.0, 1.0, 2.0], [1.5, 1.5, 3.0]],212pl.Array(pl.Float64, 3),213),214),215(pl.element().unique(), True, pl.Series("a", [[0, 1, 2], [5, 3, 4], [7, 8]])),216],217)218def test_arr_eval_with_filter_in_agg_25384(219insert_none: bool,220keys: pl.Expr,221filter: pl.Expr | None,222expr: pl.Expr,223as_list: bool,224result: pl.Series,225) -> None:226s = pl.Series("a", [[0, 1, 2], [5, 3, 4], [7, 7, 8]], pl.Array(pl.Int64, 3))227df = s.to_frame().with_columns(228pl.Series("g", [10, 10, 20]), pl.Series("b", [True, True, True])229)230q_inner = (231pl.col("a").arr.eval(expr, as_list=as_list)232if filter is None233else pl.col("a").filter(filter).arr.eval(expr, as_list=as_list)234)235236if insert_none:237df = df.with_columns(238pl.when(pl.int_range(0, pl.len()) != 1).then(pl.col.a).otherwise(None)239)240result = (241result.to_frame()242.with_columns(243pl.when(pl.int_range(0, pl.len()) != 1).then(pl.col.a).otherwise(None)244)245.to_series()246)247248# no agg249q = df.lazy().select(q_inner)250assert_series_equal(q.collect().to_series(), result)251252# over253q = df.lazy().select(q_inner.over(keys))254assert_series_equal(q.collect().to_series(), result)255256# group_by257q = df.lazy().group_by(keys, maintain_order=True).agg(q_inner)258out = q.collect().select(pl.col.a).explode("a")259assert_series_equal(out.to_series(), result)260261262@pytest.mark.parametrize("insert_none", [False, True])263@pytest.mark.parametrize("keys", [pl.lit(42), pl.col.g])264@pytest.mark.parametrize("filter", [None, pl.lit(True), pl.col.b])265@pytest.mark.parametrize(266("expr", "result"),267[268(pl.element().sum(), pl.Series("a", [1, 8, 22])),269(pl.element().null_count(), pl.Series("a", [1, 1, 0], pl.get_index_type())),270],271)272def test_arr_agg_with_filter_in_agg_25384(273insert_none: bool,274keys: pl.Expr,275filter: pl.Expr | None,276expr: pl.Expr,277result: pl.Series,278) -> None:279s = pl.Series("a", [[0, 1, None], [5, 3, None], [7, 7, 8]], pl.Array(pl.Int64, 3))280df = s.to_frame().with_columns(281pl.Series("g", [10, 10, 20]), pl.Series("b", [True, True, True])282)283q_inner = (284pl.col("a").arr.agg(expr)285if filter is None286else pl.col("a").filter(filter).arr.agg(expr)287)288289if insert_none:290df = df.with_columns(291pl.when(pl.int_range(0, pl.len()) != 1).then(pl.col.a).otherwise(None)292)293result = (294result.to_frame()295.with_columns(296pl.when(pl.int_range(0, pl.len()) != 1).then(pl.col.a).otherwise(None)297)298.to_series()299)300301# no agg302q = df.lazy().select(q_inner)303assert_series_equal(q.collect().to_series(), result)304305# over306q = df.lazy().select(q_inner.over(keys))307assert_series_equal(q.collect().to_series(), result)308309# group_by310q = df.lazy().group_by(keys, maintain_order=True).agg(q_inner)311out = q.collect().select(pl.col.a).explode("a")312assert_series_equal(out.to_series(), result)313314315