Path: blob/main/py-polars/tests/unit/interop/numpy/test_ufunc_series.py
8424 views
from collections.abc import Callable1from typing import Any, cast23import numpy as np4import pytest5from numpy.testing import assert_array_equal67import polars as pl8from polars.exceptions import ComputeError9from polars.testing import assert_series_equal101112def test_ufunc() -> None:13# test if output dtype is calculated correctly.14s_float32 = pl.Series("a", [1.0, 2.0, 3.0, 4.0], dtype=pl.Float32)15assert_series_equal(16cast("pl.Series", np.multiply(s_float32, 4)),17pl.Series("a", [4.0, 8.0, 12.0, 16.0], dtype=pl.Float32),18)1920s_float64 = pl.Series("a", [1.0, 2.0, 3.0, 4.0], dtype=pl.Float64)21assert_series_equal(22cast("pl.Series", np.multiply(s_float64, 4)),23pl.Series("a", [4.0, 8.0, 12.0, 16.0], dtype=pl.Float64),24)2526s_uint8 = pl.Series("a", [1, 2, 3, 4], dtype=pl.UInt8)27assert_series_equal(28cast("pl.Series", np.power(s_uint8, 2)),29pl.Series("a", [1, 4, 9, 16], dtype=pl.UInt8),30)31assert_series_equal(32cast("pl.Series", np.power(s_uint8, 2.0)),33pl.Series("a", [1.0, 4.0, 9.0, 16.0], dtype=pl.Float64),34)35assert_series_equal(36cast("pl.Series", np.power(s_uint8, 2, dtype=np.uint16)),37pl.Series("a", [1, 4, 9, 16], dtype=pl.UInt16),38)3940s_int8 = pl.Series("a", [1, -2, 3, -4], dtype=pl.Int8)41assert_series_equal(42cast("pl.Series", np.power(s_int8, 2)),43pl.Series("a", [1, 4, 9, 16], dtype=pl.Int8),44)45assert_series_equal(46cast("pl.Series", np.power(s_int8, 2.0)),47pl.Series("a", [1.0, 4.0, 9.0, 16.0], dtype=pl.Float64),48)49assert_series_equal(50cast("pl.Series", np.power(s_int8, 2, dtype=np.int16)),51pl.Series("a", [1, 4, 9, 16], dtype=pl.Int16),52)5354s_uint32 = pl.Series("a", [1, 2, 3, 4], dtype=pl.UInt32)55assert_series_equal(56cast("pl.Series", np.power(s_uint32, 2)),57pl.Series("a", [1, 4, 9, 16], dtype=pl.UInt32),58)59assert_series_equal(60cast("pl.Series", np.power(s_uint32, 2.0)),61pl.Series("a", [1.0, 4.0, 9.0, 16.0], dtype=pl.Float64),62)6364s_int32 = pl.Series("a", [1, -2, 3, -4], dtype=pl.Int32)65assert_series_equal(66cast("pl.Series", np.power(s_int32, 2)),67pl.Series("a", [1, 4, 9, 16], dtype=pl.Int32),68)69assert_series_equal(70cast("pl.Series", np.power(s_int32, 2.0)),71pl.Series("a", [1.0, 4.0, 9.0, 16.0], dtype=pl.Float64),72)7374s_uint64 = pl.Series("a", [1, 2, 3, 4], dtype=pl.UInt64)75assert_series_equal(76cast("pl.Series", np.power(s_uint64, 2)),77pl.Series("a", [1, 4, 9, 16], dtype=pl.UInt64),78)79assert_series_equal(80cast("pl.Series", np.power(s_uint64, 2.0)),81pl.Series("a", [1.0, 4.0, 9.0, 16.0], dtype=pl.Float64),82)8384s_int64 = pl.Series("a", [1, -2, 3, -4], dtype=pl.Int64)85assert_series_equal(86cast("pl.Series", np.power(s_int64, 2)),87pl.Series("a", [1, 4, 9, 16], dtype=pl.Int64),88)89assert_series_equal(90cast("pl.Series", np.power(s_int64, 2.0)),91pl.Series("a", [1.0, 4.0, 9.0, 16.0], dtype=pl.Float64),92)9394# test if null bitmask is preserved95a1 = pl.Series("a", [1.0, None, 3.0])96b1 = cast("pl.Series", np.exp(a1))97assert b1.null_count() == 19899# test if it works with chunked series.100a2 = pl.Series("a", [1.0, None, 3.0])101b2 = pl.Series("b", [4.0, 5.0, None])102a2.append(b2)103assert a2.n_chunks() == 2104c2 = np.multiply(a2, 3)105assert_series_equal(106cast("pl.Series", c2),107pl.Series("a", [3.0, None, 9.0, 12.0, 15.0, None]),108)109110# Test if nulls propagate through ufuncs111a3 = pl.Series("a", [None, None, 3, 3])112b3 = pl.Series("b", [None, 3, None, 3])113assert_series_equal(114cast("pl.Series", np.maximum(a3, b3)), pl.Series("a", [None, None, None, 3])115)116117118def test_numpy_string_array() -> None:119s_str = pl.Series("a", ["aa", "bb", "cc", "dd"], dtype=pl.String)120assert_array_equal(121np.char.capitalize(s_str),122np.array(["Aa", "Bb", "Cc", "Dd"], dtype="<U2"),123)124125126def make_add_one() -> Callable[[pl.Series], pl.Series]:127numba = pytest.importorskip("numba", exc_type=ImportError)128129@numba.guvectorize([(numba.float64[:], numba.float64[:])], "(n)->(n)") # type: ignore[misc, untyped-decorator]130def add_one(arr: Any, result: Any) -> None:131for i in range(len(arr)):132result[i] = arr[i] + 1.0133134return add_one # type: ignore[no-any-return]135136137def test_generalized_ufunc() -> None:138"""A generalized ufunc can be called on a pl.Series."""139add_one = make_add_one()140s_float = pl.Series("f", [1.0, 2.0, 3.0])141result = add_one(s_float)142assert_series_equal(result, pl.Series("f", [2.0, 3.0, 4.0]))143144145def test_generalized_ufunc_missing_data() -> None:146"""147If a pl.Series is missing data, using a generalized ufunc is not allowed.148149While this particular example isn't necessarily a semantic issue, consider150a mean() function running on integers: it will give wrong results if the151input is missing data, since NumPy has no way to model missing slots. In152the general case, we can't assume the function will handle missing data153correctly.154"""155add_one = make_add_one()156s_float = pl.Series("f", [1.0, 2.0, 3.0, None], dtype=pl.Float64)157with pytest.raises(158ComputeError,159match="can't pass a Series with missing data to a generalized ufunc",160):161add_one(s_float)162163164def make_divide_by_sum() -> Callable[[pl.Series, pl.Series], pl.Series]:165numba = pytest.importorskip("numba", exc_type=ImportError)166float64 = numba.float64167168@numba.guvectorize([(float64[:], float64[:], float64[:])], "(n),(m)->(m)") # type: ignore[misc, untyped-decorator]169def divide_by_sum(arr: Any, arr2: Any, result: Any) -> None:170total = arr.sum()171for i in range(len(arr2)):172result[i] = arr2[i] / total173174return divide_by_sum # type: ignore[no-any-return]175176177def test_generalized_ufunc_different_output_size() -> None:178"""179It's possible to call a generalized ufunc that takes pl.Series of different sizes.180181The result has the correct size.182"""183divide_by_sum = make_divide_by_sum()184185series = pl.Series("s", [1.0, 3.0], dtype=pl.Float64)186series2 = pl.Series("s2", [8.0, 16.0, 32.0], dtype=pl.Float64)187assert_series_equal(188divide_by_sum(series, series2),189pl.Series("s", [2.0, 4.0, 8.0], dtype=pl.Float64),190)191assert_series_equal(192divide_by_sum(series2, series),193pl.Series("s2", [1.0 / 56, 3.0 / 56], dtype=pl.Float64),194)195196197