Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-compute/src/rolling/moment.rs
8327 views
1
use num_traits::{FromPrimitive, ToPrimitive};
2
3
use super::no_nulls::RollingAggWindowNoNulls;
4
use super::nulls::RollingAggWindowNulls;
5
use super::*;
6
use crate::moment::{KurtosisState, SkewState, VarState};
7
8
pub trait StateUpdate {
9
fn new(params: Option<RollingFnParams>) -> Self;
10
fn reset(&mut self);
11
fn insert_one(&mut self, x: f64);
12
fn remove_one(&mut self, x: f64);
13
fn finalize(&self) -> Option<f64>;
14
}
15
16
pub struct VarianceMoment {
17
state: VarState,
18
ddof: u8,
19
}
20
21
impl StateUpdate for VarianceMoment {
22
fn new(params: Option<RollingFnParams>) -> Self {
23
let ddof = if let Some(RollingFnParams::Var(params)) = params {
24
params.ddof
25
} else {
26
1
27
};
28
29
Self {
30
state: VarState::default(),
31
ddof,
32
}
33
}
34
35
#[inline(always)]
36
fn reset(&mut self) {
37
self.state = VarState::default();
38
}
39
40
#[inline(always)]
41
fn insert_one(&mut self, x: f64) {
42
self.state.insert_one(x);
43
}
44
45
#[inline(always)]
46
fn remove_one(&mut self, x: f64) {
47
self.state.remove_one(x);
48
}
49
50
#[inline(always)]
51
fn finalize(&self) -> Option<f64> {
52
self.state.finalize(self.ddof)
53
}
54
}
55
56
pub struct KurtosisMoment {
57
state: KurtosisState,
58
fisher: bool,
59
bias: bool,
60
}
61
62
impl StateUpdate for KurtosisMoment {
63
fn new(params: Option<RollingFnParams>) -> Self {
64
let (fisher, bias) = if let Some(RollingFnParams::Kurtosis { fisher, bias }) = params {
65
(fisher, bias)
66
} else {
67
(false, false)
68
};
69
70
Self {
71
state: KurtosisState::default(),
72
fisher,
73
bias,
74
}
75
}
76
77
#[inline(always)]
78
fn reset(&mut self) {
79
self.state = KurtosisState::default();
80
}
81
82
#[inline(always)]
83
fn insert_one(&mut self, x: f64) {
84
self.state.insert_one(x);
85
}
86
87
#[inline(always)]
88
fn remove_one(&mut self, x: f64) {
89
self.state.remove_one(x);
90
}
91
92
#[inline(always)]
93
fn finalize(&self) -> Option<f64> {
94
self.state.finalize(self.fisher, self.bias)
95
}
96
}
97
98
pub struct SkewMoment {
99
state: SkewState,
100
bias: bool,
101
}
102
103
impl StateUpdate for SkewMoment {
104
fn new(params: Option<RollingFnParams>) -> Self {
105
let bias = if let Some(RollingFnParams::Skew { bias }) = params {
106
bias
107
} else {
108
false
109
};
110
111
Self {
112
state: SkewState::default(),
113
bias,
114
}
115
}
116
117
#[inline(always)]
118
fn reset(&mut self) {
119
self.state = SkewState::default();
120
}
121
122
#[inline(always)]
123
fn insert_one(&mut self, x: f64) {
124
self.state.insert_one(x);
125
}
126
127
#[inline(always)]
128
fn remove_one(&mut self, x: f64) {
129
self.state.remove_one(x);
130
}
131
132
#[inline(always)]
133
fn finalize(&self) -> Option<f64> {
134
self.state.finalize(self.bias)
135
}
136
}
137
138
pub struct MomentWindow<'a, T, M: StateUpdate> {
139
slice: &'a [T],
140
validity: Option<&'a Bitmap>,
141
moment: M,
142
non_finite_count: usize, // NaN or infinity.
143
null_count: usize,
144
start: usize,
145
end: usize,
146
}
147
148
impl<'a, T, M> MomentWindow<'a, T, M>
149
where
150
T: NativeType + ToPrimitive + IsFloat + FromPrimitive,
151
M: StateUpdate,
152
{
153
fn new_impl(
154
slice: &'a [T],
155
validity: Option<&'a Bitmap>,
156
params: Option<RollingFnParams>,
157
) -> Self {
158
Self {
159
slice,
160
validity,
161
moment: M::new(params),
162
non_finite_count: 0,
163
null_count: 0,
164
start: 0,
165
end: 0,
166
}
167
}
168
169
#[inline(always)]
170
fn reset(&mut self) {
171
self.moment.reset();
172
self.non_finite_count = 0;
173
self.null_count = 0;
174
}
175
176
#[inline(always)]
177
fn insert(&mut self, val: T) {
178
if val.is_finite() {
179
self.moment.insert_one(NumCast::from(val).unwrap());
180
} else {
181
self.moment.insert_one(0.0); // A hack to replicate ddof null behavior.
182
self.non_finite_count += 1;
183
}
184
}
185
186
#[inline(always)]
187
fn remove(&mut self, val: T) {
188
if val.is_finite() {
189
self.moment.remove_one(NumCast::from(val).unwrap());
190
} else {
191
self.moment.remove_one(0.0); // A hack to replicate ddof null behavior.
192
self.non_finite_count -= 1;
193
}
194
}
195
196
#[inline(always)]
197
fn get_moment(&self) -> Option<T> {
198
if self.non_finite_count > 0 {
199
self.moment
200
.finalize()
201
.map(|_v| T::from_f64(f64::NAN).unwrap())
202
} else {
203
self.moment.finalize().map(|v| T::from_f64(v).unwrap())
204
}
205
}
206
}
207
208
impl<T, M> RollingAggWindowNoNulls<T> for MomentWindow<'_, T, M>
209
where
210
T: NativeType + ToPrimitive + IsFloat + FromPrimitive,
211
M: StateUpdate,
212
{
213
type This<'a> = MomentWindow<'a, T, M>;
214
215
fn new<'a>(
216
slice: &'a [T],
217
start: usize,
218
end: usize,
219
params: Option<RollingFnParams>,
220
_window_size: Option<usize>,
221
) -> Self::This<'a> {
222
let mut out = MomentWindow::new_impl(slice, None, params);
223
unsafe { RollingAggWindowNoNulls::update(&mut out, start, end) };
224
out
225
}
226
227
// # Safety
228
// The start, end range must be in-bounds.
229
#[inline]
230
unsafe fn update(&mut self, new_start: usize, new_end: usize) {
231
if new_start >= self.end {
232
self.reset();
233
self.start = new_start;
234
self.end = new_start;
235
}
236
237
for val in &self.slice[self.start..new_start] {
238
self.remove(*val);
239
}
240
241
for val in &self.slice[self.end..new_end] {
242
self.insert(*val);
243
}
244
245
self.start = new_start;
246
self.end = new_end;
247
}
248
249
fn get_agg(&self, _idx: usize) -> Option<T> {
250
self.get_moment()
251
}
252
253
fn slice_len(&self) -> usize {
254
self.slice.len()
255
}
256
}
257
258
impl<T, M> RollingAggWindowNulls<T> for MomentWindow<'_, T, M>
259
where
260
T: NativeType + ToPrimitive + IsFloat + FromPrimitive,
261
M: StateUpdate,
262
{
263
type This<'a> = MomentWindow<'a, T, M>;
264
265
fn new<'a>(
266
slice: &'a [T],
267
validity: &'a Bitmap,
268
start: usize,
269
end: usize,
270
params: Option<RollingFnParams>,
271
_window_size: Option<usize>,
272
) -> Self::This<'a> {
273
assert!(start <= slice.len() && end <= slice.len() && start <= end);
274
let mut out = MomentWindow::new_impl(slice, Some(validity), params);
275
// SAFETY: We bounds checked `start` and `end`.
276
unsafe { RollingAggWindowNulls::update(&mut out, start, end) };
277
out
278
}
279
280
// # Safety
281
// The start, end range must be in-bounds.
282
#[inline]
283
unsafe fn update(&mut self, new_start: usize, new_end: usize) {
284
let validity = unsafe { self.validity.unwrap_unchecked() };
285
286
if new_start >= self.end {
287
self.reset();
288
self.start = new_start;
289
self.end = new_start;
290
}
291
292
for idx in self.start..new_start {
293
let valid = unsafe { validity.get_bit_unchecked(idx) };
294
if valid {
295
self.remove(unsafe { *self.slice.get_unchecked(idx) });
296
} else {
297
self.null_count -= 1;
298
}
299
}
300
301
for idx in self.end..new_end {
302
let valid = unsafe { validity.get_bit_unchecked(idx) };
303
if valid {
304
self.insert(unsafe { *self.slice.get_unchecked(idx) });
305
} else {
306
self.null_count += 1;
307
}
308
}
309
310
self.start = new_start;
311
self.end = new_end;
312
}
313
314
fn get_agg(&self, _idx: usize) -> Option<T> {
315
self.get_moment()
316
}
317
318
#[inline(always)]
319
fn is_valid(&self, min_periods: usize) -> bool {
320
((self.end - self.start) - self.null_count) >= min_periods
321
}
322
323
fn slice_len(&self) -> usize {
324
self.slice.len()
325
}
326
}
327
328