Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-compute/src/rolling/nulls/mod.rs
6939 views
1
mod mean;
2
mod min_max;
3
mod moment;
4
mod quantile;
5
mod sum;
6
7
use arrow::legacy::utils::CustomIterTools;
8
pub use mean::*;
9
pub use min_max::*;
10
pub use moment::*;
11
pub use quantile::*;
12
pub use sum::*;
13
14
use super::*;
15
16
pub trait RollingAggWindowNulls<'a, T: NativeType> {
17
/// # Safety
18
/// `start` and `end` must be in bounds for `slice` and `validity`
19
unsafe fn new(
20
slice: &'a [T],
21
validity: &'a Bitmap,
22
start: usize,
23
end: usize,
24
params: Option<RollingFnParams>,
25
window_size: Option<usize>,
26
) -> Self;
27
28
/// # Safety
29
/// `start` and `end` must be in bounds of `slice` and `bitmap`
30
unsafe fn update(&mut self, start: usize, end: usize) -> Option<T>;
31
32
fn is_valid(&self, min_periods: usize) -> bool;
33
}
34
35
// Use an aggregation window that maintains the state
36
pub(super) fn rolling_apply_agg_window<'a, Agg, T, Fo>(
37
values: &'a [T],
38
validity: &'a Bitmap,
39
window_size: usize,
40
min_periods: usize,
41
det_offsets_fn: Fo,
42
params: Option<RollingFnParams>,
43
) -> ArrayRef
44
where
45
Fo: Fn(Idx, WindowSize, Len) -> (Start, End) + Copy,
46
Agg: RollingAggWindowNulls<'a, T>,
47
T: IsFloat + NativeType,
48
{
49
let len = values.len();
50
let (start, end) = det_offsets_fn(0, window_size, len);
51
// SAFETY; we are in bounds
52
let mut agg_window =
53
unsafe { Agg::new(values, validity, start, end, params, Some(window_size)) };
54
55
let mut validity = create_validity(min_periods, len, window_size, det_offsets_fn)
56
.unwrap_or_else(|| {
57
let mut validity = MutableBitmap::with_capacity(len);
58
validity.extend_constant(len, true);
59
validity
60
});
61
62
let out = (0..len)
63
.map(|idx| {
64
let (start, end) = det_offsets_fn(idx, window_size, len);
65
// SAFETY:
66
// we are in bounds
67
let agg = unsafe { agg_window.update(start, end) };
68
match agg {
69
Some(val) => {
70
if agg_window.is_valid(min_periods) {
71
val
72
} else {
73
// SAFETY: we are in bounds
74
unsafe { validity.set_unchecked(idx, false) };
75
T::default()
76
}
77
},
78
None => {
79
// SAFETY: we are in bounds
80
unsafe { validity.set_unchecked(idx, false) };
81
T::default()
82
},
83
}
84
})
85
.collect_trusted::<Vec<_>>();
86
87
Box::new(PrimitiveArray::new(
88
T::PRIMITIVE.into(),
89
out.into(),
90
Some(validity.into()),
91
))
92
}
93
94
#[cfg(test)]
95
mod test {
96
use arrow::array::{Array, Int32Array};
97
use arrow::buffer::Buffer;
98
use arrow::datatypes::ArrowDataType;
99
use polars_utils::min_max::MaxIgnoreNan;
100
101
use super::*;
102
use crate::rolling::min_max::MinMaxWindow;
103
104
fn get_null_arr() -> PrimitiveArray<f64> {
105
// 1, None, -1, 4
106
let buf = Buffer::from(vec![1.0, 0.0, -1.0, 4.0]);
107
PrimitiveArray::new(
108
ArrowDataType::Float64,
109
buf,
110
Some(Bitmap::from(&[true, false, true, true])),
111
)
112
}
113
114
#[test]
115
fn test_rolling_sum_nulls() {
116
let buf = Buffer::from(vec![1.0, 2.0, 3.0, 4.0]);
117
let arr = &PrimitiveArray::new(
118
ArrowDataType::Float64,
119
buf,
120
Some(Bitmap::from(&[true, false, true, true])),
121
);
122
123
let out = rolling_sum(arr, 2, 2, false, None, None);
124
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
125
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
126
assert_eq!(out, &[None, None, None, Some(7.0)]);
127
128
let out = rolling_sum(arr, 2, 1, false, None, None);
129
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
130
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
131
assert_eq!(out, &[Some(1.0), Some(1.0), Some(3.0), Some(7.0)]);
132
133
let out = rolling_sum(arr, 4, 1, false, None, None);
134
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
135
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
136
assert_eq!(out, &[Some(1.0), Some(1.0), Some(4.0), Some(8.0)]);
137
138
let out = rolling_sum(arr, 4, 1, true, None, None);
139
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
140
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
141
assert_eq!(out, &[Some(1.0), Some(4.0), Some(8.0), Some(7.0)]);
142
143
let out = rolling_sum(arr, 4, 4, true, None, None);
144
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
145
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
146
assert_eq!(out, &[None, None, None, None]);
147
}
148
149
#[test]
150
fn test_rolling_mean_nulls() {
151
let arr = get_null_arr();
152
let arr = &arr;
153
154
let out = rolling_mean(arr, 2, 2, false, None, None);
155
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
156
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
157
assert_eq!(out, &[None, None, None, Some(1.5)]);
158
159
let out = rolling_mean(arr, 2, 1, false, None, None);
160
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
161
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
162
assert_eq!(out, &[Some(1.0), Some(1.0), Some(-1.0), Some(1.5)]);
163
164
let out = rolling_mean(arr, 4, 1, false, None, None);
165
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
166
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
167
assert_eq!(out, &[Some(1.0), Some(1.0), Some(0.0), Some(4.0 / 3.0)]);
168
}
169
170
#[test]
171
fn test_rolling_var_nulls() {
172
let arr = get_null_arr();
173
let arr = &arr;
174
175
let out = rolling_var(arr, 3, 1, false, None, None);
176
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
177
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
178
179
assert_eq!(out, &[None, None, Some(2.0), Some(12.5)]);
180
181
let testpars = Some(RollingFnParams::Var(RollingVarParams { ddof: 0 }));
182
let out = rolling_var(arr, 3, 1, false, None, testpars);
183
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
184
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
185
186
assert_eq!(out, &[Some(0.0), Some(0.0), Some(1.0), Some(6.25)]);
187
188
let out = rolling_var(arr, 4, 1, false, None, None);
189
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
190
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
191
assert_eq!(out, &[None, None, Some(2.0), Some(6.333333333333334)]);
192
193
let out = rolling_var(arr, 4, 1, false, None, testpars);
194
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
195
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
196
assert_eq!(
197
out,
198
&[Some(0.), Some(0.0), Some(1.0), Some(4.222222222222222)]
199
);
200
}
201
202
#[test]
203
fn test_rolling_max_no_nulls() {
204
let buf = Buffer::from(vec![1.0, 2.0, 3.0, 4.0]);
205
let arr = &PrimitiveArray::new(
206
ArrowDataType::Float64,
207
buf,
208
Some(Bitmap::from(&[true, true, true, true])),
209
);
210
let out = rolling_max(arr, 4, 1, false, None, None);
211
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
212
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
213
assert_eq!(out, &[Some(1.0), Some(2.0), Some(3.0), Some(4.0)]);
214
215
let out = rolling_max(arr, 2, 2, false, None, None);
216
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
217
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
218
assert_eq!(out, &[None, Some(2.0), Some(3.0), Some(4.0)]);
219
220
let out = rolling_max(arr, 4, 4, false, None, None);
221
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
222
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
223
assert_eq!(out, &[None, None, None, Some(4.0)]);
224
225
let buf = Buffer::from(vec![4.0, 3.0, 2.0, 1.0]);
226
let arr = &PrimitiveArray::new(
227
ArrowDataType::Float64,
228
buf,
229
Some(Bitmap::from(&[true, true, true, true])),
230
);
231
let out = rolling_max(arr, 2, 1, false, None, None);
232
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
233
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
234
assert_eq!(out, &[Some(4.0), Some(4.0), Some(3.0), Some(2.0)]);
235
236
let out =
237
super::no_nulls::rolling_max(arr.values().as_slice(), 2, 1, false, None, None).unwrap();
238
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
239
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
240
assert_eq!(out, &[Some(4.0), Some(4.0), Some(3.0), Some(2.0)]);
241
}
242
243
#[test]
244
fn test_rolling_extrema_nulls() {
245
let vals = vec![3, 3, 3, 10, 10, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1];
246
let validity = Bitmap::new_with_value(true, vals.len());
247
let window_size = 3;
248
let min_periods = 3;
249
250
let arr = Int32Array::new(ArrowDataType::Int32, vals.into(), Some(validity));
251
252
let out = rolling_apply_agg_window::<MinMaxWindow<i32, MaxIgnoreNan>, _, _>(
253
arr.values().as_slice(),
254
arr.validity().as_ref().unwrap(),
255
window_size,
256
min_periods,
257
det_offsets,
258
None,
259
);
260
let arr = out.as_any().downcast_ref::<Int32Array>().unwrap();
261
assert_eq!(arr.null_count(), 2);
262
assert_eq!(
263
&arr.values().as_slice()[2..],
264
&[3, 10, 10, 10, 10, 10, 9, 8, 7, 6, 5, 4, 3]
265
);
266
}
267
}
268
269