Path: blob/main/crates/polars-compute/src/rolling/no_nulls/rank.rs
8480 views
use std::marker::PhantomData;12use polars_utils::IdxSize;3use polars_utils::order_statistic_tree::OrderStatisticTree;45use super::super::rank::*;6use super::*;78#[derive(Debug)]9pub struct RankWindow<'a, T, Out, P>10where11T: NativeType,12Out: NativeType,13P: RankPolicy<T, Out>,14{15slice: &'a [T],16pub(super) start: usize,17pub(super) end: usize,18ost: OrderStatisticTree<&'a T>,19policy: P,20_out: PhantomData<Out>,21}2223impl<T, Out, P> RollingAggWindowNoNulls<T, Out> for RankWindow<'_, T, Out, P>24where25T: NativeType,26Out: NativeType,27P: RankPolicy<T, Out>,28{29type This<'a> = RankWindow<'a, T, Out, P>;3031fn new<'a>(32slice: &'a [T],33start: usize,34end: usize,35params: Option<RollingFnParams>,36window_size: Option<usize>,37) -> Self::This<'a> {38assert!(start <= slice.len() && end <= slice.len() && start <= end);3940let cmp = |a: &&T, b: &&T| T::tot_cmp(*a, *b);41let ost = match window_size {42Some(ws) => OrderStatisticTree::with_capacity(ws, cmp),43None => OrderStatisticTree::new(cmp),44};45let policy = P::new(¶ms.unwrap());46let mut this = RankWindow {47slice,48start: 0,49end: 0,50ost,51policy,52_out: PhantomData,53};5455// SAFETY: We checked that `start` and `end` are in-bounds.56unsafe {57this.update(start, end);58}5960this61}6263unsafe fn update(&mut self, new_start: usize, new_end: usize) {64debug_assert!(self.ost.len() == self.end - self.start);65debug_assert!(self.start <= self.end);66debug_assert!(self.end <= self.slice.len());67debug_assert!(new_start <= new_end);68debug_assert!(new_end <= self.slice.len());69debug_assert!(self.start <= new_start);70debug_assert!(self.end <= new_end);7172for i in self.end..new_end {73self.ost.insert(unsafe { self.slice.get_unchecked(i) });74}75for i in self.start..new_start {76self.ost77.remove(&unsafe { self.slice.get_unchecked(i) })78.expect("previously added value is missing");79}80self.start = new_start;81self.end = new_end;82self.policy.bump_rng();83}8485fn get_agg(&self, idx: usize) -> Option<Out> {86if !(self.start..self.end).contains(&idx) {87return None;88}89self.policy.rank(&self.ost, &self.slice[idx])90}9192fn slice_len(&self) -> usize {93self.slice.len()94}95}9697pub type RankWindowAvg<'a, T> = RankWindow<'a, T, f64, RankPolicyAverage>;98pub type RankWindowMin<'a, T> = RankWindow<'a, T, IdxSize, RankPolicyMin>;99pub type RankWindowMax<'a, T> = RankWindow<'a, T, IdxSize, RankPolicyMax>;100pub type RankWindowDense<'a, T> = RankWindow<'a, T, IdxSize, RankPolicyDense>;101pub type RankWindowRandom<'a, T> = RankWindow<'a, T, IdxSize, RankPolicyRandom>;102103pub fn rolling_rank<T>(104values: &[T],105window_size: usize,106min_periods: usize,107center: bool,108weights: Option<&[f64]>,109params: Option<RollingFnParams>,110) -> PolarsResult<ArrayRef>111where112T: NativeType + num_traits::Num,113{114assert!(weights.is_none(), "weights are not supported for rank");115116let offset_fn = match center {117true => det_offsets_center,118false => det_offsets,119};120let method = if let Some(RollingFnParams::Rank { method, .. }) = params {121method122} else {123unreachable!("expected RollingFnParams::Rank");124};125126match method {127RollingRankMethod::Average => rolling_apply_agg_window::<RankWindowAvg<T>, _, _, _>(128values,129window_size,130min_periods,131offset_fn,132params,133),134RollingRankMethod::Min => rolling_apply_agg_window::<RankWindowMin<T>, _, _, _>(135values,136window_size,137min_periods,138offset_fn,139params,140),141RollingRankMethod::Max => rolling_apply_agg_window::<RankWindowMax<T>, _, _, _>(142values,143window_size,144min_periods,145offset_fn,146params,147),148RollingRankMethod::Dense => rolling_apply_agg_window::<RankWindowDense<T>, _, _, _>(149values,150window_size,151min_periods,152offset_fn,153params,154),155RollingRankMethod::Random => rolling_apply_agg_window::<RankWindowRandom<T>, _, _, _>(156values,157window_size,158min_periods,159offset_fn,160params,161),162}163}164165166