Path: blob/main/py-polars/tests/unit/operations/rolling/test_map.py
6940 views
from __future__ import annotations12from typing import TYPE_CHECKING34import numpy as np5import pytest67import polars as pl8from polars.testing import assert_series_equal9from tests.unit.conftest import FLOAT_DTYPES, INTEGER_DTYPES1011if TYPE_CHECKING:12from polars._typing import PolarsDataType131415@pytest.mark.parametrize(16("input", "output"),17[18([1, 5], [1, 6]),19([1], [1]),20],21)22def test_rolling_map_window_size_9160(input: list[int], output: list[int]) -> None:23s = pl.Series(input)24result = s.rolling_map(lambda x: sum(x), window_size=2, min_samples=1)25expected = pl.Series(output)26assert_series_equal(result, expected)272829def testing_rolling_map_window_size_with_nulls() -> None:30s = pl.Series([0, 1, None, 3, 4, 5])31result = s.rolling_map(lambda x: sum(x), window_size=3, min_samples=3)32expected = pl.Series([None, None, None, None, None, 12])33assert_series_equal(result, expected)343536def test_rolling_map_clear_reuse_series_state_10681() -> None:37df = pl.DataFrame(38{39"a": [1, 1, 1, 1, 2, 2, 2, 2],40"b": [0.0, 1.0, 11.0, 7.0, 4.0, 2.0, 3.0, 8.0],41}42)4344result = df.select(45pl.col("b")46.rolling_map(lambda s: s.min(), window_size=3, min_samples=2)47.over("a")48.alias("min")49)5051expected = pl.Series("min", [None, 0.0, 0.0, 1.0, None, 2.0, 2.0, 2.0])52assert_series_equal(result.to_series(), expected)535455def test_rolling_map_np_nansum() -> None:56s = pl.Series("a", [11.0, 2.0, 9.0, float("nan"), 8.0])5758result = s.rolling_map(np.nansum, 3)5960expected = pl.Series("a", [None, None, 22.0, 11.0, 17.0])61assert_series_equal(result, expected)626364@pytest.mark.parametrize("dtype", [pl.Float32, pl.Float64])65def test_rolling_map_std(dtype: PolarsDataType) -> None:66s = pl.Series("A", [1.0, 2.0, 9.0, 2.0, 13.0], dtype=dtype)67result = s.rolling_map(function=lambda s: s.std(), window_size=3)6869expected = pl.Series("A", [None, None, 4.358899, 4.041452, 5.567764], dtype=dtype)70assert_series_equal(result, expected)717273@pytest.mark.parametrize("dtype", [pl.Float32, pl.Float64])74def test_rolling_map_std_weights(dtype: PolarsDataType) -> None:75s = pl.Series("A", [1.0, 2.0, 9.0, 2.0, 13.0], dtype=dtype)7677result = s.rolling_map(78function=lambda s: s.std(), window_size=3, weights=[1.0, 2.0, 3.0]79)8081expected = pl.Series("A", [None, None, 14.224392, 8.326664, 18.929694], dtype=dtype)82assert_series_equal(result, expected)838485@pytest.mark.parametrize("dtype", INTEGER_DTYPES)86def test_rolling_map_sum_int(dtype: PolarsDataType) -> None:87s = pl.Series("A", [1, 2, 9, 2, 13], dtype=dtype)8889result = s.rolling_map(function=lambda s: s.sum(), window_size=3)9091expected = pl.Series("A", [None, None, 12, 13, 24], dtype=dtype)92assert_series_equal(result, expected)9394q = (95s.to_frame()96.lazy()97.select(pl.col("A").rolling_map(function=lambda s: s.sum(), window_size=3))98)99assert q.collect_schema() == q.collect().schema100101102@pytest.mark.parametrize("dtype", INTEGER_DTYPES)103def test_rolling_map_sum_int_cast_to_float(dtype: PolarsDataType) -> None:104s = pl.Series("A", [1, 2, 9, None, 13], dtype=dtype)105106result = s.rolling_map(107function=lambda s: s.sum(), window_size=3, weights=[1.0, 2.0, 3.0]108)109110expected = pl.Series("A", [None, None, 32.0, None, None], dtype=pl.Float64)111assert_series_equal(result, expected)112113q = (114s.to_frame()115.lazy()116.select(117pl.col("A").rolling_map(118function=lambda s: s.sum(), window_size=3, weights=[1.0, 2.0, 3.0]119)120)121)122assert q.collect_schema() == q.collect().schema123124125@pytest.mark.parametrize("dtype", FLOAT_DTYPES)126def test_rolling_map_sum_float(dtype: PolarsDataType) -> None:127s = pl.Series("A", [1, 2, 9, 2, 13], dtype=dtype)128129result = s.rolling_map(function=lambda s: s.sum(), window_size=3)130131expected = pl.Series("A", [None, None, 12.0, 13.0, 24.0], dtype=dtype)132assert_series_equal(result, expected)133134q = (135s.to_frame()136.lazy()137.select(pl.col("A").rolling_map(function=lambda s: s.sum(), window_size=3))138)139assert q.collect_schema() == q.collect().schema140141142@pytest.mark.parametrize("dtype", FLOAT_DTYPES)143def test_rolling_map_sum_float_weights(dtype: PolarsDataType) -> None:144s = pl.Series("A", [1, 2, 9, 2, 13], dtype=dtype)145146result = s.rolling_map(147function=lambda s: s.sum(), window_size=3, weights=[1.0, 2.0, 3.0]148)149150expected = pl.Series("A", [None, None, 32.0, 26.0, 52.0], dtype=dtype)151assert_series_equal(result, expected)152153q = (154s.to_frame()155.lazy()156.select(157pl.col("A").rolling_map(158function=lambda s: s.sum(), window_size=3, weights=[1.0, 2.0, 3.0]159)160)161)162assert q.collect_schema() == q.collect().schema163164165def test_rolling_map_rolling_sum() -> None:166s = pl.Series("A", list(range(5)), dtype=pl.Float64)167168result = s.rolling_map(169function=lambda s: s.sum(),170window_size=3,171weights=[1.0, 2.1, 3.2],172min_samples=2,173center=True,174)175176expected = s.rolling_sum(177window_size=3, weights=[1.0, 2.1, 3.2], min_samples=2, center=True178)179assert_series_equal(result, expected)180181182def test_rolling_map_rolling_std() -> None:183s = pl.Series("A", list(range(6)), dtype=pl.Float64)184185result = s.rolling_map(186function=lambda s: s.std(),187window_size=4,188min_samples=3,189center=False,190)191192expected = s.rolling_std(window_size=4, min_samples=3, center=False)193assert_series_equal(result, expected)194195196