Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-compute/src/rolling/no_nulls/sum.rs
6939 views
1
#![allow(unsafe_op_in_unsafe_fn)]
2
use super::*;
3
4
pub struct SumWindow<'a, T, S> {
5
slice: &'a [T],
6
sum: S,
7
err: S,
8
non_finite_count: usize, // NaN or infinity.
9
pos_inf_count: usize,
10
neg_inf_count: usize,
11
last_start: usize,
12
last_end: usize,
13
}
14
15
impl<T, S> SumWindow<'_, T, S>
16
where
17
T: NativeType + IsFloat + Sub<Output = T> + NumCast + PartialOrd,
18
S: NativeType + AddAssign + SubAssign + Sub<Output = S> + Add<Output = S> + NumCast,
19
{
20
fn add_finite_kahan(&mut self, val: T) {
21
let val: S = NumCast::from(val).unwrap();
22
let y = val - self.err;
23
let new_sum = self.sum + y;
24
self.err = (new_sum - self.sum) - y;
25
self.sum = new_sum;
26
}
27
28
fn add(&mut self, val: T) {
29
if T::is_float() {
30
if val.is_finite() {
31
self.add_finite_kahan(val);
32
} else {
33
self.non_finite_count += 1;
34
self.pos_inf_count += (val > T::zeroed()) as usize;
35
self.neg_inf_count += (val < T::zeroed()) as usize;
36
}
37
} else {
38
let val: S = NumCast::from(val).unwrap();
39
self.sum += val;
40
}
41
}
42
43
fn sub(&mut self, val: T) {
44
if T::is_float() {
45
if val.is_finite() {
46
self.add_finite_kahan(T::zeroed() - val);
47
} else {
48
self.non_finite_count -= 1;
49
self.pos_inf_count -= (val > T::zeroed()) as usize;
50
self.neg_inf_count -= (val < T::zeroed()) as usize;
51
}
52
} else {
53
let val: S = NumCast::from(val).unwrap();
54
self.sum -= val;
55
}
56
}
57
}
58
59
impl<'a, T, S> RollingAggWindowNoNulls<'a, T> for SumWindow<'a, T, S>
60
where
61
T: NativeType + IsFloat + Sub<Output = T> + NumCast + PartialOrd,
62
S: NativeType + AddAssign + SubAssign + Sub<Output = S> + Add<Output = S> + NumCast,
63
{
64
fn new(
65
slice: &'a [T],
66
start: usize,
67
end: usize,
68
_params: Option<RollingFnParams>,
69
_window_size: Option<usize>,
70
) -> Self {
71
let mut out = Self {
72
slice,
73
sum: S::zeroed(),
74
err: S::zeroed(),
75
non_finite_count: 0,
76
pos_inf_count: 0,
77
neg_inf_count: 0,
78
last_start: 0,
79
last_end: 0,
80
};
81
unsafe { out.update(start, end) };
82
out
83
}
84
85
// # Safety
86
// The start, end range must be in-bounds.
87
unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
88
if start >= self.last_end {
89
self.sum = S::zeroed();
90
self.err = S::zeroed();
91
self.non_finite_count = 0;
92
self.pos_inf_count = 0;
93
self.neg_inf_count = 0;
94
self.last_start = start;
95
self.last_end = start;
96
}
97
98
for val in &self.slice[self.last_start..start] {
99
self.sub(*val);
100
}
101
102
for val in &self.slice[self.last_end..end] {
103
self.add(*val);
104
}
105
106
self.last_start = start;
107
self.last_end = end;
108
if self.non_finite_count == 0 {
109
NumCast::from(self.sum)
110
} else if self.non_finite_count == self.pos_inf_count {
111
Some(T::pos_inf_value())
112
} else if self.non_finite_count == self.neg_inf_count {
113
Some(T::neg_inf_value())
114
} else {
115
Some(T::nan_value())
116
}
117
}
118
}
119
120
pub fn rolling_sum<T>(
121
values: &[T],
122
window_size: usize,
123
min_periods: usize,
124
center: bool,
125
weights: Option<&[f64]>,
126
_params: Option<RollingFnParams>,
127
) -> PolarsResult<ArrayRef>
128
where
129
T: NativeType
130
+ std::iter::Sum
131
+ NumCast
132
+ Mul<Output = T>
133
+ AddAssign
134
+ SubAssign
135
+ IsFloat
136
+ Num
137
+ PartialOrd,
138
{
139
match (center, weights) {
140
(true, None) => rolling_apply_agg_window::<SumWindow<T, T>, _, _>(
141
values,
142
window_size,
143
min_periods,
144
det_offsets_center,
145
None,
146
),
147
(false, None) => rolling_apply_agg_window::<SumWindow<T, T>, _, _>(
148
values,
149
window_size,
150
min_periods,
151
det_offsets,
152
None,
153
),
154
(true, Some(weights)) => {
155
let weights = no_nulls::coerce_weights(weights);
156
no_nulls::rolling_apply_weights(
157
values,
158
window_size,
159
min_periods,
160
det_offsets_center,
161
no_nulls::compute_sum_weights,
162
&weights,
163
center,
164
)
165
},
166
(false, Some(weights)) => {
167
let weights = no_nulls::coerce_weights(weights);
168
no_nulls::rolling_apply_weights(
169
values,
170
window_size,
171
min_periods,
172
det_offsets,
173
no_nulls::compute_sum_weights,
174
&weights,
175
center,
176
)
177
},
178
}
179
}
180
181
#[cfg(test)]
182
mod test {
183
use super::*;
184
#[test]
185
fn test_rolling_sum() {
186
let values = &[1.0f64, 2.0, 3.0, 4.0];
187
188
let out = rolling_sum(values, 2, 2, false, None, None).unwrap();
189
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
190
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
191
assert_eq!(out, &[None, Some(3.0), Some(5.0), Some(7.0)]);
192
193
let out = rolling_sum(values, 2, 1, false, None, None).unwrap();
194
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
195
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
196
assert_eq!(out, &[Some(1.0), Some(3.0), Some(5.0), Some(7.0)]);
197
198
let out = rolling_sum(values, 4, 1, false, None, None).unwrap();
199
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
200
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
201
assert_eq!(out, &[Some(1.0), Some(3.0), Some(6.0), Some(10.0)]);
202
203
let out = rolling_sum(values, 4, 1, true, None, None).unwrap();
204
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
205
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
206
assert_eq!(out, &[Some(3.0), Some(6.0), Some(10.0), Some(9.0)]);
207
208
let out = rolling_sum(values, 4, 4, true, None, None).unwrap();
209
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
210
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
211
assert_eq!(out, &[None, None, Some(10.0), None]);
212
213
// test nan handling.
214
let values = &[1.0, 2.0, 3.0, f64::nan(), 5.0, 6.0, 7.0];
215
let out = rolling_sum(values, 3, 3, false, None, None).unwrap();
216
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
217
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
218
219
assert_eq!(
220
format!("{:?}", out.as_slice()),
221
format!(
222
"{:?}",
223
&[
224
None,
225
None,
226
Some(6.0),
227
Some(f64::nan()),
228
Some(f64::nan()),
229
Some(f64::nan()),
230
Some(18.0)
231
]
232
)
233
);
234
}
235
}
236
237