Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-compute/src/rolling/rank.rs
7884 views
1
use std::fmt::Debug;
2
3
use polars_utils::IdxSize;
4
use polars_utils::order_statistic_tree::OrderStatisticTree;
5
use rand::rngs::SmallRng;
6
use rand::{Rng, SeedableRng};
7
8
use super::*;
9
10
pub trait RankPolicy<T, Out>: Debug
11
where
12
T: NativeType,
13
Out: NativeType,
14
{
15
fn new(params: &RollingFnParams) -> Self;
16
fn rank<'a>(&mut self, ost: &OrderStatisticTree<&'a T>, value: &'a T) -> Option<Out>;
17
}
18
19
#[derive(Debug)]
20
pub struct RankPolicyAverage;
21
22
impl<T: NativeType> RankPolicy<T, f64> for RankPolicyAverage {
23
fn new(_params: &RollingFnParams) -> Self {
24
Self
25
}
26
fn rank<'a>(&mut self, ost: &OrderStatisticTree<&'a T>, value: &'a T) -> Option<f64> {
27
let rank_range = ost.rank_range(&value).ok()?;
28
let rank_lo = (rank_range.start() + 1) as f64;
29
let rank_hi = (rank_range.end() + 1) as f64;
30
Some((rank_lo + rank_hi) / 2.0)
31
}
32
}
33
34
#[derive(Debug)]
35
pub struct RankPolicyMin;
36
37
impl<T: NativeType> RankPolicy<T, IdxSize> for RankPolicyMin {
38
fn new(_params: &RollingFnParams) -> Self {
39
Self
40
}
41
fn rank<'a>(&mut self, ost: &OrderStatisticTree<&'a T>, value: &'a T) -> Option<IdxSize> {
42
Some(IdxSize::try_from(ost.rank_range(&value).ok()?.start() + 1).unwrap())
43
}
44
}
45
46
#[derive(Debug)]
47
pub struct RankPolicyMax;
48
49
impl<T: NativeType> RankPolicy<T, IdxSize> for RankPolicyMax {
50
fn new(_params: &RollingFnParams) -> Self {
51
Self
52
}
53
fn rank<'a>(&mut self, ost: &OrderStatisticTree<&'a T>, value: &'a T) -> Option<IdxSize> {
54
Some(IdxSize::try_from(*ost.rank_range(&value).ok()?.end() + 1).unwrap())
55
}
56
}
57
58
#[derive(Debug)]
59
pub struct RankPolicyDense;
60
61
impl<T: NativeType> RankPolicy<T, IdxSize> for RankPolicyDense {
62
fn new(_params: &RollingFnParams) -> Self {
63
Self
64
}
65
fn rank<'a>(&mut self, ost: &OrderStatisticTree<&'a T>, value: &'a T) -> Option<IdxSize> {
66
Some(IdxSize::try_from(ost.rank_unique(&value).ok()? + 1).unwrap())
67
}
68
}
69
70
#[derive(Debug)]
71
pub struct RankPolicyRandom {
72
rng: SmallRng,
73
}
74
75
impl<T: NativeType> RankPolicy<T, IdxSize> for RankPolicyRandom {
76
fn new(params: &RollingFnParams) -> Self {
77
let RollingFnParams::Rank { seed, .. } = params else {
78
unreachable!("expected RollingFnParams::Rank");
79
};
80
let rng = match seed {
81
Some(s) => SmallRng::seed_from_u64(*s),
82
None => SmallRng::from_os_rng(),
83
};
84
Self { rng }
85
}
86
fn rank<'a>(&mut self, ost: &OrderStatisticTree<&'a T>, value: &'a T) -> Option<IdxSize> {
87
let rank_range = ost.rank_range(&value).ok()?;
88
let rank_lo = rank_range.start() + 1;
89
let rank_hi = rank_range.end() + 1;
90
Some(IdxSize::try_from(self.rng.random_range(rank_lo..=rank_hi)).unwrap())
91
}
92
}
93
94