Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-compute/src/rolling/nulls/mod.rs
8422 views
1
mod mean;
2
mod min_max;
3
mod moment;
4
mod quantile;
5
mod rank;
6
mod sum;
7
8
use arrow::legacy::utils::CustomIterTools;
9
pub use mean::*;
10
pub use min_max::*;
11
pub use moment::*;
12
pub use quantile::*;
13
pub use rank::*;
14
pub use sum::*;
15
16
use super::*;
17
18
pub trait RollingAggWindowNulls<T: NativeType, Out: NativeType = T> {
19
type This<'a>: RollingAggWindowNulls<T, Out>;
20
21
/// # Safety
22
/// `start` and `end` must be in bounds for `slice` and `validity`
23
fn new<'a>(
24
slice: &'a [T],
25
validity: &'a Bitmap,
26
start: usize,
27
end: usize,
28
params: Option<RollingFnParams>,
29
window_size: Option<usize>,
30
) -> Self::This<'a>;
31
32
/// # Safety
33
/// `start` and `end` must be in bounds of `slice` and `bitmap`
34
unsafe fn update(&mut self, new_start: usize, new_end: usize);
35
36
/// Get the aggregate of the current window relative to the value at `idx`.
37
fn get_agg(&self, idx: usize) -> Option<Out>;
38
39
/// Returns the length of the underlying input.
40
fn slice_len(&self) -> usize;
41
42
fn is_valid(&self, min_periods: usize) -> bool;
43
}
44
45
// Use an aggregation window that maintains the state
46
pub(super) fn rolling_apply_agg_window<Agg, T, Out, Fo>(
47
values: &[T],
48
validity: &Bitmap,
49
window_size: usize,
50
min_periods: usize,
51
det_offsets_fn: Fo,
52
params: Option<RollingFnParams>,
53
) -> ArrayRef
54
where
55
Agg: RollingAggWindowNulls<T, Out>,
56
T: NativeType,
57
Out: NativeType,
58
Fo: Fn(Idx, WindowSize, Len) -> (Start, End) + Copy,
59
{
60
let len = values.len();
61
let (start, end) = det_offsets_fn(0, window_size, len);
62
let mut agg_window = Agg::new(values, validity, start, end, params, Some(window_size));
63
64
let mut validity = create_validity(min_periods, len, window_size, det_offsets_fn)
65
.unwrap_or_else(|| {
66
let mut validity = MutableBitmap::with_capacity(len);
67
validity.extend_constant(len, true);
68
validity
69
});
70
71
let out = (0..len)
72
.map(|idx| {
73
let (start, end) = det_offsets_fn(idx, window_size, len);
74
// SAFETY:
75
// we are in bounds
76
unsafe { agg_window.update(start, end) };
77
match agg_window.get_agg(idx) {
78
Some(val) => {
79
if agg_window.is_valid(min_periods) {
80
val
81
} else {
82
// SAFETY: we are in bounds
83
unsafe { validity.set_unchecked(idx, false) };
84
Out::default()
85
}
86
},
87
None => {
88
// SAFETY: we are in bounds
89
unsafe { validity.set_unchecked(idx, false) };
90
Out::default()
91
},
92
}
93
})
94
.collect_trusted::<Vec<_>>();
95
96
Box::new(PrimitiveArray::new(
97
Out::PRIMITIVE.into(),
98
out.into(),
99
Some(validity.into()),
100
))
101
}
102
103
#[cfg(test)]
104
mod test {
105
use arrow::array::{Array, Int32Array};
106
use arrow::datatypes::ArrowDataType;
107
use polars_buffer::Buffer;
108
use polars_utils::min_max::MaxIgnoreNan;
109
110
use super::*;
111
use crate::rolling::min_max::MinMaxWindow;
112
113
fn get_null_arr() -> PrimitiveArray<f64> {
114
// 1, None, -1, 4
115
let buf = Buffer::from(vec![1.0, 0.0, -1.0, 4.0]);
116
PrimitiveArray::new(
117
ArrowDataType::Float64,
118
buf,
119
Some(Bitmap::from(&[true, false, true, true])),
120
)
121
}
122
123
#[test]
124
fn test_rolling_sum_nulls() {
125
let buf = Buffer::from(vec![1.0, 2.0, 3.0, 4.0]);
126
let arr = &PrimitiveArray::new(
127
ArrowDataType::Float64,
128
buf,
129
Some(Bitmap::from(&[true, false, true, true])),
130
);
131
132
let out = rolling_sum(arr, 2, 2, false, None, None);
133
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
134
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
135
assert_eq!(out, &[None, None, None, Some(7.0)]);
136
137
let out = rolling_sum(arr, 2, 1, false, None, None);
138
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
139
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
140
assert_eq!(out, &[Some(1.0), Some(1.0), Some(3.0), Some(7.0)]);
141
142
let out = rolling_sum(arr, 4, 1, false, None, None);
143
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
144
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
145
assert_eq!(out, &[Some(1.0), Some(1.0), Some(4.0), Some(8.0)]);
146
147
let out = rolling_sum(arr, 4, 1, true, None, None);
148
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
149
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
150
assert_eq!(out, &[Some(1.0), Some(4.0), Some(8.0), Some(7.0)]);
151
152
let out = rolling_sum(arr, 4, 4, true, None, None);
153
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
154
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
155
assert_eq!(out, &[None, None, None, None]);
156
}
157
158
#[test]
159
fn test_rolling_mean_nulls() {
160
let arr = get_null_arr();
161
let arr = &arr;
162
163
let out = rolling_mean(arr, 2, 2, false, None, None);
164
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
165
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
166
assert_eq!(out, &[None, None, None, Some(1.5)]);
167
168
let out = rolling_mean(arr, 2, 1, false, None, None);
169
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
170
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
171
assert_eq!(out, &[Some(1.0), Some(1.0), Some(-1.0), Some(1.5)]);
172
173
let out = rolling_mean(arr, 4, 1, false, None, None);
174
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
175
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
176
assert_eq!(out, &[Some(1.0), Some(1.0), Some(0.0), Some(4.0 / 3.0)]);
177
}
178
179
#[test]
180
fn test_rolling_var_nulls() {
181
let arr = get_null_arr();
182
let arr = &arr;
183
184
let out = rolling_var(arr, 3, 1, false, None, None);
185
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
186
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
187
188
assert_eq!(out, &[None, None, Some(2.0), Some(12.5)]);
189
190
let testpars = Some(RollingFnParams::Var(RollingVarParams { ddof: 0 }));
191
let out = rolling_var(arr, 3, 1, false, None, testpars);
192
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
193
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
194
195
assert_eq!(out, &[Some(0.0), Some(0.0), Some(1.0), Some(6.25)]);
196
197
let out = rolling_var(arr, 4, 1, false, None, None);
198
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
199
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
200
assert_eq!(out, &[None, None, Some(2.0), Some(6.333333333333334)]);
201
202
let out = rolling_var(arr, 4, 1, false, None, testpars);
203
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
204
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
205
assert_eq!(
206
out,
207
&[Some(0.), Some(0.0), Some(1.0), Some(4.222222222222222)]
208
);
209
}
210
211
#[test]
212
fn test_rolling_max_no_nulls() {
213
let buf = Buffer::from(vec![1.0, 2.0, 3.0, 4.0]);
214
let arr = &PrimitiveArray::new(
215
ArrowDataType::Float64,
216
buf,
217
Some(Bitmap::from(&[true, true, true, true])),
218
);
219
let out = rolling_max(arr, 4, 1, false, None, None);
220
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
221
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
222
assert_eq!(out, &[Some(1.0), Some(2.0), Some(3.0), Some(4.0)]);
223
224
let out = rolling_max(arr, 2, 2, false, None, None);
225
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
226
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
227
assert_eq!(out, &[None, Some(2.0), Some(3.0), Some(4.0)]);
228
229
let out = rolling_max(arr, 4, 4, false, None, None);
230
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
231
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
232
assert_eq!(out, &[None, None, None, Some(4.0)]);
233
234
let buf = Buffer::from(vec![4.0, 3.0, 2.0, 1.0]);
235
let arr = &PrimitiveArray::new(
236
ArrowDataType::Float64,
237
buf,
238
Some(Bitmap::from(&[true, true, true, true])),
239
);
240
let out = rolling_max(arr, 2, 1, false, None, None);
241
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
242
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
243
assert_eq!(out, &[Some(4.0), Some(4.0), Some(3.0), Some(2.0)]);
244
245
let out =
246
super::no_nulls::rolling_max(arr.values().as_slice(), 2, 1, false, None, None).unwrap();
247
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
248
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
249
assert_eq!(out, &[Some(4.0), Some(4.0), Some(3.0), Some(2.0)]);
250
}
251
252
#[test]
253
fn test_rolling_extrema_nulls() {
254
let vals = vec![3, 3, 3, 10, 10, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1];
255
let validity = Bitmap::new_with_value(true, vals.len());
256
let window_size = 3;
257
let min_periods = 3;
258
259
let arr = Int32Array::new(ArrowDataType::Int32, vals.into(), Some(validity));
260
261
let out = rolling_apply_agg_window::<MinMaxWindow<i32, MaxIgnoreNan>, _, _, _>(
262
arr.values().as_slice(),
263
arr.validity().as_ref().unwrap(),
264
window_size,
265
min_periods,
266
det_offsets,
267
None,
268
);
269
let arr = out.as_any().downcast_ref::<Int32Array>().unwrap();
270
assert_eq!(arr.null_count(), 2);
271
assert_eq!(
272
&arr.values().as_slice()[2..],
273
&[3, 10, 10, 10, 10, 10, 9, 8, 7, 6, 5, 4, 3]
274
);
275
}
276
}
277
278