Path: blob/main/crates/polars-compute/src/rolling/no_nulls/sum.rs
6939 views
#![allow(unsafe_op_in_unsafe_fn)]1use super::*;23pub struct SumWindow<'a, T, S> {4slice: &'a [T],5sum: S,6err: S,7non_finite_count: usize, // NaN or infinity.8pos_inf_count: usize,9neg_inf_count: usize,10last_start: usize,11last_end: usize,12}1314impl<T, S> SumWindow<'_, T, S>15where16T: NativeType + IsFloat + Sub<Output = T> + NumCast + PartialOrd,17S: NativeType + AddAssign + SubAssign + Sub<Output = S> + Add<Output = S> + NumCast,18{19fn add_finite_kahan(&mut self, val: T) {20let val: S = NumCast::from(val).unwrap();21let y = val - self.err;22let new_sum = self.sum + y;23self.err = (new_sum - self.sum) - y;24self.sum = new_sum;25}2627fn add(&mut self, val: T) {28if T::is_float() {29if val.is_finite() {30self.add_finite_kahan(val);31} else {32self.non_finite_count += 1;33self.pos_inf_count += (val > T::zeroed()) as usize;34self.neg_inf_count += (val < T::zeroed()) as usize;35}36} else {37let val: S = NumCast::from(val).unwrap();38self.sum += val;39}40}4142fn sub(&mut self, val: T) {43if T::is_float() {44if val.is_finite() {45self.add_finite_kahan(T::zeroed() - val);46} else {47self.non_finite_count -= 1;48self.pos_inf_count -= (val > T::zeroed()) as usize;49self.neg_inf_count -= (val < T::zeroed()) as usize;50}51} else {52let val: S = NumCast::from(val).unwrap();53self.sum -= val;54}55}56}5758impl<'a, T, S> RollingAggWindowNoNulls<'a, T> for SumWindow<'a, T, S>59where60T: NativeType + IsFloat + Sub<Output = T> + NumCast + PartialOrd,61S: NativeType + AddAssign + SubAssign + Sub<Output = S> + Add<Output = S> + NumCast,62{63fn new(64slice: &'a [T],65start: usize,66end: usize,67_params: Option<RollingFnParams>,68_window_size: Option<usize>,69) -> Self {70let mut out = Self {71slice,72sum: S::zeroed(),73err: S::zeroed(),74non_finite_count: 0,75pos_inf_count: 0,76neg_inf_count: 0,77last_start: 0,78last_end: 0,79};80unsafe { out.update(start, end) };81out82}8384// # Safety85// The start, end range must be in-bounds.86unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {87if start >= self.last_end {88self.sum = S::zeroed();89self.err = S::zeroed();90self.non_finite_count = 0;91self.pos_inf_count = 0;92self.neg_inf_count = 0;93self.last_start = start;94self.last_end = start;95}9697for val in &self.slice[self.last_start..start] {98self.sub(*val);99}100101for val in &self.slice[self.last_end..end] {102self.add(*val);103}104105self.last_start = start;106self.last_end = end;107if self.non_finite_count == 0 {108NumCast::from(self.sum)109} else if self.non_finite_count == self.pos_inf_count {110Some(T::pos_inf_value())111} else if self.non_finite_count == self.neg_inf_count {112Some(T::neg_inf_value())113} else {114Some(T::nan_value())115}116}117}118119pub fn rolling_sum<T>(120values: &[T],121window_size: usize,122min_periods: usize,123center: bool,124weights: Option<&[f64]>,125_params: Option<RollingFnParams>,126) -> PolarsResult<ArrayRef>127where128T: NativeType129+ std::iter::Sum130+ NumCast131+ Mul<Output = T>132+ AddAssign133+ SubAssign134+ IsFloat135+ Num136+ PartialOrd,137{138match (center, weights) {139(true, None) => rolling_apply_agg_window::<SumWindow<T, T>, _, _>(140values,141window_size,142min_periods,143det_offsets_center,144None,145),146(false, None) => rolling_apply_agg_window::<SumWindow<T, T>, _, _>(147values,148window_size,149min_periods,150det_offsets,151None,152),153(true, Some(weights)) => {154let weights = no_nulls::coerce_weights(weights);155no_nulls::rolling_apply_weights(156values,157window_size,158min_periods,159det_offsets_center,160no_nulls::compute_sum_weights,161&weights,162center,163)164},165(false, Some(weights)) => {166let weights = no_nulls::coerce_weights(weights);167no_nulls::rolling_apply_weights(168values,169window_size,170min_periods,171det_offsets,172no_nulls::compute_sum_weights,173&weights,174center,175)176},177}178}179180#[cfg(test)]181mod test {182use super::*;183#[test]184fn test_rolling_sum() {185let values = &[1.0f64, 2.0, 3.0, 4.0];186187let out = rolling_sum(values, 2, 2, false, None, None).unwrap();188let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();189let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();190assert_eq!(out, &[None, Some(3.0), Some(5.0), Some(7.0)]);191192let out = rolling_sum(values, 2, 1, false, None, None).unwrap();193let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();194let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();195assert_eq!(out, &[Some(1.0), Some(3.0), Some(5.0), Some(7.0)]);196197let out = rolling_sum(values, 4, 1, false, None, None).unwrap();198let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();199let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();200assert_eq!(out, &[Some(1.0), Some(3.0), Some(6.0), Some(10.0)]);201202let out = rolling_sum(values, 4, 1, true, None, None).unwrap();203let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();204let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();205assert_eq!(out, &[Some(3.0), Some(6.0), Some(10.0), Some(9.0)]);206207let out = rolling_sum(values, 4, 4, true, None, None).unwrap();208let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();209let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();210assert_eq!(out, &[None, None, Some(10.0), None]);211212// test nan handling.213let values = &[1.0, 2.0, 3.0, f64::nan(), 5.0, 6.0, 7.0];214let out = rolling_sum(values, 3, 3, false, None, None).unwrap();215let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();216let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();217218assert_eq!(219format!("{:?}", out.as_slice()),220format!(221"{:?}",222&[223None,224None,225Some(6.0),226Some(f64::nan()),227Some(f64::nan()),228Some(f64::nan()),229Some(18.0)230]231)232);233}234}235236237