Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-arrow/src/legacy/kernels/ewm/average.rs
6940 views
1
use std::ops::{AddAssign, MulAssign};
2
3
use num_traits::Float;
4
5
use crate::array::PrimitiveArray;
6
use crate::legacy::utils::CustomIterTools;
7
use crate::trusted_len::TrustedLen;
8
use crate::types::NativeType;
9
10
pub fn ewm_mean<I, T>(
11
xs: I,
12
alpha: T,
13
adjust: bool,
14
min_periods: usize,
15
ignore_nulls: bool,
16
) -> PrimitiveArray<T>
17
where
18
I: IntoIterator<Item = Option<T>>,
19
I::IntoIter: TrustedLen,
20
T: Float + NativeType + AddAssign + MulAssign,
21
{
22
let new_wt = if adjust { T::one() } else { alpha };
23
let old_wt_factor = T::one() - alpha;
24
let mut old_wt = T::one();
25
let mut weighted_avg = None;
26
let mut non_null_cnt = 0usize;
27
28
xs.into_iter()
29
.enumerate()
30
.map(|(i, opt_x)| {
31
if opt_x.is_some() {
32
non_null_cnt += 1;
33
}
34
match (i, weighted_avg) {
35
(0, _) | (_, None) => weighted_avg = opt_x,
36
(_, Some(w_avg)) => {
37
if opt_x.is_some() || !ignore_nulls {
38
old_wt *= old_wt_factor;
39
if let Some(x) = opt_x {
40
if w_avg != x {
41
weighted_avg =
42
Some((old_wt * w_avg + new_wt * x) / (old_wt + new_wt));
43
}
44
old_wt = if adjust { old_wt + new_wt } else { T::one() };
45
}
46
}
47
},
48
}
49
match (non_null_cnt < min_periods, opt_x.is_some()) {
50
(_, false) => None,
51
(true, true) => None,
52
(false, true) => weighted_avg,
53
}
54
})
55
.collect_trusted()
56
}
57
58
#[cfg(test)]
59
mod test {
60
use super::super::assert_allclose;
61
use super::*;
62
const ALPHA: f64 = 0.5;
63
const EPS: f64 = 1e-15;
64
65
#[test]
66
fn test_ewm_mean_without_null() {
67
let xs: Vec<Option<f64>> = vec![Some(1.0), Some(2.0), Some(3.0)];
68
for adjust in [false, true] {
69
for ignore_nulls in [false, true] {
70
for min_periods in [0, 1] {
71
let result = ewm_mean(xs.clone(), ALPHA, adjust, min_periods, ignore_nulls);
72
let expected = match adjust {
73
false => PrimitiveArray::from([Some(1.0f64), Some(1.5f64), Some(2.25f64)]),
74
true => PrimitiveArray::from([
75
Some(1.0),
76
Some(1.666_666_666_666_666_7),
77
Some(2.428_571_428_571_428_4),
78
]),
79
};
80
assert_allclose!(result, expected, 1e-15);
81
}
82
let result = ewm_mean(xs.clone(), ALPHA, adjust, 2, ignore_nulls);
83
let expected = match adjust {
84
false => PrimitiveArray::from([None, Some(1.5f64), Some(2.25f64)]),
85
true => PrimitiveArray::from([
86
None,
87
Some(1.666_666_666_666_666_7),
88
Some(2.428_571_428_571_428_4),
89
]),
90
};
91
assert_allclose!(result, expected, EPS);
92
}
93
}
94
}
95
96
#[test]
97
fn test_ewm_mean_with_null() {
98
let xs1 = vec![
99
None,
100
None,
101
Some(5.0f64),
102
Some(7.0f64),
103
None,
104
Some(2.0f64),
105
Some(1.0f64),
106
Some(4.0f64),
107
];
108
assert_allclose!(
109
ewm_mean(xs1.clone(), 0.5, true, 0, true),
110
PrimitiveArray::from([
111
None,
112
None,
113
Some(5.0),
114
Some(6.333_333_333_333_333),
115
None,
116
Some(3.857_142_857_142_857),
117
Some(2.333_333_333_333_333_5),
118
Some(3.193_548_387_096_774),
119
]),
120
EPS
121
);
122
assert_allclose!(
123
ewm_mean(xs1.clone(), 0.5, true, 0, false),
124
PrimitiveArray::from([
125
None,
126
None,
127
Some(5.0),
128
Some(6.333_333_333_333_333),
129
None,
130
Some(3.181_818_181_818_181_7),
131
Some(1.888_888_888_888_888_8),
132
Some(3.033_898_305_084_745_7),
133
]),
134
EPS
135
);
136
assert_allclose!(
137
ewm_mean(xs1.clone(), 0.5, false, 0, true),
138
PrimitiveArray::from([
139
None,
140
None,
141
Some(5.0),
142
Some(6.0),
143
None,
144
Some(4.0),
145
Some(2.5),
146
Some(3.25),
147
]),
148
EPS
149
);
150
assert_allclose!(
151
ewm_mean(xs1, 0.5, false, 0, false),
152
PrimitiveArray::from([
153
None,
154
None,
155
Some(5.0),
156
Some(6.0),
157
None,
158
Some(3.333_333_333_333_333_5),
159
Some(2.166_666_666_666_667),
160
Some(3.083_333_333_333_333_5),
161
]),
162
EPS
163
);
164
}
165
}
166
167