Path: blob/main/py-polars/tests/unit/operations/test_group_by.py
8429 views
from __future__ import annotations12import typing3from collections import OrderedDict4from datetime import date, datetime, time, timedelta5from typing import TYPE_CHECKING, Any6from zoneinfo import ZoneInfo78import numpy as np9import pytest10from hypothesis import given1112import polars as pl13import polars.selectors as cs14from polars import Expr15from polars.exceptions import (16ColumnNotFoundError,17InvalidOperationError,18)19from polars.meta import get_index_type20from polars.testing import assert_frame_equal, assert_series_equal21from polars.testing.parametric import column, dataframes, series2223if TYPE_CHECKING:24from collections.abc import Callable2526from polars._typing import PolarsDataType, TimeUnit27from tests.conftest import PlMonkeyPatch282930def test_group_by() -> None:31df = pl.DataFrame(32{33"a": ["a", "b", "a", "b", "b", "c"],34"b": [1, 2, 3, 4, 5, 6],35"c": [6, 5, 4, 3, 2, 1],36}37)3839# Use lazy API in eager group_by40assert sorted(df.group_by("a").agg([pl.sum("b")]).rows()) == [41("a", 4),42("b", 11),43("c", 6),44]45# test if it accepts a single expression46assert df.group_by("a", maintain_order=True).agg(pl.sum("b")).rows() == [47("a", 4),48("b", 11),49("c", 6),50]5152df = pl.DataFrame(53{54"a": [1, 2, 3, 4, 5],55"b": ["a", "a", "b", "b", "b"],56"c": [None, 1, None, 1, None],57}58)5960# check if this query runs and thus column names propagate61df.group_by("b").agg(pl.col("c").fill_null(strategy="forward")).explode("c")6263# get a specific column64result = df.group_by("b", maintain_order=True).agg(pl.count("a"))65assert result.rows() == [("a", 2), ("b", 3)]66assert result.columns == ["b", "a"]676869@pytest.mark.parametrize(70("input", "expected", "input_dtype", "output_dtype"),71[72([1, 2, 3, 4], [2, 4], pl.UInt8, pl.Float64),73([1, 2, 3, 4], [2, 4], pl.Int8, pl.Float64),74([1, 2, 3, 4], [2, 4], pl.UInt16, pl.Float64),75([1, 2, 3, 4], [2, 4], pl.Int16, pl.Float64),76([1, 2, 3, 4], [2, 4], pl.UInt32, pl.Float64),77([1, 2, 3, 4], [2, 4], pl.Int32, pl.Float64),78([1, 2, 3, 4], [2, 4], pl.UInt64, pl.Float64),79([1, 2, 3, 4], [2, 4], pl.Float32, pl.Float32),80([1, 2, 3, 4], [2, 4], pl.Float64, pl.Float64),81([False, True, True, True], [2 / 3, 1], pl.Boolean, pl.Float64),82(83[date(2023, 1, 1), date(2023, 1, 2), date(2023, 1, 4), date(2023, 1, 5)],84[datetime(2023, 1, 2, 8, 0, 0), datetime(2023, 1, 5)],85pl.Date,86pl.Datetime("us"),87),88(89[90datetime(2023, 1, 1),91datetime(2023, 1, 2),92datetime(2023, 1, 3),93datetime(2023, 1, 4),94],95[datetime(2023, 1, 2), datetime(2023, 1, 4)],96pl.Datetime("ms"),97pl.Datetime("ms"),98),99(100[101datetime(2023, 1, 1),102datetime(2023, 1, 2),103datetime(2023, 1, 3),104datetime(2023, 1, 4),105],106[datetime(2023, 1, 2), datetime(2023, 1, 4)],107pl.Datetime("us"),108pl.Datetime("us"),109),110(111[112datetime(2023, 1, 1),113datetime(2023, 1, 2),114datetime(2023, 1, 3),115datetime(2023, 1, 4),116],117[datetime(2023, 1, 2), datetime(2023, 1, 4)],118pl.Datetime("ns"),119pl.Datetime("ns"),120),121(122[timedelta(1), timedelta(2), timedelta(3), timedelta(4)],123[timedelta(2), timedelta(4)],124pl.Duration("ms"),125pl.Duration("ms"),126),127(128[timedelta(1), timedelta(2), timedelta(3), timedelta(4)],129[timedelta(2), timedelta(4)],130pl.Duration("us"),131pl.Duration("us"),132),133(134[timedelta(1), timedelta(2), timedelta(3), timedelta(4)],135[timedelta(2), timedelta(4)],136pl.Duration("ns"),137pl.Duration("ns"),138),139],140)141def test_group_by_mean_by_dtype(142input: list[Any],143expected: list[Any],144input_dtype: PolarsDataType,145output_dtype: PolarsDataType,146) -> None:147# groups are defined by first 3 values, then last value148name = str(input_dtype)149key = ["a", "a", "a", "b"]150df = pl.LazyFrame(151{152"key": key,153name: pl.Series(input, dtype=input_dtype),154}155)156result = df.group_by("key", maintain_order=True).mean()157df_expected = pl.DataFrame(158{159"key": ["a", "b"],160name: pl.Series(expected, dtype=output_dtype),161}162)163assert result.collect_schema() == df_expected.schema164assert_frame_equal(result.collect(), df_expected)165166167@pytest.mark.parametrize(168("input", "expected", "input_dtype", "output_dtype"),169[170([1, 2, 4, 5], [2, 5], pl.UInt8, pl.Float64),171([1, 2, 4, 5], [2, 5], pl.Int8, pl.Float64),172([1, 2, 4, 5], [2, 5], pl.UInt16, pl.Float64),173([1, 2, 4, 5], [2, 5], pl.Int16, pl.Float64),174([1, 2, 4, 5], [2, 5], pl.UInt32, pl.Float64),175([1, 2, 4, 5], [2, 5], pl.Int32, pl.Float64),176([1, 2, 4, 5], [2, 5], pl.UInt64, pl.Float64),177([1, 2, 4, 5], [2, 5], pl.Float32, pl.Float32),178([1, 2, 4, 5], [2, 5], pl.Float64, pl.Float64),179([False, True, True, True], [1, 1], pl.Boolean, pl.Float64),180(181[date(2023, 1, 1), date(2023, 1, 2), date(2023, 1, 4), date(2023, 1, 5)],182[datetime(2023, 1, 2), datetime(2023, 1, 5)],183pl.Date,184pl.Datetime("us"),185),186(187[188datetime(2023, 1, 1),189datetime(2023, 1, 2),190datetime(2023, 1, 4),191datetime(2023, 1, 5),192],193[datetime(2023, 1, 2), datetime(2023, 1, 5)],194pl.Datetime("ms"),195pl.Datetime("ms"),196),197(198[199datetime(2023, 1, 1),200datetime(2023, 1, 2),201datetime(2023, 1, 4),202datetime(2023, 1, 5),203],204[datetime(2023, 1, 2), datetime(2023, 1, 5)],205pl.Datetime("us"),206pl.Datetime("us"),207),208(209[210datetime(2023, 1, 1),211datetime(2023, 1, 2),212datetime(2023, 1, 4),213datetime(2023, 1, 5),214],215[datetime(2023, 1, 2), datetime(2023, 1, 5)],216pl.Datetime("ns"),217pl.Datetime("ns"),218),219(220[timedelta(1), timedelta(2), timedelta(4), timedelta(5)],221[timedelta(2), timedelta(5)],222pl.Duration("ms"),223pl.Duration("ms"),224),225(226[timedelta(1), timedelta(2), timedelta(4), timedelta(5)],227[timedelta(2), timedelta(5)],228pl.Duration("us"),229pl.Duration("us"),230),231(232[timedelta(1), timedelta(2), timedelta(4), timedelta(5)],233[timedelta(2), timedelta(5)],234pl.Duration("ns"),235pl.Duration("ns"),236),237],238)239def test_group_by_median_by_dtype(240input: list[Any],241expected: list[Any],242input_dtype: PolarsDataType,243output_dtype: PolarsDataType,244) -> None:245# groups are defined by first 3 values, then last value246name = str(input_dtype)247key = ["a", "a", "a", "b"]248df = pl.LazyFrame(249{250"key": key,251name: pl.Series(input, dtype=input_dtype),252}253)254result = df.group_by("key", maintain_order=True).median()255df_expected = pl.DataFrame(256{257"key": ["a", "b"],258name: pl.Series(expected, dtype=output_dtype),259}260)261assert result.collect_schema() == df_expected.schema262assert_frame_equal(result.collect(), df_expected)263264265@pytest.fixture266def df() -> pl.DataFrame:267return pl.DataFrame(268{269"a": [1, 2, 3, 4, 5],270"b": ["a", "a", "b", "b", "b"],271"c": [None, 1, None, 1, None],272}273)274275276@pytest.mark.parametrize(277("method", "expected"),278[279("all", [("a", [1, 2], [None, 1]), ("b", [3, 4, 5], [None, 1, None])]),280("len", [("a", 2), ("b", 3)]),281("first", [("a", 1, None), ("b", 3, None)]),282("last", [("a", 2, 1), ("b", 5, None)]),283("max", [("a", 2, 1), ("b", 5, 1)]),284("mean", [("a", 1.5, 1.0), ("b", 4.0, 1.0)]),285("median", [("a", 1.5, 1.0), ("b", 4.0, 1.0)]),286("min", [("a", 1, 1), ("b", 3, 1)]),287("n_unique", [("a", 2, 2), ("b", 3, 2)]),288],289)290def test_group_by_shorthands(291df: pl.DataFrame, method: str, expected: list[tuple[Any]]292) -> None:293gb = df.group_by("b", maintain_order=True)294result = getattr(gb, method)()295assert result.rows() == expected296297gb_lazy = df.lazy().group_by("b", maintain_order=True)298result = getattr(gb_lazy, method)().collect()299assert result.rows() == expected300301302def test_group_by_shorthand_quantile(df: pl.DataFrame) -> None:303result = df.group_by("b", maintain_order=True).quantile(0.5)304expected = [("a", 2.0, 1.0), ("b", 4.0, 1.0)]305assert result.rows() == expected306307result = df.lazy().group_by("b", maintain_order=True).quantile(0.5).collect()308assert result.rows() == expected309310311def test_group_by_quantile_date() -> None:312df = pl.DataFrame(313{314"group": [1, 1, 1, 1, 2, 2, 2, 2],315"value": [date(2025, 1, x) for x in range(1, 9)],316}317)318result = (319df.lazy()320.group_by("group", maintain_order=True)321.agg(322nearest=pl.col("value").quantile(0.5, "nearest"),323higher=pl.col("value").quantile(0.5, "higher"),324lower=pl.col("value").quantile(0.5, "lower"),325linear=pl.col("value").quantile(0.5, "linear"),326)327)328dt = pl.Datetime("us")329expected = pl.DataFrame(330{331"group": [1, 2],332"nearest": pl.Series(333[datetime(2025, 1, 3), datetime(2025, 1, 7)], dtype=dt334),335"higher": pl.Series([datetime(2025, 1, 3), datetime(2025, 1, 7)], dtype=dt),336"lower": pl.Series([datetime(2025, 1, 2), datetime(2025, 1, 6)], dtype=dt),337"linear": pl.Series(338[datetime(2025, 1, 2, 12), datetime(2025, 1, 6, 12)], dtype=dt339),340}341)342assert result.collect_schema() == pl.Schema(343{ # type: ignore[arg-type]344"group": pl.Int64,345"nearest": dt,346"higher": dt,347"lower": dt,348"linear": dt,349}350)351assert_frame_equal(result.collect(), expected)352353354@pytest.mark.parametrize("tu", ["ms", "us", "ns"])355@pytest.mark.parametrize("time_zone", [None, "Asia/Tokyo"])356def test_group_by_quantile_datetime(tu: TimeUnit, time_zone: str) -> None:357dt = pl.Datetime(tu, time_zone)358tz = ZoneInfo(time_zone) if time_zone else None359df = pl.DataFrame(360{361"group": [1, 1, 1, 1, 2, 2, 2, 2],362"value": pl.Series(363[datetime(2025, 1, x, tzinfo=tz) for x in range(1, 9)],364dtype=dt,365),366}367)368result = (369df.lazy()370.group_by("group", maintain_order=True)371.agg(372nearest=pl.col("value").quantile(0.5, "nearest"),373higher=pl.col("value").quantile(0.5, "higher"),374lower=pl.col("value").quantile(0.5, "lower"),375linear=pl.col("value").quantile(0.5, "linear"),376)377)378expected = pl.DataFrame(379{380"group": [1, 2],381"nearest": pl.Series(382[datetime(2025, 1, 3, tzinfo=tz), datetime(2025, 1, 7, tzinfo=tz)],383dtype=dt,384),385"higher": pl.Series(386[datetime(2025, 1, 3, tzinfo=tz), datetime(2025, 1, 7, tzinfo=tz)],387dtype=dt,388),389"lower": pl.Series(390[datetime(2025, 1, 2, tzinfo=tz), datetime(2025, 1, 6, tzinfo=tz)],391dtype=dt,392),393"linear": pl.Series(394[395datetime(2025, 1, 2, 12, tzinfo=tz),396datetime(2025, 1, 6, 12, tzinfo=tz),397],398dtype=dt,399),400}401)402assert result.collect_schema() == pl.Schema(403{ # type: ignore[arg-type]404"group": pl.Int64,405"nearest": dt,406"higher": dt,407"lower": dt,408"linear": dt,409}410)411assert_frame_equal(result.collect(), expected)412413414@pytest.mark.parametrize("tu", ["ms", "us", "ns"])415def test_group_by_quantile_duration(tu: TimeUnit) -> None:416dt = pl.Duration(tu)417df = pl.DataFrame(418{419"group": [1, 1, 1, 1, 2, 2, 2, 2],420"value": pl.Series([timedelta(hours=x) for x in range(1, 9)], dtype=dt),421}422)423result = (424df.lazy()425.group_by("group", maintain_order=True)426.agg(427nearest=pl.col("value").quantile(0.5, "nearest"),428higher=pl.col("value").quantile(0.5, "higher"),429lower=pl.col("value").quantile(0.5, "lower"),430linear=pl.col("value").quantile(0.5, "linear"),431)432)433expected = pl.DataFrame(434{435"group": [1, 2],436"nearest": pl.Series([timedelta(hours=3), timedelta(hours=7)], dtype=dt),437"higher": pl.Series([timedelta(hours=3), timedelta(hours=7)], dtype=dt),438"lower": pl.Series([timedelta(hours=2), timedelta(hours=6)], dtype=dt),439"linear": pl.Series(440[timedelta(hours=2, minutes=30), timedelta(hours=6, minutes=30)],441dtype=dt,442),443}444)445assert result.collect_schema() == pl.Schema(446{ # type: ignore[arg-type]447"group": pl.Int64,448"nearest": dt,449"higher": dt,450"lower": dt,451"linear": dt,452}453)454assert_frame_equal(result.collect(), expected)455456457def test_group_by_quantile_time() -> None:458df = pl.DataFrame(459{460"group": [1, 1, 1, 1, 2, 2, 2, 2],461"value": pl.Series([time(hour=x) for x in range(1, 9)]),462}463)464result = (465df.lazy()466.group_by("group", maintain_order=True)467.agg(468nearest=pl.col("value").quantile(0.5, "nearest"),469higher=pl.col("value").quantile(0.5, "higher"),470lower=pl.col("value").quantile(0.5, "lower"),471linear=pl.col("value").quantile(0.5, "linear"),472)473)474expected = pl.DataFrame(475{476"group": [1, 2],477"nearest": pl.Series([time(hour=3), time(hour=7)]),478"higher": pl.Series([time(hour=3), time(hour=7)]),479"lower": pl.Series([time(hour=2), time(hour=6)]),480"linear": pl.Series([time(hour=2, minute=30), time(hour=6, minute=30)]),481}482)483assert result.collect_schema() == pl.Schema(484{485"group": pl.Int64,486"nearest": pl.Time,487"higher": pl.Time,488"lower": pl.Time,489"linear": pl.Time,490}491)492assert_frame_equal(result.collect(), expected)493494495def test_group_by_args() -> None:496df = pl.DataFrame(497{498"a": ["a", "b", "a", "b", "b", "c"],499"b": [1, 2, 3, 4, 5, 6],500"c": [6, 5, 4, 3, 2, 1],501}502)503504# Single column name505assert df.group_by("a").agg("b").columns == ["a", "b"]506# Column names as list507expected = ["a", "b", "c"]508assert df.group_by(["a", "b"]).agg("c").columns == expected509# Column names as positional arguments510assert df.group_by("a", "b").agg("c").columns == expected511# With keyword argument512assert df.group_by("a", "b", maintain_order=True).agg("c").columns == expected513# Multiple aggregations as list514assert df.group_by("a").agg(["b", "c"]).columns == expected515# Multiple aggregations as positional arguments516assert df.group_by("a").agg("b", "c").columns == expected517# Multiple aggregations as keyword arguments518assert df.group_by("a").agg(q="b", r="c").columns == ["a", "q", "r"]519520521def test_group_by_empty() -> None:522df = pl.DataFrame({"a": [1, 1, 2]})523result = df.group_by("a").agg()524expected = pl.DataFrame({"a": [1, 2]})525assert_frame_equal(result, expected, check_row_order=False)526527528def test_group_by_iteration() -> None:529df = pl.DataFrame(530{531"foo": ["a", "b", "a", "b", "b", "c"],532"bar": [1, 2, 3, 4, 5, 6],533"baz": [6, 5, 4, 3, 2, 1],534}535)536expected_names = ["a", "b", "c"]537expected_rows = [538[("a", 1, 6), ("a", 3, 4)],539[("b", 2, 5), ("b", 4, 3), ("b", 5, 2)],540[("c", 6, 1)],541]542gb_iter = enumerate(df.group_by("foo", maintain_order=True))543for i, (group, data) in gb_iter:544assert group == (expected_names[i],)545assert data.rows() == expected_rows[i]546547# Grouped by ALL columns should give groups of a single row548result = list(df.group_by(["foo", "bar", "baz"]))549assert len(result) == 6550551# Iterating over groups should also work when grouping by expressions552result2 = list(df.group_by(["foo", pl.col("bar") * pl.col("baz")]))553assert len(result2) == 5554555# Single expression, alias in group_by556df = pl.DataFrame({"foo": [1, 2, 3, 4, 5, 6]})557gb = df.group_by((pl.col("foo") // 2).alias("bar"), maintain_order=True)558result3 = [(group, df.rows()) for group, df in gb]559expected3 = [560((0,), [(1,)]),561((1,), [(2,), (3,)]),562((2,), [(4,), (5,)]),563((3,), [(6,)]),564]565assert result3 == expected3566567568def test_group_by_iteration_selector() -> None:569df = pl.DataFrame({"a": ["one", "two", "one", "two"], "b": [1, 2, 3, 4]})570result = dict(df.group_by(cs.string()))571result_first = result["one",]572assert result_first.to_dict(as_series=False) == {"a": ["one", "one"], "b": [1, 3]}573574575@pytest.mark.parametrize("input", [[pl.col("b").sum()], pl.col("b").sum()])576def test_group_by_agg_input_types(input: Any) -> None:577df = pl.LazyFrame({"a": [1, 1, 2, 2], "b": [1, 2, 3, 4]})578result = df.group_by("a", maintain_order=True).agg(input)579expected = pl.LazyFrame({"a": [1, 2], "b": [3, 7]})580assert_frame_equal(result, expected)581582583@pytest.mark.parametrize("input", [str, "b".join])584def test_group_by_agg_bad_input_types(input: Any) -> None:585df = pl.LazyFrame({"a": [1, 1, 2, 2], "b": [1, 2, 3, 4]})586with pytest.raises(TypeError):587df.group_by("a").agg(input)588589590def test_group_by_sorted_empty_dataframe_3680() -> None:591df = (592pl.DataFrame(593[594pl.Series("key", [], dtype=pl.Categorical),595pl.Series("val", [], dtype=pl.Float64),596]597)598.lazy()599.sort("key")600.group_by("key")601.tail(1)602.collect(optimizations=pl.QueryOptFlags(check_order_observe=False))603)604assert df.rows() == []605assert df.shape == (0, 2)606assert df.schema == {"key": pl.Categorical(), "val": pl.Float64}607608609def test_group_by_custom_agg_empty_list() -> None:610assert (611pl.DataFrame(612[613pl.Series("key", [], dtype=pl.Categorical),614pl.Series("val", [], dtype=pl.Float64),615]616)617.group_by("key")618.agg(619[620pl.col("val").mean().alias("mean"),621pl.col("val").std().alias("std"),622pl.col("val").skew().alias("skew"),623pl.col("val").kurtosis().alias("kurt"),624]625)626).dtypes == [pl.Categorical, pl.Float64, pl.Float64, pl.Float64, pl.Float64]627628629def test_apply_after_take_in_group_by_3869() -> None:630assert (631pl.DataFrame(632{633"k": list("aaabbb"),634"t": [1, 2, 3, 4, 5, 6],635"v": [3, 1, 2, 5, 6, 4],636}637)638.group_by("k", maintain_order=True)639.agg(640pl.col("v").get(pl.col("t").arg_max()).sqrt()641) # <- fails for sqrt, exp, log, pow, etc.642).to_dict(as_series=False) == {"k": ["a", "b"], "v": [1.4142135623730951, 2.0]}643644645def test_group_by_signed_transmutes() -> None:646df = pl.DataFrame({"foo": [-1, -2, -3, -4, -5], "bar": [500, 600, 700, 800, 900]})647648for dt in [pl.Int8, pl.Int16, pl.Int32, pl.Int64]:649df = (650df.with_columns([pl.col("foo").cast(dt), pl.col("bar")])651.group_by("foo", maintain_order=True)652.agg(pl.col("bar").median())653)654655assert df.to_dict(as_series=False) == {656"foo": [-1, -2, -3, -4, -5],657"bar": [500.0, 600.0, 700.0, 800.0, 900.0],658}659660661def test_arg_sort_sort_by_groups_update__4360() -> None:662df = pl.DataFrame(663{664"group": ["a"] * 3 + ["b"] * 3 + ["c"] * 3,665"col1": [1, 2, 3] * 3,666"col2": [1, 2, 3, 3, 2, 1, 2, 3, 1],667}668)669670out = df.with_columns(671pl.col("col2").arg_sort().over("group").alias("col2_arg_sort")672).with_columns(673pl.col("col1").sort_by(pl.col("col2_arg_sort")).over("group").alias("result_a"),674pl.col("col1")675.sort_by(pl.col("col2").arg_sort())676.over("group")677.alias("result_b"),678)679680assert_series_equal(out["result_a"], out["result_b"], check_names=False)681assert out["result_a"].to_list() == [1, 2, 3, 3, 2, 1, 2, 3, 1]682683684def test_unique_order() -> None:685df = pl.DataFrame({"a": [1, 2, 1]}).with_row_index()686assert df.unique(keep="last", subset="a", maintain_order=True).to_dict(687as_series=False688) == {689"index": [1, 2],690"a": [2, 1],691}692assert df.unique(keep="first", subset="a", maintain_order=True).to_dict(693as_series=False694) == {695"index": [0, 1],696"a": [1, 2],697}698699700def test_group_by_dynamic_flat_agg_4814() -> None:701df = pl.DataFrame({"a": [1, 2, 2], "b": [1, 8, 12]}).set_sorted("a")702703assert df.group_by_dynamic("a", every="1i", period="2i").agg(704[705(pl.col("b").sum() / pl.col("a").sum()).alias("sum_ratio_1"),706(pl.col("b").last() / pl.col("a").last()).alias("last_ratio_1"),707(pl.col("b") / pl.col("a")).last().alias("last_ratio_2"),708]709).to_dict(as_series=False) == {710"a": [1, 2],711"sum_ratio_1": [4.2, 5.0],712"last_ratio_1": [6.0, 6.0],713"last_ratio_2": [6.0, 6.0],714}715716717@pytest.mark.parametrize(718("every", "period"),719[720("10s", timedelta(seconds=100)),721(timedelta(seconds=10), "100s"),722],723)724@pytest.mark.parametrize("time_zone", [None, "UTC", "Asia/Kathmandu"])725def test_group_by_dynamic_overlapping_groups_flat_apply_multiple_5038(726every: str | timedelta, period: str | timedelta, time_zone: str | None727) -> None:728res = (729(730pl.DataFrame(731{732"a": [733datetime(2021, 1, 1) + timedelta(seconds=2**i)734for i in range(10)735],736"b": [float(i) for i in range(10)],737}738)739.with_columns(pl.col("a").dt.replace_time_zone(time_zone))740.lazy()741.set_sorted("a")742.group_by_dynamic("a", every=every, period=period)743.agg([pl.col("b").var().sqrt().alias("corr")])744)745.collect()746.sum()747.to_dict(as_series=False)748)749750assert res["corr"] == pytest.approx([6.988674024215477])751assert res["a"] == [None]752753754def test_take_in_group_by() -> None:755df = pl.DataFrame({"group": [1, 1, 1, 2, 2, 2], "values": [10, 200, 3, 40, 500, 6]})756assert df.group_by("group").agg(757pl.col("values").get(1) - pl.col("values").get(2)758).sort("group").to_dict(as_series=False) == {"group": [1, 2], "values": [197, 494]}759760761def test_group_by_wildcard() -> None:762df = pl.DataFrame(763{764"a": [1, 2],765"b": [1, 2],766}767)768assert df.group_by([pl.col("*")], maintain_order=True).agg(769[pl.col("a").first().name.suffix("_agg")]770).to_dict(as_series=False) == {"a": [1, 2], "b": [1, 2], "a_agg": [1, 2]}771772773def test_group_by_all_masked_out() -> None:774df = pl.DataFrame(775{776"val": pl.Series(777[None, None, None, None], dtype=pl.Categorical, nan_to_null=True778).set_sorted(),779"col": [4, 4, 4, 4],780}781)782parts = df.partition_by("val")783assert len(parts) == 1784assert_frame_equal(parts[0], df)785786787def test_group_by_null_propagation_6185() -> None:788df_1 = pl.DataFrame({"A": [0, 0], "B": [1, 2]})789790expr = pl.col("A").filter(pl.col("A") > 0)791792expected = {"B": [1, 2], "A": [None, None]}793assert (794df_1.group_by("B")795.agg((expr - expr.mean()).mean())796.sort("B")797.to_dict(as_series=False)798== expected799)800801802def test_group_by_when_then_with_binary_and_agg_in_pred_6202() -> None:803df = pl.DataFrame(804{"code": ["a", "b", "b", "b", "a"], "xx": [1.0, -1.5, -0.2, -3.9, 3.0]}805)806assert (807df.group_by("code", maintain_order=True).agg(808[pl.when(pl.col("xx") > pl.min("xx")).then(True).otherwise(False)]809)810).to_dict(as_series=False) == {811"code": ["a", "b"],812"literal": [[False, True], [True, True, False]],813}814815816def test_group_by_binary_agg_with_literal() -> None:817df = pl.DataFrame({"id": ["a", "a", "b", "b"], "value": [1, 2, 3, 4]})818819out = df.group_by("id", maintain_order=True).agg(820pl.col("value") + pl.Series([1, 3])821)822assert out.to_dict(as_series=False) == {"id": ["a", "b"], "value": [[2, 5], [4, 7]]}823824out = df.group_by("id", maintain_order=True).agg(pl.col("value") + pl.lit(1))825assert out.to_dict(as_series=False) == {"id": ["a", "b"], "value": [[2, 3], [4, 5]]}826827out = df.group_by("id", maintain_order=True).agg(pl.lit(1) + pl.lit(2))828assert out.to_dict(as_series=False) == {"id": ["a", "b"], "literal": [3, 3]}829830out = df.group_by("id", maintain_order=True).agg(pl.lit(1) + pl.Series([2, 3]))831assert out.to_dict(as_series=False) == {832"id": ["a", "b"],833"literal": [[3, 4], [3, 4]],834}835836out = df.group_by("id", maintain_order=True).agg(837value=pl.lit(pl.Series([1, 2])) + pl.lit(pl.Series([3, 4]))838)839assert out.to_dict(as_series=False) == {"id": ["a", "b"], "value": [[4, 6], [4, 6]]}840841842@pytest.mark.slow843@pytest.mark.parametrize("dtype", [pl.Int32, pl.UInt32])844def test_overflow_mean_partitioned_group_by_5194(dtype: PolarsDataType) -> None:845df = pl.DataFrame(846[847pl.Series("data", [10_00_00_00] * 100_000, dtype=dtype),848pl.Series("group", [1, 2] * 50_000, dtype=dtype),849]850)851result = df.group_by("group").agg(pl.col("data").mean()).sort(by="group")852expected = {"group": [1, 2], "data": [10000000.0, 10000000.0]}853assert result.to_dict(as_series=False) == expected854855856# https://github.com/pola-rs/polars/issues/7181857def test_group_by_multiple_column_reference() -> None:858df = pl.DataFrame(859{860"gr": ["a", "b", "a", "b", "a", "b"],861"val": [1, 20, 100, 2000, 10000, 200000],862}863)864result = df.group_by("gr").agg(865pl.col("val") + pl.col("val").shift().fill_null(0),866)867868assert result.sort("gr").to_dict(as_series=False) == {869"gr": ["a", "b"],870"val": [[1, 101, 10100], [20, 2020, 202000]],871}872873874@pytest.mark.parametrize(875("aggregation", "args", "expected_values", "expected_dtype"),876[877("first", [], [1, None], pl.Int64),878("last", [], [1, None], pl.Int64),879("max", [], [1, None], pl.Int64),880("mean", [], [1.0, None], pl.Float64),881("median", [], [1.0, None], pl.Float64),882("min", [], [1, None], pl.Int64),883("n_unique", [], [1, 0], pl.get_index_type()),884("quantile", [0.5], [1.0, None], pl.Float64),885],886)887def test_group_by_empty_groups(888aggregation: str,889args: list[object],890expected_values: list[object],891expected_dtype: pl.DataType,892) -> None:893df = pl.DataFrame({"a": [1, 2], "b": [1, 2]})894result = df.group_by("b", maintain_order=True).agg(895getattr(pl.col("a").filter(pl.col("b") != 2), aggregation)(*args)896)897expected = pl.DataFrame({"b": [1, 2], "a": expected_values}).with_columns(898pl.col("a").cast(expected_dtype)899)900assert_frame_equal(result, expected)901902903# https://github.com/pola-rs/polars/issues/8663904def test_perfect_hash_table_null_values() -> None:905# fmt: off906values = ["3", "41", "17", "5", "26", "27", "43", "45", "41", "13", "45", "48", "17", "22", "31", "25", "28", "13", "7", "26", "17", "4", "43", "47", "30", "28", "8", "27", "6", "7", "26", "11", "37", "29", "49", "20", "29", "28", "23", "9", None, "38", "19", "7", "38", "3", "30", "37", "41", "5", "16", "26", "31", "6", "25", "11", "17", "31", "31", "20", "26", None, "39", "10", "38", "4", "39", "15", "13", "35", "38", "11", "39", "11", "48", "36", "18", "11", "34", "16", "28", "9", "37", "8", "17", "48", "44", "28", "25", "30", "37", "30", "18", "12", None, "27", "10", "3", "16", "27", "6"]907groups = ["3", "41", "17", "5", "26", "27", "43", "45", "13", "48", "22", "31", "25", "28", "7", "4", "47", "30", "8", "6", "11", "37", "29", "49", "20", "23", "9", None, "38", "19", "16", "39", "10", "15", "35", "36", "18", "34", "44", "12"]908# fmt: on909910s = pl.Series("a", values, dtype=pl.Categorical)911912result = (913s.to_frame("a").group_by("a", maintain_order=True).agg(pl.col("a").alias("agg"))914)915916agg_values = [917["3", "3", "3"],918["41", "41", "41"],919["17", "17", "17", "17", "17"],920["5", "5"],921["26", "26", "26", "26", "26"],922["27", "27", "27", "27"],923["43", "43"],924["45", "45"],925["13", "13", "13"],926["48", "48", "48"],927["22"],928["31", "31", "31", "31"],929["25", "25", "25"],930["28", "28", "28", "28", "28"],931["7", "7", "7"],932["4", "4"],933["47"],934["30", "30", "30", "30"],935["8", "8"],936["6", "6", "6"],937["11", "11", "11", "11", "11"],938["37", "37", "37", "37"],939["29", "29"],940["49"],941["20", "20"],942["23"],943["9", "9"],944[None, None, None],945["38", "38", "38", "38"],946["19"],947["16", "16", "16"],948["39", "39", "39"],949["10", "10"],950["15"],951["35"],952["36"],953["18", "18"],954["34"],955["44"],956["12"],957]958expected = pl.DataFrame(959{960"a": groups,961"agg": agg_values,962},963schema={"a": pl.Categorical, "agg": pl.List(pl.Categorical)},964)965assert_frame_equal(result, expected)966967968def test_group_by_partitioned_ending_cast(plmonkeypatch: PlMonkeyPatch) -> None:969plmonkeypatch.setenv("POLARS_FORCE_PARTITION", "1")970df = pl.DataFrame({"a": [1] * 5, "b": [1] * 5})971out = df.group_by(["a", "b"]).agg(pl.len().cast(pl.Int64).alias("num"))972expected = pl.DataFrame({"a": [1], "b": [1], "num": [5]})973assert_frame_equal(out, expected)974975976def test_group_by_series_partitioned(partition_limit: int) -> None:977# test 15354978df = pl.DataFrame([0, 0] * partition_limit)979groups = pl.Series([0, 1] * partition_limit)980df.group_by(groups).agg(pl.all().is_not_null().sum())981982983def test_group_by_list_scalar_11749() -> None:984df = pl.DataFrame(985{986"group_name": ["a;b", "a;b", "c;d", "c;d", "a;b", "a;b"],987"parent_name": ["a", "b", "c", "d", "a", "b"],988"measurement": [989["x1", "x2"],990["x1", "x2"],991["y1", "y2"],992["z1", "z2"],993["x1", "x2"],994["x1", "x2"],995],996}997)998assert (999df.group_by("group_name").agg(1000(pl.col("measurement").first() == pl.col("measurement")).alias("eq"),1001)1002).sort("group_name").to_dict(as_series=False) == {1003"group_name": ["a;b", "c;d"],1004"eq": [[True, True, True, True], [True, False]],1005}100610071008def test_group_by_with_expr_as_key() -> None:1009gb = pl.select(x=1).group_by(pl.col("x").alias("key"))1010result = gb.agg(pl.all().first())1011expected = gb.agg(pl.first("x"))1012assert_frame_equal(result, expected)10131014# tests: 117661015result = gb.head(0)1016expected = gb.agg(pl.col("x").head(0)).explode("x")1017assert_frame_equal(result, expected)10181019result = gb.tail(0)1020expected = gb.agg(pl.col("x").tail(0)).explode("x")1021assert_frame_equal(result, expected)102210231024def test_lazy_group_by_reuse_11767() -> None:1025lgb = pl.select(x=1).lazy().group_by("x")1026a = lgb.len()1027b = lgb.len()1028assert_frame_equal(a, b)102910301031def test_group_by_double_on_empty_12194() -> None:1032df = pl.DataFrame({"group": [1], "x": [1]}).clear()1033squared_deviation_sum = ((pl.col("x") - pl.col("x").mean()) ** 2).sum()1034assert df.group_by("group").agg(squared_deviation_sum).schema == OrderedDict(1035[("group", pl.Int64), ("x", pl.Float64)]1036)103710381039def test_group_by_when_then_no_aggregation_predicate() -> None:1040df = pl.DataFrame(1041{1042"key": ["aa", "aa", "bb", "bb", "aa", "aa"],1043"val": [-3, -2, 1, 4, -3, 5],1044}1045)1046assert df.group_by("key").agg(1047pos=pl.when(pl.col("val") >= 0).then(pl.col("val")).sum(),1048neg=pl.when(pl.col("val") < 0).then(pl.col("val")).sum(),1049).sort("key").to_dict(as_series=False) == {1050"key": ["aa", "bb"],1051"pos": [5, 5],1052"neg": [-8, 0],1053}105410551056def test_group_by_apply_first_input_is_literal() -> None:1057df = pl.DataFrame({"x": [1, 2, 3, 4, 5], "g": [1, 1, 2, 2, 2]})1058pow = df.group_by("g").agg(2 ** pl.col("x"))1059assert pow.sort("g").to_dict(as_series=False) == {1060"g": [1, 2],1061"literal": [[2.0, 4.0], [8.0, 16.0, 32.0]],1062}106310641065def test_group_by_all_12869() -> None:1066df = pl.DataFrame({"a": [1]})1067result = next(iter(df.group_by(pl.all())))[1]1068assert_frame_equal(df, result)106910701071def test_group_by_named() -> None:1072df = pl.DataFrame({"a": [1, 1, 2, 2, 3, 3], "b": range(6)})1073result = df.group_by(z=pl.col("a") * 2, maintain_order=True).agg(pl.col("b").min())1074expected = df.group_by((pl.col("a") * 2).alias("z"), maintain_order=True).agg(1075pl.col("b").min()1076)1077assert_frame_equal(result, expected)107810791080def test_group_by_with_null() -> None:1081df = pl.DataFrame(1082{"a": [None, None, None, None], "b": [1, 1, 2, 2], "c": ["x", "y", "z", "u"]}1083)1084expected = pl.DataFrame(1085{"a": [None, None], "b": [1, 2], "c": [["x", "y"], ["z", "u"]]}1086)1087output = df.group_by(["a", "b"], maintain_order=True).agg(pl.col("c"))1088assert_frame_equal(expected, output)108910901091def test_partitioned_group_by_14954(plmonkeypatch: PlMonkeyPatch) -> None:1092plmonkeypatch.setenv("POLARS_FORCE_PARTITION", "1")1093assert (1094pl.DataFrame({"a": range(20)})1095.select(pl.col("a") % 2)1096.group_by("a")1097.agg(1098(pl.col("a") > 1000).alias("a > 1000"),1099)1100).sort("a").to_dict(as_series=False) == {1101"a": [0, 1],1102"a > 1000": [1103[False, False, False, False, False, False, False, False, False, False],1104[False, False, False, False, False, False, False, False, False, False],1105],1106}110711081109def test_partitioned_group_by_nulls_mean_21838() -> None:1110size = 101111a = [1 for i in range(size)] + [2 for i in range(size)] + [3 for i in range(size)]1112b = [1 for i in range(size)] + [None for i in range(size * 2)]1113df = pl.DataFrame({"a": a, "b": b})1114assert df.group_by("a").mean().sort("a").to_dict(as_series=False) == {1115"a": [1, 2, 3],1116"b": [1.0, None, None],1117}111811191120def test_aggregated_scalar_elementwise_15602() -> None:1121df = pl.DataFrame({"group": [1, 2, 1]})11221123out = df.group_by("group", maintain_order=True).agg(1124foo=pl.col("group").is_between(1, pl.max("group"))1125)1126expected = pl.DataFrame({"group": [1, 2], "foo": [[True, True], [True]]})1127assert_frame_equal(out, expected)112811291130def test_group_by_multiple_null_cols_15623() -> None:1131df = pl.DataFrame(schema={"a": pl.Null, "b": pl.Null}).group_by(pl.all()).len()1132assert df.is_empty()113311341135@pytest.mark.release1136def test_categorical_vs_str_group_by() -> None:1137# this triggers the perfect hash table1138s = pl.Series("a", np.random.randint(0, 50, 100))1139s_with_nulls = pl.select(1140pl.when(s < 3).then(None).otherwise(s).alias("a")1141).to_series()11421143for s_ in [s, s_with_nulls]:1144s_ = s_.cast(str)1145cat_out = (1146s_.cast(pl.Categorical)1147.to_frame("a")1148.group_by("a")1149.agg(pl.first().alias("first"))1150)11511152str_out = s_.to_frame("a").group_by("a").agg(pl.first().alias("first"))1153cat_out.with_columns(pl.col("a").cast(str))1154assert_frame_equal(1155cat_out.with_columns(1156pl.col("a").cast(str), pl.col("first").cast(pl.List(str))1157).sort("a"),1158str_out.sort("a"),1159)116011611162@pytest.mark.release1163def test_boolean_min_max_agg() -> None:1164np.random.seed(0)1165idx = np.random.randint(0, 500, 1000)1166c = np.random.randint(0, 500, 1000) > 25011671168df = pl.DataFrame({"idx": idx, "c": c})1169aggs = [pl.col("c").min().alias("c_min"), pl.col("c").max().alias("c_max")]11701171result = df.group_by("idx").agg(aggs).sum()11721173schema = {"idx": pl.Int64, "c_min": pl.UInt32, "c_max": pl.UInt32}1174expected = pl.DataFrame(1175{1176"idx": [107583],1177"c_min": [120],1178"c_max": [321],1179},1180schema=schema,1181)1182assert_frame_equal(result, expected)11831184nulls = np.random.randint(0, 500, 1000) < 10011851186result = (1187df.with_columns(c=pl.when(pl.lit(nulls)).then(None).otherwise(pl.col("c")))1188.group_by("idx")1189.agg(aggs)1190.sum()1191)11921193expected = pl.DataFrame(1194{1195"idx": [107583],1196"c_min": [133],1197"c_max": [276],1198},1199schema=schema,1200)1201assert_frame_equal(result, expected)120212031204def test_partitioned_group_by_chunked(partition_limit: int) -> None:1205n = partition_limit1206df1 = pl.DataFrame(np.random.randn(n, 2))1207df2 = pl.DataFrame(np.random.randn(n, 2))1208gps = pl.Series(name="oo", values=[0] * n + [1] * n)1209df = pl.concat([df1, df2], rechunk=False)1210assert_frame_equal(1211df.group_by(gps).sum().sort("oo"),1212df.rechunk().group_by(gps, maintain_order=True).sum(),1213)121412151216def test_schema_on_agg() -> None:1217lf = pl.LazyFrame({"a": ["x", "x", "y", "n"], "b": [1, 2, 3, 4]})12181219result = lf.group_by("a").agg(1220pl.col("b").min().alias("min"),1221pl.col("b").max().alias("max"),1222pl.col("b").sum().alias("sum"),1223pl.col("b").first().alias("first"),1224pl.col("b").last().alias("last"),1225pl.col("b").item().alias("item"),1226)1227expected_schema = {1228"a": pl.String,1229"min": pl.Int64,1230"max": pl.Int64,1231"sum": pl.Int64,1232"first": pl.Int64,1233"last": pl.Int64,1234"item": pl.Int64,1235}1236assert result.collect_schema() == expected_schema123712381239def test_group_by_schema_err() -> None:1240lf = pl.LazyFrame({"foo": [None, 1, 2], "bar": [1, 2, 3]})1241with pytest.raises(ColumnNotFoundError):1242lf.group_by("not-existent").agg(1243pl.col("bar").max().alias("max_bar")1244).collect_schema()124512461247@pytest.mark.parametrize(1248("data", "expr", "expected_select", "expected_gb"),1249[1250(1251{"x": ["x"], "y": ["y"]},1252pl.coalesce(pl.col("x"), pl.col("y")),1253{"x": pl.String},1254{"x": pl.List(pl.String)},1255),1256(1257{"x": [True]},1258pl.col("x").sum(),1259{"x": pl.get_index_type()},1260{"x": pl.get_index_type()},1261),1262(1263{"a": [[1, 2]]},1264pl.col("a").list.sum(),1265{"a": pl.Int64},1266{"a": pl.List(pl.Int64)},1267),1268],1269)1270def test_schemas(1271data: dict[str, list[Any]],1272expr: pl.Expr,1273expected_select: dict[str, PolarsDataType],1274expected_gb: dict[str, PolarsDataType],1275) -> None:1276df = pl.DataFrame(data)12771278# test selection schema1279schema = df.select(expr).schema1280for key, dtype in expected_select.items():1281assert schema[key] == dtype12821283# test group_by schema1284schema = df.group_by(pl.lit(1)).agg(expr).schema1285for key, dtype in expected_gb.items():1286assert schema[key] == dtype128712881289def test_lit_iter_schema() -> None:1290df = pl.DataFrame(1291{1292"key": ["A", "A", "A", "A"],1293"dates": [1294date(1970, 1, 1),1295date(1970, 1, 1),1296date(1970, 1, 2),1297date(1970, 1, 3),1298],1299}1300)13011302result = df.group_by("key").agg(pl.col("dates").unique() + timedelta(days=1))1303expected = {1304"key": ["A"],1305"dates": [[date(1970, 1, 2), date(1970, 1, 3), date(1970, 1, 4)]],1306}1307assert result.to_dict(as_series=False) == expected130813091310def test_absence_off_null_prop_8224() -> None:1311# a reminder to self to not do null propagation1312# it is inconsistent and makes output dtype1313# dependent of the data, big no!13141315def sub_col_min(column: str, min_column: str) -> pl.Expr:1316return pl.col(column).sub(pl.col(min_column).min())13171318df = pl.DataFrame(1319{1320"group": [1, 1, 2, 2],1321"vals_num": [10.0, 11.0, 12.0, 13.0],1322"vals_partial": [None, None, 12.0, 13.0],1323"vals_null": [None, None, None, None],1324}1325)13261327q = (1328df.lazy()1329.group_by("group")1330.agg(1331sub_col_min("vals_num", "vals_num").alias("sub_num"),1332sub_col_min("vals_num", "vals_partial").alias("sub_partial"),1333sub_col_min("vals_num", "vals_null").alias("sub_null"),1334)1335)13361337assert q.collect().dtypes == [1338pl.Int64,1339pl.List(pl.Float64),1340pl.List(pl.Float64),1341pl.List(pl.Float64),1342]134313441345@pytest.mark.parametrize("maintain_order", [False, True])1346def test_grouped_slice_literals(maintain_order: bool) -> None:1347df = pl.DataFrame({"idx": [1, 2, 3]})1348q = (1349df.lazy()1350.group_by(True, maintain_order=maintain_order)1351.agg(1352x=pl.lit([1, 2]).slice(1353-1, 11354), # slices a list of 1 element, so remains the same element1355x2=pl.lit(pl.Series([1, 2])).slice(-1, 1),1356x3=pl.lit(pl.Series([[1, 2]])).slice(-1, 1),1357)1358)1359out = q.collect()1360expected = pl.DataFrame(1361{"literal": [True], "x": [[[1, 2]]], "x2": [[2]], "x3": [[[1, 2]]]}1362)1363assert_frame_equal(1364out,1365expected,1366check_row_order=maintain_order,1367)1368assert q.collect_schema() == q.collect().schema136913701371def test_positional_by_with_list_or_tuple_17540() -> None:1372with pytest.raises(TypeError, match="Hint: if you"):1373pl.DataFrame({"a": [1, 2, 3]}).group_by(by=["a"])1374with pytest.raises(TypeError, match="Hint: if you"):1375pl.LazyFrame({"a": [1, 2, 3]}).group_by(by=["a"])137613771378def test_group_by_agg_19173() -> None:1379df = pl.DataFrame({"x": [1.0], "g": [0]})1380out = df.head(0).group_by("g").agg((pl.col.x - pl.col.x.sum() * pl.col.x) ** 2)1381assert out.to_dict(as_series=False) == {"g": [], "x": []}1382assert out.schema == pl.Schema([("g", pl.Int64), ("x", pl.List(pl.Float64))])138313841385def test_group_by_map_groups_slice_pushdown_20002() -> None:1386schema = {1387"a": pl.Int8,1388"b": pl.UInt8,1389}13901391df = (1392pl.LazyFrame(1393data={"a": [1, 2, 3, 4, 5], "b": [90, 80, 70, 60, 50]},1394schema=schema,1395)1396.group_by("a", maintain_order=True)1397.map_groups(lambda df: df * 2.0, schema=schema)1398.head(3)1399.collect()1400)14011402assert_frame_equal(1403df,1404pl.DataFrame(1405{1406"a": [2.0, 4.0, 6.0],1407"b": [180.0, 160.0, 140.0],1408}1409),1410)141114121413@typing.no_type_check1414def test_group_by_lit_series(capfd: Any, plmonkeypatch: PlMonkeyPatch) -> None:1415plmonkeypatch.setenv("POLARS_VERBOSE", "1")1416n = 101417df = pl.DataFrame({"x": np.ones(2 * n), "y": n * list(range(2))})1418a = np.ones(n, dtype=float)1419df.lazy().group_by("y").agg(pl.col("x").dot(a)).collect()1420captured = capfd.readouterr().err1421assert "are not partitionable" in captured142214231424def test_group_by_list_column() -> None:1425df = pl.DataFrame({"a": [1, 2, 3], "b": [[1, 2], [3], [1, 2]]})1426result = df.group_by("b").agg(pl.sum("a")).sort("b")1427expected = pl.DataFrame({"b": [[1, 2], [3]], "a": [4, 2]})1428assert_frame_equal(result, expected)142914301431def test_enum_perfect_group_by_21360() -> None:1432dtype = pl.Enum(categories=["a", "b"])14331434assert_frame_equal(1435pl.from_dicts([{"col": "a"}], schema={"col": dtype})1436.group_by("col")1437.agg(pl.len()),1438pl.DataFrame(1439[1440pl.Series("col", ["a"], dtype),1441pl.Series("len", [1], get_index_type()),1442]1443),1444)144514461447def test_partitioned_group_by_21634(partition_limit: int) -> None:1448n = partition_limit1449df = pl.DataFrame({"grp": [1] * n, "x": [1] * n})1450assert df.group_by("grp", True).agg().to_dict(as_series=False) == {1451"grp": [1],1452"literal": [True],1453}145414551456def test_group_by_cse_dup_key_alias_22238() -> None:1457df = pl.LazyFrame({"a": [1, 1, 2, 2, -1], "x": [0, 1, 2, 3, 10]})1458result = df.group_by(1459pl.col("a").abs(),1460pl.col("a").abs().alias("a_with_alias"),1461).agg(pl.col("x").sum())1462assert_frame_equal(1463result.collect(),1464pl.DataFrame({"a": [1, 2], "a_with_alias": [1, 2], "x": [11, 5]}),1465check_row_order=False,1466)146714681469def test_group_by_22328() -> None:1470N = 2014711472df1 = pl.select(1473x=pl.repeat(1, N // 2).append(pl.repeat(2, N // 2)).shuffle(),1474y=pl.lit(3.0, pl.Float32),1475).lazy()14761477df2 = pl.select(x=pl.repeat(4, N)).lazy()14781479assert (1480df2.join(df1.group_by("x").mean().with_columns(z="y"), how="left", on="x")1481.with_columns(pl.col("z").fill_null(0))1482.collect()1483).shape == (20, 3)148414851486@pytest.mark.parametrize("maintain_order", [False, True])1487def test_group_by_arrays_22574(maintain_order: bool) -> None:1488assert_frame_equal(1489pl.Series("a", [[1], [2], [2]], pl.Array(pl.Int64, 1))1490.to_frame()1491.group_by("a", maintain_order=maintain_order)1492.agg(pl.len()),1493pl.DataFrame(1494[1495pl.Series("a", [[1], [2]], pl.Array(pl.Int64, 1)),1496pl.Series("len", [1, 2], pl.get_index_type()),1497]1498),1499check_row_order=maintain_order,1500)15011502assert_frame_equal(1503pl.Series(1504"a", [[[1, 2]], [[2, 3]], [[2, 3]]], pl.Array(pl.Array(pl.Int64, 2), 1)1505)1506.to_frame()1507.group_by("a", maintain_order=maintain_order)1508.agg(pl.len()),1509pl.DataFrame(1510[1511pl.Series(1512"a", [[[1, 2]], [[2, 3]]], pl.Array(pl.Array(pl.Int64, 2), 1)1513),1514pl.Series("len", [1, 2], pl.get_index_type()),1515]1516),1517check_row_order=maintain_order,1518)151915201521def test_group_by_empty_rows_with_literal_21959() -> None:1522out = (1523pl.LazyFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [1, 1, 3]})1524.filter(pl.col("c") == 99)1525.group_by(pl.lit(1).alias("d"), pl.col("a"), pl.col("b"))1526.agg()1527.collect()1528)1529expected = pl.DataFrame(1530{"d": [], "a": [], "b": []},1531schema={"d": pl.Int32, "a": pl.Int64, "b": pl.Int64},1532)1533assert_frame_equal(out, expected)153415351536def test_group_by_empty_dtype_22716() -> None:1537df = pl.DataFrame(schema={"a": pl.String, "b": pl.Int64})1538out = df.group_by("a").agg(x=(pl.col("b") == pl.int_range(pl.len())).all())1539assert_frame_equal(out, pl.DataFrame(schema={"a": pl.String, "x": pl.Boolean}))154015411542def test_group_by_implode_22870() -> None:1543out = (1544pl.DataFrame({"x": ["a", "b"]})1545.group_by(pl.col.x)1546.agg(1547y=pl.col.x.replace_strict(1548pl.lit(pl.Series(["a", "b"])).implode(),1549pl.lit(pl.Series([1, 2])).implode(),1550default=-1,1551)1552)1553)1554assert_frame_equal(1555out,1556pl.DataFrame({"x": ["a", "b"], "y": [[1], [2]]}),1557check_row_order=False,1558)155915601561# Note: the underlying bug is not guaranteed to manifest itself as it depends1562# on the internal group order, i.e., for the bug to materialize, there must be1563# empty groups before the non-empty group1564def test_group_by_empty_groups_23338() -> None:1565# We need one non-empty and many groups1566df = pl.DataFrame(1567{1568"k": [10, 10, 20, 30, 40, 50, 60, 70, 80, 90],1569"a": [1, 1, 2, 3, 4, 5, 6, 7, 8, 9],1570}1571)1572out = df.group_by("k").agg(1573pl.col("a").filter(pl.col("a") == 1).fill_nan(None).sum()1574)1575expected = df.group_by("k").agg(pl.col("a").filter(pl.col("a") == 1).sum())1576assert_frame_equal(out.sort("k"), expected.sort("k"))157715781579def test_group_by_filter_all_22955() -> None:1580df = pl.DataFrame(1581{1582"grp": [1, 2, 3, 4, 5],1583"value": [10, 20, 30, 40, 50],1584}1585)15861587assert_frame_equal(1588df.group_by("grp").agg(1589pl.all().filter(pl.col("value") > 20),1590),1591pl.DataFrame(1592{1593"grp": [1, 2, 3, 4, 5],1594"value": [[], [], [30], [40], [50]],1595}1596),1597check_row_order=False,1598)159916001601@pytest.mark.parametrize("maintain_order", [False, True])1602def test_group_by_series_lit_22103(maintain_order: bool) -> None:1603df = pl.DataFrame(1604{1605"g": [0, 1],1606}1607)1608assert_frame_equal(1609df.group_by("g", maintain_order=maintain_order).agg(1610foo=pl.lit(pl.Series([42, 2, 3]))1611),1612pl.DataFrame(1613{1614"g": [0, 1],1615"foo": [[42, 2, 3], [42, 2, 3]],1616}1617),1618check_row_order=maintain_order,1619)162016211622@pytest.mark.parametrize("maintain_order", [False, True])1623def test_group_by_filter_sum_23897(maintain_order: bool) -> None:1624testdf = pl.DataFrame(1625{1626"id": [8113, 9110, 9110],1627"value": [None, None, 1.0],1628"weight": [1.0, 1.0, 1.0],1629}1630)16311632w = pl.col("weight").filter(pl.col("value").is_finite())16331634w = w / w.sum()16351636result = w.sum()16371638assert_frame_equal(1639testdf.group_by("id", maintain_order=maintain_order).agg(result),1640pl.DataFrame({"id": [8113, 9110], "weight": [0.0, 1.0]}),1641check_row_order=maintain_order,1642)164316441645@pytest.mark.parametrize("maintain_order", [False, True])1646def test_group_by_shift_filter_23910(maintain_order: bool) -> None:1647df = pl.DataFrame({"a": [3, 7, 5, 9, 2, 1], "b": [2, 2, 2, 3, 3, 1]})16481649out = df.group_by("b", maintain_order=maintain_order).agg(1650pl.col("a").filter(pl.col("a") > pl.col("a").shift(1)).sum().alias("tt")1651)16521653assert_frame_equal(1654out,1655pl.DataFrame(1656{1657"b": [2, 3, 1],1658"tt": [7, 0, 0],1659}1660),1661check_row_order=maintain_order,1662)166316641665@pytest.mark.parametrize("maintain_order", [False, True])1666def test_group_by_having(maintain_order: bool) -> None:1667df = pl.DataFrame(1668{1669"grp": ["A", "A", "B", "B", "C", "C"],1670"value": [10, 15, 5, 15, 5, 10],1671}1672)16731674result = (1675df.group_by("grp", maintain_order=maintain_order)1676.having(pl.col("value").mean() >= 10)1677.agg()1678)1679expected = pl.DataFrame({"grp": ["A", "B"]})1680assert_frame_equal(result, expected, check_row_order=maintain_order)168116821683def test_group_by_tuple_typing_24112() -> None:1684df = pl.DataFrame({"id": ["a", "b", "a"], "val": [1, 2, 3]})1685for (id_,), _ in df.group_by("id"):1686_should_work: str = id_168716881689def test_group_by_input_independent_with_len_23868() -> None:1690out = pl.DataFrame({"a": ["A", "B", "C"]}).group_by(pl.lit("G")).agg(pl.len())1691assert_frame_equal(1692out,1693pl.DataFrame(1694{"literal": "G", "len": 3},1695schema={"literal": pl.String, "len": pl.get_index_type()},1696),1697)169816991700@pytest.mark.parametrize("maintain_order", [False, True])1701def test_group_by_head_tail_24215(maintain_order: bool) -> None:1702df = pl.DataFrame(1703{1704"station": ["A", "A", "B"],1705"num_rides": [1, 2, 3],1706}1707)1708expected = pl.DataFrame(1709{"station": ["A", "B"], "num_rides": [1.5, 3], "rides_per_day": [[1, 2], [3]]}1710)17111712result = (1713df.group_by("station", maintain_order=maintain_order)1714.agg(1715cs.numeric().mean(),1716pl.col("num_rides").alias("rides_per_day"),1717)1718.group_by("station", maintain_order=maintain_order)1719.head(1)1720)1721assert_frame_equal(result, expected, check_row_order=maintain_order)17221723result = (1724df.group_by("station", maintain_order=maintain_order)1725.agg(1726cs.numeric().mean(),1727pl.col("num_rides").alias("rides_per_day"),1728)1729.group_by("station", maintain_order=maintain_order)1730.tail(1)1731)1732assert_frame_equal(result, expected, check_row_order=maintain_order)173317341735def test_slice_group_by_offset_24259() -> None:1736df = pl.DataFrame(1737{1738"letters": ["c", "c", "a", "c", "a", "b", "d"],1739"nrs": [1, 2, 3, 4, 5, 6, None],1740}1741)1742assert df.group_by("letters").agg(1743x=pl.col("nrs").drop_nulls(),1744tail=pl.col("nrs").drop_nulls().tail(1),1745).sort("letters").to_dict(as_series=False) == {1746"letters": ["a", "b", "c", "d"],1747"x": [[3, 5], [6], [1, 2, 4], []],1748"tail": [[5], [6], [4], []],1749}175017511752def test_group_by_first_nondet_24278() -> None:1753values = [175496, 86, 0, 86, 43, 50, 9, 14, 98, 39, 93, 7, 71, 1, 93, 41, 56,175556, 93, 41, 58, 91, 81, 29, 81, 68, 5, 9, 32, 93, 78, 34, 17, 40,175614, 2, 52, 77, 81, 4, 56, 42, 64, 12, 29, 58, 71, 98, 32, 49, 34,175786, 29, 94, 37, 21, 41, 36, 9, 72, 23, 28, 71, 9, 66, 72, 84, 81,175823, 12, 64, 57, 99, 15, 77, 38, 95, 64, 13, 91, 43, 61, 70, 47,175939, 75, 47, 93, 45, 1, 95, 55, 29, 5, 83, 8, 3, 6, 45, 84,1760] # fmt: skip1761q = (1762pl.LazyFrame({"a": values, "idx": range(100)})1763.group_by("a")1764.agg(pl.col.idx.first())1765.select(a=pl.col.idx)1766)17671768fst_value = q.collect().to_series().sum()1769for _ in range(10):1770assert q.collect().to_series().sum() == fst_value177117721773@pytest.mark.parametrize("maintain_order", [False, True])1774def test_group_by_agg_on_lit(maintain_order: bool) -> None:1775fs: list[Callable[[Expr], Expr]] = [1776Expr.min,1777Expr.max,1778Expr.mean,1779Expr.sum,1780Expr.len,1781Expr.count,1782Expr.first,1783Expr.last,1784Expr.n_unique,1785Expr.implode,1786Expr.std,1787Expr.var,1788lambda e: e.quantile(0.5),1789Expr.nan_min,1790Expr.nan_max,1791Expr.skew,1792Expr.null_count,1793Expr.product,1794lambda e: pl.corr(e, e),1795]17961797df = pl.DataFrame({"a": [1, 2], "b": [1, 1]})17981799assert_frame_equal(1800df.group_by("a", maintain_order=maintain_order).agg(1801f(pl.lit(1)).alias(f"c{i}") for i, f in enumerate(fs)1802),1803pl.select(1804[pl.lit(pl.Series("a", [1, 2]))]1805+ [f(pl.lit(1)).alias(f"c{i}") for i, f in enumerate(fs)]1806),1807check_row_order=maintain_order,1808)18091810df = pl.DataFrame({"a": [1, 2], "b": [None, 1]})18111812assert_frame_equal(1813df.group_by("a", maintain_order=maintain_order).agg(1814f(pl.lit(1)).alias(f"c{i}") for i, f in enumerate(fs)1815),1816pl.select(1817[pl.lit(pl.Series("a", [1, 2]))]1818+ [f(pl.lit(1)).alias(f"c{i}") for i, f in enumerate(fs)]1819),1820check_row_order=maintain_order,1821)182218231824def test_group_by_cum_sum_key_24489() -> None:1825df = pl.LazyFrame({"x": [1, 2]})1826out = df.group_by((pl.col.x > 1).cum_sum()).agg().collect()1827expected = pl.DataFrame({"x": [0, 1]}, schema={"x": pl.UInt32})1828assert_frame_equal(out, expected, check_row_order=False)182918301831@pytest.mark.parametrize("maintain_order", [False, True])1832def test_double_aggregations(maintain_order: bool) -> None:1833fs: list[Callable[[pl.Expr], pl.Expr]] = [1834Expr.min,1835Expr.max,1836Expr.mean,1837Expr.sum,1838Expr.len,1839Expr.count,1840Expr.first,1841Expr.last,1842Expr.n_unique,1843Expr.implode,1844Expr.std,1845Expr.var,1846lambda e: e.quantile(0.5),1847Expr.nan_min,1848Expr.nan_max,1849Expr.skew,1850Expr.null_count,1851Expr.product,1852lambda e: pl.corr(e, e),1853]18541855df = pl.DataFrame({"a": [1, 2], "b": [1, 1]})18561857assert_frame_equal(1858df.group_by("a", maintain_order=maintain_order).agg(1859f(pl.col.b).alias(f"c{i}") for i, f in enumerate(fs)1860),1861df.group_by("a", maintain_order=maintain_order).agg(1862f(pl.col.b.first()).alias(f"c{i}") for i, f in enumerate(fs)1863),1864check_row_order=maintain_order,1865)18661867df = pl.DataFrame({"a": [1, 2], "b": [None, 1]})18681869assert_frame_equal(1870df.group_by("a", maintain_order=maintain_order).agg(1871f(pl.col.b).alias(f"c{i}") for i, f in enumerate(fs)1872),1873df.group_by("a", maintain_order=maintain_order).agg(1874f(pl.col.b.first()).alias(f"c{i}") for i, f in enumerate(fs)1875),1876check_row_order=maintain_order,1877)187818791880def test_group_by_length_preserving_on_scalar() -> None:1881df = pl.DataFrame({"a": [[1], [2], [3]]})1882df = df.group_by(pl.lit(1, pl.Int64)).agg(1883a=pl.col.a.first().reverse(),1884b=pl.col.a.first(),1885c=pl.col.a.reverse(),1886d=pl.lit(1, pl.Int64).reverse(),1887e=pl.lit(1, pl.Int64).unique(),1888)18891890assert_frame_equal(1891df,1892pl.DataFrame(1893{1894"literal": [1],1895"a": [[1]],1896"b": [[1]],1897"c": [[[3], [2], [1]]],1898"d": [1],1899"e": [[1]],1900}1901),1902)190319041905def test_group_by_enum_min_max_18394() -> None:1906df = pl.DataFrame(1907{1908"id": ["a", "a", "b", "b", "c", "c"],1909"degree": ["low", "high", "high", "mid", "mid", "low"],1910}1911).with_columns(pl.col("degree").cast(pl.Enum(["low", "mid", "high"])))1912out = df.group_by("id").agg(1913min_degree=pl.col("degree").min(),1914max_degree=pl.col("degree").max(),1915)1916expected = pl.DataFrame(1917{1918"id": ["a", "b", "c"],1919"min_degree": ["low", "mid", "low"],1920"max_degree": ["high", "high", "mid"],1921},1922schema={1923"id": pl.String,1924"min_degree": pl.Enum(["low", "mid", "high"]),1925"max_degree": pl.Enum(["low", "mid", "high"]),1926},1927)1928assert_frame_equal(out, expected, check_row_order=False)192919301931@pytest.mark.parametrize("maintain_order", [False, True])1932def test_group_by_filter_24838(maintain_order: bool) -> None:1933df = pl.DataFrame({"a": [1, 1, 2, 2, 3], "b": [1, 2, 1, 2, 1]})19341935assert_frame_equal(1936df.group_by("a", maintain_order=maintain_order).agg(1937b=pl.lit(2, pl.Int64).filter(pl.col.b != 1)1938),1939pl.DataFrame(1940[1941pl.Series("a", [1, 2, 3], pl.Int64),1942pl.Series("b", [[2], [2], []], pl.List(pl.Int64)),1943]1944),1945check_row_order=maintain_order,1946)194719481949@pytest.mark.parametrize(1950"lhs",1951[1952pl.lit(2),1953pl.col.a,1954pl.col.a.first(),1955pl.col.a.reverse(),1956pl.col.a.fill_null(strategy="forward"),1957],1958)1959@pytest.mark.parametrize(1960"rhs",1961[1962pl.col.b == 3,1963pl.col.b != 3,1964pl.col.b.reverse() == 3,1965pl.col.b.reverse() != 3,1966pl.col.b.fill_null(1) != 3,1967pl.col.b.fill_null(1) == 3,1968pl.lit(True),1969pl.lit(False),1970pl.lit(pl.Series([True])),1971pl.lit(pl.Series([False])),1972pl.lit(pl.Series([True])).first(),1973pl.lit(pl.Series([False])).first(),1974],1975)1976@pytest.mark.parametrize(1977"agg",1978[1979Expr.implode,1980Expr.sum,1981Expr.first,1982],1983)1984def test_group_by_filter_parametric(1985lhs: pl.Expr, rhs: pl.Expr, agg: Callable[[pl.Expr], pl.Expr]1986) -> None:1987df = pl.DataFrame({"a": [1, 1, 2, 2, 3], "b": [1, 2, 1, 2, 1]})1988gb = df.group_by(pl.lit(1)).agg(a=agg(lhs.filter(rhs))).to_series(1)1989gb = gb.rename("a")1990sl = df.select(a=agg(lhs.filter(rhs))).to_series()1991assert_series_equal(gb, sl)199219931994@given(s=series(name="a", min_size=1))1995@pytest.mark.parametrize(1996("expr", "is_scalar", "maintain_order"),1997[1998(pl.Expr.n_unique, True, True),1999(pl.Expr.unique, False, False),2000(lambda e: e.unique(maintain_order=True), False, True),2001],2002)2003def test_group_by_unique_parametric(2004s: pl.Series,2005expr: Callable[[pl.Expr], pl.Expr],2006is_scalar: bool,2007maintain_order: bool,2008) -> None:2009df = s.to_frame()20102011sl = df.select(expr(pl.col.a))2012gb = df.group_by(pl.lit(1)).agg(expr(pl.col.a)).drop("literal")2013if not is_scalar:2014gb = gb.select(pl.col.a.explode())2015assert_frame_equal(sl, gb, check_row_order=maintain_order)20162017# check scalar case2018sl_first = df.select(expr(pl.col.a.first()))2019gb = df.group_by(pl.lit(1)).agg(expr(pl.col.a.first())).drop("literal")2020if not is_scalar:2021gb = gb.select(pl.col.a.explode())2022assert_frame_equal(sl_first, gb, check_row_order=maintain_order)20232024li = df.select(pl.col.a.implode().list.eval(expr(pl.element())))2025li = li.select(pl.col.a.explode())2026assert_frame_equal(sl, li, check_row_order=maintain_order)202720282029@pytest.mark.parametrize(2030"expr",2031[2032pl.Expr.any,2033pl.Expr.all,2034lambda e: e.any(ignore_nulls=False),2035lambda e: e.all(ignore_nulls=False),2036],2037)2038def test_group_by_any_all(expr: Callable[[pl.Expr], pl.Expr]) -> None:2039combinations = [2040[True, None],2041[None, None],2042[False, None],2043[True, True],2044[False, False],2045[True, False],2046]20472048cl = cs.starts_with("x")2049df = pl.DataFrame(2050[pl.Series("g", [1, 1])]2051+ [pl.Series(f"x{i}", c, pl.Boolean()) for i, c in enumerate(combinations)]2052)20532054# verify that we are actually calculating something2055assert len(df.lazy().select(expr(cl)).collect_schema()) == len(combinations)20562057assert_frame_equal(2058df.select(expr(cl)),2059df.group_by(lit=pl.lit(1)).agg(expr(cl)).drop("lit"),2060)20612062assert_frame_equal(2063df.select(expr(cl)),2064df.group_by("g").agg(expr(cl)).drop("g"),2065)20662067assert_frame_equal(2068df.select(expr(cl)),2069df.select(cl.implode().list.agg(expr(pl.element()))),2070)20712072df = pl.Schema({"x": pl.Boolean()}).to_frame()20732074assert_frame_equal(2075df.select(expr(cl)),2076pl.DataFrame({"x": [None]})2077.group_by(lit=pl.lit(1))2078.agg(expr(pl.lit(pl.Series("x", [], pl.Boolean()))))2079.drop("lit"),2080)20812082assert_frame_equal(2083df.select(expr(cl)),2084df.select(cl.implode().list.agg(expr(pl.element()))),2085)208620872088@given(2089s=series(2090name="f",2091dtype=pl.Float64(),2092allow_chunks=False, # bug: See #249602093)2094)2095@pytest.mark.may_fail_auto_streaming # bug: See #249602096def test_group_by_skew_kurtosis(s: pl.Series) -> None:2097df = s.to_frame()20982099exprs: dict[str, Callable[[pl.Expr], pl.Expr]] = {2100"skew": lambda e: e.skew(),2101"skew_b": lambda e: e.skew(bias=False),2102"kurt": lambda e: e.kurtosis(),2103"kurt_f": lambda e: e.kurtosis(fisher=False),2104"kurt_b": lambda e: e.kurtosis(bias=False),2105"kurt_fb": lambda e: e.kurtosis(fisher=False, bias=False),2106}21072108sl = df.select([e(pl.col.f).alias(n) for n, e in exprs.items()])2109if s.len() > 0:2110gb = (2111df.group_by(pl.lit(1))2112.agg([e(pl.col.f).alias(n) for n, e in exprs.items()])2113.drop("literal")2114)2115assert_frame_equal(sl, gb)21162117# check scalar case2118sl_first = df.select([e(pl.col.f.first()).alias(n) for n, e in exprs.items()])2119gb = (2120df.group_by(pl.lit(1))2121.agg([e(pl.col.f.first()).alias(n) for n, e in exprs.items()])2122.drop("literal")2123)2124assert_frame_equal(sl_first, gb)21252126li = df.select(pl.col.f.implode()).select(2127[pl.col.f.list.agg(e(pl.element())).alias(n) for n, e in exprs.items()]2128)2129assert_frame_equal(sl, li)213021312132def test_group_by_rolling_fill_null_25036() -> None:2133frame = pl.DataFrame(2134{2135"date": [date(2013, 1, 1), date(2013, 1, 2), date(2013, 1, 3)] * 2,2136"group": ["A"] * 3 + ["B"] * 3,2137"value": [None, None, 3, 4, None, None],2138}2139)2140result = frame.rolling(index_column="date", period="2d", group_by="group").agg(2141pl.col("value").forward_fill(limit=None).last()2142)21432144expected = pl.DataFrame(2145{2146"group": ["A"] * 3 + ["B"] * 3,2147"date": [date(2013, 1, 1), date(2013, 1, 2), date(2013, 1, 3)] * 2,2148"value": [None, None, 3, 4, 4, None],2149}2150)21512152assert_frame_equal(result, expected)215321542155exprs = [2156pl.col.a,2157pl.col.a.filter(pl.col.a <= 1),2158pl.col.a.first(),2159pl.lit(1).alias("one"),2160pl.lit(pl.Series([1])),2161]216221632164@pytest.mark.parametrize("lhs", exprs)2165@pytest.mark.parametrize("rhs", exprs)2166@pytest.mark.parametrize("op", [pl.Expr.add, pl.Expr.pow])2167def test_group_broadcast_binary_apply_expr_25046(2168lhs: pl.Expr, rhs: pl.Expr, op: Any2169) -> None:2170df = pl.DataFrame({"g": [10, 10, 20], "a": [1, 2, 3]})2171groups = pl.lit(1)2172out = df.group_by(groups).agg((op(lhs, rhs)).implode()).to_series(1)2173expected = df.select((op(lhs, rhs)).implode()).to_series()2174assert_series_equal(out, expected)217521762177def test_group_by_explode_none_dtype_25045() -> None:2178df = pl.DataFrame({"a": [None, None, None], "b": [1.0, 2.0, None]})2179out_a = df.group_by(pl.lit(1)).agg(pl.col.a.explode())2180expected_a = pl.DataFrame({"literal": 1, "a": [[None, None, None]]})2181assert_frame_equal(out_a, expected_a)21822183out_b = df.group_by(pl.lit(1)).agg(pl.col.b.explode())2184assert len(out_a["a"][0]) == len(out_b["b"][0])21852186out_c = df.select(2187pl.coalesce(pl.col.a.explode(), pl.col.b.explode())2188.implode()2189.over(pl.int_range(pl.len()))2190)2191expected_c = pl.DataFrame({"a": [[1.0], [2.0], [None]]})2192assert_frame_equal(out_c, expected_c)219321942195@pytest.mark.parametrize(2196("expr", "is_scalar"),2197[2198(pl.Expr.forward_fill, False),2199(pl.Expr.backward_fill, False),2200(lambda e: e.forward_fill(1), False),2201(lambda e: e.backward_fill(1), False),2202(lambda e: e.forward_fill(2), False),2203(lambda e: e.backward_fill(2), False),2204(lambda e: e.forward_fill().min(), True),2205(lambda e: e.backward_fill().min(), True),2206(lambda e: e.forward_fill().first(), True),2207(lambda e: e.backward_fill().first(), True),2208],2209)2210def test_group_by_forward_backward_fill(2211expr: Callable[[pl.Expr], pl.Expr], is_scalar: bool2212) -> None:2213combinations = [2214[1, None, 2, None, None],2215[None, 1, 2, 3, 4],2216[None, None, None, None, None],2217[1, 2, 3, 4, 5],2218[1, None, 2, 3, 4],2219[None, None, None, None, 1],2220[1, None, None, None, None],2221[None, None, None, 1, None],2222[None, 1, None, None, None],2223]22242225cl = cs.starts_with("x")2226df = pl.DataFrame(2227[pl.Series("g", [1] * 5)]2228+ [pl.Series(f"x{i}", c, pl.Int64()) for i, c in enumerate(combinations)]2229)22302231# verify that we are actually calculating something2232assert len(df.lazy().select(expr(cl)).collect_schema()) == len(combinations)22332234data = df.group_by(lit=pl.lit(1)).agg(expr(cl)).drop("lit")2235if not is_scalar:2236data = data.explode(cs.all())2237assert_frame_equal(df.select(expr(cl)), data)22382239data = df.group_by("g").agg(expr(cl)).drop("g")2240if not is_scalar:2241data = data.explode(cs.all())2242assert_frame_equal(df.select(expr(cl)), data)22432244assert_frame_equal(2245df.select(expr(cl)),2246df.select(cl.implode().list.eval(expr(pl.element())).explode()),2247)22482249df = pl.Schema({"x": pl.Int64()}).to_frame()22502251data = (2252pl.DataFrame({"x": [None]})2253.group_by(lit=pl.lit(1))2254.agg(expr(pl.lit(pl.Series("x", [], pl.Int64()))))2255.drop("lit")2256)2257if not is_scalar:2258data = data.select(cs.all().reshape((-1,)))2259assert_frame_equal(df.select(expr(cl)), data)22602261assert_frame_equal(2262df.select(expr(cl)),2263df.select(cl.implode().list.eval(expr(pl.element())).reshape((-1,))),2264)226522662267@given(s=series())2268def test_group_by_drop_nulls(s: pl.Series) -> None:2269df = s.rename("f").to_frame()22702271data = (2272df.group_by(lit=pl.lit(1))2273.agg(pl.col.f.drop_nulls())2274.drop("lit")2275.select(pl.col.f.reshape((-1,)))2276)2277assert_frame_equal(df.select(pl.col.f.drop_nulls()), data)22782279assert_frame_equal(2280df.select(pl.col.f.drop_nulls()),2281df.select(2282pl.col.f.implode().list.eval(pl.element().drop_nulls()).reshape((-1,))2283),2284)22852286df = pl.Schema({"f": pl.Int64()}).to_frame()22872288data = (2289pl.DataFrame({"x": [None]})2290.group_by(lit=pl.lit(1))2291.agg(pl.lit(pl.Series("f", [], pl.Int64())).drop_nulls())2292.drop("lit")2293)2294data = data.select(cs.all().reshape((-1,)))2295assert_frame_equal(df.select(pl.col.f.drop_nulls()), data)22962297assert_frame_equal(2298df.select(pl.col.f.drop_nulls()),2299df.select(2300pl.col.f.implode().list.eval(pl.element().drop_nulls()).reshape((-1,))2301),2302)230323042305@given(s=series())2306def test_group_by_drop_nans(s: pl.Series) -> None:2307df = s.rename("f").to_frame()23082309data = (2310df.group_by(lit=pl.lit(1))2311.agg(pl.col.f.drop_nans())2312.select(pl.col.f.reshape((-1,)))2313)2314assert_frame_equal(df.select(pl.col.f.drop_nans()), data)23152316assert_frame_equal(2317df.select(pl.col.f.drop_nans()),2318df.select(2319pl.col.f.implode().list.eval(pl.element().drop_nans()).reshape((-1,))2320),2321)23222323df = pl.Schema({"f": pl.Int64()}).to_frame()23242325data = (2326pl.DataFrame({"x": [None]})2327.group_by(lit=pl.lit(1))2328.agg(pl.lit(pl.Series("f", [], pl.Int64())).drop_nans())2329.drop("lit")2330)2331data = data.select(cs.all().reshape((-1,)))2332assert_frame_equal(df.select(pl.col.f.drop_nans()), data)23332334assert_frame_equal(2335df.select(pl.col.f.drop_nans()),2336df.select(2337pl.col.f.implode().list.eval(pl.element().drop_nans()).reshape((-1,))2338),2339)234023412342@given(2343df=dataframes(2344min_size=1,2345include_cols=[column(name="key", dtype=pl.UInt8, allow_null=False)],2346),2347)2348@pytest.mark.parametrize(2349("expr", "check_order", "returns_scalar", "length_preserving", "is_window"),2350[2351(pl.Expr.unique, False, False, False, False),2352(lambda e: e.unique(maintain_order=True), True, False, False, False),2353(pl.Expr.drop_nans, True, False, False, False),2354(pl.Expr.drop_nulls, True, False, False, False),2355(pl.Expr.null_count, True, False, False, False),2356(pl.Expr.n_unique, True, True, False, False),2357(2358lambda e: e.filter(pl.int_range(0, e.len()) % 3 == 0),2359True,2360False,2361False,2362False,2363),2364(pl.Expr.shift, True, False, True, False),2365(pl.Expr.forward_fill, True, False, True, False),2366(pl.Expr.backward_fill, True, False, True, False),2367(pl.Expr.reverse, True, False, True, False),2368(2369lambda e: (pl.int_range(e.len() - e.len(), e.len()) % 3 == 0).any(),2370True,2371True,2372False,2373False,2374),2375(2376lambda e: (pl.int_range(e.len() - e.len(), e.len()) % 3 == 0).all(),2377True,2378True,2379False,2380False,2381),2382(lambda e: e.head(2), True, False, False, False),2383(pl.Expr.first, True, True, False, False),2384(pl.Expr.mode, False, False, False, False),2385(lambda e: e.fill_null(e.first()).over(e), True, False, True, True),2386(lambda e: e.first().over(e), True, False, True, True),2387(2388lambda e: e.fill_null(e.first()).over(e, mapping_strategy="join"),2389True,2390False,2391True,2392True,2393),2394(2395lambda e: e.fill_null(e.first()).over(e, mapping_strategy="explode"),2396True,2397False,2398False,2399True,2400),2401(2402lambda e: e.fill_null(strategy="forward").over([e, e]),2403True,2404False,2405True,2406True,2407),2408(lambda e: e.fill_null(e.first()).over(e, order_by=e), True, False, True, True),2409(2410lambda e: e.fill_null(e.first()).over(e, order_by=e, descending=True),2411True,2412False,2413True,2414True,2415),2416(2417lambda e: e.gather(pl.int_range(0, e.len()).slice(1, 3)),2418True,2419False,2420False,2421False,2422),2423],2424)2425def test_grouped_agg_parametric(2426df: pl.DataFrame,2427expr: Callable[[pl.Expr], pl.Expr],2428check_order: bool,2429returns_scalar: bool,2430length_preserving: bool,2431is_window: bool,2432) -> None:2433types: dict[str, tuple[Callable[[pl.Expr], pl.Expr], bool, bool]] = {2434"basic": (lambda e: e, False, True),2435}24362437if not is_window:2438types["first"] = (pl.Expr.first, True, False)2439types["slice"] = (lambda e: e.slice(1, 3), False, False)2440types["impl_expl"] = (lambda e: e.implode().explode(), False, False)2441types["rolling"] = (2442lambda e: e.rolling(pl.row_index(), period="3i"),2443False,2444True,2445)2446types["over"] = (lambda e: e.forward_fill().over(e), False, True)24472448def slit(s: pl.Series) -> pl.Expr:2449import polars._plr as plr24502451return pl.Expr._from_pyexpr(plr.lit(s._s, False, is_scalar=True))24522453df = df.with_columns(pl.col.key % 4)2454gb = df.group_by("key").agg(2455*[2456expr(t(~cs.by_name("key"))).name.prefix(f"{k}_")2457for k, (t, _, _) in types.items()2458],2459*[2460expr(slit(df[c].head(1))).alias(f"literal_{c}")2461for c in filter(lambda c: c != "key", df.columns)2462],2463)2464ls = (2465df.group_by("key")2466.agg(pl.all())2467.select(2468pl.col.key,2469*[2470(~cs.by_name("key"))2471.list.agg(expr(t(pl.element())))2472.name.prefix(f"{k}_")2473for k, (t, _, _) in types.items()2474],2475*[2476pl.col(c).list.agg(expr(slit(df[c].head(1)))).alias(f"literal_{c}")2477for c in filter(lambda c: c != "key", df.columns)2478],2479)2480)24812482if not is_window:2483types["literal"] = (lambda e: e, True, False)24842485def verify_index(i: int) -> None:2486idx_df = df.filter(pl.col.key == pl.lit(i, pl.UInt8))2487idx_gb = gb.filter(pl.col.key == pl.lit(i, pl.UInt8))2488idx_ls = ls.filter(pl.col.key == pl.lit(i, pl.UInt8))24892490for col in df.columns:2491if col == "key":2492continue24932494for k, (t, t_is_scalar, t_is_length_preserving) in types.items():2495c = f"{k}_{col}"24962497if k == "literal":2498df_s = idx_df.select(2499expr(t(slit(df[col].head(1)))).alias(c)2500).to_series()2501else:2502df_s = idx_df.select(expr(t(pl.col(col))).alias(c)).to_series()25032504gb_s = idx_gb[c]2505ls_s = idx_ls[c]25062507result_is_scalar = False2508result_is_scalar |= returns_scalar and t_is_length_preserving2509result_is_scalar |= t_is_scalar and length_preserving2510result_is_scalar &= not is_window25112512if not result_is_scalar:2513gb_s = gb_s.explode(empty_as_null=False)2514ls_s = ls_s.explode(empty_as_null=False)25152516assert_series_equal(df_s, gb_s, check_order=check_order)2517assert_series_equal(df_s, ls_s, check_order=check_order)25182519if 0 in df["key"]:2520verify_index(0)2521if 1 in df["key"]:2522verify_index(1)2523if 2 in df["key"]:2524verify_index(2)2525if 3 in df["key"]:2526verify_index(3)252725282529@pytest.mark.parametrize("maintain_order", [False, True])2530@pytest.mark.parametrize(2531("df", "out"),2532[2533(2534pl.DataFrame(2535{2536"key": [0, 0, 0, 0, 1],2537"a": [True, False, False, False, False],2538}2539).with_columns(2540a=pl.when(pl.Series([False, False, False, False, True])).then(pl.col.a)2541),2542pl.DataFrame(2543{2544"key": [0, 1],2545"a": [1, 1],2546},2547schema_overrides={"a": pl.get_index_type()},2548),2549),2550(2551pl.DataFrame(2552{2553"key": [0, 0, 1, 1],2554"a": [False, False, False, False],2555}2556).with_columns(2557a=pl.when(pl.Series([False, False, True, True])).then(pl.col.a)2558),2559pl.DataFrame(2560{2561"key": [0, 1],2562"a": [1, 1],2563},2564schema_overrides={"a": pl.get_index_type()},2565),2566),2567],2568)2569def test_n_unique_masked_bools(2570maintain_order: bool, df: pl.DataFrame, out: pl.DataFrame2571) -> None:2572df = df25732574assert_frame_equal(2575df.group_by("key", maintain_order=maintain_order).agg(pl.col.a.n_unique()),2576out,2577check_row_order=maintain_order,2578)2579assert_frame_equal(2580df.group_by("key", maintain_order=maintain_order)2581.agg(pl.col.a)2582.with_columns(pl.col.a.list.agg(pl.element().n_unique())),2583out,2584check_row_order=maintain_order,2585)258625872588@pytest.mark.parametrize("maintain_order", [False, True])2589@pytest.mark.parametrize("stable", [False, True])2590def test_group_bool_unique_25267(maintain_order: bool, stable: bool) -> None:2591df = pl.DataFrame(2592{2593"id": ["A", "A", "B", "B", "C", "C"],2594"str_values": ["D", "E", "F", "F", "G", "G"],2595"bool_values": [True, False, True, True, False, False],2596}2597)25982599gb = df.group_by("id", maintain_order=maintain_order).agg(2600pl.col("str_values", "bool_values").unique(maintain_order=stable),2601)26022603ls = (2604df.group_by("id", maintain_order=maintain_order)2605.agg("str_values", "bool_values")2606.with_columns(2607pl.col("str_values", "bool_values").list.agg(2608pl.element().unique(maintain_order=stable)2609)2610)2611)26122613for i in ["A", "B", "C"]:2614for c in ["str_values", "bool_values"]:2615df_s = (2616df.select(pl.col(c).filter(pl.col.id == pl.lit(i)))2617.to_series()2618.unique(maintain_order=stable)2619)2620gb_s = gb.select(2621pl.col(c).filter(pl.col.id == pl.lit(i)).reshape((-1,))2622).to_series()2623ls_s = ls.select(2624pl.col(c).filter(pl.col.id == pl.lit(i)).reshape((-1,))2625).to_series()26262627assert_series_equal(df_s, gb_s, check_order=stable)2628assert_series_equal(df_s, ls_s, check_order=stable)262926302631@pytest.mark.parametrize("group_as_slice", [False, True])2632@pytest.mark.parametrize("n", [10, 100, 519])2633@pytest.mark.parametrize(2634"dtype", [pl.Int32, pl.Boolean, pl.String, pl.Categorical, pl.List(pl.Int32)]2635)2636def test_group_by_first_last(2637group_as_slice: bool, n: int, dtype: PolarsDataType2638) -> None:2639group_by_first_last_test_impl(group_as_slice, n, dtype)264026412642@pytest.mark.slow2643@pytest.mark.parametrize("group_as_slice", [False, True])2644@pytest.mark.parametrize("n", [1056, 10_432])2645@pytest.mark.parametrize(2646"dtype", [pl.Int32, pl.Boolean, pl.String, pl.Categorical, pl.List(pl.Int32)]2647)2648def test_group_by_first_last_big(2649group_as_slice: bool, n: int, dtype: PolarsDataType2650) -> None:2651group_by_first_last_test_impl(group_as_slice, n, dtype)265226532654def group_by_first_last_test_impl(2655group_as_slice: bool, n: int, dtype: PolarsDataType2656) -> None:2657idx = pl.Series([1, 2, 3, 4, 5], dtype=pl.Int32)26582659lf = pl.LazyFrame(2660{2661"idx": pl.Series(2662[1] * n + [2] * n + [3] * n + [4] * n + [5] * n, dtype=pl.Int322663),2664# Each successive group has an additional None spanning the elements2665"a": pl.Series(2666[2667*[None] * 0, *list(range(1, n + 1)), *[None] * 0, # idx = 12668*[None] * 1, *list(range(2, n - 0)), *[None] * 1, # idx = 22669*[None] * 2, *list(range(3, n - 1)), *[None] * 2, # idx = 32670*[None] * 3, *list(range(4, n - 2)), *[None] * 3, # idx = 42671*[None] * 4, *list(range(5, n - 3)), *[None] * 4, # idx = 52672],2673dtype=pl.Int32,2674),2675}2676) # fmt: skip2677if group_as_slice:2678lf = lf.set_sorted("idx") # Use GroupSlice path26792680if dtype == pl.Categorical:2681# for Categorical, we must first go through String2682lf = lf.with_columns(pl.col("a").cast(pl.String))2683lf = lf.with_columns(pl.col("a").cast(dtype))26842685# first()2686result = lf.group_by("idx", maintain_order=True).agg(pl.col("a").first()).collect()2687expected_vals = pl.Series([1, None, None, None, None])2688if dtype == pl.Categorical:2689# for Categorical, we must first go through String2690expected_vals = expected_vals.cast(pl.String)26912692expected_vals = expected_vals.cast(dtype)2693expected = pl.DataFrame({"idx": idx, "a": expected_vals})2694assert_frame_equal(result, expected)2695result = lf.group_by("idx", maintain_order=True).first().collect()2696assert_frame_equal(result, expected)26972698# first(ignore_nulls=True)2699result = (2700lf.group_by("idx", maintain_order=True)2701.agg(pl.col("a").first(ignore_nulls=True))2702.collect()2703)2704expected_vals = pl.Series([1, 2, 3, 4, 5])2705if dtype == pl.Categorical:2706# for Categorical, we must first go through String2707expected_vals = expected_vals.cast(pl.String)27082709expected_vals = expected_vals.cast(dtype)2710expected = pl.DataFrame({"idx": idx, "a": expected_vals})2711assert_frame_equal(result, expected)2712result = lf.group_by("idx", maintain_order=True).first(ignore_nulls=True).collect()2713assert_frame_equal(result, expected)27142715# last()2716result = lf.group_by("idx", maintain_order=True).agg(pl.col("a").last()).collect()2717expected_vals = pl.Series([n, None, None, None, None])2718if dtype == pl.Categorical:2719# for Categorical, we must first go through String2720expected_vals = expected_vals.cast(pl.String)27212722expected_vals = expected_vals.cast(dtype)2723expected = pl.DataFrame({"idx": idx, "a": expected_vals})2724assert_frame_equal(result, expected)2725result = lf.group_by("idx", maintain_order=True).last().collect()2726assert_frame_equal(result, expected)27272728# last_non_null2729result = (2730lf.group_by("idx", maintain_order=True)2731.agg(pl.col("a").last(ignore_nulls=True))2732.collect()2733)2734expected_vals = pl.Series([n, n - 1, n - 2, n - 3, n - 4])2735if dtype == pl.Categorical:2736# for Categorical, we must first go through String2737expected_vals = expected_vals.cast(pl.String)27382739expected_vals = expected_vals.cast(dtype)2740expected = pl.DataFrame({"idx": idx, "a": expected_vals})2741assert_frame_equal(result, expected)2742result = lf.group_by("idx", maintain_order=True).last(ignore_nulls=True).collect()2743assert_frame_equal(result, expected)27442745# Test with no nulls2746lf = pl.LazyFrame(2747{2748"idx": pl.Series(2749[1] * n + [2] * n + [3] * n + [4] * n + [5] * n, dtype=pl.Int322750),2751# Each successive group has an additional None spanning the elements2752"a": pl.Series(2753[2754*list(range(1, n + 1)), # idx = 12755*list(range(2, n + 2)), # idx = 22756*list(range(3, n + 3)), # idx = 32757*list(range(4, n + 4)), # idx = 42758*list(range(5, n + 5)), # idx = 52759],2760dtype=pl.Int32,2761),2762}2763)2764if group_as_slice:2765lf = lf.set_sorted("idx") # Use GroupSlice path27662767if dtype == pl.Categorical:2768# for Categorical, we must first go through String2769lf = lf.with_columns(pl.col("a").cast(pl.String))2770lf = lf.with_columns(pl.col("a").cast(dtype))27712772# first()2773expected_vals = pl.Series([1, 2, 3, 4, 5])2774if dtype == pl.Categorical:2775# for Categorical, we must first go through String2776expected_vals = expected_vals.cast(pl.String)27772778expected_vals = expected_vals.cast(dtype)2779expected = pl.DataFrame({"idx": idx, "a": expected_vals})2780result = lf.group_by("idx", maintain_order=True).agg(pl.col("a").first()).collect()2781assert_frame_equal(result, expected)2782result = lf.group_by("idx", maintain_order=True).first().collect()2783assert_frame_equal(result, expected)27842785# first_non_null2786result = (2787lf.group_by("idx", maintain_order=True)2788.agg(pl.col("a").first(ignore_nulls=True))2789.collect()2790)2791assert_frame_equal(result, expected)2792result = lf.group_by("idx", maintain_order=True).first(ignore_nulls=True).collect()2793assert_frame_equal(result, expected)27942795# last()2796expected_vals = pl.Series([n, n + 1, n + 2, n + 3, n + 4])2797if dtype == pl.Categorical:2798# for Categorical, we must first go through String2799expected_vals = expected_vals.cast(pl.String)28002801expected_vals = expected_vals.cast(dtype)2802expected = pl.DataFrame({"idx": idx, "a": expected_vals})2803result = lf.group_by("idx", maintain_order=True).agg(pl.col("a").last()).collect()2804assert_frame_equal(result, expected)2805result = lf.group_by("idx", maintain_order=True).last().collect()2806assert_frame_equal(result, expected)28072808# last_non_null2809result = (2810lf.group_by("idx", maintain_order=True)2811.agg(pl.col("a").last(ignore_nulls=True))2812.collect()2813)2814assert_frame_equal(result, expected)2815result = lf.group_by("idx", maintain_order=True).last(ignore_nulls=True).collect()2816assert_frame_equal(result, expected)281728182819def test_sorted_group_by() -> None:2820lf = pl.LazyFrame(2821{2822"a": [1, 1, 2, 2, 3, 3, 3],2823"b": [4, 5, 8, 1, 0, 1, 3],2824}2825)28262827lf1 = lf2828lf2 = lf.set_sorted("a")28292830assert_frame_equal(2831*[2832q.group_by("a")2833.agg(b_first=pl.col.b.first(), b_sum=pl.col.b.sum(), b=pl.col.b)2834.collect(engine="streaming")2835for q in (lf1, lf2)2836],2837check_row_order=False,2838)28392840lf = lf.with_columns(c=pl.col.a.rle_id())2841lf1 = lf2842lf2 = lf.set_sorted("a", "c")28432844assert_frame_equal(2845*[2846q.group_by("a", "c")2847.agg(b_first=pl.col.b.first(), b_sum=pl.col.b.sum(), b=pl.col.b)2848.collect(engine="streaming")2849for q in (lf1, lf2)2850],2851check_row_order=False,2852)285328542855def test_sorted_group_by_slice() -> None:2856lf = (2857pl.DataFrame({"a": [0, 5, 2, 1, 3] * 50})2858.with_row_index()2859.with_columns(pl.col.index // 5)2860.lazy()2861.set_sorted("index")2862.group_by("index", maintain_order=True)2863.agg(pl.col.a.sum() + pl.col.index.first())2864)28652866expected = pl.DataFrame(2867[2868pl.Series("index", range(50), pl.get_index_type()),2869pl.Series("a", range(11, 11 + 50), pl.Int64),2870]2871)28722873assert_frame_equal(lf.head(2).collect(), expected.head(2))2874assert_frame_equal(lf.slice(1, 3).collect(), expected.slice(1, 3))2875assert_frame_equal(lf.tail(2).collect(), expected.tail(2))2876assert_frame_equal(lf.slice(5, 1).collect(), expected.slice(5, 1))2877assert_frame_equal(lf.slice(5, 0).collect(), expected.slice(5, 0))2878assert_frame_equal(lf.slice(2, 1).collect(), expected.slice(2, 1))2879assert_frame_equal(lf.slice(50, 1).collect(), expected.slice(50, 1))2880assert_frame_equal(lf.slice(20, 30).collect(), expected.slice(20, 30))2881assert_frame_equal(lf.slice(20, 30).collect(), expected.slice(20, 30))288228832884def test_agg_first_last_non_null_25405() -> None:2885lf = pl.LazyFrame(2886{2887"a": [1, 1, 1, 1, 2, 2, 2, 2, 2],2888"b": pl.Series([1, 2, 3, None, None, 4, 5, 6, None]),2889}2890)28912892# first2893result = lf.group_by("a", maintain_order=True).agg(2894pl.col("b").first(ignore_nulls=True)2895)2896expected = pl.DataFrame(2897{2898"a": [1, 2],2899"b": [1, 4],2900}2901)2902assert_frame_equal(result.collect(), expected)29032904result = lf.with_columns(pl.col("b").first(ignore_nulls=True).over("a"))2905expected = pl.DataFrame(2906{2907"a": [1, 1, 1, 1, 2, 2, 2, 2, 2],2908"b": [1, 1, 1, 1, 4, 4, 4, 4, 4],2909}2910)2911assert_frame_equal(result.collect(), expected)29122913# last2914result = lf.group_by("a", maintain_order=True).agg(2915pl.col("b").last(ignore_nulls=True)2916)2917expected = pl.DataFrame(2918{2919"a": [1, 2],2920"b": [3, 6],2921}2922)2923assert_frame_equal(result.collect(), expected)29242925result = lf.with_columns(pl.col("b").last(ignore_nulls=True).over("a"))2926expected = pl.DataFrame(2927{2928"a": [1, 1, 1, 1, 2, 2, 2, 2, 2],2929"b": [3, 3, 3, 3, 6, 6, 6, 6, 6],2930}2931)2932assert_frame_equal(result.collect(), expected)293329342935def test_group_by_sum_on_strings_should_error_24659() -> None:2936with pytest.raises(2937InvalidOperationError,2938match=r"`sum`.*operation not supported for dtype.*str",2939):2940pl.DataFrame({"str": ["a", "b"]}).group_by(1).agg(pl.col.str.sum())294129422943@pytest.mark.parametrize("tail", [0, 1, 4, 5, 6, 10])2944def test_unique_head_tail_26429(tail: int) -> None:2945df = pl.DataFrame(2946{2947"x": [1, 2, 3, 4, 5],2948}2949)2950out = df.lazy().unique().tail(tail).collect()2951expected = min(tail, df.height)2952assert len(out) == expected295329542955def test_group_by_cse_alias_26423() -> None:2956df = pl.LazyFrame({"a": [1, 2, 1, 2, 3, 4]})2957result = df.group_by("a").agg(pl.len(), pl.len().alias("len_a")).collect()2958expected = pl.DataFrame(2959{"a": [1, 2, 3, 4], "len": [2, 2, 1, 1], "len_a": [2, 2, 1, 1]},2960schema={2961"a": pl.Int64,2962"len": pl.get_index_type(),2963"len_a": pl.get_index_type(),2964},2965)2966assert_frame_equal(result, expected, check_row_order=False)296729682969