Path: blob/main/crates/polars-arrow/src/legacy/kernels/ewm/average.rs
6940 views
use std::ops::{AddAssign, MulAssign};12use num_traits::Float;34use crate::array::PrimitiveArray;5use crate::legacy::utils::CustomIterTools;6use crate::trusted_len::TrustedLen;7use crate::types::NativeType;89pub fn ewm_mean<I, T>(10xs: I,11alpha: T,12adjust: bool,13min_periods: usize,14ignore_nulls: bool,15) -> PrimitiveArray<T>16where17I: IntoIterator<Item = Option<T>>,18I::IntoIter: TrustedLen,19T: Float + NativeType + AddAssign + MulAssign,20{21let new_wt = if adjust { T::one() } else { alpha };22let old_wt_factor = T::one() - alpha;23let mut old_wt = T::one();24let mut weighted_avg = None;25let mut non_null_cnt = 0usize;2627xs.into_iter()28.enumerate()29.map(|(i, opt_x)| {30if opt_x.is_some() {31non_null_cnt += 1;32}33match (i, weighted_avg) {34(0, _) | (_, None) => weighted_avg = opt_x,35(_, Some(w_avg)) => {36if opt_x.is_some() || !ignore_nulls {37old_wt *= old_wt_factor;38if let Some(x) = opt_x {39if w_avg != x {40weighted_avg =41Some((old_wt * w_avg + new_wt * x) / (old_wt + new_wt));42}43old_wt = if adjust { old_wt + new_wt } else { T::one() };44}45}46},47}48match (non_null_cnt < min_periods, opt_x.is_some()) {49(_, false) => None,50(true, true) => None,51(false, true) => weighted_avg,52}53})54.collect_trusted()55}5657#[cfg(test)]58mod test {59use super::super::assert_allclose;60use super::*;61const ALPHA: f64 = 0.5;62const EPS: f64 = 1e-15;6364#[test]65fn test_ewm_mean_without_null() {66let xs: Vec<Option<f64>> = vec![Some(1.0), Some(2.0), Some(3.0)];67for adjust in [false, true] {68for ignore_nulls in [false, true] {69for min_periods in [0, 1] {70let result = ewm_mean(xs.clone(), ALPHA, adjust, min_periods, ignore_nulls);71let expected = match adjust {72false => PrimitiveArray::from([Some(1.0f64), Some(1.5f64), Some(2.25f64)]),73true => PrimitiveArray::from([74Some(1.0),75Some(1.666_666_666_666_666_7),76Some(2.428_571_428_571_428_4),77]),78};79assert_allclose!(result, expected, 1e-15);80}81let result = ewm_mean(xs.clone(), ALPHA, adjust, 2, ignore_nulls);82let expected = match adjust {83false => PrimitiveArray::from([None, Some(1.5f64), Some(2.25f64)]),84true => PrimitiveArray::from([85None,86Some(1.666_666_666_666_666_7),87Some(2.428_571_428_571_428_4),88]),89};90assert_allclose!(result, expected, EPS);91}92}93}9495#[test]96fn test_ewm_mean_with_null() {97let xs1 = vec![98None,99None,100Some(5.0f64),101Some(7.0f64),102None,103Some(2.0f64),104Some(1.0f64),105Some(4.0f64),106];107assert_allclose!(108ewm_mean(xs1.clone(), 0.5, true, 0, true),109PrimitiveArray::from([110None,111None,112Some(5.0),113Some(6.333_333_333_333_333),114None,115Some(3.857_142_857_142_857),116Some(2.333_333_333_333_333_5),117Some(3.193_548_387_096_774),118]),119EPS120);121assert_allclose!(122ewm_mean(xs1.clone(), 0.5, true, 0, false),123PrimitiveArray::from([124None,125None,126Some(5.0),127Some(6.333_333_333_333_333),128None,129Some(3.181_818_181_818_181_7),130Some(1.888_888_888_888_888_8),131Some(3.033_898_305_084_745_7),132]),133EPS134);135assert_allclose!(136ewm_mean(xs1.clone(), 0.5, false, 0, true),137PrimitiveArray::from([138None,139None,140Some(5.0),141Some(6.0),142None,143Some(4.0),144Some(2.5),145Some(3.25),146]),147EPS148);149assert_allclose!(150ewm_mean(xs1, 0.5, false, 0, false),151PrimitiveArray::from([152None,153None,154Some(5.0),155Some(6.0),156None,157Some(3.333_333_333_333_333_5),158Some(2.166_666_666_666_667),159Some(3.083_333_333_333_333_5),160]),161EPS162);163}164}165166167