Path: blob/main/py-polars/tests/unit/operations/map/test_map_elements.py
6940 views
from __future__ import annotations12import json3from datetime import date, datetime, timedelta4from typing import Any, NamedTuple56import numpy as np7import pytest89import polars as pl10from polars.exceptions import PolarsInefficientMapWarning11from polars.testing import assert_frame_equal, assert_series_equal12from tests.unit.conftest import NUMERIC_DTYPES, TEMPORAL_DTYPES1314pytestmark = pytest.mark.filterwarnings(15"ignore::polars.exceptions.PolarsInefficientMapWarning"16)171819@pytest.mark.may_fail_auto_streaming # dtype not set20@pytest.mark.may_fail_cloud # reason: eager - return_dtype must be set21def test_map_elements_infer_list() -> None:22df = pl.DataFrame(23{24"int": [1, 2],25"str": ["a", "b"],26"bool": [True, None],27}28)29assert df.select([pl.all().map_elements(lambda x: [x])]).dtypes == [pl.List] * 3303132def test_map_elements_upcast_null_dtype_empty_list() -> None:33df = pl.DataFrame({"a": [1, 2]})34out = df.select(35pl.col("a").map_elements(lambda _: [], return_dtype=pl.List(pl.Int64))36)37assert_frame_equal(38out, pl.DataFrame({"a": [[], []]}, schema={"a": pl.List(pl.Int64)})39)404142@pytest.mark.may_fail_cloud # reason: eager - return_dtype must be set43def test_map_elements_arithmetic_consistency() -> None:44df = pl.DataFrame({"A": ["a", "a"], "B": [2, 3]})45with pytest.warns(PolarsInefficientMapWarning, match="with this one instead"):46assert df.group_by("A").agg(47pl.col("B")48.implode()49.map_elements(lambda x: x + 1.0, return_dtype=pl.List(pl.Float64))50)["B"].to_list() == [[3.0, 4.0]]515253@pytest.mark.may_fail_cloud # reason: eager - return_dtype must be set54def test_map_elements_struct() -> None:55df = pl.DataFrame(56{57"A": ["a", "a", None],58"B": [2, 3, None],59"C": [True, False, None],60"D": [12.0, None, None],61"E": [None, [1], [2, 3]],62}63)6465out = df.with_columns(pl.struct(df.columns).alias("struct")).select(66pl.col("struct").map_elements(lambda x: x["A"]).alias("A_field"),67pl.col("struct").map_elements(lambda x: x["B"]).alias("B_field"),68pl.col("struct").map_elements(lambda x: x["C"]).alias("C_field"),69pl.col("struct").map_elements(lambda x: x["D"]).alias("D_field"),70pl.col("struct").map_elements(lambda x: x["E"]).alias("E_field"),71)72expected = pl.DataFrame(73{74"A_field": ["a", "a", None],75"B_field": [2, 3, None],76"C_field": [True, False, None],77"D_field": [12.0, None, None],78"E_field": [None, [1], [2, 3]],79}80)8182assert_frame_equal(out, expected)838485@pytest.mark.may_fail_cloud # reason: eager - return_dtype must be set86def test_map_elements_numpy_int_out() -> None:87df = pl.DataFrame({"col1": [2, 4, 8, 16]})88result = df.with_columns(89pl.col("col1").map_elements(lambda x: np.left_shift(x, 8)).alias("result")90)91expected = pl.DataFrame({"col1": [2, 4, 8, 16], "result": [512, 1024, 2048, 4096]})92assert_frame_equal(result, expected)9394df = pl.DataFrame({"col1": [2, 4, 8, 16], "shift": [1, 1, 2, 2]})95result = df.select(96pl.struct(["col1", "shift"])97.map_elements(lambda cols: np.left_shift(cols["col1"], cols["shift"]))98.alias("result")99)100expected = pl.DataFrame({"result": [4, 8, 32, 64]})101assert_frame_equal(result, expected)102103104def test_datelike_identity() -> None:105for s in [106pl.Series([datetime(year=2000, month=1, day=1)]),107pl.Series([timedelta(hours=2)]),108pl.Series([date(year=2000, month=1, day=1)]),109]:110assert s.map_elements(lambda x: x).to_list() == s.to_list()111112113def test_map_elements_list_any_value_fallback() -> None:114with pytest.warns(115PolarsInefficientMapWarning,116match=r'(?s)with this one instead:.*pl.col\("text"\).str.json_decode()',117):118df = pl.DataFrame({"text": ['[{"x": 1, "y": 2}, {"x": 3, "y": 4}]']})119assert df.select(120pl.col("text").map_elements(121json.loads,122return_dtype=pl.List(pl.Struct({"x": pl.Int64, "y": pl.Int64})),123)124).to_dict(as_series=False) == {"text": [[{"x": 1, "y": 2}, {"x": 3, "y": 4}]]}125126# starts with empty list '[]'127df = pl.DataFrame(128{129"text": [130"[]",131'[{"x": 1, "y": 2}, {"x": 3, "y": 4}]',132'[{"x": 1, "y": 2}]',133]134}135)136assert df.select(137pl.col("text").map_elements(138json.loads,139return_dtype=pl.List(pl.Struct({"x": pl.Int64, "y": pl.Int64})),140)141).to_dict(as_series=False) == {142"text": [[], [{"x": 1, "y": 2}, {"x": 3, "y": 4}], [{"x": 1, "y": 2}]]143}144145146def test_map_elements_all_types() -> None:147# test we don't panic148dtypes = NUMERIC_DTYPES + TEMPORAL_DTYPES + [pl.Decimal(None, 2)]149for dtype in dtypes:150pl.Series([1, 2, 3, 4, 5], dtype=dtype).map_elements(lambda x: x)151152153def test_map_elements_type_propagation() -> None:154assert (155pl.from_dict(156{157"a": [1, 2, 3],158"b": [{"c": 1, "d": 2}, {"c": 2, "d": 3}, {"c": None, "d": None}],159}160)161.group_by("a", maintain_order=True)162.agg(163[164pl.when(~pl.col("b").has_nulls())165.then(166pl.col("b")167.implode()168.map_elements(169lambda s: s[0]["c"],170return_dtype=pl.Float64,171)172)173.otherwise(None)174]175)176).to_dict(as_series=False) == {"a": [1, 2, 3], "b": [1.0, 2.0, None]}177178179@pytest.mark.may_fail_auto_streaming # dtype not set180@pytest.mark.may_fail_cloud # reason: eager - return_dtype must be set181def test_empty_list_in_map_elements() -> None:182df = pl.DataFrame(183{"a": [[1], [1, 2], [3, 4], [5, 6]], "b": [[3], [1, 2], [1, 2], [4, 5]]}184)185186assert df.select(187pl.struct(["a", "b"]).map_elements(188lambda row: list(set(row["a"]) & set(row["b"]))189)190).to_dict(as_series=False) == {"a": [[], [1, 2], [], [5]]}191192193@pytest.mark.parametrize("value", [1, True, "abc", [1, 2], {"a": 1}])194@pytest.mark.parametrize("return_value", [1, True, "abc", [1, 2], {"a": 1}])195def test_map_elements_skip_nulls(value: Any, return_value: Any) -> None:196s = pl.Series([value, None])197198result = s.map_elements(lambda x: return_value, skip_nulls=True).to_list()199assert result == [return_value, None]200201result = s.map_elements(lambda x: return_value, skip_nulls=False).to_list()202assert result == [return_value, return_value]203204205@pytest.mark.may_fail_cloud # reason: Object type not supported206def test_map_elements_object_dtypes() -> None:207with pytest.warns(208PolarsInefficientMapWarning,209match=r"(?s)Replace this expression.*lambda x:",210):211assert pl.DataFrame(212{"a": pl.Series([1, 2, "a", 4, 5], dtype=pl.Object)}213).with_columns(214pl.col("a").map_elements(lambda x: x * 2, return_dtype=pl.Object),215pl.col("a")216.map_elements(217lambda x: isinstance(x, (int, float)), return_dtype=pl.Boolean218)219.alias("is_numeric1"),220pl.col("a")221.map_elements(222lambda x: isinstance(x, (int, float)), return_dtype=pl.Boolean223)224.alias("is_numeric_infer"),225).to_dict(as_series=False) == {226"a": [2, 4, "aa", 8, 10],227"is_numeric1": [True, True, False, True, True],228"is_numeric_infer": [True, True, False, True, True],229}230231232def test_map_elements_explicit_list_output_type() -> None:233out = pl.DataFrame({"str": ["a", "b"]}).with_columns(234pl.col("str").map_elements(235lambda _: pl.Series([1, 2, 3]), return_dtype=pl.List(pl.Int64)236)237)238239assert out.dtypes == [pl.List(pl.Int64)]240assert out.to_dict(as_series=False) == {"str": [[1, 2, 3], [1, 2, 3]]}241242243@pytest.mark.may_fail_auto_streaming # dtype not set244def test_map_elements_dict() -> None:245with pytest.warns(246PolarsInefficientMapWarning,247match=r'(?s)with this one instead:.*pl.col\("abc"\).str.json_decode()',248):249df = pl.DataFrame({"abc": ['{"A":"Value1"}', '{"B":"Value2"}']})250assert df.select(251pl.col("abc").map_elements(252json.loads, return_dtype=pl.Struct({"A": pl.String, "B": pl.String})253)254).to_dict(as_series=False) == {255"abc": [{"A": "Value1", "B": None}, {"A": None, "B": "Value2"}]256}257assert pl.DataFrame(258{"abc": ['{"A":"Value1", "B":"Value2"}', '{"B":"Value3"}']}259).select(260pl.col("abc").map_elements(261json.loads, return_dtype=pl.Struct({"A": pl.String, "B": pl.String})262)263).to_dict(as_series=False) == {264"abc": [{"A": "Value1", "B": "Value2"}, {"A": None, "B": "Value3"}]265}266267268def test_map_elements_pass_name() -> None:269df = pl.DataFrame(270{271"bar": [1, 1, 2],272"foo": [1, 2, 3],273}274)275276mapper = {"foo": "foo1"}277278def element_mapper(s: pl.Series) -> pl.Series:279return pl.Series([mapper[s.name]])280281assert df.group_by("bar", maintain_order=True).agg(282pl.col("foo")283.implode()284.map_elements(element_mapper, pass_name=True, return_dtype=pl.List(pl.String)),285).to_dict(as_series=False) == {"bar": [1, 2], "foo": [["foo1"], ["foo1"]]}286287288@pytest.mark.may_fail_cloud # reason: eager - return_dtype must be set289def test_map_elements_binary() -> None:290assert pl.DataFrame({"bin": [b"\x11" * 12, b"\x22" * 12, b"\xaa" * 12]}).select(291pl.col("bin").map_elements(bytes.hex)292).to_dict(as_series=False) == {293"bin": [294"111111111111111111111111",295"222222222222222222222222",296"aaaaaaaaaaaaaaaaaaaaaaaa",297]298}299300301def test_map_elements_set_datetime_output_8984() -> None:302df = pl.DataFrame({"a": [""]})303payload = datetime(2001, 1, 1)304assert df.select(305pl.col("a").map_elements(lambda _: payload, return_dtype=pl.Datetime),306)["a"].to_list() == [payload]307308309@pytest.mark.may_fail_cloud # reason: eager - return_dtype must be set310def test_map_elements_dict_order_10128() -> None:311df = pl.select(pl.lit("").map_elements(lambda x: {"c": 1, "b": 2, "a": 3}))312assert df.to_dict(as_series=False) == {"literal": [{"c": 1, "b": 2, "a": 3}]}313314315@pytest.mark.may_fail_cloud # reason: eager - return_dtype must be set316def test_map_elements_10237() -> None:317df = pl.DataFrame({"a": [1, 2, 3]})318assert (319df.select(pl.all().map_elements(lambda x: x > 50))["a"].to_list() == [False] * 3320)321322323@pytest.mark.may_fail_cloud # reason: eager - return_dtype must be set324def test_map_elements_on_empty_col_10639() -> None:325df = pl.DataFrame({"A": [], "B": []}, schema={"A": pl.Float32, "B": pl.Float32})326res = df.group_by("B").agg(327pl.col("A")328.map_elements(lambda x: x, return_dtype=pl.Int32, strategy="threading")329.alias("Foo")330)331assert res.to_dict(as_series=False) == {332"B": [],333"Foo": [],334}335336res = df.group_by("B").agg(337pl.col("A")338.map_elements(lambda x: x, return_dtype=pl.Int32, strategy="thread_local")339.alias("Foo")340)341assert res.to_dict(as_series=False) == {342"B": [],343"Foo": [],344}345346347def test_map_elements_chunked_14390() -> None:348s = pl.concat(2 * [pl.Series([1])], rechunk=False)349assert s.n_chunks() > 1350with pytest.warns(PolarsInefficientMapWarning):351assert_series_equal(352s.map_elements(str, return_dtype=pl.String),353pl.Series(["1", "1"]),354check_names=False,355)356357358def test_cabbage_strategy_14396() -> None:359df = pl.DataFrame({"x": [1, 2, 3]})360with (361pytest.raises(ValueError, match="strategy 'cabbage' is not supported"),362pytest.warns(PolarsInefficientMapWarning),363):364df.select(pl.col("x").map_elements(lambda x: 2 * x, strategy="cabbage")) # type: ignore[arg-type]365366367def test_map_elements_list_dtype_18472() -> None:368s = pl.Series([[None], ["abc ", None]])369result = s.map_elements(lambda s: [i.strip() if i else None for i in s])370expected = pl.Series([[None], ["abc", None]])371assert_series_equal(result, expected)372373374def test_map_elements_list_return_dtype() -> None:375s = pl.Series([[1], [2, 3]])376return_dtype = pl.List(pl.UInt16)377378result = s.map_elements(379lambda s: [i + 1 for i in s],380return_dtype=return_dtype,381)382expected = pl.Series([[2], [3, 4]], dtype=return_dtype)383assert_series_equal(result, expected)384385386def test_map_elements_list_of_named_tuple_15425() -> None:387class Foo(NamedTuple):388x: int389390df = pl.DataFrame({"a": [0, 1, 2]})391result = df.select(392pl.col("a").map_elements(393lambda x: [Foo(i) for i in range(x)],394return_dtype=pl.List(pl.Struct({"x": pl.Int64})),395)396)397expected = pl.DataFrame({"a": [[], [{"x": 0}], [{"x": 0}, {"x": 1}]]})398assert_frame_equal(result, expected)399400401def test_map_elements_list_dtype_24006() -> None:402values = [None, [1, 2], [2, 3]]403dtype = pl.List(pl.Int64)404405s1 = pl.Series([0, 1, 2]).map_elements(lambda x: values[x])406s2 = pl.Series([0, 1, 2]).map_elements(lambda x: values[x], return_dtype=dtype)407408assert_series_equal(s1, s2)409assert_series_equal(s1, pl.Series(values, dtype=dtype))410411412def test_map_elements_reentrant_mutable_no_deadlock() -> None:413s = pl.Series("a", [1, 2, 3])414s.map_elements(lambda _: s.rechunk(in_place=True)[0])415416417