Path: blob/main/py-polars/tests/unit/operations/test_group_by.py
6939 views
from __future__ import annotations12import typing3from collections import OrderedDict4from datetime import date, datetime, timedelta5from typing import TYPE_CHECKING, Any67import numpy as np8import pytest910import polars as pl11import polars.selectors as cs12from polars.exceptions import ColumnNotFoundError13from polars.meta import get_index_type14from polars.testing import assert_frame_equal, assert_series_equal1516if TYPE_CHECKING:17from polars._typing import PolarsDataType181920def test_group_by() -> None:21df = pl.DataFrame(22{23"a": ["a", "b", "a", "b", "b", "c"],24"b": [1, 2, 3, 4, 5, 6],25"c": [6, 5, 4, 3, 2, 1],26}27)2829# Use lazy API in eager group_by30assert sorted(df.group_by("a").agg([pl.sum("b")]).rows()) == [31("a", 4),32("b", 11),33("c", 6),34]35# test if it accepts a single expression36assert df.group_by("a", maintain_order=True).agg(pl.sum("b")).rows() == [37("a", 4),38("b", 11),39("c", 6),40]4142df = pl.DataFrame(43{44"a": [1, 2, 3, 4, 5],45"b": ["a", "a", "b", "b", "b"],46"c": [None, 1, None, 1, None],47}48)4950# check if this query runs and thus column names propagate51df.group_by("b").agg(pl.col("c").fill_null(strategy="forward")).explode("c")5253# get a specific column54result = df.group_by("b", maintain_order=True).agg(pl.count("a"))55assert result.rows() == [("a", 2), ("b", 3)]56assert result.columns == ["b", "a"]575859@pytest.mark.parametrize(60("input", "expected", "input_dtype", "output_dtype"),61[62([1, 2, 3, 4], [2, 4], pl.UInt8, pl.Float64),63([1, 2, 3, 4], [2, 4], pl.Int8, pl.Float64),64([1, 2, 3, 4], [2, 4], pl.UInt16, pl.Float64),65([1, 2, 3, 4], [2, 4], pl.Int16, pl.Float64),66([1, 2, 3, 4], [2, 4], pl.UInt32, pl.Float64),67([1, 2, 3, 4], [2, 4], pl.Int32, pl.Float64),68([1, 2, 3, 4], [2, 4], pl.UInt64, pl.Float64),69([1, 2, 3, 4], [2, 4], pl.Float32, pl.Float32),70([1, 2, 3, 4], [2, 4], pl.Float64, pl.Float64),71([False, True, True, True], [2 / 3, 1], pl.Boolean, pl.Float64),72(73[date(2023, 1, 1), date(2023, 1, 2), date(2023, 1, 4), date(2023, 1, 5)],74[datetime(2023, 1, 2, 8, 0, 0), datetime(2023, 1, 5)],75pl.Date,76pl.Datetime("us"),77),78(79[80datetime(2023, 1, 1),81datetime(2023, 1, 2),82datetime(2023, 1, 3),83datetime(2023, 1, 4),84],85[datetime(2023, 1, 2), datetime(2023, 1, 4)],86pl.Datetime("ms"),87pl.Datetime("ms"),88),89(90[91datetime(2023, 1, 1),92datetime(2023, 1, 2),93datetime(2023, 1, 3),94datetime(2023, 1, 4),95],96[datetime(2023, 1, 2), datetime(2023, 1, 4)],97pl.Datetime("us"),98pl.Datetime("us"),99),100(101[102datetime(2023, 1, 1),103datetime(2023, 1, 2),104datetime(2023, 1, 3),105datetime(2023, 1, 4),106],107[datetime(2023, 1, 2), datetime(2023, 1, 4)],108pl.Datetime("ns"),109pl.Datetime("ns"),110),111(112[timedelta(1), timedelta(2), timedelta(3), timedelta(4)],113[timedelta(2), timedelta(4)],114pl.Duration("ms"),115pl.Duration("ms"),116),117(118[timedelta(1), timedelta(2), timedelta(3), timedelta(4)],119[timedelta(2), timedelta(4)],120pl.Duration("us"),121pl.Duration("us"),122),123(124[timedelta(1), timedelta(2), timedelta(3), timedelta(4)],125[timedelta(2), timedelta(4)],126pl.Duration("ns"),127pl.Duration("ns"),128),129],130)131def test_group_by_mean_by_dtype(132input: list[Any],133expected: list[Any],134input_dtype: PolarsDataType,135output_dtype: PolarsDataType,136) -> None:137# groups are defined by first 3 values, then last value138name = str(input_dtype)139key = ["a", "a", "a", "b"]140df = pl.LazyFrame(141{142"key": key,143name: pl.Series(input, dtype=input_dtype),144}145)146result = df.group_by("key", maintain_order=True).mean()147df_expected = pl.DataFrame(148{149"key": ["a", "b"],150name: pl.Series(expected, dtype=output_dtype),151}152)153assert result.collect_schema() == df_expected.schema154assert_frame_equal(result.collect(), df_expected)155156157@pytest.mark.parametrize(158("input", "expected", "input_dtype", "output_dtype"),159[160([1, 2, 4, 5], [2, 5], pl.UInt8, pl.Float64),161([1, 2, 4, 5], [2, 5], pl.Int8, pl.Float64),162([1, 2, 4, 5], [2, 5], pl.UInt16, pl.Float64),163([1, 2, 4, 5], [2, 5], pl.Int16, pl.Float64),164([1, 2, 4, 5], [2, 5], pl.UInt32, pl.Float64),165([1, 2, 4, 5], [2, 5], pl.Int32, pl.Float64),166([1, 2, 4, 5], [2, 5], pl.UInt64, pl.Float64),167([1, 2, 4, 5], [2, 5], pl.Float32, pl.Float32),168([1, 2, 4, 5], [2, 5], pl.Float64, pl.Float64),169([False, True, True, True], [1, 1], pl.Boolean, pl.Float64),170(171[date(2023, 1, 1), date(2023, 1, 2), date(2023, 1, 4), date(2023, 1, 5)],172[datetime(2023, 1, 2), datetime(2023, 1, 5)],173pl.Date,174pl.Datetime("us"),175),176(177[178datetime(2023, 1, 1),179datetime(2023, 1, 2),180datetime(2023, 1, 4),181datetime(2023, 1, 5),182],183[datetime(2023, 1, 2), datetime(2023, 1, 5)],184pl.Datetime("ms"),185pl.Datetime("ms"),186),187(188[189datetime(2023, 1, 1),190datetime(2023, 1, 2),191datetime(2023, 1, 4),192datetime(2023, 1, 5),193],194[datetime(2023, 1, 2), datetime(2023, 1, 5)],195pl.Datetime("us"),196pl.Datetime("us"),197),198(199[200datetime(2023, 1, 1),201datetime(2023, 1, 2),202datetime(2023, 1, 4),203datetime(2023, 1, 5),204],205[datetime(2023, 1, 2), datetime(2023, 1, 5)],206pl.Datetime("ns"),207pl.Datetime("ns"),208),209(210[timedelta(1), timedelta(2), timedelta(4), timedelta(5)],211[timedelta(2), timedelta(5)],212pl.Duration("ms"),213pl.Duration("ms"),214),215(216[timedelta(1), timedelta(2), timedelta(4), timedelta(5)],217[timedelta(2), timedelta(5)],218pl.Duration("us"),219pl.Duration("us"),220),221(222[timedelta(1), timedelta(2), timedelta(4), timedelta(5)],223[timedelta(2), timedelta(5)],224pl.Duration("ns"),225pl.Duration("ns"),226),227],228)229def test_group_by_median_by_dtype(230input: list[Any],231expected: list[Any],232input_dtype: PolarsDataType,233output_dtype: PolarsDataType,234) -> None:235# groups are defined by first 3 values, then last value236name = str(input_dtype)237key = ["a", "a", "a", "b"]238df = pl.LazyFrame(239{240"key": key,241name: pl.Series(input, dtype=input_dtype),242}243)244result = df.group_by("key", maintain_order=True).median()245df_expected = pl.DataFrame(246{247"key": ["a", "b"],248name: pl.Series(expected, dtype=output_dtype),249}250)251assert result.collect_schema() == df_expected.schema252assert_frame_equal(result.collect(), df_expected)253254255@pytest.fixture256def df() -> pl.DataFrame:257return pl.DataFrame(258{259"a": [1, 2, 3, 4, 5],260"b": ["a", "a", "b", "b", "b"],261"c": [None, 1, None, 1, None],262}263)264265266@pytest.mark.parametrize(267("method", "expected"),268[269("all", [("a", [1, 2], [None, 1]), ("b", [3, 4, 5], [None, 1, None])]),270("len", [("a", 2), ("b", 3)]),271("first", [("a", 1, None), ("b", 3, None)]),272("last", [("a", 2, 1), ("b", 5, None)]),273("max", [("a", 2, 1), ("b", 5, 1)]),274("mean", [("a", 1.5, 1.0), ("b", 4.0, 1.0)]),275("median", [("a", 1.5, 1.0), ("b", 4.0, 1.0)]),276("min", [("a", 1, 1), ("b", 3, 1)]),277("n_unique", [("a", 2, 2), ("b", 3, 2)]),278],279)280def test_group_by_shorthands(281df: pl.DataFrame, method: str, expected: list[tuple[Any]]282) -> None:283gb = df.group_by("b", maintain_order=True)284result = getattr(gb, method)()285assert result.rows() == expected286287gb_lazy = df.lazy().group_by("b", maintain_order=True)288result = getattr(gb_lazy, method)().collect()289assert result.rows() == expected290291292def test_group_by_shorthand_quantile(df: pl.DataFrame) -> None:293result = df.group_by("b", maintain_order=True).quantile(0.5)294expected = [("a", 2.0, 1.0), ("b", 4.0, 1.0)]295assert result.rows() == expected296297result = df.lazy().group_by("b", maintain_order=True).quantile(0.5).collect()298assert result.rows() == expected299300301def test_group_by_args() -> None:302df = pl.DataFrame(303{304"a": ["a", "b", "a", "b", "b", "c"],305"b": [1, 2, 3, 4, 5, 6],306"c": [6, 5, 4, 3, 2, 1],307}308)309310# Single column name311assert df.group_by("a").agg("b").columns == ["a", "b"]312# Column names as list313expected = ["a", "b", "c"]314assert df.group_by(["a", "b"]).agg("c").columns == expected315# Column names as positional arguments316assert df.group_by("a", "b").agg("c").columns == expected317# With keyword argument318assert df.group_by("a", "b", maintain_order=True).agg("c").columns == expected319# Multiple aggregations as list320assert df.group_by("a").agg(["b", "c"]).columns == expected321# Multiple aggregations as positional arguments322assert df.group_by("a").agg("b", "c").columns == expected323# Multiple aggregations as keyword arguments324assert df.group_by("a").agg(q="b", r="c").columns == ["a", "q", "r"]325326327def test_group_by_empty() -> None:328df = pl.DataFrame({"a": [1, 1, 2]})329result = df.group_by("a").agg()330expected = pl.DataFrame({"a": [1, 2]})331assert_frame_equal(result, expected, check_row_order=False)332333334def test_group_by_iteration() -> None:335df = pl.DataFrame(336{337"foo": ["a", "b", "a", "b", "b", "c"],338"bar": [1, 2, 3, 4, 5, 6],339"baz": [6, 5, 4, 3, 2, 1],340}341)342expected_names = ["a", "b", "c"]343expected_rows = [344[("a", 1, 6), ("a", 3, 4)],345[("b", 2, 5), ("b", 4, 3), ("b", 5, 2)],346[("c", 6, 1)],347]348gb_iter = enumerate(df.group_by("foo", maintain_order=True))349for i, (group, data) in gb_iter:350assert group == (expected_names[i],)351assert data.rows() == expected_rows[i]352353# Grouped by ALL columns should give groups of a single row354result = list(df.group_by(["foo", "bar", "baz"]))355assert len(result) == 6356357# Iterating over groups should also work when grouping by expressions358result2 = list(df.group_by(["foo", pl.col("bar") * pl.col("baz")]))359assert len(result2) == 5360361# Single expression, alias in group_by362df = pl.DataFrame({"foo": [1, 2, 3, 4, 5, 6]})363gb = df.group_by((pl.col("foo") // 2).alias("bar"), maintain_order=True)364result3 = [(group, df.rows()) for group, df in gb]365expected3 = [366((0,), [(1,)]),367((1,), [(2,), (3,)]),368((2,), [(4,), (5,)]),369((3,), [(6,)]),370]371assert result3 == expected3372373374def test_group_by_iteration_selector() -> None:375df = pl.DataFrame({"a": ["one", "two", "one", "two"], "b": [1, 2, 3, 4]})376result = dict(df.group_by(cs.string()))377result_first = result["one",]378assert result_first.to_dict(as_series=False) == {"a": ["one", "one"], "b": [1, 3]}379380381@pytest.mark.parametrize("input", [[pl.col("b").sum()], pl.col("b").sum()])382def test_group_by_agg_input_types(input: Any) -> None:383df = pl.LazyFrame({"a": [1, 1, 2, 2], "b": [1, 2, 3, 4]})384result = df.group_by("a", maintain_order=True).agg(input)385expected = pl.LazyFrame({"a": [1, 2], "b": [3, 7]})386assert_frame_equal(result, expected)387388389@pytest.mark.parametrize("input", [str, "b".join])390def test_group_by_agg_bad_input_types(input: Any) -> None:391df = pl.LazyFrame({"a": [1, 1, 2, 2], "b": [1, 2, 3, 4]})392with pytest.raises(TypeError):393df.group_by("a").agg(input)394395396def test_group_by_sorted_empty_dataframe_3680() -> None:397df = (398pl.DataFrame(399[400pl.Series("key", [], dtype=pl.Categorical),401pl.Series("val", [], dtype=pl.Float64),402]403)404.lazy()405.sort("key")406.group_by("key")407.tail(1)408.collect(optimizations=pl.QueryOptFlags(check_order_observe=False))409)410assert df.rows() == []411assert df.shape == (0, 2)412assert df.schema == {"key": pl.Categorical(ordering="lexical"), "val": pl.Float64}413414415def test_group_by_custom_agg_empty_list() -> None:416assert (417pl.DataFrame(418[419pl.Series("key", [], dtype=pl.Categorical),420pl.Series("val", [], dtype=pl.Float64),421]422)423.group_by("key")424.agg(425[426pl.col("val").mean().alias("mean"),427pl.col("val").std().alias("std"),428pl.col("val").skew().alias("skew"),429pl.col("val").kurtosis().alias("kurt"),430]431)432).dtypes == [pl.Categorical, pl.Float64, pl.Float64, pl.Float64, pl.Float64]433434435def test_apply_after_take_in_group_by_3869() -> None:436assert (437pl.DataFrame(438{439"k": list("aaabbb"),440"t": [1, 2, 3, 4, 5, 6],441"v": [3, 1, 2, 5, 6, 4],442}443)444.group_by("k", maintain_order=True)445.agg(446pl.col("v").get(pl.col("t").arg_max()).sqrt()447) # <- fails for sqrt, exp, log, pow, etc.448).to_dict(as_series=False) == {"k": ["a", "b"], "v": [1.4142135623730951, 2.0]}449450451def test_group_by_signed_transmutes() -> None:452df = pl.DataFrame({"foo": [-1, -2, -3, -4, -5], "bar": [500, 600, 700, 800, 900]})453454for dt in [pl.Int8, pl.Int16, pl.Int32, pl.Int64]:455df = (456df.with_columns([pl.col("foo").cast(dt), pl.col("bar")])457.group_by("foo", maintain_order=True)458.agg(pl.col("bar").median())459)460461assert df.to_dict(as_series=False) == {462"foo": [-1, -2, -3, -4, -5],463"bar": [500.0, 600.0, 700.0, 800.0, 900.0],464}465466467def test_arg_sort_sort_by_groups_update__4360() -> None:468df = pl.DataFrame(469{470"group": ["a"] * 3 + ["b"] * 3 + ["c"] * 3,471"col1": [1, 2, 3] * 3,472"col2": [1, 2, 3, 3, 2, 1, 2, 3, 1],473}474)475476out = df.with_columns(477pl.col("col2").arg_sort().over("group").alias("col2_arg_sort")478).with_columns(479pl.col("col1").sort_by(pl.col("col2_arg_sort")).over("group").alias("result_a"),480pl.col("col1")481.sort_by(pl.col("col2").arg_sort())482.over("group")483.alias("result_b"),484)485486assert_series_equal(out["result_a"], out["result_b"], check_names=False)487assert out["result_a"].to_list() == [1, 2, 3, 3, 2, 1, 2, 3, 1]488489490def test_unique_order() -> None:491df = pl.DataFrame({"a": [1, 2, 1]}).with_row_index()492assert df.unique(keep="last", subset="a", maintain_order=True).to_dict(493as_series=False494) == {495"index": [1, 2],496"a": [2, 1],497}498assert df.unique(keep="first", subset="a", maintain_order=True).to_dict(499as_series=False500) == {501"index": [0, 1],502"a": [1, 2],503}504505506def test_group_by_dynamic_flat_agg_4814() -> None:507df = pl.DataFrame({"a": [1, 2, 2], "b": [1, 8, 12]}).set_sorted("a")508509assert df.group_by_dynamic("a", every="1i", period="2i").agg(510[511(pl.col("b").sum() / pl.col("a").sum()).alias("sum_ratio_1"),512(pl.col("b").last() / pl.col("a").last()).alias("last_ratio_1"),513(pl.col("b") / pl.col("a")).last().alias("last_ratio_2"),514]515).to_dict(as_series=False) == {516"a": [1, 2],517"sum_ratio_1": [4.2, 5.0],518"last_ratio_1": [6.0, 6.0],519"last_ratio_2": [6.0, 6.0],520}521522523@pytest.mark.parametrize(524("every", "period"),525[526("10s", timedelta(seconds=100)),527(timedelta(seconds=10), "100s"),528],529)530@pytest.mark.parametrize("time_zone", [None, "UTC", "Asia/Kathmandu"])531def test_group_by_dynamic_overlapping_groups_flat_apply_multiple_5038(532every: str | timedelta, period: str | timedelta, time_zone: str | None533) -> None:534res = (535(536pl.DataFrame(537{538"a": [539datetime(2021, 1, 1) + timedelta(seconds=2**i)540for i in range(10)541],542"b": [float(i) for i in range(10)],543}544)545.with_columns(pl.col("a").dt.replace_time_zone(time_zone))546.lazy()547.set_sorted("a")548.group_by_dynamic("a", every=every, period=period)549.agg([pl.col("b").var().sqrt().alias("corr")])550)551.collect()552.sum()553.to_dict(as_series=False)554)555556assert res["corr"] == pytest.approx([6.988674024215477])557assert res["a"] == [None]558559560def test_take_in_group_by() -> None:561df = pl.DataFrame({"group": [1, 1, 1, 2, 2, 2], "values": [10, 200, 3, 40, 500, 6]})562assert df.group_by("group").agg(563pl.col("values").get(1) - pl.col("values").get(2)564).sort("group").to_dict(as_series=False) == {"group": [1, 2], "values": [197, 494]}565566567def test_group_by_wildcard() -> None:568df = pl.DataFrame(569{570"a": [1, 2],571"b": [1, 2],572}573)574assert df.group_by([pl.col("*")], maintain_order=True).agg(575[pl.col("a").first().name.suffix("_agg")]576).to_dict(as_series=False) == {"a": [1, 2], "b": [1, 2], "a_agg": [1, 2]}577578579def test_group_by_all_masked_out() -> None:580df = pl.DataFrame(581{582"val": pl.Series(583[None, None, None, None], dtype=pl.Categorical, nan_to_null=True584).set_sorted(),585"col": [4, 4, 4, 4],586}587)588parts = df.partition_by("val")589assert len(parts) == 1590assert_frame_equal(parts[0], df)591592593def test_group_by_null_propagation_6185() -> None:594df_1 = pl.DataFrame({"A": [0, 0], "B": [1, 2]})595596expr = pl.col("A").filter(pl.col("A") > 0)597598expected = {"B": [1, 2], "A": [None, None]}599assert (600df_1.group_by("B")601.agg((expr - expr.mean()).mean())602.sort("B")603.to_dict(as_series=False)604== expected605)606607608def test_group_by_when_then_with_binary_and_agg_in_pred_6202() -> None:609df = pl.DataFrame(610{"code": ["a", "b", "b", "b", "a"], "xx": [1.0, -1.5, -0.2, -3.9, 3.0]}611)612assert (613df.group_by("code", maintain_order=True).agg(614[pl.when(pl.col("xx") > pl.min("xx")).then(True).otherwise(False)]615)616).to_dict(as_series=False) == {617"code": ["a", "b"],618"literal": [[False, True], [True, True, False]],619}620621622def test_group_by_binary_agg_with_literal() -> None:623df = pl.DataFrame({"id": ["a", "a", "b", "b"], "value": [1, 2, 3, 4]})624625out = df.group_by("id", maintain_order=True).agg(626pl.col("value") + pl.Series([1, 3])627)628assert out.to_dict(as_series=False) == {"id": ["a", "b"], "value": [[2, 5], [4, 7]]}629630out = df.group_by("id", maintain_order=True).agg(pl.col("value") + pl.lit(1))631assert out.to_dict(as_series=False) == {"id": ["a", "b"], "value": [[2, 3], [4, 5]]}632633out = df.group_by("id", maintain_order=True).agg(pl.lit(1) + pl.lit(2))634assert out.to_dict(as_series=False) == {"id": ["a", "b"], "literal": [3, 3]}635636out = df.group_by("id", maintain_order=True).agg(pl.lit(1) + pl.Series([2, 3]))637assert out.to_dict(as_series=False) == {638"id": ["a", "b"],639"literal": [[3, 4], [3, 4]],640}641642out = df.group_by("id", maintain_order=True).agg(643value=pl.lit(pl.Series([1, 2])) + pl.lit(pl.Series([3, 4]))644)645assert out.to_dict(as_series=False) == {"id": ["a", "b"], "value": [[4, 6], [4, 6]]}646647648@pytest.mark.slow649@pytest.mark.parametrize("dtype", [pl.Int32, pl.UInt32])650def test_overflow_mean_partitioned_group_by_5194(dtype: PolarsDataType) -> None:651df = pl.DataFrame(652[653pl.Series("data", [10_00_00_00] * 100_000, dtype=dtype),654pl.Series("group", [1, 2] * 50_000, dtype=dtype),655]656)657result = df.group_by("group").agg(pl.col("data").mean()).sort(by="group")658expected = {"group": [1, 2], "data": [10000000.0, 10000000.0]}659assert result.to_dict(as_series=False) == expected660661662# https://github.com/pola-rs/polars/issues/7181663def test_group_by_multiple_column_reference() -> None:664df = pl.DataFrame(665{666"gr": ["a", "b", "a", "b", "a", "b"],667"val": [1, 20, 100, 2000, 10000, 200000],668}669)670result = df.group_by("gr").agg(671pl.col("val") + pl.col("val").shift().fill_null(0),672)673674assert result.sort("gr").to_dict(as_series=False) == {675"gr": ["a", "b"],676"val": [[1, 101, 10100], [20, 2020, 202000]],677}678679680@pytest.mark.parametrize(681("aggregation", "args", "expected_values", "expected_dtype"),682[683("first", [], [1, None], pl.Int64),684("last", [], [1, None], pl.Int64),685("max", [], [1, None], pl.Int64),686("mean", [], [1.0, None], pl.Float64),687("median", [], [1.0, None], pl.Float64),688("min", [], [1, None], pl.Int64),689("n_unique", [], [1, 0], pl.UInt32),690("quantile", [0.5], [1.0, None], pl.Float64),691],692)693def test_group_by_empty_groups(694aggregation: str,695args: list[object],696expected_values: list[object],697expected_dtype: pl.DataType,698) -> None:699df = pl.DataFrame({"a": [1, 2], "b": [1, 2]})700result = df.group_by("b", maintain_order=True).agg(701getattr(pl.col("a").filter(pl.col("b") != 2), aggregation)(*args)702)703expected = pl.DataFrame({"b": [1, 2], "a": expected_values}).with_columns(704pl.col("a").cast(expected_dtype)705)706assert_frame_equal(result, expected)707708709# https://github.com/pola-rs/polars/issues/8663710def test_perfect_hash_table_null_values() -> None:711# fmt: off712values = ["3", "41", "17", "5", "26", "27", "43", "45", "41", "13", "45", "48", "17", "22", "31", "25", "28", "13", "7", "26", "17", "4", "43", "47", "30", "28", "8", "27", "6", "7", "26", "11", "37", "29", "49", "20", "29", "28", "23", "9", None, "38", "19", "7", "38", "3", "30", "37", "41", "5", "16", "26", "31", "6", "25", "11", "17", "31", "31", "20", "26", None, "39", "10", "38", "4", "39", "15", "13", "35", "38", "11", "39", "11", "48", "36", "18", "11", "34", "16", "28", "9", "37", "8", "17", "48", "44", "28", "25", "30", "37", "30", "18", "12", None, "27", "10", "3", "16", "27", "6"]713groups = ["3", "41", "17", "5", "26", "27", "43", "45", "13", "48", "22", "31", "25", "28", "7", "4", "47", "30", "8", "6", "11", "37", "29", "49", "20", "23", "9", None, "38", "19", "16", "39", "10", "15", "35", "36", "18", "34", "44", "12"]714# fmt: on715716s = pl.Series("a", values, dtype=pl.Categorical)717718result = (719s.to_frame("a").group_by("a", maintain_order=True).agg(pl.col("a").alias("agg"))720)721722agg_values = [723["3", "3", "3"],724["41", "41", "41"],725["17", "17", "17", "17", "17"],726["5", "5"],727["26", "26", "26", "26", "26"],728["27", "27", "27", "27"],729["43", "43"],730["45", "45"],731["13", "13", "13"],732["48", "48", "48"],733["22"],734["31", "31", "31", "31"],735["25", "25", "25"],736["28", "28", "28", "28", "28"],737["7", "7", "7"],738["4", "4"],739["47"],740["30", "30", "30", "30"],741["8", "8"],742["6", "6", "6"],743["11", "11", "11", "11", "11"],744["37", "37", "37", "37"],745["29", "29"],746["49"],747["20", "20"],748["23"],749["9", "9"],750[None, None, None],751["38", "38", "38", "38"],752["19"],753["16", "16", "16"],754["39", "39", "39"],755["10", "10"],756["15"],757["35"],758["36"],759["18", "18"],760["34"],761["44"],762["12"],763]764expected = pl.DataFrame(765{766"a": groups,767"agg": agg_values,768},769schema={"a": pl.Categorical, "agg": pl.List(pl.Categorical)},770)771assert_frame_equal(result, expected)772773774def test_group_by_partitioned_ending_cast(monkeypatch: Any) -> None:775monkeypatch.setenv("POLARS_FORCE_PARTITION", "1")776df = pl.DataFrame({"a": [1] * 5, "b": [1] * 5})777out = df.group_by(["a", "b"]).agg(pl.len().cast(pl.Int64).alias("num"))778expected = pl.DataFrame({"a": [1], "b": [1], "num": [5]})779assert_frame_equal(out, expected)780781782def test_group_by_series_partitioned(partition_limit: int) -> None:783# test 15354784df = pl.DataFrame([0, 0] * partition_limit)785groups = pl.Series([0, 1] * partition_limit)786df.group_by(groups).agg(pl.all().is_not_null().sum())787788789def test_group_by_list_scalar_11749() -> None:790df = pl.DataFrame(791{792"group_name": ["a;b", "a;b", "c;d", "c;d", "a;b", "a;b"],793"parent_name": ["a", "b", "c", "d", "a", "b"],794"measurement": [795["x1", "x2"],796["x1", "x2"],797["y1", "y2"],798["z1", "z2"],799["x1", "x2"],800["x1", "x2"],801],802}803)804assert (805df.group_by("group_name").agg(806(pl.col("measurement").first() == pl.col("measurement")).alias("eq"),807)808).sort("group_name").to_dict(as_series=False) == {809"group_name": ["a;b", "c;d"],810"eq": [[True, True, True, True], [True, False]],811}812813814def test_group_by_with_expr_as_key() -> None:815gb = pl.select(x=1).group_by(pl.col("x").alias("key"))816result = gb.agg(pl.all().first())817expected = gb.agg(pl.first("x"))818assert_frame_equal(result, expected)819820# tests: 11766821result = gb.head(0)822expected = gb.agg(pl.col("x").head(0)).explode("x")823assert_frame_equal(result, expected)824825result = gb.tail(0)826expected = gb.agg(pl.col("x").tail(0)).explode("x")827assert_frame_equal(result, expected)828829830def test_lazy_group_by_reuse_11767() -> None:831lgb = pl.select(x=1).lazy().group_by("x")832a = lgb.len()833b = lgb.len()834assert_frame_equal(a, b)835836837def test_group_by_double_on_empty_12194() -> None:838df = pl.DataFrame({"group": [1], "x": [1]}).clear()839squared_deviation_sum = ((pl.col("x") - pl.col("x").mean()) ** 2).sum()840assert df.group_by("group").agg(squared_deviation_sum).schema == OrderedDict(841[("group", pl.Int64), ("x", pl.Float64)]842)843844845def test_group_by_when_then_no_aggregation_predicate() -> None:846df = pl.DataFrame(847{848"key": ["aa", "aa", "bb", "bb", "aa", "aa"],849"val": [-3, -2, 1, 4, -3, 5],850}851)852assert df.group_by("key").agg(853pos=pl.when(pl.col("val") >= 0).then(pl.col("val")).sum(),854neg=pl.when(pl.col("val") < 0).then(pl.col("val")).sum(),855).sort("key").to_dict(as_series=False) == {856"key": ["aa", "bb"],857"pos": [5, 5],858"neg": [-8, 0],859}860861862def test_group_by_apply_first_input_is_literal() -> None:863df = pl.DataFrame({"x": [1, 2, 3, 4, 5], "g": [1, 1, 2, 2, 2]})864pow = df.group_by("g").agg(2 ** pl.col("x"))865assert pow.sort("g").to_dict(as_series=False) == {866"g": [1, 2],867"literal": [[2.0, 4.0], [8.0, 16.0, 32.0]],868}869870871def test_group_by_all_12869() -> None:872df = pl.DataFrame({"a": [1]})873result = next(iter(df.group_by(pl.all())))[1]874assert_frame_equal(df, result)875876877def test_group_by_named() -> None:878df = pl.DataFrame({"a": [1, 1, 2, 2, 3, 3], "b": range(6)})879result = df.group_by(z=pl.col("a") * 2, maintain_order=True).agg(pl.col("b").min())880expected = df.group_by((pl.col("a") * 2).alias("z"), maintain_order=True).agg(881pl.col("b").min()882)883assert_frame_equal(result, expected)884885886def test_group_by_with_null() -> None:887df = pl.DataFrame(888{"a": [None, None, None, None], "b": [1, 1, 2, 2], "c": ["x", "y", "z", "u"]}889)890expected = pl.DataFrame(891{"a": [None, None], "b": [1, 2], "c": [["x", "y"], ["z", "u"]]}892)893output = df.group_by(["a", "b"], maintain_order=True).agg(pl.col("c"))894assert_frame_equal(expected, output)895896897def test_partitioned_group_by_14954(monkeypatch: Any) -> None:898monkeypatch.setenv("POLARS_FORCE_PARTITION", "1")899assert (900pl.DataFrame({"a": range(20)})901.select(pl.col("a") % 2)902.group_by("a")903.agg(904(pl.col("a") > 1000).alias("a > 1000"),905)906).sort("a").to_dict(as_series=False) == {907"a": [0, 1],908"a > 1000": [909[False, False, False, False, False, False, False, False, False, False],910[False, False, False, False, False, False, False, False, False, False],911],912}913914915def test_partitioned_group_by_nulls_mean_21838() -> None:916size = 10917a = [1 for i in range(size)] + [2 for i in range(size)] + [3 for i in range(size)]918b = [1 for i in range(size)] + [None for i in range(size * 2)]919df = pl.DataFrame({"a": a, "b": b})920assert df.group_by("a").mean().sort("a").to_dict(as_series=False) == {921"a": [1, 2, 3],922"b": [1.0, None, None],923}924925926def test_aggregated_scalar_elementwise_15602() -> None:927df = pl.DataFrame({"group": [1, 2, 1]})928929out = df.group_by("group", maintain_order=True).agg(930foo=pl.col("group").is_between(1, pl.max("group"))931)932expected = pl.DataFrame({"group": [1, 2], "foo": [[True, True], [True]]})933assert_frame_equal(out, expected)934935936def test_group_by_multiple_null_cols_15623() -> None:937df = pl.DataFrame(schema={"a": pl.Null, "b": pl.Null}).group_by(pl.all()).len()938assert df.is_empty()939940941@pytest.mark.release942def test_categorical_vs_str_group_by() -> None:943# this triggers the perfect hash table944s = pl.Series("a", np.random.randint(0, 50, 100))945s_with_nulls = pl.select(946pl.when(s < 3).then(None).otherwise(s).alias("a")947).to_series()948949for s_ in [s, s_with_nulls]:950s_ = s_.cast(str)951cat_out = (952s_.cast(pl.Categorical)953.to_frame("a")954.group_by("a")955.agg(pl.first().alias("first"))956)957958str_out = s_.to_frame("a").group_by("a").agg(pl.first().alias("first"))959cat_out.with_columns(pl.col("a").cast(str))960assert_frame_equal(961cat_out.with_columns(962pl.col("a").cast(str), pl.col("first").cast(pl.List(str))963).sort("a"),964str_out.sort("a"),965)966967968@pytest.mark.release969def test_boolean_min_max_agg() -> None:970np.random.seed(0)971idx = np.random.randint(0, 500, 1000)972c = np.random.randint(0, 500, 1000) > 250973974df = pl.DataFrame({"idx": idx, "c": c})975aggs = [pl.col("c").min().alias("c_min"), pl.col("c").max().alias("c_max")]976977result = df.group_by("idx").agg(aggs).sum()978979schema = {"idx": pl.Int64, "c_min": pl.UInt32, "c_max": pl.UInt32}980expected = pl.DataFrame(981{982"idx": [107583],983"c_min": [120],984"c_max": [321],985},986schema=schema,987)988assert_frame_equal(result, expected)989990nulls = np.random.randint(0, 500, 1000) < 100991992result = (993df.with_columns(c=pl.when(pl.lit(nulls)).then(None).otherwise(pl.col("c")))994.group_by("idx")995.agg(aggs)996.sum()997)998999expected = pl.DataFrame(1000{1001"idx": [107583],1002"c_min": [133],1003"c_max": [276],1004},1005schema=schema,1006)1007assert_frame_equal(result, expected)100810091010def test_partitioned_group_by_chunked(partition_limit: int) -> None:1011n = partition_limit1012df1 = pl.DataFrame(np.random.randn(n, 2))1013df2 = pl.DataFrame(np.random.randn(n, 2))1014gps = pl.Series(name="oo", values=[0] * n + [1] * n)1015df = pl.concat([df1, df2], rechunk=False)1016assert_frame_equal(1017df.group_by(gps).sum().sort("oo"),1018df.rechunk().group_by(gps, maintain_order=True).sum(),1019)102010211022def test_schema_on_agg() -> None:1023lf = pl.LazyFrame({"a": ["x", "x", "y", "n"], "b": [1, 2, 3, 4]})10241025result = lf.group_by("a").agg(1026pl.col("b").min().alias("min"),1027pl.col("b").max().alias("max"),1028pl.col("b").sum().alias("sum"),1029pl.col("b").first().alias("first"),1030pl.col("b").last().alias("last"),1031)1032expected_schema = {1033"a": pl.String,1034"min": pl.Int64,1035"max": pl.Int64,1036"sum": pl.Int64,1037"first": pl.Int64,1038"last": pl.Int64,1039}1040assert result.collect_schema() == expected_schema104110421043def test_group_by_schema_err() -> None:1044lf = pl.LazyFrame({"foo": [None, 1, 2], "bar": [1, 2, 3]})1045with pytest.raises(ColumnNotFoundError):1046lf.group_by("not-existent").agg(1047pl.col("bar").max().alias("max_bar")1048).collect_schema()104910501051@pytest.mark.parametrize(1052("data", "expr", "expected_select", "expected_gb"),1053[1054(1055{"x": ["x"], "y": ["y"]},1056pl.coalesce(pl.col("x"), pl.col("y")),1057{"x": pl.String},1058{"x": pl.List(pl.String)},1059),1060(1061{"x": [True]},1062pl.col("x").sum(),1063{"x": pl.UInt32},1064{"x": pl.UInt32},1065),1066(1067{"a": [[1, 2]]},1068pl.col("a").list.sum(),1069{"a": pl.Int64},1070{"a": pl.List(pl.Int64)},1071),1072],1073)1074def test_schemas(1075data: dict[str, list[Any]],1076expr: pl.Expr,1077expected_select: dict[str, PolarsDataType],1078expected_gb: dict[str, PolarsDataType],1079) -> None:1080df = pl.DataFrame(data)10811082# test selection schema1083schema = df.select(expr).schema1084for key, dtype in expected_select.items():1085assert schema[key] == dtype10861087# test group_by schema1088schema = df.group_by(pl.lit(1)).agg(expr).schema1089for key, dtype in expected_gb.items():1090assert schema[key] == dtype109110921093def test_lit_iter_schema() -> None:1094df = pl.DataFrame(1095{1096"key": ["A", "A", "A", "A"],1097"dates": [1098date(1970, 1, 1),1099date(1970, 1, 1),1100date(1970, 1, 2),1101date(1970, 1, 3),1102],1103}1104)11051106result = df.group_by("key").agg(pl.col("dates").unique() + timedelta(days=1))1107expected = {1108"key": ["A"],1109"dates": [[date(1970, 1, 2), date(1970, 1, 3), date(1970, 1, 4)]],1110}1111assert result.to_dict(as_series=False) == expected111211131114def test_absence_off_null_prop_8224() -> None:1115# a reminder to self to not do null propagation1116# it is inconsistent and makes output dtype1117# dependent of the data, big no!11181119def sub_col_min(column: str, min_column: str) -> pl.Expr:1120return pl.col(column).sub(pl.col(min_column).min())11211122df = pl.DataFrame(1123{1124"group": [1, 1, 2, 2],1125"vals_num": [10.0, 11.0, 12.0, 13.0],1126"vals_partial": [None, None, 12.0, 13.0],1127"vals_null": [None, None, None, None],1128}1129)11301131q = (1132df.lazy()1133.group_by("group")1134.agg(1135sub_col_min("vals_num", "vals_num").alias("sub_num"),1136sub_col_min("vals_num", "vals_partial").alias("sub_partial"),1137sub_col_min("vals_num", "vals_null").alias("sub_null"),1138)1139)11401141assert q.collect().dtypes == [1142pl.Int64,1143pl.List(pl.Float64),1144pl.List(pl.Float64),1145pl.List(pl.Float64),1146]114711481149@pytest.mark.parametrize("maintain_order", [False, True])1150def test_grouped_slice_literals(maintain_order: bool) -> None:1151df = pl.DataFrame({"idx": [1, 2, 3]})1152q = (1153df.lazy()1154.group_by(True, maintain_order=maintain_order)1155.agg(1156x=pl.lit([1, 2]).slice(1157-1, 11158), # slices a list of 1 element, so remains the same element1159x2=pl.lit(pl.Series([1, 2])).slice(-1, 1),1160x3=pl.lit(pl.Series([[1, 2]])).slice(-1, 1),1161)1162)1163out = q.collect()1164expected = pl.DataFrame(1165{"literal": [True], "x": [[[1, 2]]], "x2": [[2]], "x3": [[[1, 2]]]}1166)1167assert_frame_equal(1168out,1169expected,1170check_row_order=maintain_order,1171)1172assert q.collect_schema() == q.collect().schema117311741175def test_positional_by_with_list_or_tuple_17540() -> None:1176with pytest.raises(TypeError, match="Hint: if you"):1177pl.DataFrame({"a": [1, 2, 3]}).group_by(by=["a"])1178with pytest.raises(TypeError, match="Hint: if you"):1179pl.LazyFrame({"a": [1, 2, 3]}).group_by(by=["a"])118011811182def test_group_by_agg_19173() -> None:1183df = pl.DataFrame({"x": [1.0], "g": [0]})1184out = df.head(0).group_by("g").agg((pl.col.x - pl.col.x.sum() * pl.col.x) ** 2)1185assert out.to_dict(as_series=False) == {"g": [], "x": []}1186assert out.schema == pl.Schema([("g", pl.Int64), ("x", pl.List(pl.Float64))])118711881189def test_group_by_map_groups_slice_pushdown_20002() -> None:1190schema = {1191"a": pl.Int8,1192"b": pl.UInt8,1193}11941195df = (1196pl.LazyFrame(1197data={"a": [1, 2, 3, 4, 5], "b": [90, 80, 70, 60, 50]},1198schema=schema,1199)1200.group_by("a", maintain_order=True)1201.map_groups(lambda df: df * 2.0, schema=schema)1202.head(3)1203.collect()1204)12051206assert_frame_equal(1207df,1208pl.DataFrame(1209{1210"a": [2.0, 4.0, 6.0],1211"b": [180.0, 160.0, 140.0],1212}1213),1214)121512161217@typing.no_type_check1218def test_group_by_lit_series(capfd: Any, monkeypatch: Any) -> None:1219monkeypatch.setenv("POLARS_VERBOSE", "1")1220n = 101221df = pl.DataFrame({"x": np.ones(2 * n), "y": n * list(range(2))})1222a = np.ones(n, dtype=float)1223df.lazy().group_by("y").agg(pl.col("x").dot(a)).collect()1224captured = capfd.readouterr().err1225assert "are not partitionable" in captured122612271228def test_group_by_list_column() -> None:1229df = pl.DataFrame({"a": [1, 2, 3], "b": [[1, 2], [3], [1, 2]]})1230result = df.group_by("b").agg(pl.sum("a")).sort("b")1231expected = pl.DataFrame({"b": [[1, 2], [3]], "a": [4, 2]})1232assert_frame_equal(result, expected)123312341235def test_enum_perfect_group_by_21360() -> None:1236dtype = pl.Enum(categories=["a", "b"])12371238assert_frame_equal(1239pl.from_dicts([{"col": "a"}], schema={"col": dtype})1240.group_by("col")1241.agg(pl.len()),1242pl.DataFrame(1243[1244pl.Series("col", ["a"], dtype),1245pl.Series("len", [1], get_index_type()),1246]1247),1248)124912501251def test_partitioned_group_by_21634(partition_limit: int) -> None:1252n = partition_limit1253df = pl.DataFrame({"grp": [1] * n, "x": [1] * n})1254assert df.group_by("grp", True).agg().to_dict(as_series=False) == {1255"grp": [1],1256"literal": [True],1257}125812591260def test_group_by_cse_dup_key_alias_22238() -> None:1261df = pl.LazyFrame({"a": [1, 1, 2, 2, -1], "x": [0, 1, 2, 3, 10]})1262result = df.group_by(1263pl.col("a").abs(),1264pl.col("a").abs().alias("a_with_alias"),1265).agg(pl.col("x").sum())1266assert_frame_equal(1267result.collect(),1268pl.DataFrame({"a": [1, 2], "a_with_alias": [1, 2], "x": [11, 5]}),1269check_row_order=False,1270)127112721273def test_group_by_22328() -> None:1274N = 2012751276df1 = pl.select(1277x=pl.repeat(1, N // 2).append(pl.repeat(2, N // 2)).shuffle(),1278y=pl.lit(3.0, pl.Float32),1279).lazy()12801281df2 = pl.select(x=pl.repeat(4, N)).lazy()12821283assert (1284df2.join(df1.group_by("x").mean().with_columns(z="y"), how="left", on="x")1285.with_columns(pl.col("z").fill_null(0))1286.collect()1287).shape == (20, 3)128812891290@pytest.mark.parametrize("maintain_order", [False, True])1291def test_group_by_arrays_22574(maintain_order: bool) -> None:1292assert_frame_equal(1293pl.Series("a", [[1], [2], [2]], pl.Array(pl.Int64, 1))1294.to_frame()1295.group_by("a", maintain_order=maintain_order)1296.agg(pl.len()),1297pl.DataFrame(1298[1299pl.Series("a", [[1], [2]], pl.Array(pl.Int64, 1)),1300pl.Series("len", [1, 2], pl.get_index_type()),1301]1302),1303check_row_order=maintain_order,1304)13051306assert_frame_equal(1307pl.Series(1308"a", [[[1, 2]], [[2, 3]], [[2, 3]]], pl.Array(pl.Array(pl.Int64, 2), 1)1309)1310.to_frame()1311.group_by("a", maintain_order=maintain_order)1312.agg(pl.len()),1313pl.DataFrame(1314[1315pl.Series(1316"a", [[[1, 2]], [[2, 3]]], pl.Array(pl.Array(pl.Int64, 2), 1)1317),1318pl.Series("len", [1, 2], pl.get_index_type()),1319]1320),1321check_row_order=maintain_order,1322)132313241325def test_group_by_empty_rows_with_literal_21959() -> None:1326out = (1327pl.LazyFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [1, 1, 3]})1328.filter(pl.col("c") == 99)1329.group_by(pl.lit(1).alias("d"), pl.col("a"), pl.col("b"))1330.agg()1331.collect()1332)1333expected = pl.DataFrame(1334{"d": [], "a": [], "b": []},1335schema={"d": pl.Int32, "a": pl.Int64, "b": pl.Int64},1336)1337assert_frame_equal(out, expected)133813391340def test_group_by_empty_dtype_22716() -> None:1341df = pl.DataFrame(schema={"a": pl.String, "b": pl.Int64})1342out = df.group_by("a").agg(x=(pl.col("b") == pl.int_range(pl.len())).all())1343assert_frame_equal(out, pl.DataFrame(schema={"a": pl.String, "x": pl.Boolean}))134413451346def test_group_by_implode_22870() -> None:1347out = (1348pl.DataFrame({"x": ["a", "b"]})1349.group_by(pl.col.x)1350.agg(1351y=pl.col.x.replace_strict(1352pl.lit(pl.Series(["a", "b"])).implode(),1353pl.lit(pl.Series([1, 2])).implode(),1354default=-1,1355)1356)1357)1358assert_frame_equal(1359out,1360pl.DataFrame({"x": ["a", "b"], "y": [[1], [2]]}),1361check_row_order=False,1362)136313641365# Note: the underlying bug is not guaranteed to manifest itself as it depends1366# on the internal group order, i.e., for the bug to materialize, there must be1367# empty groups before the non-empty group1368def test_group_by_empty_groups_23338() -> None:1369# We need one non-empty and many groups1370df = pl.DataFrame(1371{1372"k": [10, 10, 20, 30, 40, 50, 60, 70, 80, 90],1373"a": [1, 1, 2, 3, 4, 5, 6, 7, 8, 9],1374}1375)1376out = df.group_by("k").agg(1377pl.col("a").filter(pl.col("a") == 1).fill_nan(None).sum()1378)1379expected = df.group_by("k").agg(pl.col("a").filter(pl.col("a") == 1).sum())1380assert_frame_equal(out.sort("k"), expected.sort("k"))138113821383def test_group_by_filter_all_22955() -> None:1384df = pl.DataFrame(1385{1386"grp": [1, 2, 3, 4, 5],1387"value": [10, 20, 30, 40, 50],1388}1389)13901391assert_frame_equal(1392df.group_by("grp").agg(1393pl.all().filter(pl.col("value") > 20),1394),1395pl.DataFrame(1396{1397"grp": [1, 2, 3, 4, 5],1398"value": [[], [], [30], [40], [50]],1399}1400),1401check_row_order=False,1402)140314041405@pytest.mark.parametrize("maintain_order", [False, True])1406def test_group_by_series_lit_22103(maintain_order: bool) -> None:1407df = pl.DataFrame(1408{1409"g": [0, 1],1410}1411)1412assert_frame_equal(1413df.group_by("g", maintain_order=maintain_order).agg(1414foo=pl.lit(pl.Series([42, 2, 3]))1415),1416pl.DataFrame(1417{1418"g": [0, 1],1419"foo": [[42, 2, 3], [42, 2, 3]],1420}1421),1422check_row_order=maintain_order,1423)142414251426@pytest.mark.parametrize("maintain_order", [False, True])1427def test_group_by_filter_sum_23897(maintain_order: bool) -> None:1428testdf = pl.DataFrame(1429{1430"id": [8113, 9110, 9110],1431"value": [None, None, 1.0],1432"weight": [1.0, 1.0, 1.0],1433}1434)14351436w = pl.col("weight").filter(pl.col("value").is_finite())14371438w = w / w.sum()14391440result = w.sum()14411442assert_frame_equal(1443testdf.group_by("id", maintain_order=maintain_order).agg(result),1444pl.DataFrame({"id": [8113, 9110], "weight": [0.0, 1.0]}),1445check_row_order=maintain_order,1446)144714481449@pytest.mark.parametrize("maintain_order", [False, True])1450def test_group_by_shift_filter_23910(maintain_order: bool) -> None:1451df = pl.DataFrame({"a": [3, 7, 5, 9, 2, 1], "b": [2, 2, 2, 3, 3, 1]})14521453out = df.group_by("b", maintain_order=maintain_order).agg(1454pl.col("a").filter(pl.col("a") > pl.col("a").shift(1)).sum().alias("tt")1455)14561457assert_frame_equal(1458out,1459pl.DataFrame(1460{1461"b": [2, 3, 1],1462"tt": [7, 0, 0],1463}1464),1465check_row_order=maintain_order,1466)146714681469def test_group_by_tuple_typing_24112() -> None:1470df = pl.DataFrame({"id": ["a", "b", "a"], "val": [1, 2, 3]})1471for (id_,), _ in df.group_by("id"):1472_should_work: str = id_147314741475def test_group_by_input_independent_with_len_23868() -> None:1476out = pl.DataFrame({"a": ["A", "B", "C"]}).group_by(pl.lit("G")).agg(pl.len())1477assert_frame_equal(1478out,1479pl.DataFrame(1480{"literal": "G", "len": 3},1481schema={"literal": pl.String, "len": pl.get_index_type()},1482),1483)148414851486@pytest.mark.parametrize("maintain_order", [False, True])1487def test_group_by_head_tail_24215(maintain_order: bool) -> None:1488df = pl.DataFrame(1489{1490"station": ["A", "A", "B"],1491"num_rides": [1, 2, 3],1492}1493)1494expected = pl.DataFrame(1495{"station": ["A", "B"], "num_rides": [1.5, 3], "rides_per_day": [[1, 2], [3]]}1496)14971498result = (1499df.group_by("station", maintain_order=maintain_order)1500.agg(1501cs.numeric().mean(),1502pl.col("num_rides").alias("rides_per_day"),1503)1504.group_by("station", maintain_order=maintain_order)1505.head(1)1506)1507assert_frame_equal(result, expected, check_row_order=maintain_order)15081509result = (1510df.group_by("station", maintain_order=maintain_order)1511.agg(1512cs.numeric().mean(),1513pl.col("num_rides").alias("rides_per_day"),1514)1515.group_by("station", maintain_order=maintain_order)1516.tail(1)1517)1518assert_frame_equal(result, expected, check_row_order=maintain_order)151915201521def test_slice_group_by_offset_24259() -> None:1522df = pl.DataFrame(1523{1524"letters": ["c", "c", "a", "c", "a", "b", "d"],1525"nrs": [1, 2, 3, 4, 5, 6, None],1526}1527)1528assert df.group_by("letters").agg(1529x=pl.col("nrs").drop_nulls(),1530tail=pl.col("nrs").drop_nulls().tail(1),1531).sort("letters").to_dict(as_series=False) == {1532"letters": ["a", "b", "c", "d"],1533"x": [[3, 5], [6], [1, 2, 4], []],1534"tail": [[5], [6], [4], []],1535}153615371538def test_group_by_first_nondet_24278() -> None:1539values = [154096, 86, 0, 86, 43, 50, 9, 14, 98, 39, 93, 7, 71, 1, 93, 41, 56,154156, 93, 41, 58, 91, 81, 29, 81, 68, 5, 9, 32, 93, 78, 34, 17, 40,154214, 2, 52, 77, 81, 4, 56, 42, 64, 12, 29, 58, 71, 98, 32, 49, 34,154386, 29, 94, 37, 21, 41, 36, 9, 72, 23, 28, 71, 9, 66, 72, 84, 81,154423, 12, 64, 57, 99, 15, 77, 38, 95, 64, 13, 91, 43, 61, 70, 47,154539, 75, 47, 93, 45, 1, 95, 55, 29, 5, 83, 8, 3, 6, 45, 84,1546] # fmt: skip1547q = (1548pl.LazyFrame({"a": values, "idx": range(100)})1549.group_by("a")1550.agg(pl.col.idx.first())1551.select(a=pl.col.idx)1552)15531554fst_value = q.collect().to_series().sum()1555for _ in range(10):1556assert q.collect().to_series().sum() == fst_value155715581559