Path: blob/main/py-polars/tests/unit/operations/aggregation/test_horizontal.py
6940 views
from __future__ import annotations12import datetime3from collections import OrderedDict4from typing import TYPE_CHECKING, Any56import pytest78import polars as pl9import polars.selectors as cs10from polars.exceptions import ComputeError, PolarsError11from polars.testing import assert_frame_equal, assert_series_equal1213if TYPE_CHECKING:14from polars._typing import PolarsDataType151617def test_any_expr(fruits_cars: pl.DataFrame) -> None:18assert fruits_cars.select(pl.any_horizontal("A", "B")).to_series()[0] is True192021def test_all_any_horizontally() -> None:22df = pl.DataFrame(23[24[False, False, True],25[False, False, True],26[True, False, False],27[False, None, True],28[None, None, False],29],30schema=["var1", "var2", "var3"],31orient="row",32)33result = df.select(34any=pl.any_horizontal(pl.col("var2"), pl.col("var3")),35all=pl.all_horizontal(pl.col("var2"), pl.col("var3")),36)37expected = pl.DataFrame(38{39"any": [True, True, False, True, None],40"all": [False, False, False, None, False],41}42)43assert_frame_equal(result, expected)4445# note: a kwargs filter will use an internal call to all_horizontal46dfltr = df.lazy().filter(var1=True, var3=False)47assert dfltr.collect().rows() == [(True, False, False)]4849# confirm that we reduced the horizontal filter components50# (eg: explain does not contain an "all_horizontal" node)51assert "horizontal" not in dfltr.explain().lower()525354def test_empty_all_any_horizontally() -> None:55# any/all_horizontal don't allow empty input, but we can still trigger this56# by selecting an empty set of columns with pl.selectors.57df = pl.DataFrame({"x": [1, 2, 3]})58assert_frame_equal(59df.select(pl.any_horizontal(cs.string().is_null())),60pl.DataFrame({"any_horizontal": False}),61)62assert_frame_equal(63df.select(pl.all_horizontal(cs.string().is_null())),64pl.DataFrame({"all_horizontal": True}),65)666768def test_all_any_single_input() -> None:69df = pl.DataFrame({"a": [0, 1, None]})70out = df.select(71all=pl.all_horizontal(pl.col("a")), any=pl.any_horizontal(pl.col("a"))72)7374expected = pl.DataFrame(75{76"all": [False, True, None],77"any": [False, True, None],78}79)80assert_frame_equal(out, expected)818283def test_all_any_accept_expr() -> None:84lf = pl.LazyFrame(85{86"a": [1, None, 2, None],87"b": [1, 2, None, None],88}89)9091result = lf.select(92pl.any_horizontal(pl.all().is_null()).alias("null_in_row"),93pl.all_horizontal(pl.all().is_null()).alias("all_null_in_row"),94)9596expected = pl.LazyFrame(97{98"null_in_row": [False, True, True, True],99"all_null_in_row": [False, False, False, True],100}101)102assert_frame_equal(result, expected)103104105def test_max_min_multiple_columns(fruits_cars: pl.DataFrame) -> None:106result = fruits_cars.select(max=pl.max_horizontal("A", "B"))107expected = pl.Series("max", [5, 4, 3, 4, 5])108assert_series_equal(result.to_series(), expected)109110result = fruits_cars.select(min=pl.min_horizontal("A", "B"))111expected = pl.Series("min", [1, 2, 3, 2, 1])112assert_series_equal(result.to_series(), expected)113114115def test_max_min_nulls_consistency() -> None:116df = pl.DataFrame({"a": [None, 2, 3], "b": [4, None, 6], "c": [7, 5, 0]})117118result = df.select(max=pl.max_horizontal("a", "b", "c")).to_series()119expected = pl.Series("max", [7, 5, 6])120assert_series_equal(result, expected)121122result = df.select(min=pl.min_horizontal("a", "b", "c")).to_series()123expected = pl.Series("min", [4, 2, 0])124assert_series_equal(result, expected)125126127def test_nested_min_max() -> None:128df = pl.DataFrame({"a": [1], "b": [2], "c": [3], "d": [4]})129130result = df.with_columns(131pl.max_horizontal(132pl.min_horizontal("a", "b"), pl.min_horizontal("c", "d")133).alias("t")134)135136expected = pl.DataFrame({"a": [1], "b": [2], "c": [3], "d": [4], "t": [3]})137assert_frame_equal(result, expected)138139140def test_empty_inputs_raise() -> None:141with pytest.raises(142ComputeError,143match="cannot return empty fold because the number of output rows is unknown",144):145pl.select(pl.any_horizontal())146147with pytest.raises(148ComputeError,149match="cannot return empty fold because the number of output rows is unknown",150):151pl.select(pl.all_horizontal())152153154def test_max_min_wildcard_columns(fruits_cars: pl.DataFrame) -> None:155result = fruits_cars.select(pl.col(pl.datatypes.Int64)).select(156min=pl.min_horizontal("*")157)158expected = pl.Series("min", [1, 2, 3, 2, 1])159assert_series_equal(result.to_series(), expected)160161result = fruits_cars.select(pl.col(pl.datatypes.Int64)).select(162min=pl.min_horizontal(pl.all())163)164assert_series_equal(result.to_series(), expected)165166result = fruits_cars.select(pl.col(pl.datatypes.Int64)).select(167max=pl.max_horizontal("*")168)169expected = pl.Series("max", [5, 4, 3, 4, 5])170assert_series_equal(result.to_series(), expected)171172result = fruits_cars.select(pl.col(pl.datatypes.Int64)).select(173max=pl.max_horizontal(pl.all())174)175assert_series_equal(result.to_series(), expected)176177result = fruits_cars.select(pl.col(pl.datatypes.Int64)).select(178max=pl.max_horizontal(pl.all(), "A", "*")179)180assert_series_equal(result.to_series(), expected)181182183@pytest.mark.parametrize(184("input", "expected_data"),185[186(pl.col("^a|b$"), [1, 2]),187(pl.col("a", "b"), [1, 2]),188(pl.col("a"), [1, 4]),189(pl.lit(5, dtype=pl.Int64), [5]),190(5.0, [5.0]),191],192)193def test_min_horizontal_single_input(input: Any, expected_data: list[Any]) -> None:194df = pl.DataFrame({"a": [1, 4], "b": [3, 2]})195result = df.select(min=pl.min_horizontal(input)).to_series()196expected = pl.Series("min", expected_data)197assert_series_equal(result, expected)198199200@pytest.mark.parametrize(201("inputs", "expected_data"),202[203((["a", "b"]), [1, 2]),204(("a", "b"), [1, 2]),205(("a", 3), [1, 3]),206],207)208def test_min_horizontal_multi_input(209inputs: tuple[Any, ...], expected_data: list[Any]210) -> None:211df = pl.DataFrame({"a": [1, 4], "b": [3, 2]})212result = df.select(min=pl.min_horizontal(*inputs))213expected = pl.DataFrame({"min": expected_data})214assert_frame_equal(result, expected)215216217@pytest.mark.parametrize(218("input", "expected_data"),219[220(pl.col("^a|b$"), [3, 4]),221(pl.col("a", "b"), [3, 4]),222(pl.col("a"), [1, 4]),223(pl.lit(5, dtype=pl.Int64), [5]),224(5.0, [5.0]),225],226)227def test_max_horizontal_single_input(input: Any, expected_data: list[Any]) -> None:228df = pl.DataFrame({"a": [1, 4], "b": [3, 2]})229result = df.select(max=pl.max_horizontal(input))230expected = pl.DataFrame({"max": expected_data})231assert_frame_equal(result, expected)232233234@pytest.mark.parametrize(235("inputs", "expected_data"),236[237((["a", "b"]), [3, 4]),238(("a", "b"), [3, 4]),239(("a", 3), [3, 4]),240],241)242def test_max_horizontal_multi_input(243inputs: tuple[Any, ...], expected_data: list[Any]244) -> None:245df = pl.DataFrame({"a": [1, 4], "b": [3, 2]})246result = df.select(max=pl.max_horizontal(*inputs))247expected = pl.DataFrame({"max": expected_data})248assert_frame_equal(result, expected)249250251def test_expanding_sum() -> None:252df = pl.DataFrame(253{254"x": [0, 1, 2],255"y_1": [1.1, 2.2, 3.3],256"y_2": [1.0, 2.5, 3.5],257}258)259260result = df.with_columns(pl.sum_horizontal(pl.col(r"^y_.*$")).alias("y_sum"))[261"y_sum"262]263assert result.to_list() == [2.1, 4.7, 6.8]264265266def test_sum_max_min() -> None:267df = pl.DataFrame({"a": [1, 2, 3], "b": [1.0, 2.0, 3.0]})268out = df.select(269sum=pl.sum_horizontal("a", "b"),270max=pl.max_horizontal("a", pl.col("b") ** 2),271min=pl.min_horizontal("a", pl.col("b") ** 2),272)273assert_series_equal(out["sum"], pl.Series("sum", [2.0, 4.0, 6.0]))274assert_series_equal(out["max"], pl.Series("max", [1.0, 4.0, 9.0]))275assert_series_equal(out["min"], pl.Series("min", [1.0, 2.0, 3.0]))276277278def test_str_sum_horizontal() -> None:279df = pl.DataFrame(280{"A": ["a", "b", None, "c", None], "B": ["f", "g", "h", None, None]}281)282out = df.select(pl.sum_horizontal("A", "B"))283assert_series_equal(out["A"], pl.Series("A", ["af", "bg", "h", "c", ""]))284285286def test_sum_null_dtype() -> None:287df = pl.DataFrame(288{289"A": [5, None, 3, 2, 1],290"B": [5, 3, None, 2, 1],291"C": [None, None, None, None, None],292}293)294295assert_series_equal(296df.select(pl.sum_horizontal("A", "B", "C")).to_series(),297pl.Series("A", [10, 3, 3, 4, 2]),298)299assert_series_equal(300df.select(pl.sum_horizontal("C", "B")).to_series(),301pl.Series("C", [5, 3, 0, 2, 1]),302)303assert_series_equal(304df.select(pl.sum_horizontal("C", "C")).to_series(),305pl.Series("C", [None, None, None, None, None]),306)307308309def test_sum_single_col() -> None:310df = pl.DataFrame(311{312"A": [5, None, 3, None, 1],313}314)315316assert_series_equal(317df.select(pl.sum_horizontal("A")).to_series(), pl.Series("A", [5, 0, 3, 0, 1])318)319320321@pytest.mark.parametrize("ignore_nulls", [False, True])322def test_sum_correct_supertype(ignore_nulls: bool) -> None:323values = [1, 2] if ignore_nulls else [None, None] # type: ignore[list-item]324lf = pl.LazyFrame(325{326"null": [None, None],327"int": pl.Series(values, dtype=pl.Int32),328"float": pl.Series(values, dtype=pl.Float32),329}330)331332# null + int32 should produce int32333out = lf.select(pl.sum_horizontal("null", "int", ignore_nulls=ignore_nulls))334expected = pl.LazyFrame({"null": pl.Series(values, dtype=pl.Int32)})335assert_frame_equal(out.collect(), expected.collect())336assert out.collect_schema() == expected.collect_schema()337338# null + float32 should produce float32339out = lf.select(pl.sum_horizontal("null", "float", ignore_nulls=ignore_nulls))340expected = pl.LazyFrame({"null": pl.Series(values, dtype=pl.Float32)})341assert_frame_equal(out.collect(), expected.collect())342assert out.collect_schema() == expected.collect_schema()343344# null + int32 + float32 should produce float64345values = [2, 4] if ignore_nulls else [None, None] # type: ignore[list-item]346out = lf.select(347pl.sum_horizontal("null", "int", "float", ignore_nulls=ignore_nulls)348)349expected = pl.LazyFrame({"null": pl.Series(values, dtype=pl.Float64)})350assert_frame_equal(out.collect(), expected.collect())351assert out.collect_schema() == expected.collect_schema()352353354def test_cum_sum_horizontal() -> None:355df = pl.DataFrame(356{357"a": [1, 2],358"b": [3, 4],359"c": [5, 6],360}361)362result = df.select(pl.cum_sum_horizontal("a", "c"))363expected = pl.DataFrame({"cum_sum": [{"a": 1, "c": 6}, {"a": 2, "c": 8}]})364assert_frame_equal(result, expected)365366q = df.lazy().select(pl.cum_sum_horizontal("a", "c"))367assert q.collect_schema() == q.collect().schema368369370def test_sum_dtype_12028() -> None:371result = pl.select(372pl.sum_horizontal([pl.duration(seconds=10)]).alias("sum_duration")373)374expected = pl.DataFrame(375[376pl.Series(377"sum_duration",378[datetime.timedelta(seconds=10)],379dtype=pl.Duration(time_unit="us"),380),381]382)383assert_frame_equal(expected, result)384385386def test_horizontal_expr_use_left_name() -> None:387df = pl.DataFrame(388{389"a": [1, 2],390"b": [3, 4],391}392)393394assert df.select(pl.sum_horizontal("a", "b")).columns == ["a"]395assert df.select(pl.max_horizontal("*")).columns == ["a"]396assert df.select(pl.min_horizontal("b", "a")).columns == ["b"]397assert df.select(pl.any_horizontal("b", "a")).columns == ["b"]398assert df.select(pl.all_horizontal("a", "b")).columns == ["a"]399400401def test_horizontal_broadcasting() -> None:402df = pl.DataFrame(403{404"a": [1, 3],405"b": [3, 6],406}407)408409assert_series_equal(410df.select(sum=pl.sum_horizontal(1, "a", "b")).to_series(),411pl.Series("sum", [5, 10]),412)413assert_series_equal(414df.select(mean=pl.mean_horizontal(1, "a", "b")).to_series(),415pl.Series("mean", [1.66666, 3.33333]),416)417assert_series_equal(418df.select(max=pl.max_horizontal(4, "*")).to_series(), pl.Series("max", [4, 6])419)420assert_series_equal(421df.select(min=pl.min_horizontal(2, "b", "a")).to_series(),422pl.Series("min", [1, 2]),423)424assert_series_equal(425df.select(any=pl.any_horizontal(False, pl.Series([True, False]))).to_series(),426pl.Series("any", [True, False]),427)428assert_series_equal(429df.select(all=pl.all_horizontal(True, pl.Series([True, False]))).to_series(),430pl.Series("all", [True, False]),431)432433434def test_mean_horizontal() -> None:435lf = pl.LazyFrame({"a": [1, 2, 3], "b": [2.0, 4.0, 6.0], "c": [3, None, 9]})436result = lf.select(pl.mean_horizontal(pl.all()).alias("mean"))437438expected = pl.LazyFrame({"mean": [2.0, 3.0, 6.0]}, schema={"mean": pl.Float64})439assert_frame_equal(result, expected)440441442def test_mean_horizontal_bool() -> None:443df = pl.DataFrame(444{445"a": [True, False, False],446"b": [None, True, False],447"c": [True, False, False],448}449)450expected = pl.DataFrame({"mean": [1.0, 1 / 3, 0.0]}, schema={"mean": pl.Float64})451result = df.select(mean=pl.mean_horizontal(pl.all()))452assert_frame_equal(result, expected)453454455def test_mean_horizontal_no_columns() -> None:456lf = pl.LazyFrame({"a": [1, 2, 3], "b": [2.0, 4.0, 6.0], "c": [3, None, 9]})457458with pytest.raises(ComputeError, match="number of output rows is unknown"):459lf.select(pl.mean_horizontal())460461462def test_mean_horizontal_no_rows() -> None:463lf = pl.LazyFrame({"a": [], "b": [], "c": []}).with_columns(pl.all().cast(pl.Int64))464465result = lf.select(pl.mean_horizontal(pl.all()))466467expected = pl.LazyFrame({"a": []}, schema={"a": pl.Float64})468assert_frame_equal(result, expected)469470471def test_mean_horizontal_all_null() -> None:472lf = pl.LazyFrame({"a": [1, None], "b": [2, None], "c": [None, None]})473474result = lf.select(pl.mean_horizontal(pl.all()))475476expected = pl.LazyFrame({"a": [1.5, None]}, schema={"a": pl.Float64})477assert_frame_equal(result, expected)478479480@pytest.mark.parametrize(481("in_dtype", "out_dtype"),482[483(pl.Boolean, pl.Float64),484(pl.UInt8, pl.Float64),485(pl.UInt16, pl.Float64),486(pl.UInt32, pl.Float64),487(pl.UInt64, pl.Float64),488(pl.Int8, pl.Float64),489(pl.Int16, pl.Float64),490(pl.Int32, pl.Float64),491(pl.Int64, pl.Float64),492(pl.Float32, pl.Float32),493(pl.Float64, pl.Float64),494],495)496def test_schema_mean_horizontal_single_column(497in_dtype: PolarsDataType,498out_dtype: PolarsDataType,499) -> None:500lf = pl.LazyFrame({"a": pl.Series([1, 0]).cast(in_dtype)}).select(501pl.mean_horizontal(pl.all())502)503504assert lf.collect_schema() == OrderedDict([("a", out_dtype)])505506507def test_schema_boolean_sum_horizontal() -> None:508lf = pl.LazyFrame({"a": [True, False]}).select(pl.sum_horizontal("a"))509assert lf.collect_schema() == OrderedDict([("a", pl.UInt32)])510511512def test_fold_all_schema() -> None:513df = pl.DataFrame(514{515"A": [1, 2, 3, 4, 5],516"fruits": ["banana", "banana", "apple", "apple", "banana"],517"B": [5, 4, 3, 2, 1],518"cars": ["beetle", "audi", "beetle", "beetle", "beetle"],519"optional": [28, 300, None, 2, -30],520}521)522# divide because of overflow523result = df.select(pl.sum_horizontal(pl.all().hash(seed=1) // int(1e8)))524assert result.dtypes == [pl.UInt64]525526527@pytest.mark.parametrize(528"horizontal_func",529[530pl.all_horizontal,531pl.any_horizontal,532pl.max_horizontal,533pl.min_horizontal,534pl.mean_horizontal,535pl.sum_horizontal,536],537)538def test_expected_horizontal_dtype_errors(horizontal_func: type[pl.Expr]) -> None:539from decimal import Decimal as D540541import polars as pl542543df = pl.DataFrame(544{545"cola": [D("1.5"), D("0.5"), D("5"), D("0"), D("-0.25")],546"colb": [[0, 1], [2], [3, 4], [5], [6]],547"colc": ["aa", "bb", "cc", "dd", "ee"],548"cold": ["bb", "cc", "dd", "ee", "ff"],549"cole": [1000, 2000, 3000, 4000, 5000],550}551)552with pytest.raises(PolarsError):553df.select(554horizontal_func( # type: ignore[call-arg]555pl.col("cola"),556pl.col("colb"),557pl.col("colc"),558pl.col("cold"),559pl.col("cole"),560)561)562563564def test_horizontal_sum_boolean_with_null() -> None:565lf = pl.LazyFrame(566{567"null": [None, None],568"bool": [True, False],569}570)571572out = lf.select(573pl.sum_horizontal("null", "bool").alias("null_first"),574pl.sum_horizontal("bool", "null").alias("bool_first"),575)576577expected_schema = pl.Schema(578{579"null_first": pl.get_index_type(),580"bool_first": pl.get_index_type(),581}582)583584assert out.collect_schema() == expected_schema585586expected_df = pl.DataFrame(587{588"null_first": pl.Series([1, 0], dtype=pl.get_index_type()),589"bool_first": pl.Series([1, 0], dtype=pl.get_index_type()),590}591)592593assert_frame_equal(out.collect(), expected_df)594595596@pytest.mark.parametrize("ignore_nulls", [True, False])597@pytest.mark.parametrize(598("dtype_in", "dtype_out"),599[600(pl.Null, pl.Null),601(pl.Boolean, pl.get_index_type()),602(pl.UInt8, pl.UInt8),603(pl.Float32, pl.Float32),604(pl.Float64, pl.Float64),605(pl.Decimal(None, 5), pl.Decimal(None, 5)),606],607)608def test_horizontal_sum_with_null_col_ignore_strategy(609dtype_in: PolarsDataType,610dtype_out: PolarsDataType,611ignore_nulls: bool,612) -> None:613lf = pl.LazyFrame(614{615"null": [None, None, None],616"s": pl.Series([1, 0, 1], dtype=dtype_in, strict=False),617"s2": pl.Series([1, 0, None], dtype=dtype_in, strict=False),618}619)620result = lf.select(pl.sum_horizontal("null", "s", "s2", ignore_nulls=ignore_nulls))621if ignore_nulls and dtype_in != pl.Null:622values = [2, 0, 1]623else:624values = [None, None, None] # type: ignore[list-item]625expected = pl.LazyFrame(pl.Series("null", values, dtype=dtype_out))626assert_frame_equal(result, expected)627assert result.collect_schema() == expected.collect_schema()628629630@pytest.mark.parametrize("ignore_nulls", [True, False])631@pytest.mark.parametrize(632("dtype_in", "dtype_out"),633[634(pl.Null, pl.Float64),635(pl.Boolean, pl.Float64),636(pl.UInt8, pl.Float64),637(pl.Float32, pl.Float32),638(pl.Float64, pl.Float64),639],640)641def test_horizontal_mean_with_null_col_ignore_strategy(642dtype_in: PolarsDataType,643dtype_out: PolarsDataType,644ignore_nulls: bool,645) -> None:646lf = pl.LazyFrame(647{648"null": [None, None, None],649"s": pl.Series([1, 0, 1], dtype=dtype_in, strict=False),650"s2": pl.Series([1, 0, None], dtype=dtype_in, strict=False),651}652)653result = lf.select(pl.mean_horizontal("null", "s", "s2", ignore_nulls=ignore_nulls))654if ignore_nulls and dtype_in != pl.Null:655values = [1, 0, 1]656else:657values = [None, None, None] # type: ignore[list-item]658expected = pl.LazyFrame(pl.Series("null", values, dtype=dtype_out))659assert_frame_equal(result, expected)660661662def test_raise_invalid_types_21835() -> None:663df = pl.DataFrame({"x": [1, 2], "y": ["three", "four"]})664665with pytest.raises(666ComputeError,667match=r"cannot compare string with numeric type \(i64\)",668):669df.select(pl.min_horizontal("x", "y"))670671672