Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/operations/arithmetic/test_list.py
6940 views
1
from __future__ import annotations
2
3
import operator
4
from typing import TYPE_CHECKING, Any, Callable
5
6
import pytest
7
8
import polars as pl
9
from polars.exceptions import InvalidOperationError, ShapeError
10
from polars.testing import assert_frame_equal, assert_series_equal
11
from tests.unit.operations.arithmetic.utils import (
12
BROADCAST_SERIES_COMBINATIONS,
13
EXEC_OP_COMBINATIONS,
14
)
15
16
if TYPE_CHECKING:
17
from polars._typing import PolarsDataType
18
19
20
@pytest.mark.parametrize(
21
"list_side", ["left", "left3", "both", "right3", "right", "none"]
22
)
23
@pytest.mark.parametrize(
24
"broadcast_series",
25
BROADCAST_SERIES_COMBINATIONS,
26
)
27
@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS)
28
@pytest.mark.slow
29
def test_list_arithmetic_values(
30
list_side: str,
31
broadcast_series: Callable[
32
[pl.Series, pl.Series, pl.Series], tuple[pl.Series, pl.Series, pl.Series]
33
],
34
exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series],
35
) -> None:
36
"""
37
Tests value correctness.
38
39
This test checks for output value correctness (a + b == c) across different
40
codepaths, by wrapping the values (a, b, c) in different combinations of
41
list / primitive columns.
42
"""
43
import operator as op
44
45
dtypes: list[Any] = [pl.Null, pl.Null, pl.Null]
46
dtype: Any = pl.Null
47
48
def materialize_list(v: Any) -> pl.Series:
49
return pl.Series(
50
[[None, v, None]],
51
dtype=pl.List(dtype),
52
)
53
54
def materialize_list3(v: Any) -> pl.Series:
55
return pl.Series(
56
[[[[None, v], None], None]],
57
dtype=pl.List(pl.List(pl.List(dtype))),
58
)
59
60
def materialize_primitive(v: Any) -> pl.Series:
61
return pl.Series([v], dtype=dtype)
62
63
def materialize_series(
64
l: Any, # noqa: E741
65
r: Any,
66
o: Any,
67
) -> tuple[pl.Series, pl.Series, pl.Series]:
68
nonlocal dtype
69
70
dtype = dtypes[0]
71
l = { # noqa: E741
72
"left": materialize_list,
73
"left3": materialize_list3,
74
"both": materialize_list,
75
"right": materialize_primitive,
76
"right3": materialize_primitive,
77
"none": materialize_primitive,
78
}[list_side](l) # fmt: skip
79
80
dtype = dtypes[1]
81
r = {
82
"left": materialize_primitive,
83
"left3": materialize_primitive,
84
"both": materialize_list,
85
"right": materialize_list,
86
"right3": materialize_list3,
87
"none": materialize_primitive,
88
}[list_side](r) # fmt: skip
89
90
dtype = dtypes[2]
91
o = {
92
"left": materialize_list,
93
"left3": materialize_list3,
94
"both": materialize_list,
95
"right": materialize_list,
96
"right3": materialize_list3,
97
"none": materialize_primitive,
98
}[list_side](o) # fmt: skip
99
100
assert l.len() == 1
101
assert r.len() == 1
102
assert o.len() == 1
103
104
return broadcast_series(l, r, o)
105
106
# Signed
107
dtypes = [pl.Int8, pl.Int8, pl.Int8]
108
109
l, r, o = materialize_series(2, 3, 5) # noqa: E741
110
assert_series_equal(exec_op(l, r, op.add), o)
111
112
l, r, o = materialize_series(-5, 127, 124) # noqa: E741
113
assert_series_equal(exec_op(l, r, op.sub), o)
114
115
l, r, o = materialize_series(-5, 127, -123) # noqa: E741
116
assert_series_equal(exec_op(l, r, op.mul), o)
117
118
l, r, o = materialize_series(-5, 3, -2) # noqa: E741
119
assert_series_equal(exec_op(l, r, op.floordiv), o)
120
121
l, r, o = materialize_series(-5, 3, 1) # noqa: E741
122
assert_series_equal(exec_op(l, r, op.mod), o)
123
124
dtypes = [pl.UInt8, pl.UInt8, pl.Float64]
125
l, r, o = materialize_series(2, 128, 0.015625) # noqa: E741
126
assert_series_equal(exec_op(l, r, op.truediv), o)
127
128
# Unsigned
129
dtypes = [pl.UInt8, pl.UInt8, pl.UInt8]
130
131
l, r, o = materialize_series(2, 3, 5) # noqa: E741
132
assert_series_equal(exec_op(l, r, op.add), o)
133
134
l, r, o = materialize_series(2, 3, 255) # noqa: E741
135
assert_series_equal(exec_op(l, r, op.sub), o)
136
137
l, r, o = materialize_series(2, 128, 0) # noqa: E741
138
assert_series_equal(exec_op(l, r, op.mul), o)
139
140
l, r, o = materialize_series(5, 2, 2) # noqa: E741
141
assert_series_equal(exec_op(l, r, op.floordiv), o)
142
143
l, r, o = materialize_series(5, 2, 1) # noqa: E741
144
assert_series_equal(exec_op(l, r, op.mod), o)
145
146
dtypes = [pl.UInt8, pl.UInt8, pl.Float64]
147
l, r, o = materialize_series(2, 128, 0.015625) # noqa: E741
148
assert_series_equal(exec_op(l, r, op.truediv), o)
149
150
# Floats. Note we pick Float32 to ensure there is no accidental upcasting
151
# to Float64.
152
dtypes = [pl.Float32, pl.Float32, pl.Float32]
153
l, r, o = materialize_series(1.7, 2.3, 4.0) # noqa: E741
154
assert_series_equal(exec_op(l, r, op.add), o)
155
156
l, r, o = materialize_series(1.7, 2.3, -0.5999999999999999) # noqa: E741
157
assert_series_equal(exec_op(l, r, op.sub), o)
158
159
l, r, o = materialize_series(1.7, 2.3, 3.9099999999999997) # noqa: E741
160
assert_series_equal(exec_op(l, r, op.mul), o)
161
162
l, r, o = materialize_series(7.0, 3.0, 2.0) # noqa: E741
163
assert_series_equal(exec_op(l, r, op.floordiv), o)
164
165
l, r, o = materialize_series(-5.0, 3.0, 1.0) # noqa: E741
166
assert_series_equal(exec_op(l, r, op.mod), o)
167
168
l, r, o = materialize_series(2.0, 128.0, 0.015625) # noqa: E741
169
assert_series_equal(exec_op(l, r, op.truediv), o)
170
171
#
172
# Tests for zero behavior
173
#
174
175
# Integer
176
177
dtypes = [pl.UInt8, pl.UInt8, pl.UInt8]
178
179
l, r, o = materialize_series(1, 0, None) # noqa: E741
180
assert_series_equal(exec_op(l, r, op.floordiv), o)
181
assert_series_equal(exec_op(l, r, op.mod), o)
182
183
l, r, o = materialize_series(0, 0, None) # noqa: E741
184
assert_series_equal(exec_op(l, r, op.floordiv), o)
185
assert_series_equal(exec_op(l, r, op.mod), o)
186
187
dtypes = [pl.UInt8, pl.UInt8, pl.Float64]
188
189
l, r, o = materialize_series(1, 0, float("inf")) # noqa: E741
190
assert_series_equal(exec_op(l, r, op.truediv), o)
191
192
l, r, o = materialize_series(0, 0, float("nan")) # noqa: E741
193
assert_series_equal(exec_op(l, r, op.truediv), o)
194
195
# Float
196
197
dtypes = [pl.Float32, pl.Float32, pl.Float32]
198
199
l, r, o = materialize_series(1, 0, float("inf")) # noqa: E741
200
assert_series_equal(exec_op(l, r, op.floordiv), o)
201
202
l, r, o = materialize_series(1, 0, float("nan")) # noqa: E741
203
assert_series_equal(exec_op(l, r, op.mod), o)
204
205
l, r, o = materialize_series(1, 0, float("inf")) # noqa: E741
206
assert_series_equal(exec_op(l, r, op.truediv), o)
207
208
l, r, o = materialize_series(0, 0, float("nan")) # noqa: E741
209
assert_series_equal(exec_op(l, r, op.floordiv), o)
210
211
l, r, o = materialize_series(0, 0, float("nan")) # noqa: E741
212
assert_series_equal(exec_op(l, r, op.mod), o)
213
214
l, r, o = materialize_series(0, 0, float("nan")) # noqa: E741
215
assert_series_equal(exec_op(l, r, op.truediv), o)
216
217
#
218
# Tests for NULL behavior
219
#
220
221
for dtype, truediv_dtype in [ # type: ignore[misc]
222
[pl.Int8, pl.Float64],
223
[pl.Float32, pl.Float32],
224
]:
225
for vals in [
226
[None, None, None],
227
[0, None, None],
228
[None, 0, None],
229
[0, None, None],
230
[None, 0, None],
231
[3, None, None],
232
[None, 3, None],
233
]:
234
dtypes = 3 * [dtype]
235
236
l, r, o = materialize_series(*vals) # type: ignore[misc] # noqa: E741
237
assert_series_equal(exec_op(l, r, op.add), o)
238
assert_series_equal(exec_op(l, r, op.sub), o)
239
assert_series_equal(exec_op(l, r, op.mul), o)
240
assert_series_equal(exec_op(l, r, op.floordiv), o)
241
assert_series_equal(exec_op(l, r, op.mod), o)
242
dtypes[2] = truediv_dtype # type: ignore[has-type]
243
l, r, o = materialize_series(*vals) # type: ignore[misc] # noqa: E741
244
assert_series_equal(exec_op(l, r, op.truediv), o)
245
246
# Type upcasting for Boolean and Null
247
248
# Check boolean upcasting
249
dtypes = [pl.Boolean, pl.UInt8, pl.UInt8]
250
251
l, r, o = materialize_series(True, 3, 4) # noqa: E741
252
assert_series_equal(exec_op(l, r, op.add), o)
253
254
l, r, o = materialize_series(True, 3, 254) # noqa: E741
255
assert_series_equal(exec_op(l, r, op.sub), o)
256
257
l, r, o = materialize_series(True, 3, 3) # noqa: E741
258
assert_series_equal(exec_op(l, r, op.mul), o)
259
260
l, r, o = materialize_series(True, 3, 0) # noqa: E741
261
if list_side != "none":
262
# TODO: FIXME: We get an error on non-lists with this:
263
# "floor_div operation not supported for dtype `bool`"
264
assert_series_equal(exec_op(l, r, op.floordiv), o)
265
266
l, r, o = materialize_series(True, 3, 1) # noqa: E741
267
assert_series_equal(exec_op(l, r, op.mod), o)
268
269
dtypes = [pl.Boolean, pl.UInt8, pl.Float64]
270
l, r, o = materialize_series(True, 128, 0.0078125) # noqa: E741
271
assert_series_equal(exec_op(l, r, op.truediv), o)
272
273
# Check Null upcasting
274
dtypes = [pl.Null, pl.UInt8, pl.UInt8]
275
l, r, o = materialize_series(None, 3, None) # noqa: E741
276
assert_series_equal(exec_op(l, r, op.add), o)
277
assert_series_equal(exec_op(l, r, op.sub), o)
278
assert_series_equal(exec_op(l, r, op.mul), o)
279
if list_side != "none":
280
assert_series_equal(exec_op(l, r, op.floordiv), o)
281
assert_series_equal(exec_op(l, r, op.mod), o)
282
283
dtypes = [pl.Null, pl.UInt8, pl.Float64]
284
l, r, o = materialize_series(None, 3, None) # noqa: E741
285
assert_series_equal(exec_op(l, r, op.truediv), o)
286
287
288
@pytest.mark.parametrize(
289
("lhs_dtype", "rhs_dtype", "expected_dtype"),
290
[
291
(pl.List(pl.Int64), pl.Int64, pl.List(pl.Float64)),
292
(pl.List(pl.Float32), pl.Float32, pl.List(pl.Float32)),
293
(pl.List(pl.Duration("us")), pl.Int64, pl.List(pl.Duration("us"))),
294
],
295
)
296
def test_list_truediv_schema(
297
lhs_dtype: PolarsDataType, rhs_dtype: PolarsDataType, expected_dtype: PolarsDataType
298
) -> None:
299
schema = {"lhs": lhs_dtype, "rhs": rhs_dtype}
300
df = pl.DataFrame({"lhs": [[None, 10]], "rhs": 2}, schema=schema)
301
result = df.lazy().select(pl.col("lhs").truediv("rhs")).collect_schema()["lhs"]
302
assert result == expected_dtype
303
304
305
@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS)
306
def test_list_add_supertype(
307
exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series],
308
) -> None:
309
import operator as op
310
311
a = pl.Series("a", [[1], [2]], dtype=pl.List(pl.Int8))
312
b = pl.Series("b", [[1], [999]], dtype=pl.List(pl.Int64))
313
314
assert_series_equal(
315
exec_op(a, b, op.add),
316
pl.Series("a", [[2], [1001]], dtype=pl.List(pl.Int64)),
317
)
318
319
320
@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS)
321
@pytest.mark.parametrize(
322
"broadcast_series",
323
BROADCAST_SERIES_COMBINATIONS,
324
)
325
@pytest.mark.slow
326
def test_list_numeric_op_validity_combination(
327
broadcast_series: Callable[
328
[pl.Series, pl.Series, pl.Series], tuple[pl.Series, pl.Series, pl.Series]
329
],
330
exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series],
331
) -> None:
332
import operator as op
333
334
a = pl.Series("a", [[1], [2], None, [None], [11], [1111]], dtype=pl.List(pl.Int32))
335
b = pl.Series("b", [[1], [3], [11], [1111], None, [None]], dtype=pl.List(pl.Int64))
336
# expected result
337
e = pl.Series("a", [[2], [5], None, [None], None, [None]], dtype=pl.List(pl.Int64))
338
339
assert_series_equal(
340
exec_op(a, b, op.add),
341
e,
342
)
343
344
a = pl.Series("a", [[1]], dtype=pl.List(pl.Int32))
345
b = pl.Series("b", [None], dtype=pl.Int64)
346
e = pl.Series("a", [[None]], dtype=pl.List(pl.Int64))
347
348
a, b, e = broadcast_series(a, b, e)
349
assert_series_equal(exec_op(a, b, op.add), e)
350
351
a = pl.Series("a", [None], dtype=pl.List(pl.Int32))
352
b = pl.Series("b", [1], dtype=pl.Int64)
353
e = pl.Series("a", [None], dtype=pl.List(pl.Int64))
354
355
a, b, e = broadcast_series(a, b, e)
356
assert_series_equal(exec_op(a, b, op.add), e)
357
358
a = pl.Series("a", [None], dtype=pl.List(pl.Int32))
359
b = pl.Series("b", [0], dtype=pl.Int64)
360
e = pl.Series("a", [None], dtype=pl.List(pl.Int64))
361
362
a, b, e = broadcast_series(a, b, e)
363
assert_series_equal(exec_op(a, b, op.floordiv), e)
364
365
366
def test_list_add_alignment() -> None:
367
a = pl.Series("a", [[1, 1], [1, 1, 1]])
368
b = pl.Series("b", [[1, 1, 1], [1, 1]])
369
370
df = pl.DataFrame([a, b])
371
372
with pytest.raises(ShapeError):
373
df.select(x=pl.col("a") + pl.col("b"))
374
375
# Test masking and slicing
376
a = pl.Series("a", [[1, 1, 1], [1], [1, 1], [1, 1, 1]])
377
b = pl.Series("b", [[1, 1], [1], [1, 1, 1], [1]])
378
c = pl.Series("c", [1, 1, 1, 1])
379
p = pl.Series("p", [True, True, False, False])
380
381
df = pl.DataFrame([a, b, c, p]).filter("p").slice(1)
382
383
for rhs in [pl.col("b"), pl.lit(1), pl.col("c"), pl.lit([1])]:
384
assert_series_equal(
385
df.select(x=pl.col("a") + rhs).to_series(), pl.Series("x", [[2]])
386
)
387
388
df = df.vstack(df)
389
390
for rhs in [pl.col("b"), pl.lit(1), pl.col("c"), pl.lit([1])]:
391
assert_series_equal(
392
df.select(x=pl.col("a") + rhs).to_series(), pl.Series("x", [[2], [2]])
393
)
394
395
396
@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS)
397
@pytest.mark.slow
398
def test_list_add_empty_lists(
399
exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series],
400
) -> None:
401
l = pl.Series( # noqa: E741
402
"x",
403
[[[[]], []], []],
404
)
405
r = pl.Series([1])
406
407
assert_series_equal(
408
exec_op(l, r, operator.add),
409
pl.Series("x", [[[[]], []], []], dtype=pl.List(pl.List(pl.List(pl.Int64)))),
410
)
411
412
l = pl.Series( # noqa: E741
413
"x",
414
[[[[]], None], []],
415
)
416
r = pl.Series([1])
417
418
assert_series_equal(
419
exec_op(l, r, operator.add),
420
pl.Series("x", [[[[]], None], []], dtype=pl.List(pl.List(pl.List(pl.Int64)))),
421
)
422
423
424
@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS)
425
def test_list_to_list_arithmetic_double_nesting_raises_error(
426
exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series],
427
) -> None:
428
s = pl.Series(dtype=pl.List(pl.List(pl.Int32)))
429
430
with pytest.raises(
431
InvalidOperationError,
432
match="cannot add two list columns with non-numeric inner types",
433
):
434
exec_op(s, s, operator.add)
435
436
437
@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS)
438
def test_list_add_height_mismatch(
439
exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series],
440
) -> None:
441
s = pl.Series([[1], [2], [3]], dtype=pl.List(pl.Int32))
442
443
# TODO: Make the error type consistently a ShapeError
444
with pytest.raises(
445
(ShapeError, InvalidOperationError),
446
match="length",
447
):
448
exec_op(s, pl.Series([1, 1]), operator.add)
449
450
451
@pytest.mark.parametrize(
452
"op",
453
[
454
operator.add,
455
operator.sub,
456
operator.mul,
457
operator.floordiv,
458
operator.mod,
459
operator.truediv,
460
],
461
)
462
@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS)
463
@pytest.mark.slow
464
def test_list_date_to_numeric_arithmetic_raises_error(
465
op: Callable[[Any], Any], exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series]
466
) -> None:
467
l = pl.Series([1], dtype=pl.Date) # noqa: E741
468
r = pl.Series([[1]], dtype=pl.List(pl.Int32))
469
470
exec_op(l.to_physical(), r, op)
471
472
# TODO(_): Ideally this always raises InvalidOperationError. The TypeError
473
# is being raised by checks on the Python side that should be moved to Rust.
474
with pytest.raises((InvalidOperationError, TypeError)):
475
exec_op(l, r, op)
476
477
478
@pytest.mark.parametrize(
479
("expected", "expr", "column_names"),
480
[
481
([[2, 4], [6]], lambda a, b: a + b, ("a", "a")),
482
([[0, 0], [0]], lambda a, b: a - b, ("a", "a")),
483
([[1, 4], [9]], lambda a, b: a * b, ("a", "a")),
484
([[1.0, 1.0], [1.0]], lambda a, b: a / b, ("a", "a")),
485
([[0, 0], [0]], lambda a, b: a % b, ("a", "a")),
486
(
487
[[3, 4], [7]],
488
lambda a, b: a + b,
489
("a", "uint8"),
490
),
491
],
492
)
493
def test_list_arithmetic_same_size(
494
expected: Any,
495
expr: Callable[[pl.Series | pl.Expr, pl.Series | pl.Expr], pl.Series],
496
column_names: tuple[str, str],
497
) -> None:
498
df = pl.DataFrame(
499
[
500
pl.Series("a", [[1, 2], [3]]),
501
pl.Series("uint8", [[2, 2], [4]], dtype=pl.List(pl.UInt8())),
502
pl.Series("nested", [[[1, 2]], [[3]]]),
503
pl.Series(
504
"nested_uint8", [[[1, 2]], [[3]]], dtype=pl.List(pl.List(pl.UInt8()))
505
),
506
]
507
)
508
# Expr-based arithmetic:
509
assert_frame_equal(
510
df.select(expr(pl.col(column_names[0]), pl.col(column_names[1]))),
511
pl.Series(column_names[0], expected).to_frame(),
512
)
513
# Direct arithmetic on the Series:
514
assert_series_equal(
515
expr(df[column_names[0]], df[column_names[1]]),
516
pl.Series(column_names[0], expected),
517
)
518
519
520
@pytest.mark.parametrize(
521
("a", "b", "expected"),
522
[
523
([[1, 2, 3]], [[1, None, 5]], [[2, None, 8]]),
524
([[2], None, [5]], [None, [3], [2]], [None, None, [7]]),
525
],
526
)
527
def test_list_arithmetic_nulls(a: list[Any], b: list[Any], expected: list[Any]) -> None:
528
series_a = pl.Series(a)
529
series_b = pl.Series(b)
530
series_expected = pl.Series(expected)
531
532
# Same dtype:
533
assert_series_equal(series_a + series_b, series_expected)
534
535
# Different dtype:
536
assert_series_equal(
537
series_a._recursive_cast_to_dtype(pl.Int32())
538
+ series_b._recursive_cast_to_dtype(pl.Int64()),
539
series_expected._recursive_cast_to_dtype(pl.Int64()),
540
)
541
542
543
def test_list_arithmetic_error_cases() -> None:
544
# Different series length:
545
with pytest.raises(InvalidOperationError, match="different lengths"):
546
_ = pl.Series("a", [[1, 2], [1, 2], [1, 2]]) / pl.Series("b", [[1, 2], [3, 4]])
547
with pytest.raises(InvalidOperationError, match="different lengths"):
548
_ = pl.Series("a", [[1, 2], [1, 2], [1, 2]]) / pl.Series("b", [[1, 2], None])
549
550
# Different list length:
551
with pytest.raises(ShapeError, match="lengths differed at index 0: 2 != 1"):
552
_ = pl.Series("a", [[1, 2], [1, 2], [1, 2]]) / pl.Series("b", [[1]])
553
554
with pytest.raises(ShapeError, match="lengths differed at index 0: 2 != 1"):
555
_ = pl.Series("a", [[1, 2], [2, 3]]) / pl.Series("b", [[1], None])
556
557
558
@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS)
559
def test_list_arithmetic_invalid_dtypes(
560
exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series],
561
) -> None:
562
import operator as op
563
564
a = pl.Series([[1, 2]])
565
b = pl.Series(["hello"])
566
567
# Wrong types:
568
with pytest.raises(
569
InvalidOperationError, match="add operation not supported for dtypes"
570
):
571
exec_op(a, b, op.add)
572
573
a = pl.Series("a", [[1]])
574
b = pl.Series("b", [[[1]]])
575
576
# list<->list is restricted to 1 level of nesting
577
with pytest.raises(
578
InvalidOperationError,
579
match="cannot add two list columns with non-numeric inner types",
580
):
581
exec_op(a, b, op.add)
582
583
# Ensure dtype is validated to be `List` at all nesting levels instead of panicking.
584
a = pl.Series([[[1]], [[1]]], dtype=pl.List(pl.Array(pl.Int64, 1)))
585
b = pl.Series([1], dtype=pl.Int64)
586
587
with pytest.raises(
588
InvalidOperationError, match="dtype was not list on all nesting levels"
589
):
590
exec_op(a, b, op.add)
591
592
with pytest.raises(
593
InvalidOperationError, match="dtype was not list on all nesting levels"
594
):
595
exec_op(b, a, op.add)
596
597
598
@pytest.mark.parametrize(
599
("expected", "expr", "column_names"),
600
[
601
# All 5 arithmetic operations:
602
([[3, 4], [6]], lambda a, b: a + b, ("list", "int64")),
603
([[-1, 0], [0]], lambda a, b: a - b, ("list", "int64")),
604
([[2, 4], [9]], lambda a, b: a * b, ("list", "int64")),
605
([[0.5, 1.0], [1.0]], lambda a, b: a / b, ("list", "int64")),
606
([[1, 0], [0]], lambda a, b: a % b, ("list", "int64")),
607
# Different types:
608
(
609
[[3, 4], [7]],
610
lambda a, b: a + b,
611
("list", "uint8"),
612
),
613
# Extra nesting + different types:
614
(
615
[[[3, 4]], [[8]]],
616
lambda a, b: a + b,
617
("nested", "int64"),
618
),
619
# Primitive numeric on the left; only addition and multiplication are
620
# supported:
621
([[3, 4], [6]], lambda a, b: a + b, ("int64", "list")),
622
([[2, 4], [9]], lambda a, b: a * b, ("int64", "list")),
623
# Primitive numeric on the left with different types:
624
(
625
[[3, 4], [7]],
626
lambda a, b: a + b,
627
("uint8", "list"),
628
),
629
(
630
[[2, 4], [12]],
631
lambda a, b: a * b,
632
("uint8", "list"),
633
),
634
],
635
)
636
def test_list_and_numeric_arithmetic_same_size(
637
expected: Any,
638
expr: Callable[[pl.Series | pl.Expr, pl.Series | pl.Expr], pl.Series],
639
column_names: tuple[str, str],
640
) -> None:
641
df = pl.DataFrame(
642
[
643
pl.Series("list", [[1, 2], [3]]),
644
pl.Series("int64", [2, 3], dtype=pl.Int64()),
645
pl.Series("uint8", [2, 4], dtype=pl.UInt8()),
646
pl.Series("nested", [[[1, 2]], [[5]]]),
647
]
648
)
649
# Expr-based arithmetic:
650
assert_frame_equal(
651
df.select(expr(pl.col(column_names[0]), pl.col(column_names[1]))),
652
pl.Series(column_names[0], expected).to_frame(),
653
)
654
# Direct arithmetic on the Series:
655
assert_series_equal(
656
expr(df[column_names[0]], df[column_names[1]]),
657
pl.Series(column_names[0], expected),
658
)
659
660
661
@pytest.mark.parametrize(
662
("a", "b", "expected"),
663
[
664
# Null on numeric on the right:
665
([[1, 2], [3]], [1, None], [[2, 3], [None]]),
666
# Null on list on the left:
667
([[[1, 2]], [[3]]], [None, 1], [[[None, None]], [[4]]]),
668
# Extra nesting:
669
([[[2, None]], [[3, 6]]], [3, 4], [[[5, None]], [[7, 10]]]),
670
],
671
)
672
def test_list_and_numeric_arithmetic_nulls(
673
a: list[Any], b: list[Any], expected: list[Any]
674
) -> None:
675
series_a = pl.Series(a)
676
series_b = pl.Series(b)
677
series_expected = pl.Series(expected, dtype=series_a.dtype)
678
679
# Same dtype:
680
assert_series_equal(series_a + series_b, series_expected)
681
682
# Different dtype:
683
assert_series_equal(
684
series_a._recursive_cast_to_dtype(pl.Int32())
685
+ series_b._recursive_cast_to_dtype(pl.Int64()),
686
series_expected._recursive_cast_to_dtype(pl.Int64()),
687
)
688
689
# Swap sides:
690
assert_series_equal(series_b + series_a, series_expected)
691
assert_series_equal(
692
series_b._recursive_cast_to_dtype(pl.Int32())
693
+ series_a._recursive_cast_to_dtype(pl.Int64()),
694
series_expected._recursive_cast_to_dtype(pl.Int64()),
695
)
696
697
698
def test_list_and_numeric_arithmetic_error_cases() -> None:
699
# Different series length:
700
with pytest.raises(
701
InvalidOperationError, match="series of different lengths: got 3 and 2"
702
):
703
_ = pl.Series("a", [[1, 2], [3, 4], [5, 6]]) + pl.Series("b", [1, 2])
704
with pytest.raises(
705
InvalidOperationError, match="series of different lengths: got 3 and 2"
706
):
707
_ = pl.Series("a", [[1, 2], [3, 4], [5, 6]]) / pl.Series("b", [1, None])
708
709
# Wrong types:
710
with pytest.raises(
711
InvalidOperationError, match="add operation not supported for dtypes"
712
):
713
_ = pl.Series("a", [[1, 2], [3, 4]]) + pl.Series("b", ["hello", "world"])
714
715
716
@pytest.mark.parametrize("broadcast", [True, False])
717
@pytest.mark.parametrize("dtype", [pl.Int64(), pl.Float64()])
718
def test_list_arithmetic_div_ops_zero_denominator(
719
broadcast: bool, dtype: pl.DataType
720
) -> None:
721
# Notes
722
# * truediv (/) on integers upcasts to Float64
723
# * Otherwise, we test floordiv (//) and module/rem (%)
724
# * On integers, 0-denominator is expected to output NULL
725
# * On floats, 0-denominator has different outputs, e.g. NaN, Inf, depending
726
# on a few factors (e.g. whether the numerator is also 0).
727
728
s = pl.Series([[0], [1], [None], None]).cast(pl.List(dtype))
729
730
n = 1 if broadcast else s.len()
731
732
# list<->primitive
733
734
# truediv
735
assert_series_equal(
736
pl.Series([1]).new_from_index(0, n) / s,
737
pl.Series([[float("inf")], [1.0], [None], None], dtype=pl.List(pl.Float64)),
738
)
739
740
assert_series_equal(
741
s / pl.Series([1]).new_from_index(0, n),
742
pl.Series([[0.0], [1.0], [None], None], dtype=pl.List(pl.Float64)),
743
)
744
745
# floordiv
746
assert_series_equal(
747
pl.Series([1]).new_from_index(0, n) // s,
748
(
749
pl.Series([[None], [1], [None], None], dtype=s.dtype)
750
if not dtype.is_float()
751
else pl.Series([[float("inf")], [1.0], [None], None], dtype=s.dtype)
752
),
753
)
754
755
assert_series_equal(
756
s // pl.Series([0]).new_from_index(0, n),
757
(
758
pl.Series([[None], [None], [None], None], dtype=s.dtype)
759
if not dtype.is_float()
760
else pl.Series(
761
[[float("nan")], [float("inf")], [None], None], dtype=s.dtype
762
)
763
),
764
)
765
766
# rem
767
assert_series_equal(
768
pl.Series([1]).new_from_index(0, n) % s,
769
(
770
pl.Series([[None], [0], [None], None], dtype=s.dtype)
771
if not dtype.is_float()
772
else pl.Series([[float("nan")], [0.0], [None], None], dtype=s.dtype)
773
),
774
)
775
776
assert_series_equal(
777
s % pl.Series([0]).new_from_index(0, n),
778
(
779
pl.Series([[None], [None], [None], None], dtype=s.dtype)
780
if not dtype.is_float()
781
else pl.Series(
782
[[float("nan")], [float("nan")], [None], None], dtype=s.dtype
783
)
784
),
785
)
786
787
# list<->list
788
789
# truediv
790
assert_series_equal(
791
pl.Series([[1]]).new_from_index(0, n) / s,
792
pl.Series([[float("inf")], [1.0], [None], None], dtype=pl.List(pl.Float64)),
793
)
794
795
assert_series_equal(
796
s / pl.Series([[0]]).new_from_index(0, n),
797
pl.Series(
798
[[float("nan")], [float("inf")], [None], None], dtype=pl.List(pl.Float64)
799
),
800
)
801
802
# floordiv
803
assert_series_equal(
804
pl.Series([[1]]).new_from_index(0, n) // s,
805
(
806
pl.Series([[None], [1], [None], None], dtype=s.dtype)
807
if not dtype.is_float()
808
else pl.Series([[float("inf")], [1.0], [None], None], dtype=s.dtype)
809
),
810
)
811
812
assert_series_equal(
813
s // pl.Series([[0]]).new_from_index(0, n),
814
(
815
pl.Series([[None], [None], [None], None], dtype=s.dtype)
816
if not dtype.is_float()
817
else pl.Series(
818
[[float("nan")], [float("inf")], [None], None], dtype=s.dtype
819
)
820
),
821
)
822
823
# rem
824
assert_series_equal(
825
pl.Series([[1]]).new_from_index(0, n) % s,
826
(
827
pl.Series([[None], [0], [None], None], dtype=s.dtype)
828
if not dtype.is_float()
829
else pl.Series([[float("nan")], [0.0], [None], None], dtype=s.dtype)
830
),
831
)
832
833
assert_series_equal(
834
s % pl.Series([[0]]).new_from_index(0, n),
835
(
836
pl.Series([[None], [None], [None], None], dtype=s.dtype)
837
if not dtype.is_float()
838
else pl.Series(
839
[[float("nan")], [float("nan")], [None], None], dtype=s.dtype
840
)
841
),
842
)
843
844
845
def test_list_to_primitive_arithmetic() -> None:
846
# Input data
847
# * List type: List(List(List(Int16))) (triple-nested)
848
# * Numeric type: Int32
849
#
850
# Tests run
851
# Broadcast Operation
852
# | L | R |
853
# * list<->primitive | | | floor_div
854
# * primitive<->list | | | floor_div
855
# * list<->primitive | | * | subtract
856
# * primitive<->list | * | | subtract
857
# * list<->primitive | * | | subtract
858
# * primitive<->list | | * | subtract
859
#
860
# Notes
861
# * In floor_div, we check that results from a 0 denominator are masked out
862
# * We choose floor_div and subtract as they emit different results when
863
# sides are swapped
864
865
# Create some non-zero start offsets and masked out rows.
866
lhs = (
867
pl.Series(
868
[
869
[[[None, None, None, None, None]]], # sliced out
870
# Nulls at every level XO
871
[[[3, 7]], [[-3], [None], [], [], None], [], None],
872
[[[1, 2, 3, 4, 5]]], # masked out
873
[[[3, 7]], [[0], [None], [], [], None]],
874
[[[3, 7]]],
875
],
876
dtype=pl.List(pl.List(pl.List(pl.Int16))),
877
)
878
.slice(1)
879
.to_frame()
880
.select(pl.when(pl.int_range(pl.len()) != 1).then(pl.first()))
881
.to_series()
882
)
883
884
# Note to reader: This is what our LHS looks like
885
assert_series_equal(
886
lhs,
887
pl.Series(
888
[
889
[[[3, 7]], [[-3], [None], [], [], None], [], None],
890
None,
891
[[[3, 7]], [[0], [None], [], [], None]],
892
[[[3, 7]]],
893
],
894
dtype=pl.List(pl.List(pl.List(pl.Int16))),
895
),
896
)
897
898
class _:
899
# Floor div, no broadcasting
900
rhs = pl.Series([5, 1, 0, None], dtype=pl.Int32)
901
902
assert len(lhs) == len(rhs)
903
904
expect = pl.Series(
905
[
906
[[[0, 1]], [[-1], [None], [], [], None], [], None],
907
None,
908
[[[None, None]], [[None], [None], [], [], None]],
909
[[[None, None]]],
910
],
911
dtype=pl.List(pl.List(pl.List(pl.Int32))),
912
)
913
914
out = (
915
pl.select(l=lhs, r=rhs)
916
.select(pl.col("l") // pl.col("r"))
917
.to_series()
918
.alias("")
919
)
920
921
assert_series_equal(out, expect)
922
923
# Flipped
924
925
expect = pl.Series( # noqa: PIE794
926
[
927
[[[1, 0]], [[-2], [None], [], [], None], [], None],
928
None,
929
[[[0, 0]], [[None], [None], [], [], None]],
930
[[[None, None]]],
931
],
932
dtype=pl.List(pl.List(pl.List(pl.Int32))),
933
)
934
935
out = ( # noqa: PIE794
936
pl.select(l=lhs, r=rhs)
937
.select(pl.col("r") // pl.col("l"))
938
.to_series()
939
.alias("")
940
)
941
942
assert_series_equal(out, expect)
943
944
class _: # type: ignore[no-redef]
945
# Subtraction with broadcasting
946
rhs = pl.Series([1], dtype=pl.Int32)
947
948
expect = pl.Series(
949
[
950
[[[2, 6]], [[-4], [None], [], [], None], [], None],
951
None,
952
[[[2, 6]], [[-1], [None], [], [], None]],
953
[[[2, 6]]],
954
],
955
dtype=pl.List(pl.List(pl.List(pl.Int32))),
956
)
957
958
out = pl.select(l=lhs).select(pl.col("l") - rhs).to_series().alias("")
959
960
assert_series_equal(out, expect)
961
962
# Flipped
963
964
expect = pl.Series( # noqa: PIE794
965
[
966
[[[-2, -6]], [[4], [None], [], [], None], [], None],
967
None,
968
[[[-2, -6]], [[1], [None], [], [], None]],
969
[[[-2, -6]]],
970
],
971
dtype=pl.List(pl.List(pl.List(pl.Int32))),
972
)
973
974
out = pl.select(l=lhs).select(rhs - pl.col("l")).to_series().alias("") # noqa: PIE794
975
976
assert_series_equal(out, expect)
977
978
# Test broadcasting of the list side
979
lhs = lhs.slice(2, 1)
980
# Note to reader: This is what our LHS looks like
981
assert_series_equal(
982
lhs,
983
pl.Series(
984
[
985
[[[3, 7]], [[0], [None], [], [], None]],
986
],
987
dtype=pl.List(pl.List(pl.List(pl.Int16))),
988
),
989
)
990
991
assert len(lhs) == 1
992
993
class _: # type: ignore[no-redef]
994
rhs = pl.Series([1, 2, 3, None, 5], dtype=pl.Int32)
995
996
expect = pl.Series(
997
[
998
[[[2, 6]], [[-1], [None], [], [], None]],
999
[[[1, 5]], [[-2], [None], [], [], None]],
1000
[[[0, 4]], [[-3], [None], [], [], None]],
1001
[[[None, None]], [[None], [None], [], [], None]],
1002
[[[-2, 2]], [[-5], [None], [], [], None]],
1003
],
1004
dtype=pl.List(pl.List(pl.List(pl.Int32))),
1005
)
1006
1007
out = pl.select(r=rhs).select(lhs - pl.col("r")).to_series().alias("")
1008
1009
assert_series_equal(out, expect)
1010
1011
# Flipped
1012
1013
expect = pl.Series( # noqa: PIE794
1014
[
1015
[[[-2, -6]], [[1], [None], [], [], None]],
1016
[[[-1, -5]], [[2], [None], [], [], None]],
1017
[[[0, -4]], [[3], [None], [], [], None]],
1018
[[[None, None]], [[None], [None], [], [], None]],
1019
[[[2, -2]], [[5], [None], [], [], None]],
1020
],
1021
dtype=pl.List(pl.List(pl.List(pl.Int32))),
1022
)
1023
1024
out = pl.select(r=rhs).select(pl.col("r") - lhs).to_series().alias("") # noqa: PIE794
1025
1026
assert_series_equal(out, expect)
1027
1028