Path: blob/main/py-polars/tests/unit/operations/arithmetic/test_list.py
6940 views
from __future__ import annotations12import operator3from typing import TYPE_CHECKING, Any, Callable45import pytest67import polars as pl8from polars.exceptions import InvalidOperationError, ShapeError9from polars.testing import assert_frame_equal, assert_series_equal10from tests.unit.operations.arithmetic.utils import (11BROADCAST_SERIES_COMBINATIONS,12EXEC_OP_COMBINATIONS,13)1415if TYPE_CHECKING:16from polars._typing import PolarsDataType171819@pytest.mark.parametrize(20"list_side", ["left", "left3", "both", "right3", "right", "none"]21)22@pytest.mark.parametrize(23"broadcast_series",24BROADCAST_SERIES_COMBINATIONS,25)26@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS)27@pytest.mark.slow28def test_list_arithmetic_values(29list_side: str,30broadcast_series: Callable[31[pl.Series, pl.Series, pl.Series], tuple[pl.Series, pl.Series, pl.Series]32],33exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series],34) -> None:35"""36Tests value correctness.3738This test checks for output value correctness (a + b == c) across different39codepaths, by wrapping the values (a, b, c) in different combinations of40list / primitive columns.41"""42import operator as op4344dtypes: list[Any] = [pl.Null, pl.Null, pl.Null]45dtype: Any = pl.Null4647def materialize_list(v: Any) -> pl.Series:48return pl.Series(49[[None, v, None]],50dtype=pl.List(dtype),51)5253def materialize_list3(v: Any) -> pl.Series:54return pl.Series(55[[[[None, v], None], None]],56dtype=pl.List(pl.List(pl.List(dtype))),57)5859def materialize_primitive(v: Any) -> pl.Series:60return pl.Series([v], dtype=dtype)6162def materialize_series(63l: Any, # noqa: E74164r: Any,65o: Any,66) -> tuple[pl.Series, pl.Series, pl.Series]:67nonlocal dtype6869dtype = dtypes[0]70l = { # noqa: E74171"left": materialize_list,72"left3": materialize_list3,73"both": materialize_list,74"right": materialize_primitive,75"right3": materialize_primitive,76"none": materialize_primitive,77}[list_side](l) # fmt: skip7879dtype = dtypes[1]80r = {81"left": materialize_primitive,82"left3": materialize_primitive,83"both": materialize_list,84"right": materialize_list,85"right3": materialize_list3,86"none": materialize_primitive,87}[list_side](r) # fmt: skip8889dtype = dtypes[2]90o = {91"left": materialize_list,92"left3": materialize_list3,93"both": materialize_list,94"right": materialize_list,95"right3": materialize_list3,96"none": materialize_primitive,97}[list_side](o) # fmt: skip9899assert l.len() == 1100assert r.len() == 1101assert o.len() == 1102103return broadcast_series(l, r, o)104105# Signed106dtypes = [pl.Int8, pl.Int8, pl.Int8]107108l, r, o = materialize_series(2, 3, 5) # noqa: E741109assert_series_equal(exec_op(l, r, op.add), o)110111l, r, o = materialize_series(-5, 127, 124) # noqa: E741112assert_series_equal(exec_op(l, r, op.sub), o)113114l, r, o = materialize_series(-5, 127, -123) # noqa: E741115assert_series_equal(exec_op(l, r, op.mul), o)116117l, r, o = materialize_series(-5, 3, -2) # noqa: E741118assert_series_equal(exec_op(l, r, op.floordiv), o)119120l, r, o = materialize_series(-5, 3, 1) # noqa: E741121assert_series_equal(exec_op(l, r, op.mod), o)122123dtypes = [pl.UInt8, pl.UInt8, pl.Float64]124l, r, o = materialize_series(2, 128, 0.015625) # noqa: E741125assert_series_equal(exec_op(l, r, op.truediv), o)126127# Unsigned128dtypes = [pl.UInt8, pl.UInt8, pl.UInt8]129130l, r, o = materialize_series(2, 3, 5) # noqa: E741131assert_series_equal(exec_op(l, r, op.add), o)132133l, r, o = materialize_series(2, 3, 255) # noqa: E741134assert_series_equal(exec_op(l, r, op.sub), o)135136l, r, o = materialize_series(2, 128, 0) # noqa: E741137assert_series_equal(exec_op(l, r, op.mul), o)138139l, r, o = materialize_series(5, 2, 2) # noqa: E741140assert_series_equal(exec_op(l, r, op.floordiv), o)141142l, r, o = materialize_series(5, 2, 1) # noqa: E741143assert_series_equal(exec_op(l, r, op.mod), o)144145dtypes = [pl.UInt8, pl.UInt8, pl.Float64]146l, r, o = materialize_series(2, 128, 0.015625) # noqa: E741147assert_series_equal(exec_op(l, r, op.truediv), o)148149# Floats. Note we pick Float32 to ensure there is no accidental upcasting150# to Float64.151dtypes = [pl.Float32, pl.Float32, pl.Float32]152l, r, o = materialize_series(1.7, 2.3, 4.0) # noqa: E741153assert_series_equal(exec_op(l, r, op.add), o)154155l, r, o = materialize_series(1.7, 2.3, -0.5999999999999999) # noqa: E741156assert_series_equal(exec_op(l, r, op.sub), o)157158l, r, o = materialize_series(1.7, 2.3, 3.9099999999999997) # noqa: E741159assert_series_equal(exec_op(l, r, op.mul), o)160161l, r, o = materialize_series(7.0, 3.0, 2.0) # noqa: E741162assert_series_equal(exec_op(l, r, op.floordiv), o)163164l, r, o = materialize_series(-5.0, 3.0, 1.0) # noqa: E741165assert_series_equal(exec_op(l, r, op.mod), o)166167l, r, o = materialize_series(2.0, 128.0, 0.015625) # noqa: E741168assert_series_equal(exec_op(l, r, op.truediv), o)169170#171# Tests for zero behavior172#173174# Integer175176dtypes = [pl.UInt8, pl.UInt8, pl.UInt8]177178l, r, o = materialize_series(1, 0, None) # noqa: E741179assert_series_equal(exec_op(l, r, op.floordiv), o)180assert_series_equal(exec_op(l, r, op.mod), o)181182l, r, o = materialize_series(0, 0, None) # noqa: E741183assert_series_equal(exec_op(l, r, op.floordiv), o)184assert_series_equal(exec_op(l, r, op.mod), o)185186dtypes = [pl.UInt8, pl.UInt8, pl.Float64]187188l, r, o = materialize_series(1, 0, float("inf")) # noqa: E741189assert_series_equal(exec_op(l, r, op.truediv), o)190191l, r, o = materialize_series(0, 0, float("nan")) # noqa: E741192assert_series_equal(exec_op(l, r, op.truediv), o)193194# Float195196dtypes = [pl.Float32, pl.Float32, pl.Float32]197198l, r, o = materialize_series(1, 0, float("inf")) # noqa: E741199assert_series_equal(exec_op(l, r, op.floordiv), o)200201l, r, o = materialize_series(1, 0, float("nan")) # noqa: E741202assert_series_equal(exec_op(l, r, op.mod), o)203204l, r, o = materialize_series(1, 0, float("inf")) # noqa: E741205assert_series_equal(exec_op(l, r, op.truediv), o)206207l, r, o = materialize_series(0, 0, float("nan")) # noqa: E741208assert_series_equal(exec_op(l, r, op.floordiv), o)209210l, r, o = materialize_series(0, 0, float("nan")) # noqa: E741211assert_series_equal(exec_op(l, r, op.mod), o)212213l, r, o = materialize_series(0, 0, float("nan")) # noqa: E741214assert_series_equal(exec_op(l, r, op.truediv), o)215216#217# Tests for NULL behavior218#219220for dtype, truediv_dtype in [ # type: ignore[misc]221[pl.Int8, pl.Float64],222[pl.Float32, pl.Float32],223]:224for vals in [225[None, None, None],226[0, None, None],227[None, 0, None],228[0, None, None],229[None, 0, None],230[3, None, None],231[None, 3, None],232]:233dtypes = 3 * [dtype]234235l, r, o = materialize_series(*vals) # type: ignore[misc] # noqa: E741236assert_series_equal(exec_op(l, r, op.add), o)237assert_series_equal(exec_op(l, r, op.sub), o)238assert_series_equal(exec_op(l, r, op.mul), o)239assert_series_equal(exec_op(l, r, op.floordiv), o)240assert_series_equal(exec_op(l, r, op.mod), o)241dtypes[2] = truediv_dtype # type: ignore[has-type]242l, r, o = materialize_series(*vals) # type: ignore[misc] # noqa: E741243assert_series_equal(exec_op(l, r, op.truediv), o)244245# Type upcasting for Boolean and Null246247# Check boolean upcasting248dtypes = [pl.Boolean, pl.UInt8, pl.UInt8]249250l, r, o = materialize_series(True, 3, 4) # noqa: E741251assert_series_equal(exec_op(l, r, op.add), o)252253l, r, o = materialize_series(True, 3, 254) # noqa: E741254assert_series_equal(exec_op(l, r, op.sub), o)255256l, r, o = materialize_series(True, 3, 3) # noqa: E741257assert_series_equal(exec_op(l, r, op.mul), o)258259l, r, o = materialize_series(True, 3, 0) # noqa: E741260if list_side != "none":261# TODO: FIXME: We get an error on non-lists with this:262# "floor_div operation not supported for dtype `bool`"263assert_series_equal(exec_op(l, r, op.floordiv), o)264265l, r, o = materialize_series(True, 3, 1) # noqa: E741266assert_series_equal(exec_op(l, r, op.mod), o)267268dtypes = [pl.Boolean, pl.UInt8, pl.Float64]269l, r, o = materialize_series(True, 128, 0.0078125) # noqa: E741270assert_series_equal(exec_op(l, r, op.truediv), o)271272# Check Null upcasting273dtypes = [pl.Null, pl.UInt8, pl.UInt8]274l, r, o = materialize_series(None, 3, None) # noqa: E741275assert_series_equal(exec_op(l, r, op.add), o)276assert_series_equal(exec_op(l, r, op.sub), o)277assert_series_equal(exec_op(l, r, op.mul), o)278if list_side != "none":279assert_series_equal(exec_op(l, r, op.floordiv), o)280assert_series_equal(exec_op(l, r, op.mod), o)281282dtypes = [pl.Null, pl.UInt8, pl.Float64]283l, r, o = materialize_series(None, 3, None) # noqa: E741284assert_series_equal(exec_op(l, r, op.truediv), o)285286287@pytest.mark.parametrize(288("lhs_dtype", "rhs_dtype", "expected_dtype"),289[290(pl.List(pl.Int64), pl.Int64, pl.List(pl.Float64)),291(pl.List(pl.Float32), pl.Float32, pl.List(pl.Float32)),292(pl.List(pl.Duration("us")), pl.Int64, pl.List(pl.Duration("us"))),293],294)295def test_list_truediv_schema(296lhs_dtype: PolarsDataType, rhs_dtype: PolarsDataType, expected_dtype: PolarsDataType297) -> None:298schema = {"lhs": lhs_dtype, "rhs": rhs_dtype}299df = pl.DataFrame({"lhs": [[None, 10]], "rhs": 2}, schema=schema)300result = df.lazy().select(pl.col("lhs").truediv("rhs")).collect_schema()["lhs"]301assert result == expected_dtype302303304@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS)305def test_list_add_supertype(306exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series],307) -> None:308import operator as op309310a = pl.Series("a", [[1], [2]], dtype=pl.List(pl.Int8))311b = pl.Series("b", [[1], [999]], dtype=pl.List(pl.Int64))312313assert_series_equal(314exec_op(a, b, op.add),315pl.Series("a", [[2], [1001]], dtype=pl.List(pl.Int64)),316)317318319@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS)320@pytest.mark.parametrize(321"broadcast_series",322BROADCAST_SERIES_COMBINATIONS,323)324@pytest.mark.slow325def test_list_numeric_op_validity_combination(326broadcast_series: Callable[327[pl.Series, pl.Series, pl.Series], tuple[pl.Series, pl.Series, pl.Series]328],329exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series],330) -> None:331import operator as op332333a = pl.Series("a", [[1], [2], None, [None], [11], [1111]], dtype=pl.List(pl.Int32))334b = pl.Series("b", [[1], [3], [11], [1111], None, [None]], dtype=pl.List(pl.Int64))335# expected result336e = pl.Series("a", [[2], [5], None, [None], None, [None]], dtype=pl.List(pl.Int64))337338assert_series_equal(339exec_op(a, b, op.add),340e,341)342343a = pl.Series("a", [[1]], dtype=pl.List(pl.Int32))344b = pl.Series("b", [None], dtype=pl.Int64)345e = pl.Series("a", [[None]], dtype=pl.List(pl.Int64))346347a, b, e = broadcast_series(a, b, e)348assert_series_equal(exec_op(a, b, op.add), e)349350a = pl.Series("a", [None], dtype=pl.List(pl.Int32))351b = pl.Series("b", [1], dtype=pl.Int64)352e = pl.Series("a", [None], dtype=pl.List(pl.Int64))353354a, b, e = broadcast_series(a, b, e)355assert_series_equal(exec_op(a, b, op.add), e)356357a = pl.Series("a", [None], dtype=pl.List(pl.Int32))358b = pl.Series("b", [0], dtype=pl.Int64)359e = pl.Series("a", [None], dtype=pl.List(pl.Int64))360361a, b, e = broadcast_series(a, b, e)362assert_series_equal(exec_op(a, b, op.floordiv), e)363364365def test_list_add_alignment() -> None:366a = pl.Series("a", [[1, 1], [1, 1, 1]])367b = pl.Series("b", [[1, 1, 1], [1, 1]])368369df = pl.DataFrame([a, b])370371with pytest.raises(ShapeError):372df.select(x=pl.col("a") + pl.col("b"))373374# Test masking and slicing375a = pl.Series("a", [[1, 1, 1], [1], [1, 1], [1, 1, 1]])376b = pl.Series("b", [[1, 1], [1], [1, 1, 1], [1]])377c = pl.Series("c", [1, 1, 1, 1])378p = pl.Series("p", [True, True, False, False])379380df = pl.DataFrame([a, b, c, p]).filter("p").slice(1)381382for rhs in [pl.col("b"), pl.lit(1), pl.col("c"), pl.lit([1])]:383assert_series_equal(384df.select(x=pl.col("a") + rhs).to_series(), pl.Series("x", [[2]])385)386387df = df.vstack(df)388389for rhs in [pl.col("b"), pl.lit(1), pl.col("c"), pl.lit([1])]:390assert_series_equal(391df.select(x=pl.col("a") + rhs).to_series(), pl.Series("x", [[2], [2]])392)393394395@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS)396@pytest.mark.slow397def test_list_add_empty_lists(398exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series],399) -> None:400l = pl.Series( # noqa: E741401"x",402[[[[]], []], []],403)404r = pl.Series([1])405406assert_series_equal(407exec_op(l, r, operator.add),408pl.Series("x", [[[[]], []], []], dtype=pl.List(pl.List(pl.List(pl.Int64)))),409)410411l = pl.Series( # noqa: E741412"x",413[[[[]], None], []],414)415r = pl.Series([1])416417assert_series_equal(418exec_op(l, r, operator.add),419pl.Series("x", [[[[]], None], []], dtype=pl.List(pl.List(pl.List(pl.Int64)))),420)421422423@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS)424def test_list_to_list_arithmetic_double_nesting_raises_error(425exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series],426) -> None:427s = pl.Series(dtype=pl.List(pl.List(pl.Int32)))428429with pytest.raises(430InvalidOperationError,431match="cannot add two list columns with non-numeric inner types",432):433exec_op(s, s, operator.add)434435436@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS)437def test_list_add_height_mismatch(438exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series],439) -> None:440s = pl.Series([[1], [2], [3]], dtype=pl.List(pl.Int32))441442# TODO: Make the error type consistently a ShapeError443with pytest.raises(444(ShapeError, InvalidOperationError),445match="length",446):447exec_op(s, pl.Series([1, 1]), operator.add)448449450@pytest.mark.parametrize(451"op",452[453operator.add,454operator.sub,455operator.mul,456operator.floordiv,457operator.mod,458operator.truediv,459],460)461@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS)462@pytest.mark.slow463def test_list_date_to_numeric_arithmetic_raises_error(464op: Callable[[Any], Any], exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series]465) -> None:466l = pl.Series([1], dtype=pl.Date) # noqa: E741467r = pl.Series([[1]], dtype=pl.List(pl.Int32))468469exec_op(l.to_physical(), r, op)470471# TODO(_): Ideally this always raises InvalidOperationError. The TypeError472# is being raised by checks on the Python side that should be moved to Rust.473with pytest.raises((InvalidOperationError, TypeError)):474exec_op(l, r, op)475476477@pytest.mark.parametrize(478("expected", "expr", "column_names"),479[480([[2, 4], [6]], lambda a, b: a + b, ("a", "a")),481([[0, 0], [0]], lambda a, b: a - b, ("a", "a")),482([[1, 4], [9]], lambda a, b: a * b, ("a", "a")),483([[1.0, 1.0], [1.0]], lambda a, b: a / b, ("a", "a")),484([[0, 0], [0]], lambda a, b: a % b, ("a", "a")),485(486[[3, 4], [7]],487lambda a, b: a + b,488("a", "uint8"),489),490],491)492def test_list_arithmetic_same_size(493expected: Any,494expr: Callable[[pl.Series | pl.Expr, pl.Series | pl.Expr], pl.Series],495column_names: tuple[str, str],496) -> None:497df = pl.DataFrame(498[499pl.Series("a", [[1, 2], [3]]),500pl.Series("uint8", [[2, 2], [4]], dtype=pl.List(pl.UInt8())),501pl.Series("nested", [[[1, 2]], [[3]]]),502pl.Series(503"nested_uint8", [[[1, 2]], [[3]]], dtype=pl.List(pl.List(pl.UInt8()))504),505]506)507# Expr-based arithmetic:508assert_frame_equal(509df.select(expr(pl.col(column_names[0]), pl.col(column_names[1]))),510pl.Series(column_names[0], expected).to_frame(),511)512# Direct arithmetic on the Series:513assert_series_equal(514expr(df[column_names[0]], df[column_names[1]]),515pl.Series(column_names[0], expected),516)517518519@pytest.mark.parametrize(520("a", "b", "expected"),521[522([[1, 2, 3]], [[1, None, 5]], [[2, None, 8]]),523([[2], None, [5]], [None, [3], [2]], [None, None, [7]]),524],525)526def test_list_arithmetic_nulls(a: list[Any], b: list[Any], expected: list[Any]) -> None:527series_a = pl.Series(a)528series_b = pl.Series(b)529series_expected = pl.Series(expected)530531# Same dtype:532assert_series_equal(series_a + series_b, series_expected)533534# Different dtype:535assert_series_equal(536series_a._recursive_cast_to_dtype(pl.Int32())537+ series_b._recursive_cast_to_dtype(pl.Int64()),538series_expected._recursive_cast_to_dtype(pl.Int64()),539)540541542def test_list_arithmetic_error_cases() -> None:543# Different series length:544with pytest.raises(InvalidOperationError, match="different lengths"):545_ = pl.Series("a", [[1, 2], [1, 2], [1, 2]]) / pl.Series("b", [[1, 2], [3, 4]])546with pytest.raises(InvalidOperationError, match="different lengths"):547_ = pl.Series("a", [[1, 2], [1, 2], [1, 2]]) / pl.Series("b", [[1, 2], None])548549# Different list length:550with pytest.raises(ShapeError, match="lengths differed at index 0: 2 != 1"):551_ = pl.Series("a", [[1, 2], [1, 2], [1, 2]]) / pl.Series("b", [[1]])552553with pytest.raises(ShapeError, match="lengths differed at index 0: 2 != 1"):554_ = pl.Series("a", [[1, 2], [2, 3]]) / pl.Series("b", [[1], None])555556557@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS)558def test_list_arithmetic_invalid_dtypes(559exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series],560) -> None:561import operator as op562563a = pl.Series([[1, 2]])564b = pl.Series(["hello"])565566# Wrong types:567with pytest.raises(568InvalidOperationError, match="add operation not supported for dtypes"569):570exec_op(a, b, op.add)571572a = pl.Series("a", [[1]])573b = pl.Series("b", [[[1]]])574575# list<->list is restricted to 1 level of nesting576with pytest.raises(577InvalidOperationError,578match="cannot add two list columns with non-numeric inner types",579):580exec_op(a, b, op.add)581582# Ensure dtype is validated to be `List` at all nesting levels instead of panicking.583a = pl.Series([[[1]], [[1]]], dtype=pl.List(pl.Array(pl.Int64, 1)))584b = pl.Series([1], dtype=pl.Int64)585586with pytest.raises(587InvalidOperationError, match="dtype was not list on all nesting levels"588):589exec_op(a, b, op.add)590591with pytest.raises(592InvalidOperationError, match="dtype was not list on all nesting levels"593):594exec_op(b, a, op.add)595596597@pytest.mark.parametrize(598("expected", "expr", "column_names"),599[600# All 5 arithmetic operations:601([[3, 4], [6]], lambda a, b: a + b, ("list", "int64")),602([[-1, 0], [0]], lambda a, b: a - b, ("list", "int64")),603([[2, 4], [9]], lambda a, b: a * b, ("list", "int64")),604([[0.5, 1.0], [1.0]], lambda a, b: a / b, ("list", "int64")),605([[1, 0], [0]], lambda a, b: a % b, ("list", "int64")),606# Different types:607(608[[3, 4], [7]],609lambda a, b: a + b,610("list", "uint8"),611),612# Extra nesting + different types:613(614[[[3, 4]], [[8]]],615lambda a, b: a + b,616("nested", "int64"),617),618# Primitive numeric on the left; only addition and multiplication are619# supported:620([[3, 4], [6]], lambda a, b: a + b, ("int64", "list")),621([[2, 4], [9]], lambda a, b: a * b, ("int64", "list")),622# Primitive numeric on the left with different types:623(624[[3, 4], [7]],625lambda a, b: a + b,626("uint8", "list"),627),628(629[[2, 4], [12]],630lambda a, b: a * b,631("uint8", "list"),632),633],634)635def test_list_and_numeric_arithmetic_same_size(636expected: Any,637expr: Callable[[pl.Series | pl.Expr, pl.Series | pl.Expr], pl.Series],638column_names: tuple[str, str],639) -> None:640df = pl.DataFrame(641[642pl.Series("list", [[1, 2], [3]]),643pl.Series("int64", [2, 3], dtype=pl.Int64()),644pl.Series("uint8", [2, 4], dtype=pl.UInt8()),645pl.Series("nested", [[[1, 2]], [[5]]]),646]647)648# Expr-based arithmetic:649assert_frame_equal(650df.select(expr(pl.col(column_names[0]), pl.col(column_names[1]))),651pl.Series(column_names[0], expected).to_frame(),652)653# Direct arithmetic on the Series:654assert_series_equal(655expr(df[column_names[0]], df[column_names[1]]),656pl.Series(column_names[0], expected),657)658659660@pytest.mark.parametrize(661("a", "b", "expected"),662[663# Null on numeric on the right:664([[1, 2], [3]], [1, None], [[2, 3], [None]]),665# Null on list on the left:666([[[1, 2]], [[3]]], [None, 1], [[[None, None]], [[4]]]),667# Extra nesting:668([[[2, None]], [[3, 6]]], [3, 4], [[[5, None]], [[7, 10]]]),669],670)671def test_list_and_numeric_arithmetic_nulls(672a: list[Any], b: list[Any], expected: list[Any]673) -> None:674series_a = pl.Series(a)675series_b = pl.Series(b)676series_expected = pl.Series(expected, dtype=series_a.dtype)677678# Same dtype:679assert_series_equal(series_a + series_b, series_expected)680681# Different dtype:682assert_series_equal(683series_a._recursive_cast_to_dtype(pl.Int32())684+ series_b._recursive_cast_to_dtype(pl.Int64()),685series_expected._recursive_cast_to_dtype(pl.Int64()),686)687688# Swap sides:689assert_series_equal(series_b + series_a, series_expected)690assert_series_equal(691series_b._recursive_cast_to_dtype(pl.Int32())692+ series_a._recursive_cast_to_dtype(pl.Int64()),693series_expected._recursive_cast_to_dtype(pl.Int64()),694)695696697def test_list_and_numeric_arithmetic_error_cases() -> None:698# Different series length:699with pytest.raises(700InvalidOperationError, match="series of different lengths: got 3 and 2"701):702_ = pl.Series("a", [[1, 2], [3, 4], [5, 6]]) + pl.Series("b", [1, 2])703with pytest.raises(704InvalidOperationError, match="series of different lengths: got 3 and 2"705):706_ = pl.Series("a", [[1, 2], [3, 4], [5, 6]]) / pl.Series("b", [1, None])707708# Wrong types:709with pytest.raises(710InvalidOperationError, match="add operation not supported for dtypes"711):712_ = pl.Series("a", [[1, 2], [3, 4]]) + pl.Series("b", ["hello", "world"])713714715@pytest.mark.parametrize("broadcast", [True, False])716@pytest.mark.parametrize("dtype", [pl.Int64(), pl.Float64()])717def test_list_arithmetic_div_ops_zero_denominator(718broadcast: bool, dtype: pl.DataType719) -> None:720# Notes721# * truediv (/) on integers upcasts to Float64722# * Otherwise, we test floordiv (//) and module/rem (%)723# * On integers, 0-denominator is expected to output NULL724# * On floats, 0-denominator has different outputs, e.g. NaN, Inf, depending725# on a few factors (e.g. whether the numerator is also 0).726727s = pl.Series([[0], [1], [None], None]).cast(pl.List(dtype))728729n = 1 if broadcast else s.len()730731# list<->primitive732733# truediv734assert_series_equal(735pl.Series([1]).new_from_index(0, n) / s,736pl.Series([[float("inf")], [1.0], [None], None], dtype=pl.List(pl.Float64)),737)738739assert_series_equal(740s / pl.Series([1]).new_from_index(0, n),741pl.Series([[0.0], [1.0], [None], None], dtype=pl.List(pl.Float64)),742)743744# floordiv745assert_series_equal(746pl.Series([1]).new_from_index(0, n) // s,747(748pl.Series([[None], [1], [None], None], dtype=s.dtype)749if not dtype.is_float()750else pl.Series([[float("inf")], [1.0], [None], None], dtype=s.dtype)751),752)753754assert_series_equal(755s // pl.Series([0]).new_from_index(0, n),756(757pl.Series([[None], [None], [None], None], dtype=s.dtype)758if not dtype.is_float()759else pl.Series(760[[float("nan")], [float("inf")], [None], None], dtype=s.dtype761)762),763)764765# rem766assert_series_equal(767pl.Series([1]).new_from_index(0, n) % s,768(769pl.Series([[None], [0], [None], None], dtype=s.dtype)770if not dtype.is_float()771else pl.Series([[float("nan")], [0.0], [None], None], dtype=s.dtype)772),773)774775assert_series_equal(776s % pl.Series([0]).new_from_index(0, n),777(778pl.Series([[None], [None], [None], None], dtype=s.dtype)779if not dtype.is_float()780else pl.Series(781[[float("nan")], [float("nan")], [None], None], dtype=s.dtype782)783),784)785786# list<->list787788# truediv789assert_series_equal(790pl.Series([[1]]).new_from_index(0, n) / s,791pl.Series([[float("inf")], [1.0], [None], None], dtype=pl.List(pl.Float64)),792)793794assert_series_equal(795s / pl.Series([[0]]).new_from_index(0, n),796pl.Series(797[[float("nan")], [float("inf")], [None], None], dtype=pl.List(pl.Float64)798),799)800801# floordiv802assert_series_equal(803pl.Series([[1]]).new_from_index(0, n) // s,804(805pl.Series([[None], [1], [None], None], dtype=s.dtype)806if not dtype.is_float()807else pl.Series([[float("inf")], [1.0], [None], None], dtype=s.dtype)808),809)810811assert_series_equal(812s // pl.Series([[0]]).new_from_index(0, n),813(814pl.Series([[None], [None], [None], None], dtype=s.dtype)815if not dtype.is_float()816else pl.Series(817[[float("nan")], [float("inf")], [None], None], dtype=s.dtype818)819),820)821822# rem823assert_series_equal(824pl.Series([[1]]).new_from_index(0, n) % s,825(826pl.Series([[None], [0], [None], None], dtype=s.dtype)827if not dtype.is_float()828else pl.Series([[float("nan")], [0.0], [None], None], dtype=s.dtype)829),830)831832assert_series_equal(833s % pl.Series([[0]]).new_from_index(0, n),834(835pl.Series([[None], [None], [None], None], dtype=s.dtype)836if not dtype.is_float()837else pl.Series(838[[float("nan")], [float("nan")], [None], None], dtype=s.dtype839)840),841)842843844def test_list_to_primitive_arithmetic() -> None:845# Input data846# * List type: List(List(List(Int16))) (triple-nested)847# * Numeric type: Int32848#849# Tests run850# Broadcast Operation851# | L | R |852# * list<->primitive | | | floor_div853# * primitive<->list | | | floor_div854# * list<->primitive | | * | subtract855# * primitive<->list | * | | subtract856# * list<->primitive | * | | subtract857# * primitive<->list | | * | subtract858#859# Notes860# * In floor_div, we check that results from a 0 denominator are masked out861# * We choose floor_div and subtract as they emit different results when862# sides are swapped863864# Create some non-zero start offsets and masked out rows.865lhs = (866pl.Series(867[868[[[None, None, None, None, None]]], # sliced out869# Nulls at every level XO870[[[3, 7]], [[-3], [None], [], [], None], [], None],871[[[1, 2, 3, 4, 5]]], # masked out872[[[3, 7]], [[0], [None], [], [], None]],873[[[3, 7]]],874],875dtype=pl.List(pl.List(pl.List(pl.Int16))),876)877.slice(1)878.to_frame()879.select(pl.when(pl.int_range(pl.len()) != 1).then(pl.first()))880.to_series()881)882883# Note to reader: This is what our LHS looks like884assert_series_equal(885lhs,886pl.Series(887[888[[[3, 7]], [[-3], [None], [], [], None], [], None],889None,890[[[3, 7]], [[0], [None], [], [], None]],891[[[3, 7]]],892],893dtype=pl.List(pl.List(pl.List(pl.Int16))),894),895)896897class _:898# Floor div, no broadcasting899rhs = pl.Series([5, 1, 0, None], dtype=pl.Int32)900901assert len(lhs) == len(rhs)902903expect = pl.Series(904[905[[[0, 1]], [[-1], [None], [], [], None], [], None],906None,907[[[None, None]], [[None], [None], [], [], None]],908[[[None, None]]],909],910dtype=pl.List(pl.List(pl.List(pl.Int32))),911)912913out = (914pl.select(l=lhs, r=rhs)915.select(pl.col("l") // pl.col("r"))916.to_series()917.alias("")918)919920assert_series_equal(out, expect)921922# Flipped923924expect = pl.Series( # noqa: PIE794925[926[[[1, 0]], [[-2], [None], [], [], None], [], None],927None,928[[[0, 0]], [[None], [None], [], [], None]],929[[[None, None]]],930],931dtype=pl.List(pl.List(pl.List(pl.Int32))),932)933934out = ( # noqa: PIE794935pl.select(l=lhs, r=rhs)936.select(pl.col("r") // pl.col("l"))937.to_series()938.alias("")939)940941assert_series_equal(out, expect)942943class _: # type: ignore[no-redef]944# Subtraction with broadcasting945rhs = pl.Series([1], dtype=pl.Int32)946947expect = pl.Series(948[949[[[2, 6]], [[-4], [None], [], [], None], [], None],950None,951[[[2, 6]], [[-1], [None], [], [], None]],952[[[2, 6]]],953],954dtype=pl.List(pl.List(pl.List(pl.Int32))),955)956957out = pl.select(l=lhs).select(pl.col("l") - rhs).to_series().alias("")958959assert_series_equal(out, expect)960961# Flipped962963expect = pl.Series( # noqa: PIE794964[965[[[-2, -6]], [[4], [None], [], [], None], [], None],966None,967[[[-2, -6]], [[1], [None], [], [], None]],968[[[-2, -6]]],969],970dtype=pl.List(pl.List(pl.List(pl.Int32))),971)972973out = pl.select(l=lhs).select(rhs - pl.col("l")).to_series().alias("") # noqa: PIE794974975assert_series_equal(out, expect)976977# Test broadcasting of the list side978lhs = lhs.slice(2, 1)979# Note to reader: This is what our LHS looks like980assert_series_equal(981lhs,982pl.Series(983[984[[[3, 7]], [[0], [None], [], [], None]],985],986dtype=pl.List(pl.List(pl.List(pl.Int16))),987),988)989990assert len(lhs) == 1991992class _: # type: ignore[no-redef]993rhs = pl.Series([1, 2, 3, None, 5], dtype=pl.Int32)994995expect = pl.Series(996[997[[[2, 6]], [[-1], [None], [], [], None]],998[[[1, 5]], [[-2], [None], [], [], None]],999[[[0, 4]], [[-3], [None], [], [], None]],1000[[[None, None]], [[None], [None], [], [], None]],1001[[[-2, 2]], [[-5], [None], [], [], None]],1002],1003dtype=pl.List(pl.List(pl.List(pl.Int32))),1004)10051006out = pl.select(r=rhs).select(lhs - pl.col("r")).to_series().alias("")10071008assert_series_equal(out, expect)10091010# Flipped10111012expect = pl.Series( # noqa: PIE7941013[1014[[[-2, -6]], [[1], [None], [], [], None]],1015[[[-1, -5]], [[2], [None], [], [], None]],1016[[[0, -4]], [[3], [None], [], [], None]],1017[[[None, None]], [[None], [None], [], [], None]],1018[[[2, -2]], [[5], [None], [], [], None]],1019],1020dtype=pl.List(pl.List(pl.List(pl.Int32))),1021)10221023out = pl.select(r=rhs).select(pl.col("r") - lhs).to_series().alias("") # noqa: PIE79410241025assert_series_equal(out, expect)102610271028