Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/operations/map/test_map_groups.py
8415 views
1
from __future__ import annotations
2
3
import re
4
from typing import TYPE_CHECKING, Any
5
6
import numpy as np
7
import pytest
8
9
import polars as pl
10
from polars.exceptions import ComputeError, ShapeError
11
from polars.testing import assert_frame_equal
12
13
if TYPE_CHECKING:
14
from collections.abc import Sequence
15
16
17
def test_map_groups() -> None:
18
df = pl.DataFrame(
19
{
20
"a": ["a", "b", "a", "b", "b", "c"],
21
"b": [1, 2, 3, 4, 5, 6],
22
"c": [6, 5, 4, 3, 2, 1],
23
}
24
)
25
26
result = df.group_by("a").map_groups(lambda df: df[["c"]].sum())
27
28
expected = pl.DataFrame({"c": [10, 10, 1]})
29
assert_frame_equal(result, expected, check_row_order=False)
30
31
32
def test_map_groups_lazy() -> None:
33
lf = pl.LazyFrame({"a": [1, 1, 3], "b": [1.0, 2.0, 3.0]})
34
35
schema = {"a": pl.Float64, "b": pl.Float64}
36
result = lf.group_by("a").map_groups(lambda df: df * 2.0, schema=schema)
37
38
expected = pl.LazyFrame({"a": [6.0, 2.0, 2.0], "b": [6.0, 2.0, 4.0]})
39
assert_frame_equal(result, expected, check_row_order=False)
40
assert result.collect_schema() == expected.collect_schema()
41
42
43
def test_map_groups_rolling() -> None:
44
df = pl.DataFrame(
45
{
46
"a": [1, 2, 3, 4, 5],
47
"b": [1, 2, 3, 4, 5],
48
}
49
).set_sorted("a")
50
51
def function(df: pl.DataFrame) -> pl.DataFrame:
52
return df.select(
53
pl.col("a").min(),
54
pl.col("b").max(),
55
)
56
57
result = df.rolling("a", period="2i").map_groups(function, schema=df.schema)
58
59
expected = pl.DataFrame(
60
[
61
pl.Series("a", [1, 1, 2, 3, 4], dtype=pl.Int64),
62
pl.Series("b", [1, 2, 3, 4, 5], dtype=pl.Int64),
63
]
64
)
65
assert_frame_equal(result, expected)
66
67
68
def test_map_groups_empty() -> None:
69
df = pl.DataFrame(schema={"x": pl.Int64})
70
with pytest.raises(
71
ComputeError, match=r"cannot group_by \+ apply on empty 'DataFrame'"
72
):
73
df.group_by("x").map_groups(lambda x: x)
74
75
schema = {"x": pl.Int64, "y": pl.Int64}
76
result = (
77
df.lazy()
78
.group_by("x")
79
.map_groups(lambda df: df.with_columns(pl.col("x").alias("y")), schema=schema)
80
)
81
82
expected = pl.LazyFrame(schema=schema)
83
assert_frame_equal(result, expected)
84
assert result.collect_schema() == expected.collect_schema()
85
86
87
def test_map_groups_none() -> None:
88
df = pl.DataFrame(
89
{
90
"g": [1, 1, 1, 2, 2, 2, 5],
91
"a": [2, 4, 5, 190, 1, 4, 1],
92
"b": [1, 3, 2, 1, 43, 3, 1],
93
}
94
)
95
96
out = (
97
df.group_by("g", maintain_order=True).agg(
98
pl.map_groups(
99
exprs=["a", pl.col("b") ** 4, pl.col("a") / 4],
100
function=lambda x: x[0] * x[1] + x[2].sum(),
101
return_dtype=pl.Float64,
102
returns_scalar=False,
103
).alias("multiple")
104
)
105
)["multiple"]
106
assert out[0].to_list() == [4.75, 326.75, 82.75]
107
assert out[1].to_list() == [238.75, 3418849.75, 372.75]
108
109
out_df = df.select(pl.map_batches(exprs=["a", "b"], function=lambda s: s[0] * s[1]))
110
assert out_df["a"].to_list() == (df["a"] * df["b"]).to_list()
111
112
# check if we can return None
113
def func(s: Sequence[pl.Series]) -> pl.Series | None:
114
if s[0][0] == 190:
115
return None
116
else:
117
return s[0].implode()
118
119
out = (
120
df.group_by("g", maintain_order=True).agg(
121
pl.map_groups(
122
exprs=["a", pl.col("b") ** 4, pl.col("a") / 4],
123
function=func,
124
return_dtype=pl.self_dtype().wrap_in_list(),
125
returns_scalar=True,
126
).alias("multiple")
127
)
128
)["multiple"]
129
assert out[1] is None
130
131
132
def test_map_groups_object_output() -> None:
133
df = pl.DataFrame(
134
{
135
"names": ["foo", "ham", "spam", "cheese", "egg", "foo"],
136
"dates": ["1", "1", "2", "3", "3", "4"],
137
"groups": ["A", "A", "B", "B", "B", "C"],
138
}
139
)
140
141
class Foo:
142
def __init__(self, payload: Any) -> None:
143
self.payload = payload
144
145
result = df.group_by("groups").agg(
146
pl.map_groups(
147
[pl.col("dates"), pl.col("names")],
148
lambda s: Foo(dict(zip(s[0], s[1], strict=True))),
149
return_dtype=pl.Object,
150
returns_scalar=True,
151
)
152
)
153
154
assert result.dtypes == [pl.String, pl.Object]
155
156
157
def test_map_groups_numpy_output_3057() -> None:
158
df = pl.DataFrame(
159
{
160
"id": [0, 0, 0, 1, 1, 1],
161
"t": [2.0, 4.3, 5, 10, 11, 14],
162
"y": [0.0, 1, 1.3, 2, 3, 4],
163
}
164
)
165
166
result = df.group_by("id", maintain_order=True).agg(
167
pl.map_groups(
168
["y", "t"],
169
lambda lst: np.mean([lst[0], lst[1]]),
170
returns_scalar=True,
171
return_dtype=pl.self_dtype(),
172
).alias("result")
173
)
174
175
expected = pl.DataFrame({"id": [0, 1], "result": [2.266666, 7.333333]})
176
assert_frame_equal(result, expected)
177
178
179
def test_map_groups_return_all_null_15260() -> None:
180
def foo(x: Sequence[pl.Series]) -> pl.Series:
181
return pl.Series([x[0][0]], dtype=x[0].dtype)
182
183
assert_frame_equal(
184
pl.DataFrame({"key": [0, 0, 1], "a": [None, None, None]})
185
.group_by("key")
186
.agg(
187
pl.map_groups(
188
exprs=["a"],
189
function=foo,
190
returns_scalar=True,
191
return_dtype=pl.self_dtype(),
192
)
193
)
194
.sort("key"),
195
pl.DataFrame({"key": [0, 1], "a": [None, None]}),
196
)
197
198
199
@pytest.mark.parametrize(
200
("func", "result"),
201
[
202
(lambda n: n[0] + n[1], [[85], [85]]),
203
(lambda _: pl.Series([1, 2, 3]), [[1, 2, 3], [1, 2, 3]]),
204
],
205
)
206
@pytest.mark.parametrize("maintain_order", [True, False])
207
def test_map_groups_multiple_all_literal(
208
func: Any, result: list[int], maintain_order: bool
209
) -> None:
210
df = pl.DataFrame({"g": [10, 10, 20], "a": [1, 2, 3], "b": [2, 3, 4]})
211
212
q = (
213
df.lazy()
214
.group_by(pl.col("g"), maintain_order=maintain_order)
215
.agg(
216
pl.map_groups(
217
exprs=[pl.lit(42).cast(pl.Int64), pl.lit(43).cast(pl.Int64)],
218
function=func,
219
return_dtype=pl.Int64,
220
).alias("out")
221
)
222
)
223
out = q.collect()
224
expected = pl.DataFrame({"g": [10, 20], "out": result})
225
assert_frame_equal(out, expected, check_row_order=maintain_order)
226
227
228
@pytest.mark.may_fail_auto_streaming # reason: alternate error message
229
def test_map_groups_multiple_all_literal_elementwise_raises() -> None:
230
df = pl.DataFrame({"g": [10, 10, 20], "a": [1, 2, 3], "b": [2, 3, 4]})
231
q = (
232
df.lazy()
233
.group_by(pl.col("g"))
234
.agg(
235
pl.map_groups(
236
exprs=[pl.lit(42), pl.lit(43)],
237
function=lambda _: pl.Series([1, 2, 3]),
238
return_dtype=pl.Int64,
239
is_elementwise=True,
240
).alias("out")
241
)
242
)
243
msg = "elementwise expression dyn int: 42.python_udf([dyn int: 43]) must return exactly 1 value on literals, got 3"
244
with pytest.raises(ComputeError, match=re.escape(msg)):
245
q.collect(engine="in-memory")
246
247
# different error message in streaming, not specific to the problem
248
with pytest.raises(ShapeError):
249
q.collect(engine="streaming")
250
251
252
def test_nested_query_with_streaming_dispatch_25172() -> None:
253
def simple(_: Any) -> pl.Series:
254
import io
255
256
pl.LazyFrame({}).sink_parquet(
257
pl.PartitionBy(
258
"", file_path_provider=lambda _: io.BytesIO(), max_rows_per_file=1
259
),
260
)
261
262
return pl.Series([1])
263
264
assert_frame_equal(
265
pl.LazyFrame({"a": ["A", "B"] * 1000, "b": [1] * 2000})
266
.group_by("a")
267
.agg(pl.map_groups(["b"], simple, pl.Int64(), returns_scalar=True))
268
.collect(engine="in-memory")
269
.sort("a"),
270
pl.DataFrame({"a": ["A", "B"], "b": [1, 1]}, schema_overrides={"b": pl.Int64}),
271
)
272
273
274
def test_map_groups_with_slice_25805() -> None:
275
schema = {"a": pl.Int8, "b": pl.Int8}
276
277
df = (
278
pl.LazyFrame(
279
data={"a": [1, 1], "b": [1, 2]},
280
schema=schema,
281
)
282
.group_by("a", maintain_order=True)
283
.map_groups(lambda df: df, schema=schema)
284
.head(1)
285
.collect()
286
)
287
assert_frame_equal(df, pl.DataFrame({"a": [1], "b": [1]}, schema=schema))
288
289