Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-compute/src/ewm/mean.rs
7884 views
1
use arrow::array::{Array, PrimitiveArray};
2
use arrow::types::NativeType;
3
4
use crate::ewm::EwmStateUpdate;
5
6
pub fn ewm_mean<I, T>(
7
xs: I,
8
alpha: T,
9
adjust: bool,
10
min_periods: usize,
11
ignore_nulls: bool,
12
) -> PrimitiveArray<T>
13
where
14
I: IntoIterator<Item = Option<T>>,
15
T: num_traits::Float + NativeType + std::ops::MulAssign,
16
{
17
let mut state: EwmMeanState<T> = EwmMeanState::new(alpha, adjust, min_periods, ignore_nulls);
18
state.update_iter(xs).collect()
19
}
20
21
pub struct EwmMeanState<T> {
22
weighted_mean: T,
23
weight: T,
24
alpha: T,
25
non_null_count: usize,
26
adjust: bool,
27
min_periods: usize,
28
ignore_nulls: bool,
29
}
30
31
impl<T> EwmMeanState<T>
32
where
33
T: NativeType + num_traits::Float + std::ops::MulAssign,
34
{
35
pub fn new(alpha: T, adjust: bool, min_periods: usize, ignore_nulls: bool) -> Self {
36
Self {
37
weighted_mean: T::zero(),
38
weight: T::zero(),
39
alpha,
40
non_null_count: 0,
41
adjust,
42
min_periods: min_periods.max(1),
43
ignore_nulls,
44
}
45
}
46
47
pub fn update(&mut self, values: &PrimitiveArray<T>) -> PrimitiveArray<T> {
48
self.update_iter(values.iter().map(|x| x.copied()))
49
.collect()
50
}
51
52
pub fn update_iter<I>(&mut self, values: I) -> impl Iterator<Item = Option<T>>
53
where
54
I: IntoIterator<Item = Option<T>>,
55
{
56
let new_value_weight = if self.adjust { T::one() } else { self.alpha };
57
58
values.into_iter().map(move |opt_v| {
59
if self.non_null_count == 0
60
&& let Some(v) = opt_v
61
{
62
// Initialize
63
self.non_null_count = 1;
64
self.weighted_mean = v;
65
self.weight = T::one();
66
} else {
67
if opt_v.is_some() || !self.ignore_nulls {
68
self.weight *= T::one() - self.alpha;
69
}
70
71
if let Some(new_v) = opt_v {
72
let new_weight = self.weight + new_value_weight;
73
74
self.weighted_mean = self.weighted_mean
75
+ (new_v - self.weighted_mean) * (new_value_weight / new_weight);
76
77
self.weight = if self.adjust {
78
self.weight + T::one()
79
} else {
80
T::one()
81
};
82
83
self.non_null_count += 1;
84
}
85
}
86
87
(opt_v.is_some() && self.non_null_count >= self.min_periods)
88
.then_some(self.weighted_mean)
89
})
90
}
91
}
92
93
impl<T> EwmStateUpdate for EwmMeanState<T>
94
where
95
T: NativeType + num_traits::Float + std::ops::MulAssign,
96
{
97
fn ewm_state_update(&mut self, values: &dyn Array) -> Box<dyn Array> {
98
let values: &PrimitiveArray<T> = values.as_any().downcast_ref().unwrap();
99
100
let out: PrimitiveArray<T> = self.update(values);
101
102
out.boxed()
103
}
104
}
105
106
#[cfg(test)]
107
mod test {
108
use super::super::assert_allclose;
109
use super::*;
110
const ALPHA: f64 = 0.5;
111
const EPS: f64 = 1e-15;
112
113
#[test]
114
fn test_ewm_mean_without_null() {
115
let xs: Vec<Option<f64>> = vec![Some(1.0), Some(2.0), Some(3.0)];
116
for adjust in [false, true] {
117
for ignore_nulls in [false, true] {
118
for min_periods in [0, 1] {
119
let result = ewm_mean(xs.clone(), ALPHA, adjust, min_periods, ignore_nulls);
120
let expected = match adjust {
121
false => PrimitiveArray::from([Some(1.0f64), Some(1.5f64), Some(2.25f64)]),
122
true => PrimitiveArray::from([
123
Some(1.0),
124
Some(1.666_666_666_666_666_7),
125
Some(2.428_571_428_571_428_4),
126
]),
127
};
128
assert_allclose!(result, expected, 1e-15);
129
}
130
let result = ewm_mean(xs.clone(), ALPHA, adjust, 2, ignore_nulls);
131
let expected = match adjust {
132
false => PrimitiveArray::from([None, Some(1.5f64), Some(2.25f64)]),
133
true => PrimitiveArray::from([
134
None,
135
Some(1.666_666_666_666_666_7),
136
Some(2.428_571_428_571_428_4),
137
]),
138
};
139
assert_allclose!(result, expected, EPS);
140
}
141
}
142
}
143
144
#[test]
145
fn test_ewm_mean_with_null() {
146
let xs1 = vec![
147
None,
148
None,
149
Some(5.0f64),
150
Some(7.0f64),
151
None,
152
Some(2.0f64),
153
Some(1.0f64),
154
Some(4.0f64),
155
];
156
assert_allclose!(
157
ewm_mean(xs1.clone(), 0.5, true, 0, true),
158
PrimitiveArray::from([
159
None,
160
None,
161
Some(5.0),
162
Some(6.333_333_333_333_333),
163
None,
164
Some(3.857_142_857_142_857),
165
Some(2.333_333_333_333_333_5),
166
Some(3.193_548_387_096_774),
167
]),
168
EPS
169
);
170
assert_allclose!(
171
ewm_mean(xs1.clone(), 0.5, true, 0, false),
172
PrimitiveArray::from([
173
None,
174
None,
175
Some(5.0),
176
Some(6.333_333_333_333_333),
177
None,
178
Some(3.181_818_181_818_181_7),
179
Some(1.888_888_888_888_888_8),
180
Some(3.033_898_305_084_745_7),
181
]),
182
EPS
183
);
184
assert_allclose!(
185
ewm_mean(xs1.clone(), 0.5, false, 0, true),
186
PrimitiveArray::from([
187
None,
188
None,
189
Some(5.0),
190
Some(6.0),
191
None,
192
Some(4.0),
193
Some(2.5),
194
Some(3.25),
195
]),
196
EPS
197
);
198
assert_allclose!(
199
ewm_mean(xs1, 0.5, false, 0, false),
200
PrimitiveArray::from([
201
None,
202
None,
203
Some(5.0),
204
Some(6.0),
205
None,
206
Some(3.333_333_333_333_333_5),
207
Some(2.166_666_666_666_667),
208
Some(3.083_333_333_333_333_5),
209
]),
210
EPS
211
);
212
}
213
}
214
215