Path: blob/main/crates/polars-compute/src/rolling/no_nulls/quantile.rs
6939 views
#![allow(unsafe_op_in_unsafe_fn)]1use arrow::legacy::utils::CustomIterTools;2use num_traits::ToPrimitive;3use polars_error::polars_ensure;45use super::QuantileMethod::*;6use super::*;7use crate::rolling::quantile_filter::SealedRolling;89pub struct QuantileWindow<'a, T: NativeType> {10sorted: SortedBuf<'a, T>,11prob: f64,12method: QuantileMethod,13}1415impl<16'a,17T: NativeType18+ Float19+ std::iter::Sum20+ AddAssign21+ SubAssign22+ Div<Output = T>23+ NumCast24+ One25+ Zero26+ SealedRolling27+ Sub<Output = T>,28> RollingAggWindowNoNulls<'a, T> for QuantileWindow<'a, T>29{30fn new(31slice: &'a [T],32start: usize,33end: usize,34params: Option<RollingFnParams>,35window_size: Option<usize>,36) -> Self {37let params = params.unwrap();38let RollingFnParams::Quantile(params) = params else {39unreachable!("expected Quantile params");40};4142Self {43sorted: SortedBuf::new(slice, start, end, window_size),44prob: params.prob,45method: params.method,46}47}4849unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {50self.sorted.update(start, end);51let length = self.sorted.len();5253let idx = match self.method {54Linear => {55// Maybe add a fast path for median case? They could branch depending on odd/even.56let length_f = length as f64;57let idx = ((length_f - 1.0) * self.prob).floor() as usize;5859let float_idx_top = (length_f - 1.0) * self.prob;60let top_idx = float_idx_top.ceil() as usize;61return if idx == top_idx {62Some(self.sorted.get(idx))63} else {64let proportion = T::from(float_idx_top - idx as f64).unwrap();65let mut vals = self.sorted.index_range(idx..top_idx + 1);66let vi = *vals.next().unwrap();67let vj = *vals.next().unwrap();6869Some(proportion * (vj - vi) + vi)70};71},72Midpoint => {73let length_f = length as f64;74let idx = (length_f * self.prob) as usize;75let idx = std::cmp::min(idx, length - 1);7677let top_idx = ((length_f - 1.0) * self.prob).ceil() as usize;78return if top_idx == idx {79Some(self.sorted.get(idx))80} else {81let top_idx = idx + 1;82let mut vals = self.sorted.index_range(idx..top_idx + 1);83let mid = *vals.next().unwrap();84let mid_plus_1 = *vals.next().unwrap();8586Some((mid + mid_plus_1) / (T::one() + T::one()))87};88},89Nearest => {90let idx = (((length as f64) - 1.0) * self.prob).round() as usize;91std::cmp::min(idx, length - 1)92},93Lower => ((length as f64 - 1.0) * self.prob).floor() as usize,94Higher => {95let idx = ((length as f64 - 1.0) * self.prob).ceil() as usize;96std::cmp::min(idx, length - 1)97},98Equiprobable => ((length as f64 * self.prob).ceil() - 1.0).max(0.0) as usize,99};100101Some(self.sorted.get(idx))102}103}104105pub fn rolling_quantile<T>(106values: &[T],107window_size: usize,108min_periods: usize,109center: bool,110weights: Option<&[f64]>,111params: Option<RollingFnParams>,112) -> PolarsResult<ArrayRef>113where114T: NativeType115+ IsFloat116+ Float117+ std::iter::Sum118+ AddAssign119+ SubAssign120+ Div<Output = T>121+ NumCast122+ One123+ Zero124+ SealedRolling125+ PartialOrd126+ Sub<Output = T>,127{128let offset_fn = match center {129true => det_offsets_center,130false => det_offsets,131};132match weights {133None => {134if !center {135let params = params.as_ref().unwrap();136let RollingFnParams::Quantile(params) = params else {137unreachable!("expected Quantile params");138};139let out = super::quantile_filter::rolling_quantile::<_, Vec<_>>(140params.method,141min_periods,142window_size,143values,144params.prob,145);146let validity = create_validity(min_periods, values.len(), window_size, offset_fn);147return Ok(Box::new(PrimitiveArray::new(148T::PRIMITIVE.into(),149out.into(),150validity.map(|b| b.into()),151)));152}153154rolling_apply_agg_window::<QuantileWindow<_>, _, _>(155values,156window_size,157min_periods,158offset_fn,159params,160)161},162Some(weights) => {163let wsum = weights.iter().sum();164polars_ensure!(165wsum != 0.0,166ComputeError: "Weighted quantile is undefined if weights sum to 0"167);168let params = params.unwrap();169let RollingFnParams::Quantile(params) = params else {170unreachable!("expected Quantile params");171};172173Ok(rolling_apply_weighted_quantile(174values,175params.prob,176params.method,177window_size,178min_periods,179offset_fn,180weights,181wsum,182))183},184}185}186187#[inline]188fn compute_wq<T>(buf: &[(T, f64)], p: f64, wsum: f64, method: QuantileMethod) -> T189where190T: Debug + NativeType + Mul<Output = T> + Sub<Output = T> + NumCast + ToPrimitive + Zero,191{192// There are a few ways to compute a weighted quantile but no "canonical" way.193// This is mostly taken from the Julia implementation which was readable and reasonable194// https://juliastats.org/StatsBase.jl/stable/scalarstats/#Quantile-and-Related-Functions-1195let (mut s, mut s_old, mut vk, mut v_old) = (0.0, 0.0, T::zero(), T::zero());196197// Once the cumulative weight crosses h, we've found our ind{ex/ices}. The definition may look198// odd but it's the equivalent of taking h = p * (n - 1) + 1 if your data is indexed from 1.199let h: f64 = p * (wsum - buf[0].1) + buf[0].1;200for &(v, w) in buf.iter() {201if s > h {202break;203}204(s_old, v_old, vk) = (s, vk, v);205s += w;206}207match (h == s_old, method) {208(true, _) => v_old, // If we hit the break exactly interpolation shouldn't matter209(_, Lower) => v_old,210(_, Higher) => vk,211(_, Nearest) => {212if s - h > h - s_old {213v_old214} else {215vk216}217},218(_, Equiprobable) => {219let threshold = (wsum * p).ceil() - 1.0;220if s > threshold { vk } else { v_old }221},222(_, Midpoint) => (vk + v_old) * NumCast::from(0.5).unwrap(),223// This is seemingly the canonical way to do it.224(_, Linear) => {225v_old + <T as NumCast>::from((h - s_old) / (s - s_old)).unwrap() * (vk - v_old)226},227}228}229230#[allow(clippy::too_many_arguments)]231fn rolling_apply_weighted_quantile<T, Fo>(232values: &[T],233p: f64,234method: QuantileMethod,235window_size: usize,236min_periods: usize,237det_offsets_fn: Fo,238weights: &[f64],239wsum: f64,240) -> ArrayRef241where242Fo: Fn(Idx, WindowSize, Len) -> (Start, End),243T: Debug + NativeType + Mul<Output = T> + Sub<Output = T> + NumCast + ToPrimitive + Zero,244{245assert_eq!(weights.len(), window_size);246// Keep nonzero weights and their indices to know which values we need each iteration.247let nz_idx_wts: Vec<_> = weights.iter().enumerate().filter(|x| x.1 != &0.0).collect();248let mut buf = vec![(T::zero(), 0.0); nz_idx_wts.len()];249let len = values.len();250let out = (0..len)251.map(|idx| {252// Don't need end. Window size is constant and we computed offsets from start above.253let (start, _) = det_offsets_fn(idx, window_size, len);254255// Sorting is not ideal, see https://github.com/tobiasschoch/wquantile for something faster256unsafe {257buf.iter_mut()258.zip(nz_idx_wts.iter())259.for_each(|(b, (i, w))| *b = (*values.get_unchecked(i + start), **w));260}261buf.sort_unstable_by(|&a, &b| a.0.tot_cmp(&b.0));262compute_wq(&buf, p, wsum, method)263})264.collect_trusted::<Vec<T>>();265266let validity = create_validity(min_periods, len, window_size, det_offsets_fn);267Box::new(PrimitiveArray::new(268T::PRIMITIVE.into(),269out.into(),270validity.map(|b| b.into()),271))272}273274#[cfg(test)]275mod test {276use super::*;277278#[test]279fn test_rolling_median() {280let values = &[1.0, 2.0, 3.0, 4.0];281let med_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {282prob: 0.5,283method: Linear,284}));285let out = rolling_quantile(values, 2, 2, false, None, med_pars).unwrap();286let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();287let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();288assert_eq!(out, &[None, Some(1.5), Some(2.5), Some(3.5)]);289290let out = rolling_quantile(values, 2, 1, false, None, med_pars).unwrap();291let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();292let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();293assert_eq!(out, &[Some(1.0), Some(1.5), Some(2.5), Some(3.5)]);294295let out = rolling_quantile(values, 4, 1, false, None, med_pars).unwrap();296let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();297let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();298assert_eq!(out, &[Some(1.0), Some(1.5), Some(2.0), Some(2.5)]);299300let out = rolling_quantile(values, 4, 1, true, None, med_pars).unwrap();301let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();302let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();303assert_eq!(out, &[Some(1.5), Some(2.0), Some(2.5), Some(3.0)]);304305let out = rolling_quantile(values, 4, 4, true, None, med_pars).unwrap();306let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();307let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();308assert_eq!(out, &[None, None, Some(2.5), None]);309}310311#[test]312fn test_rolling_quantile_limits() {313let values = &[1.0f64, 2.0, 3.0, 4.0];314315let methods = vec![316QuantileMethod::Lower,317QuantileMethod::Higher,318QuantileMethod::Nearest,319QuantileMethod::Midpoint,320QuantileMethod::Linear,321QuantileMethod::Equiprobable,322];323324for method in methods {325let min_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {326prob: 0.0,327method,328}));329let out1 = rolling_min(values, 2, 2, false, None, None).unwrap();330let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();331let out1 = out1.into_iter().map(|v| v.copied()).collect::<Vec<_>>();332let out2 = rolling_quantile(values, 2, 2, false, None, min_pars).unwrap();333let out2 = out2.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();334let out2 = out2.into_iter().map(|v| v.copied()).collect::<Vec<_>>();335assert_eq!(out1, out2);336337let max_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {338prob: 1.0,339method,340}));341let out1 = rolling_max(values, 2, 2, false, None, None).unwrap();342let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();343let out1 = out1.into_iter().map(|v| v.copied()).collect::<Vec<_>>();344let out2 = rolling_quantile(values, 2, 2, false, None, max_pars).unwrap();345let out2 = out2.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();346let out2 = out2.into_iter().map(|v| v.copied()).collect::<Vec<_>>();347assert_eq!(out1, out2);348}349}350}351352353