Path: blob/main/crates/polars-compute/src/rolling/nulls/quantile.rs
8421 views
#![allow(unsafe_op_in_unsafe_fn)]1use super::*;2use crate::rolling::quantile_filter::SealedRolling;34pub struct QuantileWindow<'a, T: NativeType + IsFloat + PartialOrd> {5sorted: SortedBufNulls<'a, T>,6prob: f64,7method: QuantileMethod,8}910impl<11T: NativeType12+ IsFloat13+ Float14+ std::iter::Sum15+ AddAssign16+ SubAssign17+ Div<Output = T>18+ NumCast19+ One20+ Zero21+ SealedRolling22+ PartialOrd23+ Sub<Output = T>,24> RollingAggWindowNulls<T> for QuantileWindow<'_, T>25{26type This<'a> = QuantileWindow<'a, T>;2728fn new<'a>(29slice: &'a [T],30validity: &'a Bitmap,31start: usize,32end: usize,33params: Option<RollingFnParams>,34window_size: Option<usize>,35) -> Self::This<'a> {36let params = params.unwrap();37let RollingFnParams::Quantile(params) = params else {38unreachable!("expected Quantile params");39};40QuantileWindow {41sorted: SortedBufNulls::new(slice, validity, start, end, window_size),42prob: params.prob,43method: params.method,44}45}4647unsafe fn update(&mut self, new_start: usize, new_end: usize) {48self.sorted.update(new_start, new_end);49}5051fn get_agg(&self, _idx: usize) -> Option<T> {52let mut length = self.sorted.len();53let null_count = self.sorted.null_count;5455// The min periods_issue will be taken care of when actually rolling56if null_count == length {57return None;58}59// Nulls are guaranteed to be at the front60length -= null_count;61let mut idx = match self.method {62QuantileMethod::Nearest => (((length as f64) - 1.0) * self.prob).round() as usize,63QuantileMethod::Lower | QuantileMethod::Midpoint | QuantileMethod::Linear => {64((length as f64 - 1.0) * self.prob).floor() as usize65},66QuantileMethod::Higher => ((length as f64 - 1.0) * self.prob).ceil() as usize,67QuantileMethod::Equiprobable => {68((length as f64 * self.prob).ceil() - 1.0).max(0.0) as usize69},70};7172idx = std::cmp::min(idx, length - 1);7374// we can unwrap because we sliced of the nulls75match self.method {76QuantileMethod::Midpoint => {77let top_idx = ((length as f64 - 1.0) * self.prob).ceil() as usize;7879debug_assert!(idx <= top_idx);80let v = if idx != top_idx {81let low = self.sorted.get(idx + null_count).unwrap();82let high = self.sorted.get(idx + null_count + 1).unwrap();83(low + high) / T::from::<f64>(2.0f64).unwrap()84} else {85self.sorted.get(idx + null_count).unwrap()86};8788Some(v)89},90QuantileMethod::Linear => {91let float_idx = (length as f64 - 1.0) * self.prob;92let top_idx = f64::ceil(float_idx) as usize;9394if top_idx == idx {95Some(self.sorted.get(idx + null_count).unwrap())96} else {97let low = self.sorted.get(idx + null_count).unwrap();98let high = self.sorted.get(top_idx + null_count).unwrap();99let proportion = T::from(float_idx - idx as f64).unwrap();100Some(proportion * (high - low) + low)101}102},103_ => Some(self.sorted.get(idx + null_count).unwrap()),104}105}106107fn is_valid(&self, min_periods: usize) -> bool {108self.sorted.is_valid(min_periods)109}110111fn slice_len(&self) -> usize {112self.sorted.slice_len()113}114}115116pub fn rolling_quantile<T>(117arr: &PrimitiveArray<T>,118window_size: usize,119min_periods: usize,120center: bool,121weights: Option<&[f64]>,122params: Option<RollingFnParams>,123) -> ArrayRef124where125T: NativeType126+ IsFloat127+ Float128+ std::iter::Sum129+ AddAssign130+ SubAssign131+ Div<Output = T>132+ NumCast133+ One134+ Zero135+ SealedRolling136+ PartialOrd137+ Sub<Output = T>,138{139if weights.is_some() {140panic!("weights not yet supported on array with null values")141}142let offset_fn = match center {143true => det_offsets_center,144false => det_offsets,145};146/*147TODO: fix or remove the dancing links based rolling implementation148see https://github.com/pola-rs/polars/issues/23480149if !center {150let params = params.as_ref().unwrap();151let RollingFnParams::Quantile(params) = params else {152unreachable!("expected Quantile params");153};154155let out = super::quantile_filter::rolling_quantile::<_, MutablePrimitiveArray<_>>(156params.method,157min_periods,158window_size,159arr.clone(),160params.prob,161);162let out: PrimitiveArray<T> = out.into();163return Box::new(out);164}165*/166rolling_apply_agg_window::<QuantileWindow<T>, _, _, _>(167arr.values().as_slice(),168arr.validity().as_ref().unwrap(),169window_size,170min_periods,171offset_fn,172params,173)174}175176#[cfg(test)]177mod test {178use arrow::datatypes::ArrowDataType;179use polars_buffer::Buffer;180181use super::*;182183#[test]184fn test_rolling_median_nulls() {185let buf = Buffer::from(vec![1.0, 2.0, 3.0, 4.0]);186let arr = &PrimitiveArray::new(187ArrowDataType::Float64,188buf,189Some(Bitmap::from(&[true, false, true, true])),190);191let med_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {192prob: 0.5,193method: QuantileMethod::Linear,194}));195196let out = rolling_quantile(arr, 2, 2, false, None, med_pars);197let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();198let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();199assert_eq!(out, &[None, None, None, Some(3.5)]);200201let out = rolling_quantile(arr, 2, 1, false, None, med_pars);202let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();203let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();204assert_eq!(out, &[Some(1.0), Some(1.0), Some(3.0), Some(3.5)]);205206let out = rolling_quantile(arr, 4, 1, false, None, med_pars);207let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();208let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();209assert_eq!(out, &[Some(1.0), Some(1.0), Some(2.0), Some(3.0)]);210211let out = rolling_quantile(arr, 4, 1, true, None, med_pars);212let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();213let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();214assert_eq!(out, &[Some(1.0), Some(2.0), Some(3.0), Some(3.5)]);215216let out = rolling_quantile(arr, 4, 4, true, None, med_pars);217let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();218let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();219assert_eq!(out, &[None, None, None, None]);220}221222#[test]223fn test_rolling_quantile_nulls_limits() {224// compare quantiles to corresponding min/max/median values225let buf = Buffer::<f64>::from(vec![1.0, 2.0, 3.0, 4.0, 5.0]);226let values = &PrimitiveArray::new(227ArrowDataType::Float64,228buf,229Some(Bitmap::from(&[true, false, false, true, true])),230);231232let methods = vec![233QuantileMethod::Lower,234QuantileMethod::Higher,235QuantileMethod::Nearest,236QuantileMethod::Midpoint,237QuantileMethod::Linear,238QuantileMethod::Equiprobable,239];240241for method in methods {242let min_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {243prob: 0.0,244method,245}));246let out1 = rolling_min(values, 2, 1, false, None, None);247let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();248let out1 = out1.into_iter().map(|v| v.copied()).collect::<Vec<_>>();249let out2 = rolling_quantile(values, 2, 1, false, None, min_pars);250let out2 = out2.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();251let out2 = out2.into_iter().map(|v| v.copied()).collect::<Vec<_>>();252assert_eq!(out1, out2);253254let max_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {255prob: 1.0,256method,257}));258let out1 = rolling_max(values, 2, 1, false, None, None);259let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();260let out1 = out1.into_iter().map(|v| v.copied()).collect::<Vec<_>>();261let out2 = rolling_quantile(values, 2, 1, false, None, max_pars);262let out2 = out2.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();263let out2 = out2.into_iter().map(|v| v.copied()).collect::<Vec<_>>();264assert_eq!(out1, out2);265}266}267}268269270