Path: blob/main/py-polars/tests/unit/operations/map/test_map_elements.py
8408 views
from __future__ import annotations12import json3from datetime import date, datetime, timedelta4from typing import Any, NamedTuple56import numpy as np7import pytest89import polars as pl10from polars.exceptions import PolarsInefficientMapWarning11from polars.testing import assert_frame_equal, assert_series_equal12from tests.unit.conftest import NUMERIC_DTYPES, TEMPORAL_DTYPES1314pytestmark = pytest.mark.filterwarnings(15"ignore::polars.exceptions.PolarsInefficientMapWarning"16)171819@pytest.mark.may_fail_auto_streaming # dtype not set20@pytest.mark.may_fail_cloud # reason: eager - return_dtype must be set21def test_map_elements_infer_list() -> None:22df = pl.DataFrame(23{24"int": [1, 2],25"str": ["a", "b"],26"bool": [True, None],27}28)29assert df.select([pl.all().map_elements(lambda x: [x])]).dtypes == [pl.List] * 3303132def test_map_elements_upcast_null_dtype_empty_list() -> None:33df = pl.DataFrame({"a": [1, 2]})34out = df.select(35pl.col("a").map_elements(lambda _: [], return_dtype=pl.List(pl.Int64))36)37assert_frame_equal(38out, pl.DataFrame({"a": [[], []]}, schema={"a": pl.List(pl.Int64)})39)404142@pytest.mark.may_fail_cloud # reason: eager - return_dtype must be set43def test_map_elements_arithmetic_consistency() -> None:44df = pl.DataFrame({"A": ["a", "a"], "B": [2, 3]})45with pytest.warns(46PolarsInefficientMapWarning,47match="with this one instead",48):49assert df.group_by("A").agg(50pl.col("B")51.implode()52.map_elements(lambda x: x + 1.0, return_dtype=pl.List(pl.Float64))53)["B"].to_list() == [[3.0, 4.0]]545556@pytest.mark.may_fail_cloud # reason: eager - return_dtype must be set57def test_map_elements_struct() -> None:58df = pl.DataFrame(59{60"A": ["a", "a", None],61"B": [2, 3, None],62"C": [True, False, None],63"D": [12.0, None, None],64"E": [None, [1], [2, 3]],65}66)6768out = df.with_columns(pl.struct(df.columns).alias("struct")).select(69pl.col("struct").map_elements(lambda x: x["A"]).alias("A_field"),70pl.col("struct").map_elements(lambda x: x["B"]).alias("B_field"),71pl.col("struct").map_elements(lambda x: x["C"]).alias("C_field"),72pl.col("struct").map_elements(lambda x: x["D"]).alias("D_field"),73pl.col("struct").map_elements(lambda x: x["E"]).alias("E_field"),74)75expected = pl.DataFrame(76{77"A_field": ["a", "a", None],78"B_field": [2, 3, None],79"C_field": [True, False, None],80"D_field": [12.0, None, None],81"E_field": [None, [1], [2, 3]],82}83)8485assert_frame_equal(out, expected)868788@pytest.mark.may_fail_cloud # reason: eager - return_dtype must be set89def test_map_elements_numpy_int_out() -> None:90df = pl.DataFrame({"col1": [2, 4, 8, 16]})91result = df.with_columns(92pl.col("col1").map_elements(lambda x: np.left_shift(x, 8)).alias("result")93)94expected = pl.DataFrame({"col1": [2, 4, 8, 16], "result": [512, 1024, 2048, 4096]})95assert_frame_equal(result, expected)9697df = pl.DataFrame({"col1": [2, 4, 8, 16], "shift": [1, 1, 2, 2]})98result = df.select(99pl.struct(["col1", "shift"])100.map_elements(lambda cols: np.left_shift(cols["col1"], cols["shift"]))101.alias("result")102)103expected = pl.DataFrame({"result": [4, 8, 32, 64]})104assert_frame_equal(result, expected)105106107def test_datelike_identity() -> None:108for s in [109pl.Series([datetime(year=2000, month=1, day=1)]),110pl.Series([timedelta(hours=2)]),111pl.Series([date(year=2000, month=1, day=1)]),112]:113assert s.map_elements(lambda x: x).to_list() == s.to_list()114115116def test_map_elements_list_any_value_fallback() -> None:117df = pl.DataFrame({"text": ['[{"x": 1, "y": 2}, {"x": 3, "y": 4}]']})118with pytest.warns(119PolarsInefficientMapWarning,120match=r'(?s)with this one instead:.*pl.col\("text"\).str.json_decode()',121):122assert df.select(123pl.col("text").map_elements(124json.loads,125return_dtype=pl.List(pl.Struct({"x": pl.Int64, "y": pl.Int64})),126)127).to_dict(as_series=False) == {"text": [[{"x": 1, "y": 2}, {"x": 3, "y": 4}]]}128129# starts with empty list '[]'130df = pl.DataFrame(131{132"text": [133"[]",134'[{"x": 1, "y": 2}, {"x": 3, "y": 4}]',135'[{"x": 1, "y": 2}]',136]137}138)139with pytest.warns(140PolarsInefficientMapWarning,141match=r'(?s)with this one instead:.*pl.col\("text"\).str.json_decode()',142):143assert df.select(144pl.col("text").map_elements(145json.loads,146return_dtype=pl.List(pl.Struct({"x": pl.Int64, "y": pl.Int64})),147)148).to_dict(as_series=False) == {149"text": [[], [{"x": 1, "y": 2}, {"x": 3, "y": 4}], [{"x": 1, "y": 2}]]150}151152153def test_map_elements_all_types() -> None:154# test we don't panic155dtypes = NUMERIC_DTYPES + TEMPORAL_DTYPES + [pl.Decimal(None, 2)]156for dtype in dtypes:157pl.Series([1, 2, 3, 4, 5], dtype=dtype).map_elements(lambda x: x)158159160def test_map_elements_type_propagation() -> None:161assert (162pl.from_dict(163{164"a": [1, 2, 3],165"b": [{"c": 1, "d": 2}, {"c": 2, "d": 3}, {"c": None, "d": None}],166}167)168.group_by("a", maintain_order=True)169.agg(170[171pl.when(~pl.col("b").has_nulls())172.then(173pl.col("b")174.implode()175.map_elements(176lambda s: float(s[0]["c"]) if s[0]["c"] is not None else None,177return_dtype=pl.Float64,178)179)180.otherwise(None)181]182)183).to_dict(as_series=False) == {"a": [1, 2, 3], "b": [1.0, 2.0, None]}184185186@pytest.mark.may_fail_auto_streaming # dtype not set187@pytest.mark.may_fail_cloud # reason: eager - return_dtype must be set188def test_empty_list_in_map_elements() -> None:189df = pl.DataFrame(190{"a": [[1], [1, 2], [3, 4], [5, 6]], "b": [[3], [1, 2], [1, 2], [4, 5]]}191)192193assert df.select(194pl.struct(["a", "b"]).map_elements(195lambda row: list(set(row["a"]) & set(row["b"]))196)197).to_dict(as_series=False) == {"a": [[], [1, 2], [], [5]]}198199200@pytest.mark.parametrize("value", [1, True, "abc", [1, 2], {"a": 1}])201@pytest.mark.parametrize("return_value", [1, True, "abc", [1, 2], {"a": 1}])202def test_map_elements_skip_nulls(value: Any, return_value: Any) -> None:203s = pl.Series([value, None])204205result = s.map_elements(lambda x: return_value, skip_nulls=True).to_list()206assert result == [return_value, None]207208result = s.map_elements(lambda x: return_value, skip_nulls=False).to_list()209assert result == [return_value, return_value]210211212@pytest.mark.may_fail_cloud # reason: Object type not supported213def test_map_elements_object_dtypes() -> None:214with pytest.warns(215PolarsInefficientMapWarning,216match=r"(?s)Replace this expression.*lambda x:",217):218assert pl.DataFrame(219{"a": pl.Series([1, 2, "a", 4, 5], dtype=pl.Object)}220).with_columns(221pl.col("a").map_elements(lambda x: x * 2, return_dtype=pl.Object),222pl.col("a")223.map_elements(224lambda x: isinstance(x, (int, float)), return_dtype=pl.Boolean225)226.alias("is_numeric1"),227pl.col("a")228.map_elements(229lambda x: isinstance(x, (int, float)), return_dtype=pl.Boolean230)231.alias("is_numeric_infer"),232).to_dict(as_series=False) == {233"a": [2, 4, "aa", 8, 10],234"is_numeric1": [True, True, False, True, True],235"is_numeric_infer": [True, True, False, True, True],236}237238239def test_map_elements_explicit_list_output_type() -> None:240out = pl.DataFrame({"str": ["a", "b"]}).with_columns(241pl.col("str").map_elements(242lambda _: pl.Series([1, 2, 3]), return_dtype=pl.List(pl.Int64)243)244)245246assert out.dtypes == [pl.List(pl.Int64)]247assert out.to_dict(as_series=False) == {"str": [[1, 2, 3], [1, 2, 3]]}248249250@pytest.mark.may_fail_auto_streaming # dtype not set251def test_map_elements_dict() -> None:252df = pl.DataFrame({"abc": ['{"A":"Value1"}', '{"B":"Value2"}']})253with pytest.warns(254PolarsInefficientMapWarning,255match=r'(?s)with this one instead:.*pl.col\("abc"\).str.json_decode()',256):257assert df.select(258pl.col("abc").map_elements(259json.loads, return_dtype=pl.Struct({"A": pl.String, "B": pl.String})260)261).to_dict(as_series=False) == {262"abc": [{"A": "Value1", "B": None}, {"A": None, "B": "Value2"}]263}264265with pytest.warns(266PolarsInefficientMapWarning,267match=r'(?s)with this one instead:.*pl.col\("abc"\).str.json_decode()',268):269assert pl.DataFrame(270{"abc": ['{"A":"Value1", "B":"Value2"}', '{"B":"Value3"}']}271).select(272pl.col("abc").map_elements(273json.loads, return_dtype=pl.Struct({"A": pl.String, "B": pl.String})274)275).to_dict(as_series=False) == {276"abc": [{"A": "Value1", "B": "Value2"}, {"A": None, "B": "Value3"}]277}278279280def test_map_elements_pass_name() -> None:281df = pl.DataFrame(282{283"bar": [1, 1, 2],284"foo": [1, 2, 3],285}286)287288mapper = {"foo": "foo1"}289290def element_mapper(s: pl.Series) -> pl.Series:291return pl.Series([mapper[s.name]])292293assert df.group_by("bar", maintain_order=True).agg(294pl.col("foo")295.implode()296.map_elements(element_mapper, pass_name=True, return_dtype=pl.List(pl.String)),297).to_dict(as_series=False) == {"bar": [1, 2], "foo": [["foo1"], ["foo1"]]}298299300@pytest.mark.may_fail_cloud # reason: eager - return_dtype must be set301def test_map_elements_binary() -> None:302assert pl.DataFrame({"bin": [b"\x11" * 12, b"\x22" * 12, b"\xaa" * 12]}).select(303pl.col("bin").map_elements(bytes.hex)304).to_dict(as_series=False) == {305"bin": [306"111111111111111111111111",307"222222222222222222222222",308"aaaaaaaaaaaaaaaaaaaaaaaa",309]310}311312313def test_map_elements_set_datetime_output_8984() -> None:314df = pl.DataFrame({"a": [""]})315payload = datetime(2001, 1, 1)316assert df.select(317pl.col("a").map_elements(lambda _: payload, return_dtype=pl.Datetime),318)["a"].to_list() == [payload]319320321@pytest.mark.may_fail_cloud # reason: eager - return_dtype must be set322def test_map_elements_dict_order_10128() -> None:323df = pl.select(pl.lit("").map_elements(lambda x: {"c": 1, "b": 2, "a": 3}))324assert df.to_dict(as_series=False) == {"literal": [{"c": 1, "b": 2, "a": 3}]}325326327@pytest.mark.may_fail_cloud # reason: eager - return_dtype must be set328def test_map_elements_10237() -> None:329df = pl.DataFrame({"a": [1, 2, 3]})330assert (331df.select(pl.all().map_elements(lambda x: x > 50))["a"].to_list() == [False] * 3332)333334335@pytest.mark.may_fail_cloud # reason: eager - return_dtype must be set336def test_map_elements_on_empty_col_10639() -> None:337df = pl.DataFrame({"A": [], "B": []}, schema={"A": pl.Float32, "B": pl.Float32})338res = df.group_by("B").agg(339pl.col("A")340.map_elements(lambda x: x, return_dtype=pl.Int32, strategy="threading")341.alias("Foo")342)343assert res.to_dict(as_series=False) == {344"B": [],345"Foo": [],346}347348res = df.group_by("B").agg(349pl.col("A")350.map_elements(lambda x: x, return_dtype=pl.Int32, strategy="thread_local")351.alias("Foo")352)353assert res.to_dict(as_series=False) == {354"B": [],355"Foo": [],356}357358359def test_map_elements_chunked_14390() -> None:360s = pl.concat(2 * [pl.Series([1])], rechunk=False)361assert s.n_chunks() > 1362with pytest.warns(PolarsInefficientMapWarning):363assert_series_equal(364s.map_elements(str, return_dtype=pl.String),365pl.Series(["1", "1"]),366check_names=False,367)368369370def test_cabbage_strategy_14396() -> None:371df = pl.DataFrame({"x": [1, 2, 3]})372with (373pytest.raises(ValueError, match="strategy 'cabbage' is not supported"),374pytest.warns(PolarsInefficientMapWarning),375):376df.select(pl.col("x").map_elements(lambda x: 2 * x, strategy="cabbage")) # type: ignore[arg-type]377378379def test_map_elements_list_dtype_18472() -> None:380s = pl.Series([[None], ["abc ", None]])381result = s.map_elements(lambda s: [i.strip() if i else None for i in s])382expected = pl.Series([[None], ["abc", None]])383assert_series_equal(result, expected)384385386def test_map_elements_list_return_dtype() -> None:387s = pl.Series([[1], [2, 3]])388return_dtype = pl.List(pl.UInt16)389390result = s.map_elements(391lambda s: [i + 1 for i in s],392return_dtype=return_dtype,393)394expected = pl.Series([[2], [3, 4]], dtype=return_dtype)395assert_series_equal(result, expected)396397398def test_map_elements_list_of_named_tuple_15425() -> None:399class Foo(NamedTuple):400x: int401402df = pl.DataFrame({"a": [0, 1, 2]})403result = df.select(404pl.col("a").map_elements(405lambda x: [Foo(i) for i in range(x)],406return_dtype=pl.List(pl.Struct({"x": pl.Int64})),407)408)409expected = pl.DataFrame({"a": [[], [{"x": 0}], [{"x": 0}, {"x": 1}]]})410assert_frame_equal(result, expected)411412413def test_map_elements_list_dtype_24006() -> None:414values = [None, [1, 2], [2, 3]]415dtype = pl.List(pl.Int64)416417s1 = pl.Series([0, 1, 2]).map_elements(lambda x: values[x])418s2 = pl.Series([0, 1, 2]).map_elements(lambda x: values[x], return_dtype=dtype)419420assert_series_equal(s1, s2)421assert_series_equal(s1, pl.Series(values, dtype=dtype))422423424def test_map_elements_reentrant_mutable_no_deadlock() -> None:425s = pl.Series("a", [1, 2, 3])426s.map_elements(lambda _: s.rechunk(in_place=True)[0])427428429