Path: blob/main/py-polars/tests/unit/interop/numpy/test_ufunc_series.py
6939 views
from typing import Any, Callable, cast12import numpy as np3import pytest4from numpy.testing import assert_array_equal56import polars as pl7from polars.exceptions import ComputeError8from polars.testing import assert_series_equal91011def test_ufunc() -> None:12# test if output dtype is calculated correctly.13s_float32 = pl.Series("a", [1.0, 2.0, 3.0, 4.0], dtype=pl.Float32)14assert_series_equal(15cast(pl.Series, np.multiply(s_float32, 4)),16pl.Series("a", [4.0, 8.0, 12.0, 16.0], dtype=pl.Float32),17)1819s_float64 = pl.Series("a", [1.0, 2.0, 3.0, 4.0], dtype=pl.Float64)20assert_series_equal(21cast(pl.Series, np.multiply(s_float64, 4)),22pl.Series("a", [4.0, 8.0, 12.0, 16.0], dtype=pl.Float64),23)2425s_uint8 = pl.Series("a", [1, 2, 3, 4], dtype=pl.UInt8)26assert_series_equal(27cast(pl.Series, np.power(s_uint8, 2)),28pl.Series("a", [1, 4, 9, 16], dtype=pl.UInt8),29)30assert_series_equal(31cast(pl.Series, np.power(s_uint8, 2.0)),32pl.Series("a", [1.0, 4.0, 9.0, 16.0], dtype=pl.Float64),33)34assert_series_equal(35cast(pl.Series, np.power(s_uint8, 2, dtype=np.uint16)),36pl.Series("a", [1, 4, 9, 16], dtype=pl.UInt16),37)3839s_int8 = pl.Series("a", [1, -2, 3, -4], dtype=pl.Int8)40assert_series_equal(41cast(pl.Series, np.power(s_int8, 2)),42pl.Series("a", [1, 4, 9, 16], dtype=pl.Int8),43)44assert_series_equal(45cast(pl.Series, np.power(s_int8, 2.0)),46pl.Series("a", [1.0, 4.0, 9.0, 16.0], dtype=pl.Float64),47)48assert_series_equal(49cast(pl.Series, np.power(s_int8, 2, dtype=np.int16)),50pl.Series("a", [1, 4, 9, 16], dtype=pl.Int16),51)5253s_uint32 = pl.Series("a", [1, 2, 3, 4], dtype=pl.UInt32)54assert_series_equal(55cast(pl.Series, np.power(s_uint32, 2)),56pl.Series("a", [1, 4, 9, 16], dtype=pl.UInt32),57)58assert_series_equal(59cast(pl.Series, np.power(s_uint32, 2.0)),60pl.Series("a", [1.0, 4.0, 9.0, 16.0], dtype=pl.Float64),61)6263s_int32 = pl.Series("a", [1, -2, 3, -4], dtype=pl.Int32)64assert_series_equal(65cast(pl.Series, np.power(s_int32, 2)),66pl.Series("a", [1, 4, 9, 16], dtype=pl.Int32),67)68assert_series_equal(69cast(pl.Series, np.power(s_int32, 2.0)),70pl.Series("a", [1.0, 4.0, 9.0, 16.0], dtype=pl.Float64),71)7273s_uint64 = pl.Series("a", [1, 2, 3, 4], dtype=pl.UInt64)74assert_series_equal(75cast(pl.Series, np.power(s_uint64, 2)),76pl.Series("a", [1, 4, 9, 16], dtype=pl.UInt64),77)78assert_series_equal(79cast(pl.Series, np.power(s_uint64, 2.0)),80pl.Series("a", [1.0, 4.0, 9.0, 16.0], dtype=pl.Float64),81)8283s_int64 = pl.Series("a", [1, -2, 3, -4], dtype=pl.Int64)84assert_series_equal(85cast(pl.Series, np.power(s_int64, 2)),86pl.Series("a", [1, 4, 9, 16], dtype=pl.Int64),87)88assert_series_equal(89cast(pl.Series, np.power(s_int64, 2.0)),90pl.Series("a", [1.0, 4.0, 9.0, 16.0], dtype=pl.Float64),91)9293# test if null bitmask is preserved94a1 = pl.Series("a", [1.0, None, 3.0])95b1 = cast(pl.Series, np.exp(a1))96assert b1.null_count() == 19798# test if it works with chunked series.99a2 = pl.Series("a", [1.0, None, 3.0])100b2 = pl.Series("b", [4.0, 5.0, None])101a2.append(b2)102assert a2.n_chunks() == 2103c2 = np.multiply(a2, 3)104assert_series_equal(105cast(pl.Series, c2),106pl.Series("a", [3.0, None, 9.0, 12.0, 15.0, None]),107)108109# Test if nulls propagate through ufuncs110a3 = pl.Series("a", [None, None, 3, 3])111b3 = pl.Series("b", [None, 3, None, 3])112assert_series_equal(113cast(pl.Series, np.maximum(a3, b3)), pl.Series("a", [None, None, None, 3])114)115116117def test_numpy_string_array() -> None:118s_str = pl.Series("a", ["aa", "bb", "cc", "dd"], dtype=pl.String)119assert_array_equal(120np.char.capitalize(s_str),121np.array(["Aa", "Bb", "Cc", "Dd"], dtype="<U2"),122)123124125def make_add_one() -> Callable[[pl.Series], pl.Series]:126numba = pytest.importorskip("numba")127128@numba.guvectorize([(numba.float64[:], numba.float64[:])], "(n)->(n)") # type: ignore[misc]129def add_one(arr: Any, result: Any) -> None:130for i in range(len(arr)):131result[i] = arr[i] + 1.0132133return add_one # type: ignore[no-any-return]134135136def test_generalized_ufunc() -> None:137"""A generalized ufunc can be called on a pl.Series."""138add_one = make_add_one()139s_float = pl.Series("f", [1.0, 2.0, 3.0])140result = add_one(s_float)141assert_series_equal(result, pl.Series("f", [2.0, 3.0, 4.0]))142143144def test_generalized_ufunc_missing_data() -> None:145"""146If a pl.Series is missing data, using a generalized ufunc is not allowed.147148While this particular example isn't necessarily a semantic issue, consider149a mean() function running on integers: it will give wrong results if the150input is missing data, since NumPy has no way to model missing slots. In151the general case, we can't assume the function will handle missing data152correctly.153"""154add_one = make_add_one()155s_float = pl.Series("f", [1.0, 2.0, 3.0, None], dtype=pl.Float64)156with pytest.raises(157ComputeError,158match="can't pass a Series with missing data to a generalized ufunc",159):160add_one(s_float)161162163def make_divide_by_sum() -> Callable[[pl.Series, pl.Series], pl.Series]:164numba = pytest.importorskip("numba")165float64 = numba.float64166167@numba.guvectorize([(float64[:], float64[:], float64[:])], "(n),(m)->(m)") # type: ignore[misc]168def divide_by_sum(arr: Any, arr2: Any, result: Any) -> None:169total = arr.sum()170for i in range(len(arr2)):171result[i] = arr2[i] / total172173return divide_by_sum # type: ignore[no-any-return]174175176def test_generalized_ufunc_different_output_size() -> None:177"""178It's possible to call a generalized ufunc that takes pl.Series of different sizes.179180The result has the correct size.181"""182divide_by_sum = make_divide_by_sum()183184series = pl.Series("s", [1.0, 3.0], dtype=pl.Float64)185series2 = pl.Series("s2", [8.0, 16.0, 32.0], dtype=pl.Float64)186assert_series_equal(187divide_by_sum(series, series2),188pl.Series("s", [2.0, 4.0, 8.0], dtype=pl.Float64),189)190assert_series_equal(191divide_by_sum(series2, series),192pl.Series("s2", [1.0 / 56, 3.0 / 56], dtype=pl.Float64),193)194195196