Path: blob/main/crates/polars-arrow/src/legacy/kernels/ewm/variance.rs
6940 views
use std::ops::{AddAssign, DivAssign, MulAssign};12use num_traits::Float;34use crate::array::PrimitiveArray;5use crate::legacy::utils::CustomIterTools;6use crate::trusted_len::TrustedLen;7use crate::types::NativeType;89#[allow(clippy::too_many_arguments)]10fn ewm_cov_internal<I, T>(11xs: I,12ys: I,13alpha: T,14adjust: bool,15bias: bool,16min_periods: usize,17ignore_nulls: bool,18do_sqrt: bool,19) -> PrimitiveArray<T>20where21I: IntoIterator<Item = Option<T>>,22I::IntoIter: TrustedLen,23T: Float + NativeType + AddAssign + MulAssign + DivAssign,24{25let old_wt_factor = T::one() - alpha;26let new_wt = if adjust { T::one() } else { alpha };27let mut sum_wt = T::one();28let mut sum_wt2 = T::one();29let mut old_wt = T::one();3031let mut opt_mean_x = None;32let mut opt_mean_y = None;33let mut cov = T::zero();34let mut non_na_cnt = 0usize;35let min_periods_fixed = if min_periods == 0 { 1 } else { min_periods };3637let res = xs38.into_iter()39.zip(ys)40.enumerate()41.map(|(i, (opt_x, opt_y))| {42let is_observation = opt_x.is_some() && opt_y.is_some();43if is_observation {44non_na_cnt += 1;45}46match (i, opt_mean_x, opt_mean_y) {47(0, _, _) => {48if is_observation {49opt_mean_x = opt_x;50opt_mean_y = opt_y;51}52},53(_, Some(mean_x), Some(mean_y)) => {54if is_observation || !ignore_nulls {55sum_wt *= old_wt_factor;56sum_wt2 *= old_wt_factor * old_wt_factor;57old_wt *= old_wt_factor;58if is_observation {59let x = opt_x.unwrap();60let y = opt_y.unwrap();61let old_mean_x = mean_x;62let old_mean_y = mean_y;6364// avoid numerical errors on constant series65if mean_x != x {66opt_mean_x =67Some((old_wt * old_mean_x + new_wt * x) / (old_wt + new_wt));68}6970// avoid numerical errors on constant series71if mean_y != y {72opt_mean_y =73Some((old_wt * old_mean_y + new_wt * y) / (old_wt + new_wt));74}7576cov = ((old_wt77* (cov78+ ((old_mean_x - opt_mean_x.unwrap())79* (old_mean_y - opt_mean_y.unwrap()))))80+ (new_wt81* ((x - opt_mean_x.unwrap()) * (y - opt_mean_y.unwrap()))))82/ (old_wt + new_wt);8384sum_wt += new_wt;85sum_wt2 += new_wt * new_wt;86old_wt += new_wt;87if !adjust {88sum_wt /= old_wt;89sum_wt2 /= old_wt * old_wt;90old_wt = T::one();91}92}93}94},95_ => {96if is_observation {97opt_mean_x = opt_x;98opt_mean_y = opt_y;99}100},101}102match (non_na_cnt >= min_periods_fixed, bias, is_observation) {103(_, _, false) => None,104(false, _, true) => None,105(true, false, true) => {106if non_na_cnt == 1 {107Some(cov)108} else {109let numerator = sum_wt * sum_wt;110let denominator = numerator - sum_wt2;111if denominator > T::zero() {112Some((numerator / denominator) * cov)113} else {114None115}116}117},118(true, true, true) => Some(cov),119}120});121122if do_sqrt {123res.map(|opt_x| opt_x.map(|x| x.sqrt())).collect_trusted()124} else {125res.collect_trusted()126}127}128129pub fn ewm_cov<I, T>(130xs: I,131ys: I,132alpha: T,133adjust: bool,134bias: bool,135min_periods: usize,136ignore_nulls: bool,137) -> PrimitiveArray<T>138where139I: IntoIterator<Item = Option<T>>,140I::IntoIter: TrustedLen,141T: Float + NativeType + AddAssign + MulAssign + DivAssign,142{143ewm_cov_internal(144xs,145ys,146alpha,147adjust,148bias,149min_periods,150ignore_nulls,151false,152)153}154155pub fn ewm_var<I, T>(156xs: I,157alpha: T,158adjust: bool,159bias: bool,160min_periods: usize,161ignore_nulls: bool,162) -> PrimitiveArray<T>163where164I: IntoIterator<Item = Option<T>> + Clone,165I::IntoIter: TrustedLen,166T: Float + NativeType + AddAssign + MulAssign + DivAssign,167{168ewm_cov_internal(169xs.clone(),170xs,171alpha,172adjust,173bias,174min_periods,175ignore_nulls,176false,177)178}179180pub fn ewm_std<I, T>(181xs: I,182alpha: T,183adjust: bool,184bias: bool,185min_periods: usize,186ignore_nulls: bool,187) -> PrimitiveArray<T>188where189I: IntoIterator<Item = Option<T>> + Clone,190I::IntoIter: TrustedLen,191T: Float + NativeType + AddAssign + MulAssign + DivAssign,192{193ewm_cov_internal(194xs.clone(),195xs,196alpha,197adjust,198bias,199min_periods,200ignore_nulls,201true,202)203}204205#[cfg(test)]206mod test {207use super::super::assert_allclose;208use super::*;209const ALPHA: f64 = 0.5;210const EPS: f64 = 1e-15;211use std::f64::consts::SQRT_2;212213const XS: [Option<f64>; 7] = [214Some(1.0),215Some(5.0),216Some(7.0),217Some(1.0),218Some(2.0),219Some(1.0),220Some(4.0),221];222const YS: [Option<f64>; 7] = [None, Some(5.0), Some(7.0), None, None, Some(1.0), Some(4.0)];223224#[test]225fn test_ewm_var() {226assert_allclose!(227ewm_var(XS.to_vec(), ALPHA, true, true, 0, true),228PrimitiveArray::from([229Some(0.0),230Some(3.555_555_555_555_556),231Some(4.244_897_959_183_674),232Some(7.182_222_222_222_221),233Some(3.796_045_785_639_958),234Some(2.467_120_181_405_896),235Some(2.476_036_952_073_904_3),236]),237EPS238);239assert_allclose!(240ewm_var(XS.to_vec(), ALPHA, true, true, 0, false),241PrimitiveArray::from([242Some(0.0),243Some(3.555_555_555_555_556),244Some(4.244_897_959_183_674),245Some(7.182_222_222_222_221),246Some(3.796_045_785_639_958),247Some(2.467_120_181_405_896),248Some(2.476_036_952_073_904_3),249]),250EPS251);252assert_allclose!(253ewm_var(XS.to_vec(), ALPHA, true, false, 0, true),254PrimitiveArray::from([255Some(0.0),256Some(8.0),257Some(7.428_571_428_571_429),258Some(11.542_857_142_857_143),259Some(5.883_870_967_741_934_5),260Some(3.760_368_663_594_470_6),261Some(3.743_532_058_492_688_6),262]),263EPS264);265assert_allclose!(266ewm_var(XS.to_vec(), ALPHA, true, false, 0, false),267PrimitiveArray::from([268Some(0.0),269Some(8.0),270Some(7.428_571_428_571_429),271Some(11.542_857_142_857_143),272Some(5.883_870_967_741_934_5),273Some(3.760_368_663_594_470_6),274Some(3.743_532_058_492_688_6),275]),276EPS277);278assert_allclose!(279ewm_var(XS.to_vec(), ALPHA, false, true, 0, true),280PrimitiveArray::from([281Some(0.0),282Some(4.0),283Some(6.0),284Some(7.0),285Some(3.75),286Some(2.437_5),287Some(2.484_375),288]),289EPS290);291assert_allclose!(292ewm_var(XS.to_vec(), ALPHA, false, true, 0, false),293PrimitiveArray::from([294Some(0.0),295Some(4.0),296Some(6.0),297Some(7.0),298Some(3.75),299Some(2.437_5),300Some(2.484_375),301]),302EPS303);304assert_allclose!(305ewm_var(XS.to_vec(), ALPHA, false, true, 0, false),306PrimitiveArray::from([307Some(0.0),308Some(4.0),309Some(6.0),310Some(7.0),311Some(3.75),312Some(2.437_5),313Some(2.484_375),314]),315EPS316);317assert_allclose!(318ewm_var(XS.to_vec(), ALPHA, false, false, 0, true),319PrimitiveArray::from([320Some(0.0),321Some(8.0),322Some(9.600_000_000_000_001),323Some(10.666_666_666_666_666),324Some(5.647_058_823_529_411),325Some(3.659_824_046_920_821),326Some(3.727_472_527_472_527_6),327]),328EPS329);330assert_allclose!(331ewm_var(XS.to_vec(), ALPHA, false, false, 0, false),332PrimitiveArray::from([333Some(0.0),334Some(8.0),335Some(9.600_000_000_000_001),336Some(10.666_666_666_666_666),337Some(5.647_058_823_529_411),338Some(3.659_824_046_920_821),339Some(3.727_472_527_472_527_6),340]),341EPS342);343assert_allclose!(344ewm_var(YS.to_vec(), ALPHA, true, true, 0, true),345PrimitiveArray::from([346None,347Some(0.0),348Some(0.888_888_888_888_889),349None,350None,351Some(7.346_938_775_510_203),352Some(3.555_555_555_555_555_4),353]),354EPS355);356assert_allclose!(357ewm_var(YS.to_vec(), ALPHA, true, true, 0, false),358PrimitiveArray::from([359None,360Some(0.0),361Some(0.888_888_888_888_889),362None,363None,364Some(3.922_437_673_130_193_3),365Some(2.549_788_542_868_127_3),366]),367EPS368);369assert_allclose!(370ewm_var(YS.to_vec(), ALPHA, true, false, 0, true),371PrimitiveArray::from([372None,373Some(0.0),374Some(2.0),375None,376None,377Some(12.857_142_857_142_856),378Some(5.714_285_714_285_714),379]),380EPS381);382assert_allclose!(383ewm_var(YS.to_vec(), ALPHA, true, false, 0, false),384PrimitiveArray::from([385None,386Some(0.0),387Some(2.0),388None,389None,390Some(14.159_999_999_999_997),391Some(5.039_513_677_811_549_5),392]),393EPS394);395assert_allclose!(396ewm_var(YS.to_vec(), ALPHA, false, true, 0, true),397PrimitiveArray::from([398None,399Some(0.0),400Some(1.0),401None,402None,403Some(6.75),404Some(3.437_5),405]),406EPS407);408assert_allclose!(409ewm_var(YS.to_vec(), ALPHA, false, true, 0, false),410PrimitiveArray::from([None, Some(0.0), Some(1.0), None, None, Some(4.2), Some(3.1)]),411EPS412);413assert_allclose!(414ewm_var(YS.to_vec(), ALPHA, false, false, 0, true),415PrimitiveArray::from([416None,417Some(0.0),418Some(2.0),419None,420None,421Some(10.8),422Some(5.238_095_238_095_238),423]),424EPS425);426assert_allclose!(427ewm_var(YS.to_vec(), ALPHA, false, false, 0, false),428PrimitiveArray::from([429None,430Some(0.0),431Some(2.0),432None,433None,434Some(12.352_941_176_470_589),435Some(5.299_145_299_145_3),436]),437EPS438);439}440441#[test]442fn test_ewm_cov() {443assert_allclose!(444ewm_cov(XS.to_vec(), YS.to_vec(), ALPHA, true, true, 0, true),445PrimitiveArray::from([446None,447Some(0.0),448Some(0.888_888_888_888_889),449None,450None,451Some(7.346_938_775_510_203),452Some(3.555_555_555_555_555_4)453]),454EPS455);456assert_allclose!(457ewm_cov(XS.to_vec(), YS.to_vec(), ALPHA, true, true, 0, false),458PrimitiveArray::from([459None,460Some(0.0),461Some(0.888_888_888_888_889),462None,463None,464Some(3.922_437_673_130_193_3),465Some(2.549_788_542_868_127_3)466]),467EPS468);469assert_allclose!(470ewm_cov(XS.to_vec(), YS.to_vec(), ALPHA, true, false, 0, true),471PrimitiveArray::from([472None,473Some(0.0),474Some(2.0),475None,476None,477Some(12.857_142_857_142_856),478Some(5.714_285_714_285_714)479]),480EPS481);482assert_allclose!(483ewm_cov(XS.to_vec(), YS.to_vec(), ALPHA, true, false, 0, false),484PrimitiveArray::from([485None,486Some(0.0),487Some(2.0),488None,489None,490Some(14.159_999_999_999_997),491Some(5.039_513_677_811_549_5)492]),493EPS494);495assert_allclose!(496ewm_cov(XS.to_vec(), YS.to_vec(), ALPHA, false, true, 0, true),497PrimitiveArray::from([498None,499Some(0.0),500Some(1.0),501None,502None,503Some(6.75),504Some(3.437_5)505]),506EPS507);508assert_allclose!(509ewm_cov(XS.to_vec(), YS.to_vec(), ALPHA, false, true, 0, false),510PrimitiveArray::from([None, Some(0.0), Some(1.0), None, None, Some(4.2), Some(3.1)]),511EPS512);513assert_allclose!(514ewm_cov(XS.to_vec(), YS.to_vec(), ALPHA, false, false, 0, true),515PrimitiveArray::from([516None,517Some(0.0),518Some(2.0),519None,520None,521Some(10.8),522Some(5.238_095_238_095_238)523]),524EPS525);526assert_allclose!(527ewm_cov(XS.to_vec(), YS.to_vec(), ALPHA, false, false, 0, false),528PrimitiveArray::from([529None,530Some(0.0),531Some(2.0),532None,533None,534Some(12.352_941_176_470_589),535Some(5.299_145_299_145_3)536]),537EPS538);539}540541#[test]542fn test_ewm_std() {543assert_allclose!(544ewm_std(XS.to_vec(), ALPHA, true, true, 0, true),545PrimitiveArray::from([546Some(0.0),547Some(1.885_618_083_164_126_7),548Some(2.060_315_014_550_851_3),549Some(2.679_966_832_298_904),550Some(1.948_344_370_392_451_5),551Some(1.570_706_904_997_204_2),552Some(1.573_542_802_746_053_2),553]),554EPS555);556assert_allclose!(557ewm_std(XS.to_vec(), ALPHA, true, true, 0, false),558PrimitiveArray::from([559Some(0.0),560Some(1.885_618_083_164_126_7),561Some(2.060_315_014_550_851_3),562Some(2.679_966_832_298_904),563Some(1.948_344_370_392_451_5),564Some(1.570_706_904_997_204_2),565Some(1.573_542_802_746_053_2),566]),567EPS568);569assert_allclose!(570ewm_std(XS.to_vec(), ALPHA, true, false, 0, true),571PrimitiveArray::from([572Some(0.0),573Some(2.828_427_124_746_190_3),574Some(2.725_540_575_476_987_5),575Some(3.397_478_056_273_085_3),576Some(2.425_669_179_369_259),577Some(1.939_167_002_502_484_5),578Some(1.934_820_937_061_796_6),579]),580EPS581);582assert_allclose!(583ewm_std(XS.to_vec(), ALPHA, true, false, 0, false),584PrimitiveArray::from([585Some(0.0),586Some(2.828_427_124_746_190_3),587Some(2.725_540_575_476_987_5),588Some(3.397_478_056_273_085_3),589Some(2.425_669_179_369_259),590Some(1.939_167_002_502_484_5),591Some(1.934_820_937_061_796_6),592]),593EPS594);595assert_allclose!(596ewm_std(XS.to_vec(), ALPHA, false, true, 0, true),597PrimitiveArray::from([598Some(0.0),599Some(2.0),600Some(2.449_489_742_783_178),601Some(2.645_751_311_064_590_7),602Some(1.936_491_673_103_708_5),603Some(1.561_249_499_599_599_6),604Some(1.576_190_026_614_811_4),605]),606EPS607);608assert_allclose!(609ewm_std(XS.to_vec(), ALPHA, false, true, 0, false),610PrimitiveArray::from([611Some(0.0),612Some(2.0),613Some(2.449_489_742_783_178),614Some(2.645_751_311_064_590_7),615Some(1.936_491_673_103_708_5),616Some(1.561_249_499_599_599_6),617Some(1.576_190_026_614_811_4),618]),619EPS620);621assert_allclose!(622ewm_std(XS.to_vec(), ALPHA, false, false, 0, true),623PrimitiveArray::from([624Some(0.0),625Some(2.828_427_124_746_190_3),626Some(3.098_386_676_965_933_6),627Some(3.265_986_323_710_904),628Some(2.376_354_103_144_018_3),629Some(1.913_066_660_344_281_2),630Some(1.930_666_342_865_210_7),631]),632EPS633);634assert_allclose!(635ewm_std(XS.to_vec(), ALPHA, false, false, 0, false),636PrimitiveArray::from([637Some(0.0),638Some(2.828_427_124_746_190_3),639Some(3.098_386_676_965_933_6),640Some(3.265_986_323_710_904),641Some(2.376_354_103_144_018_3),642Some(1.913_066_660_344_281_2),643Some(1.930_666_342_865_210_7),644]),645EPS646);647assert_allclose!(648ewm_std(YS.to_vec(), ALPHA, true, true, 0, true),649PrimitiveArray::from([650None,651Some(0.0),652Some(0.942_809_041_582_063_4),653None,654None,655Some(2.710_523_708_715_753_4),656Some(1.885_618_083_164_126_7),657]),658EPS659);660assert_allclose!(661ewm_std(YS.to_vec(), ALPHA, true, true, 0, false),662PrimitiveArray::from([663None,664Some(0.0),665Some(0.942_809_041_582_063_4),666None,667None,668Some(1.980_514_497_076_503),669Some(1.596_805_731_098_222),670]),671EPS672);673assert_allclose!(674ewm_std(YS.to_vec(), ALPHA, true, false, 0, true),675PrimitiveArray::from([676None,677Some(0.0),678Some(SQRT_2),679None,680None,681Some(3.585_685_828_003_181),682Some(2.390_457_218_668_787),683]),684EPS685);686assert_allclose!(687ewm_std(YS.to_vec(), ALPHA, true, false, 0, false),688PrimitiveArray::from([689None,690Some(0.0),691Some(SQRT_2),692None,693None,694Some(3.762_977_544_445_355_3),695Some(2.244_886_116_891_356),696]),697EPS698);699assert_allclose!(700ewm_std(YS.to_vec(), ALPHA, false, true, 0, true),701PrimitiveArray::from([702None,703Some(0.0),704Some(1.0),705None,706None,707Some(2.598_076_211_353_316),708Some(1.854_049_621_773_915_7),709]),710EPS711);712assert_allclose!(713ewm_std(YS.to_vec(), ALPHA, false, true, 0, false),714PrimitiveArray::from([715None,716Some(0.0),717Some(1.0),718None,719None,720Some(2.049_390_153_191_92),721Some(1.760_681_686_165_901),722]),723EPS724);725assert_allclose!(726ewm_std(YS.to_vec(), ALPHA, false, false, 0, true),727PrimitiveArray::from([728None,729Some(0.0),730Some(SQRT_2),731None,732None,733Some(3.286_335_345_030_997),734Some(2.288_688_541_085_317_5),735]),736EPS737);738assert_allclose!(739ewm_std(YS.to_vec(), ALPHA, false, false, 0, false),740PrimitiveArray::from([741None,742Some(0.0),743Some(SQRT_2),744None,745None,746Some(3.514_675_116_774_036_7),747Some(2.301_987_249_996_250_4),748]),749EPS750);751}752753#[test]754fn test_ewm_min_periods() {755assert_allclose!(756ewm_var(YS.to_vec(), ALPHA, true, true, 0, false),757PrimitiveArray::from([758None,759Some(0.0),760Some(0.888_888_888_888_889),761None,762None,763Some(3.922_437_673_130_193_3),764Some(2.549_788_542_868_127_3),765]),766EPS767);768assert_allclose!(769ewm_var(YS.to_vec(), ALPHA, true, true, 1, false),770PrimitiveArray::from([771None,772Some(0.0),773Some(0.888_888_888_888_889),774None,775None,776Some(3.922_437_673_130_193_3),777Some(2.549_788_542_868_127_3),778]),779EPS780);781assert_allclose!(782ewm_var(YS.to_vec(), ALPHA, true, true, 2, false),783PrimitiveArray::from([784None,785None,786Some(0.888_888_888_888_889),787None,788None,789Some(3.922_437_673_130_193_3),790Some(2.549_788_542_868_127_3),791]),792EPS793);794}795}796797798