Path: blob/main/py-polars/tests/unit/operations/map/test_map_batches.py
6940 views
from __future__ import annotations12from functools import reduce34import numpy as np5import pytest67import polars as pl8from polars.exceptions import ComputeError, InvalidOperationError9from polars.testing import assert_frame_equal101112@pytest.mark.may_fail_cloud # reason: eager - return_dtype must be set13def test_map_return_py_object() -> None:14df = pl.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})1516result = df.select(17[18pl.all().map_batches(19lambda s: reduce(lambda a, b: a + b, s), returns_scalar=True20)21]22)2324expected = pl.DataFrame({"A": [6], "B": [15]})25assert_frame_equal(result, expected)262728@pytest.mark.may_fail_cloud # reason: eager - return_dtype must be set29def test_map_no_dtype_set_8531() -> None:30df = pl.DataFrame({"a": [1]})3132result = df.with_columns(33pl.col("a").map_batches(lambda x: x * 2).shift(n=0, fill_value=0)34)3536expected = pl.DataFrame({"a": [2]})37assert_frame_equal(result, expected)383940def test_error_on_reducing_map() -> None:41df = pl.DataFrame(42{"id": [0, 0, 0, 1, 1, 1], "t": [2, 4, 5, 10, 11, 14], "y": [0, 1, 1, 2, 3, 4]}43)44assert_frame_equal(45df.group_by("id").agg(46pl.map_batches(["t", "y"], np.mean, pl.Float64(), returns_scalar=True)47),48pl.DataFrame(49{50"id": [0, 1],51"t": [2.166667, 7.333333],52}53),54check_row_order=False,55)5657df = pl.DataFrame({"x": [1, 2, 3, 4], "group": [1, 2, 1, 2]})5859with pytest.raises(60InvalidOperationError,61match=(62r"output length of `map` \(1\) must be equal to "63r"the input length \(4\); consider using `apply` instead"64),65):66df.select(67pl.col("x")68.map_batches(69lambda x: pl.Series(70[x.cut(breaks=[1, 2, 3], include_breaks=True).struct.unnest()]71),72is_elementwise=True,73)74.over("group")75)767778def test_map_batches_group() -> None:79df = pl.DataFrame(80{"id": [0, 0, 0, 1, 1, 1], "t": [2, 4, 5, 10, 11, 14], "y": [0, 1, 1, 2, 3, 4]}81)82with pytest.raises(83TypeError,84match="`map` with `returns_scalar=False` must return a Series; found 'int'",85):86df.group_by("id").agg(87pl.col("t").map_batches(lambda s: s.sum(), return_dtype=pl.self_dtype())88)89# If returns_scalar is True, the result won't be wrapped in a list:90assert df.group_by("id").agg(91pl.col("t").map_batches(92lambda s: s.sum(), returns_scalar=True, return_dtype=pl.self_dtype()93)94).sort("id").to_dict(as_series=False) == {"id": [0, 1], "t": [11, 35]}959697def test_ufunc_args() -> None:98df = pl.DataFrame({"a": [1, 2, 3], "b": [2, 4, 6]})99result = df.select(100z=np.add(pl.col("a"), pl.col("b")) # type: ignore[call-overload]101)102expected = pl.DataFrame({"z": [3, 6, 9]})103assert_frame_equal(result, expected)104result = df.select(z=np.add(2, pl.col("a"))) # type: ignore[call-overload]105expected = pl.DataFrame({"z": [3, 4, 5]})106assert_frame_equal(result, expected)107108109def test_lazy_map_schema() -> None:110df = pl.DataFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]})111112# identity113assert_frame_equal(df.lazy().map_batches(lambda x: x).collect(), df)114115def custom(df: pl.DataFrame) -> pl.Series:116return df["a"]117118with pytest.raises(119ComputeError,120match="Expected 'LazyFrame.map' to return a 'DataFrame', got a",121):122df.lazy().map_batches(custom).collect() # type: ignore[arg-type]123124def custom2(125df: pl.DataFrame,126) -> pl.DataFrame:127# changes schema128return df.select(pl.all().cast(pl.String))129130with pytest.raises(131ComputeError,132match="The output schema of 'LazyFrame.map' is incorrect. Expected",133):134df.lazy().map_batches(custom2).collect()135136assert df.lazy().map_batches(137custom2, validate_output_schema=False138).collect().to_dict(as_series=False) == {"a": ["1", "2", "3"], "b": ["a", "b", "c"]}139140141def test_map_batches_collect_schema_17327() -> None:142df = pl.LazyFrame({"a": [1, 1, 1], "b": [2, 3, 4]})143q = df.group_by("a").agg(144pl.col("b").map_batches(lambda s: s, return_dtype=pl.self_dtype())145)146expected = pl.Schema({"a": pl.Int64(), "b": pl.List(pl.Int64)})147assert q.collect_schema() == expected148149150