Path: blob/main/py-polars/tests/unit/operations/map/test_map_groups.py
6940 views
from __future__ import annotations12from typing import TYPE_CHECKING, Any34import numpy as np5import pytest67import polars as pl8from polars.exceptions import ComputeError9from polars.testing import assert_frame_equal1011if TYPE_CHECKING:12from collections.abc import Sequence131415def test_map_groups() -> None:16df = pl.DataFrame(17{18"a": ["a", "b", "a", "b", "b", "c"],19"b": [1, 2, 3, 4, 5, 6],20"c": [6, 5, 4, 3, 2, 1],21}22)2324result = df.group_by("a").map_groups(lambda df: df[["c"]].sum())2526expected = pl.DataFrame({"c": [10, 10, 1]})27assert_frame_equal(result, expected, check_row_order=False)282930def test_map_groups_lazy() -> None:31lf = pl.LazyFrame({"a": [1, 1, 3], "b": [1.0, 2.0, 3.0]})3233schema = {"a": pl.Float64, "b": pl.Float64}34result = lf.group_by("a").map_groups(lambda df: df * 2.0, schema=schema)3536expected = pl.LazyFrame({"a": [6.0, 2.0, 2.0], "b": [6.0, 2.0, 4.0]})37assert_frame_equal(result, expected, check_row_order=False)38assert result.collect_schema() == expected.collect_schema()394041def test_map_groups_rolling() -> None:42df = pl.DataFrame(43{44"a": [1, 2, 3, 4, 5],45"b": [1, 2, 3, 4, 5],46}47).set_sorted("a")4849def function(df: pl.DataFrame) -> pl.DataFrame:50return df.select(51pl.col("a").min(),52pl.col("b").max(),53)5455result = df.rolling("a", period="2i").map_groups(function, schema=df.schema)5657expected = pl.DataFrame(58[59pl.Series("a", [1, 1, 2, 3, 4], dtype=pl.Int64),60pl.Series("b", [1, 2, 3, 4, 5], dtype=pl.Int64),61]62)63assert_frame_equal(result, expected)646566def test_map_groups_empty() -> None:67df = pl.DataFrame(schema={"x": pl.Int64})68with pytest.raises(69ComputeError, match=r"cannot group_by \+ apply on empty 'DataFrame'"70):71df.group_by("x").map_groups(lambda x: x)727374def test_map_groups_none() -> None:75df = pl.DataFrame(76{77"g": [1, 1, 1, 2, 2, 2, 5],78"a": [2, 4, 5, 190, 1, 4, 1],79"b": [1, 3, 2, 1, 43, 3, 1],80}81)8283out = (84df.group_by("g", maintain_order=True).agg(85pl.map_groups(86exprs=["a", pl.col("b") ** 4, pl.col("a") / 4],87function=lambda x: x[0] * x[1] + x[2].sum(),88return_dtype=pl.Float64,89returns_scalar=False,90).alias("multiple")91)92)["multiple"]93assert out[0].to_list() == [4.75, 326.75, 82.75]94assert out[1].to_list() == [238.75, 3418849.75, 372.75]9596out_df = df.select(pl.map_batches(exprs=["a", "b"], function=lambda s: s[0] * s[1]))97assert out_df["a"].to_list() == (df["a"] * df["b"]).to_list()9899# check if we can return None100def func(s: Sequence[pl.Series]) -> pl.Series | None:101if s[0][0] == 190:102return None103else:104return s[0].implode()105106out = (107df.group_by("g", maintain_order=True).agg(108pl.map_groups(109exprs=["a", pl.col("b") ** 4, pl.col("a") / 4],110function=func,111return_dtype=pl.self_dtype().wrap_in_list(),112returns_scalar=True,113).alias("multiple")114)115)["multiple"]116assert out[1] is None117118119def test_map_groups_object_output() -> None:120df = pl.DataFrame(121{122"names": ["foo", "ham", "spam", "cheese", "egg", "foo"],123"dates": ["1", "1", "2", "3", "3", "4"],124"groups": ["A", "A", "B", "B", "B", "C"],125}126)127128class Foo:129def __init__(self, payload: Any) -> None:130self.payload = payload131132result = df.group_by("groups").agg(133pl.map_groups(134[pl.col("dates"), pl.col("names")],135lambda s: Foo(dict(zip(s[0], s[1]))),136return_dtype=pl.Object,137returns_scalar=True,138)139)140141assert result.dtypes == [pl.String, pl.Object]142143144def test_map_groups_numpy_output_3057() -> None:145df = pl.DataFrame(146{147"id": [0, 0, 0, 1, 1, 1],148"t": [2.0, 4.3, 5, 10, 11, 14],149"y": [0.0, 1, 1.3, 2, 3, 4],150}151)152153result = df.group_by("id", maintain_order=True).agg(154pl.map_groups(155["y", "t"],156lambda lst: np.mean([lst[0], lst[1]]),157returns_scalar=True,158return_dtype=pl.self_dtype(),159).alias("result")160)161162expected = pl.DataFrame({"id": [0, 1], "result": [2.266666, 7.333333]})163assert_frame_equal(result, expected)164165166def test_map_groups_return_all_null_15260() -> None:167def foo(x: Sequence[pl.Series]) -> pl.Series:168return pl.Series([x[0][0]], dtype=x[0].dtype)169170assert_frame_equal(171pl.DataFrame({"key": [0, 0, 1], "a": [None, None, None]})172.group_by("key")173.agg(174pl.map_groups(175exprs=["a"],176function=foo,177returns_scalar=True,178return_dtype=pl.self_dtype(),179)180)181.sort("key"),182pl.DataFrame({"key": [0, 1], "a": [None, None]}),183)184185186