Path: blob/main/py-polars/tests/unit/operations/arithmetic/test_array.py
6940 views
from __future__ import annotations12from typing import TYPE_CHECKING, Any, Callable34import pytest56import polars as pl7from polars.exceptions import InvalidOperationError8from polars.testing import assert_series_equal9from tests.unit.operations.arithmetic.utils import (10BROADCAST_SERIES_COMBINATIONS,11EXEC_OP_COMBINATIONS,12)1314if TYPE_CHECKING:15from polars._typing import PolarsDataType161718@pytest.mark.parametrize(19"array_side", ["left", "left3", "both", "both3", "right3", "right", "none"]20)21@pytest.mark.parametrize(22"broadcast_series",23BROADCAST_SERIES_COMBINATIONS,24)25@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS)26@pytest.mark.slow27def test_array_arithmetic_values(28array_side: str,29broadcast_series: Callable[30[pl.Series, pl.Series, pl.Series], tuple[pl.Series, pl.Series, pl.Series]31],32exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series],33) -> None:34"""35Tests value correctness.3637This test checks for output value correctness (a + b == c) across different38codepaths, by wrapping the values (a, b, c) in different combinations of39list / primitive columns.40"""41import operator as op4243dtypes: list[Any] = [pl.Null, pl.Null, pl.Null]44dtype: Any = pl.Null4546def materialize_array(v: Any) -> pl.Series:47return pl.Series(48[[None, v, None]],49dtype=pl.Array(dtype, 3),50)5152def materialize_array3(v: Any) -> pl.Series:53return pl.Series(54[[[[None, v], None], None]],55dtype=pl.Array(pl.Array(pl.Array(dtype, 2), 2), 2),56)5758def materialize_primitive(v: Any) -> pl.Series:59return pl.Series([v], dtype=dtype)6061def materialize_series(62l: Any, # noqa: E74163r: Any,64o: Any,65) -> tuple[pl.Series, pl.Series, pl.Series]:66nonlocal dtype6768dtype = dtypes[0]69l = { # noqa: E74170"left": materialize_array,71"left3": materialize_array3,72"both": materialize_array,73"both3": materialize_array3,74"right": materialize_primitive,75"right3": materialize_primitive,76"none": materialize_primitive,77}[array_side](l) # fmt: skip7879dtype = dtypes[1]80r = {81"left": materialize_primitive,82"left3": materialize_primitive,83"both": materialize_array,84"both3": materialize_array3,85"right": materialize_array,86"right3": materialize_array3,87"none": materialize_primitive,88}[array_side](r) # fmt: skip8990dtype = dtypes[2]91o = {92"left": materialize_array,93"left3": materialize_array3,94"both": materialize_array,95"both3": materialize_array3,96"right": materialize_array,97"right3": materialize_array3,98"none": materialize_primitive,99}[array_side](o) # fmt: skip100101assert l.len() == 1102assert r.len() == 1103assert o.len() == 1104105return broadcast_series(l, r, o)106107# Signed108dtypes = [pl.Int8, pl.Int8, pl.Int8]109110l, r, o = materialize_series(2, 3, 5) # noqa: E741111assert_series_equal(exec_op(l, r, op.add), o)112113l, r, o = materialize_series(-5, 127, 124) # noqa: E741114assert_series_equal(exec_op(l, r, op.sub), o)115116l, r, o = materialize_series(-5, 127, -123) # noqa: E741117assert_series_equal(exec_op(l, r, op.mul), o)118119l, r, o = materialize_series(-5, 3, -2) # noqa: E741120assert_series_equal(exec_op(l, r, op.floordiv), o)121122l, r, o = materialize_series(-5, 3, 1) # noqa: E741123assert_series_equal(exec_op(l, r, op.mod), o)124125dtypes = [pl.UInt8, pl.UInt8, pl.Float64]126l, r, o = materialize_series(2, 128, 0.015625) # noqa: E741127assert_series_equal(exec_op(l, r, op.truediv), o)128129# Unsigned130dtypes = [pl.UInt8, pl.UInt8, pl.UInt8]131132l, r, o = materialize_series(2, 3, 5) # noqa: E741133assert_series_equal(exec_op(l, r, op.add), o)134135l, r, o = materialize_series(2, 3, 255) # noqa: E741136assert_series_equal(exec_op(l, r, op.sub), o)137138l, r, o = materialize_series(2, 128, 0) # noqa: E741139assert_series_equal(exec_op(l, r, op.mul), o)140141l, r, o = materialize_series(5, 2, 2) # noqa: E741142assert_series_equal(exec_op(l, r, op.floordiv), o)143144l, r, o = materialize_series(5, 2, 1) # noqa: E741145assert_series_equal(exec_op(l, r, op.mod), o)146147dtypes = [pl.UInt8, pl.UInt8, pl.Float64]148l, r, o = materialize_series(2, 128, 0.015625) # noqa: E741149assert_series_equal(exec_op(l, r, op.truediv), o)150151# Floats. Note we pick Float32 to ensure there is no accidental upcasting152# to Float64.153dtypes = [pl.Float32, pl.Float32, pl.Float32]154l, r, o = materialize_series(1.7, 2.3, 4.0) # noqa: E741155assert_series_equal(exec_op(l, r, op.add), o)156157l, r, o = materialize_series(1.7, 2.3, -0.5999999999999999) # noqa: E741158assert_series_equal(exec_op(l, r, op.sub), o)159160l, r, o = materialize_series(1.7, 2.3, 3.9099999999999997) # noqa: E741161assert_series_equal(exec_op(l, r, op.mul), o)162163l, r, o = materialize_series(7.0, 3.0, 2.0) # noqa: E741164assert_series_equal(exec_op(l, r, op.floordiv), o)165166l, r, o = materialize_series(-5.0, 3.0, 1.0) # noqa: E741167assert_series_equal(exec_op(l, r, op.mod), o)168169l, r, o = materialize_series(2.0, 128.0, 0.015625) # noqa: E741170assert_series_equal(exec_op(l, r, op.truediv), o)171172#173# Tests for zero behavior174#175176# Integer177178dtypes = [pl.UInt8, pl.UInt8, pl.UInt8]179180l, r, o = materialize_series(1, 0, None) # noqa: E741181assert_series_equal(exec_op(l, r, op.floordiv), o)182assert_series_equal(exec_op(l, r, op.mod), o)183184l, r, o = materialize_series(0, 0, None) # noqa: E741185assert_series_equal(exec_op(l, r, op.floordiv), o)186assert_series_equal(exec_op(l, r, op.mod), o)187188dtypes = [pl.UInt8, pl.UInt8, pl.Float64]189190l, r, o = materialize_series(1, 0, float("inf")) # noqa: E741191assert_series_equal(exec_op(l, r, op.truediv), o)192193l, r, o = materialize_series(0, 0, float("nan")) # noqa: E741194assert_series_equal(exec_op(l, r, op.truediv), o)195196# Float197198dtypes = [pl.Float32, pl.Float32, pl.Float32]199200l, r, o = materialize_series(1, 0, float("inf")) # noqa: E741201assert_series_equal(exec_op(l, r, op.floordiv), o)202203l, r, o = materialize_series(1, 0, float("nan")) # noqa: E741204assert_series_equal(exec_op(l, r, op.mod), o)205206l, r, o = materialize_series(1, 0, float("inf")) # noqa: E741207assert_series_equal(exec_op(l, r, op.truediv), o)208209l, r, o = materialize_series(0, 0, float("nan")) # noqa: E741210assert_series_equal(exec_op(l, r, op.floordiv), o)211212l, r, o = materialize_series(0, 0, float("nan")) # noqa: E741213assert_series_equal(exec_op(l, r, op.mod), o)214215l, r, o = materialize_series(0, 0, float("nan")) # noqa: E741216assert_series_equal(exec_op(l, r, op.truediv), o)217218#219# Tests for NULL behavior220#221222for dtype, truediv_dtype in [ # type: ignore[misc]223[pl.Int8, pl.Float64],224[pl.Float32, pl.Float32],225]:226for vals in [227[None, None, None],228[0, None, None],229[None, 0, None],230[0, None, None],231[None, 0, None],232[3, None, None],233[None, 3, None],234]:235dtypes = 3 * [dtype]236237l, r, o = materialize_series(*vals) # type: ignore[misc] # noqa: E741238assert_series_equal(exec_op(l, r, op.add), o)239assert_series_equal(exec_op(l, r, op.sub), o)240assert_series_equal(exec_op(l, r, op.mul), o)241assert_series_equal(exec_op(l, r, op.floordiv), o)242assert_series_equal(exec_op(l, r, op.mod), o)243dtypes[2] = truediv_dtype # type: ignore[has-type]244l, r, o = materialize_series(*vals) # type: ignore[misc] # noqa: E741245assert_series_equal(exec_op(l, r, op.truediv), o)246247# Type upcasting for Boolean and Null248249# Check boolean upcasting250dtypes = [pl.Boolean, pl.UInt8, pl.UInt8]251252l, r, o = materialize_series(True, 3, 4) # noqa: E741253assert_series_equal(exec_op(l, r, op.add), o)254255l, r, o = materialize_series(True, 3, 254) # noqa: E741256assert_series_equal(exec_op(l, r, op.sub), o)257258l, r, o = materialize_series(True, 3, 3) # noqa: E741259assert_series_equal(exec_op(l, r, op.mul), o)260261l, r, o = materialize_series(True, 3, 0) # noqa: E741262if array_side != "none":263# TODO: FIXME: We get an error on non-lists with this:264# "floor_div operation not supported for dtype `bool`"265assert_series_equal(exec_op(l, r, op.floordiv), o)266267l, r, o = materialize_series(True, 3, 1) # noqa: E741268assert_series_equal(exec_op(l, r, op.mod), o)269270dtypes = [pl.Boolean, pl.UInt8, pl.Float64]271l, r, o = materialize_series(True, 128, 0.0078125) # noqa: E741272assert_series_equal(exec_op(l, r, op.truediv), o)273274# Check Null upcasting275dtypes = [pl.Null, pl.UInt8, pl.UInt8]276l, r, o = materialize_series(None, 3, None) # noqa: E741277assert_series_equal(exec_op(l, r, op.add), o)278assert_series_equal(exec_op(l, r, op.sub), o)279assert_series_equal(exec_op(l, r, op.mul), o)280if array_side != "none":281assert_series_equal(exec_op(l, r, op.floordiv), o)282assert_series_equal(exec_op(l, r, op.mod), o)283284dtypes = [pl.Null, pl.UInt8, pl.Float64]285l, r, o = materialize_series(None, 3, None) # noqa: E741286assert_series_equal(exec_op(l, r, op.truediv), o)287288289@pytest.mark.parametrize(290("lhs_dtype", "rhs_dtype", "expected_dtype"),291[292(pl.Array(pl.Int64, 2), pl.Int64, pl.Array(pl.Float64, 2)),293(pl.Array(pl.Float32, 2), pl.Float32, pl.Array(pl.Float32, 2)),294(pl.Array(pl.Duration("us"), 2), pl.Int64, pl.Array(pl.Duration("us"), 2)),295],296)297def test_array_truediv_schema(298lhs_dtype: PolarsDataType, rhs_dtype: PolarsDataType, expected_dtype: PolarsDataType299) -> None:300schema = {"lhs": lhs_dtype, "rhs": rhs_dtype}301df = pl.DataFrame({"lhs": [[None, 10]], "rhs": 2}, schema=schema)302result = df.lazy().select(pl.col("lhs").truediv("rhs")).collect_schema()["lhs"]303assert result == expected_dtype304305306def test_array_literal_broadcast() -> None:307df = pl.DataFrame({"A": [[0.1, 0.2], [0.3, 0.4]]}).cast(pl.Array(float, 2))308309lit = pl.lit([3, 5], pl.Array(float, 2))310assert df.select(311mul=pl.all() * lit,312div=pl.all() / lit,313add=pl.all() + lit,314sub=pl.all() - lit,315div_=lit / pl.all(),316add_=lit + pl.all(),317sub_=lit - pl.all(),318mul_=lit * pl.all(),319).to_dict(as_series=False) == {320"mul": [[0.30000000000000004, 1.0], [0.8999999999999999, 2.0]],321"div": [[0.03333333333333333, 0.04], [0.09999999999999999, 0.08]],322"add": [[3.1, 5.2], [3.3, 5.4]],323"sub": [[-2.9, -4.8], [-2.7, -4.6]],324"div_": [[30.0, 25.0], [10.0, 12.5]],325"add_": [[3.1, 5.2], [3.3, 5.4]],326"sub_": [[2.9, 4.8], [2.7, 4.6]],327"mul_": [[0.30000000000000004, 1.0], [0.8999999999999999, 2.0]],328}329330331def test_array_arith_double_nested_shape() -> None:332# Ensure the implementation doesn't just naively add the leaf arrays without333# checking the dimension. In this example both arrays have the leaf stride as334# 6, however one is (3, 2) while the other is (2, 3).335a = pl.Series([[[1, 1], [1, 1], [1, 1]]], dtype=pl.Array(pl.Array(pl.Int64, 2), 3))336b = pl.Series([[[1, 1, 1], [1, 1, 1]]], dtype=pl.Array(pl.Array(pl.Int64, 3), 2))337338with pytest.raises(InvalidOperationError, match="differing dtypes"):339a + b340341342@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS)343@pytest.mark.parametrize(344"broadcast_series",345BROADCAST_SERIES_COMBINATIONS,346)347@pytest.mark.slow348def test_array_numeric_op_validity_combination(349broadcast_series: Callable[350[pl.Series, pl.Series, pl.Series], tuple[pl.Series, pl.Series, pl.Series]351],352exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series],353) -> None:354import operator as op355356array_dtype = pl.Array(pl.Int64, 1)357358a = pl.Series("a", [[1], [2], None, [None], [11], [1111]], dtype=array_dtype)359b = pl.Series("b", [[1], [3], [11], [1111], None, [None]], dtype=array_dtype)360# expected result361e = pl.Series("a", [[2], [5], None, [None], None, [None]], dtype=array_dtype)362363assert_series_equal(364exec_op(a, b, op.add),365e,366)367368a = pl.Series("a", [[1]], dtype=array_dtype)369b = pl.Series("b", [None], dtype=pl.Int64)370e = pl.Series("a", [[None]], dtype=array_dtype)371372a, b, e = broadcast_series(a, b, e)373assert_series_equal(exec_op(a, b, op.add), e)374375a = pl.Series("a", [None], dtype=array_dtype)376b = pl.Series("b", [1], dtype=pl.Int64)377e = pl.Series("a", [None], dtype=array_dtype)378379a, b, e = broadcast_series(a, b, e)380assert_series_equal(exec_op(a, b, op.add), e)381382a = pl.Series("a", [None], dtype=array_dtype)383b = pl.Series("b", [0], dtype=pl.Int64)384e = pl.Series("a", [None], dtype=array_dtype)385386a, b, e = broadcast_series(a, b, e)387assert_series_equal(exec_op(a, b, op.floordiv), e)388389# >1 level nested array390a = pl.Series(391# row 1: [ [1, NULL], NULL ]392# row 2: NULL393[[[1, None], None], None],394dtype=pl.Array(pl.Array(pl.Int64, 2), 2),395)396b = pl.Series(397[[[0, 0], [0, 0]], [[0, 0], [0, 0]]],398dtype=pl.Array(pl.Array(pl.Int64, 2), 2),399)400e = a # added 0401assert_series_equal(exec_op(a, b, op.add), e)402403404def test_array_elementwise_arithmetic_19682() -> None:405dt = pl.Array(pl.Int64, (2, 3))406407a = pl.Series("a", [[[1, 2, 3], [4, 5, 6]]], dt)408sc = pl.Series("a", [1])409zfa = pl.Series("a", [[]], pl.Array(pl.Int64, 0))410411assert_series_equal(a + a, pl.Series("a", [[[2, 4, 6], [8, 10, 12]]], dt))412assert_series_equal(a + sc, pl.Series("a", [[[2, 3, 4], [5, 6, 7]]], dt))413assert_series_equal(sc + a, pl.Series("a", [[[2, 3, 4], [5, 6, 7]]], dt))414assert_series_equal(zfa + zfa, pl.Series("a", [[]], pl.Array(pl.Int64, 0)))415416417@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS)418def test_array_add_supertype(419exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series],420) -> None:421import operator as op422423a = pl.Series("a", [[1], [2]], dtype=pl.Array(pl.Int8, 1))424b = pl.Series("b", [[1], [999]], dtype=pl.Array(pl.Int64, 1))425426assert_series_equal(427exec_op(a, b, op.add),428pl.Series("a", [[2], [1001]], dtype=pl.Array(pl.Int64, 1)),429)430431432@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS)433def test_array_arithmetic_dtype_mismatch(434exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series],435) -> None:436import operator as op437438a = pl.Series("a", [[1], [2]], dtype=pl.Array(pl.Int64, 1))439b = pl.Series("b", [[1, 1], [999, 999]], dtype=pl.Array(pl.Int64, 2))440441with pytest.raises(InvalidOperationError, match="differing dtypes"):442exec_op(a, b, op.add)443444a = pl.Series([[[1]], [[1]]], dtype=pl.Array(pl.List(pl.Int64), 1))445b = pl.Series([1], dtype=pl.Int64)446447with pytest.raises(448InvalidOperationError, match="dtype was not array on all nesting levels"449):450exec_op(a, a, op.add)451452with pytest.raises(453InvalidOperationError, match="dtype was not array on all nesting levels"454):455exec_op(a, b, op.add)456457with pytest.raises(458InvalidOperationError, match="dtype was not array on all nesting levels"459):460exec_op(b, a, op.add)461462463