Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-compute/src/rolling/nulls/mean.rs
6939 views
1
#![allow(unsafe_op_in_unsafe_fn)]
2
use super::*;
3
4
pub struct MeanWindow<'a, T> {
5
sum: SumWindow<'a, T, f64>,
6
}
7
8
impl<
9
'a,
10
T: NativeType
11
+ IsFloat
12
+ Add<Output = T>
13
+ Sub<Output = T>
14
+ NumCast
15
+ Div<Output = T>
16
+ AddAssign
17
+ SubAssign
18
+ PartialOrd,
19
> RollingAggWindowNulls<'a, T> for MeanWindow<'a, T>
20
{
21
unsafe fn new(
22
slice: &'a [T],
23
validity: &'a Bitmap,
24
start: usize,
25
end: usize,
26
params: Option<RollingFnParams>,
27
window_size: Option<usize>,
28
) -> Self {
29
Self {
30
sum: SumWindow::new(slice, validity, start, end, params, window_size),
31
}
32
}
33
34
unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
35
let sum = self.sum.update(start, end);
36
let len = end - start;
37
if self.sum.null_count == len {
38
None
39
} else {
40
sum.map(|sum| sum / NumCast::from(end - start - self.sum.null_count).unwrap())
41
}
42
}
43
fn is_valid(&self, min_periods: usize) -> bool {
44
self.sum.is_valid(min_periods)
45
}
46
}
47
48
pub fn rolling_mean<T>(
49
arr: &PrimitiveArray<T>,
50
window_size: usize,
51
min_periods: usize,
52
center: bool,
53
weights: Option<&[f64]>,
54
_params: Option<RollingFnParams>,
55
) -> ArrayRef
56
where
57
T: NativeType
58
+ IsFloat
59
+ PartialOrd
60
+ Add<Output = T>
61
+ Sub<Output = T>
62
+ NumCast
63
+ AddAssign
64
+ SubAssign
65
+ Div<Output = T>,
66
{
67
if weights.is_some() {
68
panic!("weights not yet supported on array with null values")
69
}
70
if center {
71
rolling_apply_agg_window::<MeanWindow<_>, _, _>(
72
arr.values().as_slice(),
73
arr.validity().as_ref().unwrap(),
74
window_size,
75
min_periods,
76
det_offsets_center,
77
None,
78
)
79
} else {
80
rolling_apply_agg_window::<MeanWindow<_>, _, _>(
81
arr.values().as_slice(),
82
arr.validity().as_ref().unwrap(),
83
window_size,
84
min_periods,
85
det_offsets,
86
None,
87
)
88
}
89
}
90
91