Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/operations/map/test_map_batches.py
6940 views
1
from __future__ import annotations
2
3
from functools import reduce
4
5
import numpy as np
6
import pytest
7
8
import polars as pl
9
from polars.exceptions import ComputeError, InvalidOperationError
10
from polars.testing import assert_frame_equal
11
12
13
@pytest.mark.may_fail_cloud # reason: eager - return_dtype must be set
14
def test_map_return_py_object() -> None:
15
df = pl.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
16
17
result = df.select(
18
[
19
pl.all().map_batches(
20
lambda s: reduce(lambda a, b: a + b, s), returns_scalar=True
21
)
22
]
23
)
24
25
expected = pl.DataFrame({"A": [6], "B": [15]})
26
assert_frame_equal(result, expected)
27
28
29
@pytest.mark.may_fail_cloud # reason: eager - return_dtype must be set
30
def test_map_no_dtype_set_8531() -> None:
31
df = pl.DataFrame({"a": [1]})
32
33
result = df.with_columns(
34
pl.col("a").map_batches(lambda x: x * 2).shift(n=0, fill_value=0)
35
)
36
37
expected = pl.DataFrame({"a": [2]})
38
assert_frame_equal(result, expected)
39
40
41
def test_error_on_reducing_map() -> None:
42
df = pl.DataFrame(
43
{"id": [0, 0, 0, 1, 1, 1], "t": [2, 4, 5, 10, 11, 14], "y": [0, 1, 1, 2, 3, 4]}
44
)
45
assert_frame_equal(
46
df.group_by("id").agg(
47
pl.map_batches(["t", "y"], np.mean, pl.Float64(), returns_scalar=True)
48
),
49
pl.DataFrame(
50
{
51
"id": [0, 1],
52
"t": [2.166667, 7.333333],
53
}
54
),
55
check_row_order=False,
56
)
57
58
df = pl.DataFrame({"x": [1, 2, 3, 4], "group": [1, 2, 1, 2]})
59
60
with pytest.raises(
61
InvalidOperationError,
62
match=(
63
r"output length of `map` \(1\) must be equal to "
64
r"the input length \(4\); consider using `apply` instead"
65
),
66
):
67
df.select(
68
pl.col("x")
69
.map_batches(
70
lambda x: pl.Series(
71
[x.cut(breaks=[1, 2, 3], include_breaks=True).struct.unnest()]
72
),
73
is_elementwise=True,
74
)
75
.over("group")
76
)
77
78
79
def test_map_batches_group() -> None:
80
df = pl.DataFrame(
81
{"id": [0, 0, 0, 1, 1, 1], "t": [2, 4, 5, 10, 11, 14], "y": [0, 1, 1, 2, 3, 4]}
82
)
83
with pytest.raises(
84
TypeError,
85
match="`map` with `returns_scalar=False` must return a Series; found 'int'",
86
):
87
df.group_by("id").agg(
88
pl.col("t").map_batches(lambda s: s.sum(), return_dtype=pl.self_dtype())
89
)
90
# If returns_scalar is True, the result won't be wrapped in a list:
91
assert df.group_by("id").agg(
92
pl.col("t").map_batches(
93
lambda s: s.sum(), returns_scalar=True, return_dtype=pl.self_dtype()
94
)
95
).sort("id").to_dict(as_series=False) == {"id": [0, 1], "t": [11, 35]}
96
97
98
def test_ufunc_args() -> None:
99
df = pl.DataFrame({"a": [1, 2, 3], "b": [2, 4, 6]})
100
result = df.select(
101
z=np.add(pl.col("a"), pl.col("b")) # type: ignore[call-overload]
102
)
103
expected = pl.DataFrame({"z": [3, 6, 9]})
104
assert_frame_equal(result, expected)
105
result = df.select(z=np.add(2, pl.col("a"))) # type: ignore[call-overload]
106
expected = pl.DataFrame({"z": [3, 4, 5]})
107
assert_frame_equal(result, expected)
108
109
110
def test_lazy_map_schema() -> None:
111
df = pl.DataFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]})
112
113
# identity
114
assert_frame_equal(df.lazy().map_batches(lambda x: x).collect(), df)
115
116
def custom(df: pl.DataFrame) -> pl.Series:
117
return df["a"]
118
119
with pytest.raises(
120
ComputeError,
121
match="Expected 'LazyFrame.map' to return a 'DataFrame', got a",
122
):
123
df.lazy().map_batches(custom).collect() # type: ignore[arg-type]
124
125
def custom2(
126
df: pl.DataFrame,
127
) -> pl.DataFrame:
128
# changes schema
129
return df.select(pl.all().cast(pl.String))
130
131
with pytest.raises(
132
ComputeError,
133
match="The output schema of 'LazyFrame.map' is incorrect. Expected",
134
):
135
df.lazy().map_batches(custom2).collect()
136
137
assert df.lazy().map_batches(
138
custom2, validate_output_schema=False
139
).collect().to_dict(as_series=False) == {"a": ["1", "2", "3"], "b": ["a", "b", "c"]}
140
141
142
def test_map_batches_collect_schema_17327() -> None:
143
df = pl.LazyFrame({"a": [1, 1, 1], "b": [2, 3, 4]})
144
q = df.group_by("a").agg(
145
pl.col("b").map_batches(lambda s: s, return_dtype=pl.self_dtype())
146
)
147
expected = pl.Schema({"a": pl.Int64(), "b": pl.List(pl.Int64)})
148
assert q.collect_schema() == expected
149
150