Path: blob/main/py-polars/tests/unit/operations/arithmetic/test_list.py
8421 views
from __future__ import annotations12import operator3from typing import TYPE_CHECKING, Any45import pytest67import polars as pl8from polars.exceptions import InvalidOperationError, ShapeError9from polars.testing import assert_frame_equal, assert_series_equal10from tests.unit.operations.arithmetic.utils import (11BROADCAST_SERIES_COMBINATIONS,12EXEC_OP_COMBINATIONS,13)1415if TYPE_CHECKING:16from collections.abc import Callable1718from polars._typing import PolarsDataType192021@pytest.mark.parametrize(22"list_side", ["left", "left3", "both", "right3", "right", "none"]23)24@pytest.mark.parametrize(25"broadcast_series",26BROADCAST_SERIES_COMBINATIONS,27)28@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS)29@pytest.mark.slow30def test_list_arithmetic_values(31list_side: str,32broadcast_series: Callable[33[pl.Series, pl.Series, pl.Series], tuple[pl.Series, pl.Series, pl.Series]34],35exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series],36) -> None:37"""38Tests value correctness.3940This test checks for output value correctness (a + b == c) across different41codepaths, by wrapping the values (a, b, c) in different combinations of42list / primitive columns.43"""44import operator as op4546dtypes: list[Any] = [pl.Null, pl.Null, pl.Null]47dtype: Any = pl.Null4849def materialize_list(v: Any) -> pl.Series:50return pl.Series(51[[None, v, None]],52dtype=pl.List(dtype),53)5455def materialize_list3(v: Any) -> pl.Series:56return pl.Series(57[[[[None, v], None], None]],58dtype=pl.List(pl.List(pl.List(dtype))),59)6061def materialize_primitive(v: Any) -> pl.Series:62return pl.Series([v], dtype=dtype)6364def materialize_series(65l: Any, # noqa: E74166r: Any,67o: Any,68) -> tuple[pl.Series, pl.Series, pl.Series]:69nonlocal dtype7071dtype = dtypes[0]72l = { # noqa: E74173"left": materialize_list,74"left3": materialize_list3,75"both": materialize_list,76"right": materialize_primitive,77"right3": materialize_primitive,78"none": materialize_primitive,79}[list_side](l) # fmt: skip8081dtype = dtypes[1]82r = {83"left": materialize_primitive,84"left3": materialize_primitive,85"both": materialize_list,86"right": materialize_list,87"right3": materialize_list3,88"none": materialize_primitive,89}[list_side](r) # fmt: skip9091dtype = dtypes[2]92o = {93"left": materialize_list,94"left3": materialize_list3,95"both": materialize_list,96"right": materialize_list,97"right3": materialize_list3,98"none": materialize_primitive,99}[list_side](o) # fmt: skip100101assert l.len() == 1102assert r.len() == 1103assert o.len() == 1104105return broadcast_series(l, r, o)106107# Signed108dtypes = [pl.Int8, pl.Int8, pl.Int8]109110l, r, o = materialize_series(2, 3, 5) # noqa: E741111assert_series_equal(exec_op(l, r, op.add), o)112113l, r, o = materialize_series(-5, 127, 124) # noqa: E741114assert_series_equal(exec_op(l, r, op.sub), o)115116l, r, o = materialize_series(-5, 127, -123) # noqa: E741117assert_series_equal(exec_op(l, r, op.mul), o)118119l, r, o = materialize_series(-5, 3, -2) # noqa: E741120assert_series_equal(exec_op(l, r, op.floordiv), o)121122l, r, o = materialize_series(-5, 3, 1) # noqa: E741123assert_series_equal(exec_op(l, r, op.mod), o)124125dtypes = [pl.UInt8, pl.UInt8, pl.Float64]126l, r, o = materialize_series(2, 128, 0.015625) # noqa: E741127assert_series_equal(exec_op(l, r, op.truediv), o)128129# Unsigned130dtypes = [pl.UInt8, pl.UInt8, pl.UInt8]131132l, r, o = materialize_series(2, 3, 5) # noqa: E741133assert_series_equal(exec_op(l, r, op.add), o)134135l, r, o = materialize_series(2, 3, 255) # noqa: E741136assert_series_equal(exec_op(l, r, op.sub), o)137138l, r, o = materialize_series(2, 128, 0) # noqa: E741139assert_series_equal(exec_op(l, r, op.mul), o)140141l, r, o = materialize_series(5, 2, 2) # noqa: E741142assert_series_equal(exec_op(l, r, op.floordiv), o)143144l, r, o = materialize_series(5, 2, 1) # noqa: E741145assert_series_equal(exec_op(l, r, op.mod), o)146147dtypes = [pl.UInt8, pl.UInt8, pl.Float64]148l, r, o = materialize_series(2, 128, 0.015625) # noqa: E741149assert_series_equal(exec_op(l, r, op.truediv), o)150151# Floats. Note we pick Float32 to ensure there is no accidental upcasting152# to Float64.153dtypes = [pl.Float32, pl.Float32, pl.Float32]154l, r, o = materialize_series(1.7, 2.3, 4.0) # noqa: E741155assert_series_equal(exec_op(l, r, op.add), o)156157l, r, o = materialize_series(1.7, 2.3, -0.5999999999999999) # noqa: E741158assert_series_equal(exec_op(l, r, op.sub), o)159160l, r, o = materialize_series(1.7, 2.3, 3.9099999999999997) # noqa: E741161assert_series_equal(exec_op(l, r, op.mul), o)162163l, r, o = materialize_series(7.0, 3.0, 2.0) # noqa: E741164assert_series_equal(exec_op(l, r, op.floordiv), o)165166l, r, o = materialize_series(-5.0, 3.0, 1.0) # noqa: E741167assert_series_equal(exec_op(l, r, op.mod), o)168169l, r, o = materialize_series(2.0, 128.0, 0.015625) # noqa: E741170assert_series_equal(exec_op(l, r, op.truediv), o)171172#173# Tests for zero behavior174#175176# Integer177178dtypes = [pl.UInt8, pl.UInt8, pl.UInt8]179180l, r, o = materialize_series(1, 0, None) # noqa: E741181assert_series_equal(exec_op(l, r, op.floordiv), o)182assert_series_equal(exec_op(l, r, op.mod), o)183184l, r, o = materialize_series(0, 0, None) # noqa: E741185assert_series_equal(exec_op(l, r, op.floordiv), o)186assert_series_equal(exec_op(l, r, op.mod), o)187188dtypes = [pl.UInt8, pl.UInt8, pl.Float64]189190l, r, o = materialize_series(1, 0, float("inf")) # noqa: E741191assert_series_equal(exec_op(l, r, op.truediv), o)192193l, r, o = materialize_series(0, 0, float("nan")) # noqa: E741194assert_series_equal(exec_op(l, r, op.truediv), o)195196# Float197198dtypes = [pl.Float32, pl.Float32, pl.Float32]199200l, r, o = materialize_series(1, 0, float("inf")) # noqa: E741201assert_series_equal(exec_op(l, r, op.floordiv), o)202203l, r, o = materialize_series(1, 0, float("nan")) # noqa: E741204assert_series_equal(exec_op(l, r, op.mod), o)205206l, r, o = materialize_series(1, 0, float("inf")) # noqa: E741207assert_series_equal(exec_op(l, r, op.truediv), o)208209l, r, o = materialize_series(0, 0, float("nan")) # noqa: E741210assert_series_equal(exec_op(l, r, op.floordiv), o)211212l, r, o = materialize_series(0, 0, float("nan")) # noqa: E741213assert_series_equal(exec_op(l, r, op.mod), o)214215l, r, o = materialize_series(0, 0, float("nan")) # noqa: E741216assert_series_equal(exec_op(l, r, op.truediv), o)217218#219# Tests for NULL behavior220#221222for dtype, truediv_dtype in [ # type: ignore[misc]223[pl.Int8, pl.Float64],224[pl.Float32, pl.Float32],225]:226for vals in [227[None, None, None],228[0, None, None],229[None, 0, None],230[0, None, None],231[None, 0, None],232[3, None, None],233[None, 3, None],234]:235dtypes = 3 * [dtype]236237l, r, o = materialize_series(*vals) # type: ignore[misc] # noqa: E741238assert_series_equal(exec_op(l, r, op.add), o)239assert_series_equal(exec_op(l, r, op.sub), o)240assert_series_equal(exec_op(l, r, op.mul), o)241assert_series_equal(exec_op(l, r, op.floordiv), o)242assert_series_equal(exec_op(l, r, op.mod), o)243dtypes[2] = truediv_dtype # type: ignore[has-type]244l, r, o = materialize_series(*vals) # type: ignore[misc] # noqa: E741245assert_series_equal(exec_op(l, r, op.truediv), o)246247# Type upcasting for Boolean and Null248249# Check boolean upcasting250dtypes = [pl.Boolean, pl.Boolean, pl.get_index_type()]251l, r, o = materialize_series(True, True, 2) # noqa: E741252assert_series_equal(exec_op(l, r, op.add), o)253254dtypes = [pl.Boolean, pl.Boolean, pl.Float64]255l, r, o = materialize_series(True, True, 1.0) # noqa: E741256assert_series_equal(exec_op(l, r, op.truediv), o)257258dtypes = [pl.Boolean, pl.UInt8, pl.UInt8]259260l, r, o = materialize_series(True, 3, 4) # noqa: E741261assert_series_equal(exec_op(l, r, op.add), o)262263l, r, o = materialize_series(True, 3, 254) # noqa: E741264assert_series_equal(exec_op(l, r, op.sub), o)265266l, r, o = materialize_series(True, 3, 3) # noqa: E741267assert_series_equal(exec_op(l, r, op.mul), o)268269l, r, o = materialize_series(True, 3, 0) # noqa: E741270if list_side != "none":271# TODO: We get an error on non-lists with this:272# "floor_div operation not supported for dtype `bool`"273assert_series_equal(exec_op(l, r, op.floordiv), o)274275l, r, o = materialize_series(True, 3, 1) # noqa: E741276assert_series_equal(exec_op(l, r, op.mod), o)277278dtypes = [pl.Boolean, pl.UInt8, pl.Float64]279l, r, o = materialize_series(True, 128, 0.0078125) # noqa: E741280assert_series_equal(exec_op(l, r, op.truediv), o)281282# Check Null upcasting283dtypes = [pl.Null, pl.UInt8, pl.UInt8]284l, r, o = materialize_series(None, 3, None) # noqa: E741285assert_series_equal(exec_op(l, r, op.add), o)286assert_series_equal(exec_op(l, r, op.sub), o)287assert_series_equal(exec_op(l, r, op.mul), o)288if list_side != "none":289assert_series_equal(exec_op(l, r, op.floordiv), o)290assert_series_equal(exec_op(l, r, op.mod), o)291292dtypes = [pl.Null, pl.UInt8, pl.Float64]293l, r, o = materialize_series(None, 3, None) # noqa: E741294assert_series_equal(exec_op(l, r, op.truediv), o)295296297@pytest.mark.parametrize(298("lhs_dtype", "rhs_dtype", "expected_dtype"),299[300(pl.List(pl.Int64), pl.Int64, pl.List(pl.Float64)),301(pl.List(pl.Float32), pl.Float32, pl.List(pl.Float32)),302(pl.List(pl.Duration("us")), pl.Int64, pl.List(pl.Duration("us"))),303],304)305def test_list_truediv_schema(306lhs_dtype: PolarsDataType, rhs_dtype: PolarsDataType, expected_dtype: PolarsDataType307) -> None:308schema = {"lhs": lhs_dtype, "rhs": rhs_dtype}309df = pl.DataFrame({"lhs": [[None, 10]], "rhs": 2}, schema=schema)310result = df.lazy().select(pl.col("lhs").truediv("rhs")).collect_schema()["lhs"]311assert result == expected_dtype312313314@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS)315def test_list_add_supertype(316exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series],317) -> None:318import operator as op319320a = pl.Series("a", [[1], [2]], dtype=pl.List(pl.Int8))321b = pl.Series("b", [[1], [999]], dtype=pl.List(pl.Int64))322323assert_series_equal(324exec_op(a, b, op.add),325pl.Series("a", [[2], [1001]], dtype=pl.List(pl.Int64)),326)327328329@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS)330@pytest.mark.parametrize(331"broadcast_series",332BROADCAST_SERIES_COMBINATIONS,333)334@pytest.mark.slow335def test_list_numeric_op_validity_combination(336broadcast_series: Callable[337[pl.Series, pl.Series, pl.Series], tuple[pl.Series, pl.Series, pl.Series]338],339exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series],340) -> None:341import operator as op342343a = pl.Series("a", [[1], [2], None, [None], [11], [1111]], dtype=pl.List(pl.Int32))344b = pl.Series("b", [[1], [3], [11], [1111], None, [None]], dtype=pl.List(pl.Int64))345# expected result346e = pl.Series("a", [[2], [5], None, [None], None, [None]], dtype=pl.List(pl.Int64))347348assert_series_equal(349exec_op(a, b, op.add),350e,351)352353a = pl.Series("a", [[1]], dtype=pl.List(pl.Int32))354b = pl.Series("b", [None], dtype=pl.Int64)355e = pl.Series("a", [[None]], dtype=pl.List(pl.Int64))356357a, b, e = broadcast_series(a, b, e)358assert_series_equal(exec_op(a, b, op.add), e)359360a = pl.Series("a", [None], dtype=pl.List(pl.Int32))361b = pl.Series("b", [1], dtype=pl.Int64)362e = pl.Series("a", [None], dtype=pl.List(pl.Int64))363364a, b, e = broadcast_series(a, b, e)365assert_series_equal(exec_op(a, b, op.add), e)366367a = pl.Series("a", [None], dtype=pl.List(pl.Int32))368b = pl.Series("b", [0], dtype=pl.Int64)369e = pl.Series("a", [None], dtype=pl.List(pl.Int64))370371a, b, e = broadcast_series(a, b, e)372assert_series_equal(exec_op(a, b, op.floordiv), e)373374375def test_list_add_alignment() -> None:376a = pl.Series("a", [[1, 1], [1, 1, 1]])377b = pl.Series("b", [[1, 1, 1], [1, 1]])378379df = pl.DataFrame([a, b])380381with pytest.raises(ShapeError):382df.select(x=pl.col("a") + pl.col("b"))383384# Test masking and slicing385a = pl.Series("a", [[1, 1, 1], [1], [1, 1], [1, 1, 1]])386b = pl.Series("b", [[1, 1], [1], [1, 1, 1], [1]])387c = pl.Series("c", [1, 1, 1, 1])388p = pl.Series("p", [True, True, False, False])389390df = pl.DataFrame([a, b, c, p]).filter("p").slice(1)391392for rhs in [pl.col("b"), pl.lit(1), pl.col("c"), pl.lit([1])]:393assert_series_equal(394df.select(x=pl.col("a") + rhs).to_series(), pl.Series("x", [[2]])395)396397df = df.vstack(df)398399for rhs in [pl.col("b"), pl.lit(1), pl.col("c"), pl.lit([1])]:400assert_series_equal(401df.select(x=pl.col("a") + rhs).to_series(), pl.Series("x", [[2], [2]])402)403404405@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS)406@pytest.mark.slow407def test_list_add_empty_lists(408exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series],409) -> None:410l = pl.Series( # noqa: E741411"x",412[[[[]], []], []],413)414r = pl.Series([1])415416assert_series_equal(417exec_op(l, r, operator.add),418pl.Series("x", [[[[]], []], []], dtype=pl.List(pl.List(pl.List(pl.Int64)))),419)420421l = pl.Series( # noqa: E741422"x",423[[[[]], None], []],424)425r = pl.Series([1])426427assert_series_equal(428exec_op(l, r, operator.add),429pl.Series("x", [[[[]], None], []], dtype=pl.List(pl.List(pl.List(pl.Int64)))),430)431432433@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS)434def test_list_to_list_arithmetic_double_nesting_raises_error(435exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series],436) -> None:437s = pl.Series(dtype=pl.List(pl.List(pl.Int32)))438439with pytest.raises(440InvalidOperationError,441match="cannot add two list columns with non-numeric inner types",442):443exec_op(s, s, operator.add)444445446@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS)447def test_list_add_height_mismatch(448exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series],449) -> None:450s = pl.Series([[1], [2], [3]], dtype=pl.List(pl.Int32))451452# TODO: Make the error type consistently a ShapeError453with pytest.raises(454(ShapeError, InvalidOperationError),455match="length",456):457exec_op(s, pl.Series([1, 1]), operator.add)458459460@pytest.mark.parametrize(461"op",462[463operator.add,464operator.sub,465operator.mul,466operator.floordiv,467operator.mod,468operator.truediv,469],470)471@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS)472@pytest.mark.slow473def test_list_date_to_numeric_arithmetic_raises_error(474op: Callable[[Any], Any], exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series]475) -> None:476l = pl.Series([1], dtype=pl.Date) # noqa: E741477r = pl.Series([[1]], dtype=pl.List(pl.Int32))478479exec_op(l.to_physical(), r, op)480481# TODO(_): Ideally this always raises InvalidOperationError. The TypeError482# is being raised by checks on the Python side that should be moved to Rust.483with pytest.raises((InvalidOperationError, TypeError)):484exec_op(l, r, op)485486487@pytest.mark.parametrize(488("expected", "expr", "column_names"),489[490([[2, 4], [6]], lambda a, b: a + b, ("a", "a")),491([[0, 0], [0]], lambda a, b: a - b, ("a", "a")),492([[1, 4], [9]], lambda a, b: a * b, ("a", "a")),493([[1.0, 1.0], [1.0]], lambda a, b: a / b, ("a", "a")),494([[0, 0], [0]], lambda a, b: a % b, ("a", "a")),495(496[[3, 4], [7]],497lambda a, b: a + b,498("a", "uint8"),499),500],501)502def test_list_arithmetic_same_size(503expected: Any,504expr: Callable[[pl.Series | pl.Expr, pl.Series | pl.Expr], pl.Series],505column_names: tuple[str, str],506) -> None:507df = pl.DataFrame(508[509pl.Series("a", [[1, 2], [3]]),510pl.Series("uint8", [[2, 2], [4]], dtype=pl.List(pl.UInt8())),511pl.Series("nested", [[[1, 2]], [[3]]]),512pl.Series(513"nested_uint8", [[[1, 2]], [[3]]], dtype=pl.List(pl.List(pl.UInt8()))514),515]516)517# Expr-based arithmetic:518assert_frame_equal(519df.select(expr(pl.col(column_names[0]), pl.col(column_names[1]))),520pl.Series(column_names[0], expected).to_frame(),521)522# Direct arithmetic on the Series:523assert_series_equal(524expr(df[column_names[0]], df[column_names[1]]),525pl.Series(column_names[0], expected),526)527528529@pytest.mark.parametrize(530("a", "b", "expected"),531[532([[1, 2, 3]], [[1, None, 5]], [[2, None, 8]]),533([[2], None, [5]], [None, [3], [2]], [None, None, [7]]),534],535)536def test_list_arithmetic_nulls(a: list[Any], b: list[Any], expected: list[Any]) -> None:537series_a = pl.Series(a)538series_b = pl.Series(b)539series_expected = pl.Series(expected)540541# Same dtype:542assert_series_equal(series_a + series_b, series_expected)543544# Different dtype:545assert_series_equal(546series_a._recursive_cast_to_dtype(pl.Int32())547+ series_b._recursive_cast_to_dtype(pl.Int64()),548series_expected._recursive_cast_to_dtype(pl.Int64()),549)550551552def test_list_arithmetic_error_cases() -> None:553# Different series length:554with pytest.raises(InvalidOperationError, match="different lengths"):555_ = pl.Series("a", [[1, 2], [1, 2], [1, 2]]) / pl.Series("b", [[1, 2], [3, 4]])556with pytest.raises(InvalidOperationError, match="different lengths"):557_ = pl.Series("a", [[1, 2], [1, 2], [1, 2]]) / pl.Series("b", [[1, 2], None])558559# Different list length:560with pytest.raises(ShapeError, match="lengths differed at index 0: 2 != 1"):561_ = pl.Series("a", [[1, 2], [1, 2], [1, 2]]) / pl.Series("b", [[1]])562563with pytest.raises(ShapeError, match="lengths differed at index 0: 2 != 1"):564_ = pl.Series("a", [[1, 2], [2, 3]]) / pl.Series("b", [[1], None])565566567@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS)568def test_list_arithmetic_invalid_dtypes(569exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series],570) -> None:571import operator as op572573a = pl.Series([[1, 2]])574b = pl.Series(["hello"])575576# Wrong types:577with pytest.raises(578InvalidOperationError, match="add operation not supported for dtypes"579):580exec_op(a, b, op.add)581582a = pl.Series("a", [[1]])583b = pl.Series("b", [[[1]]])584585# list<->list is restricted to 1 level of nesting586with pytest.raises(587InvalidOperationError,588match="cannot add two list columns with non-numeric inner types",589):590exec_op(a, b, op.add)591592# Ensure dtype is validated to be `List` at all nesting levels instead of panicking.593a = pl.Series([[[1]], [[1]]], dtype=pl.List(pl.Array(pl.Int64, 1)))594b = pl.Series([1], dtype=pl.Int64)595596with pytest.raises(597InvalidOperationError, match="dtype was not list on all nesting levels"598):599exec_op(a, b, op.add)600601with pytest.raises(602InvalidOperationError, match="dtype was not list on all nesting levels"603):604exec_op(b, a, op.add)605606607@pytest.mark.parametrize(608("expected", "expr", "column_names"),609[610# All 5 arithmetic operations:611([[3, 4], [6]], lambda a, b: a + b, ("list", "int64")),612([[-1, 0], [0]], lambda a, b: a - b, ("list", "int64")),613([[2, 4], [9]], lambda a, b: a * b, ("list", "int64")),614([[0.5, 1.0], [1.0]], lambda a, b: a / b, ("list", "int64")),615([[1, 0], [0]], lambda a, b: a % b, ("list", "int64")),616# Different types:617(618[[3, 4], [7]],619lambda a, b: a + b,620("list", "uint8"),621),622# Extra nesting + different types:623(624[[[3, 4]], [[8]]],625lambda a, b: a + b,626("nested", "int64"),627),628# Primitive numeric on the left; only addition and multiplication are629# supported:630([[3, 4], [6]], lambda a, b: a + b, ("int64", "list")),631([[2, 4], [9]], lambda a, b: a * b, ("int64", "list")),632# Primitive numeric on the left with different types:633(634[[3, 4], [7]],635lambda a, b: a + b,636("uint8", "list"),637),638(639[[2, 4], [12]],640lambda a, b: a * b,641("uint8", "list"),642),643],644)645def test_list_and_numeric_arithmetic_same_size(646expected: Any,647expr: Callable[[pl.Series | pl.Expr, pl.Series | pl.Expr], pl.Series],648column_names: tuple[str, str],649) -> None:650df = pl.DataFrame(651[652pl.Series("list", [[1, 2], [3]]),653pl.Series("int64", [2, 3], dtype=pl.Int64()),654pl.Series("uint8", [2, 4], dtype=pl.UInt8()),655pl.Series("nested", [[[1, 2]], [[5]]]),656]657)658# Expr-based arithmetic:659assert_frame_equal(660df.select(expr(pl.col(column_names[0]), pl.col(column_names[1]))),661pl.Series(column_names[0], expected).to_frame(),662)663# Direct arithmetic on the Series:664assert_series_equal(665expr(df[column_names[0]], df[column_names[1]]),666pl.Series(column_names[0], expected),667)668669670@pytest.mark.parametrize(671("a", "b", "expected"),672[673# Null on numeric on the right:674([[1, 2], [3]], [1, None], [[2, 3], [None]]),675# Null on list on the left:676([[[1, 2]], [[3]]], [None, 1], [[[None, None]], [[4]]]),677# Extra nesting:678([[[2, None]], [[3, 6]]], [3, 4], [[[5, None]], [[7, 10]]]),679],680)681def test_list_and_numeric_arithmetic_nulls(682a: list[Any], b: list[Any], expected: list[Any]683) -> None:684series_a = pl.Series(a)685series_b = pl.Series(b)686series_expected = pl.Series(expected, dtype=series_a.dtype)687688# Same dtype:689assert_series_equal(series_a + series_b, series_expected)690691# Different dtype:692assert_series_equal(693series_a._recursive_cast_to_dtype(pl.Int32())694+ series_b._recursive_cast_to_dtype(pl.Int64()),695series_expected._recursive_cast_to_dtype(pl.Int64()),696)697698# Swap sides:699assert_series_equal(series_b + series_a, series_expected)700assert_series_equal(701series_b._recursive_cast_to_dtype(pl.Int32())702+ series_a._recursive_cast_to_dtype(pl.Int64()),703series_expected._recursive_cast_to_dtype(pl.Int64()),704)705706707def test_list_and_numeric_arithmetic_error_cases() -> None:708# Different series length:709with pytest.raises(710InvalidOperationError, match="series of different lengths: got 3 and 2"711):712_ = pl.Series("a", [[1, 2], [3, 4], [5, 6]]) + pl.Series("b", [1, 2])713with pytest.raises(714InvalidOperationError, match="series of different lengths: got 3 and 2"715):716_ = pl.Series("a", [[1, 2], [3, 4], [5, 6]]) / pl.Series("b", [1, None])717718# Wrong types:719with pytest.raises(720InvalidOperationError, match="add operation not supported for dtypes"721):722_ = pl.Series("a", [[1, 2], [3, 4]]) + pl.Series("b", ["hello", "world"])723724725@pytest.mark.parametrize("broadcast", [True, False])726@pytest.mark.parametrize("dtype", [pl.Int64(), pl.Float64()])727def test_list_arithmetic_div_ops_zero_denominator(728broadcast: bool, dtype: pl.DataType729) -> None:730# Notes731# * truediv (/) on integers upcasts to Float64732# * Otherwise, we test floordiv (//) and module/rem (%)733# * On integers, 0-denominator is expected to output NULL734# * On floats, 0-denominator has different outputs, e.g. NaN, Inf, depending735# on a few factors (e.g. whether the numerator is also 0).736737s = pl.Series([[0], [1], [None], None]).cast(pl.List(dtype))738739n = 1 if broadcast else s.len()740741# list<->primitive742743# truediv744assert_series_equal(745pl.Series([1]).new_from_index(0, n) / s,746pl.Series([[float("inf")], [1.0], [None], None], dtype=pl.List(pl.Float64)),747)748749assert_series_equal(750s / pl.Series([1]).new_from_index(0, n),751pl.Series([[0.0], [1.0], [None], None], dtype=pl.List(pl.Float64)),752)753754# floordiv755assert_series_equal(756pl.Series([1]).new_from_index(0, n) // s,757(758pl.Series([[None], [1], [None], None], dtype=s.dtype)759if not dtype.is_float()760else pl.Series([[float("inf")], [1.0], [None], None], dtype=s.dtype)761),762)763764assert_series_equal(765s // pl.Series([0]).new_from_index(0, n),766(767pl.Series([[None], [None], [None], None], dtype=s.dtype)768if not dtype.is_float()769else pl.Series(770[[float("nan")], [float("inf")], [None], None], dtype=s.dtype771)772),773)774775# rem776assert_series_equal(777pl.Series([1]).new_from_index(0, n) % s,778(779pl.Series([[None], [0], [None], None], dtype=s.dtype)780if not dtype.is_float()781else pl.Series([[float("nan")], [0.0], [None], None], dtype=s.dtype)782),783)784785assert_series_equal(786s % pl.Series([0]).new_from_index(0, n),787(788pl.Series([[None], [None], [None], None], dtype=s.dtype)789if not dtype.is_float()790else pl.Series(791[[float("nan")], [float("nan")], [None], None], dtype=s.dtype792)793),794)795796# list<->list797798# truediv799assert_series_equal(800pl.Series([[1]]).new_from_index(0, n) / s,801pl.Series([[float("inf")], [1.0], [None], None], dtype=pl.List(pl.Float64)),802)803804assert_series_equal(805s / pl.Series([[0]]).new_from_index(0, n),806pl.Series(807[[float("nan")], [float("inf")], [None], None], dtype=pl.List(pl.Float64)808),809)810811# floordiv812assert_series_equal(813pl.Series([[1]]).new_from_index(0, n) // s,814(815pl.Series([[None], [1], [None], None], dtype=s.dtype)816if not dtype.is_float()817else pl.Series([[float("inf")], [1.0], [None], None], dtype=s.dtype)818),819)820821assert_series_equal(822s // pl.Series([[0]]).new_from_index(0, n),823(824pl.Series([[None], [None], [None], None], dtype=s.dtype)825if not dtype.is_float()826else pl.Series(827[[float("nan")], [float("inf")], [None], None], dtype=s.dtype828)829),830)831832# rem833assert_series_equal(834pl.Series([[1]]).new_from_index(0, n) % s,835(836pl.Series([[None], [0], [None], None], dtype=s.dtype)837if not dtype.is_float()838else pl.Series([[float("nan")], [0.0], [None], None], dtype=s.dtype)839),840)841842assert_series_equal(843s % pl.Series([[0]]).new_from_index(0, n),844(845pl.Series([[None], [None], [None], None], dtype=s.dtype)846if not dtype.is_float()847else pl.Series(848[[float("nan")], [float("nan")], [None], None], dtype=s.dtype849)850),851)852853854def test_list_to_primitive_arithmetic() -> None:855# Input data856# * List type: List(List(List(Int16))) (triple-nested)857# * Numeric type: Int32858#859# Tests run860# Broadcast Operation861# | L | R |862# * list<->primitive | | | floor_div863# * primitive<->list | | | floor_div864# * list<->primitive | | * | subtract865# * primitive<->list | * | | subtract866# * list<->primitive | * | | subtract867# * primitive<->list | | * | subtract868#869# Notes870# * In floor_div, we check that results from a 0 denominator are masked out871# * We choose floor_div and subtract as they emit different results when872# sides are swapped873874# Create some non-zero start offsets and masked out rows.875lhs = (876pl.Series(877[878[[[None, None, None, None, None]]], # sliced out879# Nulls at every level XO880[[[3, 7]], [[-3], [None], [], [], None], [], None],881[[[1, 2, 3, 4, 5]]], # masked out882[[[3, 7]], [[0], [None], [], [], None]],883[[[3, 7]]],884],885dtype=pl.List(pl.List(pl.List(pl.Int16))),886)887.slice(1)888.to_frame()889.select(pl.when(pl.int_range(pl.len()) != 1).then(pl.first()))890.to_series()891)892893# Note to reader: This is what our LHS looks like894assert_series_equal(895lhs,896pl.Series(897[898[[[3, 7]], [[-3], [None], [], [], None], [], None],899None,900[[[3, 7]], [[0], [None], [], [], None]],901[[[3, 7]]],902],903dtype=pl.List(pl.List(pl.List(pl.Int16))),904),905)906907class _:908# Floor div, no broadcasting909rhs = pl.Series([5, 1, 0, None], dtype=pl.Int32)910911assert len(lhs) == len(rhs)912913expect = pl.Series(914[915[[[0, 1]], [[-1], [None], [], [], None], [], None],916None,917[[[None, None]], [[None], [None], [], [], None]],918[[[None, None]]],919],920dtype=pl.List(pl.List(pl.List(pl.Int32))),921)922923out = (924pl.select(l=lhs, r=rhs)925.select(pl.col("l") // pl.col("r"))926.to_series()927.alias("")928)929930assert_series_equal(out, expect)931932# Flipped933934expect = pl.Series( # noqa: PIE794935[936[[[1, 0]], [[-2], [None], [], [], None], [], None],937None,938[[[0, 0]], [[None], [None], [], [], None]],939[[[None, None]]],940],941dtype=pl.List(pl.List(pl.List(pl.Int32))),942)943944out = ( # noqa: PIE794945pl.select(l=lhs, r=rhs)946.select(pl.col("r") // pl.col("l"))947.to_series()948.alias("")949)950951assert_series_equal(out, expect)952953class _: # type: ignore[no-redef]954# Subtraction with broadcasting955rhs = pl.Series([1], dtype=pl.Int32)956957expect = pl.Series(958[959[[[2, 6]], [[-4], [None], [], [], None], [], None],960None,961[[[2, 6]], [[-1], [None], [], [], None]],962[[[2, 6]]],963],964dtype=pl.List(pl.List(pl.List(pl.Int32))),965)966967out = pl.select(l=lhs).select(pl.col("l") - rhs).to_series().alias("")968969assert_series_equal(out, expect)970971# Flipped972973expect = pl.Series( # noqa: PIE794974[975[[[-2, -6]], [[4], [None], [], [], None], [], None],976None,977[[[-2, -6]], [[1], [None], [], [], None]],978[[[-2, -6]]],979],980dtype=pl.List(pl.List(pl.List(pl.Int32))),981)982983out = pl.select(l=lhs).select(rhs - pl.col("l")).to_series().alias("") # noqa: PIE794984985assert_series_equal(out, expect)986987# Test broadcasting of the list side988lhs = lhs.slice(2, 1)989# Note to reader: This is what our LHS looks like990assert_series_equal(991lhs,992pl.Series(993[994[[[3, 7]], [[0], [None], [], [], None]],995],996dtype=pl.List(pl.List(pl.List(pl.Int16))),997),998)9991000assert len(lhs) == 110011002class _: # type: ignore[no-redef]1003rhs = pl.Series([1, 2, 3, None, 5], dtype=pl.Int32)10041005expect = pl.Series(1006[1007[[[2, 6]], [[-1], [None], [], [], None]],1008[[[1, 5]], [[-2], [None], [], [], None]],1009[[[0, 4]], [[-3], [None], [], [], None]],1010[[[None, None]], [[None], [None], [], [], None]],1011[[[-2, 2]], [[-5], [None], [], [], None]],1012],1013dtype=pl.List(pl.List(pl.List(pl.Int32))),1014)10151016out = pl.select(r=rhs).select(lhs - pl.col("r")).to_series().alias("")10171018assert_series_equal(out, expect)10191020# Flipped10211022expect = pl.Series( # noqa: PIE7941023[1024[[[-2, -6]], [[1], [None], [], [], None]],1025[[[-1, -5]], [[2], [None], [], [], None]],1026[[[0, -4]], [[3], [None], [], [], None]],1027[[[None, None]], [[None], [None], [], [], None]],1028[[[2, -2]], [[5], [None], [], [], None]],1029],1030dtype=pl.List(pl.List(pl.List(pl.Int32))),1031)10321033out = pl.select(r=rhs).select(pl.col("r") - lhs).to_series().alias("") # noqa: PIE79410341035assert_series_equal(out, expect)103610371038def test_list_boolean_arithmetic_23146() -> None:1039"""Test that boolean arithmetic in lists works and returns UInt32."""1040# Boolean list + Boolean list with single value1041result = pl.select(pl.lit([True]) + [True])1042expected = pl.DataFrame({"literal": [[2]]}, schema={"literal": pl.List(pl.UInt32)})1043assert_frame_equal(result, expected)10441045# Boolean list + Boolean list with multiple values1046result = pl.select(pl.lit([True, False]) + [True, True])1047expected = pl.DataFrame(1048{"literal": [[2, 1]]}, schema={"literal": pl.List(pl.UInt32)}1049)1050assert_frame_equal(result, expected)10511052# Boolean list + Scalar integer (supertype with Int32)1053result = pl.select(pl.lit([True, False]) + 1)1054expected = pl.DataFrame(1055{"literal": [[2, 1]]}, schema={"literal": pl.List(pl.Int32)}1056)1057assert_frame_equal(result, expected)10581059# Boolean list arithmetic on DataFrame columns1060df = pl.DataFrame(1061{"a": [[True, False]], "b": [[1, 2]]},1062schema={"a": pl.List(pl.Boolean), "b": pl.List(pl.Int64)},1063)1064result = df.select(pl.col("a") + pl.col("b"))1065expected = pl.DataFrame({"a": [[2, 2]]}, schema={"a": pl.List(pl.Int64)})1066assert_frame_equal(result, expected)10671068# Division test1069df = pl.DataFrame({"a": [[True, False]], "b": [[128, 128]]})1070result = df.select(pl.col("a") / pl.col("b"))1071assert result.schema == {"a": pl.List(pl.Float64)}107210731074