Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/operations/map/test_map_elements.py
6940 views
1
from __future__ import annotations
2
3
import json
4
from datetime import date, datetime, timedelta
5
from typing import Any, NamedTuple
6
7
import numpy as np
8
import pytest
9
10
import polars as pl
11
from polars.exceptions import PolarsInefficientMapWarning
12
from polars.testing import assert_frame_equal, assert_series_equal
13
from tests.unit.conftest import NUMERIC_DTYPES, TEMPORAL_DTYPES
14
15
pytestmark = pytest.mark.filterwarnings(
16
"ignore::polars.exceptions.PolarsInefficientMapWarning"
17
)
18
19
20
@pytest.mark.may_fail_auto_streaming # dtype not set
21
@pytest.mark.may_fail_cloud # reason: eager - return_dtype must be set
22
def test_map_elements_infer_list() -> None:
23
df = pl.DataFrame(
24
{
25
"int": [1, 2],
26
"str": ["a", "b"],
27
"bool": [True, None],
28
}
29
)
30
assert df.select([pl.all().map_elements(lambda x: [x])]).dtypes == [pl.List] * 3
31
32
33
def test_map_elements_upcast_null_dtype_empty_list() -> None:
34
df = pl.DataFrame({"a": [1, 2]})
35
out = df.select(
36
pl.col("a").map_elements(lambda _: [], return_dtype=pl.List(pl.Int64))
37
)
38
assert_frame_equal(
39
out, pl.DataFrame({"a": [[], []]}, schema={"a": pl.List(pl.Int64)})
40
)
41
42
43
@pytest.mark.may_fail_cloud # reason: eager - return_dtype must be set
44
def test_map_elements_arithmetic_consistency() -> None:
45
df = pl.DataFrame({"A": ["a", "a"], "B": [2, 3]})
46
with pytest.warns(PolarsInefficientMapWarning, match="with this one instead"):
47
assert df.group_by("A").agg(
48
pl.col("B")
49
.implode()
50
.map_elements(lambda x: x + 1.0, return_dtype=pl.List(pl.Float64))
51
)["B"].to_list() == [[3.0, 4.0]]
52
53
54
@pytest.mark.may_fail_cloud # reason: eager - return_dtype must be set
55
def test_map_elements_struct() -> None:
56
df = pl.DataFrame(
57
{
58
"A": ["a", "a", None],
59
"B": [2, 3, None],
60
"C": [True, False, None],
61
"D": [12.0, None, None],
62
"E": [None, [1], [2, 3]],
63
}
64
)
65
66
out = df.with_columns(pl.struct(df.columns).alias("struct")).select(
67
pl.col("struct").map_elements(lambda x: x["A"]).alias("A_field"),
68
pl.col("struct").map_elements(lambda x: x["B"]).alias("B_field"),
69
pl.col("struct").map_elements(lambda x: x["C"]).alias("C_field"),
70
pl.col("struct").map_elements(lambda x: x["D"]).alias("D_field"),
71
pl.col("struct").map_elements(lambda x: x["E"]).alias("E_field"),
72
)
73
expected = pl.DataFrame(
74
{
75
"A_field": ["a", "a", None],
76
"B_field": [2, 3, None],
77
"C_field": [True, False, None],
78
"D_field": [12.0, None, None],
79
"E_field": [None, [1], [2, 3]],
80
}
81
)
82
83
assert_frame_equal(out, expected)
84
85
86
@pytest.mark.may_fail_cloud # reason: eager - return_dtype must be set
87
def test_map_elements_numpy_int_out() -> None:
88
df = pl.DataFrame({"col1": [2, 4, 8, 16]})
89
result = df.with_columns(
90
pl.col("col1").map_elements(lambda x: np.left_shift(x, 8)).alias("result")
91
)
92
expected = pl.DataFrame({"col1": [2, 4, 8, 16], "result": [512, 1024, 2048, 4096]})
93
assert_frame_equal(result, expected)
94
95
df = pl.DataFrame({"col1": [2, 4, 8, 16], "shift": [1, 1, 2, 2]})
96
result = df.select(
97
pl.struct(["col1", "shift"])
98
.map_elements(lambda cols: np.left_shift(cols["col1"], cols["shift"]))
99
.alias("result")
100
)
101
expected = pl.DataFrame({"result": [4, 8, 32, 64]})
102
assert_frame_equal(result, expected)
103
104
105
def test_datelike_identity() -> None:
106
for s in [
107
pl.Series([datetime(year=2000, month=1, day=1)]),
108
pl.Series([timedelta(hours=2)]),
109
pl.Series([date(year=2000, month=1, day=1)]),
110
]:
111
assert s.map_elements(lambda x: x).to_list() == s.to_list()
112
113
114
def test_map_elements_list_any_value_fallback() -> None:
115
with pytest.warns(
116
PolarsInefficientMapWarning,
117
match=r'(?s)with this one instead:.*pl.col\("text"\).str.json_decode()',
118
):
119
df = pl.DataFrame({"text": ['[{"x": 1, "y": 2}, {"x": 3, "y": 4}]']})
120
assert df.select(
121
pl.col("text").map_elements(
122
json.loads,
123
return_dtype=pl.List(pl.Struct({"x": pl.Int64, "y": pl.Int64})),
124
)
125
).to_dict(as_series=False) == {"text": [[{"x": 1, "y": 2}, {"x": 3, "y": 4}]]}
126
127
# starts with empty list '[]'
128
df = pl.DataFrame(
129
{
130
"text": [
131
"[]",
132
'[{"x": 1, "y": 2}, {"x": 3, "y": 4}]',
133
'[{"x": 1, "y": 2}]',
134
]
135
}
136
)
137
assert df.select(
138
pl.col("text").map_elements(
139
json.loads,
140
return_dtype=pl.List(pl.Struct({"x": pl.Int64, "y": pl.Int64})),
141
)
142
).to_dict(as_series=False) == {
143
"text": [[], [{"x": 1, "y": 2}, {"x": 3, "y": 4}], [{"x": 1, "y": 2}]]
144
}
145
146
147
def test_map_elements_all_types() -> None:
148
# test we don't panic
149
dtypes = NUMERIC_DTYPES + TEMPORAL_DTYPES + [pl.Decimal(None, 2)]
150
for dtype in dtypes:
151
pl.Series([1, 2, 3, 4, 5], dtype=dtype).map_elements(lambda x: x)
152
153
154
def test_map_elements_type_propagation() -> None:
155
assert (
156
pl.from_dict(
157
{
158
"a": [1, 2, 3],
159
"b": [{"c": 1, "d": 2}, {"c": 2, "d": 3}, {"c": None, "d": None}],
160
}
161
)
162
.group_by("a", maintain_order=True)
163
.agg(
164
[
165
pl.when(~pl.col("b").has_nulls())
166
.then(
167
pl.col("b")
168
.implode()
169
.map_elements(
170
lambda s: s[0]["c"],
171
return_dtype=pl.Float64,
172
)
173
)
174
.otherwise(None)
175
]
176
)
177
).to_dict(as_series=False) == {"a": [1, 2, 3], "b": [1.0, 2.0, None]}
178
179
180
@pytest.mark.may_fail_auto_streaming # dtype not set
181
@pytest.mark.may_fail_cloud # reason: eager - return_dtype must be set
182
def test_empty_list_in_map_elements() -> None:
183
df = pl.DataFrame(
184
{"a": [[1], [1, 2], [3, 4], [5, 6]], "b": [[3], [1, 2], [1, 2], [4, 5]]}
185
)
186
187
assert df.select(
188
pl.struct(["a", "b"]).map_elements(
189
lambda row: list(set(row["a"]) & set(row["b"]))
190
)
191
).to_dict(as_series=False) == {"a": [[], [1, 2], [], [5]]}
192
193
194
@pytest.mark.parametrize("value", [1, True, "abc", [1, 2], {"a": 1}])
195
@pytest.mark.parametrize("return_value", [1, True, "abc", [1, 2], {"a": 1}])
196
def test_map_elements_skip_nulls(value: Any, return_value: Any) -> None:
197
s = pl.Series([value, None])
198
199
result = s.map_elements(lambda x: return_value, skip_nulls=True).to_list()
200
assert result == [return_value, None]
201
202
result = s.map_elements(lambda x: return_value, skip_nulls=False).to_list()
203
assert result == [return_value, return_value]
204
205
206
@pytest.mark.may_fail_cloud # reason: Object type not supported
207
def test_map_elements_object_dtypes() -> None:
208
with pytest.warns(
209
PolarsInefficientMapWarning,
210
match=r"(?s)Replace this expression.*lambda x:",
211
):
212
assert pl.DataFrame(
213
{"a": pl.Series([1, 2, "a", 4, 5], dtype=pl.Object)}
214
).with_columns(
215
pl.col("a").map_elements(lambda x: x * 2, return_dtype=pl.Object),
216
pl.col("a")
217
.map_elements(
218
lambda x: isinstance(x, (int, float)), return_dtype=pl.Boolean
219
)
220
.alias("is_numeric1"),
221
pl.col("a")
222
.map_elements(
223
lambda x: isinstance(x, (int, float)), return_dtype=pl.Boolean
224
)
225
.alias("is_numeric_infer"),
226
).to_dict(as_series=False) == {
227
"a": [2, 4, "aa", 8, 10],
228
"is_numeric1": [True, True, False, True, True],
229
"is_numeric_infer": [True, True, False, True, True],
230
}
231
232
233
def test_map_elements_explicit_list_output_type() -> None:
234
out = pl.DataFrame({"str": ["a", "b"]}).with_columns(
235
pl.col("str").map_elements(
236
lambda _: pl.Series([1, 2, 3]), return_dtype=pl.List(pl.Int64)
237
)
238
)
239
240
assert out.dtypes == [pl.List(pl.Int64)]
241
assert out.to_dict(as_series=False) == {"str": [[1, 2, 3], [1, 2, 3]]}
242
243
244
@pytest.mark.may_fail_auto_streaming # dtype not set
245
def test_map_elements_dict() -> None:
246
with pytest.warns(
247
PolarsInefficientMapWarning,
248
match=r'(?s)with this one instead:.*pl.col\("abc"\).str.json_decode()',
249
):
250
df = pl.DataFrame({"abc": ['{"A":"Value1"}', '{"B":"Value2"}']})
251
assert df.select(
252
pl.col("abc").map_elements(
253
json.loads, return_dtype=pl.Struct({"A": pl.String, "B": pl.String})
254
)
255
).to_dict(as_series=False) == {
256
"abc": [{"A": "Value1", "B": None}, {"A": None, "B": "Value2"}]
257
}
258
assert pl.DataFrame(
259
{"abc": ['{"A":"Value1", "B":"Value2"}', '{"B":"Value3"}']}
260
).select(
261
pl.col("abc").map_elements(
262
json.loads, return_dtype=pl.Struct({"A": pl.String, "B": pl.String})
263
)
264
).to_dict(as_series=False) == {
265
"abc": [{"A": "Value1", "B": "Value2"}, {"A": None, "B": "Value3"}]
266
}
267
268
269
def test_map_elements_pass_name() -> None:
270
df = pl.DataFrame(
271
{
272
"bar": [1, 1, 2],
273
"foo": [1, 2, 3],
274
}
275
)
276
277
mapper = {"foo": "foo1"}
278
279
def element_mapper(s: pl.Series) -> pl.Series:
280
return pl.Series([mapper[s.name]])
281
282
assert df.group_by("bar", maintain_order=True).agg(
283
pl.col("foo")
284
.implode()
285
.map_elements(element_mapper, pass_name=True, return_dtype=pl.List(pl.String)),
286
).to_dict(as_series=False) == {"bar": [1, 2], "foo": [["foo1"], ["foo1"]]}
287
288
289
@pytest.mark.may_fail_cloud # reason: eager - return_dtype must be set
290
def test_map_elements_binary() -> None:
291
assert pl.DataFrame({"bin": [b"\x11" * 12, b"\x22" * 12, b"\xaa" * 12]}).select(
292
pl.col("bin").map_elements(bytes.hex)
293
).to_dict(as_series=False) == {
294
"bin": [
295
"111111111111111111111111",
296
"222222222222222222222222",
297
"aaaaaaaaaaaaaaaaaaaaaaaa",
298
]
299
}
300
301
302
def test_map_elements_set_datetime_output_8984() -> None:
303
df = pl.DataFrame({"a": [""]})
304
payload = datetime(2001, 1, 1)
305
assert df.select(
306
pl.col("a").map_elements(lambda _: payload, return_dtype=pl.Datetime),
307
)["a"].to_list() == [payload]
308
309
310
@pytest.mark.may_fail_cloud # reason: eager - return_dtype must be set
311
def test_map_elements_dict_order_10128() -> None:
312
df = pl.select(pl.lit("").map_elements(lambda x: {"c": 1, "b": 2, "a": 3}))
313
assert df.to_dict(as_series=False) == {"literal": [{"c": 1, "b": 2, "a": 3}]}
314
315
316
@pytest.mark.may_fail_cloud # reason: eager - return_dtype must be set
317
def test_map_elements_10237() -> None:
318
df = pl.DataFrame({"a": [1, 2, 3]})
319
assert (
320
df.select(pl.all().map_elements(lambda x: x > 50))["a"].to_list() == [False] * 3
321
)
322
323
324
@pytest.mark.may_fail_cloud # reason: eager - return_dtype must be set
325
def test_map_elements_on_empty_col_10639() -> None:
326
df = pl.DataFrame({"A": [], "B": []}, schema={"A": pl.Float32, "B": pl.Float32})
327
res = df.group_by("B").agg(
328
pl.col("A")
329
.map_elements(lambda x: x, return_dtype=pl.Int32, strategy="threading")
330
.alias("Foo")
331
)
332
assert res.to_dict(as_series=False) == {
333
"B": [],
334
"Foo": [],
335
}
336
337
res = df.group_by("B").agg(
338
pl.col("A")
339
.map_elements(lambda x: x, return_dtype=pl.Int32, strategy="thread_local")
340
.alias("Foo")
341
)
342
assert res.to_dict(as_series=False) == {
343
"B": [],
344
"Foo": [],
345
}
346
347
348
def test_map_elements_chunked_14390() -> None:
349
s = pl.concat(2 * [pl.Series([1])], rechunk=False)
350
assert s.n_chunks() > 1
351
with pytest.warns(PolarsInefficientMapWarning):
352
assert_series_equal(
353
s.map_elements(str, return_dtype=pl.String),
354
pl.Series(["1", "1"]),
355
check_names=False,
356
)
357
358
359
def test_cabbage_strategy_14396() -> None:
360
df = pl.DataFrame({"x": [1, 2, 3]})
361
with (
362
pytest.raises(ValueError, match="strategy 'cabbage' is not supported"),
363
pytest.warns(PolarsInefficientMapWarning),
364
):
365
df.select(pl.col("x").map_elements(lambda x: 2 * x, strategy="cabbage")) # type: ignore[arg-type]
366
367
368
def test_map_elements_list_dtype_18472() -> None:
369
s = pl.Series([[None], ["abc ", None]])
370
result = s.map_elements(lambda s: [i.strip() if i else None for i in s])
371
expected = pl.Series([[None], ["abc", None]])
372
assert_series_equal(result, expected)
373
374
375
def test_map_elements_list_return_dtype() -> None:
376
s = pl.Series([[1], [2, 3]])
377
return_dtype = pl.List(pl.UInt16)
378
379
result = s.map_elements(
380
lambda s: [i + 1 for i in s],
381
return_dtype=return_dtype,
382
)
383
expected = pl.Series([[2], [3, 4]], dtype=return_dtype)
384
assert_series_equal(result, expected)
385
386
387
def test_map_elements_list_of_named_tuple_15425() -> None:
388
class Foo(NamedTuple):
389
x: int
390
391
df = pl.DataFrame({"a": [0, 1, 2]})
392
result = df.select(
393
pl.col("a").map_elements(
394
lambda x: [Foo(i) for i in range(x)],
395
return_dtype=pl.List(pl.Struct({"x": pl.Int64})),
396
)
397
)
398
expected = pl.DataFrame({"a": [[], [{"x": 0}], [{"x": 0}, {"x": 1}]]})
399
assert_frame_equal(result, expected)
400
401
402
def test_map_elements_list_dtype_24006() -> None:
403
values = [None, [1, 2], [2, 3]]
404
dtype = pl.List(pl.Int64)
405
406
s1 = pl.Series([0, 1, 2]).map_elements(lambda x: values[x])
407
s2 = pl.Series([0, 1, 2]).map_elements(lambda x: values[x], return_dtype=dtype)
408
409
assert_series_equal(s1, s2)
410
assert_series_equal(s1, pl.Series(values, dtype=dtype))
411
412
413
def test_map_elements_reentrant_mutable_no_deadlock() -> None:
414
s = pl.Series("a", [1, 2, 3])
415
s.map_elements(lambda _: s.rechunk(in_place=True)[0])
416
417