Path: blob/main/crates/polars-compute/src/rolling/no_nulls/quantile.rs
8424 views
#![allow(unsafe_op_in_unsafe_fn)]1use arrow::legacy::utils::CustomIterTools;2use num_traits::ToPrimitive;3use polars_error::polars_ensure;45use super::QuantileMethod::*;6use super::*;7use crate::rolling::quantile_filter::SealedRolling;89pub struct QuantileWindow<'a, T: NativeType> {10sorted: SortedBuf<'a, T>,11prob: f64,12method: QuantileMethod,13}1415impl<16T: NativeType17+ Float18+ std::iter::Sum19+ AddAssign20+ SubAssign21+ Div<Output = T>22+ NumCast23+ One24+ Zero25+ SealedRolling26+ Sub<Output = T>,27> RollingAggWindowNoNulls<T> for QuantileWindow<'_, T>28{29type This<'a> = QuantileWindow<'a, T>;3031fn new<'a>(32slice: &'a [T],33start: usize,34end: usize,35params: Option<RollingFnParams>,36window_size: Option<usize>,37) -> Self::This<'a> {38let params = params.unwrap();39let RollingFnParams::Quantile(params) = params else {40unreachable!("expected Quantile params");41};4243QuantileWindow {44sorted: SortedBuf::new(slice, start, end, window_size),45prob: params.prob,46method: params.method,47}48}4950unsafe fn update(&mut self, start: usize, end: usize) {51self.sorted.update(start, end);52}5354fn get_agg(&self, _idx: usize) -> Option<T> {55let length = self.sorted.len();56if length == 0 {57return None;58}59let idx = match self.method {60Linear => {61// Maybe add a fast path for median case? They could branch depending on odd/even.62let length_f = length as f64;63let idx = ((length_f - 1.0) * self.prob).floor() as usize;6465let float_idx_top = (length_f - 1.0) * self.prob;66let top_idx = float_idx_top.ceil() as usize;67return if idx == top_idx {68Some(self.sorted.get(idx))69} else {70let proportion = T::from(float_idx_top - idx as f64).unwrap();71let vi = self.sorted.get(idx);72let vj = self.sorted.get(idx + 1);73Some(proportion * (vj - vi) + vi)74};75},76Midpoint => {77let length_f = length as f64;7879let idx = ((length_f - 1.0) * self.prob).floor() as usize;80let idx = std::cmp::min(idx, length - 1);81let top_idx = ((length_f - 1.0) * self.prob).ceil() as usize;8283return if top_idx == idx {84Some(self.sorted.get(idx))85} else {86let mid = self.sorted.get(idx);87let mid_plus_1 = self.sorted.get(idx + 1);88Some((mid + mid_plus_1) / (T::one() + T::one()))89};90},91Nearest => {92let idx = (((length as f64) - 1.0) * self.prob).round() as usize;93std::cmp::min(idx, length - 1)94},95Lower => ((length as f64 - 1.0) * self.prob).floor() as usize,96Higher => {97let idx = ((length as f64 - 1.0) * self.prob).ceil() as usize;98std::cmp::min(idx, length - 1)99},100Equiprobable => ((length as f64 * self.prob).ceil() - 1.0).max(0.0) as usize,101};102103Some(self.sorted.get(idx))104}105106fn slice_len(&self) -> usize {107self.sorted.slice_len()108}109}110111pub fn rolling_quantile<T>(112values: &[T],113window_size: usize,114min_periods: usize,115center: bool,116weights: Option<&[f64]>,117params: Option<RollingFnParams>,118) -> PolarsResult<ArrayRef>119where120T: NativeType121+ IsFloat122+ Float123+ std::iter::Sum124+ AddAssign125+ SubAssign126+ Div<Output = T>127+ NumCast128+ One129+ Zero130+ SealedRolling131+ PartialOrd132+ Sub<Output = T>,133{134let offset_fn = match center {135true => det_offsets_center,136false => det_offsets,137};138match weights {139None => {140if !center {141let params = params.as_ref().unwrap();142let RollingFnParams::Quantile(params) = params else {143unreachable!("expected Quantile params");144};145let out = super::quantile_filter::rolling_quantile::<_, Vec<_>>(146params.method,147min_periods,148window_size,149values,150params.prob,151);152let validity = create_validity(min_periods, values.len(), window_size, offset_fn);153return Ok(Box::new(PrimitiveArray::new(154T::PRIMITIVE.into(),155out.into(),156validity.map(|b| b.into()),157)));158}159160rolling_apply_agg_window::<QuantileWindow<_>, _, _, _>(161values,162window_size,163min_periods,164offset_fn,165params,166)167},168Some(weights) => {169let wsum = weights.iter().sum();170polars_ensure!(171wsum != 0.0,172ComputeError: "Weighted quantile is undefined if weights sum to 0"173);174let params = params.unwrap();175let RollingFnParams::Quantile(params) = params else {176unreachable!("expected Quantile params");177};178179Ok(rolling_apply_weighted_quantile(180values,181params.prob,182params.method,183window_size,184min_periods,185offset_fn,186weights,187wsum,188))189},190}191}192193#[inline]194fn compute_wq<T>(buf: &[(T, f64)], p: f64, wsum: f64, method: QuantileMethod) -> T195where196T: Debug + NativeType + Mul<Output = T> + Sub<Output = T> + NumCast + ToPrimitive + Zero,197{198// There are a few ways to compute a weighted quantile but no "canonical" way.199// This is mostly taken from the Julia implementation which was readable and reasonable200// https://juliastats.org/StatsBase.jl/stable/scalarstats/#Quantile-and-Related-Functions-1201let (mut s, mut s_old, mut vk, mut v_old) = (0.0, 0.0, T::zero(), T::zero());202203// Once the cumulative weight crosses h, we've found our ind{ex/ices}. The definition may look204// odd but it's the equivalent of taking h = p * (n - 1) + 1 if your data is indexed from 1.205let h: f64 = p * (wsum - buf[0].1) + buf[0].1;206for &(v, w) in buf.iter() {207if s > h {208break;209}210(s_old, v_old, vk) = (s, vk, v);211s += w;212}213match (h == s_old, method) {214(true, _) => v_old, // If we hit the break exactly interpolation shouldn't matter215(_, Lower) => v_old,216(_, Higher) => vk,217(_, Nearest) => {218if s - h > h - s_old {219v_old220} else {221vk222}223},224(_, Equiprobable) => {225let threshold = (wsum * p).ceil() - 1.0;226if s > threshold { vk } else { v_old }227},228(_, Midpoint) => (vk + v_old) * NumCast::from(0.5).unwrap(),229// This is seemingly the canonical way to do it.230(_, Linear) => {231v_old + <T as NumCast>::from((h - s_old) / (s - s_old)).unwrap() * (vk - v_old)232},233}234}235236#[allow(clippy::too_many_arguments)]237fn rolling_apply_weighted_quantile<T, Fo>(238values: &[T],239p: f64,240method: QuantileMethod,241window_size: usize,242min_periods: usize,243det_offsets_fn: Fo,244weights: &[f64],245wsum: f64,246) -> ArrayRef247where248Fo: Fn(Idx, WindowSize, Len) -> (Start, End),249T: Debug + NativeType + Mul<Output = T> + Sub<Output = T> + NumCast + ToPrimitive + Zero,250{251assert_eq!(weights.len(), window_size);252// Keep nonzero weights and their indices to know which values we need each iteration.253let nz_idx_wts: Vec<_> = weights.iter().enumerate().filter(|x| x.1 != &0.0).collect();254let mut buf = vec![(T::zero(), 0.0); nz_idx_wts.len()];255let len = values.len();256let out = (0..len)257.map(|idx| {258// Don't need end. Window size is constant and we computed offsets from start above.259let (start, _) = det_offsets_fn(idx, window_size, len);260261// Sorting is not ideal, see https://github.com/tobiasschoch/wquantile for something faster262unsafe {263buf.iter_mut()264.zip(nz_idx_wts.iter())265.for_each(|(b, (i, w))| *b = (*values.get_unchecked(i + start), **w));266}267buf.sort_unstable_by(|&a, &b| a.0.tot_cmp(&b.0));268compute_wq(&buf, p, wsum, method)269})270.collect_trusted::<Vec<T>>();271272let validity = create_validity(min_periods, len, window_size, det_offsets_fn);273Box::new(PrimitiveArray::new(274T::PRIMITIVE.into(),275out.into(),276validity.map(|b| b.into()),277))278}279280#[cfg(test)]281mod test {282use super::*;283284#[test]285fn test_rolling_median() {286let values = &[1.0, 2.0, 3.0, 4.0];287let med_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {288prob: 0.5,289method: Linear,290}));291let out = rolling_quantile(values, 2, 2, false, None, med_pars).unwrap();292let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();293let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();294assert_eq!(out, &[None, Some(1.5), Some(2.5), Some(3.5)]);295296let out = rolling_quantile(values, 2, 1, false, None, med_pars).unwrap();297let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();298let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();299assert_eq!(out, &[Some(1.0), Some(1.5), Some(2.5), Some(3.5)]);300301let out = rolling_quantile(values, 4, 1, false, None, med_pars).unwrap();302let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();303let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();304assert_eq!(out, &[Some(1.0), Some(1.5), Some(2.0), Some(2.5)]);305306let out = rolling_quantile(values, 4, 1, true, None, med_pars).unwrap();307let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();308let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();309assert_eq!(out, &[Some(1.5), Some(2.0), Some(2.5), Some(3.0)]);310311let out = rolling_quantile(values, 4, 4, true, None, med_pars).unwrap();312let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();313let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();314assert_eq!(out, &[None, None, Some(2.5), None]);315}316317#[test]318fn test_rolling_quantile_limits() {319let values = &[1.0f64, 2.0, 3.0, 4.0];320321let methods = vec![322QuantileMethod::Lower,323QuantileMethod::Higher,324QuantileMethod::Nearest,325QuantileMethod::Midpoint,326QuantileMethod::Linear,327QuantileMethod::Equiprobable,328];329330for method in methods {331let min_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {332prob: 0.0,333method,334}));335let out1 = rolling_min(values, 2, 2, false, None, None).unwrap();336let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();337let out1 = out1.into_iter().map(|v| v.copied()).collect::<Vec<_>>();338let out2 = rolling_quantile(values, 2, 2, false, None, min_pars).unwrap();339let out2 = out2.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();340let out2 = out2.into_iter().map(|v| v.copied()).collect::<Vec<_>>();341assert_eq!(out1, out2);342343let max_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {344prob: 1.0,345method,346}));347let out1 = rolling_max(values, 2, 2, false, None, None).unwrap();348let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();349let out1 = out1.into_iter().map(|v| v.copied()).collect::<Vec<_>>();350let out2 = rolling_quantile(values, 2, 2, false, None, max_pars).unwrap();351let out2 = out2.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();352let out2 = out2.into_iter().map(|v| v.copied()).collect::<Vec<_>>();353assert_eq!(out1, out2);354}355}356}357358359