Path: blob/main/crates/polars-compute/src/rolling/nulls/rank.rs
7884 views
use std::marker::PhantomData;12use polars_utils::IdxSize;3use polars_utils::order_statistic_tree::OrderStatisticTree;45use super::super::rank::*;6use super::*;78pub struct RankWindow<'a, T, Out, P> {9slice: &'a [T],10validity: &'a Bitmap,11last_start: usize,12last_end: usize,13ost: OrderStatisticTree<&'a T>,14policy: P,15_out: PhantomData<Out>,16}1718impl<'a, T, Out, P> RollingAggWindowNulls<'a, T, Out> for RankWindow<'a, T, Out, P>19where20T: NativeType,21Out: NativeType,22P: RankPolicy<T, Out>,23{24unsafe fn new(25slice: &'a [T],26validity: &'a Bitmap,27start: usize,28end: usize,29params: Option<RollingFnParams>,30window_size: Option<usize>,31) -> Self {32let cmp = |a: &&T, b: &&T| T::tot_cmp(*a, *b);33let ost: OrderStatisticTree<&T> = match window_size {34Some(ws) => OrderStatisticTree::with_capacity(ws, cmp),35None => OrderStatisticTree::new(cmp),36};37let mut slf = Self {38slice,39validity,40last_start: 0,41last_end: 0,42ost,43policy: P::new(¶ms.unwrap()),44_out: PhantomData,45};46unsafe {47slf.update(start, end);48}49slf50}5152unsafe fn update(&mut self, new_start: usize, new_end: usize) -> Option<Out> {53debug_assert!(self.last_start <= self.last_end);54debug_assert!(self.last_end <= self.slice.len());55debug_assert!(new_start <= new_end);56debug_assert!(new_end <= self.slice.len());57debug_assert!(self.last_start <= new_start);58debug_assert!(self.last_end <= new_end);5960for i in self.last_end..new_end {61if !self.validity.get(i).unwrap() {62continue;63}64self.ost.insert(unsafe { self.slice.get_unchecked(i) });65}66for i in self.last_start..new_start {67if !self.validity.get(i).unwrap() {68continue;69}70self.ost71.remove(&unsafe { self.slice.get_unchecked(i) })72.expect("previously added value is missing");73}74self.last_start = new_start;75self.last_end = new_end;76let cur = unsafe { self.slice.get_unchecked(self.last_end - 1) };77self.policy.rank(&self.ost, cur)78}7980fn is_valid(&self, _min_periods: usize) -> bool {81self.validity.get(self.last_end - 1).unwrap()82}83}8485type RankWindowAvg<'a, T> = RankWindow<'a, T, f64, RankPolicyAverage>;86type RankWindowMin<'a, T> = RankWindow<'a, T, IdxSize, RankPolicyMin>;87type RankWindowMax<'a, T> = RankWindow<'a, T, IdxSize, RankPolicyMax>;88type RankWindowDense<'a, T> = RankWindow<'a, T, IdxSize, RankPolicyDense>;89type RankWindowRandom<'a, T> = RankWindow<'a, T, IdxSize, RankPolicyRandom>;9091pub fn rolling_rank<T>(92arr: &PrimitiveArray<T>,93window_size: usize,94min_periods: usize,95center: bool,96weights: Option<&[f64]>,97params: Option<RollingFnParams>,98) -> ArrayRef99where100T: NativeType,101{102assert!(weights.is_none(), "weights are not supported for rank");103104let offset_fn = match center {105true => det_offsets_center,106false => det_offsets,107};108let method = if let Some(RollingFnParams::Rank { method, .. }) = params {109method110} else {111unreachable!("expected RollingFnParams::Rank");112};113114match method {115RollingRankMethod::Average => rolling_apply_agg_window::<RankWindowAvg<T>, _, _, _>(116arr.values().as_slice(),117arr.validity().as_ref().unwrap(),118window_size,119min_periods,120offset_fn,121params,122),123RollingRankMethod::Min => rolling_apply_agg_window::<RankWindowMin<T>, _, _, _>(124arr.values().as_slice(),125arr.validity().as_ref().unwrap(),126window_size,127min_periods,128offset_fn,129params,130),131RollingRankMethod::Max => rolling_apply_agg_window::<RankWindowMax<T>, _, _, _>(132arr.values().as_slice(),133arr.validity().as_ref().unwrap(),134window_size,135min_periods,136offset_fn,137params,138),139RollingRankMethod::Dense => rolling_apply_agg_window::<RankWindowDense<T>, _, _, _>(140arr.values().as_slice(),141arr.validity().as_ref().unwrap(),142window_size,143min_periods,144offset_fn,145params,146),147RollingRankMethod::Random => rolling_apply_agg_window::<RankWindowRandom<T>, _, _, _>(148arr.values().as_slice(),149arr.validity().as_ref().unwrap(),150window_size,151min_periods,152offset_fn,153params,154),155}156}157158159