Path: blob/main/crates/polars-compute/src/rolling/no_nulls/rank.rs
7884 views
use std::marker::PhantomData;12use polars_utils::IdxSize;3use polars_utils::order_statistic_tree::OrderStatisticTree;45use super::super::rank::*;6use super::*;78#[derive(Debug)]9pub struct RankWindow<'a, T, Out, P>10where11T: NativeType,12Out: NativeType,13P: RankPolicy<T, Out>,14{15slice: &'a [T],16last_start: usize,17last_end: usize,18ost: OrderStatisticTree<&'a T>,19policy: P,20_out: PhantomData<Out>,21}2223impl<'a, T, Out, P> RollingAggWindowNoNulls<'a, T, Out> for RankWindow<'a, T, Out, P>24where25T: NativeType,26Out: NativeType,27P: RankPolicy<T, Out>,28{29fn new(30slice: &'a [T],31start: usize,32end: usize,33params: Option<RollingFnParams>,34window_size: Option<usize>,35) -> Self {36let cmp = |a: &&T, b: &&T| T::tot_cmp(*a, *b);37let ost = match window_size {38Some(ws) => OrderStatisticTree::with_capacity(ws, cmp),39None => OrderStatisticTree::new(cmp),40};41let policy = P::new(¶ms.unwrap());42let mut slf = Self {43slice,44last_start: 0,45last_end: 0,46ost,47policy,48_out: PhantomData,49};50unsafe {51slf.update(start, end);52}53slf54}5556unsafe fn update(&mut self, new_start: usize, new_end: usize) -> Option<Out> {57debug_assert!(self.ost.len() == self.last_end - self.last_start);58debug_assert!(self.last_start <= self.last_end);59debug_assert!(self.last_end <= self.slice.len());60debug_assert!(new_start <= new_end);61debug_assert!(new_end <= self.slice.len());62debug_assert!(self.last_start <= new_start);63debug_assert!(self.last_end <= new_end);6465for i in self.last_end..new_end {66self.ost.insert(unsafe { self.slice.get_unchecked(i) });67}68for i in self.last_start..new_start {69self.ost70.remove(&unsafe { self.slice.get_unchecked(i) })71.expect("previously added value is missing");72}73self.last_start = new_start;74self.last_end = new_end;75if self.last_end == 0 {76return None;77}78let cur = unsafe { self.slice.get_unchecked(self.last_end - 1) };79self.policy.rank(&self.ost, cur)80}81}8283pub type RankWindowAvg<'a, T> = RankWindow<'a, T, f64, RankPolicyAverage>;84pub type RankWindowMin<'a, T> = RankWindow<'a, T, IdxSize, RankPolicyMin>;85pub type RankWindowMax<'a, T> = RankWindow<'a, T, IdxSize, RankPolicyMax>;86pub type RankWindowDense<'a, T> = RankWindow<'a, T, IdxSize, RankPolicyDense>;87pub type RankWindowRandom<'a, T> = RankWindow<'a, T, IdxSize, RankPolicyRandom>;8889pub fn rolling_rank<T>(90values: &[T],91window_size: usize,92min_periods: usize,93center: bool,94weights: Option<&[f64]>,95params: Option<RollingFnParams>,96) -> PolarsResult<ArrayRef>97where98T: NativeType + num_traits::Num,99{100assert!(weights.is_none(), "weights are not supported for rank");101102let offset_fn = match center {103true => det_offsets_center,104false => det_offsets,105};106let method = if let Some(RollingFnParams::Rank { method, .. }) = params {107method108} else {109unreachable!("expected RollingFnParams::Rank");110};111112match method {113RollingRankMethod::Average => rolling_apply_agg_window::<RankWindowAvg<T>, _, _, _>(114values,115window_size,116min_periods,117offset_fn,118params,119),120RollingRankMethod::Min => rolling_apply_agg_window::<RankWindowMin<T>, _, _, _>(121values,122window_size,123min_periods,124offset_fn,125params,126),127RollingRankMethod::Max => rolling_apply_agg_window::<RankWindowMax<T>, _, _, _>(128values,129window_size,130min_periods,131offset_fn,132params,133),134RollingRankMethod::Dense => rolling_apply_agg_window::<RankWindowDense<T>, _, _, _>(135values,136window_size,137min_periods,138offset_fn,139params,140),141RollingRankMethod::Random => rolling_apply_agg_window::<RankWindowRandom<T>, _, _, _>(142values,143window_size,144min_periods,145offset_fn,146params,147),148}149}150151152