Path: blob/main/py-polars/tests/unit/operations/test_ewm.py
6939 views
from __future__ import annotations12from typing import Any34import hypothesis.strategies as st5import numpy as np6import pytest7from hypothesis import given89import polars as pl10from polars.expr.expr import _prepare_alpha11from polars.testing import assert_series_equal12from polars.testing.parametric import series131415def test_ewm_mean() -> None:16s = pl.Series([2, 5, 3])1718expected = pl.Series([2.0, 4.0, 3.4285714285714284])19assert_series_equal(s.ewm_mean(alpha=0.5, adjust=True, ignore_nulls=True), expected)20assert_series_equal(21s.ewm_mean(alpha=0.5, adjust=True, ignore_nulls=False), expected22)2324expected = pl.Series([2.0, 3.8, 3.421053])25assert_series_equal(s.ewm_mean(com=2.0, adjust=True, ignore_nulls=True), expected)26assert_series_equal(s.ewm_mean(com=2.0, adjust=True, ignore_nulls=False), expected)2728expected = pl.Series([2.0, 3.5, 3.25])29assert_series_equal(30s.ewm_mean(alpha=0.5, adjust=False, ignore_nulls=True), expected31)32assert_series_equal(33s.ewm_mean(alpha=0.5, adjust=False, ignore_nulls=False), expected34)3536s = pl.Series([2, 3, 5, 7, 4])3738expected = pl.Series([None, 2.666667, 4.0, 5.6, 4.774194])39assert_series_equal(40s.ewm_mean(alpha=0.5, adjust=True, min_samples=2, ignore_nulls=True), expected41)42assert_series_equal(43s.ewm_mean(alpha=0.5, adjust=True, min_samples=2, ignore_nulls=False), expected44)4546expected = pl.Series([None, None, 4.0, 5.6, 4.774194])47assert_series_equal(48s.ewm_mean(alpha=0.5, adjust=True, min_samples=3, ignore_nulls=True), expected49)50assert_series_equal(51s.ewm_mean(alpha=0.5, adjust=True, min_samples=3, ignore_nulls=False), expected52)5354s = pl.Series([None, 1.0, 5.0, 7.0, None, 2.0, 5.0, 4])5556expected = pl.Series(57[58None,591.0,603.6666666666666665,615.571428571428571,62None,633.6666666666666665,644.354838709677419,654.174603174603175,66],67)68assert_series_equal(s.ewm_mean(alpha=0.5, adjust=True, ignore_nulls=True), expected)69expected = pl.Series(70[71None,721.0,733.666666666666667,745.571428571428571,75None,763.08695652173913,774.2,784.092436974789916,79]80)81assert_series_equal(82s.ewm_mean(alpha=0.5, adjust=True, ignore_nulls=False), expected83)8485expected = pl.Series([None, 1.0, 3.0, 5.0, None, 3.5, 4.25, 4.125])86assert_series_equal(87s.ewm_mean(alpha=0.5, adjust=False, ignore_nulls=True), expected88)8990expected = pl.Series([None, 1.0, 3.0, 5.0, None, 3.0, 4.0, 4.0])91assert_series_equal(92s.ewm_mean(alpha=0.5, adjust=False, ignore_nulls=False), expected93)949596def test_ewm_mean_leading_nulls() -> None:97for min_samples in [1, 2, 3]:98assert (99pl.Series([1, 2, 3, 4])100.ewm_mean(com=3, min_samples=min_samples, ignore_nulls=False)101.null_count()102== min_samples - 1103)104assert pl.Series([None, 1.0, 1.0, 1.0]).ewm_mean(105alpha=0.5, min_samples=1, ignore_nulls=True106).to_list() == [None, 1.0, 1.0, 1.0]107assert pl.Series([None, 1.0, 1.0, 1.0]).ewm_mean(108alpha=0.5, min_samples=2, ignore_nulls=True109).to_list() == [None, None, 1.0, 1.0]110111112def test_ewm_mean_min_samples() -> None:113series = pl.Series([1.0, None, None, None])114115ewm_mean = series.ewm_mean(alpha=0.5, min_samples=1, ignore_nulls=True)116assert ewm_mean.to_list() == [1.0, None, None, None]117ewm_mean = series.ewm_mean(alpha=0.5, min_samples=2, ignore_nulls=True)118assert ewm_mean.to_list() == [None, None, None, None]119120series = pl.Series([1.0, None, 2.0, None, 3.0])121122ewm_mean = series.ewm_mean(alpha=0.5, min_samples=1, ignore_nulls=True)123assert_series_equal(124ewm_mean,125pl.Series(126[1271.0,128None,1291.6666666666666665,130None,1312.4285714285714284,132]133),134)135ewm_mean = series.ewm_mean(alpha=0.5, min_samples=2, ignore_nulls=True)136assert_series_equal(137ewm_mean,138pl.Series(139[140None,141None,1421.6666666666666665,143None,1442.4285714285714284,145]146),147)148149150def test_ewm_std_var() -> None:151series = pl.Series("a", [2, 5, 3])152153var = series.ewm_var(alpha=0.5, ignore_nulls=False)154std = series.ewm_std(alpha=0.5, ignore_nulls=False)155expected = pl.Series("a", [0.0, 4.5, 1.9285714285714288])156assert np.allclose(var, std**2, rtol=1e-16)157assert_series_equal(var, expected)158159160def test_ewm_std_var_with_nulls() -> None:161series = pl.Series("a", [2, 5, None, 3])162163var = series.ewm_var(alpha=0.5, ignore_nulls=True)164std = series.ewm_std(alpha=0.5, ignore_nulls=True)165expected = pl.Series("a", [0.0, 4.5, None, 1.9285714285714288])166assert_series_equal(var, expected)167assert_series_equal(std**2, expected)168169var = series.ewm_var(alpha=0.5, ignore_nulls=False)170std = series.ewm_std(alpha=0.5, ignore_nulls=False)171expected = pl.Series("a", [0.0, 4.5, None, 1.7307692307692308])172assert_series_equal(var, expected)173assert_series_equal(std**2, expected)174175176def test_ewm_param_validation() -> None:177s = pl.Series("values", range(10))178179with pytest.raises(ValueError, match="mutually exclusive"):180s.ewm_std(com=0.5, alpha=0.5, ignore_nulls=False)181182with pytest.raises(ValueError, match="mutually exclusive"):183s.ewm_mean(span=1.5, half_life=0.75, ignore_nulls=False)184185with pytest.raises(ValueError, match="mutually exclusive"):186s.ewm_var(alpha=0.5, span=1.5, ignore_nulls=False)187188with pytest.raises(ValueError, match="require `com` >= 0"):189s.ewm_std(com=-0.5, ignore_nulls=False)190191with pytest.raises(ValueError, match="require `span` >= 1"):192s.ewm_mean(span=0.5, ignore_nulls=False)193194with pytest.raises(ValueError, match="require `half_life` > 0"):195s.ewm_var(half_life=0, ignore_nulls=False)196197for alpha in (-0.5, -0.0000001, 0.0, 1.0000001, 1.5):198with pytest.raises(ValueError, match="require 0 < `alpha` <= 1"):199s.ewm_std(alpha=alpha, ignore_nulls=False)200201202# https://github.com/pola-rs/polars/issues/4951203@pytest.mark.may_fail_auto_streaming204@pytest.mark.may_fail_cloud # reason: chunking205def test_ewm_with_multiple_chunks() -> None:206df0 = pl.DataFrame(207data=[208("w", 6.0, 1.0),209("x", 5.0, 2.0),210("y", 4.0, 3.0),211("z", 3.0, 4.0),212],213schema=["a", "b", "c"],214orient="row",215).with_columns(216pl.col(pl.Float64).log().diff().name.prefix("ld_"),217)218assert df0.n_chunks() == 1219220# NOTE: We aren't testing whether `select` creates two chunks;221# we just need two chunks to properly test `ewm_mean`222df1 = df0.select(["ld_b", "ld_c"])223assert df1.n_chunks() == 2224225ewm_std = df1.with_columns(226pl.all().ewm_std(com=20, ignore_nulls=False).name.prefix("ewm_"),227)228assert ewm_std.null_count().sum_horizontal()[0] == 4229230231def alpha_guard(**decay_param: float) -> bool:232"""Protects against unnecessary noise in small number regime."""233if not next(iter(decay_param.values())):234return True235alpha = _prepare_alpha(**decay_param)236return ((1 - alpha) if round(alpha) else alpha) > 1e-6237238239@given(240s=series(241min_size=4,242dtype=pl.Float64,243allow_null=True,244strategy=st.floats(min_value=-1e8, max_value=1e8),245),246half_life=st.floats(min_value=0, max_value=4, exclude_min=True).filter(247lambda x: alpha_guard(half_life=x)248),249com=st.floats(min_value=0, max_value=99).filter(lambda x: alpha_guard(com=x)),250span=st.floats(min_value=1, max_value=10).filter(lambda x: alpha_guard(span=x)),251ignore_nulls=st.booleans(),252adjust=st.booleans(),253bias=st.booleans(),254)255def test_ewm_methods(256s: pl.Series,257com: float | None,258span: float | None,259half_life: float | None,260ignore_nulls: bool,261adjust: bool,262bias: bool,263) -> None:264# validate a large set of varied EWM calculations265for decay_param in [{"com": com}, {"span": span}, {"half_life": half_life}]:266alpha = _prepare_alpha(**decay_param)267268# convert parametrically-generated series to pandas, then use that as a269# reference implementation for comparison (after normalising NaN/None)270p = s.to_pandas()271272# note: skip min_samples < 2, due to pandas-side inconsistency:273# https://github.com/pola-rs/polars/issues/5006#issuecomment-1259477178274for mp in range(2, len(s), len(s) // 3):275# consolidate ewm parameters276pl_params: dict[str, Any] = {277"min_samples": mp,278"adjust": adjust,279"ignore_nulls": ignore_nulls,280}281pl_params.update(decay_param)282pd_params: dict[str, Any] = {283"min_periods": mp,284"adjust": adjust,285"ignore_nulls": ignore_nulls,286}287pd_params.update(decay_param)288289if "half_life" in pl_params:290pd_params["halflife"] = pd_params.pop("half_life")291if "ignore_nulls" in pl_params:292pd_params["ignore_na"] = pd_params.pop("ignore_nulls")293294# mean:295ewm_mean_pl = s.ewm_mean(**pl_params).fill_nan(None)296ewm_mean_pd = pl.Series(p.ewm(**pd_params).mean())297if alpha == 1:298# apply fill-forward to nulls to match pandas299# https://github.com/pola-rs/polars/pull/5011#issuecomment-1262318124300ewm_mean_pl = ewm_mean_pl.fill_null(strategy="forward")301302assert_series_equal(ewm_mean_pl, ewm_mean_pd, abs_tol=1e-07)303304# std:305ewm_std_pl = s.ewm_std(bias=bias, **pl_params).fill_nan(None)306ewm_std_pd = pl.Series(p.ewm(**pd_params).std(bias=bias))307assert_series_equal(ewm_std_pl, ewm_std_pd, abs_tol=1e-07)308309# var:310ewm_var_pl = s.ewm_var(bias=bias, **pl_params).fill_nan(None)311ewm_var_pd = pl.Series(p.ewm(**pd_params).var(bias=bias))312assert_series_equal(ewm_var_pl, ewm_var_pd, abs_tol=1e-07)313314315