Path: blob/main/py-polars/tests/unit/operations/map/test_map_groups.py
8415 views
from __future__ import annotations12import re3from typing import TYPE_CHECKING, Any45import numpy as np6import pytest78import polars as pl9from polars.exceptions import ComputeError, ShapeError10from polars.testing import assert_frame_equal1112if TYPE_CHECKING:13from collections.abc import Sequence141516def test_map_groups() -> None:17df = pl.DataFrame(18{19"a": ["a", "b", "a", "b", "b", "c"],20"b": [1, 2, 3, 4, 5, 6],21"c": [6, 5, 4, 3, 2, 1],22}23)2425result = df.group_by("a").map_groups(lambda df: df[["c"]].sum())2627expected = pl.DataFrame({"c": [10, 10, 1]})28assert_frame_equal(result, expected, check_row_order=False)293031def test_map_groups_lazy() -> None:32lf = pl.LazyFrame({"a": [1, 1, 3], "b": [1.0, 2.0, 3.0]})3334schema = {"a": pl.Float64, "b": pl.Float64}35result = lf.group_by("a").map_groups(lambda df: df * 2.0, schema=schema)3637expected = pl.LazyFrame({"a": [6.0, 2.0, 2.0], "b": [6.0, 2.0, 4.0]})38assert_frame_equal(result, expected, check_row_order=False)39assert result.collect_schema() == expected.collect_schema()404142def test_map_groups_rolling() -> None:43df = pl.DataFrame(44{45"a": [1, 2, 3, 4, 5],46"b": [1, 2, 3, 4, 5],47}48).set_sorted("a")4950def function(df: pl.DataFrame) -> pl.DataFrame:51return df.select(52pl.col("a").min(),53pl.col("b").max(),54)5556result = df.rolling("a", period="2i").map_groups(function, schema=df.schema)5758expected = pl.DataFrame(59[60pl.Series("a", [1, 1, 2, 3, 4], dtype=pl.Int64),61pl.Series("b", [1, 2, 3, 4, 5], dtype=pl.Int64),62]63)64assert_frame_equal(result, expected)656667def test_map_groups_empty() -> None:68df = pl.DataFrame(schema={"x": pl.Int64})69with pytest.raises(70ComputeError, match=r"cannot group_by \+ apply on empty 'DataFrame'"71):72df.group_by("x").map_groups(lambda x: x)7374schema = {"x": pl.Int64, "y": pl.Int64}75result = (76df.lazy()77.group_by("x")78.map_groups(lambda df: df.with_columns(pl.col("x").alias("y")), schema=schema)79)8081expected = pl.LazyFrame(schema=schema)82assert_frame_equal(result, expected)83assert result.collect_schema() == expected.collect_schema()848586def test_map_groups_none() -> None:87df = pl.DataFrame(88{89"g": [1, 1, 1, 2, 2, 2, 5],90"a": [2, 4, 5, 190, 1, 4, 1],91"b": [1, 3, 2, 1, 43, 3, 1],92}93)9495out = (96df.group_by("g", maintain_order=True).agg(97pl.map_groups(98exprs=["a", pl.col("b") ** 4, pl.col("a") / 4],99function=lambda x: x[0] * x[1] + x[2].sum(),100return_dtype=pl.Float64,101returns_scalar=False,102).alias("multiple")103)104)["multiple"]105assert out[0].to_list() == [4.75, 326.75, 82.75]106assert out[1].to_list() == [238.75, 3418849.75, 372.75]107108out_df = df.select(pl.map_batches(exprs=["a", "b"], function=lambda s: s[0] * s[1]))109assert out_df["a"].to_list() == (df["a"] * df["b"]).to_list()110111# check if we can return None112def func(s: Sequence[pl.Series]) -> pl.Series | None:113if s[0][0] == 190:114return None115else:116return s[0].implode()117118out = (119df.group_by("g", maintain_order=True).agg(120pl.map_groups(121exprs=["a", pl.col("b") ** 4, pl.col("a") / 4],122function=func,123return_dtype=pl.self_dtype().wrap_in_list(),124returns_scalar=True,125).alias("multiple")126)127)["multiple"]128assert out[1] is None129130131def test_map_groups_object_output() -> None:132df = pl.DataFrame(133{134"names": ["foo", "ham", "spam", "cheese", "egg", "foo"],135"dates": ["1", "1", "2", "3", "3", "4"],136"groups": ["A", "A", "B", "B", "B", "C"],137}138)139140class Foo:141def __init__(self, payload: Any) -> None:142self.payload = payload143144result = df.group_by("groups").agg(145pl.map_groups(146[pl.col("dates"), pl.col("names")],147lambda s: Foo(dict(zip(s[0], s[1], strict=True))),148return_dtype=pl.Object,149returns_scalar=True,150)151)152153assert result.dtypes == [pl.String, pl.Object]154155156def test_map_groups_numpy_output_3057() -> None:157df = pl.DataFrame(158{159"id": [0, 0, 0, 1, 1, 1],160"t": [2.0, 4.3, 5, 10, 11, 14],161"y": [0.0, 1, 1.3, 2, 3, 4],162}163)164165result = df.group_by("id", maintain_order=True).agg(166pl.map_groups(167["y", "t"],168lambda lst: np.mean([lst[0], lst[1]]),169returns_scalar=True,170return_dtype=pl.self_dtype(),171).alias("result")172)173174expected = pl.DataFrame({"id": [0, 1], "result": [2.266666, 7.333333]})175assert_frame_equal(result, expected)176177178def test_map_groups_return_all_null_15260() -> None:179def foo(x: Sequence[pl.Series]) -> pl.Series:180return pl.Series([x[0][0]], dtype=x[0].dtype)181182assert_frame_equal(183pl.DataFrame({"key": [0, 0, 1], "a": [None, None, None]})184.group_by("key")185.agg(186pl.map_groups(187exprs=["a"],188function=foo,189returns_scalar=True,190return_dtype=pl.self_dtype(),191)192)193.sort("key"),194pl.DataFrame({"key": [0, 1], "a": [None, None]}),195)196197198@pytest.mark.parametrize(199("func", "result"),200[201(lambda n: n[0] + n[1], [[85], [85]]),202(lambda _: pl.Series([1, 2, 3]), [[1, 2, 3], [1, 2, 3]]),203],204)205@pytest.mark.parametrize("maintain_order", [True, False])206def test_map_groups_multiple_all_literal(207func: Any, result: list[int], maintain_order: bool208) -> None:209df = pl.DataFrame({"g": [10, 10, 20], "a": [1, 2, 3], "b": [2, 3, 4]})210211q = (212df.lazy()213.group_by(pl.col("g"), maintain_order=maintain_order)214.agg(215pl.map_groups(216exprs=[pl.lit(42).cast(pl.Int64), pl.lit(43).cast(pl.Int64)],217function=func,218return_dtype=pl.Int64,219).alias("out")220)221)222out = q.collect()223expected = pl.DataFrame({"g": [10, 20], "out": result})224assert_frame_equal(out, expected, check_row_order=maintain_order)225226227@pytest.mark.may_fail_auto_streaming # reason: alternate error message228def test_map_groups_multiple_all_literal_elementwise_raises() -> None:229df = pl.DataFrame({"g": [10, 10, 20], "a": [1, 2, 3], "b": [2, 3, 4]})230q = (231df.lazy()232.group_by(pl.col("g"))233.agg(234pl.map_groups(235exprs=[pl.lit(42), pl.lit(43)],236function=lambda _: pl.Series([1, 2, 3]),237return_dtype=pl.Int64,238is_elementwise=True,239).alias("out")240)241)242msg = "elementwise expression dyn int: 42.python_udf([dyn int: 43]) must return exactly 1 value on literals, got 3"243with pytest.raises(ComputeError, match=re.escape(msg)):244q.collect(engine="in-memory")245246# different error message in streaming, not specific to the problem247with pytest.raises(ShapeError):248q.collect(engine="streaming")249250251def test_nested_query_with_streaming_dispatch_25172() -> None:252def simple(_: Any) -> pl.Series:253import io254255pl.LazyFrame({}).sink_parquet(256pl.PartitionBy(257"", file_path_provider=lambda _: io.BytesIO(), max_rows_per_file=1258),259)260261return pl.Series([1])262263assert_frame_equal(264pl.LazyFrame({"a": ["A", "B"] * 1000, "b": [1] * 2000})265.group_by("a")266.agg(pl.map_groups(["b"], simple, pl.Int64(), returns_scalar=True))267.collect(engine="in-memory")268.sort("a"),269pl.DataFrame({"a": ["A", "B"], "b": [1, 1]}, schema_overrides={"b": pl.Int64}),270)271272273def test_map_groups_with_slice_25805() -> None:274schema = {"a": pl.Int8, "b": pl.Int8}275276df = (277pl.LazyFrame(278data={"a": [1, 1], "b": [1, 2]},279schema=schema,280)281.group_by("a", maintain_order=True)282.map_groups(lambda df: df, schema=schema)283.head(1)284.collect()285)286assert_frame_equal(df, pl.DataFrame({"a": [1], "b": [1]}, schema=schema))287288289