Path: blob/main/crates/polars-compute/src/rolling/sum.rs
8421 views
use std::ops::{Add, AddAssign, Sub, SubAssign};12use super::no_nulls::RollingAggWindowNoNulls;3use super::nulls::RollingAggWindowNulls;4use super::*;56pub struct SumWindow<'a, T, S> {7slice: &'a [T],8validity: Option<&'a Bitmap>,9sum: S,10err_add: S,11err_sub: S,12non_finite_count: usize, // NaN or infinity.13pos_inf_count: usize,14neg_inf_count: usize,15pub(super) null_count: usize,16pub(super) start: usize,17pub(super) end: usize,18}1920impl<'a, T, S> SumWindow<'a, T, S>21where22T: NativeType + IsFloat + Sub<Output = T> + NumCast + PartialOrd,23S: NativeType + AddAssign + SubAssign + Sub<Output = S> + Add<Output = S> + NumCast,24{25fn new_impl(slice: &'a [T], validity: Option<&'a Bitmap>) -> Self {26Self {27slice,28validity,29sum: S::zeroed(),30err_add: S::zeroed(),31err_sub: S::zeroed(),32non_finite_count: 0,33pos_inf_count: 0,34neg_inf_count: 0,35null_count: 0,36start: 0,37end: 0,38}39}4041fn reset(&mut self) {42self.sum = S::zeroed();43self.err_add = S::zeroed();44self.err_sub = S::zeroed();45self.non_finite_count = 0;46self.pos_inf_count = 0;47self.neg_inf_count = 0;48self.null_count = 0;49}5051fn add_finite_kahan(&mut self, val: T) {52let val: S = NumCast::from(val).unwrap();53let y = val - self.err_add;54let new_sum = self.sum + y;55self.err_add = (new_sum - self.sum) - y;56self.sum = new_sum;57}5859fn sub_finite_kahan(&mut self, val: T) {60let val: S = NumCast::from(T::zeroed() - val).unwrap();61let y = val - self.err_sub;62let new_sum = self.sum + y;63self.err_sub = (new_sum - self.sum) - y;64self.sum = new_sum;65}6667fn add(&mut self, val: T) {68if T::is_float() {69if val.is_finite() {70self.add_finite_kahan(val);71} else {72self.non_finite_count += 1;73self.pos_inf_count += (val > T::zeroed()) as usize;74self.neg_inf_count += (val < T::zeroed()) as usize;75}76} else {77let val: S = NumCast::from(val).unwrap();78self.sum += val;79}80}8182fn sub(&mut self, val: T) {83if T::is_float() {84if val.is_finite() {85self.sub_finite_kahan(val);86} else {87self.non_finite_count -= 1;88self.pos_inf_count -= (val > T::zeroed()) as usize;89self.neg_inf_count -= (val < T::zeroed()) as usize;90}91} else {92let val: S = NumCast::from(val).unwrap();93self.sum -= val;94}95}9697fn get_sum(&self) -> Option<T> {98if self.non_finite_count == 0 {99NumCast::from(self.sum)100} else if self.non_finite_count == self.pos_inf_count {101Some(T::pos_inf_value())102} else if self.non_finite_count == self.neg_inf_count {103Some(T::neg_inf_value())104} else {105Some(T::nan_value())106}107}108}109110impl<T, S> RollingAggWindowNoNulls<T> for SumWindow<'_, T, S>111where112T: NativeType + IsFloat + Sub<Output = T> + NumCast + PartialOrd,113S: NativeType + AddAssign + SubAssign + Sub<Output = S> + Add<Output = S> + NumCast,114{115type This<'a> = SumWindow<'a, T, S>;116117fn new<'a>(118slice: &'a [T],119start: usize,120end: usize,121_params: Option<RollingFnParams>,122_window_size: Option<usize>,123) -> Self::This<'a> {124let mut out = SumWindow::new_impl(slice, None);125unsafe { RollingAggWindowNoNulls::update(&mut out, start, end) };126out127}128129// # Safety130// The start, end range must be in-bounds.131unsafe fn update(&mut self, new_start: usize, new_end: usize) {132if new_start >= self.end {133self.reset();134self.start = new_start;135self.end = new_start;136}137138for val in &self.slice[self.start..new_start] {139self.sub(*val);140}141142for val in &self.slice[self.end..new_end] {143self.add(*val);144}145146self.start = new_start;147self.end = new_end;148}149150fn get_agg(&self, _idx: usize) -> Option<T> {151self.get_sum()152}153154fn slice_len(&self) -> usize {155self.slice.len()156}157}158159impl<T, S> RollingAggWindowNulls<T> for SumWindow<'_, T, S>160where161T: NativeType + IsFloat + Sub<Output = T> + NumCast + PartialOrd,162S: NativeType + AddAssign + SubAssign + Sub<Output = S> + Add<Output = S> + NumCast,163{164type This<'a> = SumWindow<'a, T, S>;165166fn new<'a>(167slice: &'a [T],168validity: &'a Bitmap,169start: usize,170end: usize,171_params: Option<RollingFnParams>,172_window_size: Option<usize>,173) -> Self::This<'a> {174assert!(start <= slice.len() && end <= slice.len() && start <= end);175let mut out = SumWindow::new_impl(slice, Some(validity));176// SAFETY: We bounds checked `start` and `end`.177unsafe { RollingAggWindowNulls::update(&mut out, start, end) };178out179}180181// # Safety182// The start, end range must be in-bounds.183unsafe fn update(&mut self, new_start: usize, new_end: usize) {184let validity = unsafe { self.validity.unwrap_unchecked() };185186if new_start >= self.end {187self.reset();188self.start = new_start;189self.end = new_start;190}191192for idx in self.start..new_start {193let valid = unsafe { validity.get_bit_unchecked(idx) };194if valid {195self.sub(unsafe { *self.slice.get_unchecked(idx) });196} else {197self.null_count -= 1;198}199}200201for idx in self.end..new_end {202let valid = unsafe { validity.get_bit_unchecked(idx) };203if valid {204self.add(unsafe { *self.slice.get_unchecked(idx) });205} else {206self.null_count += 1;207}208}209210self.start = new_start;211self.end = new_end;212}213214fn get_agg(&self, _idx: usize) -> Option<T> {215self.get_sum()216}217218fn is_valid(&self, min_periods: usize) -> bool {219((self.end - self.start) - self.null_count) >= min_periods220}221222fn slice_len(&self) -> usize {223self.slice.len()224}225}226227228