Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-compute/src/rolling/no_nulls/moment.rs
6939 views
1
#![allow(unsafe_op_in_unsafe_fn)]
2
use num_traits::{FromPrimitive, ToPrimitive};
3
use polars_error::polars_ensure;
4
5
pub use super::super::moment::*;
6
use super::*;
7
8
pub struct MomentWindow<'a, T, M: StateUpdate> {
9
slice: &'a [T],
10
moment: M,
11
last_start: usize,
12
last_end: usize,
13
params: Option<RollingFnParams>,
14
}
15
16
impl<T: ToPrimitive + Copy, M: StateUpdate> MomentWindow<'_, T, M> {
17
fn compute_var(&mut self, start: usize, end: usize) {
18
self.moment = M::new(self.params);
19
for value in &self.slice[start..end] {
20
let value: f64 = NumCast::from(*value).unwrap();
21
self.moment.insert_one(value);
22
}
23
}
24
}
25
26
impl<'a, T: NativeType + IsFloat + Float + ToPrimitive + FromPrimitive, M: StateUpdate>
27
RollingAggWindowNoNulls<'a, T> for MomentWindow<'a, T, M>
28
{
29
fn new(
30
slice: &'a [T],
31
start: usize,
32
end: usize,
33
params: Option<RollingFnParams>,
34
_window_size: Option<usize>,
35
) -> Self {
36
let mut out = Self {
37
slice,
38
moment: M::new(params),
39
last_start: start,
40
last_end: end,
41
params,
42
};
43
out.compute_var(start, end);
44
out
45
}
46
47
unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
48
let recompute_var = if start >= self.last_end {
49
true
50
} else {
51
// remove elements that should leave the window
52
let mut recompute_var = false;
53
for idx in self.last_start..start {
54
// SAFETY: we are in bounds
55
let leaving_value = *self.slice.get_unchecked(idx);
56
57
// if the leaving value is nan we need to recompute the window
58
if T::is_float() && !leaving_value.is_finite() {
59
recompute_var = true;
60
break;
61
}
62
let leaving_value: f64 = NumCast::from(leaving_value).unwrap();
63
self.moment.remove_one(leaving_value);
64
}
65
recompute_var
66
};
67
68
self.last_start = start;
69
70
// we traverse all values and compute
71
if recompute_var {
72
self.compute_var(start, end);
73
} else {
74
for idx in self.last_end..end {
75
let entering_value = *self.slice.get_unchecked(idx);
76
let entering_value: f64 = NumCast::from(entering_value).unwrap();
77
78
self.moment.insert_one(entering_value);
79
}
80
}
81
self.last_end = end;
82
self.moment.finalize().map(|v| T::from_f64(v).unwrap())
83
}
84
}
85
86
pub fn rolling_var<T>(
87
values: &[T],
88
window_size: usize,
89
min_periods: usize,
90
center: bool,
91
weights: Option<&[f64]>,
92
params: Option<RollingFnParams>,
93
) -> PolarsResult<ArrayRef>
94
where
95
T: NativeType + Float + IsFloat + ToPrimitive + FromPrimitive + AddAssign,
96
{
97
let offset_fn = match center {
98
true => det_offsets_center,
99
false => det_offsets,
100
};
101
match weights {
102
None => rolling_apply_agg_window::<MomentWindow<_, VarianceMoment>, _, _>(
103
values,
104
window_size,
105
min_periods,
106
offset_fn,
107
params,
108
),
109
Some(weights) => {
110
// Validate and standardize the weights like we do for the mean. This definition is fine
111
// because frequency weights and unbiasing don't make sense for rolling operations.
112
let mut wts = no_nulls::coerce_weights(weights);
113
let wsum = wts.iter().fold(T::zero(), |acc, x| acc + *x);
114
polars_ensure!(
115
wsum != T::zero(),
116
ComputeError: "Weighted variance is undefined if weights sum to 0"
117
);
118
wts.iter_mut().for_each(|w| *w = *w / wsum);
119
super::rolling_apply_weights(
120
values,
121
window_size,
122
min_periods,
123
offset_fn,
124
compute_var_weights,
125
&wts,
126
center,
127
)
128
},
129
}
130
}
131
132
pub fn rolling_skew<T>(
133
values: &[T],
134
window_size: usize,
135
min_periods: usize,
136
center: bool,
137
params: Option<RollingFnParams>,
138
) -> PolarsResult<ArrayRef>
139
where
140
T: NativeType + Float + IsFloat + ToPrimitive + FromPrimitive + AddAssign,
141
{
142
let offset_fn = match center {
143
true => det_offsets_center,
144
false => det_offsets,
145
};
146
rolling_apply_agg_window::<MomentWindow<_, SkewMoment>, _, _>(
147
values,
148
window_size,
149
min_periods,
150
offset_fn,
151
params,
152
)
153
}
154
155
pub fn rolling_kurtosis<T>(
156
values: &[T],
157
window_size: usize,
158
min_periods: usize,
159
center: bool,
160
params: Option<RollingFnParams>,
161
) -> PolarsResult<ArrayRef>
162
where
163
T: NativeType + Float + IsFloat + ToPrimitive + FromPrimitive + AddAssign,
164
{
165
let offset_fn = match center {
166
true => det_offsets_center,
167
false => det_offsets,
168
};
169
rolling_apply_agg_window::<MomentWindow<_, KurtosisMoment>, _, _>(
170
values,
171
window_size,
172
min_periods,
173
offset_fn,
174
params,
175
)
176
}
177
178
#[cfg(test)]
179
mod test {
180
use super::*;
181
182
#[test]
183
fn test_rolling_var() {
184
let values = &[1.0f64, 5.0, 3.0, 4.0];
185
186
let out = rolling_var(values, 2, 2, false, None, None).unwrap();
187
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
188
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
189
assert_eq!(out, &[None, Some(8.0), Some(2.0), Some(0.5)]);
190
191
let testpars = Some(RollingFnParams::Var(RollingVarParams { ddof: 0 }));
192
let out = rolling_var(values, 2, 2, false, None, testpars).unwrap();
193
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
194
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
195
assert_eq!(out, &[None, Some(4.0), Some(1.0), Some(0.25)]);
196
197
let out = rolling_var(values, 2, 1, false, None, None).unwrap();
198
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
199
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
200
// we cannot compare nans, so we compare the string values
201
assert_eq!(
202
format!("{:?}", out.as_slice()),
203
format!("{:?}", &[None, Some(8.0), Some(2.0), Some(0.5)])
204
);
205
// test nan handling.
206
let values = &[-10.0, 2.0, 3.0, f64::nan(), 5.0, 6.0, 7.0];
207
let out = rolling_var(values, 3, 3, false, None, None).unwrap();
208
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
209
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
210
// we cannot compare nans, so we compare the string values
211
assert_eq!(
212
format!("{:?}", out.as_slice()),
213
format!(
214
"{:?}",
215
&[
216
None,
217
None,
218
Some(52.33333333333333),
219
Some(f64::nan()),
220
Some(f64::nan()),
221
Some(f64::nan()),
222
Some(1.0)
223
]
224
)
225
);
226
}
227
}
228
229