Path: blob/main/crates/polars-compute/src/rolling/nulls/rank.rs
8430 views
use core::panic;1use std::marker::PhantomData;23use polars_utils::IdxSize;4use polars_utils::order_statistic_tree::OrderStatisticTree;56use super::super::rank::*;7use super::*;89pub struct RankWindow<'a, T, Out, P> {10slice: &'a [T],11validity: &'a Bitmap,12start: usize,13end: usize,14ost: OrderStatisticTree<&'a T>,15policy: P,16_out: PhantomData<Out>,17}1819impl<T, Out, P> RollingAggWindowNulls<T, Out> for RankWindow<'_, T, Out, P>20where21T: NativeType,22Out: NativeType,23P: RankPolicy<T, Out>,24{25type This<'a> = RankWindow<'a, T, Out, P>;2627fn new<'a>(28slice: &'a [T],29validity: &'a Bitmap,30start: usize,31end: usize,32params: Option<RollingFnParams>,33window_size: Option<usize>,34) -> Self::This<'a> {35assert!(start <= slice.len() && end <= slice.len() && start <= end);3637let cmp = |a: &&T, b: &&T| T::tot_cmp(*a, *b);38let ost: OrderStatisticTree<&T> = match window_size {39Some(ws) => OrderStatisticTree::with_capacity(ws, cmp),40None => OrderStatisticTree::new(cmp),41};42let mut this = RankWindow {43slice,44validity,45start: 0,46end: 0,47ost,48policy: P::new(¶ms.unwrap()),49_out: PhantomData,50};51// SAFETY: We bounds checked `start` and `end`.52unsafe {53this.update(start, end);54}55this56}5758unsafe fn update(&mut self, new_start: usize, new_end: usize) {59debug_assert!(self.start <= self.end);60debug_assert!(self.end <= self.slice.len());61debug_assert!(new_start <= new_end);62debug_assert!(new_end <= self.slice.len());63debug_assert!(self.start <= new_start);64debug_assert!(self.end <= new_end);6566for i in self.end..new_end {67if !self.validity.get(i).unwrap() {68continue;69}70self.ost.insert(unsafe { self.slice.get_unchecked(i) });71}72for i in self.start..new_start {73if !self.validity.get(i).unwrap() {74continue;75}76self.ost77.remove(&unsafe { self.slice.get_unchecked(i) })78.expect("previously added value is missing");79}80self.start = new_start;81self.end = new_end;82}8384fn get_agg(&self, idx: usize) -> Option<Out> {85if !(self.start..self.end).contains(&idx) {86panic!("index out of bounds");87}88self.policy.rank(&self.ost, &self.slice[idx])89}9091fn is_valid(&self, _min_periods: usize) -> bool {92self.validity.get(self.end - 1).unwrap()93}9495fn slice_len(&self) -> usize {96self.slice.len()97}98}99100pub type RankWindowAvg<'a, T> = RankWindow<'a, T, f64, RankPolicyAverage>;101pub type RankWindowMin<'a, T> = RankWindow<'a, T, IdxSize, RankPolicyMin>;102pub type RankWindowMax<'a, T> = RankWindow<'a, T, IdxSize, RankPolicyMax>;103pub type RankWindowDense<'a, T> = RankWindow<'a, T, IdxSize, RankPolicyDense>;104pub type RankWindowRandom<'a, T> = RankWindow<'a, T, IdxSize, RankPolicyRandom>;105106pub fn rolling_rank<T>(107arr: &PrimitiveArray<T>,108window_size: usize,109min_periods: usize,110center: bool,111weights: Option<&[f64]>,112params: Option<RollingFnParams>,113) -> ArrayRef114where115T: NativeType,116{117assert!(weights.is_none(), "weights are not supported for rank");118119let offset_fn = match center {120true => det_offsets_center,121false => det_offsets,122};123let method = if let Some(RollingFnParams::Rank { method, .. }) = params {124method125} else {126unreachable!("expected RollingFnParams::Rank");127};128129match method {130RollingRankMethod::Average => rolling_apply_agg_window::<RankWindowAvg<T>, _, _, _>(131arr.values().as_slice(),132arr.validity().as_ref().unwrap(),133window_size,134min_periods,135offset_fn,136params,137),138RollingRankMethod::Min => rolling_apply_agg_window::<RankWindowMin<T>, _, _, _>(139arr.values().as_slice(),140arr.validity().as_ref().unwrap(),141window_size,142min_periods,143offset_fn,144params,145),146RollingRankMethod::Max => rolling_apply_agg_window::<RankWindowMax<T>, _, _, _>(147arr.values().as_slice(),148arr.validity().as_ref().unwrap(),149window_size,150min_periods,151offset_fn,152params,153),154RollingRankMethod::Dense => rolling_apply_agg_window::<RankWindowDense<T>, _, _, _>(155arr.values().as_slice(),156arr.validity().as_ref().unwrap(),157window_size,158min_periods,159offset_fn,160params,161),162RollingRankMethod::Random => rolling_apply_agg_window::<RankWindowRandom<T>, _, _, _>(163arr.values().as_slice(),164arr.validity().as_ref().unwrap(),165window_size,166min_periods,167offset_fn,168params,169),170}171}172173174