Path: blob/main/crates/polars-compute/src/rolling/sum.rs
7884 views
use std::ops::{Add, AddAssign, Sub, SubAssign};12use super::no_nulls::RollingAggWindowNoNulls;3use super::nulls::RollingAggWindowNulls;4use super::*;56pub struct SumWindow<'a, T, S> {7slice: &'a [T],8validity: Option<&'a Bitmap>,9sum: S,10err_add: S,11err_sub: S,12non_finite_count: usize, // NaN or infinity.13pos_inf_count: usize,14neg_inf_count: usize,15pub(super) null_count: usize,16last_start: usize,17last_end: usize,18}1920impl<'a, T, S> SumWindow<'a, T, S>21where22T: NativeType + IsFloat + Sub<Output = T> + NumCast + PartialOrd,23S: NativeType + AddAssign + SubAssign + Sub<Output = S> + Add<Output = S> + NumCast,24{25fn new_impl(slice: &'a [T], validity: Option<&'a Bitmap>) -> Self {26Self {27slice,28validity,29sum: S::zeroed(),30err_add: S::zeroed(),31err_sub: S::zeroed(),32non_finite_count: 0,33pos_inf_count: 0,34neg_inf_count: 0,35null_count: 0,36last_start: 0,37last_end: 0,38}39}4041fn reset(&mut self) {42self.sum = S::zeroed();43self.err_add = S::zeroed();44self.err_sub = S::zeroed();45self.non_finite_count = 0;46self.pos_inf_count = 0;47self.neg_inf_count = 0;48self.null_count = 0;49}5051fn add_finite_kahan(&mut self, val: T) {52let val: S = NumCast::from(val).unwrap();53let y = val - self.err_add;54let new_sum = self.sum + y;55self.err_add = (new_sum - self.sum) - y;56self.sum = new_sum;57}5859fn sub_finite_kahan(&mut self, val: T) {60let val: S = NumCast::from(T::zeroed() - val).unwrap();61let y = val - self.err_sub;62let new_sum = self.sum + y;63self.err_sub = (new_sum - self.sum) - y;64self.sum = new_sum;65}6667fn add(&mut self, val: T) {68if T::is_float() {69if val.is_finite() {70self.add_finite_kahan(val);71} else {72self.non_finite_count += 1;73self.pos_inf_count += (val > T::zeroed()) as usize;74self.neg_inf_count += (val < T::zeroed()) as usize;75}76} else {77let val: S = NumCast::from(val).unwrap();78self.sum += val;79}80}8182fn sub(&mut self, val: T) {83if T::is_float() {84if val.is_finite() {85self.sub_finite_kahan(val);86} else {87self.non_finite_count -= 1;88self.pos_inf_count -= (val > T::zeroed()) as usize;89self.neg_inf_count -= (val < T::zeroed()) as usize;90}91} else {92let val: S = NumCast::from(val).unwrap();93self.sum -= val;94}95}9697fn finalize(&self) -> Option<T> {98if self.non_finite_count == 0 {99NumCast::from(self.sum)100} else if self.non_finite_count == self.pos_inf_count {101Some(T::pos_inf_value())102} else if self.non_finite_count == self.neg_inf_count {103Some(T::neg_inf_value())104} else {105Some(T::nan_value())106}107}108}109110impl<'a, T, S> RollingAggWindowNoNulls<'a, T> for SumWindow<'a, T, S>111where112T: NativeType + IsFloat + Sub<Output = T> + NumCast + PartialOrd,113S: NativeType + AddAssign + SubAssign + Sub<Output = S> + Add<Output = S> + NumCast,114{115fn new(116slice: &'a [T],117start: usize,118end: usize,119_params: Option<RollingFnParams>,120_window_size: Option<usize>,121) -> Self {122let mut out = Self::new_impl(slice, None);123unsafe { RollingAggWindowNoNulls::update(&mut out, start, end) };124out125}126127// # Safety128// The start, end range must be in-bounds.129unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {130if start >= self.last_end {131self.reset();132self.last_start = start;133self.last_end = start;134}135136for val in &self.slice[self.last_start..start] {137self.sub(*val);138}139140for val in &self.slice[self.last_end..end] {141self.add(*val);142}143144self.last_start = start;145self.last_end = end;146self.finalize()147}148}149150impl<'a, T, S> RollingAggWindowNulls<'a, T> for SumWindow<'a, T, S>151where152T: NativeType + IsFloat + Sub<Output = T> + NumCast + PartialOrd,153S: NativeType + AddAssign + SubAssign + Sub<Output = S> + Add<Output = S> + NumCast,154{155unsafe fn new(156slice: &'a [T],157validity: &'a Bitmap,158start: usize,159end: usize,160_params: Option<RollingFnParams>,161_window_size: Option<usize>,162) -> Self {163let mut out = Self::new_impl(slice, Some(validity));164unsafe { RollingAggWindowNulls::update(&mut out, start, end) };165out166}167168// # Safety169// The start, end range must be in-bounds.170unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {171let validity = unsafe { self.validity.unwrap_unchecked() };172173if start >= self.last_end {174self.reset();175self.last_start = start;176self.last_end = start;177}178179for idx in self.last_start..start {180let valid = unsafe { validity.get_bit_unchecked(idx) };181if valid {182self.sub(unsafe { *self.slice.get_unchecked(idx) });183} else {184self.null_count -= 1;185}186}187188for idx in self.last_end..end {189let valid = unsafe { validity.get_bit_unchecked(idx) };190if valid {191self.add(unsafe { *self.slice.get_unchecked(idx) });192} else {193self.null_count += 1;194}195}196197self.last_start = start;198self.last_end = end;199self.finalize()200}201202fn is_valid(&self, min_periods: usize) -> bool {203((self.last_end - self.last_start) - self.null_count) >= min_periods204}205}206207208