Path: blob/main/py-polars/tests/unit/operations/aggregation/test_aggregations.py
6940 views
from __future__ import annotations12from datetime import date, datetime, timedelta3from typing import TYPE_CHECKING, Any, cast45import numpy as np6import pytest78import polars as pl9from polars.exceptions import InvalidOperationError10from polars.testing import assert_frame_equal1112if TYPE_CHECKING:13import numpy.typing as npt1415from polars._typing import PolarsDataType161718def test_quantile_expr_input() -> None:19df = pl.DataFrame({"a": [1, 2, 3, 4, 5], "b": [0.0, 0.0, 0.3, 0.2, 0.0]})2021assert_frame_equal(22df.select([pl.col("a").quantile(pl.col("b").sum() + 0.1)]),23df.select(pl.col("a").quantile(0.6)),24)252627def test_boolean_aggs() -> None:28df = pl.DataFrame({"bool": [True, False, None, True]})2930aggs = [31pl.mean("bool").alias("mean"),32pl.std("bool").alias("std"),33pl.var("bool").alias("var"),34]35assert df.select(aggs).to_dict(as_series=False) == {36"mean": [0.6666666666666666],37"std": [0.5773502691896258],38"var": [0.33333333333333337],39}4041assert df.group_by(pl.lit(1)).agg(aggs).to_dict(as_series=False) == {42"literal": [1],43"mean": [0.6666666666666666],44"std": [0.5773502691896258],45"var": [0.33333333333333337],46}474849def test_duration_aggs() -> None:50df = pl.DataFrame(51{52"time1": pl.datetime_range(53start=datetime(2022, 12, 12),54end=datetime(2022, 12, 18),55interval="1d",56eager=True,57),58"time2": pl.datetime_range(59start=datetime(2023, 1, 12),60end=datetime(2023, 1, 18),61interval="1d",62eager=True,63),64}65)6667df = df.with_columns((pl.col("time2") - pl.col("time1")).alias("time_difference"))6869assert df.select("time_difference").mean().to_dict(as_series=False) == {70"time_difference": [timedelta(days=31)]71}72assert df.group_by(pl.lit(1)).agg(pl.mean("time_difference")).to_dict(73as_series=False74) == {75"literal": [1],76"time_difference": [timedelta(days=31)],77}787980def test_list_aggregation_that_filters_all_data_6017() -> None:81out = (82pl.DataFrame({"col_to_group_by": [2], "flt": [1672740910.967138], "col3": [1]})83.group_by("col_to_group_by")84.agg((pl.col("flt").filter(col3=0).diff() * 1000).diff().alias("calc"))85)8687assert out.schema == {"col_to_group_by": pl.Int64, "calc": pl.List(pl.Float64)}88assert out.to_dict(as_series=False) == {"col_to_group_by": [2], "calc": [[]]}899091def test_median() -> None:92s = pl.Series([1, 2, 3])93assert s.median() == 2949596def test_single_element_std() -> None:97s = pl.Series([1])98assert s.std(ddof=1) is None99assert s.std(ddof=0) == 0.0100101102def test_quantile() -> None:103s = pl.Series([1, 2, 3])104assert s.quantile(0.5, "nearest") == 2105assert s.quantile(0.5, "lower") == 2106assert s.quantile(0.5, "higher") == 2107108109@pytest.mark.slow110@pytest.mark.parametrize("tp", [int, float])111@pytest.mark.parametrize("n", [1, 2, 10, 100])112def test_quantile_vs_numpy(tp: type, n: int) -> None:113a: np.ndarray[Any, Any] = np.random.randint(0, 50, n).astype(tp)114np_result: npt.ArrayLike | None = np.median(a)115# nan check116if np_result != np_result:117np_result = None118median = pl.Series(a).median()119if median is not None:120assert np.isclose(median, np_result) # type: ignore[arg-type]121else:122assert np_result is None123124q = np.random.sample()125try:126np_result = np.quantile(a, q)127except IndexError:128np_result = None129if np_result:130# nan check131if np_result != np_result:132np_result = None133assert np.isclose(134pl.Series(a).quantile(q, interpolation="linear"), # type: ignore[arg-type]135np_result, # type: ignore[arg-type]136)137138139def test_mean_overflow() -> None:140assert np.isclose(141pl.Series([9_223_372_036_854_775_800, 100]).mean(), # type: ignore[arg-type]1424.611686018427388e18,143)144145146def test_mean_null_simd() -> None:147for dtype in [int, float]:148df = (149pl.Series(np.random.randint(0, 100, 1000))150.cast(dtype)151.to_frame("a")152.select(pl.when(pl.col("a") > 40).then(pl.col("a")))153)154155s = df["a"]156assert s.mean() == s.to_pandas().mean()157158159def test_literal_group_agg_chunked_7968() -> None:160df = pl.DataFrame({"A": [1, 1], "B": [1, 3]})161ser = pl.concat([pl.Series([3]), pl.Series([4, 5])], rechunk=False)162163assert_frame_equal(164df.group_by("A").agg(pl.col("B").search_sorted(ser)),165pl.DataFrame(166[167pl.Series("A", [1], dtype=pl.Int64),168pl.Series("B", [[1, 2, 2]], dtype=pl.List(pl.UInt32)),169]170),171)172173174def test_duration_function_literal() -> None:175df = pl.DataFrame(176{177"A": ["x", "x", "y", "y", "y"],178"T": pl.datetime_range(179date(2022, 1, 1), date(2022, 5, 1), interval="1mo", eager=True180),181"S": [1, 2, 4, 8, 16],182}183)184185result = df.group_by("A", maintain_order=True).agg(186(pl.col("T").max() + pl.duration(seconds=1)) - pl.col("T")187)188189# this checks if the `pl.duration` is flagged as AggState::Literal190expected = pl.DataFrame(191{192"A": ["x", "y"],193"T": [194[timedelta(days=31, seconds=1), timedelta(seconds=1)],195[196timedelta(days=61, seconds=1),197timedelta(days=30, seconds=1),198timedelta(seconds=1),199],200],201}202)203assert_frame_equal(result, expected)204205206def test_string_par_materialize_8207() -> None:207df = pl.LazyFrame(208{209"a": ["a", "b", "d", "c", "e"],210"b": ["P", "L", "R", "T", "a long string"],211}212)213214assert df.group_by(["a"]).agg(pl.min("b")).sort("a").collect().to_dict(215as_series=False216) == {217"a": ["a", "b", "c", "d", "e"],218"b": ["P", "L", "T", "R", "a long string"],219}220221222def test_online_variance() -> None:223df = pl.DataFrame(224{225"id": [1] * 5,226"no_nulls": [1, 2, 3, 4, 5],227"nulls": [1, None, 3, None, 5],228}229)230231assert_frame_equal(232df.group_by("id")233.agg(pl.all().exclude("id").std())234.select(["no_nulls", "nulls"]),235df.select(pl.all().exclude("id").std()),236)237238239def test_implode_and_agg() -> None:240df = pl.DataFrame({"type": ["water", "fire", "water", "earth"]})241242# this would OOB243with pytest.raises(244InvalidOperationError,245match=r"'implode' followed by an aggregation is not allowed",246):247df.group_by("type").agg(pl.col("type").implode().first().alias("foo"))248249# implode + function should be allowed in group_by250assert df.group_by("type", maintain_order=True).agg(251pl.col("type").implode().list.head().alias("foo")252).to_dict(as_series=False) == {253"type": ["water", "fire", "earth"],254"foo": [["water", "water"], ["fire"], ["earth"]],255}256assert df.select(pl.col("type").implode().list.head(1).over("type")).to_dict(257as_series=False258) == {"type": [["water"], ["fire"], ["water"], ["earth"]]}259260261def test_mapped_literal_to_literal_9217() -> None:262df = pl.DataFrame({"unique_id": ["a", "b"]})263assert df.group_by(True).agg(264pl.struct(pl.lit("unique_id").alias("unique_id"))265).to_dict(as_series=False) == {266"literal": [True],267"unique_id": [{"unique_id": "unique_id"}],268}269270271def test_sum_empty_and_null_set() -> None:272series = pl.Series("a", [], dtype=pl.Float32)273assert series.sum() == 0274275series = pl.Series("a", [None], dtype=pl.Float32)276assert series.sum() == 0277278df = pl.DataFrame(279{"a": [None, None, None], "b": [1, 1, 1]},280schema={"a": pl.Float32, "b": pl.Int64},281)282assert df.select(pl.sum("a")).item() == 0.0283assert df.group_by("b").agg(pl.sum("a"))["a"].item() == 0.0284285286def test_horizontal_sum_null_to_identity() -> None:287assert pl.DataFrame({"a": [1, 5], "b": [10, None]}).select(288pl.sum_horizontal(["a", "b"])289).to_series().to_list() == [11, 5]290291292def test_horizontal_sum_bool_dtype() -> None:293out = pl.DataFrame({"a": [True, False]}).select(pl.sum_horizontal("a"))294assert_frame_equal(out, pl.DataFrame({"a": pl.Series([1, 0], dtype=pl.UInt32)}))295296297def test_horizontal_sum_in_group_by_15102() -> None:298nbr_records = 1000299out = (300pl.LazyFrame(301{302"x": [None, "two", None] * nbr_records,303"y": ["one", "two", None] * nbr_records,304"z": [None, "two", None] * nbr_records,305}306)307.select(pl.sum_horizontal(pl.all().is_null()).alias("num_null"))308.group_by("num_null")309.len()310.sort(by="num_null")311.collect()312)313assert_frame_equal(314out,315pl.DataFrame(316{317"num_null": pl.Series([0, 2, 3], dtype=pl.UInt32),318"len": pl.Series([nbr_records] * 3, dtype=pl.UInt32),319}320),321)322323324def test_first_last_unit_length_12363() -> None:325df = pl.DataFrame(326{327"a": [1, 2],328"b": [None, None],329}330)331332assert df.select(333pl.all().drop_nulls().first().name.suffix("_first"),334pl.all().drop_nulls().last().name.suffix("_last"),335).to_dict(as_series=False) == {336"a_first": [1],337"b_first": [None],338"a_last": [2],339"b_last": [None],340}341342343def test_binary_op_agg_context_no_simplify_expr_12423() -> None:344expect = pl.DataFrame({"x": [1], "y": [1]}, schema={"x": pl.Int64, "y": pl.Int32})345346for simplify_expression in (True, False):347assert_frame_equal(348expect,349pl.LazyFrame({"x": [1]})350.group_by("x")351.agg(y=pl.lit(1) * pl.lit(1))352.collect(353optimizations=pl.QueryOptFlags(simplify_expression=simplify_expression)354),355)356357358def test_nan_inf_aggregation() -> None:359df = pl.DataFrame(360[361("both nan", np.nan),362("both nan", np.nan),363("nan and 5", np.nan),364("nan and 5", 5),365("nan and null", np.nan),366("nan and null", None),367("both none", None),368("both none", None),369("both inf", np.inf),370("both inf", np.inf),371("inf and null", np.inf),372("inf and null", None),373],374schema=["group", "value"],375orient="row",376)377378assert_frame_equal(379df.group_by("group", maintain_order=True).agg(380min=pl.col("value").min(),381max=pl.col("value").max(),382mean=pl.col("value").mean(),383),384pl.DataFrame(385[386("both nan", np.nan, np.nan, np.nan),387("nan and 5", 5, 5, np.nan),388("nan and null", np.nan, np.nan, np.nan),389("both none", None, None, None),390("both inf", np.inf, np.inf, np.inf),391("inf and null", np.inf, np.inf, np.inf),392],393schema=["group", "min", "max", "mean"],394orient="row",395),396)397398399@pytest.mark.parametrize("dtype", [pl.Int16, pl.UInt16])400def test_int16_max_12904(dtype: PolarsDataType) -> None:401s = pl.Series([None, 1], dtype=dtype)402403assert s.min() == 1404assert s.max() == 1405406407def test_agg_filter_over_empty_df_13610() -> None:408ldf = pl.LazyFrame(409{410"a": [1, 1, 1, 2, 3],411"b": [True, True, True, True, True],412"c": [None, None, None, None, None],413}414)415416out = (417ldf.drop_nulls()418.group_by(["a"], maintain_order=True)419.agg(pl.col("b").filter(pl.col("b").shift(1)))420.collect()421)422expected = pl.DataFrame(schema={"a": pl.Int64, "b": pl.List(pl.Boolean)})423assert_frame_equal(out, expected)424425df = pl.DataFrame(schema={"a": pl.Int64, "b": pl.Boolean})426out = df.group_by("a").agg(pl.col("b").filter(pl.col("b").shift()))427expected = pl.DataFrame(schema={"a": pl.Int64, "b": pl.List(pl.Boolean)})428assert_frame_equal(out, expected)429430431@pytest.mark.may_fail_cloud # reason: output order is defined for this in cloud432@pytest.mark.slow433def test_agg_empty_sum_after_filter_14734() -> None:434f = (435pl.DataFrame({"a": [1, 2], "b": [1, 2]})436.lazy()437.group_by("a")438.agg(pl.col("b").filter(pl.lit(False)).sum())439.collect440)441442last = f()443444# We need both possible output orders, which should happen within445# 1000 iterations (during testing it usually happens within 10).446limit = 1000447i = 0448while (curr := f()).equals(last):449i += 1450assert i != limit451452expect = pl.Series("b", [0, 0]).to_frame()453assert_frame_equal(expect, last.select("b"))454assert_frame_equal(expect, curr.select("b"))455456457@pytest.mark.slow458def test_grouping_hash_14749() -> None:459n_groups = 251460rows_per_group = 4461assert (462pl.DataFrame(463{464"grp": np.repeat(np.arange(n_groups), rows_per_group),465"x": np.tile(np.arange(rows_per_group), n_groups),466}467)468.select(pl.col("x").max().over("grp"))["x"]469.value_counts()470).to_dict(as_series=False) == {"x": [3], "count": [1004]}471472473@pytest.mark.parametrize(474("in_dtype", "out_dtype"),475[476(pl.Boolean, pl.Float64),477(pl.UInt8, pl.Float64),478(pl.UInt16, pl.Float64),479(pl.UInt32, pl.Float64),480(pl.UInt64, pl.Float64),481(pl.Int8, pl.Float64),482(pl.Int16, pl.Float64),483(pl.Int32, pl.Float64),484(pl.Int64, pl.Float64),485(pl.Float32, pl.Float32),486(pl.Float64, pl.Float64),487],488)489def test_horizontal_mean_single_column(490in_dtype: PolarsDataType,491out_dtype: PolarsDataType,492) -> None:493out = (494pl.LazyFrame({"a": pl.Series([1, 0]).cast(in_dtype)})495.select(pl.mean_horizontal(pl.all()))496.collect()497)498499assert_frame_equal(out, pl.DataFrame({"a": pl.Series([1.0, 0.0]).cast(out_dtype)}))500501502def test_horizontal_mean_in_group_by_15115() -> None:503nbr_records = 1000504out = (505pl.LazyFrame(506{507"w": [None, "one", "two", "three"] * nbr_records,508"x": [None, None, "two", "three"] * nbr_records,509"y": [None, None, None, "three"] * nbr_records,510"z": [None, None, None, None] * nbr_records,511}512)513.select(pl.mean_horizontal(pl.all().is_null()).alias("mean_null"))514.group_by("mean_null")515.len()516.sort(by="mean_null")517.collect()518)519assert_frame_equal(520out,521pl.DataFrame(522{523"mean_null": pl.Series([0.25, 0.5, 0.75, 1.0], dtype=pl.Float64),524"len": pl.Series([nbr_records] * 4, dtype=pl.UInt32),525}526),527)528529530def test_group_count_over_null_column_15705() -> None:531df = pl.DataFrame(532{"a": [1, 1, 2, 2, 3, 3], "c": [None, None, None, None, None, None]}533)534out = df.group_by("a", maintain_order=True).agg(pl.col("c").count())535assert out["c"].to_list() == [0, 0, 0]536537538@pytest.mark.release539def test_min_max_2850() -> None:540# https://github.com/pola-rs/polars/issues/2850541df = pl.DataFrame(542{543"id": [544130352432,545130352277,546130352611,547130352833,548130352305,549130352258,550130352764,551130352475,552130352368,553130352346,554]555}556)557558minimum = 130352258559maximum = 130352833.0560561for _ in range(10):562permuted = df.sample(fraction=1.0, seed=0)563computed = permuted.select(564pl.col("id").min().alias("min"), pl.col("id").max().alias("max")565)566assert cast(int, computed[0, "min"]) == minimum567assert cast(float, computed[0, "max"]) == maximum568569570def test_multi_arg_structify_15834() -> None:571df = pl.DataFrame(572{573"group": [1, 2, 1, 2],574"value": [5750.1973209146402105,5760.13380719982405365,5770.6152394463707009,5780.4558767896005155,579],580}581)582583assert df.lazy().group_by("group").agg(584pl.struct(a=1, value=pl.col("value").sum())585).collect().sort("group").to_dict(as_series=False) == {586"group": [1, 2],587"a": [588{"a": 1, "value": 0.8125603610109114},589{"a": 1, "value": 0.5896839894245691},590],591}592593594def test_filter_aggregation_16642() -> None:595df = pl.DataFrame(596{597"datetime": [598datetime(2022, 1, 1, 11, 0),599datetime(2022, 1, 1, 11, 1),600datetime(2022, 1, 1, 11, 2),601datetime(2022, 1, 1, 11, 3),602datetime(2022, 1, 1, 11, 4),603datetime(2022, 1, 1, 11, 5),604datetime(2022, 1, 1, 11, 6),605datetime(2022, 1, 1, 11, 7),606datetime(2022, 1, 1, 11, 8),607datetime(2022, 1, 1, 11, 9, 1),608datetime(2022, 1, 2, 11, 0),609datetime(2022, 1, 2, 11, 1),610datetime(2022, 1, 2, 11, 2),611datetime(2022, 1, 2, 11, 3),612datetime(2022, 1, 2, 11, 4),613datetime(2022, 1, 2, 11, 5),614datetime(2022, 1, 2, 11, 6),615datetime(2022, 1, 2, 11, 7),616datetime(2022, 1, 2, 11, 8),617datetime(2022, 1, 2, 11, 9, 1),618],619"alpha": [620"A",621"B",622"C",623"D",624"E",625"F",626"G",627"H",628"I",629"J",630"A",631"B",632"C",633"D",634"E",635"F",636"G",637"H",638"I",639"J",640],641"num": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9],642}643)644grouped = df.group_by(pl.col("datetime").dt.date())645646ts_filter = pl.col("datetime").dt.time() <= pl.time(11, 3)647648report = grouped.agg(pl.col("num").filter(ts_filter).max()).sort("datetime")649assert report.to_dict(as_series=False) == {650"datetime": [date(2022, 1, 1), date(2022, 1, 2)],651"num": [3, 3],652}653654655def test_sort_by_over_single_nulls_first() -> None:656key = [0, 0, 0, 0, 1, 1, 1, 1]657df = pl.DataFrame(658{659"key": key,660"value": [2, None, 1, 0, 2, None, 1, 0],661}662)663out = df.select(664pl.all().sort_by("value", nulls_last=False, maintain_order=True).over("key")665)666expected = pl.DataFrame(667{668"key": key,669"value": [None, 0, 1, 2, None, 0, 1, 2],670}671)672assert_frame_equal(out, expected)673674675def test_sort_by_over_single_nulls_last() -> None:676key = [0, 0, 0, 0, 1, 1, 1, 1]677df = pl.DataFrame(678{679"key": key,680"value": [2, None, 1, 0, 2, None, 1, 0],681}682)683out = df.select(684pl.all().sort_by("value", nulls_last=True, maintain_order=True).over("key")685)686expected = pl.DataFrame(687{688"key": key,689"value": [0, 1, 2, None, 0, 1, 2, None],690}691)692assert_frame_equal(out, expected)693694695def test_sort_by_over_multiple_nulls_first() -> None:696key1 = [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]697key2 = [0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1]698df = pl.DataFrame(699{700"key1": key1,701"key2": key2,702"value": [1, None, 0, 1, None, 0, 1, None, 0, None, 1, 0],703}704)705out = df.select(706pl.all()707.sort_by("value", nulls_last=False, maintain_order=True)708.over("key1", "key2")709)710expected = pl.DataFrame(711{712"key1": key1,713"key2": key2,714"value": [None, 0, 1, None, 0, 1, None, 0, 1, None, 0, 1],715}716)717assert_frame_equal(out, expected)718719720def test_sort_by_over_multiple_nulls_last() -> None:721key1 = [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]722key2 = [0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1]723df = pl.DataFrame(724{725"key1": key1,726"key2": key2,727"value": [1, None, 0, 1, None, 0, 1, None, 0, None, 1, 0],728}729)730out = df.select(731pl.all()732.sort_by("value", nulls_last=True, maintain_order=True)733.over("key1", "key2")734)735expected = pl.DataFrame(736{737"key1": key1,738"key2": key2,739"value": [0, 1, None, 0, 1, None, 0, 1, None, 0, 1, None],740}741)742assert_frame_equal(out, expected)743744745def test_slice_after_agg_raises() -> None:746with pytest.raises(747InvalidOperationError, match=r"cannot slice\(\) an aggregated scalar value"748):749pl.select(a=1, b=1).group_by("a").agg(pl.col("b").first().slice(99, 0))750751752def test_agg_scalar_empty_groups_20115() -> None:753assert_frame_equal(754(755pl.DataFrame({"key": [123], "value": [456]})756.group_by("key")757.agg(pl.col("value").slice(1, 1).first())758),759pl.select(key=pl.lit(123, pl.Int64), value=pl.lit(None, pl.Int64)),760)761762763def test_agg_expr_returns_list_type_15574() -> None:764assert (765pl.LazyFrame({"a": [1, None], "b": [1, 2]})766.group_by("b")767.agg(pl.col("a").drop_nulls())768.collect_schema()769) == {"b": pl.Int64, "a": pl.List(pl.Int64)}770771772def test_empty_agg_22005() -> None:773out = (774pl.concat([pl.LazyFrame({"a": [1, 2]}), pl.LazyFrame({"a": [1, 2]})])775.limit(0)776.select(pl.col("a").sum())777)778assert_frame_equal(out.collect(), pl.DataFrame({"a": 0}))779780781@pytest.mark.parametrize("wrap_numerical", [True, False])782@pytest.mark.parametrize("strict_cast", [True, False])783def test_agg_with_filter_then_cast_23682(784strict_cast: bool, wrap_numerical: bool785) -> None:786assert_frame_equal(787pl.DataFrame([{"a": 123, "b": 12}, {"a": 123, "b": 257}])788.group_by("a")789.agg(790pl.col("b")791.filter(pl.col("b") < 256)792.cast(pl.UInt8, strict=strict_cast, wrap_numerical=wrap_numerical)793),794pl.DataFrame(795[{"a": 123, "b": [12]}], schema={"a": pl.Int64, "b": pl.List(pl.UInt8)}796),797)798799800@pytest.mark.parametrize("wrap_numerical", [True, False])801@pytest.mark.parametrize("strict_cast", [True, False])802def test_agg_with_slice_then_cast_23682(803strict_cast: bool, wrap_numerical: bool804) -> None:805assert_frame_equal(806pl.DataFrame([{"a": 123, "b": 12}, {"a": 123, "b": 257}])807.group_by("a")808.agg(809pl.col("b")810.slice(0, 1)811.cast(pl.UInt8, strict=strict_cast, wrap_numerical=wrap_numerical)812),813pl.DataFrame(814[{"a": 123, "b": [12]}], schema={"a": pl.Int64, "b": pl.List(pl.UInt8)}815),816)817818819