Path: blob/main/crates/polars-compute/src/rolling/rank.rs
7884 views
use std::fmt::Debug;12use polars_utils::IdxSize;3use polars_utils::order_statistic_tree::OrderStatisticTree;4use rand::rngs::SmallRng;5use rand::{Rng, SeedableRng};67use super::*;89pub trait RankPolicy<T, Out>: Debug10where11T: NativeType,12Out: NativeType,13{14fn new(params: &RollingFnParams) -> Self;15fn rank<'a>(&mut self, ost: &OrderStatisticTree<&'a T>, value: &'a T) -> Option<Out>;16}1718#[derive(Debug)]19pub struct RankPolicyAverage;2021impl<T: NativeType> RankPolicy<T, f64> for RankPolicyAverage {22fn new(_params: &RollingFnParams) -> Self {23Self24}25fn rank<'a>(&mut self, ost: &OrderStatisticTree<&'a T>, value: &'a T) -> Option<f64> {26let rank_range = ost.rank_range(&value).ok()?;27let rank_lo = (rank_range.start() + 1) as f64;28let rank_hi = (rank_range.end() + 1) as f64;29Some((rank_lo + rank_hi) / 2.0)30}31}3233#[derive(Debug)]34pub struct RankPolicyMin;3536impl<T: NativeType> RankPolicy<T, IdxSize> for RankPolicyMin {37fn new(_params: &RollingFnParams) -> Self {38Self39}40fn rank<'a>(&mut self, ost: &OrderStatisticTree<&'a T>, value: &'a T) -> Option<IdxSize> {41Some(IdxSize::try_from(ost.rank_range(&value).ok()?.start() + 1).unwrap())42}43}4445#[derive(Debug)]46pub struct RankPolicyMax;4748impl<T: NativeType> RankPolicy<T, IdxSize> for RankPolicyMax {49fn new(_params: &RollingFnParams) -> Self {50Self51}52fn rank<'a>(&mut self, ost: &OrderStatisticTree<&'a T>, value: &'a T) -> Option<IdxSize> {53Some(IdxSize::try_from(*ost.rank_range(&value).ok()?.end() + 1).unwrap())54}55}5657#[derive(Debug)]58pub struct RankPolicyDense;5960impl<T: NativeType> RankPolicy<T, IdxSize> for RankPolicyDense {61fn new(_params: &RollingFnParams) -> Self {62Self63}64fn rank<'a>(&mut self, ost: &OrderStatisticTree<&'a T>, value: &'a T) -> Option<IdxSize> {65Some(IdxSize::try_from(ost.rank_unique(&value).ok()? + 1).unwrap())66}67}6869#[derive(Debug)]70pub struct RankPolicyRandom {71rng: SmallRng,72}7374impl<T: NativeType> RankPolicy<T, IdxSize> for RankPolicyRandom {75fn new(params: &RollingFnParams) -> Self {76let RollingFnParams::Rank { seed, .. } = params else {77unreachable!("expected RollingFnParams::Rank");78};79let rng = match seed {80Some(s) => SmallRng::seed_from_u64(*s),81None => SmallRng::from_os_rng(),82};83Self { rng }84}85fn rank<'a>(&mut self, ost: &OrderStatisticTree<&'a T>, value: &'a T) -> Option<IdxSize> {86let rank_range = ost.rank_range(&value).ok()?;87let rank_lo = rank_range.start() + 1;88let rank_hi = rank_range.end() + 1;89Some(IdxSize::try_from(self.rng.random_range(rank_lo..=rank_hi)).unwrap())90}91}929394