Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-compute/src/rolling/nulls/sum.rs
6939 views
1
#![allow(unsafe_op_in_unsafe_fn)]
2
use super::*;
3
4
pub struct SumWindow<'a, T, S> {
5
slice: &'a [T],
6
validity: &'a Bitmap,
7
sum: S,
8
err: S,
9
non_finite_count: usize, // NaN or infinity.
10
pos_inf_count: usize,
11
neg_inf_count: usize,
12
pub(super) null_count: usize,
13
last_start: usize,
14
last_end: usize,
15
}
16
17
impl<T, S> SumWindow<'_, T, S>
18
where
19
T: NativeType + IsFloat + Sub<Output = T> + NumCast + PartialOrd,
20
S: NativeType + AddAssign + SubAssign + Sub<Output = S> + Add<Output = S> + NumCast,
21
{
22
fn add_finite_kahan(&mut self, val: T) {
23
let val: S = NumCast::from(val).unwrap();
24
let y = val - self.err;
25
let new_sum = self.sum + y;
26
self.err = (new_sum - self.sum) - y;
27
self.sum = new_sum;
28
}
29
30
fn add(&mut self, val: T) {
31
if T::is_float() {
32
if val.is_finite() {
33
self.add_finite_kahan(val);
34
} else {
35
self.non_finite_count += 1;
36
self.pos_inf_count += (val > T::zeroed()) as usize;
37
self.neg_inf_count += (val < T::zeroed()) as usize;
38
}
39
} else {
40
let val: S = NumCast::from(val).unwrap();
41
self.sum += val;
42
}
43
}
44
45
fn sub(&mut self, val: T) {
46
if T::is_float() {
47
if val.is_finite() {
48
self.add_finite_kahan(T::zeroed() - val);
49
} else {
50
self.non_finite_count -= 1;
51
self.pos_inf_count -= (val > T::zeroed()) as usize;
52
self.neg_inf_count -= (val < T::zeroed()) as usize;
53
}
54
} else {
55
let val: S = NumCast::from(val).unwrap();
56
self.sum -= val;
57
}
58
}
59
}
60
61
impl<'a, T, S> RollingAggWindowNulls<'a, T> for SumWindow<'a, T, S>
62
where
63
T: NativeType + IsFloat + Sub<Output = T> + NumCast + PartialOrd,
64
S: NativeType + AddAssign + SubAssign + Sub<Output = S> + Add<Output = S> + NumCast,
65
{
66
unsafe fn new(
67
slice: &'a [T],
68
validity: &'a Bitmap,
69
start: usize,
70
end: usize,
71
_params: Option<RollingFnParams>,
72
_window_size: Option<usize>,
73
) -> Self {
74
let mut out = Self {
75
slice,
76
validity,
77
sum: S::zeroed(),
78
err: S::zeroed(),
79
non_finite_count: 0,
80
pos_inf_count: 0,
81
neg_inf_count: 0,
82
last_start: 0,
83
last_end: 0,
84
null_count: 0,
85
};
86
out.update(start, end);
87
out
88
}
89
90
// # Safety
91
// The start, end range must be in-bounds.
92
unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
93
if start >= self.last_end {
94
self.sum = S::zeroed();
95
self.err = S::zeroed();
96
self.non_finite_count = 0;
97
self.pos_inf_count = 0;
98
self.neg_inf_count = 0;
99
self.null_count = 0;
100
self.last_start = start;
101
self.last_end = start;
102
}
103
104
for idx in self.last_start..start {
105
let valid = self.validity.get_bit_unchecked(idx);
106
if valid {
107
self.sub(unsafe { *self.slice.get_unchecked(idx) });
108
} else {
109
self.null_count -= 1;
110
}
111
}
112
113
for idx in self.last_end..end {
114
let valid = self.validity.get_bit_unchecked(idx);
115
if valid {
116
self.add(unsafe { *self.slice.get_unchecked(idx) });
117
} else {
118
self.null_count += 1;
119
}
120
}
121
122
self.last_start = start;
123
self.last_end = end;
124
if self.non_finite_count == 0 {
125
NumCast::from(self.sum)
126
} else if self.non_finite_count == self.pos_inf_count {
127
Some(T::pos_inf_value())
128
} else if self.non_finite_count == self.neg_inf_count {
129
Some(T::neg_inf_value())
130
} else {
131
Some(T::nan_value())
132
}
133
}
134
135
fn is_valid(&self, min_periods: usize) -> bool {
136
((self.last_end - self.last_start) - self.null_count) >= min_periods
137
}
138
}
139
140
pub fn rolling_sum<T>(
141
arr: &PrimitiveArray<T>,
142
window_size: usize,
143
min_periods: usize,
144
center: bool,
145
weights: Option<&[f64]>,
146
_params: Option<RollingFnParams>,
147
) -> ArrayRef
148
where
149
T: NativeType
150
+ IsFloat
151
+ PartialOrd
152
+ Add<Output = T>
153
+ Sub<Output = T>
154
+ SubAssign
155
+ AddAssign
156
+ NumCast,
157
{
158
if weights.is_some() {
159
panic!("weights not yet supported on array with null values")
160
}
161
if center {
162
rolling_apply_agg_window::<SumWindow<T, T>, _, _>(
163
arr.values().as_slice(),
164
arr.validity().as_ref().unwrap(),
165
window_size,
166
min_periods,
167
det_offsets_center,
168
None,
169
)
170
} else {
171
rolling_apply_agg_window::<SumWindow<T, T>, _, _>(
172
arr.values().as_slice(),
173
arr.validity().as_ref().unwrap(),
174
window_size,
175
min_periods,
176
det_offsets,
177
None,
178
)
179
}
180
}
181
182