Path: blob/main/py-polars/tests/unit/interop/numpy/test_ufunc_expr.py
8406 views
from __future__ import annotations12from typing import TYPE_CHECKING, Any, cast34import numpy as np5import pytest67import polars as pl8from polars.testing import assert_frame_equal, assert_series_equal910if TYPE_CHECKING:11from collections.abc import Callable121314def test_ufunc() -> None:15df = pl.DataFrame([pl.Series("a", [1, 2, 3, 4], dtype=pl.UInt8)])16out = df.select(17np.power(pl.col("a"), 2).alias("power_uint8"), # type: ignore[call-overload]18np.power(pl.col("a"), 2.0).alias("power_float64"), # type: ignore[call-overload]19np.power(pl.col("a"), 2, dtype=np.uint16).alias("power_uint16"), # type: ignore[call-overload]20)21expected = pl.DataFrame(22[23pl.Series("power_uint8", [1, 4, 9, 16], dtype=pl.UInt8),24pl.Series("power_float64", [1.0, 4.0, 9.0, 16.0], dtype=pl.Float64),25pl.Series("power_uint16", [1, 4, 9, 16], dtype=pl.UInt16),26]27)28assert_frame_equal(out, expected)29assert out.dtypes == expected.dtypes303132def test_ufunc_expr_not_first() -> None:33"""Check numpy ufunc expressions also work if expression not the first argument."""34df = pl.DataFrame([pl.Series("a", [1, 2, 3], dtype=pl.Float64)])35out = df.select(36np.power(2.0, cast("Any", pl.col("a"))).alias("power"),37(2.0 / cast("Any", pl.col("a"))).alias("divide_scalar"),38)39expected = pl.DataFrame(40[41pl.Series("power", [2**1, 2**2, 2**3], dtype=pl.Float64),42pl.Series("divide_scalar", [2 / 1, 2 / 2, 2 / 3], dtype=pl.Float64),43]44)45assert_frame_equal(out, expected)464748def test_lazy_ufunc() -> None:49ldf = pl.LazyFrame([pl.Series("a", [1, 2, 3, 4], dtype=pl.UInt8)])50out = ldf.select(51np.power(cast("Any", pl.col("a")), 2).alias("power_uint8"),52np.power(cast("Any", pl.col("a")), 2.0).alias("power_float64"),53np.power(cast("Any", pl.col("a")), 2, dtype=np.uint16).alias("power_uint16"),54)55expected = pl.DataFrame(56[57pl.Series("power_uint8", [1, 4, 9, 16], dtype=pl.UInt8),58pl.Series("power_float64", [1.0, 4.0, 9.0, 16.0], dtype=pl.Float64),59pl.Series("power_uint16", [1, 4, 9, 16], dtype=pl.UInt16),60]61)62assert_frame_equal(out.collect(), expected)636465def test_lazy_ufunc_expr_not_first() -> None:66"""Check numpy ufunc expressions also work if expression not the first argument."""67ldf = pl.LazyFrame([pl.Series("a", [1, 2, 3], dtype=pl.Float64)])68out = ldf.select(69np.power(2.0, cast("Any", pl.col("a"))).alias("power"),70(2.0 / cast("Any", pl.col("a"))).alias("divide_scalar"),71)72expected = pl.DataFrame(73[74pl.Series("power", [2**1, 2**2, 2**3], dtype=pl.Float64),75pl.Series("divide_scalar", [2 / 1, 2 / 2, 2 / 3], dtype=pl.Float64),76]77)78assert_frame_equal(out.collect(), expected)798081def test_ufunc_recognition() -> None:82df = pl.DataFrame({"a": [1, 1, 2, 2], "b": [1.1, 2.2, 3.3, 4.4]})83assert_frame_equal(df.select(np.exp(pl.col("b"))), df.select(pl.col("b").exp()))848586# https://github.com/pola-rs/polars/issues/677087def test_ufunc_multiple_expressions() -> None:88df = pl.DataFrame(89{90"v": [91-4.293,92-2.4659,93-1.8378,94-0.2821,95-4.5649,96-3.8128,97-7.4274,983.3443,993.8604,100-4.2200,101],102"u": [103-11.2268,1046.3478,1057.1681,1063.4986,1072.7320,108-1.0695,109-10.1408,11011.2327,1116.6623,112-8.1412,113],114}115)116expected = np.arctan2(df.get_column("v"), df.get_column("u"))117result = df.select(np.arctan2(pl.col("v"), pl.col("u")))[:, 0] # type: ignore[call-overload]118assert_series_equal(expected, result) # type: ignore[arg-type]119120121def test_repeated_name_ufunc_17472() -> None:122"""If a ufunc takes multiple inputs has a repeating name, this works."""123df = pl.DataFrame({"a": [6.0]})124result = df.select(np.divide(pl.col("a"), pl.col("a"))) # type: ignore[call-overload]125expected = pl.DataFrame({"a": [1.0]})126assert_frame_equal(expected, result)127128129def test_grouped_ufunc() -> None:130df = pl.DataFrame({"id": ["a", "a", "b", "b"], "values": [0.1, 0.1, -0.1, -0.1]})131df.group_by("id").agg(pl.col("values").log1p().sum().pipe(np.expm1))132133134def test_generalized_ufunc_scalar() -> None:135numba = pytest.importorskip("numba", exc_type=ImportError)136137@numba.guvectorize([(numba.int64[:], numba.int64[:])], "(n)->()") # type: ignore[misc, untyped-decorator]138def my_custom_sum(arr, result) -> None: # type: ignore[no-untyped-def] # noqa: ANN001139total = 0140for value in arr:141total += value142result[0] = total143144# Make type checkers happy:145custom_sum = cast("Callable[[object], object]", my_custom_sum)146147# Demonstrate NumPy as the canonical expected behavior:148assert custom_sum(np.array([10, 2, 3], dtype=np.int64)) == 15149150# Direct call of the gufunc:151df = pl.DataFrame({"values": [10, 2, 3]})152assert custom_sum(df.get_column("values")) == 15153154# Indirect call of the gufunc:155indirect = df.select(156pl.col("values").map_batches(157custom_sum, returns_scalar=True, return_dtype=pl.self_dtype()158)159)160assert_frame_equal(indirect, pl.DataFrame({"values": 15}))161indirect = df.select(162pl.col("values").map_batches(163lambda s: pl.Series([custom_sum(s)]),164returns_scalar=False,165return_dtype=pl.self_dtype(),166)167)168assert_frame_equal(indirect, pl.DataFrame({"values": [15]}))169170# group_by()171df = pl.DataFrame({"labels": ["a", "b", "a", "b"], "values": [10, 2, 3, 30]})172indirect = (173df.group_by("labels")174.agg(175pl.col("values").map_batches(176custom_sum, returns_scalar=True, return_dtype=pl.self_dtype()177)178)179.sort("labels")180)181assert_frame_equal(182indirect, pl.DataFrame({"labels": ["a", "b"], "values": [13, 32]})183)184185186def make_gufunc_mean() -> Callable[[pl.Series], pl.Series]:187numba = pytest.importorskip("numba", exc_type=ImportError)188189@numba.guvectorize([(numba.float64[:], numba.float64[:])], "(n)->(n)") # type: ignore[misc, untyped-decorator]190def gufunc_mean(arr: Any, result: Any) -> None:191mean = arr.mean()192for i in range(len(arr)):193result[i] = mean + i194195return gufunc_mean # type: ignore[no-any-return]196197198def test_generalized_ufunc() -> None:199gufunc_mean = make_gufunc_mean()200df = pl.DataFrame({"s": [1.0, 2.0, 3.0]})201result = df.select([pl.col("s").map_batches(gufunc_mean).alias("result")])202expected = pl.DataFrame({"result": [2.0, 3.0, 4.0]})203assert_frame_equal(result, expected)204205206def test_grouped_generalized_ufunc() -> None:207gufunc_mean = make_gufunc_mean()208df = pl.DataFrame({"id": ["a", "a", "b", "b"], "values": [1.0, 2.0, 3.0, 4.0]})209result = (210df.group_by("id")211.agg(pl.col("values").map_batches(gufunc_mean, return_dtype=pl.self_dtype()))212.sort("id")213)214expected = pl.DataFrame({"id": ["a", "b"], "values": [[1.5, 2.5], [3.5, 4.5]]})215assert_frame_equal(result, expected)216217218def test_ufunc_chain() -> None:219df = pl.DataFrame(220data={"A": [2, 10, 11, 12, 3, 10, 11, 12], "counter": [1, 2, 3, 4, 5, 6, 7, 8]}221)222result = df.rolling(index_column="counter", period="2i").agg(223(np.log(pl.col("A"))).mean().alias("mean_numpy"),224(pl.col("A")).log().mean().alias("mean_polars"),225)226assert_series_equal(result["mean_numpy"], result["mean_polars"].alias("mean_numpy"))227228229