Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-compute/src/rolling/no_nulls/sum.rs
8433 views
1
#![allow(unsafe_op_in_unsafe_fn)]
2
use super::super::sum::SumWindow;
3
use super::*;
4
5
pub fn rolling_sum<T>(
6
values: &[T],
7
window_size: usize,
8
min_periods: usize,
9
center: bool,
10
weights: Option<&[f64]>,
11
_params: Option<RollingFnParams>,
12
) -> PolarsResult<ArrayRef>
13
where
14
T: NativeType
15
+ std::iter::Sum
16
+ NumCast
17
+ Mul<Output = T>
18
+ AddAssign
19
+ SubAssign
20
+ IsFloat
21
+ Num
22
+ PartialOrd,
23
{
24
match (center, weights) {
25
(true, None) => rolling_apply_agg_window::<SumWindow<T, T>, _, _, _>(
26
values,
27
window_size,
28
min_periods,
29
det_offsets_center,
30
None,
31
),
32
(false, None) => rolling_apply_agg_window::<SumWindow<T, T>, _, _, _>(
33
values,
34
window_size,
35
min_periods,
36
det_offsets,
37
None,
38
),
39
(true, Some(weights)) => {
40
let weights = no_nulls::coerce_weights(weights);
41
no_nulls::rolling_apply_weights(
42
values,
43
window_size,
44
min_periods,
45
det_offsets_center,
46
no_nulls::compute_sum_weights,
47
&weights,
48
center,
49
)
50
},
51
(false, Some(weights)) => {
52
let weights = no_nulls::coerce_weights(weights);
53
no_nulls::rolling_apply_weights(
54
values,
55
window_size,
56
min_periods,
57
det_offsets,
58
no_nulls::compute_sum_weights,
59
&weights,
60
center,
61
)
62
},
63
}
64
}
65
66
#[cfg(test)]
67
mod test {
68
use super::*;
69
#[test]
70
fn test_rolling_sum() {
71
let values = &[1.0f64, 2.0, 3.0, 4.0];
72
73
let out = rolling_sum(values, 2, 2, false, None, None).unwrap();
74
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
75
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
76
assert_eq!(out, &[None, Some(3.0), Some(5.0), Some(7.0)]);
77
78
let out = rolling_sum(values, 2, 1, false, None, None).unwrap();
79
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
80
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
81
assert_eq!(out, &[Some(1.0), Some(3.0), Some(5.0), Some(7.0)]);
82
83
let out = rolling_sum(values, 4, 1, false, None, None).unwrap();
84
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
85
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
86
assert_eq!(out, &[Some(1.0), Some(3.0), Some(6.0), Some(10.0)]);
87
88
let out = rolling_sum(values, 4, 1, true, None, None).unwrap();
89
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
90
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
91
assert_eq!(out, &[Some(3.0), Some(6.0), Some(10.0), Some(9.0)]);
92
93
let out = rolling_sum(values, 4, 4, true, None, None).unwrap();
94
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
95
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
96
assert_eq!(out, &[None, None, Some(10.0), None]);
97
98
// test nan handling.
99
let values = &[1.0, 2.0, 3.0, f64::nan(), 5.0, 6.0, 7.0];
100
let out = rolling_sum(values, 3, 3, false, None, None).unwrap();
101
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
102
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
103
104
assert_eq!(
105
format!("{:?}", out.as_slice()),
106
format!(
107
"{:?}",
108
&[
109
None,
110
None,
111
Some(6.0),
112
Some(f64::nan()),
113
Some(f64::nan()),
114
Some(f64::nan()),
115
Some(18.0)
116
]
117
)
118
);
119
}
120
}
121
122