Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/operations/test_ewm.py
6939 views
1
from __future__ import annotations
2
3
from typing import Any
4
5
import hypothesis.strategies as st
6
import numpy as np
7
import pytest
8
from hypothesis import given
9
10
import polars as pl
11
from polars.expr.expr import _prepare_alpha
12
from polars.testing import assert_series_equal
13
from polars.testing.parametric import series
14
15
16
def test_ewm_mean() -> None:
17
s = pl.Series([2, 5, 3])
18
19
expected = pl.Series([2.0, 4.0, 3.4285714285714284])
20
assert_series_equal(s.ewm_mean(alpha=0.5, adjust=True, ignore_nulls=True), expected)
21
assert_series_equal(
22
s.ewm_mean(alpha=0.5, adjust=True, ignore_nulls=False), expected
23
)
24
25
expected = pl.Series([2.0, 3.8, 3.421053])
26
assert_series_equal(s.ewm_mean(com=2.0, adjust=True, ignore_nulls=True), expected)
27
assert_series_equal(s.ewm_mean(com=2.0, adjust=True, ignore_nulls=False), expected)
28
29
expected = pl.Series([2.0, 3.5, 3.25])
30
assert_series_equal(
31
s.ewm_mean(alpha=0.5, adjust=False, ignore_nulls=True), expected
32
)
33
assert_series_equal(
34
s.ewm_mean(alpha=0.5, adjust=False, ignore_nulls=False), expected
35
)
36
37
s = pl.Series([2, 3, 5, 7, 4])
38
39
expected = pl.Series([None, 2.666667, 4.0, 5.6, 4.774194])
40
assert_series_equal(
41
s.ewm_mean(alpha=0.5, adjust=True, min_samples=2, ignore_nulls=True), expected
42
)
43
assert_series_equal(
44
s.ewm_mean(alpha=0.5, adjust=True, min_samples=2, ignore_nulls=False), expected
45
)
46
47
expected = pl.Series([None, None, 4.0, 5.6, 4.774194])
48
assert_series_equal(
49
s.ewm_mean(alpha=0.5, adjust=True, min_samples=3, ignore_nulls=True), expected
50
)
51
assert_series_equal(
52
s.ewm_mean(alpha=0.5, adjust=True, min_samples=3, ignore_nulls=False), expected
53
)
54
55
s = pl.Series([None, 1.0, 5.0, 7.0, None, 2.0, 5.0, 4])
56
57
expected = pl.Series(
58
[
59
None,
60
1.0,
61
3.6666666666666665,
62
5.571428571428571,
63
None,
64
3.6666666666666665,
65
4.354838709677419,
66
4.174603174603175,
67
],
68
)
69
assert_series_equal(s.ewm_mean(alpha=0.5, adjust=True, ignore_nulls=True), expected)
70
expected = pl.Series(
71
[
72
None,
73
1.0,
74
3.666666666666667,
75
5.571428571428571,
76
None,
77
3.08695652173913,
78
4.2,
79
4.092436974789916,
80
]
81
)
82
assert_series_equal(
83
s.ewm_mean(alpha=0.5, adjust=True, ignore_nulls=False), expected
84
)
85
86
expected = pl.Series([None, 1.0, 3.0, 5.0, None, 3.5, 4.25, 4.125])
87
assert_series_equal(
88
s.ewm_mean(alpha=0.5, adjust=False, ignore_nulls=True), expected
89
)
90
91
expected = pl.Series([None, 1.0, 3.0, 5.0, None, 3.0, 4.0, 4.0])
92
assert_series_equal(
93
s.ewm_mean(alpha=0.5, adjust=False, ignore_nulls=False), expected
94
)
95
96
97
def test_ewm_mean_leading_nulls() -> None:
98
for min_samples in [1, 2, 3]:
99
assert (
100
pl.Series([1, 2, 3, 4])
101
.ewm_mean(com=3, min_samples=min_samples, ignore_nulls=False)
102
.null_count()
103
== min_samples - 1
104
)
105
assert pl.Series([None, 1.0, 1.0, 1.0]).ewm_mean(
106
alpha=0.5, min_samples=1, ignore_nulls=True
107
).to_list() == [None, 1.0, 1.0, 1.0]
108
assert pl.Series([None, 1.0, 1.0, 1.0]).ewm_mean(
109
alpha=0.5, min_samples=2, ignore_nulls=True
110
).to_list() == [None, None, 1.0, 1.0]
111
112
113
def test_ewm_mean_min_samples() -> None:
114
series = pl.Series([1.0, None, None, None])
115
116
ewm_mean = series.ewm_mean(alpha=0.5, min_samples=1, ignore_nulls=True)
117
assert ewm_mean.to_list() == [1.0, None, None, None]
118
ewm_mean = series.ewm_mean(alpha=0.5, min_samples=2, ignore_nulls=True)
119
assert ewm_mean.to_list() == [None, None, None, None]
120
121
series = pl.Series([1.0, None, 2.0, None, 3.0])
122
123
ewm_mean = series.ewm_mean(alpha=0.5, min_samples=1, ignore_nulls=True)
124
assert_series_equal(
125
ewm_mean,
126
pl.Series(
127
[
128
1.0,
129
None,
130
1.6666666666666665,
131
None,
132
2.4285714285714284,
133
]
134
),
135
)
136
ewm_mean = series.ewm_mean(alpha=0.5, min_samples=2, ignore_nulls=True)
137
assert_series_equal(
138
ewm_mean,
139
pl.Series(
140
[
141
None,
142
None,
143
1.6666666666666665,
144
None,
145
2.4285714285714284,
146
]
147
),
148
)
149
150
151
def test_ewm_std_var() -> None:
152
series = pl.Series("a", [2, 5, 3])
153
154
var = series.ewm_var(alpha=0.5, ignore_nulls=False)
155
std = series.ewm_std(alpha=0.5, ignore_nulls=False)
156
expected = pl.Series("a", [0.0, 4.5, 1.9285714285714288])
157
assert np.allclose(var, std**2, rtol=1e-16)
158
assert_series_equal(var, expected)
159
160
161
def test_ewm_std_var_with_nulls() -> None:
162
series = pl.Series("a", [2, 5, None, 3])
163
164
var = series.ewm_var(alpha=0.5, ignore_nulls=True)
165
std = series.ewm_std(alpha=0.5, ignore_nulls=True)
166
expected = pl.Series("a", [0.0, 4.5, None, 1.9285714285714288])
167
assert_series_equal(var, expected)
168
assert_series_equal(std**2, expected)
169
170
var = series.ewm_var(alpha=0.5, ignore_nulls=False)
171
std = series.ewm_std(alpha=0.5, ignore_nulls=False)
172
expected = pl.Series("a", [0.0, 4.5, None, 1.7307692307692308])
173
assert_series_equal(var, expected)
174
assert_series_equal(std**2, expected)
175
176
177
def test_ewm_param_validation() -> None:
178
s = pl.Series("values", range(10))
179
180
with pytest.raises(ValueError, match="mutually exclusive"):
181
s.ewm_std(com=0.5, alpha=0.5, ignore_nulls=False)
182
183
with pytest.raises(ValueError, match="mutually exclusive"):
184
s.ewm_mean(span=1.5, half_life=0.75, ignore_nulls=False)
185
186
with pytest.raises(ValueError, match="mutually exclusive"):
187
s.ewm_var(alpha=0.5, span=1.5, ignore_nulls=False)
188
189
with pytest.raises(ValueError, match="require `com` >= 0"):
190
s.ewm_std(com=-0.5, ignore_nulls=False)
191
192
with pytest.raises(ValueError, match="require `span` >= 1"):
193
s.ewm_mean(span=0.5, ignore_nulls=False)
194
195
with pytest.raises(ValueError, match="require `half_life` > 0"):
196
s.ewm_var(half_life=0, ignore_nulls=False)
197
198
for alpha in (-0.5, -0.0000001, 0.0, 1.0000001, 1.5):
199
with pytest.raises(ValueError, match="require 0 < `alpha` <= 1"):
200
s.ewm_std(alpha=alpha, ignore_nulls=False)
201
202
203
# https://github.com/pola-rs/polars/issues/4951
204
@pytest.mark.may_fail_auto_streaming
205
@pytest.mark.may_fail_cloud # reason: chunking
206
def test_ewm_with_multiple_chunks() -> None:
207
df0 = pl.DataFrame(
208
data=[
209
("w", 6.0, 1.0),
210
("x", 5.0, 2.0),
211
("y", 4.0, 3.0),
212
("z", 3.0, 4.0),
213
],
214
schema=["a", "b", "c"],
215
orient="row",
216
).with_columns(
217
pl.col(pl.Float64).log().diff().name.prefix("ld_"),
218
)
219
assert df0.n_chunks() == 1
220
221
# NOTE: We aren't testing whether `select` creates two chunks;
222
# we just need two chunks to properly test `ewm_mean`
223
df1 = df0.select(["ld_b", "ld_c"])
224
assert df1.n_chunks() == 2
225
226
ewm_std = df1.with_columns(
227
pl.all().ewm_std(com=20, ignore_nulls=False).name.prefix("ewm_"),
228
)
229
assert ewm_std.null_count().sum_horizontal()[0] == 4
230
231
232
def alpha_guard(**decay_param: float) -> bool:
233
"""Protects against unnecessary noise in small number regime."""
234
if not next(iter(decay_param.values())):
235
return True
236
alpha = _prepare_alpha(**decay_param)
237
return ((1 - alpha) if round(alpha) else alpha) > 1e-6
238
239
240
@given(
241
s=series(
242
min_size=4,
243
dtype=pl.Float64,
244
allow_null=True,
245
strategy=st.floats(min_value=-1e8, max_value=1e8),
246
),
247
half_life=st.floats(min_value=0, max_value=4, exclude_min=True).filter(
248
lambda x: alpha_guard(half_life=x)
249
),
250
com=st.floats(min_value=0, max_value=99).filter(lambda x: alpha_guard(com=x)),
251
span=st.floats(min_value=1, max_value=10).filter(lambda x: alpha_guard(span=x)),
252
ignore_nulls=st.booleans(),
253
adjust=st.booleans(),
254
bias=st.booleans(),
255
)
256
def test_ewm_methods(
257
s: pl.Series,
258
com: float | None,
259
span: float | None,
260
half_life: float | None,
261
ignore_nulls: bool,
262
adjust: bool,
263
bias: bool,
264
) -> None:
265
# validate a large set of varied EWM calculations
266
for decay_param in [{"com": com}, {"span": span}, {"half_life": half_life}]:
267
alpha = _prepare_alpha(**decay_param)
268
269
# convert parametrically-generated series to pandas, then use that as a
270
# reference implementation for comparison (after normalising NaN/None)
271
p = s.to_pandas()
272
273
# note: skip min_samples < 2, due to pandas-side inconsistency:
274
# https://github.com/pola-rs/polars/issues/5006#issuecomment-1259477178
275
for mp in range(2, len(s), len(s) // 3):
276
# consolidate ewm parameters
277
pl_params: dict[str, Any] = {
278
"min_samples": mp,
279
"adjust": adjust,
280
"ignore_nulls": ignore_nulls,
281
}
282
pl_params.update(decay_param)
283
pd_params: dict[str, Any] = {
284
"min_periods": mp,
285
"adjust": adjust,
286
"ignore_nulls": ignore_nulls,
287
}
288
pd_params.update(decay_param)
289
290
if "half_life" in pl_params:
291
pd_params["halflife"] = pd_params.pop("half_life")
292
if "ignore_nulls" in pl_params:
293
pd_params["ignore_na"] = pd_params.pop("ignore_nulls")
294
295
# mean:
296
ewm_mean_pl = s.ewm_mean(**pl_params).fill_nan(None)
297
ewm_mean_pd = pl.Series(p.ewm(**pd_params).mean())
298
if alpha == 1:
299
# apply fill-forward to nulls to match pandas
300
# https://github.com/pola-rs/polars/pull/5011#issuecomment-1262318124
301
ewm_mean_pl = ewm_mean_pl.fill_null(strategy="forward")
302
303
assert_series_equal(ewm_mean_pl, ewm_mean_pd, abs_tol=1e-07)
304
305
# std:
306
ewm_std_pl = s.ewm_std(bias=bias, **pl_params).fill_nan(None)
307
ewm_std_pd = pl.Series(p.ewm(**pd_params).std(bias=bias))
308
assert_series_equal(ewm_std_pl, ewm_std_pd, abs_tol=1e-07)
309
310
# var:
311
ewm_var_pl = s.ewm_var(bias=bias, **pl_params).fill_nan(None)
312
ewm_var_pd = pl.Series(p.ewm(**pd_params).var(bias=bias))
313
assert_series_equal(ewm_var_pl, ewm_var_pd, abs_tol=1e-07)
314
315