Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/operations/arithmetic/test_list.py
8421 views
1
from __future__ import annotations
2
3
import operator
4
from typing import TYPE_CHECKING, Any
5
6
import pytest
7
8
import polars as pl
9
from polars.exceptions import InvalidOperationError, ShapeError
10
from polars.testing import assert_frame_equal, assert_series_equal
11
from tests.unit.operations.arithmetic.utils import (
12
BROADCAST_SERIES_COMBINATIONS,
13
EXEC_OP_COMBINATIONS,
14
)
15
16
if TYPE_CHECKING:
17
from collections.abc import Callable
18
19
from polars._typing import PolarsDataType
20
21
22
@pytest.mark.parametrize(
23
"list_side", ["left", "left3", "both", "right3", "right", "none"]
24
)
25
@pytest.mark.parametrize(
26
"broadcast_series",
27
BROADCAST_SERIES_COMBINATIONS,
28
)
29
@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS)
30
@pytest.mark.slow
31
def test_list_arithmetic_values(
32
list_side: str,
33
broadcast_series: Callable[
34
[pl.Series, pl.Series, pl.Series], tuple[pl.Series, pl.Series, pl.Series]
35
],
36
exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series],
37
) -> None:
38
"""
39
Tests value correctness.
40
41
This test checks for output value correctness (a + b == c) across different
42
codepaths, by wrapping the values (a, b, c) in different combinations of
43
list / primitive columns.
44
"""
45
import operator as op
46
47
dtypes: list[Any] = [pl.Null, pl.Null, pl.Null]
48
dtype: Any = pl.Null
49
50
def materialize_list(v: Any) -> pl.Series:
51
return pl.Series(
52
[[None, v, None]],
53
dtype=pl.List(dtype),
54
)
55
56
def materialize_list3(v: Any) -> pl.Series:
57
return pl.Series(
58
[[[[None, v], None], None]],
59
dtype=pl.List(pl.List(pl.List(dtype))),
60
)
61
62
def materialize_primitive(v: Any) -> pl.Series:
63
return pl.Series([v], dtype=dtype)
64
65
def materialize_series(
66
l: Any, # noqa: E741
67
r: Any,
68
o: Any,
69
) -> tuple[pl.Series, pl.Series, pl.Series]:
70
nonlocal dtype
71
72
dtype = dtypes[0]
73
l = { # noqa: E741
74
"left": materialize_list,
75
"left3": materialize_list3,
76
"both": materialize_list,
77
"right": materialize_primitive,
78
"right3": materialize_primitive,
79
"none": materialize_primitive,
80
}[list_side](l) # fmt: skip
81
82
dtype = dtypes[1]
83
r = {
84
"left": materialize_primitive,
85
"left3": materialize_primitive,
86
"both": materialize_list,
87
"right": materialize_list,
88
"right3": materialize_list3,
89
"none": materialize_primitive,
90
}[list_side](r) # fmt: skip
91
92
dtype = dtypes[2]
93
o = {
94
"left": materialize_list,
95
"left3": materialize_list3,
96
"both": materialize_list,
97
"right": materialize_list,
98
"right3": materialize_list3,
99
"none": materialize_primitive,
100
}[list_side](o) # fmt: skip
101
102
assert l.len() == 1
103
assert r.len() == 1
104
assert o.len() == 1
105
106
return broadcast_series(l, r, o)
107
108
# Signed
109
dtypes = [pl.Int8, pl.Int8, pl.Int8]
110
111
l, r, o = materialize_series(2, 3, 5) # noqa: E741
112
assert_series_equal(exec_op(l, r, op.add), o)
113
114
l, r, o = materialize_series(-5, 127, 124) # noqa: E741
115
assert_series_equal(exec_op(l, r, op.sub), o)
116
117
l, r, o = materialize_series(-5, 127, -123) # noqa: E741
118
assert_series_equal(exec_op(l, r, op.mul), o)
119
120
l, r, o = materialize_series(-5, 3, -2) # noqa: E741
121
assert_series_equal(exec_op(l, r, op.floordiv), o)
122
123
l, r, o = materialize_series(-5, 3, 1) # noqa: E741
124
assert_series_equal(exec_op(l, r, op.mod), o)
125
126
dtypes = [pl.UInt8, pl.UInt8, pl.Float64]
127
l, r, o = materialize_series(2, 128, 0.015625) # noqa: E741
128
assert_series_equal(exec_op(l, r, op.truediv), o)
129
130
# Unsigned
131
dtypes = [pl.UInt8, pl.UInt8, pl.UInt8]
132
133
l, r, o = materialize_series(2, 3, 5) # noqa: E741
134
assert_series_equal(exec_op(l, r, op.add), o)
135
136
l, r, o = materialize_series(2, 3, 255) # noqa: E741
137
assert_series_equal(exec_op(l, r, op.sub), o)
138
139
l, r, o = materialize_series(2, 128, 0) # noqa: E741
140
assert_series_equal(exec_op(l, r, op.mul), o)
141
142
l, r, o = materialize_series(5, 2, 2) # noqa: E741
143
assert_series_equal(exec_op(l, r, op.floordiv), o)
144
145
l, r, o = materialize_series(5, 2, 1) # noqa: E741
146
assert_series_equal(exec_op(l, r, op.mod), o)
147
148
dtypes = [pl.UInt8, pl.UInt8, pl.Float64]
149
l, r, o = materialize_series(2, 128, 0.015625) # noqa: E741
150
assert_series_equal(exec_op(l, r, op.truediv), o)
151
152
# Floats. Note we pick Float32 to ensure there is no accidental upcasting
153
# to Float64.
154
dtypes = [pl.Float32, pl.Float32, pl.Float32]
155
l, r, o = materialize_series(1.7, 2.3, 4.0) # noqa: E741
156
assert_series_equal(exec_op(l, r, op.add), o)
157
158
l, r, o = materialize_series(1.7, 2.3, -0.5999999999999999) # noqa: E741
159
assert_series_equal(exec_op(l, r, op.sub), o)
160
161
l, r, o = materialize_series(1.7, 2.3, 3.9099999999999997) # noqa: E741
162
assert_series_equal(exec_op(l, r, op.mul), o)
163
164
l, r, o = materialize_series(7.0, 3.0, 2.0) # noqa: E741
165
assert_series_equal(exec_op(l, r, op.floordiv), o)
166
167
l, r, o = materialize_series(-5.0, 3.0, 1.0) # noqa: E741
168
assert_series_equal(exec_op(l, r, op.mod), o)
169
170
l, r, o = materialize_series(2.0, 128.0, 0.015625) # noqa: E741
171
assert_series_equal(exec_op(l, r, op.truediv), o)
172
173
#
174
# Tests for zero behavior
175
#
176
177
# Integer
178
179
dtypes = [pl.UInt8, pl.UInt8, pl.UInt8]
180
181
l, r, o = materialize_series(1, 0, None) # noqa: E741
182
assert_series_equal(exec_op(l, r, op.floordiv), o)
183
assert_series_equal(exec_op(l, r, op.mod), o)
184
185
l, r, o = materialize_series(0, 0, None) # noqa: E741
186
assert_series_equal(exec_op(l, r, op.floordiv), o)
187
assert_series_equal(exec_op(l, r, op.mod), o)
188
189
dtypes = [pl.UInt8, pl.UInt8, pl.Float64]
190
191
l, r, o = materialize_series(1, 0, float("inf")) # noqa: E741
192
assert_series_equal(exec_op(l, r, op.truediv), o)
193
194
l, r, o = materialize_series(0, 0, float("nan")) # noqa: E741
195
assert_series_equal(exec_op(l, r, op.truediv), o)
196
197
# Float
198
199
dtypes = [pl.Float32, pl.Float32, pl.Float32]
200
201
l, r, o = materialize_series(1, 0, float("inf")) # noqa: E741
202
assert_series_equal(exec_op(l, r, op.floordiv), o)
203
204
l, r, o = materialize_series(1, 0, float("nan")) # noqa: E741
205
assert_series_equal(exec_op(l, r, op.mod), o)
206
207
l, r, o = materialize_series(1, 0, float("inf")) # noqa: E741
208
assert_series_equal(exec_op(l, r, op.truediv), o)
209
210
l, r, o = materialize_series(0, 0, float("nan")) # noqa: E741
211
assert_series_equal(exec_op(l, r, op.floordiv), o)
212
213
l, r, o = materialize_series(0, 0, float("nan")) # noqa: E741
214
assert_series_equal(exec_op(l, r, op.mod), o)
215
216
l, r, o = materialize_series(0, 0, float("nan")) # noqa: E741
217
assert_series_equal(exec_op(l, r, op.truediv), o)
218
219
#
220
# Tests for NULL behavior
221
#
222
223
for dtype, truediv_dtype in [ # type: ignore[misc]
224
[pl.Int8, pl.Float64],
225
[pl.Float32, pl.Float32],
226
]:
227
for vals in [
228
[None, None, None],
229
[0, None, None],
230
[None, 0, None],
231
[0, None, None],
232
[None, 0, None],
233
[3, None, None],
234
[None, 3, None],
235
]:
236
dtypes = 3 * [dtype]
237
238
l, r, o = materialize_series(*vals) # type: ignore[misc] # noqa: E741
239
assert_series_equal(exec_op(l, r, op.add), o)
240
assert_series_equal(exec_op(l, r, op.sub), o)
241
assert_series_equal(exec_op(l, r, op.mul), o)
242
assert_series_equal(exec_op(l, r, op.floordiv), o)
243
assert_series_equal(exec_op(l, r, op.mod), o)
244
dtypes[2] = truediv_dtype # type: ignore[has-type]
245
l, r, o = materialize_series(*vals) # type: ignore[misc] # noqa: E741
246
assert_series_equal(exec_op(l, r, op.truediv), o)
247
248
# Type upcasting for Boolean and Null
249
250
# Check boolean upcasting
251
dtypes = [pl.Boolean, pl.Boolean, pl.get_index_type()]
252
l, r, o = materialize_series(True, True, 2) # noqa: E741
253
assert_series_equal(exec_op(l, r, op.add), o)
254
255
dtypes = [pl.Boolean, pl.Boolean, pl.Float64]
256
l, r, o = materialize_series(True, True, 1.0) # noqa: E741
257
assert_series_equal(exec_op(l, r, op.truediv), o)
258
259
dtypes = [pl.Boolean, pl.UInt8, pl.UInt8]
260
261
l, r, o = materialize_series(True, 3, 4) # noqa: E741
262
assert_series_equal(exec_op(l, r, op.add), o)
263
264
l, r, o = materialize_series(True, 3, 254) # noqa: E741
265
assert_series_equal(exec_op(l, r, op.sub), o)
266
267
l, r, o = materialize_series(True, 3, 3) # noqa: E741
268
assert_series_equal(exec_op(l, r, op.mul), o)
269
270
l, r, o = materialize_series(True, 3, 0) # noqa: E741
271
if list_side != "none":
272
# TODO: We get an error on non-lists with this:
273
# "floor_div operation not supported for dtype `bool`"
274
assert_series_equal(exec_op(l, r, op.floordiv), o)
275
276
l, r, o = materialize_series(True, 3, 1) # noqa: E741
277
assert_series_equal(exec_op(l, r, op.mod), o)
278
279
dtypes = [pl.Boolean, pl.UInt8, pl.Float64]
280
l, r, o = materialize_series(True, 128, 0.0078125) # noqa: E741
281
assert_series_equal(exec_op(l, r, op.truediv), o)
282
283
# Check Null upcasting
284
dtypes = [pl.Null, pl.UInt8, pl.UInt8]
285
l, r, o = materialize_series(None, 3, None) # noqa: E741
286
assert_series_equal(exec_op(l, r, op.add), o)
287
assert_series_equal(exec_op(l, r, op.sub), o)
288
assert_series_equal(exec_op(l, r, op.mul), o)
289
if list_side != "none":
290
assert_series_equal(exec_op(l, r, op.floordiv), o)
291
assert_series_equal(exec_op(l, r, op.mod), o)
292
293
dtypes = [pl.Null, pl.UInt8, pl.Float64]
294
l, r, o = materialize_series(None, 3, None) # noqa: E741
295
assert_series_equal(exec_op(l, r, op.truediv), o)
296
297
298
@pytest.mark.parametrize(
299
("lhs_dtype", "rhs_dtype", "expected_dtype"),
300
[
301
(pl.List(pl.Int64), pl.Int64, pl.List(pl.Float64)),
302
(pl.List(pl.Float32), pl.Float32, pl.List(pl.Float32)),
303
(pl.List(pl.Duration("us")), pl.Int64, pl.List(pl.Duration("us"))),
304
],
305
)
306
def test_list_truediv_schema(
307
lhs_dtype: PolarsDataType, rhs_dtype: PolarsDataType, expected_dtype: PolarsDataType
308
) -> None:
309
schema = {"lhs": lhs_dtype, "rhs": rhs_dtype}
310
df = pl.DataFrame({"lhs": [[None, 10]], "rhs": 2}, schema=schema)
311
result = df.lazy().select(pl.col("lhs").truediv("rhs")).collect_schema()["lhs"]
312
assert result == expected_dtype
313
314
315
@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS)
316
def test_list_add_supertype(
317
exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series],
318
) -> None:
319
import operator as op
320
321
a = pl.Series("a", [[1], [2]], dtype=pl.List(pl.Int8))
322
b = pl.Series("b", [[1], [999]], dtype=pl.List(pl.Int64))
323
324
assert_series_equal(
325
exec_op(a, b, op.add),
326
pl.Series("a", [[2], [1001]], dtype=pl.List(pl.Int64)),
327
)
328
329
330
@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS)
331
@pytest.mark.parametrize(
332
"broadcast_series",
333
BROADCAST_SERIES_COMBINATIONS,
334
)
335
@pytest.mark.slow
336
def test_list_numeric_op_validity_combination(
337
broadcast_series: Callable[
338
[pl.Series, pl.Series, pl.Series], tuple[pl.Series, pl.Series, pl.Series]
339
],
340
exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series],
341
) -> None:
342
import operator as op
343
344
a = pl.Series("a", [[1], [2], None, [None], [11], [1111]], dtype=pl.List(pl.Int32))
345
b = pl.Series("b", [[1], [3], [11], [1111], None, [None]], dtype=pl.List(pl.Int64))
346
# expected result
347
e = pl.Series("a", [[2], [5], None, [None], None, [None]], dtype=pl.List(pl.Int64))
348
349
assert_series_equal(
350
exec_op(a, b, op.add),
351
e,
352
)
353
354
a = pl.Series("a", [[1]], dtype=pl.List(pl.Int32))
355
b = pl.Series("b", [None], dtype=pl.Int64)
356
e = pl.Series("a", [[None]], dtype=pl.List(pl.Int64))
357
358
a, b, e = broadcast_series(a, b, e)
359
assert_series_equal(exec_op(a, b, op.add), e)
360
361
a = pl.Series("a", [None], dtype=pl.List(pl.Int32))
362
b = pl.Series("b", [1], dtype=pl.Int64)
363
e = pl.Series("a", [None], dtype=pl.List(pl.Int64))
364
365
a, b, e = broadcast_series(a, b, e)
366
assert_series_equal(exec_op(a, b, op.add), e)
367
368
a = pl.Series("a", [None], dtype=pl.List(pl.Int32))
369
b = pl.Series("b", [0], dtype=pl.Int64)
370
e = pl.Series("a", [None], dtype=pl.List(pl.Int64))
371
372
a, b, e = broadcast_series(a, b, e)
373
assert_series_equal(exec_op(a, b, op.floordiv), e)
374
375
376
def test_list_add_alignment() -> None:
377
a = pl.Series("a", [[1, 1], [1, 1, 1]])
378
b = pl.Series("b", [[1, 1, 1], [1, 1]])
379
380
df = pl.DataFrame([a, b])
381
382
with pytest.raises(ShapeError):
383
df.select(x=pl.col("a") + pl.col("b"))
384
385
# Test masking and slicing
386
a = pl.Series("a", [[1, 1, 1], [1], [1, 1], [1, 1, 1]])
387
b = pl.Series("b", [[1, 1], [1], [1, 1, 1], [1]])
388
c = pl.Series("c", [1, 1, 1, 1])
389
p = pl.Series("p", [True, True, False, False])
390
391
df = pl.DataFrame([a, b, c, p]).filter("p").slice(1)
392
393
for rhs in [pl.col("b"), pl.lit(1), pl.col("c"), pl.lit([1])]:
394
assert_series_equal(
395
df.select(x=pl.col("a") + rhs).to_series(), pl.Series("x", [[2]])
396
)
397
398
df = df.vstack(df)
399
400
for rhs in [pl.col("b"), pl.lit(1), pl.col("c"), pl.lit([1])]:
401
assert_series_equal(
402
df.select(x=pl.col("a") + rhs).to_series(), pl.Series("x", [[2], [2]])
403
)
404
405
406
@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS)
407
@pytest.mark.slow
408
def test_list_add_empty_lists(
409
exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series],
410
) -> None:
411
l = pl.Series( # noqa: E741
412
"x",
413
[[[[]], []], []],
414
)
415
r = pl.Series([1])
416
417
assert_series_equal(
418
exec_op(l, r, operator.add),
419
pl.Series("x", [[[[]], []], []], dtype=pl.List(pl.List(pl.List(pl.Int64)))),
420
)
421
422
l = pl.Series( # noqa: E741
423
"x",
424
[[[[]], None], []],
425
)
426
r = pl.Series([1])
427
428
assert_series_equal(
429
exec_op(l, r, operator.add),
430
pl.Series("x", [[[[]], None], []], dtype=pl.List(pl.List(pl.List(pl.Int64)))),
431
)
432
433
434
@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS)
435
def test_list_to_list_arithmetic_double_nesting_raises_error(
436
exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series],
437
) -> None:
438
s = pl.Series(dtype=pl.List(pl.List(pl.Int32)))
439
440
with pytest.raises(
441
InvalidOperationError,
442
match="cannot add two list columns with non-numeric inner types",
443
):
444
exec_op(s, s, operator.add)
445
446
447
@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS)
448
def test_list_add_height_mismatch(
449
exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series],
450
) -> None:
451
s = pl.Series([[1], [2], [3]], dtype=pl.List(pl.Int32))
452
453
# TODO: Make the error type consistently a ShapeError
454
with pytest.raises(
455
(ShapeError, InvalidOperationError),
456
match="length",
457
):
458
exec_op(s, pl.Series([1, 1]), operator.add)
459
460
461
@pytest.mark.parametrize(
462
"op",
463
[
464
operator.add,
465
operator.sub,
466
operator.mul,
467
operator.floordiv,
468
operator.mod,
469
operator.truediv,
470
],
471
)
472
@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS)
473
@pytest.mark.slow
474
def test_list_date_to_numeric_arithmetic_raises_error(
475
op: Callable[[Any], Any], exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series]
476
) -> None:
477
l = pl.Series([1], dtype=pl.Date) # noqa: E741
478
r = pl.Series([[1]], dtype=pl.List(pl.Int32))
479
480
exec_op(l.to_physical(), r, op)
481
482
# TODO(_): Ideally this always raises InvalidOperationError. The TypeError
483
# is being raised by checks on the Python side that should be moved to Rust.
484
with pytest.raises((InvalidOperationError, TypeError)):
485
exec_op(l, r, op)
486
487
488
@pytest.mark.parametrize(
489
("expected", "expr", "column_names"),
490
[
491
([[2, 4], [6]], lambda a, b: a + b, ("a", "a")),
492
([[0, 0], [0]], lambda a, b: a - b, ("a", "a")),
493
([[1, 4], [9]], lambda a, b: a * b, ("a", "a")),
494
([[1.0, 1.0], [1.0]], lambda a, b: a / b, ("a", "a")),
495
([[0, 0], [0]], lambda a, b: a % b, ("a", "a")),
496
(
497
[[3, 4], [7]],
498
lambda a, b: a + b,
499
("a", "uint8"),
500
),
501
],
502
)
503
def test_list_arithmetic_same_size(
504
expected: Any,
505
expr: Callable[[pl.Series | pl.Expr, pl.Series | pl.Expr], pl.Series],
506
column_names: tuple[str, str],
507
) -> None:
508
df = pl.DataFrame(
509
[
510
pl.Series("a", [[1, 2], [3]]),
511
pl.Series("uint8", [[2, 2], [4]], dtype=pl.List(pl.UInt8())),
512
pl.Series("nested", [[[1, 2]], [[3]]]),
513
pl.Series(
514
"nested_uint8", [[[1, 2]], [[3]]], dtype=pl.List(pl.List(pl.UInt8()))
515
),
516
]
517
)
518
# Expr-based arithmetic:
519
assert_frame_equal(
520
df.select(expr(pl.col(column_names[0]), pl.col(column_names[1]))),
521
pl.Series(column_names[0], expected).to_frame(),
522
)
523
# Direct arithmetic on the Series:
524
assert_series_equal(
525
expr(df[column_names[0]], df[column_names[1]]),
526
pl.Series(column_names[0], expected),
527
)
528
529
530
@pytest.mark.parametrize(
531
("a", "b", "expected"),
532
[
533
([[1, 2, 3]], [[1, None, 5]], [[2, None, 8]]),
534
([[2], None, [5]], [None, [3], [2]], [None, None, [7]]),
535
],
536
)
537
def test_list_arithmetic_nulls(a: list[Any], b: list[Any], expected: list[Any]) -> None:
538
series_a = pl.Series(a)
539
series_b = pl.Series(b)
540
series_expected = pl.Series(expected)
541
542
# Same dtype:
543
assert_series_equal(series_a + series_b, series_expected)
544
545
# Different dtype:
546
assert_series_equal(
547
series_a._recursive_cast_to_dtype(pl.Int32())
548
+ series_b._recursive_cast_to_dtype(pl.Int64()),
549
series_expected._recursive_cast_to_dtype(pl.Int64()),
550
)
551
552
553
def test_list_arithmetic_error_cases() -> None:
554
# Different series length:
555
with pytest.raises(InvalidOperationError, match="different lengths"):
556
_ = pl.Series("a", [[1, 2], [1, 2], [1, 2]]) / pl.Series("b", [[1, 2], [3, 4]])
557
with pytest.raises(InvalidOperationError, match="different lengths"):
558
_ = pl.Series("a", [[1, 2], [1, 2], [1, 2]]) / pl.Series("b", [[1, 2], None])
559
560
# Different list length:
561
with pytest.raises(ShapeError, match="lengths differed at index 0: 2 != 1"):
562
_ = pl.Series("a", [[1, 2], [1, 2], [1, 2]]) / pl.Series("b", [[1]])
563
564
with pytest.raises(ShapeError, match="lengths differed at index 0: 2 != 1"):
565
_ = pl.Series("a", [[1, 2], [2, 3]]) / pl.Series("b", [[1], None])
566
567
568
@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS)
569
def test_list_arithmetic_invalid_dtypes(
570
exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series],
571
) -> None:
572
import operator as op
573
574
a = pl.Series([[1, 2]])
575
b = pl.Series(["hello"])
576
577
# Wrong types:
578
with pytest.raises(
579
InvalidOperationError, match="add operation not supported for dtypes"
580
):
581
exec_op(a, b, op.add)
582
583
a = pl.Series("a", [[1]])
584
b = pl.Series("b", [[[1]]])
585
586
# list<->list is restricted to 1 level of nesting
587
with pytest.raises(
588
InvalidOperationError,
589
match="cannot add two list columns with non-numeric inner types",
590
):
591
exec_op(a, b, op.add)
592
593
# Ensure dtype is validated to be `List` at all nesting levels instead of panicking.
594
a = pl.Series([[[1]], [[1]]], dtype=pl.List(pl.Array(pl.Int64, 1)))
595
b = pl.Series([1], dtype=pl.Int64)
596
597
with pytest.raises(
598
InvalidOperationError, match="dtype was not list on all nesting levels"
599
):
600
exec_op(a, b, op.add)
601
602
with pytest.raises(
603
InvalidOperationError, match="dtype was not list on all nesting levels"
604
):
605
exec_op(b, a, op.add)
606
607
608
@pytest.mark.parametrize(
609
("expected", "expr", "column_names"),
610
[
611
# All 5 arithmetic operations:
612
([[3, 4], [6]], lambda a, b: a + b, ("list", "int64")),
613
([[-1, 0], [0]], lambda a, b: a - b, ("list", "int64")),
614
([[2, 4], [9]], lambda a, b: a * b, ("list", "int64")),
615
([[0.5, 1.0], [1.0]], lambda a, b: a / b, ("list", "int64")),
616
([[1, 0], [0]], lambda a, b: a % b, ("list", "int64")),
617
# Different types:
618
(
619
[[3, 4], [7]],
620
lambda a, b: a + b,
621
("list", "uint8"),
622
),
623
# Extra nesting + different types:
624
(
625
[[[3, 4]], [[8]]],
626
lambda a, b: a + b,
627
("nested", "int64"),
628
),
629
# Primitive numeric on the left; only addition and multiplication are
630
# supported:
631
([[3, 4], [6]], lambda a, b: a + b, ("int64", "list")),
632
([[2, 4], [9]], lambda a, b: a * b, ("int64", "list")),
633
# Primitive numeric on the left with different types:
634
(
635
[[3, 4], [7]],
636
lambda a, b: a + b,
637
("uint8", "list"),
638
),
639
(
640
[[2, 4], [12]],
641
lambda a, b: a * b,
642
("uint8", "list"),
643
),
644
],
645
)
646
def test_list_and_numeric_arithmetic_same_size(
647
expected: Any,
648
expr: Callable[[pl.Series | pl.Expr, pl.Series | pl.Expr], pl.Series],
649
column_names: tuple[str, str],
650
) -> None:
651
df = pl.DataFrame(
652
[
653
pl.Series("list", [[1, 2], [3]]),
654
pl.Series("int64", [2, 3], dtype=pl.Int64()),
655
pl.Series("uint8", [2, 4], dtype=pl.UInt8()),
656
pl.Series("nested", [[[1, 2]], [[5]]]),
657
]
658
)
659
# Expr-based arithmetic:
660
assert_frame_equal(
661
df.select(expr(pl.col(column_names[0]), pl.col(column_names[1]))),
662
pl.Series(column_names[0], expected).to_frame(),
663
)
664
# Direct arithmetic on the Series:
665
assert_series_equal(
666
expr(df[column_names[0]], df[column_names[1]]),
667
pl.Series(column_names[0], expected),
668
)
669
670
671
@pytest.mark.parametrize(
672
("a", "b", "expected"),
673
[
674
# Null on numeric on the right:
675
([[1, 2], [3]], [1, None], [[2, 3], [None]]),
676
# Null on list on the left:
677
([[[1, 2]], [[3]]], [None, 1], [[[None, None]], [[4]]]),
678
# Extra nesting:
679
([[[2, None]], [[3, 6]]], [3, 4], [[[5, None]], [[7, 10]]]),
680
],
681
)
682
def test_list_and_numeric_arithmetic_nulls(
683
a: list[Any], b: list[Any], expected: list[Any]
684
) -> None:
685
series_a = pl.Series(a)
686
series_b = pl.Series(b)
687
series_expected = pl.Series(expected, dtype=series_a.dtype)
688
689
# Same dtype:
690
assert_series_equal(series_a + series_b, series_expected)
691
692
# Different dtype:
693
assert_series_equal(
694
series_a._recursive_cast_to_dtype(pl.Int32())
695
+ series_b._recursive_cast_to_dtype(pl.Int64()),
696
series_expected._recursive_cast_to_dtype(pl.Int64()),
697
)
698
699
# Swap sides:
700
assert_series_equal(series_b + series_a, series_expected)
701
assert_series_equal(
702
series_b._recursive_cast_to_dtype(pl.Int32())
703
+ series_a._recursive_cast_to_dtype(pl.Int64()),
704
series_expected._recursive_cast_to_dtype(pl.Int64()),
705
)
706
707
708
def test_list_and_numeric_arithmetic_error_cases() -> None:
709
# Different series length:
710
with pytest.raises(
711
InvalidOperationError, match="series of different lengths: got 3 and 2"
712
):
713
_ = pl.Series("a", [[1, 2], [3, 4], [5, 6]]) + pl.Series("b", [1, 2])
714
with pytest.raises(
715
InvalidOperationError, match="series of different lengths: got 3 and 2"
716
):
717
_ = pl.Series("a", [[1, 2], [3, 4], [5, 6]]) / pl.Series("b", [1, None])
718
719
# Wrong types:
720
with pytest.raises(
721
InvalidOperationError, match="add operation not supported for dtypes"
722
):
723
_ = pl.Series("a", [[1, 2], [3, 4]]) + pl.Series("b", ["hello", "world"])
724
725
726
@pytest.mark.parametrize("broadcast", [True, False])
727
@pytest.mark.parametrize("dtype", [pl.Int64(), pl.Float64()])
728
def test_list_arithmetic_div_ops_zero_denominator(
729
broadcast: bool, dtype: pl.DataType
730
) -> None:
731
# Notes
732
# * truediv (/) on integers upcasts to Float64
733
# * Otherwise, we test floordiv (//) and module/rem (%)
734
# * On integers, 0-denominator is expected to output NULL
735
# * On floats, 0-denominator has different outputs, e.g. NaN, Inf, depending
736
# on a few factors (e.g. whether the numerator is also 0).
737
738
s = pl.Series([[0], [1], [None], None]).cast(pl.List(dtype))
739
740
n = 1 if broadcast else s.len()
741
742
# list<->primitive
743
744
# truediv
745
assert_series_equal(
746
pl.Series([1]).new_from_index(0, n) / s,
747
pl.Series([[float("inf")], [1.0], [None], None], dtype=pl.List(pl.Float64)),
748
)
749
750
assert_series_equal(
751
s / pl.Series([1]).new_from_index(0, n),
752
pl.Series([[0.0], [1.0], [None], None], dtype=pl.List(pl.Float64)),
753
)
754
755
# floordiv
756
assert_series_equal(
757
pl.Series([1]).new_from_index(0, n) // s,
758
(
759
pl.Series([[None], [1], [None], None], dtype=s.dtype)
760
if not dtype.is_float()
761
else pl.Series([[float("inf")], [1.0], [None], None], dtype=s.dtype)
762
),
763
)
764
765
assert_series_equal(
766
s // pl.Series([0]).new_from_index(0, n),
767
(
768
pl.Series([[None], [None], [None], None], dtype=s.dtype)
769
if not dtype.is_float()
770
else pl.Series(
771
[[float("nan")], [float("inf")], [None], None], dtype=s.dtype
772
)
773
),
774
)
775
776
# rem
777
assert_series_equal(
778
pl.Series([1]).new_from_index(0, n) % s,
779
(
780
pl.Series([[None], [0], [None], None], dtype=s.dtype)
781
if not dtype.is_float()
782
else pl.Series([[float("nan")], [0.0], [None], None], dtype=s.dtype)
783
),
784
)
785
786
assert_series_equal(
787
s % pl.Series([0]).new_from_index(0, n),
788
(
789
pl.Series([[None], [None], [None], None], dtype=s.dtype)
790
if not dtype.is_float()
791
else pl.Series(
792
[[float("nan")], [float("nan")], [None], None], dtype=s.dtype
793
)
794
),
795
)
796
797
# list<->list
798
799
# truediv
800
assert_series_equal(
801
pl.Series([[1]]).new_from_index(0, n) / s,
802
pl.Series([[float("inf")], [1.0], [None], None], dtype=pl.List(pl.Float64)),
803
)
804
805
assert_series_equal(
806
s / pl.Series([[0]]).new_from_index(0, n),
807
pl.Series(
808
[[float("nan")], [float("inf")], [None], None], dtype=pl.List(pl.Float64)
809
),
810
)
811
812
# floordiv
813
assert_series_equal(
814
pl.Series([[1]]).new_from_index(0, n) // s,
815
(
816
pl.Series([[None], [1], [None], None], dtype=s.dtype)
817
if not dtype.is_float()
818
else pl.Series([[float("inf")], [1.0], [None], None], dtype=s.dtype)
819
),
820
)
821
822
assert_series_equal(
823
s // pl.Series([[0]]).new_from_index(0, n),
824
(
825
pl.Series([[None], [None], [None], None], dtype=s.dtype)
826
if not dtype.is_float()
827
else pl.Series(
828
[[float("nan")], [float("inf")], [None], None], dtype=s.dtype
829
)
830
),
831
)
832
833
# rem
834
assert_series_equal(
835
pl.Series([[1]]).new_from_index(0, n) % s,
836
(
837
pl.Series([[None], [0], [None], None], dtype=s.dtype)
838
if not dtype.is_float()
839
else pl.Series([[float("nan")], [0.0], [None], None], dtype=s.dtype)
840
),
841
)
842
843
assert_series_equal(
844
s % pl.Series([[0]]).new_from_index(0, n),
845
(
846
pl.Series([[None], [None], [None], None], dtype=s.dtype)
847
if not dtype.is_float()
848
else pl.Series(
849
[[float("nan")], [float("nan")], [None], None], dtype=s.dtype
850
)
851
),
852
)
853
854
855
def test_list_to_primitive_arithmetic() -> None:
856
# Input data
857
# * List type: List(List(List(Int16))) (triple-nested)
858
# * Numeric type: Int32
859
#
860
# Tests run
861
# Broadcast Operation
862
# | L | R |
863
# * list<->primitive | | | floor_div
864
# * primitive<->list | | | floor_div
865
# * list<->primitive | | * | subtract
866
# * primitive<->list | * | | subtract
867
# * list<->primitive | * | | subtract
868
# * primitive<->list | | * | subtract
869
#
870
# Notes
871
# * In floor_div, we check that results from a 0 denominator are masked out
872
# * We choose floor_div and subtract as they emit different results when
873
# sides are swapped
874
875
# Create some non-zero start offsets and masked out rows.
876
lhs = (
877
pl.Series(
878
[
879
[[[None, None, None, None, None]]], # sliced out
880
# Nulls at every level XO
881
[[[3, 7]], [[-3], [None], [], [], None], [], None],
882
[[[1, 2, 3, 4, 5]]], # masked out
883
[[[3, 7]], [[0], [None], [], [], None]],
884
[[[3, 7]]],
885
],
886
dtype=pl.List(pl.List(pl.List(pl.Int16))),
887
)
888
.slice(1)
889
.to_frame()
890
.select(pl.when(pl.int_range(pl.len()) != 1).then(pl.first()))
891
.to_series()
892
)
893
894
# Note to reader: This is what our LHS looks like
895
assert_series_equal(
896
lhs,
897
pl.Series(
898
[
899
[[[3, 7]], [[-3], [None], [], [], None], [], None],
900
None,
901
[[[3, 7]], [[0], [None], [], [], None]],
902
[[[3, 7]]],
903
],
904
dtype=pl.List(pl.List(pl.List(pl.Int16))),
905
),
906
)
907
908
class _:
909
# Floor div, no broadcasting
910
rhs = pl.Series([5, 1, 0, None], dtype=pl.Int32)
911
912
assert len(lhs) == len(rhs)
913
914
expect = pl.Series(
915
[
916
[[[0, 1]], [[-1], [None], [], [], None], [], None],
917
None,
918
[[[None, None]], [[None], [None], [], [], None]],
919
[[[None, None]]],
920
],
921
dtype=pl.List(pl.List(pl.List(pl.Int32))),
922
)
923
924
out = (
925
pl.select(l=lhs, r=rhs)
926
.select(pl.col("l") // pl.col("r"))
927
.to_series()
928
.alias("")
929
)
930
931
assert_series_equal(out, expect)
932
933
# Flipped
934
935
expect = pl.Series( # noqa: PIE794
936
[
937
[[[1, 0]], [[-2], [None], [], [], None], [], None],
938
None,
939
[[[0, 0]], [[None], [None], [], [], None]],
940
[[[None, None]]],
941
],
942
dtype=pl.List(pl.List(pl.List(pl.Int32))),
943
)
944
945
out = ( # noqa: PIE794
946
pl.select(l=lhs, r=rhs)
947
.select(pl.col("r") // pl.col("l"))
948
.to_series()
949
.alias("")
950
)
951
952
assert_series_equal(out, expect)
953
954
class _: # type: ignore[no-redef]
955
# Subtraction with broadcasting
956
rhs = pl.Series([1], dtype=pl.Int32)
957
958
expect = pl.Series(
959
[
960
[[[2, 6]], [[-4], [None], [], [], None], [], None],
961
None,
962
[[[2, 6]], [[-1], [None], [], [], None]],
963
[[[2, 6]]],
964
],
965
dtype=pl.List(pl.List(pl.List(pl.Int32))),
966
)
967
968
out = pl.select(l=lhs).select(pl.col("l") - rhs).to_series().alias("")
969
970
assert_series_equal(out, expect)
971
972
# Flipped
973
974
expect = pl.Series( # noqa: PIE794
975
[
976
[[[-2, -6]], [[4], [None], [], [], None], [], None],
977
None,
978
[[[-2, -6]], [[1], [None], [], [], None]],
979
[[[-2, -6]]],
980
],
981
dtype=pl.List(pl.List(pl.List(pl.Int32))),
982
)
983
984
out = pl.select(l=lhs).select(rhs - pl.col("l")).to_series().alias("") # noqa: PIE794
985
986
assert_series_equal(out, expect)
987
988
# Test broadcasting of the list side
989
lhs = lhs.slice(2, 1)
990
# Note to reader: This is what our LHS looks like
991
assert_series_equal(
992
lhs,
993
pl.Series(
994
[
995
[[[3, 7]], [[0], [None], [], [], None]],
996
],
997
dtype=pl.List(pl.List(pl.List(pl.Int16))),
998
),
999
)
1000
1001
assert len(lhs) == 1
1002
1003
class _: # type: ignore[no-redef]
1004
rhs = pl.Series([1, 2, 3, None, 5], dtype=pl.Int32)
1005
1006
expect = pl.Series(
1007
[
1008
[[[2, 6]], [[-1], [None], [], [], None]],
1009
[[[1, 5]], [[-2], [None], [], [], None]],
1010
[[[0, 4]], [[-3], [None], [], [], None]],
1011
[[[None, None]], [[None], [None], [], [], None]],
1012
[[[-2, 2]], [[-5], [None], [], [], None]],
1013
],
1014
dtype=pl.List(pl.List(pl.List(pl.Int32))),
1015
)
1016
1017
out = pl.select(r=rhs).select(lhs - pl.col("r")).to_series().alias("")
1018
1019
assert_series_equal(out, expect)
1020
1021
# Flipped
1022
1023
expect = pl.Series( # noqa: PIE794
1024
[
1025
[[[-2, -6]], [[1], [None], [], [], None]],
1026
[[[-1, -5]], [[2], [None], [], [], None]],
1027
[[[0, -4]], [[3], [None], [], [], None]],
1028
[[[None, None]], [[None], [None], [], [], None]],
1029
[[[2, -2]], [[5], [None], [], [], None]],
1030
],
1031
dtype=pl.List(pl.List(pl.List(pl.Int32))),
1032
)
1033
1034
out = pl.select(r=rhs).select(pl.col("r") - lhs).to_series().alias("") # noqa: PIE794
1035
1036
assert_series_equal(out, expect)
1037
1038
1039
def test_list_boolean_arithmetic_23146() -> None:
1040
"""Test that boolean arithmetic in lists works and returns UInt32."""
1041
# Boolean list + Boolean list with single value
1042
result = pl.select(pl.lit([True]) + [True])
1043
expected = pl.DataFrame({"literal": [[2]]}, schema={"literal": pl.List(pl.UInt32)})
1044
assert_frame_equal(result, expected)
1045
1046
# Boolean list + Boolean list with multiple values
1047
result = pl.select(pl.lit([True, False]) + [True, True])
1048
expected = pl.DataFrame(
1049
{"literal": [[2, 1]]}, schema={"literal": pl.List(pl.UInt32)}
1050
)
1051
assert_frame_equal(result, expected)
1052
1053
# Boolean list + Scalar integer (supertype with Int32)
1054
result = pl.select(pl.lit([True, False]) + 1)
1055
expected = pl.DataFrame(
1056
{"literal": [[2, 1]]}, schema={"literal": pl.List(pl.Int32)}
1057
)
1058
assert_frame_equal(result, expected)
1059
1060
# Boolean list arithmetic on DataFrame columns
1061
df = pl.DataFrame(
1062
{"a": [[True, False]], "b": [[1, 2]]},
1063
schema={"a": pl.List(pl.Boolean), "b": pl.List(pl.Int64)},
1064
)
1065
result = df.select(pl.col("a") + pl.col("b"))
1066
expected = pl.DataFrame({"a": [[2, 2]]}, schema={"a": pl.List(pl.Int64)})
1067
assert_frame_equal(result, expected)
1068
1069
# Division test
1070
df = pl.DataFrame({"a": [[True, False]], "b": [[128, 128]]})
1071
result = df.select(pl.col("a") / pl.col("b"))
1072
assert result.schema == {"a": pl.List(pl.Float64)}
1073
1074