Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/operations/arithmetic/test_array.py
8391 views
1
from __future__ import annotations
2
3
from typing import TYPE_CHECKING, Any
4
5
import pytest
6
7
import polars as pl
8
from polars.exceptions import InvalidOperationError
9
from polars.testing import assert_series_equal
10
from tests.unit.operations.arithmetic.utils import (
11
BROADCAST_SERIES_COMBINATIONS,
12
EXEC_OP_COMBINATIONS,
13
)
14
15
if TYPE_CHECKING:
16
from collections.abc import Callable
17
18
from polars._typing import PolarsDataType
19
20
21
@pytest.mark.parametrize(
22
"array_side", ["left", "left3", "both", "both3", "right3", "right", "none"]
23
)
24
@pytest.mark.parametrize(
25
"broadcast_series",
26
BROADCAST_SERIES_COMBINATIONS,
27
)
28
@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS)
29
@pytest.mark.slow
30
def test_array_arithmetic_values(
31
array_side: str,
32
broadcast_series: Callable[
33
[pl.Series, pl.Series, pl.Series], tuple[pl.Series, pl.Series, pl.Series]
34
],
35
exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series],
36
) -> None:
37
"""
38
Tests value correctness.
39
40
This test checks for output value correctness (a + b == c) across different
41
codepaths, by wrapping the values (a, b, c) in different combinations of
42
list / primitive columns.
43
"""
44
import operator as op
45
46
dtypes: list[Any] = [pl.Null, pl.Null, pl.Null]
47
dtype: Any = pl.Null
48
49
def materialize_array(v: Any) -> pl.Series:
50
return pl.Series(
51
[[None, v, None]],
52
dtype=pl.Array(dtype, 3),
53
)
54
55
def materialize_array3(v: Any) -> pl.Series:
56
return pl.Series(
57
[[[[None, v], None], None]],
58
dtype=pl.Array(pl.Array(pl.Array(dtype, 2), 2), 2),
59
)
60
61
def materialize_primitive(v: Any) -> pl.Series:
62
return pl.Series([v], dtype=dtype)
63
64
def materialize_series(
65
l: Any, # noqa: E741
66
r: Any,
67
o: Any,
68
) -> tuple[pl.Series, pl.Series, pl.Series]:
69
nonlocal dtype
70
71
dtype = dtypes[0]
72
l = { # noqa: E741
73
"left": materialize_array,
74
"left3": materialize_array3,
75
"both": materialize_array,
76
"both3": materialize_array3,
77
"right": materialize_primitive,
78
"right3": materialize_primitive,
79
"none": materialize_primitive,
80
}[array_side](l) # fmt: skip
81
82
dtype = dtypes[1]
83
r = {
84
"left": materialize_primitive,
85
"left3": materialize_primitive,
86
"both": materialize_array,
87
"both3": materialize_array3,
88
"right": materialize_array,
89
"right3": materialize_array3,
90
"none": materialize_primitive,
91
}[array_side](r) # fmt: skip
92
93
dtype = dtypes[2]
94
o = {
95
"left": materialize_array,
96
"left3": materialize_array3,
97
"both": materialize_array,
98
"both3": materialize_array3,
99
"right": materialize_array,
100
"right3": materialize_array3,
101
"none": materialize_primitive,
102
}[array_side](o) # fmt: skip
103
104
assert l.len() == 1
105
assert r.len() == 1
106
assert o.len() == 1
107
108
return broadcast_series(l, r, o)
109
110
# Signed
111
dtypes = [pl.Int8, pl.Int8, pl.Int8]
112
113
l, r, o = materialize_series(2, 3, 5) # noqa: E741
114
assert_series_equal(exec_op(l, r, op.add), o)
115
116
l, r, o = materialize_series(-5, 127, 124) # noqa: E741
117
assert_series_equal(exec_op(l, r, op.sub), o)
118
119
l, r, o = materialize_series(-5, 127, -123) # noqa: E741
120
assert_series_equal(exec_op(l, r, op.mul), o)
121
122
l, r, o = materialize_series(-5, 3, -2) # noqa: E741
123
assert_series_equal(exec_op(l, r, op.floordiv), o)
124
125
l, r, o = materialize_series(-5, 3, 1) # noqa: E741
126
assert_series_equal(exec_op(l, r, op.mod), o)
127
128
dtypes = [pl.UInt8, pl.UInt8, pl.Float64]
129
l, r, o = materialize_series(2, 128, 0.015625) # noqa: E741
130
assert_series_equal(exec_op(l, r, op.truediv), o)
131
132
# Unsigned
133
dtypes = [pl.UInt8, pl.UInt8, pl.UInt8]
134
135
l, r, o = materialize_series(2, 3, 5) # noqa: E741
136
assert_series_equal(exec_op(l, r, op.add), o)
137
138
l, r, o = materialize_series(2, 3, 255) # noqa: E741
139
assert_series_equal(exec_op(l, r, op.sub), o)
140
141
l, r, o = materialize_series(2, 128, 0) # noqa: E741
142
assert_series_equal(exec_op(l, r, op.mul), o)
143
144
l, r, o = materialize_series(5, 2, 2) # noqa: E741
145
assert_series_equal(exec_op(l, r, op.floordiv), o)
146
147
l, r, o = materialize_series(5, 2, 1) # noqa: E741
148
assert_series_equal(exec_op(l, r, op.mod), o)
149
150
dtypes = [pl.UInt8, pl.UInt8, pl.Float64]
151
l, r, o = materialize_series(2, 128, 0.015625) # noqa: E741
152
assert_series_equal(exec_op(l, r, op.truediv), o)
153
154
# Floats. Note we pick Float32 to ensure there is no accidental upcasting
155
# to Float64.
156
dtypes = [pl.Float32, pl.Float32, pl.Float32]
157
l, r, o = materialize_series(1.7, 2.3, 4.0) # noqa: E741
158
assert_series_equal(exec_op(l, r, op.add), o)
159
160
l, r, o = materialize_series(1.7, 2.3, -0.5999999999999999) # noqa: E741
161
assert_series_equal(exec_op(l, r, op.sub), o)
162
163
l, r, o = materialize_series(1.7, 2.3, 3.9099999999999997) # noqa: E741
164
assert_series_equal(exec_op(l, r, op.mul), o)
165
166
l, r, o = materialize_series(7.0, 3.0, 2.0) # noqa: E741
167
assert_series_equal(exec_op(l, r, op.floordiv), o)
168
169
l, r, o = materialize_series(-5.0, 3.0, 1.0) # noqa: E741
170
assert_series_equal(exec_op(l, r, op.mod), o)
171
172
l, r, o = materialize_series(2.0, 128.0, 0.015625) # noqa: E741
173
assert_series_equal(exec_op(l, r, op.truediv), o)
174
175
#
176
# Tests for zero behavior
177
#
178
179
# Integer
180
181
dtypes = [pl.UInt8, pl.UInt8, pl.UInt8]
182
183
l, r, o = materialize_series(1, 0, None) # noqa: E741
184
assert_series_equal(exec_op(l, r, op.floordiv), o)
185
assert_series_equal(exec_op(l, r, op.mod), o)
186
187
l, r, o = materialize_series(0, 0, None) # noqa: E741
188
assert_series_equal(exec_op(l, r, op.floordiv), o)
189
assert_series_equal(exec_op(l, r, op.mod), o)
190
191
dtypes = [pl.UInt8, pl.UInt8, pl.Float64]
192
193
l, r, o = materialize_series(1, 0, float("inf")) # noqa: E741
194
assert_series_equal(exec_op(l, r, op.truediv), o)
195
196
l, r, o = materialize_series(0, 0, float("nan")) # noqa: E741
197
assert_series_equal(exec_op(l, r, op.truediv), o)
198
199
# Float
200
201
dtypes = [pl.Float32, pl.Float32, pl.Float32]
202
203
l, r, o = materialize_series(1, 0, float("inf")) # noqa: E741
204
assert_series_equal(exec_op(l, r, op.floordiv), o)
205
206
l, r, o = materialize_series(1, 0, float("nan")) # noqa: E741
207
assert_series_equal(exec_op(l, r, op.mod), o)
208
209
l, r, o = materialize_series(1, 0, float("inf")) # noqa: E741
210
assert_series_equal(exec_op(l, r, op.truediv), o)
211
212
l, r, o = materialize_series(0, 0, float("nan")) # noqa: E741
213
assert_series_equal(exec_op(l, r, op.floordiv), o)
214
215
l, r, o = materialize_series(0, 0, float("nan")) # noqa: E741
216
assert_series_equal(exec_op(l, r, op.mod), o)
217
218
l, r, o = materialize_series(0, 0, float("nan")) # noqa: E741
219
assert_series_equal(exec_op(l, r, op.truediv), o)
220
221
#
222
# Tests for NULL behavior
223
#
224
225
for dtype, truediv_dtype in [ # type: ignore[misc]
226
[pl.Int8, pl.Float64],
227
[pl.Float32, pl.Float32],
228
]:
229
for vals in [
230
[None, None, None],
231
[0, None, None],
232
[None, 0, None],
233
[0, None, None],
234
[None, 0, None],
235
[3, None, None],
236
[None, 3, None],
237
]:
238
dtypes = 3 * [dtype]
239
240
l, r, o = materialize_series(*vals) # type: ignore[misc] # noqa: E741
241
assert_series_equal(exec_op(l, r, op.add), o)
242
assert_series_equal(exec_op(l, r, op.sub), o)
243
assert_series_equal(exec_op(l, r, op.mul), o)
244
assert_series_equal(exec_op(l, r, op.floordiv), o)
245
assert_series_equal(exec_op(l, r, op.mod), o)
246
dtypes[2] = truediv_dtype # type: ignore[has-type]
247
l, r, o = materialize_series(*vals) # type: ignore[misc] # noqa: E741
248
assert_series_equal(exec_op(l, r, op.truediv), o)
249
250
# Type upcasting for Boolean and Null
251
252
# Check boolean upcasting
253
dtypes = [pl.Boolean, pl.UInt8, pl.UInt8]
254
255
l, r, o = materialize_series(True, 3, 4) # noqa: E741
256
assert_series_equal(exec_op(l, r, op.add), o)
257
258
l, r, o = materialize_series(True, 3, 254) # noqa: E741
259
assert_series_equal(exec_op(l, r, op.sub), o)
260
261
l, r, o = materialize_series(True, 3, 3) # noqa: E741
262
assert_series_equal(exec_op(l, r, op.mul), o)
263
264
l, r, o = materialize_series(True, 3, 0) # noqa: E741
265
if array_side != "none":
266
# TODO: We get an error on non-lists with this:
267
# "floor_div operation not supported for dtype `bool`"
268
assert_series_equal(exec_op(l, r, op.floordiv), o)
269
270
l, r, o = materialize_series(True, 3, 1) # noqa: E741
271
assert_series_equal(exec_op(l, r, op.mod), o)
272
273
dtypes = [pl.Boolean, pl.UInt8, pl.Float64]
274
l, r, o = materialize_series(True, 128, 0.0078125) # noqa: E741
275
assert_series_equal(exec_op(l, r, op.truediv), o)
276
277
# Check Null upcasting
278
dtypes = [pl.Null, pl.UInt8, pl.UInt8]
279
l, r, o = materialize_series(None, 3, None) # noqa: E741
280
assert_series_equal(exec_op(l, r, op.add), o)
281
assert_series_equal(exec_op(l, r, op.sub), o)
282
assert_series_equal(exec_op(l, r, op.mul), o)
283
if array_side != "none":
284
assert_series_equal(exec_op(l, r, op.floordiv), o)
285
assert_series_equal(exec_op(l, r, op.mod), o)
286
287
dtypes = [pl.Null, pl.UInt8, pl.Float64]
288
l, r, o = materialize_series(None, 3, None) # noqa: E741
289
assert_series_equal(exec_op(l, r, op.truediv), o)
290
291
292
@pytest.mark.parametrize(
293
("lhs_dtype", "rhs_dtype", "expected_dtype"),
294
[
295
(pl.Array(pl.Int64, 2), pl.Int64, pl.Array(pl.Float64, 2)),
296
(pl.Array(pl.Float32, 2), pl.Float32, pl.Array(pl.Float32, 2)),
297
(pl.Array(pl.Duration("us"), 2), pl.Int64, pl.Array(pl.Duration("us"), 2)),
298
],
299
)
300
def test_array_truediv_schema(
301
lhs_dtype: PolarsDataType, rhs_dtype: PolarsDataType, expected_dtype: PolarsDataType
302
) -> None:
303
schema = {"lhs": lhs_dtype, "rhs": rhs_dtype}
304
df = pl.DataFrame({"lhs": [[None, 10]], "rhs": 2}, schema=schema)
305
result = df.lazy().select(pl.col("lhs").truediv("rhs")).collect_schema()["lhs"]
306
assert result == expected_dtype
307
308
309
def test_array_literal_broadcast() -> None:
310
df = pl.DataFrame({"A": [[0.1, 0.2], [0.3, 0.4]]}).cast(pl.Array(float, 2))
311
312
lit = pl.lit([3, 5], pl.Array(float, 2))
313
assert df.select(
314
mul=pl.all() * lit,
315
div=pl.all() / lit,
316
add=pl.all() + lit,
317
sub=pl.all() - lit,
318
div_=lit / pl.all(),
319
add_=lit + pl.all(),
320
sub_=lit - pl.all(),
321
mul_=lit * pl.all(),
322
).to_dict(as_series=False) == {
323
"mul": [[0.30000000000000004, 1.0], [0.8999999999999999, 2.0]],
324
"div": [[0.03333333333333333, 0.04], [0.09999999999999999, 0.08]],
325
"add": [[3.1, 5.2], [3.3, 5.4]],
326
"sub": [[-2.9, -4.8], [-2.7, -4.6]],
327
"div_": [[30.0, 25.0], [10.0, 12.5]],
328
"add_": [[3.1, 5.2], [3.3, 5.4]],
329
"sub_": [[2.9, 4.8], [2.7, 4.6]],
330
"mul_": [[0.30000000000000004, 1.0], [0.8999999999999999, 2.0]],
331
}
332
333
334
def test_array_arith_double_nested_shape() -> None:
335
# Ensure the implementation doesn't just naively add the leaf arrays without
336
# checking the dimension. In this example both arrays have the leaf stride as
337
# 6, however one is (3, 2) while the other is (2, 3).
338
a = pl.Series([[[1, 1], [1, 1], [1, 1]]], dtype=pl.Array(pl.Array(pl.Int64, 2), 3))
339
b = pl.Series([[[1, 1, 1], [1, 1, 1]]], dtype=pl.Array(pl.Array(pl.Int64, 3), 2))
340
341
with pytest.raises(InvalidOperationError, match="differing dtypes"):
342
a + b
343
344
345
@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS)
346
@pytest.mark.parametrize(
347
"broadcast_series",
348
BROADCAST_SERIES_COMBINATIONS,
349
)
350
@pytest.mark.slow
351
def test_array_numeric_op_validity_combination(
352
broadcast_series: Callable[
353
[pl.Series, pl.Series, pl.Series], tuple[pl.Series, pl.Series, pl.Series]
354
],
355
exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series],
356
) -> None:
357
import operator as op
358
359
array_dtype = pl.Array(pl.Int64, 1)
360
361
a = pl.Series("a", [[1], [2], None, [None], [11], [1111]], dtype=array_dtype)
362
b = pl.Series("b", [[1], [3], [11], [1111], None, [None]], dtype=array_dtype)
363
# expected result
364
e = pl.Series("a", [[2], [5], None, [None], None, [None]], dtype=array_dtype)
365
366
assert_series_equal(
367
exec_op(a, b, op.add),
368
e,
369
)
370
371
a = pl.Series("a", [[1]], dtype=array_dtype)
372
b = pl.Series("b", [None], dtype=pl.Int64)
373
e = pl.Series("a", [[None]], dtype=array_dtype)
374
375
a, b, e = broadcast_series(a, b, e)
376
assert_series_equal(exec_op(a, b, op.add), e)
377
378
a = pl.Series("a", [None], dtype=array_dtype)
379
b = pl.Series("b", [1], dtype=pl.Int64)
380
e = pl.Series("a", [None], dtype=array_dtype)
381
382
a, b, e = broadcast_series(a, b, e)
383
assert_series_equal(exec_op(a, b, op.add), e)
384
385
a = pl.Series("a", [None], dtype=array_dtype)
386
b = pl.Series("b", [0], dtype=pl.Int64)
387
e = pl.Series("a", [None], dtype=array_dtype)
388
389
a, b, e = broadcast_series(a, b, e)
390
assert_series_equal(exec_op(a, b, op.floordiv), e)
391
392
# >1 level nested array
393
a = pl.Series(
394
# row 1: [ [1, NULL], NULL ]
395
# row 2: NULL
396
[[[1, None], None], None],
397
dtype=pl.Array(pl.Array(pl.Int64, 2), 2),
398
)
399
b = pl.Series(
400
[[[0, 0], [0, 0]], [[0, 0], [0, 0]]],
401
dtype=pl.Array(pl.Array(pl.Int64, 2), 2),
402
)
403
e = a # added 0
404
assert_series_equal(exec_op(a, b, op.add), e)
405
406
407
def test_array_elementwise_arithmetic_19682() -> None:
408
dt = pl.Array(pl.Int64, (2, 3))
409
410
a = pl.Series("a", [[[1, 2, 3], [4, 5, 6]]], dt)
411
sc = pl.Series("a", [1])
412
zfa = pl.Series("a", [[]], pl.Array(pl.Int64, 0))
413
414
assert_series_equal(a + a, pl.Series("a", [[[2, 4, 6], [8, 10, 12]]], dt))
415
assert_series_equal(a + sc, pl.Series("a", [[[2, 3, 4], [5, 6, 7]]], dt))
416
assert_series_equal(sc + a, pl.Series("a", [[[2, 3, 4], [5, 6, 7]]], dt))
417
assert_series_equal(zfa + zfa, pl.Series("a", [[]], pl.Array(pl.Int64, 0)))
418
419
420
@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS)
421
def test_array_add_supertype(
422
exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series],
423
) -> None:
424
import operator as op
425
426
a = pl.Series("a", [[1], [2]], dtype=pl.Array(pl.Int8, 1))
427
b = pl.Series("b", [[1], [999]], dtype=pl.Array(pl.Int64, 1))
428
429
assert_series_equal(
430
exec_op(a, b, op.add),
431
pl.Series("a", [[2], [1001]], dtype=pl.Array(pl.Int64, 1)),
432
)
433
434
435
@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS)
436
def test_array_arithmetic_dtype_mismatch(
437
exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series],
438
) -> None:
439
import operator as op
440
441
a = pl.Series("a", [[1], [2]], dtype=pl.Array(pl.Int64, 1))
442
b = pl.Series("b", [[1, 1], [999, 999]], dtype=pl.Array(pl.Int64, 2))
443
444
with pytest.raises(InvalidOperationError, match="differing dtypes"):
445
exec_op(a, b, op.add)
446
447
a = pl.Series([[[1]], [[1]]], dtype=pl.Array(pl.List(pl.Int64), 1))
448
b = pl.Series([1], dtype=pl.Int64)
449
450
with pytest.raises(
451
InvalidOperationError, match="dtype was not array on all nesting levels"
452
):
453
exec_op(a, a, op.add)
454
455
with pytest.raises(
456
InvalidOperationError, match="dtype was not array on all nesting levels"
457
):
458
exec_op(a, b, op.add)
459
460
with pytest.raises(
461
InvalidOperationError, match="dtype was not array on all nesting levels"
462
):
463
exec_op(b, a, op.add)
464
465