Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-compute/src/rolling/no_nulls/rank.rs
7884 views
1
use std::marker::PhantomData;
2
3
use polars_utils::IdxSize;
4
use polars_utils::order_statistic_tree::OrderStatisticTree;
5
6
use super::super::rank::*;
7
use super::*;
8
9
#[derive(Debug)]
10
pub struct RankWindow<'a, T, Out, P>
11
where
12
T: NativeType,
13
Out: NativeType,
14
P: RankPolicy<T, Out>,
15
{
16
slice: &'a [T],
17
last_start: usize,
18
last_end: usize,
19
ost: OrderStatisticTree<&'a T>,
20
policy: P,
21
_out: PhantomData<Out>,
22
}
23
24
impl<'a, T, Out, P> RollingAggWindowNoNulls<'a, T, Out> for RankWindow<'a, T, Out, P>
25
where
26
T: NativeType,
27
Out: NativeType,
28
P: RankPolicy<T, Out>,
29
{
30
fn new(
31
slice: &'a [T],
32
start: usize,
33
end: usize,
34
params: Option<RollingFnParams>,
35
window_size: Option<usize>,
36
) -> Self {
37
let cmp = |a: &&T, b: &&T| T::tot_cmp(*a, *b);
38
let ost = match window_size {
39
Some(ws) => OrderStatisticTree::with_capacity(ws, cmp),
40
None => OrderStatisticTree::new(cmp),
41
};
42
let policy = P::new(&params.unwrap());
43
let mut slf = Self {
44
slice,
45
last_start: 0,
46
last_end: 0,
47
ost,
48
policy,
49
_out: PhantomData,
50
};
51
unsafe {
52
slf.update(start, end);
53
}
54
slf
55
}
56
57
unsafe fn update(&mut self, new_start: usize, new_end: usize) -> Option<Out> {
58
debug_assert!(self.ost.len() == self.last_end - self.last_start);
59
debug_assert!(self.last_start <= self.last_end);
60
debug_assert!(self.last_end <= self.slice.len());
61
debug_assert!(new_start <= new_end);
62
debug_assert!(new_end <= self.slice.len());
63
debug_assert!(self.last_start <= new_start);
64
debug_assert!(self.last_end <= new_end);
65
66
for i in self.last_end..new_end {
67
self.ost.insert(unsafe { self.slice.get_unchecked(i) });
68
}
69
for i in self.last_start..new_start {
70
self.ost
71
.remove(&unsafe { self.slice.get_unchecked(i) })
72
.expect("previously added value is missing");
73
}
74
self.last_start = new_start;
75
self.last_end = new_end;
76
if self.last_end == 0 {
77
return None;
78
}
79
let cur = unsafe { self.slice.get_unchecked(self.last_end - 1) };
80
self.policy.rank(&self.ost, cur)
81
}
82
}
83
84
pub type RankWindowAvg<'a, T> = RankWindow<'a, T, f64, RankPolicyAverage>;
85
pub type RankWindowMin<'a, T> = RankWindow<'a, T, IdxSize, RankPolicyMin>;
86
pub type RankWindowMax<'a, T> = RankWindow<'a, T, IdxSize, RankPolicyMax>;
87
pub type RankWindowDense<'a, T> = RankWindow<'a, T, IdxSize, RankPolicyDense>;
88
pub type RankWindowRandom<'a, T> = RankWindow<'a, T, IdxSize, RankPolicyRandom>;
89
90
pub fn rolling_rank<T>(
91
values: &[T],
92
window_size: usize,
93
min_periods: usize,
94
center: bool,
95
weights: Option<&[f64]>,
96
params: Option<RollingFnParams>,
97
) -> PolarsResult<ArrayRef>
98
where
99
T: NativeType + num_traits::Num,
100
{
101
assert!(weights.is_none(), "weights are not supported for rank");
102
103
let offset_fn = match center {
104
true => det_offsets_center,
105
false => det_offsets,
106
};
107
let method = if let Some(RollingFnParams::Rank { method, .. }) = params {
108
method
109
} else {
110
unreachable!("expected RollingFnParams::Rank");
111
};
112
113
match method {
114
RollingRankMethod::Average => rolling_apply_agg_window::<RankWindowAvg<T>, _, _, _>(
115
values,
116
window_size,
117
min_periods,
118
offset_fn,
119
params,
120
),
121
RollingRankMethod::Min => rolling_apply_agg_window::<RankWindowMin<T>, _, _, _>(
122
values,
123
window_size,
124
min_periods,
125
offset_fn,
126
params,
127
),
128
RollingRankMethod::Max => rolling_apply_agg_window::<RankWindowMax<T>, _, _, _>(
129
values,
130
window_size,
131
min_periods,
132
offset_fn,
133
params,
134
),
135
RollingRankMethod::Dense => rolling_apply_agg_window::<RankWindowDense<T>, _, _, _>(
136
values,
137
window_size,
138
min_periods,
139
offset_fn,
140
params,
141
),
142
RollingRankMethod::Random => rolling_apply_agg_window::<RankWindowRandom<T>, _, _, _>(
143
values,
144
window_size,
145
min_periods,
146
offset_fn,
147
params,
148
),
149
}
150
}
151
152