Path: blob/main/crates/polars-compute/src/rolling/nulls/quantile.rs
6939 views
#![allow(unsafe_op_in_unsafe_fn)]1use super::*;2use crate::rolling::quantile_filter::SealedRolling;34pub struct QuantileWindow<'a, T: NativeType + IsFloat + PartialOrd> {5sorted: SortedBufNulls<'a, T>,6prob: f64,7method: QuantileMethod,8}910impl<11'a,12T: NativeType13+ IsFloat14+ Float15+ std::iter::Sum16+ AddAssign17+ SubAssign18+ Div<Output = T>19+ NumCast20+ One21+ Zero22+ SealedRolling23+ PartialOrd24+ Sub<Output = T>,25> RollingAggWindowNulls<'a, T> for QuantileWindow<'a, T>26{27unsafe fn new(28slice: &'a [T],29validity: &'a Bitmap,30start: usize,31end: usize,32params: Option<RollingFnParams>,33window_size: Option<usize>,34) -> Self {35let params = params.unwrap();36let RollingFnParams::Quantile(params) = params else {37unreachable!("expected Quantile params");38};39Self {40sorted: SortedBufNulls::new(slice, validity, start, end, window_size),41prob: params.prob,42method: params.method,43}44}4546unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {47let null_count = self.sorted.update(start, end);48let mut length = self.sorted.len();49// The min periods_issue will be taken care of when actually rolling50if null_count == length {51return None;52}53// Nulls are guaranteed to be at the front54length -= null_count;55let mut idx = match self.method {56QuantileMethod::Nearest => (((length as f64) - 1.0) * self.prob).round() as usize,57QuantileMethod::Lower | QuantileMethod::Midpoint | QuantileMethod::Linear => {58((length as f64 - 1.0) * self.prob).floor() as usize59},60QuantileMethod::Higher => ((length as f64 - 1.0) * self.prob).ceil() as usize,61QuantileMethod::Equiprobable => {62((length as f64 * self.prob).ceil() - 1.0).max(0.0) as usize63},64};6566idx = std::cmp::min(idx, length - 1);6768// we can unwrap because we sliced of the nulls69match self.method {70QuantileMethod::Midpoint => {71let top_idx = ((length as f64 - 1.0) * self.prob).ceil() as usize;7273debug_assert!(idx <= top_idx);74let v = if idx != top_idx {75let mut vals = self76.sorted77.index_range(idx + null_count..top_idx + null_count + 1);78let low = vals.next().unwrap().unwrap();79let high = vals.next().unwrap().unwrap();80(low + high) / T::from::<f64>(2.0f64).unwrap()81} else {82self.sorted.get(idx + null_count).unwrap()83};8485Some(v)86},87QuantileMethod::Linear => {88let float_idx = (length as f64 - 1.0) * self.prob;89let top_idx = f64::ceil(float_idx) as usize;9091if top_idx == idx {92Some(self.sorted.get(idx + null_count).unwrap())93} else {94let mut vals = self95.sorted96.index_range(idx + null_count..top_idx + null_count + 1);97let low = vals.next().unwrap().unwrap();98let high = vals.next().unwrap().unwrap();99100let proportion = T::from(float_idx - idx as f64).unwrap();101Some(proportion * (high - low) + low)102}103},104_ => Some(self.sorted.get(idx + null_count).unwrap()),105}106}107108fn is_valid(&self, min_periods: usize) -> bool {109self.sorted.is_valid(min_periods)110}111}112113pub fn rolling_quantile<T>(114arr: &PrimitiveArray<T>,115window_size: usize,116min_periods: usize,117center: bool,118weights: Option<&[f64]>,119params: Option<RollingFnParams>,120) -> ArrayRef121where122T: NativeType123+ IsFloat124+ Float125+ std::iter::Sum126+ AddAssign127+ SubAssign128+ Div<Output = T>129+ NumCast130+ One131+ Zero132+ SealedRolling133+ PartialOrd134+ Sub<Output = T>,135{136if weights.is_some() {137panic!("weights not yet supported on array with null values")138}139let offset_fn = match center {140true => det_offsets_center,141false => det_offsets,142};143/*144TODO: fix or remove the dancing links based rolling implementation145see https://github.com/pola-rs/polars/issues/23480146if !center {147let params = params.as_ref().unwrap();148let RollingFnParams::Quantile(params) = params else {149unreachable!("expected Quantile params");150};151152let out = super::quantile_filter::rolling_quantile::<_, MutablePrimitiveArray<_>>(153params.method,154min_periods,155window_size,156arr.clone(),157params.prob,158);159let out: PrimitiveArray<T> = out.into();160return Box::new(out);161}162*/163rolling_apply_agg_window::<QuantileWindow<_>, _, _>(164arr.values().as_slice(),165arr.validity().as_ref().unwrap(),166window_size,167min_periods,168offset_fn,169params,170)171}172173#[cfg(test)]174mod test {175use arrow::buffer::Buffer;176use arrow::datatypes::ArrowDataType;177178use super::*;179180#[test]181fn test_rolling_median_nulls() {182let buf = Buffer::from(vec![1.0, 2.0, 3.0, 4.0]);183let arr = &PrimitiveArray::new(184ArrowDataType::Float64,185buf,186Some(Bitmap::from(&[true, false, true, true])),187);188let med_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {189prob: 0.5,190method: QuantileMethod::Linear,191}));192193let out = rolling_quantile(arr, 2, 2, false, None, med_pars);194let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();195let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();196assert_eq!(out, &[None, None, None, Some(3.5)]);197198let out = rolling_quantile(arr, 2, 1, false, None, med_pars);199let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();200let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();201assert_eq!(out, &[Some(1.0), Some(1.0), Some(3.0), Some(3.5)]);202203let out = rolling_quantile(arr, 4, 1, false, None, med_pars);204let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();205let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();206assert_eq!(out, &[Some(1.0), Some(1.0), Some(2.0), Some(3.0)]);207208let out = rolling_quantile(arr, 4, 1, true, None, med_pars);209let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();210let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();211assert_eq!(out, &[Some(1.0), Some(2.0), Some(3.0), Some(3.5)]);212213let out = rolling_quantile(arr, 4, 4, true, None, med_pars);214let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();215let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();216assert_eq!(out, &[None, None, None, None]);217}218219#[test]220fn test_rolling_quantile_nulls_limits() {221// compare quantiles to corresponding min/max/median values222let buf = Buffer::<f64>::from(vec![1.0, 2.0, 3.0, 4.0, 5.0]);223let values = &PrimitiveArray::new(224ArrowDataType::Float64,225buf,226Some(Bitmap::from(&[true, false, false, true, true])),227);228229let methods = vec![230QuantileMethod::Lower,231QuantileMethod::Higher,232QuantileMethod::Nearest,233QuantileMethod::Midpoint,234QuantileMethod::Linear,235QuantileMethod::Equiprobable,236];237238for method in methods {239let min_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {240prob: 0.0,241method,242}));243let out1 = rolling_min(values, 2, 1, false, None, None);244let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();245let out1 = out1.into_iter().map(|v| v.copied()).collect::<Vec<_>>();246let out2 = rolling_quantile(values, 2, 1, false, None, min_pars);247let out2 = out2.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();248let out2 = out2.into_iter().map(|v| v.copied()).collect::<Vec<_>>();249assert_eq!(out1, out2);250251let max_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {252prob: 1.0,253method,254}));255let out1 = rolling_max(values, 2, 1, false, None, None);256let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();257let out1 = out1.into_iter().map(|v| v.copied()).collect::<Vec<_>>();258let out2 = rolling_quantile(values, 2, 1, false, None, max_pars);259let out2 = out2.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();260let out2 = out2.into_iter().map(|v| v.copied()).collect::<Vec<_>>();261assert_eq!(out1, out2);262}263}264}265266267