Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-compute/src/rolling/rank.rs
8440 views
1
use std::fmt::Debug;
2
3
use polars_utils::IdxSize;
4
use polars_utils::order_statistic_tree::OrderStatisticTree;
5
use rand::rngs::SmallRng;
6
use rand::{Rng, SeedableRng};
7
8
use super::*;
9
10
pub trait RankPolicy<T, Out>: Debug
11
where
12
T: NativeType,
13
Out: NativeType,
14
{
15
fn new(params: &RollingFnParams) -> Self;
16
fn rank<'a>(&self, ost: &OrderStatisticTree<&'a T>, value: &'a T) -> Option<Out>;
17
fn bump_rng(&mut self) {}
18
}
19
20
#[derive(Debug)]
21
pub struct RankPolicyAverage;
22
23
impl<T: NativeType> RankPolicy<T, f64> for RankPolicyAverage {
24
fn new(_params: &RollingFnParams) -> Self {
25
Self
26
}
27
fn rank<'a>(&self, ost: &OrderStatisticTree<&'a T>, value: &'a T) -> Option<f64> {
28
let rank_range = ost.rank_range(&value).ok()?;
29
let rank_lo = (rank_range.start() + 1) as f64;
30
let rank_hi = (rank_range.end() + 1) as f64;
31
Some((rank_lo + rank_hi) / 2.0)
32
}
33
}
34
35
#[derive(Debug)]
36
pub struct RankPolicyMin;
37
38
impl<T: NativeType> RankPolicy<T, IdxSize> for RankPolicyMin {
39
fn new(_params: &RollingFnParams) -> Self {
40
Self
41
}
42
fn rank<'a>(&self, ost: &OrderStatisticTree<&'a T>, value: &'a T) -> Option<IdxSize> {
43
let range = ost.rank_range(&value).ok()?;
44
Some(IdxSize::try_from(range.start() + 1).unwrap())
45
}
46
}
47
48
#[derive(Debug)]
49
pub struct RankPolicyMax;
50
51
impl<T: NativeType> RankPolicy<T, IdxSize> for RankPolicyMax {
52
fn new(_params: &RollingFnParams) -> Self {
53
Self
54
}
55
fn rank<'a>(&self, ost: &OrderStatisticTree<&'a T>, value: &'a T) -> Option<IdxSize> {
56
let range = ost.rank_range(&value).ok()?;
57
Some(IdxSize::try_from(range.end() + 1).unwrap())
58
}
59
}
60
61
#[derive(Debug)]
62
pub struct RankPolicyDense;
63
64
impl<T: NativeType> RankPolicy<T, IdxSize> for RankPolicyDense {
65
fn new(_params: &RollingFnParams) -> Self {
66
Self
67
}
68
fn rank<'a>(&self, ost: &OrderStatisticTree<&'a T>, value: &'a T) -> Option<IdxSize> {
69
let rank = ost.rank_unique(&value).ok()?;
70
Some(IdxSize::try_from(rank + 1).unwrap())
71
}
72
}
73
74
#[derive(Debug)]
75
pub struct RankPolicyRandom {
76
rng: SmallRng,
77
}
78
79
impl<T: NativeType> RankPolicy<T, IdxSize> for RankPolicyRandom {
80
fn new(params: &RollingFnParams) -> Self {
81
let RollingFnParams::Rank { seed, .. } = params else {
82
unreachable!("expected RollingFnParams::Rank");
83
};
84
let rng = match seed {
85
Some(s) => SmallRng::seed_from_u64(*s),
86
None => SmallRng::from_os_rng(),
87
};
88
Self { rng }
89
}
90
fn rank<'a>(&self, ost: &OrderStatisticTree<&'a T>, value: &'a T) -> Option<IdxSize> {
91
let rank_range = ost.rank_range(&value).ok()?;
92
let rank_lo = rank_range.start() + 1;
93
let rank_hi = rank_range.end() + 1;
94
Some(IdxSize::try_from(self.rng.clone().random_range(rank_lo..=rank_hi)).unwrap())
95
}
96
fn bump_rng(&mut self) {
97
self.rng.random::<u32>();
98
}
99
}
100
101