Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/operations/test_ewm_by.py
6939 views
1
from __future__ import annotations
2
3
from datetime import date, datetime, timedelta
4
from typing import TYPE_CHECKING
5
6
import pytest
7
8
import polars as pl
9
from polars.exceptions import InvalidOperationError
10
from polars.testing import assert_frame_equal, assert_series_equal
11
12
if TYPE_CHECKING:
13
from polars._typing import PolarsIntegerType, TimeUnit
14
15
from zoneinfo import ZoneInfo
16
17
18
@pytest.mark.parametrize("sort", [True, False])
19
def test_ewma_by_date(sort: bool) -> None:
20
df = pl.LazyFrame(
21
{
22
"values": [3.0, 1.0, 2.0, None, 4.0],
23
"times": [
24
None,
25
date(2020, 1, 4),
26
date(2020, 1, 11),
27
date(2020, 1, 16),
28
date(2020, 1, 18),
29
],
30
}
31
)
32
if sort:
33
df = df.sort("times")
34
result = df.select(
35
pl.col("values").ewm_mean_by("times", half_life=timedelta(days=2)),
36
)
37
expected = pl.DataFrame(
38
{"values": [None, 1.0, 1.9116116523516815, None, 3.815410804703363]}
39
)
40
assert_frame_equal(result.collect(), expected)
41
assert result.collect_schema()["values"] == pl.Float64
42
assert result.collect().schema["values"] == pl.Float64
43
44
45
def test_ewma_by_date_constant() -> None:
46
df = pl.DataFrame(
47
{
48
"values": [1, 1, 1],
49
"times": [
50
date(2020, 1, 4),
51
date(2020, 1, 11),
52
date(2020, 1, 16),
53
],
54
}
55
)
56
result = df.select(
57
pl.col("values").ewm_mean_by("times", half_life=timedelta(days=2)),
58
)
59
expected = pl.DataFrame({"values": [1.0, 1, 1]})
60
assert_frame_equal(result, expected)
61
62
63
def test_ewma_f32() -> None:
64
df = pl.LazyFrame(
65
{
66
"values": [3.0, 1.0, 2.0, None, 4.0],
67
"times": [
68
None,
69
date(2020, 1, 4),
70
date(2020, 1, 11),
71
date(2020, 1, 16),
72
date(2020, 1, 18),
73
],
74
},
75
schema_overrides={"values": pl.Float32},
76
)
77
result = df.select(
78
pl.col("values").ewm_mean_by("times", half_life=timedelta(days=2)),
79
)
80
expected = pl.DataFrame(
81
{"values": [None, 1.0, 1.9116116523516815, None, 3.815410804703363]},
82
schema_overrides={"values": pl.Float32},
83
)
84
assert_frame_equal(result.collect(), expected)
85
assert result.collect_schema()["values"] == pl.Float32
86
assert result.collect().schema["values"] == pl.Float32
87
88
89
@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"])
90
@pytest.mark.parametrize("time_zone", [None, "UTC"])
91
def test_ewma_by_datetime(time_unit: TimeUnit, time_zone: str | None) -> None:
92
df = pl.DataFrame(
93
{
94
"values": [3.0, 1.0, 2.0, None, 4.0],
95
"times": [
96
None,
97
datetime(2020, 1, 4),
98
datetime(2020, 1, 11),
99
datetime(2020, 1, 16),
100
datetime(2020, 1, 18),
101
],
102
},
103
schema_overrides={"times": pl.Datetime(time_unit, time_zone)},
104
)
105
result = df.select(
106
pl.col("values").ewm_mean_by("times", half_life=timedelta(days=2)),
107
)
108
expected = pl.DataFrame(
109
{"values": [None, 1.0, 1.9116116523516815, None, 3.815410804703363]}
110
)
111
assert_frame_equal(result, expected)
112
113
114
@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"])
115
def test_ewma_by_datetime_tz_aware(time_unit: TimeUnit) -> None:
116
tzinfo = ZoneInfo("Asia/Kathmandu")
117
df = pl.DataFrame(
118
{
119
"values": [3.0, 1.0, 2.0, None, 4.0],
120
"times": [
121
None,
122
datetime(2020, 1, 4, tzinfo=tzinfo),
123
datetime(2020, 1, 11, tzinfo=tzinfo),
124
datetime(2020, 1, 16, tzinfo=tzinfo),
125
datetime(2020, 1, 18, tzinfo=tzinfo),
126
],
127
},
128
schema_overrides={"times": pl.Datetime(time_unit, "Asia/Kathmandu")},
129
)
130
msg = "expected `half_life` to be a constant duration"
131
with pytest.raises(InvalidOperationError, match=msg):
132
df.select(
133
pl.col("values").ewm_mean_by("times", half_life="2d"),
134
)
135
136
result = df.select(
137
pl.col("values").ewm_mean_by("times", half_life="48h0ns"),
138
)
139
expected = pl.DataFrame(
140
{"values": [None, 1.0, 1.9116116523516815, None, 3.815410804703363]}
141
)
142
assert_frame_equal(result, expected)
143
144
145
@pytest.mark.parametrize("data_type", [pl.Int64, pl.Int32, pl.UInt64, pl.UInt32])
146
def test_ewma_by_index(data_type: PolarsIntegerType) -> None:
147
df = pl.LazyFrame(
148
{
149
"values": [3.0, 1.0, 2.0, None, 4.0],
150
"times": [
151
None,
152
4,
153
11,
154
16,
155
18,
156
],
157
},
158
schema_overrides={"times": data_type},
159
)
160
result = df.select(
161
pl.col("values").ewm_mean_by("times", half_life="2i"),
162
)
163
expected = pl.DataFrame(
164
{"values": [None, 1.0, 1.9116116523516815, None, 3.815410804703363]}
165
)
166
assert_frame_equal(result.collect(), expected)
167
assert result.collect_schema()["values"] == pl.Float64
168
assert result.collect().schema["values"] == pl.Float64
169
170
171
def test_ewma_by_empty() -> None:
172
df = pl.DataFrame({"values": []}, schema_overrides={"values": pl.Float64})
173
result = df.with_row_index().select(
174
pl.col("values").ewm_mean_by("index", half_life="2i"),
175
)
176
expected = pl.DataFrame({"values": []}, schema_overrides={"values": pl.Float64})
177
assert_frame_equal(result, expected)
178
179
180
def test_ewma_by_if_unsorted() -> None:
181
df = pl.DataFrame({"values": [3.0, 2.0], "by": [3, 1]})
182
result = df.with_columns(
183
pl.col("values").ewm_mean_by("by", half_life="2i"),
184
)
185
expected = pl.DataFrame({"values": [2.5, 2.0], "by": [3, 1]})
186
assert_frame_equal(result, expected)
187
188
result = df.with_columns(
189
pl.col("values").ewm_mean_by("by", half_life="2i"),
190
)
191
assert_frame_equal(result, expected)
192
193
result = df.sort("by").with_columns(
194
pl.col("values").ewm_mean_by("by", half_life="2i"),
195
)
196
assert_frame_equal(result, expected.sort("by"))
197
198
199
def test_ewma_by_invalid() -> None:
200
df = pl.DataFrame({"values": [1, 2]})
201
with pytest.raises(InvalidOperationError, match="half_life cannot be negative"):
202
df.with_row_index().select(
203
pl.col("values").ewm_mean_by("index", half_life="-2i"),
204
)
205
df = pl.DataFrame({"values": [[1, 2], [3, 4]]})
206
with pytest.raises(
207
InvalidOperationError, match=r"expected series to be Float64, Float32, .*"
208
):
209
df.with_row_index().select(
210
pl.col("values").ewm_mean_by("index", half_life="2i"),
211
)
212
213
214
def test_ewma_by_warn_two_chunks() -> None:
215
df = pl.DataFrame({"values": [3.0, 2.0], "by": [3, 1]})
216
df = pl.concat([df, df], rechunk=False)
217
218
result = df.with_columns(
219
pl.col("values").ewm_mean_by("by", half_life="2i"),
220
)
221
expected = pl.DataFrame({"values": [2.5, 2.0, 2.5, 2], "by": [3, 1, 3, 1]})
222
assert_frame_equal(result, expected)
223
result = df.sort("by").with_columns(
224
pl.col("values").ewm_mean_by("by", half_life="2i"),
225
)
226
assert_frame_equal(result, expected.sort("by"))
227
228
229
def test_ewma_by_multiple_chunks() -> None:
230
# times contains null
231
times = pl.Series([1, 2]).append(pl.Series([None], dtype=pl.Int64))
232
values = pl.Series([1, 2]).append(pl.Series([3]))
233
result = values.ewm_mean_by(times, half_life="2i")
234
expected = pl.Series([1.0, 1.292893, None])
235
assert_series_equal(result, expected)
236
237
# values contains null
238
times = pl.Series([1, 2]).append(pl.Series([3]))
239
values = pl.Series([1, 2]).append(pl.Series([None], dtype=pl.Int64))
240
result = values.ewm_mean_by(times, half_life="2i")
241
assert_series_equal(result, expected)
242
243