Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-compute/src/rolling/sum.rs
7884 views
1
use std::ops::{Add, AddAssign, Sub, SubAssign};
2
3
use super::no_nulls::RollingAggWindowNoNulls;
4
use super::nulls::RollingAggWindowNulls;
5
use super::*;
6
7
pub struct SumWindow<'a, T, S> {
8
slice: &'a [T],
9
validity: Option<&'a Bitmap>,
10
sum: S,
11
err_add: S,
12
err_sub: S,
13
non_finite_count: usize, // NaN or infinity.
14
pos_inf_count: usize,
15
neg_inf_count: usize,
16
pub(super) null_count: usize,
17
last_start: usize,
18
last_end: usize,
19
}
20
21
impl<'a, T, S> SumWindow<'a, T, S>
22
where
23
T: NativeType + IsFloat + Sub<Output = T> + NumCast + PartialOrd,
24
S: NativeType + AddAssign + SubAssign + Sub<Output = S> + Add<Output = S> + NumCast,
25
{
26
fn new_impl(slice: &'a [T], validity: Option<&'a Bitmap>) -> Self {
27
Self {
28
slice,
29
validity,
30
sum: S::zeroed(),
31
err_add: S::zeroed(),
32
err_sub: S::zeroed(),
33
non_finite_count: 0,
34
pos_inf_count: 0,
35
neg_inf_count: 0,
36
null_count: 0,
37
last_start: 0,
38
last_end: 0,
39
}
40
}
41
42
fn reset(&mut self) {
43
self.sum = S::zeroed();
44
self.err_add = S::zeroed();
45
self.err_sub = S::zeroed();
46
self.non_finite_count = 0;
47
self.pos_inf_count = 0;
48
self.neg_inf_count = 0;
49
self.null_count = 0;
50
}
51
52
fn add_finite_kahan(&mut self, val: T) {
53
let val: S = NumCast::from(val).unwrap();
54
let y = val - self.err_add;
55
let new_sum = self.sum + y;
56
self.err_add = (new_sum - self.sum) - y;
57
self.sum = new_sum;
58
}
59
60
fn sub_finite_kahan(&mut self, val: T) {
61
let val: S = NumCast::from(T::zeroed() - val).unwrap();
62
let y = val - self.err_sub;
63
let new_sum = self.sum + y;
64
self.err_sub = (new_sum - self.sum) - y;
65
self.sum = new_sum;
66
}
67
68
fn add(&mut self, val: T) {
69
if T::is_float() {
70
if val.is_finite() {
71
self.add_finite_kahan(val);
72
} else {
73
self.non_finite_count += 1;
74
self.pos_inf_count += (val > T::zeroed()) as usize;
75
self.neg_inf_count += (val < T::zeroed()) as usize;
76
}
77
} else {
78
let val: S = NumCast::from(val).unwrap();
79
self.sum += val;
80
}
81
}
82
83
fn sub(&mut self, val: T) {
84
if T::is_float() {
85
if val.is_finite() {
86
self.sub_finite_kahan(val);
87
} else {
88
self.non_finite_count -= 1;
89
self.pos_inf_count -= (val > T::zeroed()) as usize;
90
self.neg_inf_count -= (val < T::zeroed()) as usize;
91
}
92
} else {
93
let val: S = NumCast::from(val).unwrap();
94
self.sum -= val;
95
}
96
}
97
98
fn finalize(&self) -> Option<T> {
99
if self.non_finite_count == 0 {
100
NumCast::from(self.sum)
101
} else if self.non_finite_count == self.pos_inf_count {
102
Some(T::pos_inf_value())
103
} else if self.non_finite_count == self.neg_inf_count {
104
Some(T::neg_inf_value())
105
} else {
106
Some(T::nan_value())
107
}
108
}
109
}
110
111
impl<'a, T, S> RollingAggWindowNoNulls<'a, T> for SumWindow<'a, T, S>
112
where
113
T: NativeType + IsFloat + Sub<Output = T> + NumCast + PartialOrd,
114
S: NativeType + AddAssign + SubAssign + Sub<Output = S> + Add<Output = S> + NumCast,
115
{
116
fn new(
117
slice: &'a [T],
118
start: usize,
119
end: usize,
120
_params: Option<RollingFnParams>,
121
_window_size: Option<usize>,
122
) -> Self {
123
let mut out = Self::new_impl(slice, None);
124
unsafe { RollingAggWindowNoNulls::update(&mut out, start, end) };
125
out
126
}
127
128
// # Safety
129
// The start, end range must be in-bounds.
130
unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
131
if start >= self.last_end {
132
self.reset();
133
self.last_start = start;
134
self.last_end = start;
135
}
136
137
for val in &self.slice[self.last_start..start] {
138
self.sub(*val);
139
}
140
141
for val in &self.slice[self.last_end..end] {
142
self.add(*val);
143
}
144
145
self.last_start = start;
146
self.last_end = end;
147
self.finalize()
148
}
149
}
150
151
impl<'a, T, S> RollingAggWindowNulls<'a, T> for SumWindow<'a, T, S>
152
where
153
T: NativeType + IsFloat + Sub<Output = T> + NumCast + PartialOrd,
154
S: NativeType + AddAssign + SubAssign + Sub<Output = S> + Add<Output = S> + NumCast,
155
{
156
unsafe fn new(
157
slice: &'a [T],
158
validity: &'a Bitmap,
159
start: usize,
160
end: usize,
161
_params: Option<RollingFnParams>,
162
_window_size: Option<usize>,
163
) -> Self {
164
let mut out = Self::new_impl(slice, Some(validity));
165
unsafe { RollingAggWindowNulls::update(&mut out, start, end) };
166
out
167
}
168
169
// # Safety
170
// The start, end range must be in-bounds.
171
unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
172
let validity = unsafe { self.validity.unwrap_unchecked() };
173
174
if start >= self.last_end {
175
self.reset();
176
self.last_start = start;
177
self.last_end = start;
178
}
179
180
for idx in self.last_start..start {
181
let valid = unsafe { validity.get_bit_unchecked(idx) };
182
if valid {
183
self.sub(unsafe { *self.slice.get_unchecked(idx) });
184
} else {
185
self.null_count -= 1;
186
}
187
}
188
189
for idx in self.last_end..end {
190
let valid = unsafe { validity.get_bit_unchecked(idx) };
191
if valid {
192
self.add(unsafe { *self.slice.get_unchecked(idx) });
193
} else {
194
self.null_count += 1;
195
}
196
}
197
198
self.last_start = start;
199
self.last_end = end;
200
self.finalize()
201
}
202
203
fn is_valid(&self, min_periods: usize) -> bool {
204
((self.last_end - self.last_start) - self.null_count) >= min_periods
205
}
206
}
207
208