Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/operations/namespaces/array/test_eval.py
8354 views
1
from __future__ import annotations
2
3
from typing import TYPE_CHECKING
4
5
import pytest
6
7
import polars as pl
8
from polars.testing import assert_series_equal
9
10
if TYPE_CHECKING:
11
from collections.abc import Callable
12
13
14
def set_nulls(s: pl.Series, mask: list[bool]) -> pl.Series:
15
return pl.select(pl.when(pl.Series(mask)).then(s).alias(s.name)).to_series()
16
17
18
@pytest.mark.parametrize("as_list", [False, True])
19
@pytest.mark.parametrize(
20
"nulls",
21
[
22
[True] * 3,
23
[False, True, True],
24
[True, False, True],
25
[True, True, False],
26
[False, False, True],
27
[True, False, False],
28
[False] * 3,
29
],
30
)
31
def test_eval_basic(as_list: bool, nulls: list[bool]) -> None:
32
if as_list:
33
34
def rtdt(dt: pl.DataType) -> pl.DataType:
35
return pl.List(dt)
36
else:
37
38
def rtdt(dt: pl.DataType) -> pl.DataType:
39
return pl.Array(dt, 2)
40
41
s = set_nulls(
42
pl.Series("a", [[1, 4], [8, 5], [3, 2]], pl.Array(pl.Int64(), 2)), nulls
43
)
44
45
assert_series_equal(
46
s.arr.eval(pl.element().rank(), as_list=as_list),
47
set_nulls(
48
pl.Series("a", [[1.0, 2.0], [2.0, 1.0], [2.0, 1.0]], rtdt(pl.Float64())),
49
nulls,
50
),
51
)
52
assert_series_equal(
53
s.arr.eval(pl.element() + 1, as_list=as_list),
54
set_nulls(pl.Series("a", [[2, 5], [9, 6], [4, 3]], rtdt(pl.Int64())), nulls),
55
)
56
assert_series_equal(
57
s.arr.eval(pl.element().cast(pl.String()), as_list=as_list),
58
s.cast(rtdt(pl.Int64())).cast(rtdt(pl.String())),
59
)
60
61
if as_list:
62
assert_series_equal(
63
s.arr.eval(pl.element().unique(maintain_order=True), as_list=True),
64
s.cast(rtdt(pl.Int64())),
65
)
66
67
68
def test_eval_raises_for_non_length_preserving() -> None:
69
s = pl.Series(
70
"a", [["A", "B", "C"], ["C", "C", "D"], ["D", "E", "E"]], pl.Array(pl.String, 3)
71
)
72
73
with pytest.raises(pl.exceptions.InvalidOperationError, match="as_list"):
74
s.arr.eval(pl.element().unique(maintain_order=True))
75
76
77
@pytest.mark.parametrize(
78
"nulls",
79
[
80
[True] * 3,
81
[False, True, True],
82
[True, False, True],
83
[True, True, False],
84
[False, False, True],
85
[True, False, False],
86
[False] * 3,
87
],
88
)
89
def test_eval_changing_length(nulls: list[bool]) -> None:
90
s = set_nulls(
91
pl.Series(
92
"a",
93
[["A", "B", "C"], ["C", "C", "D"], ["D", "E", "E"]],
94
pl.Array(pl.String, 3),
95
),
96
nulls,
97
)
98
99
assert_series_equal(
100
s.arr.eval(pl.element().unique(maintain_order=True), as_list=True),
101
set_nulls(
102
pl.Series(
103
"a", [["A", "B", "C"], ["C", "D"], ["D", "E"]], pl.List(pl.String)
104
),
105
nulls,
106
),
107
)
108
109
110
def set_validity(s: pl.Series, validity: list[bool]) -> pl.Series:
111
return s.zip_with(pl.Series(validity), pl.Series([None], dtype=s.dtype))
112
113
114
@pytest.mark.parametrize(
115
"sum_expr",
116
[pl.element().sum(), pl.element().unique().sum(), pl.element().fill_null(1).sum()],
117
)
118
def test_arr_agg_sum(sum_expr: pl.Expr) -> None:
119
assert_series_equal(
120
pl.Series("a", [], pl.Array(pl.Int64, 2)).arr.agg(sum_expr),
121
pl.Series("a", [], pl.Int64),
122
)
123
124
assert_series_equal(
125
pl.Series("a", [[0, 1, 2], [1, 3, 5]], pl.Array(pl.Int64, 3)).arr.agg(sum_expr),
126
pl.Series("a", [3, 9]),
127
)
128
129
assert_series_equal(
130
pl.Series("a", [[], []], pl.Array(pl.Int64, 0)).arr.agg(sum_expr),
131
pl.Series("a", [0, 0]),
132
)
133
134
assert_series_equal(
135
pl.Series("a", [None, [1, 3, 5]], pl.Array(pl.Int64, 3)).arr.agg(sum_expr),
136
pl.Series("a", [None, 9]),
137
)
138
139
assert_series_equal(
140
set_validity(
141
pl.Series("a", [[1, 2, 3], [3, 4, 5], [1, 3, 5]], pl.Array(pl.Int64, 3)),
142
[True, False, True],
143
).arr.agg(sum_expr),
144
pl.Series("a", [6, None, 9]),
145
)
146
147
148
@pytest.mark.parametrize(
149
("expr", "is_scalar"),
150
[
151
(pl.Expr.null_count, True),
152
(lambda e: e.rank().null_count(), True),
153
(pl.Expr.rank, False),
154
(lambda e: e + pl.lit(1), False),
155
(lambda e: e.filter(e != 0), False),
156
(pl.Expr.drop_nulls, False),
157
(pl.Expr.n_unique, True),
158
],
159
)
160
def test_arr_agg_parametric(
161
expr: Callable[[pl.Expr], pl.Expr], is_scalar: bool
162
) -> None:
163
def test_case(s: pl.Series) -> None:
164
out = s.arr.agg(expr(pl.element()))
165
166
for i, v in enumerate(s):
167
if v is None:
168
assert out[i] is None
169
continue
170
171
assert isinstance(v, pl.Series)
172
173
v = v.rename("")
174
v = v.to_frame().select(expr(pl.col(""))).to_series()
175
176
if not is_scalar:
177
v = v.implode()
178
179
assert_series_equal(out.rename("").slice(i, 1), v)
180
181
test_case(pl.Series("a", [], pl.Array(pl.Int64, 2)))
182
test_case(pl.Series("a", [[]], pl.Array(pl.Int64, 0)))
183
test_case(pl.Series("a", [[7], [0]], pl.Array(pl.Int64, 1)))
184
test_case(pl.Series("a", [[8], [0], None], pl.Array(pl.Int64, 1)))
185
test_case(pl.Series("a", [None, [0], None], pl.Array(pl.Int64, 1)))
186
test_case(pl.Series("a", [[1, 2, 3], [4, 5, 6]], pl.Array(pl.Int64, 3)))
187
188
189
@pytest.mark.parametrize("insert_none", [False, True])
190
@pytest.mark.parametrize("keys", [pl.lit(42), pl.col.g])
191
@pytest.mark.parametrize("filter", [None, pl.lit(True), pl.col.b])
192
@pytest.mark.parametrize(
193
("expr", "as_list", "result"),
194
[
195
(
196
pl.element(),
197
False,
198
pl.Series("a", [[0, 1, 2], [5, 3, 4], [7, 7, 8]], pl.Array(pl.Int64, 3)),
199
),
200
(
201
pl.element() + pl.element(),
202
False,
203
pl.Series(
204
"a", [[0, 2, 4], [10, 6, 8], [14, 14, 16]], pl.Array(pl.Int64, 3)
205
),
206
),
207
(
208
pl.element().rank(),
209
False,
210
pl.Series(
211
"a",
212
[[1.0, 2.0, 3.0], [3.0, 1.0, 2.0], [1.5, 1.5, 3.0]],
213
pl.Array(pl.Float64, 3),
214
),
215
),
216
(pl.element().unique(), True, pl.Series("a", [[0, 1, 2], [5, 3, 4], [7, 8]])),
217
],
218
)
219
def test_arr_eval_with_filter_in_agg_25384(
220
insert_none: bool,
221
keys: pl.Expr,
222
filter: pl.Expr | None,
223
expr: pl.Expr,
224
as_list: bool,
225
result: pl.Series,
226
) -> None:
227
s = pl.Series("a", [[0, 1, 2], [5, 3, 4], [7, 7, 8]], pl.Array(pl.Int64, 3))
228
df = s.to_frame().with_columns(
229
pl.Series("g", [10, 10, 20]), pl.Series("b", [True, True, True])
230
)
231
q_inner = (
232
pl.col("a").arr.eval(expr, as_list=as_list)
233
if filter is None
234
else pl.col("a").filter(filter).arr.eval(expr, as_list=as_list)
235
)
236
237
if insert_none:
238
df = df.with_columns(
239
pl.when(pl.int_range(0, pl.len()) != 1).then(pl.col.a).otherwise(None)
240
)
241
result = (
242
result.to_frame()
243
.with_columns(
244
pl.when(pl.int_range(0, pl.len()) != 1).then(pl.col.a).otherwise(None)
245
)
246
.to_series()
247
)
248
249
# no agg
250
q = df.lazy().select(q_inner)
251
assert_series_equal(q.collect().to_series(), result)
252
253
# over
254
q = df.lazy().select(q_inner.over(keys))
255
assert_series_equal(q.collect().to_series(), result)
256
257
# group_by
258
q = df.lazy().group_by(keys, maintain_order=True).agg(q_inner)
259
out = q.collect().select(pl.col.a).explode("a")
260
assert_series_equal(out.to_series(), result)
261
262
263
@pytest.mark.parametrize("insert_none", [False, True])
264
@pytest.mark.parametrize("keys", [pl.lit(42), pl.col.g])
265
@pytest.mark.parametrize("filter", [None, pl.lit(True), pl.col.b])
266
@pytest.mark.parametrize(
267
("expr", "result"),
268
[
269
(pl.element().sum(), pl.Series("a", [1, 8, 22])),
270
(pl.element().null_count(), pl.Series("a", [1, 1, 0], pl.get_index_type())),
271
],
272
)
273
def test_arr_agg_with_filter_in_agg_25384(
274
insert_none: bool,
275
keys: pl.Expr,
276
filter: pl.Expr | None,
277
expr: pl.Expr,
278
result: pl.Series,
279
) -> None:
280
s = pl.Series("a", [[0, 1, None], [5, 3, None], [7, 7, 8]], pl.Array(pl.Int64, 3))
281
df = s.to_frame().with_columns(
282
pl.Series("g", [10, 10, 20]), pl.Series("b", [True, True, True])
283
)
284
q_inner = (
285
pl.col("a").arr.agg(expr)
286
if filter is None
287
else pl.col("a").filter(filter).arr.agg(expr)
288
)
289
290
if insert_none:
291
df = df.with_columns(
292
pl.when(pl.int_range(0, pl.len()) != 1).then(pl.col.a).otherwise(None)
293
)
294
result = (
295
result.to_frame()
296
.with_columns(
297
pl.when(pl.int_range(0, pl.len()) != 1).then(pl.col.a).otherwise(None)
298
)
299
.to_series()
300
)
301
302
# no agg
303
q = df.lazy().select(q_inner)
304
assert_series_equal(q.collect().to_series(), result)
305
306
# over
307
q = df.lazy().select(q_inner.over(keys))
308
assert_series_equal(q.collect().to_series(), result)
309
310
# group_by
311
q = df.lazy().group_by(keys, maintain_order=True).agg(q_inner)
312
out = q.collect().select(pl.col.a).explode("a")
313
assert_series_equal(out.to_series(), result)
314
315