Path: blob/main/crates/polars-compute/src/rolling/nulls/sum.rs
6939 views
#![allow(unsafe_op_in_unsafe_fn)]1use super::*;23pub struct SumWindow<'a, T, S> {4slice: &'a [T],5validity: &'a Bitmap,6sum: S,7err: S,8non_finite_count: usize, // NaN or infinity.9pos_inf_count: usize,10neg_inf_count: usize,11pub(super) null_count: usize,12last_start: usize,13last_end: usize,14}1516impl<T, S> SumWindow<'_, T, S>17where18T: NativeType + IsFloat + Sub<Output = T> + NumCast + PartialOrd,19S: NativeType + AddAssign + SubAssign + Sub<Output = S> + Add<Output = S> + NumCast,20{21fn add_finite_kahan(&mut self, val: T) {22let val: S = NumCast::from(val).unwrap();23let y = val - self.err;24let new_sum = self.sum + y;25self.err = (new_sum - self.sum) - y;26self.sum = new_sum;27}2829fn add(&mut self, val: T) {30if T::is_float() {31if val.is_finite() {32self.add_finite_kahan(val);33} else {34self.non_finite_count += 1;35self.pos_inf_count += (val > T::zeroed()) as usize;36self.neg_inf_count += (val < T::zeroed()) as usize;37}38} else {39let val: S = NumCast::from(val).unwrap();40self.sum += val;41}42}4344fn sub(&mut self, val: T) {45if T::is_float() {46if val.is_finite() {47self.add_finite_kahan(T::zeroed() - val);48} else {49self.non_finite_count -= 1;50self.pos_inf_count -= (val > T::zeroed()) as usize;51self.neg_inf_count -= (val < T::zeroed()) as usize;52}53} else {54let val: S = NumCast::from(val).unwrap();55self.sum -= val;56}57}58}5960impl<'a, T, S> RollingAggWindowNulls<'a, T> for SumWindow<'a, T, S>61where62T: NativeType + IsFloat + Sub<Output = T> + NumCast + PartialOrd,63S: NativeType + AddAssign + SubAssign + Sub<Output = S> + Add<Output = S> + NumCast,64{65unsafe fn new(66slice: &'a [T],67validity: &'a Bitmap,68start: usize,69end: usize,70_params: Option<RollingFnParams>,71_window_size: Option<usize>,72) -> Self {73let mut out = Self {74slice,75validity,76sum: S::zeroed(),77err: S::zeroed(),78non_finite_count: 0,79pos_inf_count: 0,80neg_inf_count: 0,81last_start: 0,82last_end: 0,83null_count: 0,84};85out.update(start, end);86out87}8889// # Safety90// The start, end range must be in-bounds.91unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {92if start >= self.last_end {93self.sum = S::zeroed();94self.err = S::zeroed();95self.non_finite_count = 0;96self.pos_inf_count = 0;97self.neg_inf_count = 0;98self.null_count = 0;99self.last_start = start;100self.last_end = start;101}102103for idx in self.last_start..start {104let valid = self.validity.get_bit_unchecked(idx);105if valid {106self.sub(unsafe { *self.slice.get_unchecked(idx) });107} else {108self.null_count -= 1;109}110}111112for idx in self.last_end..end {113let valid = self.validity.get_bit_unchecked(idx);114if valid {115self.add(unsafe { *self.slice.get_unchecked(idx) });116} else {117self.null_count += 1;118}119}120121self.last_start = start;122self.last_end = end;123if self.non_finite_count == 0 {124NumCast::from(self.sum)125} else if self.non_finite_count == self.pos_inf_count {126Some(T::pos_inf_value())127} else if self.non_finite_count == self.neg_inf_count {128Some(T::neg_inf_value())129} else {130Some(T::nan_value())131}132}133134fn is_valid(&self, min_periods: usize) -> bool {135((self.last_end - self.last_start) - self.null_count) >= min_periods136}137}138139pub fn rolling_sum<T>(140arr: &PrimitiveArray<T>,141window_size: usize,142min_periods: usize,143center: bool,144weights: Option<&[f64]>,145_params: Option<RollingFnParams>,146) -> ArrayRef147where148T: NativeType149+ IsFloat150+ PartialOrd151+ Add<Output = T>152+ Sub<Output = T>153+ SubAssign154+ AddAssign155+ NumCast,156{157if weights.is_some() {158panic!("weights not yet supported on array with null values")159}160if center {161rolling_apply_agg_window::<SumWindow<T, T>, _, _>(162arr.values().as_slice(),163arr.validity().as_ref().unwrap(),164window_size,165min_periods,166det_offsets_center,167None,168)169} else {170rolling_apply_agg_window::<SumWindow<T, T>, _, _>(171arr.values().as_slice(),172arr.validity().as_ref().unwrap(),173window_size,174min_periods,175det_offsets,176None,177)178}179}180181182