Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-compute/src/rolling/sum.rs
8421 views
1
use std::ops::{Add, AddAssign, Sub, SubAssign};
2
3
use super::no_nulls::RollingAggWindowNoNulls;
4
use super::nulls::RollingAggWindowNulls;
5
use super::*;
6
7
pub struct SumWindow<'a, T, S> {
8
slice: &'a [T],
9
validity: Option<&'a Bitmap>,
10
sum: S,
11
err_add: S,
12
err_sub: S,
13
non_finite_count: usize, // NaN or infinity.
14
pos_inf_count: usize,
15
neg_inf_count: usize,
16
pub(super) null_count: usize,
17
pub(super) start: usize,
18
pub(super) end: usize,
19
}
20
21
impl<'a, T, S> SumWindow<'a, T, S>
22
where
23
T: NativeType + IsFloat + Sub<Output = T> + NumCast + PartialOrd,
24
S: NativeType + AddAssign + SubAssign + Sub<Output = S> + Add<Output = S> + NumCast,
25
{
26
fn new_impl(slice: &'a [T], validity: Option<&'a Bitmap>) -> Self {
27
Self {
28
slice,
29
validity,
30
sum: S::zeroed(),
31
err_add: S::zeroed(),
32
err_sub: S::zeroed(),
33
non_finite_count: 0,
34
pos_inf_count: 0,
35
neg_inf_count: 0,
36
null_count: 0,
37
start: 0,
38
end: 0,
39
}
40
}
41
42
fn reset(&mut self) {
43
self.sum = S::zeroed();
44
self.err_add = S::zeroed();
45
self.err_sub = S::zeroed();
46
self.non_finite_count = 0;
47
self.pos_inf_count = 0;
48
self.neg_inf_count = 0;
49
self.null_count = 0;
50
}
51
52
fn add_finite_kahan(&mut self, val: T) {
53
let val: S = NumCast::from(val).unwrap();
54
let y = val - self.err_add;
55
let new_sum = self.sum + y;
56
self.err_add = (new_sum - self.sum) - y;
57
self.sum = new_sum;
58
}
59
60
fn sub_finite_kahan(&mut self, val: T) {
61
let val: S = NumCast::from(T::zeroed() - val).unwrap();
62
let y = val - self.err_sub;
63
let new_sum = self.sum + y;
64
self.err_sub = (new_sum - self.sum) - y;
65
self.sum = new_sum;
66
}
67
68
fn add(&mut self, val: T) {
69
if T::is_float() {
70
if val.is_finite() {
71
self.add_finite_kahan(val);
72
} else {
73
self.non_finite_count += 1;
74
self.pos_inf_count += (val > T::zeroed()) as usize;
75
self.neg_inf_count += (val < T::zeroed()) as usize;
76
}
77
} else {
78
let val: S = NumCast::from(val).unwrap();
79
self.sum += val;
80
}
81
}
82
83
fn sub(&mut self, val: T) {
84
if T::is_float() {
85
if val.is_finite() {
86
self.sub_finite_kahan(val);
87
} else {
88
self.non_finite_count -= 1;
89
self.pos_inf_count -= (val > T::zeroed()) as usize;
90
self.neg_inf_count -= (val < T::zeroed()) as usize;
91
}
92
} else {
93
let val: S = NumCast::from(val).unwrap();
94
self.sum -= val;
95
}
96
}
97
98
fn get_sum(&self) -> Option<T> {
99
if self.non_finite_count == 0 {
100
NumCast::from(self.sum)
101
} else if self.non_finite_count == self.pos_inf_count {
102
Some(T::pos_inf_value())
103
} else if self.non_finite_count == self.neg_inf_count {
104
Some(T::neg_inf_value())
105
} else {
106
Some(T::nan_value())
107
}
108
}
109
}
110
111
impl<T, S> RollingAggWindowNoNulls<T> for SumWindow<'_, T, S>
112
where
113
T: NativeType + IsFloat + Sub<Output = T> + NumCast + PartialOrd,
114
S: NativeType + AddAssign + SubAssign + Sub<Output = S> + Add<Output = S> + NumCast,
115
{
116
type This<'a> = SumWindow<'a, T, S>;
117
118
fn new<'a>(
119
slice: &'a [T],
120
start: usize,
121
end: usize,
122
_params: Option<RollingFnParams>,
123
_window_size: Option<usize>,
124
) -> Self::This<'a> {
125
let mut out = SumWindow::new_impl(slice, None);
126
unsafe { RollingAggWindowNoNulls::update(&mut out, start, end) };
127
out
128
}
129
130
// # Safety
131
// The start, end range must be in-bounds.
132
unsafe fn update(&mut self, new_start: usize, new_end: usize) {
133
if new_start >= self.end {
134
self.reset();
135
self.start = new_start;
136
self.end = new_start;
137
}
138
139
for val in &self.slice[self.start..new_start] {
140
self.sub(*val);
141
}
142
143
for val in &self.slice[self.end..new_end] {
144
self.add(*val);
145
}
146
147
self.start = new_start;
148
self.end = new_end;
149
}
150
151
fn get_agg(&self, _idx: usize) -> Option<T> {
152
self.get_sum()
153
}
154
155
fn slice_len(&self) -> usize {
156
self.slice.len()
157
}
158
}
159
160
impl<T, S> RollingAggWindowNulls<T> for SumWindow<'_, T, S>
161
where
162
T: NativeType + IsFloat + Sub<Output = T> + NumCast + PartialOrd,
163
S: NativeType + AddAssign + SubAssign + Sub<Output = S> + Add<Output = S> + NumCast,
164
{
165
type This<'a> = SumWindow<'a, T, S>;
166
167
fn new<'a>(
168
slice: &'a [T],
169
validity: &'a Bitmap,
170
start: usize,
171
end: usize,
172
_params: Option<RollingFnParams>,
173
_window_size: Option<usize>,
174
) -> Self::This<'a> {
175
assert!(start <= slice.len() && end <= slice.len() && start <= end);
176
let mut out = SumWindow::new_impl(slice, Some(validity));
177
// SAFETY: We bounds checked `start` and `end`.
178
unsafe { RollingAggWindowNulls::update(&mut out, start, end) };
179
out
180
}
181
182
// # Safety
183
// The start, end range must be in-bounds.
184
unsafe fn update(&mut self, new_start: usize, new_end: usize) {
185
let validity = unsafe { self.validity.unwrap_unchecked() };
186
187
if new_start >= self.end {
188
self.reset();
189
self.start = new_start;
190
self.end = new_start;
191
}
192
193
for idx in self.start..new_start {
194
let valid = unsafe { validity.get_bit_unchecked(idx) };
195
if valid {
196
self.sub(unsafe { *self.slice.get_unchecked(idx) });
197
} else {
198
self.null_count -= 1;
199
}
200
}
201
202
for idx in self.end..new_end {
203
let valid = unsafe { validity.get_bit_unchecked(idx) };
204
if valid {
205
self.add(unsafe { *self.slice.get_unchecked(idx) });
206
} else {
207
self.null_count += 1;
208
}
209
}
210
211
self.start = new_start;
212
self.end = new_end;
213
}
214
215
fn get_agg(&self, _idx: usize) -> Option<T> {
216
self.get_sum()
217
}
218
219
fn is_valid(&self, min_periods: usize) -> bool {
220
((self.end - self.start) - self.null_count) >= min_periods
221
}
222
223
fn slice_len(&self) -> usize {
224
self.slice.len()
225
}
226
}
227
228