Path: blob/main/py-polars/tests/unit/operations/test_ewm_by.py
6939 views
from __future__ import annotations12from datetime import date, datetime, timedelta3from typing import TYPE_CHECKING45import pytest67import polars as pl8from polars.exceptions import InvalidOperationError9from polars.testing import assert_frame_equal, assert_series_equal1011if TYPE_CHECKING:12from polars._typing import PolarsIntegerType, TimeUnit1314from zoneinfo import ZoneInfo151617@pytest.mark.parametrize("sort", [True, False])18def test_ewma_by_date(sort: bool) -> None:19df = pl.LazyFrame(20{21"values": [3.0, 1.0, 2.0, None, 4.0],22"times": [23None,24date(2020, 1, 4),25date(2020, 1, 11),26date(2020, 1, 16),27date(2020, 1, 18),28],29}30)31if sort:32df = df.sort("times")33result = df.select(34pl.col("values").ewm_mean_by("times", half_life=timedelta(days=2)),35)36expected = pl.DataFrame(37{"values": [None, 1.0, 1.9116116523516815, None, 3.815410804703363]}38)39assert_frame_equal(result.collect(), expected)40assert result.collect_schema()["values"] == pl.Float6441assert result.collect().schema["values"] == pl.Float64424344def test_ewma_by_date_constant() -> None:45df = pl.DataFrame(46{47"values": [1, 1, 1],48"times": [49date(2020, 1, 4),50date(2020, 1, 11),51date(2020, 1, 16),52],53}54)55result = df.select(56pl.col("values").ewm_mean_by("times", half_life=timedelta(days=2)),57)58expected = pl.DataFrame({"values": [1.0, 1, 1]})59assert_frame_equal(result, expected)606162def test_ewma_f32() -> None:63df = pl.LazyFrame(64{65"values": [3.0, 1.0, 2.0, None, 4.0],66"times": [67None,68date(2020, 1, 4),69date(2020, 1, 11),70date(2020, 1, 16),71date(2020, 1, 18),72],73},74schema_overrides={"values": pl.Float32},75)76result = df.select(77pl.col("values").ewm_mean_by("times", half_life=timedelta(days=2)),78)79expected = pl.DataFrame(80{"values": [None, 1.0, 1.9116116523516815, None, 3.815410804703363]},81schema_overrides={"values": pl.Float32},82)83assert_frame_equal(result.collect(), expected)84assert result.collect_schema()["values"] == pl.Float3285assert result.collect().schema["values"] == pl.Float32868788@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"])89@pytest.mark.parametrize("time_zone", [None, "UTC"])90def test_ewma_by_datetime(time_unit: TimeUnit, time_zone: str | None) -> None:91df = pl.DataFrame(92{93"values": [3.0, 1.0, 2.0, None, 4.0],94"times": [95None,96datetime(2020, 1, 4),97datetime(2020, 1, 11),98datetime(2020, 1, 16),99datetime(2020, 1, 18),100],101},102schema_overrides={"times": pl.Datetime(time_unit, time_zone)},103)104result = df.select(105pl.col("values").ewm_mean_by("times", half_life=timedelta(days=2)),106)107expected = pl.DataFrame(108{"values": [None, 1.0, 1.9116116523516815, None, 3.815410804703363]}109)110assert_frame_equal(result, expected)111112113@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"])114def test_ewma_by_datetime_tz_aware(time_unit: TimeUnit) -> None:115tzinfo = ZoneInfo("Asia/Kathmandu")116df = pl.DataFrame(117{118"values": [3.0, 1.0, 2.0, None, 4.0],119"times": [120None,121datetime(2020, 1, 4, tzinfo=tzinfo),122datetime(2020, 1, 11, tzinfo=tzinfo),123datetime(2020, 1, 16, tzinfo=tzinfo),124datetime(2020, 1, 18, tzinfo=tzinfo),125],126},127schema_overrides={"times": pl.Datetime(time_unit, "Asia/Kathmandu")},128)129msg = "expected `half_life` to be a constant duration"130with pytest.raises(InvalidOperationError, match=msg):131df.select(132pl.col("values").ewm_mean_by("times", half_life="2d"),133)134135result = df.select(136pl.col("values").ewm_mean_by("times", half_life="48h0ns"),137)138expected = pl.DataFrame(139{"values": [None, 1.0, 1.9116116523516815, None, 3.815410804703363]}140)141assert_frame_equal(result, expected)142143144@pytest.mark.parametrize("data_type", [pl.Int64, pl.Int32, pl.UInt64, pl.UInt32])145def test_ewma_by_index(data_type: PolarsIntegerType) -> None:146df = pl.LazyFrame(147{148"values": [3.0, 1.0, 2.0, None, 4.0],149"times": [150None,1514,15211,15316,15418,155],156},157schema_overrides={"times": data_type},158)159result = df.select(160pl.col("values").ewm_mean_by("times", half_life="2i"),161)162expected = pl.DataFrame(163{"values": [None, 1.0, 1.9116116523516815, None, 3.815410804703363]}164)165assert_frame_equal(result.collect(), expected)166assert result.collect_schema()["values"] == pl.Float64167assert result.collect().schema["values"] == pl.Float64168169170def test_ewma_by_empty() -> None:171df = pl.DataFrame({"values": []}, schema_overrides={"values": pl.Float64})172result = df.with_row_index().select(173pl.col("values").ewm_mean_by("index", half_life="2i"),174)175expected = pl.DataFrame({"values": []}, schema_overrides={"values": pl.Float64})176assert_frame_equal(result, expected)177178179def test_ewma_by_if_unsorted() -> None:180df = pl.DataFrame({"values": [3.0, 2.0], "by": [3, 1]})181result = df.with_columns(182pl.col("values").ewm_mean_by("by", half_life="2i"),183)184expected = pl.DataFrame({"values": [2.5, 2.0], "by": [3, 1]})185assert_frame_equal(result, expected)186187result = df.with_columns(188pl.col("values").ewm_mean_by("by", half_life="2i"),189)190assert_frame_equal(result, expected)191192result = df.sort("by").with_columns(193pl.col("values").ewm_mean_by("by", half_life="2i"),194)195assert_frame_equal(result, expected.sort("by"))196197198def test_ewma_by_invalid() -> None:199df = pl.DataFrame({"values": [1, 2]})200with pytest.raises(InvalidOperationError, match="half_life cannot be negative"):201df.with_row_index().select(202pl.col("values").ewm_mean_by("index", half_life="-2i"),203)204df = pl.DataFrame({"values": [[1, 2], [3, 4]]})205with pytest.raises(206InvalidOperationError, match=r"expected series to be Float64, Float32, .*"207):208df.with_row_index().select(209pl.col("values").ewm_mean_by("index", half_life="2i"),210)211212213def test_ewma_by_warn_two_chunks() -> None:214df = pl.DataFrame({"values": [3.0, 2.0], "by": [3, 1]})215df = pl.concat([df, df], rechunk=False)216217result = df.with_columns(218pl.col("values").ewm_mean_by("by", half_life="2i"),219)220expected = pl.DataFrame({"values": [2.5, 2.0, 2.5, 2], "by": [3, 1, 3, 1]})221assert_frame_equal(result, expected)222result = df.sort("by").with_columns(223pl.col("values").ewm_mean_by("by", half_life="2i"),224)225assert_frame_equal(result, expected.sort("by"))226227228def test_ewma_by_multiple_chunks() -> None:229# times contains null230times = pl.Series([1, 2]).append(pl.Series([None], dtype=pl.Int64))231values = pl.Series([1, 2]).append(pl.Series([3]))232result = values.ewm_mean_by(times, half_life="2i")233expected = pl.Series([1.0, 1.292893, None])234assert_series_equal(result, expected)235236# values contains null237times = pl.Series([1, 2]).append(pl.Series([3]))238values = pl.Series([1, 2]).append(pl.Series([None], dtype=pl.Int64))239result = values.ewm_mean_by(times, half_life="2i")240assert_series_equal(result, expected)241242243