Path: blob/main/crates/polars-compute/src/rolling/rank.rs
8440 views
use std::fmt::Debug;12use polars_utils::IdxSize;3use polars_utils::order_statistic_tree::OrderStatisticTree;4use rand::rngs::SmallRng;5use rand::{Rng, SeedableRng};67use super::*;89pub trait RankPolicy<T, Out>: Debug10where11T: NativeType,12Out: NativeType,13{14fn new(params: &RollingFnParams) -> Self;15fn rank<'a>(&self, ost: &OrderStatisticTree<&'a T>, value: &'a T) -> Option<Out>;16fn bump_rng(&mut self) {}17}1819#[derive(Debug)]20pub struct RankPolicyAverage;2122impl<T: NativeType> RankPolicy<T, f64> for RankPolicyAverage {23fn new(_params: &RollingFnParams) -> Self {24Self25}26fn rank<'a>(&self, ost: &OrderStatisticTree<&'a T>, value: &'a T) -> Option<f64> {27let rank_range = ost.rank_range(&value).ok()?;28let rank_lo = (rank_range.start() + 1) as f64;29let rank_hi = (rank_range.end() + 1) as f64;30Some((rank_lo + rank_hi) / 2.0)31}32}3334#[derive(Debug)]35pub struct RankPolicyMin;3637impl<T: NativeType> RankPolicy<T, IdxSize> for RankPolicyMin {38fn new(_params: &RollingFnParams) -> Self {39Self40}41fn rank<'a>(&self, ost: &OrderStatisticTree<&'a T>, value: &'a T) -> Option<IdxSize> {42let range = ost.rank_range(&value).ok()?;43Some(IdxSize::try_from(range.start() + 1).unwrap())44}45}4647#[derive(Debug)]48pub struct RankPolicyMax;4950impl<T: NativeType> RankPolicy<T, IdxSize> for RankPolicyMax {51fn new(_params: &RollingFnParams) -> Self {52Self53}54fn rank<'a>(&self, ost: &OrderStatisticTree<&'a T>, value: &'a T) -> Option<IdxSize> {55let range = ost.rank_range(&value).ok()?;56Some(IdxSize::try_from(range.end() + 1).unwrap())57}58}5960#[derive(Debug)]61pub struct RankPolicyDense;6263impl<T: NativeType> RankPolicy<T, IdxSize> for RankPolicyDense {64fn new(_params: &RollingFnParams) -> Self {65Self66}67fn rank<'a>(&self, ost: &OrderStatisticTree<&'a T>, value: &'a T) -> Option<IdxSize> {68let rank = ost.rank_unique(&value).ok()?;69Some(IdxSize::try_from(rank + 1).unwrap())70}71}7273#[derive(Debug)]74pub struct RankPolicyRandom {75rng: SmallRng,76}7778impl<T: NativeType> RankPolicy<T, IdxSize> for RankPolicyRandom {79fn new(params: &RollingFnParams) -> Self {80let RollingFnParams::Rank { seed, .. } = params else {81unreachable!("expected RollingFnParams::Rank");82};83let rng = match seed {84Some(s) => SmallRng::seed_from_u64(*s),85None => SmallRng::from_os_rng(),86};87Self { rng }88}89fn rank<'a>(&self, ost: &OrderStatisticTree<&'a T>, value: &'a T) -> Option<IdxSize> {90let rank_range = ost.rank_range(&value).ok()?;91let rank_lo = rank_range.start() + 1;92let rank_hi = rank_range.end() + 1;93Some(IdxSize::try_from(self.rng.clone().random_range(rank_lo..=rank_hi)).unwrap())94}95fn bump_rng(&mut self) {96self.rng.random::<u32>();97}98}99100101