Path: blob/main/py-polars/tests/unit/interop/numpy/test_ufunc_expr.py
6939 views
from __future__ import annotations12from typing import Any, Callable, cast34import numpy as np5import pytest67import polars as pl8from polars.testing import assert_frame_equal, assert_series_equal91011def test_ufunc() -> None:12df = pl.DataFrame([pl.Series("a", [1, 2, 3, 4], dtype=pl.UInt8)])13out = df.select(14np.power(pl.col("a"), 2).alias("power_uint8"), # type: ignore[call-overload]15np.power(pl.col("a"), 2.0).alias("power_float64"), # type: ignore[call-overload]16np.power(pl.col("a"), 2, dtype=np.uint16).alias("power_uint16"), # type: ignore[call-overload]17)18expected = pl.DataFrame(19[20pl.Series("power_uint8", [1, 4, 9, 16], dtype=pl.UInt8),21pl.Series("power_float64", [1.0, 4.0, 9.0, 16.0], dtype=pl.Float64),22pl.Series("power_uint16", [1, 4, 9, 16], dtype=pl.UInt16),23]24)25assert_frame_equal(out, expected)26assert out.dtypes == expected.dtypes272829def test_ufunc_expr_not_first() -> None:30"""Check numpy ufunc expressions also work if expression not the first argument."""31df = pl.DataFrame([pl.Series("a", [1, 2, 3], dtype=pl.Float64)])32out = df.select(33np.power(2.0, cast(Any, pl.col("a"))).alias("power"),34(2.0 / cast(Any, pl.col("a"))).alias("divide_scalar"),35)36expected = pl.DataFrame(37[38pl.Series("power", [2**1, 2**2, 2**3], dtype=pl.Float64),39pl.Series("divide_scalar", [2 / 1, 2 / 2, 2 / 3], dtype=pl.Float64),40]41)42assert_frame_equal(out, expected)434445def test_lazy_ufunc() -> None:46ldf = pl.LazyFrame([pl.Series("a", [1, 2, 3, 4], dtype=pl.UInt8)])47out = ldf.select(48np.power(cast(Any, pl.col("a")), 2).alias("power_uint8"),49np.power(cast(Any, pl.col("a")), 2.0).alias("power_float64"),50np.power(cast(Any, pl.col("a")), 2, dtype=np.uint16).alias("power_uint16"),51)52expected = pl.DataFrame(53[54pl.Series("power_uint8", [1, 4, 9, 16], dtype=pl.UInt8),55pl.Series("power_float64", [1.0, 4.0, 9.0, 16.0], dtype=pl.Float64),56pl.Series("power_uint16", [1, 4, 9, 16], dtype=pl.UInt16),57]58)59assert_frame_equal(out.collect(), expected)606162def test_lazy_ufunc_expr_not_first() -> None:63"""Check numpy ufunc expressions also work if expression not the first argument."""64ldf = pl.LazyFrame([pl.Series("a", [1, 2, 3], dtype=pl.Float64)])65out = ldf.select(66np.power(2.0, cast(Any, pl.col("a"))).alias("power"),67(2.0 / cast(Any, pl.col("a"))).alias("divide_scalar"),68)69expected = pl.DataFrame(70[71pl.Series("power", [2**1, 2**2, 2**3], dtype=pl.Float64),72pl.Series("divide_scalar", [2 / 1, 2 / 2, 2 / 3], dtype=pl.Float64),73]74)75assert_frame_equal(out.collect(), expected)767778def test_ufunc_recognition() -> None:79df = pl.DataFrame({"a": [1, 1, 2, 2], "b": [1.1, 2.2, 3.3, 4.4]})80assert_frame_equal(df.select(np.exp(pl.col("b"))), df.select(pl.col("b").exp()))818283# https://github.com/pola-rs/polars/issues/677084def test_ufunc_multiple_expressions() -> None:85df = pl.DataFrame(86{87"v": [88-4.293,89-2.4659,90-1.8378,91-0.2821,92-4.5649,93-3.8128,94-7.4274,953.3443,963.8604,97-4.2200,98],99"u": [100-11.2268,1016.3478,1027.1681,1033.4986,1042.7320,105-1.0695,106-10.1408,10711.2327,1086.6623,109-8.1412,110],111}112)113expected = np.arctan2(df.get_column("v"), df.get_column("u"))114result = df.select(np.arctan2(pl.col("v"), pl.col("u")))[:, 0] # type: ignore[call-overload]115assert_series_equal(expected, result) # type: ignore[arg-type]116117118def test_repeated_name_ufunc_17472() -> None:119"""If a ufunc takes multiple inputs has a repeating name, this works."""120df = pl.DataFrame({"a": [6.0]})121result = df.select(np.divide(pl.col("a"), pl.col("a"))) # type: ignore[call-overload]122expected = pl.DataFrame({"a": [1.0]})123assert_frame_equal(expected, result)124125126def test_grouped_ufunc() -> None:127df = pl.DataFrame({"id": ["a", "a", "b", "b"], "values": [0.1, 0.1, -0.1, -0.1]})128df.group_by("id").agg(pl.col("values").log1p().sum().pipe(np.expm1))129130131def test_generalized_ufunc_scalar() -> None:132numba = pytest.importorskip("numba")133134@numba.guvectorize([(numba.int64[:], numba.int64[:])], "(n)->()") # type: ignore[misc]135def my_custom_sum(arr, result) -> None: # type: ignore[no-untyped-def] # noqa: ANN001136total = 0137for value in arr:138total += value139result[0] = total140141# Make type checkers happy:142custom_sum = cast(Callable[[object], object], my_custom_sum)143144# Demonstrate NumPy as the canonical expected behavior:145assert custom_sum(np.array([10, 2, 3], dtype=np.int64)) == 15146147# Direct call of the gufunc:148df = pl.DataFrame({"values": [10, 2, 3]})149assert custom_sum(df.get_column("values")) == 15150151# Indirect call of the gufunc:152indirect = df.select(153pl.col("values").map_batches(154custom_sum, returns_scalar=True, return_dtype=pl.self_dtype()155)156)157assert_frame_equal(indirect, pl.DataFrame({"values": 15}))158indirect = df.select(159pl.col("values").map_batches(160lambda s: pl.Series([custom_sum(s)]),161returns_scalar=False,162return_dtype=pl.self_dtype(),163)164)165assert_frame_equal(indirect, pl.DataFrame({"values": [15]}))166167# group_by()168df = pl.DataFrame({"labels": ["a", "b", "a", "b"], "values": [10, 2, 3, 30]})169indirect = (170df.group_by("labels")171.agg(172pl.col("values").map_batches(173custom_sum, returns_scalar=True, return_dtype=pl.self_dtype()174)175)176.sort("labels")177)178assert_frame_equal(179indirect, pl.DataFrame({"labels": ["a", "b"], "values": [13, 32]})180)181182183def make_gufunc_mean() -> Callable[[pl.Series], pl.Series]:184numba = pytest.importorskip("numba")185186@numba.guvectorize([(numba.float64[:], numba.float64[:])], "(n)->(n)") # type: ignore[misc]187def gufunc_mean(arr: Any, result: Any) -> None:188mean = arr.mean()189for i in range(len(arr)):190result[i] = mean + i191192return gufunc_mean # type: ignore[no-any-return]193194195def test_generalized_ufunc() -> None:196gufunc_mean = make_gufunc_mean()197df = pl.DataFrame({"s": [1.0, 2.0, 3.0]})198result = df.select([pl.col("s").map_batches(gufunc_mean).alias("result")])199expected = pl.DataFrame({"result": [2.0, 3.0, 4.0]})200assert_frame_equal(result, expected)201202203def test_grouped_generalized_ufunc() -> None:204gufunc_mean = make_gufunc_mean()205df = pl.DataFrame({"id": ["a", "a", "b", "b"], "values": [1.0, 2.0, 3.0, 4.0]})206result = (207df.group_by("id")208.agg(pl.col("values").map_batches(gufunc_mean, return_dtype=pl.self_dtype()))209.sort("id")210)211expected = pl.DataFrame({"id": ["a", "b"], "values": [[1.5, 2.5], [3.5, 4.5]]})212assert_frame_equal(result, expected)213214215def test_ufunc_chain() -> None:216df = pl.DataFrame(217data={"A": [2, 10, 11, 12, 3, 10, 11, 12], "counter": [1, 2, 3, 4, 5, 6, 7, 8]}218)219result = df.rolling(index_column="counter", period="2i").agg(220(np.log(pl.col("A"))).mean().alias("mean_numpy"),221(pl.col("A")).log().mean().alias("mean_polars"),222)223assert_series_equal(result["mean_numpy"], result["mean_polars"].alias("mean_numpy"))224225226