Path: blob/main/py-polars/tests/unit/operations/map/test_map_rows.py
6940 views
from __future__ import annotations12from typing import Any34import pytest56import polars as pl7from polars.exceptions import ComputeError8from polars.testing import assert_frame_equal91011def test_map_rows() -> None:12df = pl.DataFrame({"a": ["foo", "2"], "b": [1, 2], "c": [1.0, 2.0]})1314result = df.map_rows(lambda x: len(x), None)1516expected = pl.DataFrame({"map": [3, 3]})17assert_frame_equal(result, expected)181920def test_map_rows_list_return() -> None:21df = pl.DataFrame({"start": [1, 2], "end": [3, 5]})2223result = df.map_rows(lambda r: pl.Series(range(r[0], r[1] + 1)))2425expected = pl.DataFrame({"map": [[1, 2, 3], [2, 3, 4, 5]]})26assert_frame_equal(result, expected)272829def test_map_rows_dataframe_return() -> None:30df = pl.DataFrame({"a": [1, 2, 3], "b": ["c", "d", None]})3132result = df.map_rows(lambda row: (row[0] * 10, "foo", True, row[-1]))3334expected = pl.DataFrame(35{36"column_0": [10, 20, 30],37"column_1": ["foo", "foo", "foo"],38"column_2": [True, True, True],39"column_3": ["c", "d", None],40}41)42assert_frame_equal(result, expected)434445def test_map_rows_error_return_type() -> None:46df = pl.DataFrame({"a": [[1, 2], [2, 3]], "b": [[4, 5], [6, 7]]})4748def combine(row: tuple[Any, ...]) -> list[Any]:49res = [x + y for x, y in zip(row[0], row[1])]50return [res]5152with pytest.raises(ComputeError, match="expected tuple, got list"):53df.map_rows(combine)545556def test_map_rows_shifted_chunks() -> None:57df = pl.DataFrame(pl.Series("texts", ["test", "test123", "tests"]))58df = df.select(pl.col("texts"), pl.col("texts").shift(1).alias("texts_shifted"))5960result = df.map_rows(lambda x: x)6162expected = pl.DataFrame(63{64"column_0": ["test", "test123", "tests"],65"column_1": [None, "test", "test123"],66}67)68assert_frame_equal(result, expected)697071def test_map_elements_infer() -> None:72lf = pl.LazyFrame(73{74"a": [1, 2, 3],75}76)77lf = lf.select(pl.col.a.map_elements(lambda v: f"pre-{v}"))7879# this should not go through execution, solely through the planner80schema = lf.collect_schema()8182assert schema.names() == ["a"]83assert schema.dtypes() == [pl.String]848586