Path: blob/main/py-polars/tests/unit/operations/aggregation/test_aggregations.py
8424 views
from __future__ import annotations12from datetime import date, datetime, time, timedelta3from typing import TYPE_CHECKING, cast4from zoneinfo import ZoneInfo56import numpy as np7import pytest8from hypothesis import given910import polars as pl11from polars.testing import assert_frame_equal12from polars.testing.parametric import dataframes1314if TYPE_CHECKING:15from collections.abc import Callable16from typing import Any1718import numpy.typing as npt1920from polars._typing import PolarsDataType, TimeUnit212223def test_quantile_expr_input() -> None:24df = pl.DataFrame({"a": [1, 2, 3, 4, 5], "b": [0.0, 0.0, 0.3, 0.2, 0.0]})2526assert_frame_equal(27df.select([pl.col("a").quantile(pl.col("b").sum() + 0.1)]),28df.select(pl.col("a").quantile(0.6)),29)3031df = pl.DataFrame({"x": [1, 2, 3, 4], "y": [0.25, 0.3, 0.4, 0.75]})3233assert_frame_equal(34df.select(35pl.col.x.quantile(pl.concat_list(pl.col.y.min(), pl.col.y.max().first()))36),37df.select(pl.col.x.quantile([0.25, 0.75])),38)394041def test_boolean_aggs() -> None:42df = pl.DataFrame({"bool": [True, False, None, True]})4344aggs = [45pl.mean("bool").alias("mean"),46pl.std("bool").alias("std"),47pl.var("bool").alias("var"),48]49assert df.select(aggs).to_dict(as_series=False) == {50"mean": [0.6666666666666666],51"std": [0.5773502691896258],52"var": [0.33333333333333337],53}5455assert df.group_by(pl.lit(1)).agg(aggs).to_dict(as_series=False) == {56"literal": [1],57"mean": [0.6666666666666666],58"std": [0.5773502691896258],59"var": [0.33333333333333337],60}616263def test_duration_aggs() -> None:64df = pl.DataFrame(65{66"time1": pl.datetime_range(67start=datetime(2022, 12, 12),68end=datetime(2022, 12, 18),69interval="1d",70eager=True,71),72"time2": pl.datetime_range(73start=datetime(2023, 1, 12),74end=datetime(2023, 1, 18),75interval="1d",76eager=True,77),78}79)8081df = df.with_columns((pl.col("time2") - pl.col("time1")).alias("time_difference"))8283assert df.select("time_difference").mean().to_dict(as_series=False) == {84"time_difference": [timedelta(days=31)]85}86assert df.group_by(pl.lit(1)).agg(pl.mean("time_difference")).to_dict(87as_series=False88) == {89"literal": [1],90"time_difference": [timedelta(days=31)],91}929394def test_list_aggregation_that_filters_all_data_6017() -> None:95out = (96pl.DataFrame({"col_to_group_by": [2], "flt": [1672740910.967138], "col3": [1]})97.group_by("col_to_group_by")98.agg((pl.col("flt").filter(col3=0).diff() * 1000).diff().alias("calc"))99)100101assert out.schema == {"col_to_group_by": pl.Int64, "calc": pl.List(pl.Float64)}102assert out.to_dict(as_series=False) == {"col_to_group_by": [2], "calc": [[]]}103104105def test_median() -> None:106s = pl.Series([1, 2, 3])107assert s.median() == 2108109110def test_single_element_std() -> None:111s = pl.Series([1])112assert s.std(ddof=1) is None113assert s.std(ddof=0) == 0.0114115116def test_quantile() -> None:117s = pl.Series([1, 2, 3])118assert s.quantile(0.5, "nearest") == 2119assert s.quantile(0.5, "lower") == 2120assert s.quantile(0.5, "higher") == 2121assert s.quantile([0.25, 0.75], "linear") == [1.5, 2.5]122123df = pl.DataFrame({"a": [1.0, 2.0, 3.0]})124expected = pl.DataFrame({"a": [[2.0]]})125assert_frame_equal(126df.select(pl.col("a").quantile([0.5], interpolation="linear")), expected127)128129130def test_quantile_error_checking() -> None:131s = pl.Series([1, 2, 3])132with pytest.raises(pl.exceptions.ComputeError):133s.quantile(-0.1)134with pytest.raises(pl.exceptions.ComputeError):135s.quantile(1.1)136with pytest.raises(pl.exceptions.ComputeError):137s.quantile([0.0, 1.2])138139140def test_quantile_date() -> None:141s = pl.Series(142"a", [date(2025, 1, 1), date(2025, 1, 2), date(2025, 1, 3), date(2025, 1, 4)]143)144assert s.quantile(0.5, "nearest") == datetime(2025, 1, 3)145assert s.quantile(0.5, "lower") == datetime(2025, 1, 2)146assert s.quantile(0.5, "higher") == datetime(2025, 1, 3)147assert s.quantile(0.5, "linear") == datetime(2025, 1, 2, 12)148149df = s.to_frame().lazy()150result = df.select(151nearest=pl.col("a").quantile(0.5, "nearest"),152lower=pl.col("a").quantile(0.5, "lower"),153higher=pl.col("a").quantile(0.5, "higher"),154linear=pl.col("a").quantile(0.5, "linear"),155)156dt = pl.Datetime("us")157assert result.collect_schema() == pl.Schema(158{159"nearest": dt,160"lower": dt,161"higher": dt,162"linear": dt,163}164)165expected = pl.DataFrame(166{167"nearest": pl.Series([datetime(2025, 1, 3)], dtype=dt),168"lower": pl.Series([datetime(2025, 1, 2)], dtype=dt),169"higher": pl.Series([datetime(2025, 1, 3)], dtype=dt),170"linear": pl.Series([datetime(2025, 1, 2, 12)], dtype=dt),171}172)173assert_frame_equal(result.collect(), expected)174175176@pytest.mark.parametrize("tu", ["ms", "us", "ns"])177@pytest.mark.parametrize("tz", [None, "Asia/Tokyo", "UTC"])178def test_quantile_datetime(tu: TimeUnit, tz: str) -> None:179time_zone = ZoneInfo(tz) if tz else None180dt = pl.Datetime(tu, time_zone)181182s = pl.Series(183"a",184[185datetime(2025, 1, 1, tzinfo=time_zone),186datetime(2025, 1, 2, tzinfo=time_zone),187datetime(2025, 1, 3, tzinfo=time_zone),188datetime(2025, 1, 4, tzinfo=time_zone),189],190dtype=dt,191)192assert s.quantile(0.5, "nearest") == datetime(2025, 1, 3, tzinfo=time_zone)193assert s.quantile(0.5, "lower") == datetime(2025, 1, 2, tzinfo=time_zone)194assert s.quantile(0.5, "higher") == datetime(2025, 1, 3, tzinfo=time_zone)195assert s.quantile(0.5, "linear") == datetime(2025, 1, 2, 12, tzinfo=time_zone)196197df = s.to_frame().lazy()198result = df.select(199nearest=pl.col("a").quantile(0.5, "nearest"),200lower=pl.col("a").quantile(0.5, "lower"),201higher=pl.col("a").quantile(0.5, "higher"),202linear=pl.col("a").quantile(0.5, "linear"),203)204assert result.collect_schema() == pl.Schema(205{206"nearest": dt,207"lower": dt,208"higher": dt,209"linear": dt,210}211)212expected = pl.DataFrame(213{214"nearest": pl.Series([datetime(2025, 1, 3, tzinfo=time_zone)], dtype=dt),215"lower": pl.Series([datetime(2025, 1, 2, tzinfo=time_zone)], dtype=dt),216"higher": pl.Series([datetime(2025, 1, 3, tzinfo=time_zone)], dtype=dt),217"linear": pl.Series([datetime(2025, 1, 2, 12, tzinfo=time_zone)], dtype=dt),218}219)220assert_frame_equal(result.collect(), expected)221222223@pytest.mark.parametrize("tu", ["ms", "us", "ns"])224def test_quantile_duration(tu: TimeUnit) -> None:225dt = pl.Duration(tu)226227s = pl.Series(228"a",229[timedelta(days=1), timedelta(days=2), timedelta(days=3), timedelta(days=4)],230dtype=dt,231)232assert s.quantile(0.5, "nearest") == timedelta(days=3)233assert s.quantile(0.5, "lower") == timedelta(days=2)234assert s.quantile(0.5, "higher") == timedelta(days=3)235assert s.quantile(0.5, "linear") == timedelta(days=2, hours=12)236237df = s.to_frame().lazy()238result = df.select(239nearest=pl.col("a").quantile(0.5, "nearest"),240lower=pl.col("a").quantile(0.5, "lower"),241higher=pl.col("a").quantile(0.5, "higher"),242linear=pl.col("a").quantile(0.5, "linear"),243)244assert result.collect_schema() == pl.Schema(245{246"nearest": dt,247"lower": dt,248"higher": dt,249"linear": dt,250}251)252expected = pl.DataFrame(253{254"nearest": pl.Series([timedelta(days=3)], dtype=dt),255"lower": pl.Series([timedelta(days=2)], dtype=dt),256"higher": pl.Series([timedelta(days=3)], dtype=dt),257"linear": pl.Series([timedelta(days=2, hours=12)], dtype=dt),258}259)260assert_frame_equal(result.collect(), expected)261262263def test_quantile_time() -> None:264s = pl.Series("a", [time(hour=1), time(hour=2), time(hour=3), time(hour=4)])265assert s.quantile(0.5, "nearest") == time(hour=3)266assert s.quantile(0.5, "lower") == time(hour=2)267assert s.quantile(0.5, "higher") == time(hour=3)268assert s.quantile(0.5, "linear") == time(hour=2, minute=30)269270df = s.to_frame().lazy()271result = df.select(272nearest=pl.col("a").quantile(0.5, "nearest"),273lower=pl.col("a").quantile(0.5, "lower"),274higher=pl.col("a").quantile(0.5, "higher"),275linear=pl.col("a").quantile(0.5, "linear"),276)277assert result.collect_schema() == pl.Schema(278{279"nearest": pl.Time,280"lower": pl.Time,281"higher": pl.Time,282"linear": pl.Time,283}284)285expected = pl.DataFrame(286{287"nearest": pl.Series([time(hour=3)]),288"lower": pl.Series([time(hour=2)]),289"higher": pl.Series([time(hour=3)]),290"linear": pl.Series([time(hour=2, minute=30)]),291}292)293assert_frame_equal(result.collect(), expected)294295296@pytest.mark.slow297@pytest.mark.parametrize("tp", [int, float])298@pytest.mark.parametrize("n", [1, 2, 10, 100])299def test_quantile_vs_numpy(tp: type, n: int) -> None:300a: np.ndarray[Any, Any] = np.random.randint(0, 50, n).astype(tp)301np_result: npt.ArrayLike | None = np.median(a)302# nan check303if np_result != np_result:304np_result = None305median = pl.Series(a).median()306if median is not None:307assert np.isclose(median, np_result) # type: ignore[arg-type]308else:309assert np_result is None310311q = np.random.sample()312try:313np_result = np.quantile(a, q)314except IndexError:315np_result = None316if np_result:317# nan check318if np_result != np_result:319np_result = None320assert np.isclose(321pl.Series(a).quantile(q, interpolation="linear"), # type: ignore[arg-type]322np_result, # type: ignore[arg-type]323)324325df = pl.DataFrame({"a": a})326327expected = df.select(328pl.col.a.quantile(0.25).alias("low"), pl.col.a.quantile(0.75).alias("high")329).select(pl.concat_list(["low", "high"]).alias("quantiles"))330331result = df.select(pl.col.a.quantile([0.25, 0.75]).alias("quantiles"))332333assert_frame_equal(expected, result)334335336def test_mean_overflow() -> None:337assert np.isclose(338pl.Series([9_223_372_036_854_775_800, 100]).mean(), # type: ignore[arg-type]3394.611686018427388e18,340)341342343def test_mean_null_simd() -> None:344for dtype in [int, float]:345df = (346pl.Series(np.random.randint(0, 100, 1000))347.cast(dtype)348.to_frame("a")349.select(pl.when(pl.col("a") > 40).then(pl.col("a")))350)351352s = df["a"]353assert s.mean() == s.to_pandas().mean()354355356def test_literal_group_agg_chunked_7968() -> None:357df = pl.DataFrame({"A": [1, 1], "B": [1, 3]})358ser = pl.concat([pl.Series([3]), pl.Series([4, 5])], rechunk=False)359360assert_frame_equal(361df.group_by("A").agg(pl.col("B").search_sorted(ser)),362pl.DataFrame(363[364pl.Series("A", [1], dtype=pl.Int64),365pl.Series("B", [[1, 2, 2]], dtype=pl.List(pl.get_index_type())),366]367),368)369370371def test_duration_function_literal() -> None:372df = pl.DataFrame(373{374"A": ["x", "x", "y", "y", "y"],375"T": pl.datetime_range(376date(2022, 1, 1), date(2022, 5, 1), interval="1mo", eager=True377),378"S": [1, 2, 4, 8, 16],379}380)381382result = df.group_by("A", maintain_order=True).agg(383(pl.col("T").max() + pl.duration(seconds=1)) - pl.col("T")384)385386# this checks if the `pl.duration` is flagged as AggState::Literal387expected = pl.DataFrame(388{389"A": ["x", "y"],390"T": [391[timedelta(days=31, seconds=1), timedelta(seconds=1)],392[393timedelta(days=61, seconds=1),394timedelta(days=30, seconds=1),395timedelta(seconds=1),396],397],398}399)400assert_frame_equal(result, expected)401402403def test_string_par_materialize_8207() -> None:404df = pl.LazyFrame(405{406"a": ["a", "b", "d", "c", "e"],407"b": ["P", "L", "R", "T", "a long string"],408}409)410411assert df.group_by(["a"]).agg(pl.min("b")).sort("a").collect().to_dict(412as_series=False413) == {414"a": ["a", "b", "c", "d", "e"],415"b": ["P", "L", "T", "R", "a long string"],416}417418419def test_online_variance() -> None:420df = pl.DataFrame(421{422"id": [1] * 5,423"no_nulls": [1, 2, 3, 4, 5],424"nulls": [1, None, 3, None, 5],425}426)427428assert_frame_equal(429df.group_by("id")430.agg(pl.all().exclude("id").std())431.select(["no_nulls", "nulls"]),432df.select(pl.all().exclude("id").std()),433)434435436def test_implode_and_agg() -> None:437df = pl.DataFrame({"type": ["water", "fire", "water", "earth"]})438439assert_frame_equal(440df.group_by("type").agg(pl.col("type").implode().first().alias("foo")),441pl.DataFrame(442{443"type": ["water", "fire", "earth"],444"foo": [["water", "water"], ["fire"], ["earth"]],445}446),447check_row_order=False,448)449450# implode + function should be allowed in group_by451assert df.group_by("type", maintain_order=True).agg(452pl.col("type").implode().list.head().alias("foo")453).to_dict(as_series=False) == {454"type": ["water", "fire", "earth"],455"foo": [["water", "water"], ["fire"], ["earth"]],456}457assert df.select(pl.col("type").implode().list.head(1).over("type")).to_dict(458as_series=False459) == {"type": [["water"], ["fire"], ["water"], ["earth"]]}460461462def test_mapped_literal_to_literal_9217() -> None:463df = pl.DataFrame({"unique_id": ["a", "b"]})464assert df.group_by(True).agg(465pl.struct(pl.lit("unique_id").alias("unique_id"))466).to_dict(as_series=False) == {467"literal": [True],468"unique_id": [{"unique_id": "unique_id"}],469}470471472def test_sum_empty_and_null_set() -> None:473series = pl.Series("a", [], dtype=pl.Float32)474assert series.sum() == 0475476series = pl.Series("a", [None], dtype=pl.Float32)477assert series.sum() == 0478479df = pl.DataFrame(480{"a": [None, None, None], "b": [1, 1, 1]},481schema={"a": pl.Float32, "b": pl.Int64},482)483assert df.select(pl.sum("a")).item() == 0.0484assert df.group_by("b").agg(pl.sum("a"))["a"].item() == 0.0485486487def test_horizontal_sum_null_to_identity() -> None:488assert pl.DataFrame({"a": [1, 5], "b": [10, None]}).select(489pl.sum_horizontal(["a", "b"])490).to_series().to_list() == [11, 5]491492493def test_horizontal_sum_bool_dtype() -> None:494out = pl.DataFrame({"a": [True, False]}).select(pl.sum_horizontal("a"))495assert_frame_equal(496out, pl.DataFrame({"a": pl.Series([1, 0], dtype=pl.get_index_type())})497)498499500def test_horizontal_sum_in_group_by_15102() -> None:501nbr_records = 1000502out = (503pl.LazyFrame(504{505"x": [None, "two", None] * nbr_records,506"y": ["one", "two", None] * nbr_records,507"z": [None, "two", None] * nbr_records,508}509)510.select(pl.sum_horizontal(pl.all().is_null()).alias("num_null"))511.group_by("num_null")512.len()513.sort(by="num_null")514.collect()515)516assert_frame_equal(517out,518pl.DataFrame(519{520"num_null": pl.Series([0, 2, 3], dtype=pl.get_index_type()),521"len": pl.Series([nbr_records] * 3, dtype=pl.get_index_type()),522}523),524)525526527def test_first_last_unit_length_12363() -> None:528df = pl.DataFrame(529{530"a": [1, 2],531"b": [None, None],532}533)534535assert df.select(536pl.all().drop_nulls().first().name.suffix("_first"),537pl.all().drop_nulls().last().name.suffix("_last"),538).to_dict(as_series=False) == {539"a_first": [1],540"b_first": [None],541"a_last": [2],542"b_last": [None],543}544545546def test_binary_op_agg_context_no_simplify_expr_12423() -> None:547expect = pl.DataFrame({"x": [1], "y": [1]}, schema={"x": pl.Int64, "y": pl.Int32})548549for simplify_expression in (True, False):550assert_frame_equal(551expect,552pl.LazyFrame({"x": [1]})553.group_by("x")554.agg(y=pl.lit(1) * pl.lit(1))555.collect(556optimizations=pl.QueryOptFlags(simplify_expression=simplify_expression)557),558)559560561def test_nan_inf_aggregation() -> None:562df = pl.DataFrame(563[564("both nan", np.nan),565("both nan", np.nan),566("nan and 5", np.nan),567("nan and 5", 5),568("nan and null", np.nan),569("nan and null", None),570("both none", None),571("both none", None),572("both inf", np.inf),573("both inf", np.inf),574("inf and null", np.inf),575("inf and null", None),576],577schema=["group", "value"],578orient="row",579)580581assert_frame_equal(582df.group_by("group", maintain_order=True).agg(583min=pl.col("value").min(),584max=pl.col("value").max(),585mean=pl.col("value").mean(),586),587pl.DataFrame(588[589("both nan", np.nan, np.nan, np.nan),590("nan and 5", 5, 5, np.nan),591("nan and null", np.nan, np.nan, np.nan),592("both none", None, None, None),593("both inf", np.inf, np.inf, np.inf),594("inf and null", np.inf, np.inf, np.inf),595],596schema=["group", "min", "max", "mean"],597orient="row",598),599)600601602@pytest.mark.parametrize("dtype", [pl.Int16, pl.UInt16])603def test_int16_max_12904(dtype: PolarsDataType) -> None:604s = pl.Series([None, 1], dtype=dtype)605606assert s.min() == 1607assert s.max() == 1608609610def test_agg_filter_over_empty_df_13610() -> None:611ldf = pl.LazyFrame(612{613"a": [1, 1, 1, 2, 3],614"b": [True, True, True, True, True],615"c": [None, None, None, None, None],616}617)618619out = (620ldf.drop_nulls()621.group_by(["a"], maintain_order=True)622.agg(pl.col("b").filter(pl.col("b").shift(1)))623.collect()624)625expected = pl.DataFrame(schema={"a": pl.Int64, "b": pl.List(pl.Boolean)})626assert_frame_equal(out, expected)627628df = pl.DataFrame(schema={"a": pl.Int64, "b": pl.Boolean})629out = df.group_by("a").agg(pl.col("b").filter(pl.col("b").shift()))630expected = pl.DataFrame(schema={"a": pl.Int64, "b": pl.List(pl.Boolean)})631assert_frame_equal(out, expected)632633634@pytest.mark.may_fail_cloud # reason: output order is defined for this in cloud635@pytest.mark.may_fail_auto_streaming636@pytest.mark.slow637def test_agg_empty_sum_after_filter_14734() -> None:638f = (639pl.DataFrame({"a": [1, 2], "b": [1, 2]})640.lazy()641.group_by("a")642.agg(pl.col("b").filter(pl.lit(False)).sum())643.collect644)645646last = f()647648# We need both possible output orders, which should happen within649# 1000 iterations (during testing it usually happens within 10).650limit = 1000651i = 0652while (curr := f()).equals(last):653i += 1654assert i != limit655656expect = pl.Series("b", [0, 0]).to_frame()657assert_frame_equal(expect, last.select("b"))658assert_frame_equal(expect, curr.select("b"))659660661@pytest.mark.slow662def test_grouping_hash_14749() -> None:663n_groups = 251664rows_per_group = 4665assert (666pl.DataFrame(667{668"grp": np.repeat(np.arange(n_groups), rows_per_group),669"x": np.tile(np.arange(rows_per_group), n_groups),670}671)672.select(pl.col("x").max().over("grp"))["x"]673.value_counts()674).to_dict(as_series=False) == {"x": [3], "count": [1004]}675676677@pytest.mark.parametrize(678("in_dtype", "out_dtype"),679[680(pl.Boolean, pl.Float64),681(pl.UInt8, pl.Float64),682(pl.UInt16, pl.Float64),683(pl.UInt32, pl.Float64),684(pl.UInt64, pl.Float64),685(pl.Int8, pl.Float64),686(pl.Int16, pl.Float64),687(pl.Int32, pl.Float64),688(pl.Int64, pl.Float64),689(pl.Float32, pl.Float32),690(pl.Float64, pl.Float64),691],692)693def test_horizontal_mean_single_column(694in_dtype: PolarsDataType,695out_dtype: PolarsDataType,696) -> None:697out = (698pl.LazyFrame({"a": pl.Series([1, 0]).cast(in_dtype)})699.select(pl.mean_horizontal(pl.all()))700.collect()701)702703assert_frame_equal(out, pl.DataFrame({"a": pl.Series([1.0, 0.0]).cast(out_dtype)}))704705706def test_horizontal_mean_in_group_by_15115() -> None:707nbr_records = 1000708out = (709pl.LazyFrame(710{711"w": [None, "one", "two", "three"] * nbr_records,712"x": [None, None, "two", "three"] * nbr_records,713"y": [None, None, None, "three"] * nbr_records,714"z": [None, None, None, None] * nbr_records,715}716)717.select(pl.mean_horizontal(pl.all().is_null()).alias("mean_null"))718.group_by("mean_null")719.len()720.sort(by="mean_null")721.collect()722)723assert_frame_equal(724out,725pl.DataFrame(726{727"mean_null": pl.Series([0.25, 0.5, 0.75, 1.0], dtype=pl.Float64),728"len": pl.Series([nbr_records] * 4, dtype=pl.get_index_type()),729}730),731)732733734def test_group_count_over_null_column_15705() -> None:735df = pl.DataFrame(736{"a": [1, 1, 2, 2, 3, 3], "c": [None, None, None, None, None, None]}737)738out = df.group_by("a", maintain_order=True).agg(pl.col("c").count())739assert out["c"].to_list() == [0, 0, 0]740741742@pytest.mark.release743def test_min_max_2850() -> None:744# https://github.com/pola-rs/polars/issues/2850745df = pl.DataFrame(746{747"id": [748130352432,749130352277,750130352611,751130352833,752130352305,753130352258,754130352764,755130352475,756130352368,757130352346,758]759}760)761762minimum = 130352258763maximum = 130352833.0764765for _ in range(10):766permuted = df.sample(fraction=1.0, seed=0)767computed = permuted.select(768pl.col("id").min().alias("min"), pl.col("id").max().alias("max")769)770assert cast("int", computed[0, "min"]) == minimum771assert cast("float", computed[0, "max"]) == maximum772773774def test_multi_arg_structify_15834() -> None:775df = pl.DataFrame(776{777"group": [1, 2, 1, 2],778"value": [7790.1973209146402105,7800.13380719982405365,7810.6152394463707009,7820.4558767896005155,783],784}785)786787assert df.lazy().group_by("group").agg(788pl.struct(a=1, value=pl.col("value").sum())789).collect().sort("group").to_dict(as_series=False) == {790"group": [1, 2],791"a": [792{"a": 1, "value": 0.8125603610109114},793{"a": 1, "value": 0.5896839894245691},794],795}796797798def test_filter_aggregation_16642() -> None:799df = pl.DataFrame(800{801"datetime": [802datetime(2022, 1, 1, 11, 0),803datetime(2022, 1, 1, 11, 1),804datetime(2022, 1, 1, 11, 2),805datetime(2022, 1, 1, 11, 3),806datetime(2022, 1, 1, 11, 4),807datetime(2022, 1, 1, 11, 5),808datetime(2022, 1, 1, 11, 6),809datetime(2022, 1, 1, 11, 7),810datetime(2022, 1, 1, 11, 8),811datetime(2022, 1, 1, 11, 9, 1),812datetime(2022, 1, 2, 11, 0),813datetime(2022, 1, 2, 11, 1),814datetime(2022, 1, 2, 11, 2),815datetime(2022, 1, 2, 11, 3),816datetime(2022, 1, 2, 11, 4),817datetime(2022, 1, 2, 11, 5),818datetime(2022, 1, 2, 11, 6),819datetime(2022, 1, 2, 11, 7),820datetime(2022, 1, 2, 11, 8),821datetime(2022, 1, 2, 11, 9, 1),822],823"alpha": [824"A",825"B",826"C",827"D",828"E",829"F",830"G",831"H",832"I",833"J",834"A",835"B",836"C",837"D",838"E",839"F",840"G",841"H",842"I",843"J",844],845"num": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9],846}847)848grouped = df.group_by(pl.col("datetime").dt.date())849850ts_filter = pl.col("datetime").dt.time() <= pl.time(11, 3)851852report = grouped.agg(pl.col("num").filter(ts_filter).max()).sort("datetime")853assert report.to_dict(as_series=False) == {854"datetime": [date(2022, 1, 1), date(2022, 1, 2)],855"num": [3, 3],856}857858859def test_sort_by_over_single_nulls_first() -> None:860key = [0, 0, 0, 0, 1, 1, 1, 1]861df = pl.DataFrame(862{863"key": key,864"value": [2, None, 1, 0, 2, None, 1, 0],865}866)867out = df.select(868pl.all().sort_by("value", nulls_last=False, maintain_order=True).over("key")869)870expected = pl.DataFrame(871{872"key": key,873"value": [None, 0, 1, 2, None, 0, 1, 2],874}875)876assert_frame_equal(out, expected)877878879def test_sort_by_over_single_nulls_last() -> None:880key = [0, 0, 0, 0, 1, 1, 1, 1]881df = pl.DataFrame(882{883"key": key,884"value": [2, None, 1, 0, 2, None, 1, 0],885}886)887out = df.select(888pl.all().sort_by("value", nulls_last=True, maintain_order=True).over("key")889)890expected = pl.DataFrame(891{892"key": key,893"value": [0, 1, 2, None, 0, 1, 2, None],894}895)896assert_frame_equal(out, expected)897898899def test_sort_by_over_multiple_nulls_first() -> None:900key1 = [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]901key2 = [0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1]902df = pl.DataFrame(903{904"key1": key1,905"key2": key2,906"value": [1, None, 0, 1, None, 0, 1, None, 0, None, 1, 0],907}908)909out = df.select(910pl.all()911.sort_by("value", nulls_last=False, maintain_order=True)912.over("key1", "key2")913)914expected = pl.DataFrame(915{916"key1": key1,917"key2": key2,918"value": [None, 0, 1, None, 0, 1, None, 0, 1, None, 0, 1],919}920)921assert_frame_equal(out, expected)922923924def test_sort_by_over_multiple_nulls_last() -> None:925key1 = [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]926key2 = [0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1]927df = pl.DataFrame(928{929"key1": key1,930"key2": key2,931"value": [1, None, 0, 1, None, 0, 1, None, 0, None, 1, 0],932}933)934out = df.select(935pl.all()936.sort_by("value", nulls_last=True, maintain_order=True)937.over("key1", "key2")938)939expected = pl.DataFrame(940{941"key1": key1,942"key2": key2,943"value": [0, 1, None, 0, 1, None, 0, 1, None, 0, 1, None],944}945)946assert_frame_equal(out, expected)947948949def test_slice_after_agg() -> None:950assert_frame_equal(951pl.select(a=pl.lit(1, dtype=pl.Int64), b=pl.lit(1, dtype=pl.Int64))952.group_by("a")953.agg(pl.col("b").first().slice(99, 0)),954pl.DataFrame({"a": [1], "b": [[]]}, schema_overrides={"b": pl.List(pl.Int64)}),955)956957958def test_agg_scalar_empty_groups_20115() -> None:959assert_frame_equal(960(961pl.DataFrame({"key": [123], "value": [456]})962.group_by("key")963.agg(pl.col("value").slice(1, 1).first())964),965pl.select(key=pl.lit(123, pl.Int64), value=pl.lit(None, pl.Int64)),966)967968969def test_agg_expr_returns_list_type_15574() -> None:970assert (971pl.LazyFrame({"a": [1, None], "b": [1, 2]})972.group_by("b")973.agg(pl.col("a").drop_nulls())974.collect_schema()975) == {"b": pl.Int64, "a": pl.List(pl.Int64)}976977978def test_empty_agg_22005() -> None:979out = (980pl.concat([pl.LazyFrame({"a": [1, 2]}), pl.LazyFrame({"a": [1, 2]})])981.limit(0)982.select(pl.col("a").sum())983)984assert_frame_equal(out.collect(), pl.DataFrame({"a": 0}))985986987@pytest.mark.parametrize("wrap_numerical", [True, False])988@pytest.mark.parametrize("strict_cast", [True, False])989def test_agg_with_filter_then_cast_23682(990strict_cast: bool, wrap_numerical: bool991) -> None:992assert_frame_equal(993pl.DataFrame([{"a": 123, "b": 12}, {"a": 123, "b": 257}])994.group_by("a")995.agg(996pl.col("b")997.filter(pl.col("b") < 256)998.cast(pl.UInt8, strict=strict_cast, wrap_numerical=wrap_numerical)999),1000pl.DataFrame(1001[{"a": 123, "b": [12]}], schema={"a": pl.Int64, "b": pl.List(pl.UInt8)}1002),1003)100410051006@pytest.mark.parametrize("wrap_numerical", [True, False])1007@pytest.mark.parametrize("strict_cast", [True, False])1008def test_agg_with_slice_then_cast_23682(1009strict_cast: bool, wrap_numerical: bool1010) -> None:1011assert_frame_equal(1012pl.DataFrame([{"a": 123, "b": 12}, {"a": 123, "b": 257}])1013.group_by("a")1014.agg(1015pl.col("b")1016.slice(0, 1)1017.cast(pl.UInt8, strict=strict_cast, wrap_numerical=wrap_numerical)1018),1019pl.DataFrame(1020[{"a": 123, "b": [12]}], schema={"a": pl.Int64, "b": pl.List(pl.UInt8)}1021),1022)102310241025@pytest.mark.parametrize(1026("op", "expr"),1027[1028("any", pl.all().cast(pl.Boolean).any()),1029("all", pl.all().cast(pl.Boolean).all()),1030("arg_max", pl.all().arg_max()),1031("arg_min", pl.all().arg_min()),1032("min", pl.all().min()),1033("max", pl.all().max()),1034("mean", pl.all().mean()),1035("median", pl.all().median()),1036("product", pl.all().product()),1037("quantile", pl.all().quantile(0.5)),1038("std", pl.all().std()),1039("var", pl.all().var()),1040("sum", pl.all().sum()),1041("first", pl.all().first()),1042("last", pl.all().last()),1043("approx_n_unique", pl.all().approx_n_unique()),1044("bitwise_and", pl.all().bitwise_and()),1045("bitwise_or", pl.all().bitwise_or()),1046("bitwise_xor", pl.all().bitwise_xor()),1047],1048)1049@pytest.mark.parametrize(1050"df",1051[1052pl.DataFrame({"a": [[10]]}, schema={"a": pl.Array(shape=(1,), inner=pl.Int32)}),1053pl.DataFrame({"a": [[1]]}, schema={"a": pl.Struct(fields={"a": pl.Int32})}),1054pl.DataFrame({"a": [True]}, schema={"a": pl.Boolean}),1055pl.DataFrame({"a": ["a"]}, schema={"a": pl.Categorical}),1056pl.DataFrame({"a": [b"a"]}, schema={"a": pl.Binary}),1057pl.DataFrame({"a": ["a"]}, schema={"a": pl.Utf8}),1058pl.DataFrame({"a": [10]}, schema={"a": pl.Int32}),1059pl.DataFrame({"a": [10]}, schema={"a": pl.Float16}),1060pl.DataFrame({"a": [10]}, schema={"a": pl.Float32}),1061pl.DataFrame({"a": [10]}, schema={"a": pl.Float64}),1062pl.DataFrame({"a": [10]}, schema={"a": pl.Int128}),1063pl.DataFrame({"a": [10]}, schema={"a": pl.UInt128}),1064pl.DataFrame({"a": ["a"]}, schema={"a": pl.String}),1065pl.DataFrame({"a": [None]}, schema={"a": pl.Null}),1066pl.DataFrame({"a": [10]}, schema={"a": pl.Decimal()}),1067pl.DataFrame({"a": [datetime.now()]}, schema={"a": pl.Datetime}),1068pl.DataFrame({"a": [date.today()]}, schema={"a": pl.Date}),1069pl.DataFrame({"a": [timedelta(seconds=10)]}, schema={"a": pl.Duration}),1070],1071)1072def test_agg_invalid_same_engines_behavior(1073op: str, expr: pl.Expr, df: pl.DataFrame1074) -> None:1075# If the in-memory engine produces a good result, then the streaming engine1076# should also produce a good result, and then it should match the in-memory result.10771078if isinstance(df.schema["a"], pl.Struct) and op in {"any", "all"}:1079# TODO: Remove this exception when #24509 is resolved1080pytest.skip("polars/#24509")10811082if isinstance(df.schema["a"], pl.Duration) and op in {"std", "var"}:1083# TODO: Remove this exception when std & var are implemented for Duration1084pytest.skip(f"'{op}' aggregation not yet implemented for Duration")10851086inmemory_result, inmemory_error = None, None1087streaming_result, streaming_error = None, None10881089try:1090inmemory_result = df.select(expr)1091except pl.exceptions.PolarsError as e:1092inmemory_error = e10931094try:1095streaming_result = df.lazy().select(expr).collect(engine="streaming")1096except pl.exceptions.PolarsError as e:1097streaming_error = e10981099assert (streaming_error is None) == (inmemory_error is None), (1100f"mismatch in errors for: {streaming_error} != {inmemory_error}"1101)1102if inmemory_error:1103assert streaming_error, (1104f"streaming engine did not error (expected in-memory error: {inmemory_error})"1105)1106assert streaming_error.__class__ == inmemory_error.__class__11071108if not inmemory_error:1109assert streaming_result is not None1110assert inmemory_result is not None1111assert_frame_equal(streaming_result, inmemory_result)111211131114@pytest.mark.parametrize(1115("op", "expr"),1116[1117("sum", pl.all().sum()),1118("mean", pl.all().mean()),1119("median", pl.all().median()),1120("std", pl.all().std()),1121("var", pl.all().var()),1122("quantile", pl.all().quantile(0.5)),1123("cum_sum", pl.all().cum_sum()),1124],1125)1126@pytest.mark.parametrize(1127"df",1128[1129pl.DataFrame({"a": [[10]]}, schema={"a": pl.Array(shape=(1), inner=pl.Int32)}),1130pl.DataFrame({"a": [[1]]}, schema={"a": pl.Struct(fields={"a": pl.Int32})}),1131pl.DataFrame({"a": ["a"]}, schema={"a": pl.Categorical}),1132pl.DataFrame({"a": [b"a"]}, schema={"a": pl.Binary}),1133pl.DataFrame({"a": ["a"]}, schema={"a": pl.Utf8}),1134pl.DataFrame({"a": ["a"]}, schema={"a": pl.String}),1135],1136)1137def test_invalid_agg_dtypes_should_raise(1138op: str, expr: pl.Expr, df: pl.DataFrame1139) -> None:1140with pytest.raises(1141pl.exceptions.PolarsError, match=rf"`{op}` operation not supported for dtype"1142):1143df.select(expr)1144with pytest.raises(1145pl.exceptions.PolarsError, match=rf"`{op}` operation not supported for dtype"1146):1147df.lazy().select(expr).collect(engine="streaming")114811491150@given(1151df=dataframes(1152min_size=1,1153max_size=1,1154excluded_dtypes=[1155# TODO: polars/#249361156pl.Struct,1157],1158)1159)1160def test_single(df: pl.DataFrame) -> None:1161q = df.lazy().select(pl.all(ignore_nulls=False).item())1162assert_frame_equal(q.collect(), df)1163assert_frame_equal(q.collect(engine="streaming"), df)116411651166@given(df=dataframes(max_size=0))1167def test_single_empty(df: pl.DataFrame) -> None:1168q = df.lazy().select(pl.all().item())1169match = "aggregation 'item' expected a single value, got none"1170with pytest.raises(pl.exceptions.ComputeError, match=match):1171q.collect()1172with pytest.raises(pl.exceptions.ComputeError, match=match):1173q.collect(engine="streaming")117411751176@given(df=dataframes(min_size=2))1177def test_item_too_many(df: pl.DataFrame) -> None:1178q = df.lazy().select(pl.all(ignore_nulls=False).item())1179match = f"aggregation 'item' expected a single value, got {df.height} values"1180with pytest.raises(pl.exceptions.ComputeError, match=match):1181q.collect()1182with pytest.raises(pl.exceptions.ComputeError, match=match):1183q.collect(engine="streaming")118411851186@given(1187df=dataframes(1188min_size=1,1189max_size=1,1190allow_null=False,1191excluded_dtypes=[1192# TODO: polars/#249361193pl.Struct,1194],1195)1196)1197def test_item_on_groups(df: pl.DataFrame) -> None:1198df = df.with_columns(pl.col("col0").alias("key"))1199q = df.lazy().group_by("col0").agg(pl.all(ignore_nulls=False).item())1200assert_frame_equal(q.collect(), df)1201assert_frame_equal(q.collect(engine="streaming"), df)120212031204def test_item_on_groups_empty() -> None:1205df = pl.DataFrame({"col0": [[]]})1206q = df.lazy().select(pl.all().list.item())1207match = "aggregation 'item' expected a single value, got none"1208with pytest.raises(pl.exceptions.ComputeError, match=match):1209q.collect()1210with pytest.raises(pl.exceptions.ComputeError, match=match):1211q.collect(engine="streaming")121212131214def test_item_on_groups_too_many() -> None:1215df = pl.DataFrame({"col0": [[1, 2, 3]]})1216q = df.lazy().select(pl.all().list.item())1217match = "aggregation 'item' expected a single value, got 3 values"1218with pytest.raises(pl.exceptions.ComputeError, match=match):1219q.collect()1220with pytest.raises(pl.exceptions.ComputeError, match=match):1221q.collect(engine="streaming")122212231224def test_all_any_on_list_raises_error() -> None:1225# Ensure boolean reductions on non-boolean columns raise an error.1226# (regression for #24942).1227lf = pl.LazyFrame({"x": [[True]]}, schema={"x": pl.List(pl.Boolean)})12281229# for in-memory engine1230for expr in (pl.col("x").all(), pl.col("x").any()):1231with pytest.raises(1232pl.exceptions.InvalidOperationError, match=r"expected boolean"1233):1234lf.select(expr).collect()12351236# for streaming engine1237for expr in (pl.col("x").all(), pl.col("x").any()):1238with pytest.raises(1239pl.exceptions.InvalidOperationError, match=r"expected boolean"1240):1241lf.select(expr).collect(engine="streaming")124212431244@pytest.mark.parametrize("null_endpoints", [True, False])1245@pytest.mark.parametrize("ignore_nulls", [True, False])1246@pytest.mark.parametrize(1247("dtype", "first_value", "last_value"),1248[1249# Struct1250(1251pl.Struct({"x": pl.Enum(["c0", "c1"]), "y": pl.Float32}),1252{"x": "c0", "y": 1.2},1253{"x": "c1", "y": 3.4},1254),1255# List1256(pl.List(pl.UInt8), [1], [2]),1257# Array1258(pl.Array(pl.Int16, 2), [1, 2], [3, 4]),1259# Date (logical test)1260(pl.Date, date(2025, 1, 1), date(2025, 1, 2)),1261# Float (primitive test)1262(pl.Float32, 1.0, 2.0),1263],1264)1265def test_first_last_nested(1266null_endpoints: bool,1267ignore_nulls: bool,1268dtype: PolarsDataType,1269first_value: Any,1270last_value: Any,1271) -> None:1272s = pl.Series([first_value, last_value], dtype=dtype)1273if null_endpoints:1274# Test the case where the first/last value is null1275null = pl.Series([None], dtype=dtype)1276s = pl.concat((null, s, null))12771278lf = pl.LazyFrame({"a": s})12791280# first1281result = lf.select(pl.col("a").first(ignore_nulls=ignore_nulls)).collect()1282expected = pl.DataFrame(1283{1284"a": pl.Series(1285[None if null_endpoints and not ignore_nulls else first_value],1286dtype=dtype,1287)1288}1289)1290assert_frame_equal(result, expected)12911292# last1293result = lf.select(pl.col("a").last(ignore_nulls=ignore_nulls)).collect()1294expected = pl.DataFrame(1295{1296"a": pl.Series(1297[None if null_endpoints and not ignore_nulls else last_value],1298dtype=dtype,1299),1300}1301)1302assert_frame_equal(result, expected)130313041305def test_struct_enum_agg_streaming_24936() -> None:1306s = (1307pl.Series(1308"a",1309[{"f0": "c0"}],1310dtype=pl.Struct({"f0": pl.Enum(categories=["c0"])}),1311),1312)1313df = pl.DataFrame(s)13141315q = df.lazy().select(pl.all(ignore_nulls=False).first())1316assert_frame_equal(q.collect(), df)131713181319def test_sum_inf_not_nan_25849() -> None:1320data = [10.0, None, 10.0, 10.0, 10.0, 10.0, float("inf"), 10.0, 10.0]1321df = pl.DataFrame({"x": data, "g": ["X"] * len(data)})1322assert df.group_by("g").agg(pl.col("x").sum())["x"].item() == float("inf")132313241325COLS = ["flt", "dec", "int", "str", "cat", "enum", "date", "dt"]132613271328@pytest.mark.parametrize(1329"agg_funcs", [(pl.Expr.min_by, pl.Expr.min), (pl.Expr.max_by, pl.Expr.max)]1330)1331@pytest.mark.parametrize("by_col", COLS)1332def test_min_max_by(agg_funcs: Any, by_col: str) -> None:1333agg_by, agg = agg_funcs1334df = pl.DataFrame(1335{1336"flt": [3.0, 2.0, float("nan"), 5.0, None, 4.0],1337"dec": [3, 2, None, 5, None, 4],1338"int": [3, 2, None, 5, None, 4],1339"str": ["c", "b", None, "e", None, "d"],1340"cat": ["c", "b", None, "e", None, "d"],1341"enum": ["c", "b", None, "e", None, "d"],1342"date": [1343date(2023, 3, 3),1344date(2023, 2, 2),1345None,1346date(2023, 5, 5),1347None,1348date(2023, 4, 4),1349],1350"dt": [1351datetime(2023, 3, 3),1352datetime(2023, 2, 2),1353None,1354datetime(2023, 5, 5),1355None,1356datetime(2023, 4, 4),1357],1358"g": [1, 1, 1, 2, 2, 2],1359},1360schema_overrides={1361"dec": pl.Decimal(scale=5),1362"cat": pl.Categorical,1363"enum": pl.Enum(["a", "b", "c", "d", "e", "f"]),1364},1365)13661367result = df.select([agg_by(pl.col(c), pl.col(by_col)) for c in COLS])1368expected = df.select([agg(pl.col(c)) for c in COLS])1369assert_frame_equal(result, expected)13701371# TODO: remove after https://github.com/pola-rs/polars/issues/25906.1372if by_col != "cat":1373df = df.drop("cat")1374cols = [c for c in COLS if c != "cat"]13751376result = df.group_by("g").agg([agg_by(pl.col(c), pl.col(by_col)) for c in cols])1377expected = df.group_by("g").agg([agg(pl.col(c)) for c in cols])1378assert_frame_equal(result, expected, check_row_order=False)137913801381@pytest.mark.parametrize(("agg", "expected"), [("max", 2), ("min", 0)])1382def test_grouped_minmax_after_reverse_on_sorted_column_26141(1383agg: str, expected: int1384) -> None:1385df = pl.DataFrame({"a": [0, 1, 2]}).sort("a")13861387expr = getattr(pl.col("a").reverse(), agg)()1388out = df.group_by(1).agg(expr)13891390expected_df = pl.DataFrame(1391{1392"literal": pl.Series([1], dtype=pl.Int32),1393"a": [expected],1394}1395)1396assert_frame_equal(out, expected_df)139713981399@pytest.mark.may_fail_auto_streaming1400@pytest.mark.parametrize("agg_by", [pl.Expr.min_by, pl.Expr.max_by])1401def test_min_max_by_series_length_mismatch_26049(1402agg_by: Callable[[pl.Expr, pl.Expr], pl.Expr],1403) -> None:1404lf = pl.LazyFrame(1405{1406"a": [0, 10, 20, 30, 40, 50, 60, 70, 80, 90],1407"b": [18, 5, 8, 8, 4, 5, 6, 8, 1, -10],1408"group": ["A", "A", "A", "A", "A", "B", "B", "C", "C", "C"],1409}1410)14111412q = lf.with_columns(1413agg_by(pl.col("group").filter(pl.col("b") % 2 == 0), pl.col("a"))1414)14151416with pytest.raises(1417pl.exceptions.ShapeError,1418match=r"^'by' column in (min|max)_by expression has incorrect length: expected \d+, got \d+$",1419):1420q.collect(engine="in-memory")1421with pytest.raises(1422pl.exceptions.ShapeError,1423match=r"^zip node received non-equal length inputs$",1424):1425q.collect(engine="streaming")14261427actual = (1428lf.group_by("group")1429.agg(1430pl.col("a")1431.max_by(pl.col("b").filter(pl.col("b") < 20).abs())1432.alias("max_by")1433)1434.sort("group")1435).collect()1436expected = pl.DataFrame(1437{1438"group": ["A", "B", "C"],1439"max_by": [0, 60, 90],1440}1441)1442assert_frame_equal(actual, expected)14431444q = (1445lf.group_by("group")1446.agg(1447pl.col("a")1448.max_by(pl.col("b").filter(pl.col("b") < 7).abs())1449.alias("group_length_mismatch")1450)1451.sort("group")1452)1453with pytest.raises(1454pl.exceptions.ShapeError,1455match=r"^expressions must have matching group lengths$",1456):1457q.collect(engine="in-memory")145814591460@pytest.mark.parametrize(1461"by_expr",1462[1463pl.struct("b", "c"),1464pl.concat_list("b", "c"),1465],1466)1467def test_min_by_max_by_nested_type_key_26268(by_expr: pl.Expr) -> None:1468df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 6, 5], "c": [7, 5, 2]})14691470with pytest.raises(1471pl.exceptions.InvalidOperationError,1472match="cannot use a nested type as `by` argument in `min_by`/`max_by`",1473):1474df.select(pl.col("a").min_by(by_expr))147514761477