Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/operations/arithmetic/test_array.py
6940 views
1
from __future__ import annotations
2
3
from typing import TYPE_CHECKING, Any, Callable
4
5
import pytest
6
7
import polars as pl
8
from polars.exceptions import InvalidOperationError
9
from polars.testing import assert_series_equal
10
from tests.unit.operations.arithmetic.utils import (
11
BROADCAST_SERIES_COMBINATIONS,
12
EXEC_OP_COMBINATIONS,
13
)
14
15
if TYPE_CHECKING:
16
from polars._typing import PolarsDataType
17
18
19
@pytest.mark.parametrize(
20
"array_side", ["left", "left3", "both", "both3", "right3", "right", "none"]
21
)
22
@pytest.mark.parametrize(
23
"broadcast_series",
24
BROADCAST_SERIES_COMBINATIONS,
25
)
26
@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS)
27
@pytest.mark.slow
28
def test_array_arithmetic_values(
29
array_side: str,
30
broadcast_series: Callable[
31
[pl.Series, pl.Series, pl.Series], tuple[pl.Series, pl.Series, pl.Series]
32
],
33
exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series],
34
) -> None:
35
"""
36
Tests value correctness.
37
38
This test checks for output value correctness (a + b == c) across different
39
codepaths, by wrapping the values (a, b, c) in different combinations of
40
list / primitive columns.
41
"""
42
import operator as op
43
44
dtypes: list[Any] = [pl.Null, pl.Null, pl.Null]
45
dtype: Any = pl.Null
46
47
def materialize_array(v: Any) -> pl.Series:
48
return pl.Series(
49
[[None, v, None]],
50
dtype=pl.Array(dtype, 3),
51
)
52
53
def materialize_array3(v: Any) -> pl.Series:
54
return pl.Series(
55
[[[[None, v], None], None]],
56
dtype=pl.Array(pl.Array(pl.Array(dtype, 2), 2), 2),
57
)
58
59
def materialize_primitive(v: Any) -> pl.Series:
60
return pl.Series([v], dtype=dtype)
61
62
def materialize_series(
63
l: Any, # noqa: E741
64
r: Any,
65
o: Any,
66
) -> tuple[pl.Series, pl.Series, pl.Series]:
67
nonlocal dtype
68
69
dtype = dtypes[0]
70
l = { # noqa: E741
71
"left": materialize_array,
72
"left3": materialize_array3,
73
"both": materialize_array,
74
"both3": materialize_array3,
75
"right": materialize_primitive,
76
"right3": materialize_primitive,
77
"none": materialize_primitive,
78
}[array_side](l) # fmt: skip
79
80
dtype = dtypes[1]
81
r = {
82
"left": materialize_primitive,
83
"left3": materialize_primitive,
84
"both": materialize_array,
85
"both3": materialize_array3,
86
"right": materialize_array,
87
"right3": materialize_array3,
88
"none": materialize_primitive,
89
}[array_side](r) # fmt: skip
90
91
dtype = dtypes[2]
92
o = {
93
"left": materialize_array,
94
"left3": materialize_array3,
95
"both": materialize_array,
96
"both3": materialize_array3,
97
"right": materialize_array,
98
"right3": materialize_array3,
99
"none": materialize_primitive,
100
}[array_side](o) # fmt: skip
101
102
assert l.len() == 1
103
assert r.len() == 1
104
assert o.len() == 1
105
106
return broadcast_series(l, r, o)
107
108
# Signed
109
dtypes = [pl.Int8, pl.Int8, pl.Int8]
110
111
l, r, o = materialize_series(2, 3, 5) # noqa: E741
112
assert_series_equal(exec_op(l, r, op.add), o)
113
114
l, r, o = materialize_series(-5, 127, 124) # noqa: E741
115
assert_series_equal(exec_op(l, r, op.sub), o)
116
117
l, r, o = materialize_series(-5, 127, -123) # noqa: E741
118
assert_series_equal(exec_op(l, r, op.mul), o)
119
120
l, r, o = materialize_series(-5, 3, -2) # noqa: E741
121
assert_series_equal(exec_op(l, r, op.floordiv), o)
122
123
l, r, o = materialize_series(-5, 3, 1) # noqa: E741
124
assert_series_equal(exec_op(l, r, op.mod), o)
125
126
dtypes = [pl.UInt8, pl.UInt8, pl.Float64]
127
l, r, o = materialize_series(2, 128, 0.015625) # noqa: E741
128
assert_series_equal(exec_op(l, r, op.truediv), o)
129
130
# Unsigned
131
dtypes = [pl.UInt8, pl.UInt8, pl.UInt8]
132
133
l, r, o = materialize_series(2, 3, 5) # noqa: E741
134
assert_series_equal(exec_op(l, r, op.add), o)
135
136
l, r, o = materialize_series(2, 3, 255) # noqa: E741
137
assert_series_equal(exec_op(l, r, op.sub), o)
138
139
l, r, o = materialize_series(2, 128, 0) # noqa: E741
140
assert_series_equal(exec_op(l, r, op.mul), o)
141
142
l, r, o = materialize_series(5, 2, 2) # noqa: E741
143
assert_series_equal(exec_op(l, r, op.floordiv), o)
144
145
l, r, o = materialize_series(5, 2, 1) # noqa: E741
146
assert_series_equal(exec_op(l, r, op.mod), o)
147
148
dtypes = [pl.UInt8, pl.UInt8, pl.Float64]
149
l, r, o = materialize_series(2, 128, 0.015625) # noqa: E741
150
assert_series_equal(exec_op(l, r, op.truediv), o)
151
152
# Floats. Note we pick Float32 to ensure there is no accidental upcasting
153
# to Float64.
154
dtypes = [pl.Float32, pl.Float32, pl.Float32]
155
l, r, o = materialize_series(1.7, 2.3, 4.0) # noqa: E741
156
assert_series_equal(exec_op(l, r, op.add), o)
157
158
l, r, o = materialize_series(1.7, 2.3, -0.5999999999999999) # noqa: E741
159
assert_series_equal(exec_op(l, r, op.sub), o)
160
161
l, r, o = materialize_series(1.7, 2.3, 3.9099999999999997) # noqa: E741
162
assert_series_equal(exec_op(l, r, op.mul), o)
163
164
l, r, o = materialize_series(7.0, 3.0, 2.0) # noqa: E741
165
assert_series_equal(exec_op(l, r, op.floordiv), o)
166
167
l, r, o = materialize_series(-5.0, 3.0, 1.0) # noqa: E741
168
assert_series_equal(exec_op(l, r, op.mod), o)
169
170
l, r, o = materialize_series(2.0, 128.0, 0.015625) # noqa: E741
171
assert_series_equal(exec_op(l, r, op.truediv), o)
172
173
#
174
# Tests for zero behavior
175
#
176
177
# Integer
178
179
dtypes = [pl.UInt8, pl.UInt8, pl.UInt8]
180
181
l, r, o = materialize_series(1, 0, None) # noqa: E741
182
assert_series_equal(exec_op(l, r, op.floordiv), o)
183
assert_series_equal(exec_op(l, r, op.mod), o)
184
185
l, r, o = materialize_series(0, 0, None) # noqa: E741
186
assert_series_equal(exec_op(l, r, op.floordiv), o)
187
assert_series_equal(exec_op(l, r, op.mod), o)
188
189
dtypes = [pl.UInt8, pl.UInt8, pl.Float64]
190
191
l, r, o = materialize_series(1, 0, float("inf")) # noqa: E741
192
assert_series_equal(exec_op(l, r, op.truediv), o)
193
194
l, r, o = materialize_series(0, 0, float("nan")) # noqa: E741
195
assert_series_equal(exec_op(l, r, op.truediv), o)
196
197
# Float
198
199
dtypes = [pl.Float32, pl.Float32, pl.Float32]
200
201
l, r, o = materialize_series(1, 0, float("inf")) # noqa: E741
202
assert_series_equal(exec_op(l, r, op.floordiv), o)
203
204
l, r, o = materialize_series(1, 0, float("nan")) # noqa: E741
205
assert_series_equal(exec_op(l, r, op.mod), o)
206
207
l, r, o = materialize_series(1, 0, float("inf")) # noqa: E741
208
assert_series_equal(exec_op(l, r, op.truediv), o)
209
210
l, r, o = materialize_series(0, 0, float("nan")) # noqa: E741
211
assert_series_equal(exec_op(l, r, op.floordiv), o)
212
213
l, r, o = materialize_series(0, 0, float("nan")) # noqa: E741
214
assert_series_equal(exec_op(l, r, op.mod), o)
215
216
l, r, o = materialize_series(0, 0, float("nan")) # noqa: E741
217
assert_series_equal(exec_op(l, r, op.truediv), o)
218
219
#
220
# Tests for NULL behavior
221
#
222
223
for dtype, truediv_dtype in [ # type: ignore[misc]
224
[pl.Int8, pl.Float64],
225
[pl.Float32, pl.Float32],
226
]:
227
for vals in [
228
[None, None, None],
229
[0, None, None],
230
[None, 0, None],
231
[0, None, None],
232
[None, 0, None],
233
[3, None, None],
234
[None, 3, None],
235
]:
236
dtypes = 3 * [dtype]
237
238
l, r, o = materialize_series(*vals) # type: ignore[misc] # noqa: E741
239
assert_series_equal(exec_op(l, r, op.add), o)
240
assert_series_equal(exec_op(l, r, op.sub), o)
241
assert_series_equal(exec_op(l, r, op.mul), o)
242
assert_series_equal(exec_op(l, r, op.floordiv), o)
243
assert_series_equal(exec_op(l, r, op.mod), o)
244
dtypes[2] = truediv_dtype # type: ignore[has-type]
245
l, r, o = materialize_series(*vals) # type: ignore[misc] # noqa: E741
246
assert_series_equal(exec_op(l, r, op.truediv), o)
247
248
# Type upcasting for Boolean and Null
249
250
# Check boolean upcasting
251
dtypes = [pl.Boolean, pl.UInt8, pl.UInt8]
252
253
l, r, o = materialize_series(True, 3, 4) # noqa: E741
254
assert_series_equal(exec_op(l, r, op.add), o)
255
256
l, r, o = materialize_series(True, 3, 254) # noqa: E741
257
assert_series_equal(exec_op(l, r, op.sub), o)
258
259
l, r, o = materialize_series(True, 3, 3) # noqa: E741
260
assert_series_equal(exec_op(l, r, op.mul), o)
261
262
l, r, o = materialize_series(True, 3, 0) # noqa: E741
263
if array_side != "none":
264
# TODO: FIXME: We get an error on non-lists with this:
265
# "floor_div operation not supported for dtype `bool`"
266
assert_series_equal(exec_op(l, r, op.floordiv), o)
267
268
l, r, o = materialize_series(True, 3, 1) # noqa: E741
269
assert_series_equal(exec_op(l, r, op.mod), o)
270
271
dtypes = [pl.Boolean, pl.UInt8, pl.Float64]
272
l, r, o = materialize_series(True, 128, 0.0078125) # noqa: E741
273
assert_series_equal(exec_op(l, r, op.truediv), o)
274
275
# Check Null upcasting
276
dtypes = [pl.Null, pl.UInt8, pl.UInt8]
277
l, r, o = materialize_series(None, 3, None) # noqa: E741
278
assert_series_equal(exec_op(l, r, op.add), o)
279
assert_series_equal(exec_op(l, r, op.sub), o)
280
assert_series_equal(exec_op(l, r, op.mul), o)
281
if array_side != "none":
282
assert_series_equal(exec_op(l, r, op.floordiv), o)
283
assert_series_equal(exec_op(l, r, op.mod), o)
284
285
dtypes = [pl.Null, pl.UInt8, pl.Float64]
286
l, r, o = materialize_series(None, 3, None) # noqa: E741
287
assert_series_equal(exec_op(l, r, op.truediv), o)
288
289
290
@pytest.mark.parametrize(
291
("lhs_dtype", "rhs_dtype", "expected_dtype"),
292
[
293
(pl.Array(pl.Int64, 2), pl.Int64, pl.Array(pl.Float64, 2)),
294
(pl.Array(pl.Float32, 2), pl.Float32, pl.Array(pl.Float32, 2)),
295
(pl.Array(pl.Duration("us"), 2), pl.Int64, pl.Array(pl.Duration("us"), 2)),
296
],
297
)
298
def test_array_truediv_schema(
299
lhs_dtype: PolarsDataType, rhs_dtype: PolarsDataType, expected_dtype: PolarsDataType
300
) -> None:
301
schema = {"lhs": lhs_dtype, "rhs": rhs_dtype}
302
df = pl.DataFrame({"lhs": [[None, 10]], "rhs": 2}, schema=schema)
303
result = df.lazy().select(pl.col("lhs").truediv("rhs")).collect_schema()["lhs"]
304
assert result == expected_dtype
305
306
307
def test_array_literal_broadcast() -> None:
308
df = pl.DataFrame({"A": [[0.1, 0.2], [0.3, 0.4]]}).cast(pl.Array(float, 2))
309
310
lit = pl.lit([3, 5], pl.Array(float, 2))
311
assert df.select(
312
mul=pl.all() * lit,
313
div=pl.all() / lit,
314
add=pl.all() + lit,
315
sub=pl.all() - lit,
316
div_=lit / pl.all(),
317
add_=lit + pl.all(),
318
sub_=lit - pl.all(),
319
mul_=lit * pl.all(),
320
).to_dict(as_series=False) == {
321
"mul": [[0.30000000000000004, 1.0], [0.8999999999999999, 2.0]],
322
"div": [[0.03333333333333333, 0.04], [0.09999999999999999, 0.08]],
323
"add": [[3.1, 5.2], [3.3, 5.4]],
324
"sub": [[-2.9, -4.8], [-2.7, -4.6]],
325
"div_": [[30.0, 25.0], [10.0, 12.5]],
326
"add_": [[3.1, 5.2], [3.3, 5.4]],
327
"sub_": [[2.9, 4.8], [2.7, 4.6]],
328
"mul_": [[0.30000000000000004, 1.0], [0.8999999999999999, 2.0]],
329
}
330
331
332
def test_array_arith_double_nested_shape() -> None:
333
# Ensure the implementation doesn't just naively add the leaf arrays without
334
# checking the dimension. In this example both arrays have the leaf stride as
335
# 6, however one is (3, 2) while the other is (2, 3).
336
a = pl.Series([[[1, 1], [1, 1], [1, 1]]], dtype=pl.Array(pl.Array(pl.Int64, 2), 3))
337
b = pl.Series([[[1, 1, 1], [1, 1, 1]]], dtype=pl.Array(pl.Array(pl.Int64, 3), 2))
338
339
with pytest.raises(InvalidOperationError, match="differing dtypes"):
340
a + b
341
342
343
@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS)
344
@pytest.mark.parametrize(
345
"broadcast_series",
346
BROADCAST_SERIES_COMBINATIONS,
347
)
348
@pytest.mark.slow
349
def test_array_numeric_op_validity_combination(
350
broadcast_series: Callable[
351
[pl.Series, pl.Series, pl.Series], tuple[pl.Series, pl.Series, pl.Series]
352
],
353
exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series],
354
) -> None:
355
import operator as op
356
357
array_dtype = pl.Array(pl.Int64, 1)
358
359
a = pl.Series("a", [[1], [2], None, [None], [11], [1111]], dtype=array_dtype)
360
b = pl.Series("b", [[1], [3], [11], [1111], None, [None]], dtype=array_dtype)
361
# expected result
362
e = pl.Series("a", [[2], [5], None, [None], None, [None]], dtype=array_dtype)
363
364
assert_series_equal(
365
exec_op(a, b, op.add),
366
e,
367
)
368
369
a = pl.Series("a", [[1]], dtype=array_dtype)
370
b = pl.Series("b", [None], dtype=pl.Int64)
371
e = pl.Series("a", [[None]], dtype=array_dtype)
372
373
a, b, e = broadcast_series(a, b, e)
374
assert_series_equal(exec_op(a, b, op.add), e)
375
376
a = pl.Series("a", [None], dtype=array_dtype)
377
b = pl.Series("b", [1], dtype=pl.Int64)
378
e = pl.Series("a", [None], dtype=array_dtype)
379
380
a, b, e = broadcast_series(a, b, e)
381
assert_series_equal(exec_op(a, b, op.add), e)
382
383
a = pl.Series("a", [None], dtype=array_dtype)
384
b = pl.Series("b", [0], dtype=pl.Int64)
385
e = pl.Series("a", [None], dtype=array_dtype)
386
387
a, b, e = broadcast_series(a, b, e)
388
assert_series_equal(exec_op(a, b, op.floordiv), e)
389
390
# >1 level nested array
391
a = pl.Series(
392
# row 1: [ [1, NULL], NULL ]
393
# row 2: NULL
394
[[[1, None], None], None],
395
dtype=pl.Array(pl.Array(pl.Int64, 2), 2),
396
)
397
b = pl.Series(
398
[[[0, 0], [0, 0]], [[0, 0], [0, 0]]],
399
dtype=pl.Array(pl.Array(pl.Int64, 2), 2),
400
)
401
e = a # added 0
402
assert_series_equal(exec_op(a, b, op.add), e)
403
404
405
def test_array_elementwise_arithmetic_19682() -> None:
406
dt = pl.Array(pl.Int64, (2, 3))
407
408
a = pl.Series("a", [[[1, 2, 3], [4, 5, 6]]], dt)
409
sc = pl.Series("a", [1])
410
zfa = pl.Series("a", [[]], pl.Array(pl.Int64, 0))
411
412
assert_series_equal(a + a, pl.Series("a", [[[2, 4, 6], [8, 10, 12]]], dt))
413
assert_series_equal(a + sc, pl.Series("a", [[[2, 3, 4], [5, 6, 7]]], dt))
414
assert_series_equal(sc + a, pl.Series("a", [[[2, 3, 4], [5, 6, 7]]], dt))
415
assert_series_equal(zfa + zfa, pl.Series("a", [[]], pl.Array(pl.Int64, 0)))
416
417
418
@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS)
419
def test_array_add_supertype(
420
exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series],
421
) -> None:
422
import operator as op
423
424
a = pl.Series("a", [[1], [2]], dtype=pl.Array(pl.Int8, 1))
425
b = pl.Series("b", [[1], [999]], dtype=pl.Array(pl.Int64, 1))
426
427
assert_series_equal(
428
exec_op(a, b, op.add),
429
pl.Series("a", [[2], [1001]], dtype=pl.Array(pl.Int64, 1)),
430
)
431
432
433
@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS)
434
def test_array_arithmetic_dtype_mismatch(
435
exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series],
436
) -> None:
437
import operator as op
438
439
a = pl.Series("a", [[1], [2]], dtype=pl.Array(pl.Int64, 1))
440
b = pl.Series("b", [[1, 1], [999, 999]], dtype=pl.Array(pl.Int64, 2))
441
442
with pytest.raises(InvalidOperationError, match="differing dtypes"):
443
exec_op(a, b, op.add)
444
445
a = pl.Series([[[1]], [[1]]], dtype=pl.Array(pl.List(pl.Int64), 1))
446
b = pl.Series([1], dtype=pl.Int64)
447
448
with pytest.raises(
449
InvalidOperationError, match="dtype was not array on all nesting levels"
450
):
451
exec_op(a, a, op.add)
452
453
with pytest.raises(
454
InvalidOperationError, match="dtype was not array on all nesting levels"
455
):
456
exec_op(a, b, op.add)
457
458
with pytest.raises(
459
InvalidOperationError, match="dtype was not array on all nesting levels"
460
):
461
exec_op(b, a, op.add)
462
463