Path: blob/main/crates/polars-compute/src/rolling/no_nulls/sum.rs
8433 views
#![allow(unsafe_op_in_unsafe_fn)]1use super::super::sum::SumWindow;2use super::*;34pub fn rolling_sum<T>(5values: &[T],6window_size: usize,7min_periods: usize,8center: bool,9weights: Option<&[f64]>,10_params: Option<RollingFnParams>,11) -> PolarsResult<ArrayRef>12where13T: NativeType14+ std::iter::Sum15+ NumCast16+ Mul<Output = T>17+ AddAssign18+ SubAssign19+ IsFloat20+ Num21+ PartialOrd,22{23match (center, weights) {24(true, None) => rolling_apply_agg_window::<SumWindow<T, T>, _, _, _>(25values,26window_size,27min_periods,28det_offsets_center,29None,30),31(false, None) => rolling_apply_agg_window::<SumWindow<T, T>, _, _, _>(32values,33window_size,34min_periods,35det_offsets,36None,37),38(true, Some(weights)) => {39let weights = no_nulls::coerce_weights(weights);40no_nulls::rolling_apply_weights(41values,42window_size,43min_periods,44det_offsets_center,45no_nulls::compute_sum_weights,46&weights,47center,48)49},50(false, Some(weights)) => {51let weights = no_nulls::coerce_weights(weights);52no_nulls::rolling_apply_weights(53values,54window_size,55min_periods,56det_offsets,57no_nulls::compute_sum_weights,58&weights,59center,60)61},62}63}6465#[cfg(test)]66mod test {67use super::*;68#[test]69fn test_rolling_sum() {70let values = &[1.0f64, 2.0, 3.0, 4.0];7172let out = rolling_sum(values, 2, 2, false, None, None).unwrap();73let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();74let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();75assert_eq!(out, &[None, Some(3.0), Some(5.0), Some(7.0)]);7677let out = rolling_sum(values, 2, 1, false, None, None).unwrap();78let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();79let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();80assert_eq!(out, &[Some(1.0), Some(3.0), Some(5.0), Some(7.0)]);8182let out = rolling_sum(values, 4, 1, false, None, None).unwrap();83let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();84let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();85assert_eq!(out, &[Some(1.0), Some(3.0), Some(6.0), Some(10.0)]);8687let out = rolling_sum(values, 4, 1, true, None, None).unwrap();88let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();89let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();90assert_eq!(out, &[Some(3.0), Some(6.0), Some(10.0), Some(9.0)]);9192let out = rolling_sum(values, 4, 4, true, None, None).unwrap();93let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();94let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();95assert_eq!(out, &[None, None, Some(10.0), None]);9697// test nan handling.98let values = &[1.0, 2.0, 3.0, f64::nan(), 5.0, 6.0, 7.0];99let out = rolling_sum(values, 3, 3, false, None, None).unwrap();100let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();101let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();102103assert_eq!(104format!("{:?}", out.as_slice()),105format!(106"{:?}",107&[108None,109None,110Some(6.0),111Some(f64::nan()),112Some(f64::nan()),113Some(f64::nan()),114Some(18.0)115]116)117);118}119}120121122