Path: blob/main/crates/polars-compute/src/rolling/no_nulls/moment.rs
6939 views
#![allow(unsafe_op_in_unsafe_fn)]1use num_traits::{FromPrimitive, ToPrimitive};2use polars_error::polars_ensure;34pub use super::super::moment::*;5use super::*;67pub struct MomentWindow<'a, T, M: StateUpdate> {8slice: &'a [T],9moment: M,10last_start: usize,11last_end: usize,12params: Option<RollingFnParams>,13}1415impl<T: ToPrimitive + Copy, M: StateUpdate> MomentWindow<'_, T, M> {16fn compute_var(&mut self, start: usize, end: usize) {17self.moment = M::new(self.params);18for value in &self.slice[start..end] {19let value: f64 = NumCast::from(*value).unwrap();20self.moment.insert_one(value);21}22}23}2425impl<'a, T: NativeType + IsFloat + Float + ToPrimitive + FromPrimitive, M: StateUpdate>26RollingAggWindowNoNulls<'a, T> for MomentWindow<'a, T, M>27{28fn new(29slice: &'a [T],30start: usize,31end: usize,32params: Option<RollingFnParams>,33_window_size: Option<usize>,34) -> Self {35let mut out = Self {36slice,37moment: M::new(params),38last_start: start,39last_end: end,40params,41};42out.compute_var(start, end);43out44}4546unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {47let recompute_var = if start >= self.last_end {48true49} else {50// remove elements that should leave the window51let mut recompute_var = false;52for idx in self.last_start..start {53// SAFETY: we are in bounds54let leaving_value = *self.slice.get_unchecked(idx);5556// if the leaving value is nan we need to recompute the window57if T::is_float() && !leaving_value.is_finite() {58recompute_var = true;59break;60}61let leaving_value: f64 = NumCast::from(leaving_value).unwrap();62self.moment.remove_one(leaving_value);63}64recompute_var65};6667self.last_start = start;6869// we traverse all values and compute70if recompute_var {71self.compute_var(start, end);72} else {73for idx in self.last_end..end {74let entering_value = *self.slice.get_unchecked(idx);75let entering_value: f64 = NumCast::from(entering_value).unwrap();7677self.moment.insert_one(entering_value);78}79}80self.last_end = end;81self.moment.finalize().map(|v| T::from_f64(v).unwrap())82}83}8485pub fn rolling_var<T>(86values: &[T],87window_size: usize,88min_periods: usize,89center: bool,90weights: Option<&[f64]>,91params: Option<RollingFnParams>,92) -> PolarsResult<ArrayRef>93where94T: NativeType + Float + IsFloat + ToPrimitive + FromPrimitive + AddAssign,95{96let offset_fn = match center {97true => det_offsets_center,98false => det_offsets,99};100match weights {101None => rolling_apply_agg_window::<MomentWindow<_, VarianceMoment>, _, _>(102values,103window_size,104min_periods,105offset_fn,106params,107),108Some(weights) => {109// Validate and standardize the weights like we do for the mean. This definition is fine110// because frequency weights and unbiasing don't make sense for rolling operations.111let mut wts = no_nulls::coerce_weights(weights);112let wsum = wts.iter().fold(T::zero(), |acc, x| acc + *x);113polars_ensure!(114wsum != T::zero(),115ComputeError: "Weighted variance is undefined if weights sum to 0"116);117wts.iter_mut().for_each(|w| *w = *w / wsum);118super::rolling_apply_weights(119values,120window_size,121min_periods,122offset_fn,123compute_var_weights,124&wts,125center,126)127},128}129}130131pub fn rolling_skew<T>(132values: &[T],133window_size: usize,134min_periods: usize,135center: bool,136params: Option<RollingFnParams>,137) -> PolarsResult<ArrayRef>138where139T: NativeType + Float + IsFloat + ToPrimitive + FromPrimitive + AddAssign,140{141let offset_fn = match center {142true => det_offsets_center,143false => det_offsets,144};145rolling_apply_agg_window::<MomentWindow<_, SkewMoment>, _, _>(146values,147window_size,148min_periods,149offset_fn,150params,151)152}153154pub fn rolling_kurtosis<T>(155values: &[T],156window_size: usize,157min_periods: usize,158center: bool,159params: Option<RollingFnParams>,160) -> PolarsResult<ArrayRef>161where162T: NativeType + Float + IsFloat + ToPrimitive + FromPrimitive + AddAssign,163{164let offset_fn = match center {165true => det_offsets_center,166false => det_offsets,167};168rolling_apply_agg_window::<MomentWindow<_, KurtosisMoment>, _, _>(169values,170window_size,171min_periods,172offset_fn,173params,174)175}176177#[cfg(test)]178mod test {179use super::*;180181#[test]182fn test_rolling_var() {183let values = &[1.0f64, 5.0, 3.0, 4.0];184185let out = rolling_var(values, 2, 2, false, None, None).unwrap();186let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();187let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();188assert_eq!(out, &[None, Some(8.0), Some(2.0), Some(0.5)]);189190let testpars = Some(RollingFnParams::Var(RollingVarParams { ddof: 0 }));191let out = rolling_var(values, 2, 2, false, None, testpars).unwrap();192let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();193let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();194assert_eq!(out, &[None, Some(4.0), Some(1.0), Some(0.25)]);195196let out = rolling_var(values, 2, 1, false, None, None).unwrap();197let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();198let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();199// we cannot compare nans, so we compare the string values200assert_eq!(201format!("{:?}", out.as_slice()),202format!("{:?}", &[None, Some(8.0), Some(2.0), Some(0.5)])203);204// test nan handling.205let values = &[-10.0, 2.0, 3.0, f64::nan(), 5.0, 6.0, 7.0];206let out = rolling_var(values, 3, 3, false, None, None).unwrap();207let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();208let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();209// we cannot compare nans, so we compare the string values210assert_eq!(211format!("{:?}", out.as_slice()),212format!(213"{:?}",214&[215None,216None,217Some(52.33333333333333),218Some(f64::nan()),219Some(f64::nan()),220Some(f64::nan()),221Some(1.0)222]223)224);225}226}227228229