Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-compute/src/rolling/nulls/moment.rs
6939 views
1
#![allow(unsafe_op_in_unsafe_fn)]
2
3
use num_traits::{FromPrimitive, ToPrimitive};
4
5
pub use super::super::moment::*;
6
use super::*;
7
8
pub struct MomentWindow<'a, T, M: StateUpdate> {
9
slice: &'a [T],
10
validity: &'a Bitmap,
11
moment: Option<M>,
12
last_start: usize,
13
last_end: usize,
14
null_count: usize,
15
params: Option<RollingFnParams>,
16
}
17
18
impl<T: NativeType + ToPrimitive, M: StateUpdate> MomentWindow<'_, T, M> {
19
// compute sum from the entire window
20
unsafe fn compute_moment_and_null_count(&mut self, start: usize, end: usize) {
21
self.moment = None;
22
let mut idx = start;
23
self.null_count = 0;
24
for value in &self.slice[start..end] {
25
let valid = self.validity.get_bit_unchecked(idx);
26
if valid {
27
let value: f64 = NumCast::from(*value).unwrap();
28
self.moment
29
.get_or_insert_with(|| M::new(self.params))
30
.insert_one(value);
31
} else {
32
self.null_count += 1;
33
}
34
idx += 1;
35
}
36
}
37
}
38
39
impl<'a, T: NativeType + ToPrimitive + IsFloat + FromPrimitive, M: StateUpdate>
40
RollingAggWindowNulls<'a, T> for MomentWindow<'a, T, M>
41
{
42
unsafe fn new(
43
slice: &'a [T],
44
validity: &'a Bitmap,
45
start: usize,
46
end: usize,
47
params: Option<RollingFnParams>,
48
_window_size: Option<usize>,
49
) -> Self {
50
let mut out = Self {
51
slice,
52
validity,
53
moment: None,
54
last_start: start,
55
last_end: end,
56
null_count: 0,
57
params,
58
};
59
out.compute_moment_and_null_count(start, end);
60
out
61
}
62
63
unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
64
let recompute_var = if start >= self.last_end {
65
true
66
} else {
67
// remove elements that should leave the window
68
let mut recompute_var = false;
69
for idx in self.last_start..start {
70
// SAFETY:
71
// we are in bounds
72
let valid = self.validity.get_bit_unchecked(idx);
73
if valid {
74
let leaving_value = *self.slice.get_unchecked(idx);
75
76
// if the leaving value is nan we need to recompute the window
77
if T::is_float() && !leaving_value.is_finite() {
78
recompute_var = true;
79
break;
80
}
81
let leaving_value: f64 = NumCast::from(leaving_value).unwrap();
82
if let Some(v) = self.moment.as_mut() {
83
v.remove_one(leaving_value)
84
}
85
} else {
86
// null value leaving the window
87
self.null_count -= 1;
88
89
// self.sum is None and the leaving value is None
90
// if the entering value is valid, we might get a new sum.
91
if self.moment.is_none() {
92
recompute_var = true;
93
break;
94
}
95
}
96
}
97
recompute_var
98
};
99
100
self.last_start = start;
101
102
// we traverse all values and compute
103
if recompute_var {
104
self.compute_moment_and_null_count(start, end);
105
} else {
106
for idx in self.last_end..end {
107
let valid = self.validity.get_bit_unchecked(idx);
108
109
if valid {
110
let entering_value = *self.slice.get_unchecked(idx);
111
let entering_value: f64 = NumCast::from(entering_value).unwrap();
112
self.moment
113
.get_or_insert_with(|| M::new(self.params))
114
.insert_one(entering_value);
115
} else {
116
// null value entering the window
117
self.null_count += 1;
118
}
119
}
120
}
121
self.last_end = end;
122
self.moment.as_ref().and_then(|v| {
123
let out = v.finalize();
124
out.map(|v| T::from_f64(v).unwrap())
125
})
126
}
127
128
fn is_valid(&self, min_periods: usize) -> bool {
129
((self.last_end - self.last_start) - self.null_count) >= min_periods
130
}
131
}
132
133
pub fn rolling_var<T>(
134
arr: &PrimitiveArray<T>,
135
window_size: usize,
136
min_periods: usize,
137
center: bool,
138
weights: Option<&[f64]>,
139
params: Option<RollingFnParams>,
140
) -> ArrayRef
141
where
142
T: NativeType + ToPrimitive + FromPrimitive + IsFloat + Float,
143
{
144
if weights.is_some() {
145
panic!("weights not yet supported on array with null values")
146
}
147
let offsets_fn = if center {
148
det_offsets_center
149
} else {
150
det_offsets
151
};
152
rolling_apply_agg_window::<MomentWindow<_, VarianceMoment>, _, _>(
153
arr.values().as_slice(),
154
arr.validity().as_ref().unwrap(),
155
window_size,
156
min_periods,
157
offsets_fn,
158
params,
159
)
160
}
161
162
pub fn rolling_skew<T>(
163
arr: &PrimitiveArray<T>,
164
window_size: usize,
165
min_periods: usize,
166
center: bool,
167
params: Option<RollingFnParams>,
168
) -> ArrayRef
169
where
170
T: NativeType + ToPrimitive + FromPrimitive + IsFloat + Float,
171
{
172
let offsets_fn = if center {
173
det_offsets_center
174
} else {
175
det_offsets
176
};
177
rolling_apply_agg_window::<MomentWindow<_, SkewMoment>, _, _>(
178
arr.values().as_slice(),
179
arr.validity().as_ref().unwrap(),
180
window_size,
181
min_periods,
182
offsets_fn,
183
params,
184
)
185
}
186
187
pub fn rolling_kurtosis<T>(
188
arr: &PrimitiveArray<T>,
189
window_size: usize,
190
min_periods: usize,
191
center: bool,
192
params: Option<RollingFnParams>,
193
) -> ArrayRef
194
where
195
T: NativeType + ToPrimitive + FromPrimitive + IsFloat + Float,
196
{
197
let offsets_fn = if center {
198
det_offsets_center
199
} else {
200
det_offsets
201
};
202
rolling_apply_agg_window::<MomentWindow<_, KurtosisMoment>, _, _>(
203
arr.values().as_slice(),
204
arr.validity().as_ref().unwrap(),
205
window_size,
206
min_periods,
207
offsets_fn,
208
params,
209
)
210
}
211
212