Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/functions/test_ewm_by.py
6939 views
1
from __future__ import annotations
2
3
from datetime import date
4
5
import hypothesis.strategies as st
6
import pytest
7
from hypothesis import given
8
9
import polars as pl
10
from polars.testing import assert_frame_equal
11
from polars.testing.parametric import column, dataframes
12
13
14
@given(
15
data=st.data(),
16
half_life=st.integers(min_value=1, max_value=1000),
17
)
18
def test_ewm_by(data: st.DataObject, half_life: int) -> None:
19
# For evenly spaced times, ewm_by and ewm should be equivalent
20
df = data.draw(
21
dataframes(
22
[
23
column(
24
"values",
25
strategy=st.floats(min_value=-100, max_value=100),
26
dtype=pl.Float64,
27
),
28
],
29
min_size=1,
30
)
31
)
32
result = df.with_row_index().select(
33
pl.col("values").ewm_mean_by(by="index", half_life=f"{half_life}i")
34
)
35
expected = df.select(
36
pl.col("values").ewm_mean(half_life=half_life, ignore_nulls=False, adjust=False)
37
)
38
assert_frame_equal(result, expected)
39
result = (
40
df.with_row_index()
41
.sort("values")
42
.with_columns(
43
pl.col("values").ewm_mean_by(by="index", half_life=f"{half_life}i")
44
)
45
.sort("index")
46
.select("values")
47
)
48
assert_frame_equal(result, expected)
49
50
51
@pytest.mark.parametrize("length", [1, 3])
52
def test_length_mismatch_22084(length: int) -> None:
53
s = pl.Series([0, None])
54
by = pl.Series([date(2020, 1, 5)] * length)
55
with pytest.raises(pl.exceptions.ShapeError):
56
s.ewm_mean_by(by, half_life="4d")
57
58