Path: blob/main/py-polars/tests/unit/operations/arithmetic/test_array.py
8391 views
from __future__ import annotations12from typing import TYPE_CHECKING, Any34import pytest56import polars as pl7from polars.exceptions import InvalidOperationError8from polars.testing import assert_series_equal9from tests.unit.operations.arithmetic.utils import (10BROADCAST_SERIES_COMBINATIONS,11EXEC_OP_COMBINATIONS,12)1314if TYPE_CHECKING:15from collections.abc import Callable1617from polars._typing import PolarsDataType181920@pytest.mark.parametrize(21"array_side", ["left", "left3", "both", "both3", "right3", "right", "none"]22)23@pytest.mark.parametrize(24"broadcast_series",25BROADCAST_SERIES_COMBINATIONS,26)27@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS)28@pytest.mark.slow29def test_array_arithmetic_values(30array_side: str,31broadcast_series: Callable[32[pl.Series, pl.Series, pl.Series], tuple[pl.Series, pl.Series, pl.Series]33],34exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series],35) -> None:36"""37Tests value correctness.3839This test checks for output value correctness (a + b == c) across different40codepaths, by wrapping the values (a, b, c) in different combinations of41list / primitive columns.42"""43import operator as op4445dtypes: list[Any] = [pl.Null, pl.Null, pl.Null]46dtype: Any = pl.Null4748def materialize_array(v: Any) -> pl.Series:49return pl.Series(50[[None, v, None]],51dtype=pl.Array(dtype, 3),52)5354def materialize_array3(v: Any) -> pl.Series:55return pl.Series(56[[[[None, v], None], None]],57dtype=pl.Array(pl.Array(pl.Array(dtype, 2), 2), 2),58)5960def materialize_primitive(v: Any) -> pl.Series:61return pl.Series([v], dtype=dtype)6263def materialize_series(64l: Any, # noqa: E74165r: Any,66o: Any,67) -> tuple[pl.Series, pl.Series, pl.Series]:68nonlocal dtype6970dtype = dtypes[0]71l = { # noqa: E74172"left": materialize_array,73"left3": materialize_array3,74"both": materialize_array,75"both3": materialize_array3,76"right": materialize_primitive,77"right3": materialize_primitive,78"none": materialize_primitive,79}[array_side](l) # fmt: skip8081dtype = dtypes[1]82r = {83"left": materialize_primitive,84"left3": materialize_primitive,85"both": materialize_array,86"both3": materialize_array3,87"right": materialize_array,88"right3": materialize_array3,89"none": materialize_primitive,90}[array_side](r) # fmt: skip9192dtype = dtypes[2]93o = {94"left": materialize_array,95"left3": materialize_array3,96"both": materialize_array,97"both3": materialize_array3,98"right": materialize_array,99"right3": materialize_array3,100"none": materialize_primitive,101}[array_side](o) # fmt: skip102103assert l.len() == 1104assert r.len() == 1105assert o.len() == 1106107return broadcast_series(l, r, o)108109# Signed110dtypes = [pl.Int8, pl.Int8, pl.Int8]111112l, r, o = materialize_series(2, 3, 5) # noqa: E741113assert_series_equal(exec_op(l, r, op.add), o)114115l, r, o = materialize_series(-5, 127, 124) # noqa: E741116assert_series_equal(exec_op(l, r, op.sub), o)117118l, r, o = materialize_series(-5, 127, -123) # noqa: E741119assert_series_equal(exec_op(l, r, op.mul), o)120121l, r, o = materialize_series(-5, 3, -2) # noqa: E741122assert_series_equal(exec_op(l, r, op.floordiv), o)123124l, r, o = materialize_series(-5, 3, 1) # noqa: E741125assert_series_equal(exec_op(l, r, op.mod), o)126127dtypes = [pl.UInt8, pl.UInt8, pl.Float64]128l, r, o = materialize_series(2, 128, 0.015625) # noqa: E741129assert_series_equal(exec_op(l, r, op.truediv), o)130131# Unsigned132dtypes = [pl.UInt8, pl.UInt8, pl.UInt8]133134l, r, o = materialize_series(2, 3, 5) # noqa: E741135assert_series_equal(exec_op(l, r, op.add), o)136137l, r, o = materialize_series(2, 3, 255) # noqa: E741138assert_series_equal(exec_op(l, r, op.sub), o)139140l, r, o = materialize_series(2, 128, 0) # noqa: E741141assert_series_equal(exec_op(l, r, op.mul), o)142143l, r, o = materialize_series(5, 2, 2) # noqa: E741144assert_series_equal(exec_op(l, r, op.floordiv), o)145146l, r, o = materialize_series(5, 2, 1) # noqa: E741147assert_series_equal(exec_op(l, r, op.mod), o)148149dtypes = [pl.UInt8, pl.UInt8, pl.Float64]150l, r, o = materialize_series(2, 128, 0.015625) # noqa: E741151assert_series_equal(exec_op(l, r, op.truediv), o)152153# Floats. Note we pick Float32 to ensure there is no accidental upcasting154# to Float64.155dtypes = [pl.Float32, pl.Float32, pl.Float32]156l, r, o = materialize_series(1.7, 2.3, 4.0) # noqa: E741157assert_series_equal(exec_op(l, r, op.add), o)158159l, r, o = materialize_series(1.7, 2.3, -0.5999999999999999) # noqa: E741160assert_series_equal(exec_op(l, r, op.sub), o)161162l, r, o = materialize_series(1.7, 2.3, 3.9099999999999997) # noqa: E741163assert_series_equal(exec_op(l, r, op.mul), o)164165l, r, o = materialize_series(7.0, 3.0, 2.0) # noqa: E741166assert_series_equal(exec_op(l, r, op.floordiv), o)167168l, r, o = materialize_series(-5.0, 3.0, 1.0) # noqa: E741169assert_series_equal(exec_op(l, r, op.mod), o)170171l, r, o = materialize_series(2.0, 128.0, 0.015625) # noqa: E741172assert_series_equal(exec_op(l, r, op.truediv), o)173174#175# Tests for zero behavior176#177178# Integer179180dtypes = [pl.UInt8, pl.UInt8, pl.UInt8]181182l, r, o = materialize_series(1, 0, None) # noqa: E741183assert_series_equal(exec_op(l, r, op.floordiv), o)184assert_series_equal(exec_op(l, r, op.mod), o)185186l, r, o = materialize_series(0, 0, None) # noqa: E741187assert_series_equal(exec_op(l, r, op.floordiv), o)188assert_series_equal(exec_op(l, r, op.mod), o)189190dtypes = [pl.UInt8, pl.UInt8, pl.Float64]191192l, r, o = materialize_series(1, 0, float("inf")) # noqa: E741193assert_series_equal(exec_op(l, r, op.truediv), o)194195l, r, o = materialize_series(0, 0, float("nan")) # noqa: E741196assert_series_equal(exec_op(l, r, op.truediv), o)197198# Float199200dtypes = [pl.Float32, pl.Float32, pl.Float32]201202l, r, o = materialize_series(1, 0, float("inf")) # noqa: E741203assert_series_equal(exec_op(l, r, op.floordiv), o)204205l, r, o = materialize_series(1, 0, float("nan")) # noqa: E741206assert_series_equal(exec_op(l, r, op.mod), o)207208l, r, o = materialize_series(1, 0, float("inf")) # noqa: E741209assert_series_equal(exec_op(l, r, op.truediv), o)210211l, r, o = materialize_series(0, 0, float("nan")) # noqa: E741212assert_series_equal(exec_op(l, r, op.floordiv), o)213214l, r, o = materialize_series(0, 0, float("nan")) # noqa: E741215assert_series_equal(exec_op(l, r, op.mod), o)216217l, r, o = materialize_series(0, 0, float("nan")) # noqa: E741218assert_series_equal(exec_op(l, r, op.truediv), o)219220#221# Tests for NULL behavior222#223224for dtype, truediv_dtype in [ # type: ignore[misc]225[pl.Int8, pl.Float64],226[pl.Float32, pl.Float32],227]:228for vals in [229[None, None, None],230[0, None, None],231[None, 0, None],232[0, None, None],233[None, 0, None],234[3, None, None],235[None, 3, None],236]:237dtypes = 3 * [dtype]238239l, r, o = materialize_series(*vals) # type: ignore[misc] # noqa: E741240assert_series_equal(exec_op(l, r, op.add), o)241assert_series_equal(exec_op(l, r, op.sub), o)242assert_series_equal(exec_op(l, r, op.mul), o)243assert_series_equal(exec_op(l, r, op.floordiv), o)244assert_series_equal(exec_op(l, r, op.mod), o)245dtypes[2] = truediv_dtype # type: ignore[has-type]246l, r, o = materialize_series(*vals) # type: ignore[misc] # noqa: E741247assert_series_equal(exec_op(l, r, op.truediv), o)248249# Type upcasting for Boolean and Null250251# Check boolean upcasting252dtypes = [pl.Boolean, pl.UInt8, pl.UInt8]253254l, r, o = materialize_series(True, 3, 4) # noqa: E741255assert_series_equal(exec_op(l, r, op.add), o)256257l, r, o = materialize_series(True, 3, 254) # noqa: E741258assert_series_equal(exec_op(l, r, op.sub), o)259260l, r, o = materialize_series(True, 3, 3) # noqa: E741261assert_series_equal(exec_op(l, r, op.mul), o)262263l, r, o = materialize_series(True, 3, 0) # noqa: E741264if array_side != "none":265# TODO: We get an error on non-lists with this:266# "floor_div operation not supported for dtype `bool`"267assert_series_equal(exec_op(l, r, op.floordiv), o)268269l, r, o = materialize_series(True, 3, 1) # noqa: E741270assert_series_equal(exec_op(l, r, op.mod), o)271272dtypes = [pl.Boolean, pl.UInt8, pl.Float64]273l, r, o = materialize_series(True, 128, 0.0078125) # noqa: E741274assert_series_equal(exec_op(l, r, op.truediv), o)275276# Check Null upcasting277dtypes = [pl.Null, pl.UInt8, pl.UInt8]278l, r, o = materialize_series(None, 3, None) # noqa: E741279assert_series_equal(exec_op(l, r, op.add), o)280assert_series_equal(exec_op(l, r, op.sub), o)281assert_series_equal(exec_op(l, r, op.mul), o)282if array_side != "none":283assert_series_equal(exec_op(l, r, op.floordiv), o)284assert_series_equal(exec_op(l, r, op.mod), o)285286dtypes = [pl.Null, pl.UInt8, pl.Float64]287l, r, o = materialize_series(None, 3, None) # noqa: E741288assert_series_equal(exec_op(l, r, op.truediv), o)289290291@pytest.mark.parametrize(292("lhs_dtype", "rhs_dtype", "expected_dtype"),293[294(pl.Array(pl.Int64, 2), pl.Int64, pl.Array(pl.Float64, 2)),295(pl.Array(pl.Float32, 2), pl.Float32, pl.Array(pl.Float32, 2)),296(pl.Array(pl.Duration("us"), 2), pl.Int64, pl.Array(pl.Duration("us"), 2)),297],298)299def test_array_truediv_schema(300lhs_dtype: PolarsDataType, rhs_dtype: PolarsDataType, expected_dtype: PolarsDataType301) -> None:302schema = {"lhs": lhs_dtype, "rhs": rhs_dtype}303df = pl.DataFrame({"lhs": [[None, 10]], "rhs": 2}, schema=schema)304result = df.lazy().select(pl.col("lhs").truediv("rhs")).collect_schema()["lhs"]305assert result == expected_dtype306307308def test_array_literal_broadcast() -> None:309df = pl.DataFrame({"A": [[0.1, 0.2], [0.3, 0.4]]}).cast(pl.Array(float, 2))310311lit = pl.lit([3, 5], pl.Array(float, 2))312assert df.select(313mul=pl.all() * lit,314div=pl.all() / lit,315add=pl.all() + lit,316sub=pl.all() - lit,317div_=lit / pl.all(),318add_=lit + pl.all(),319sub_=lit - pl.all(),320mul_=lit * pl.all(),321).to_dict(as_series=False) == {322"mul": [[0.30000000000000004, 1.0], [0.8999999999999999, 2.0]],323"div": [[0.03333333333333333, 0.04], [0.09999999999999999, 0.08]],324"add": [[3.1, 5.2], [3.3, 5.4]],325"sub": [[-2.9, -4.8], [-2.7, -4.6]],326"div_": [[30.0, 25.0], [10.0, 12.5]],327"add_": [[3.1, 5.2], [3.3, 5.4]],328"sub_": [[2.9, 4.8], [2.7, 4.6]],329"mul_": [[0.30000000000000004, 1.0], [0.8999999999999999, 2.0]],330}331332333def test_array_arith_double_nested_shape() -> None:334# Ensure the implementation doesn't just naively add the leaf arrays without335# checking the dimension. In this example both arrays have the leaf stride as336# 6, however one is (3, 2) while the other is (2, 3).337a = pl.Series([[[1, 1], [1, 1], [1, 1]]], dtype=pl.Array(pl.Array(pl.Int64, 2), 3))338b = pl.Series([[[1, 1, 1], [1, 1, 1]]], dtype=pl.Array(pl.Array(pl.Int64, 3), 2))339340with pytest.raises(InvalidOperationError, match="differing dtypes"):341a + b342343344@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS)345@pytest.mark.parametrize(346"broadcast_series",347BROADCAST_SERIES_COMBINATIONS,348)349@pytest.mark.slow350def test_array_numeric_op_validity_combination(351broadcast_series: Callable[352[pl.Series, pl.Series, pl.Series], tuple[pl.Series, pl.Series, pl.Series]353],354exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series],355) -> None:356import operator as op357358array_dtype = pl.Array(pl.Int64, 1)359360a = pl.Series("a", [[1], [2], None, [None], [11], [1111]], dtype=array_dtype)361b = pl.Series("b", [[1], [3], [11], [1111], None, [None]], dtype=array_dtype)362# expected result363e = pl.Series("a", [[2], [5], None, [None], None, [None]], dtype=array_dtype)364365assert_series_equal(366exec_op(a, b, op.add),367e,368)369370a = pl.Series("a", [[1]], dtype=array_dtype)371b = pl.Series("b", [None], dtype=pl.Int64)372e = pl.Series("a", [[None]], dtype=array_dtype)373374a, b, e = broadcast_series(a, b, e)375assert_series_equal(exec_op(a, b, op.add), e)376377a = pl.Series("a", [None], dtype=array_dtype)378b = pl.Series("b", [1], dtype=pl.Int64)379e = pl.Series("a", [None], dtype=array_dtype)380381a, b, e = broadcast_series(a, b, e)382assert_series_equal(exec_op(a, b, op.add), e)383384a = pl.Series("a", [None], dtype=array_dtype)385b = pl.Series("b", [0], dtype=pl.Int64)386e = pl.Series("a", [None], dtype=array_dtype)387388a, b, e = broadcast_series(a, b, e)389assert_series_equal(exec_op(a, b, op.floordiv), e)390391# >1 level nested array392a = pl.Series(393# row 1: [ [1, NULL], NULL ]394# row 2: NULL395[[[1, None], None], None],396dtype=pl.Array(pl.Array(pl.Int64, 2), 2),397)398b = pl.Series(399[[[0, 0], [0, 0]], [[0, 0], [0, 0]]],400dtype=pl.Array(pl.Array(pl.Int64, 2), 2),401)402e = a # added 0403assert_series_equal(exec_op(a, b, op.add), e)404405406def test_array_elementwise_arithmetic_19682() -> None:407dt = pl.Array(pl.Int64, (2, 3))408409a = pl.Series("a", [[[1, 2, 3], [4, 5, 6]]], dt)410sc = pl.Series("a", [1])411zfa = pl.Series("a", [[]], pl.Array(pl.Int64, 0))412413assert_series_equal(a + a, pl.Series("a", [[[2, 4, 6], [8, 10, 12]]], dt))414assert_series_equal(a + sc, pl.Series("a", [[[2, 3, 4], [5, 6, 7]]], dt))415assert_series_equal(sc + a, pl.Series("a", [[[2, 3, 4], [5, 6, 7]]], dt))416assert_series_equal(zfa + zfa, pl.Series("a", [[]], pl.Array(pl.Int64, 0)))417418419@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS)420def test_array_add_supertype(421exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series],422) -> None:423import operator as op424425a = pl.Series("a", [[1], [2]], dtype=pl.Array(pl.Int8, 1))426b = pl.Series("b", [[1], [999]], dtype=pl.Array(pl.Int64, 1))427428assert_series_equal(429exec_op(a, b, op.add),430pl.Series("a", [[2], [1001]], dtype=pl.Array(pl.Int64, 1)),431)432433434@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS)435def test_array_arithmetic_dtype_mismatch(436exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series],437) -> None:438import operator as op439440a = pl.Series("a", [[1], [2]], dtype=pl.Array(pl.Int64, 1))441b = pl.Series("b", [[1, 1], [999, 999]], dtype=pl.Array(pl.Int64, 2))442443with pytest.raises(InvalidOperationError, match="differing dtypes"):444exec_op(a, b, op.add)445446a = pl.Series([[[1]], [[1]]], dtype=pl.Array(pl.List(pl.Int64), 1))447b = pl.Series([1], dtype=pl.Int64)448449with pytest.raises(450InvalidOperationError, match="dtype was not array on all nesting levels"451):452exec_op(a, a, op.add)453454with pytest.raises(455InvalidOperationError, match="dtype was not array on all nesting levels"456):457exec_op(a, b, op.add)458459with pytest.raises(460InvalidOperationError, match="dtype was not array on all nesting levels"461):462exec_op(b, a, op.add)463464465