Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-compute/src/rolling/no_nulls/rank.rs
8480 views
1
use std::marker::PhantomData;
2
3
use polars_utils::IdxSize;
4
use polars_utils::order_statistic_tree::OrderStatisticTree;
5
6
use super::super::rank::*;
7
use super::*;
8
9
#[derive(Debug)]
10
pub struct RankWindow<'a, T, Out, P>
11
where
12
T: NativeType,
13
Out: NativeType,
14
P: RankPolicy<T, Out>,
15
{
16
slice: &'a [T],
17
pub(super) start: usize,
18
pub(super) end: usize,
19
ost: OrderStatisticTree<&'a T>,
20
policy: P,
21
_out: PhantomData<Out>,
22
}
23
24
impl<T, Out, P> RollingAggWindowNoNulls<T, Out> for RankWindow<'_, T, Out, P>
25
where
26
T: NativeType,
27
Out: NativeType,
28
P: RankPolicy<T, Out>,
29
{
30
type This<'a> = RankWindow<'a, T, Out, P>;
31
32
fn new<'a>(
33
slice: &'a [T],
34
start: usize,
35
end: usize,
36
params: Option<RollingFnParams>,
37
window_size: Option<usize>,
38
) -> Self::This<'a> {
39
assert!(start <= slice.len() && end <= slice.len() && start <= end);
40
41
let cmp = |a: &&T, b: &&T| T::tot_cmp(*a, *b);
42
let ost = match window_size {
43
Some(ws) => OrderStatisticTree::with_capacity(ws, cmp),
44
None => OrderStatisticTree::new(cmp),
45
};
46
let policy = P::new(&params.unwrap());
47
let mut this = RankWindow {
48
slice,
49
start: 0,
50
end: 0,
51
ost,
52
policy,
53
_out: PhantomData,
54
};
55
56
// SAFETY: We checked that `start` and `end` are in-bounds.
57
unsafe {
58
this.update(start, end);
59
}
60
61
this
62
}
63
64
unsafe fn update(&mut self, new_start: usize, new_end: usize) {
65
debug_assert!(self.ost.len() == self.end - self.start);
66
debug_assert!(self.start <= self.end);
67
debug_assert!(self.end <= self.slice.len());
68
debug_assert!(new_start <= new_end);
69
debug_assert!(new_end <= self.slice.len());
70
debug_assert!(self.start <= new_start);
71
debug_assert!(self.end <= new_end);
72
73
for i in self.end..new_end {
74
self.ost.insert(unsafe { self.slice.get_unchecked(i) });
75
}
76
for i in self.start..new_start {
77
self.ost
78
.remove(&unsafe { self.slice.get_unchecked(i) })
79
.expect("previously added value is missing");
80
}
81
self.start = new_start;
82
self.end = new_end;
83
self.policy.bump_rng();
84
}
85
86
fn get_agg(&self, idx: usize) -> Option<Out> {
87
if !(self.start..self.end).contains(&idx) {
88
return None;
89
}
90
self.policy.rank(&self.ost, &self.slice[idx])
91
}
92
93
fn slice_len(&self) -> usize {
94
self.slice.len()
95
}
96
}
97
98
pub type RankWindowAvg<'a, T> = RankWindow<'a, T, f64, RankPolicyAverage>;
99
pub type RankWindowMin<'a, T> = RankWindow<'a, T, IdxSize, RankPolicyMin>;
100
pub type RankWindowMax<'a, T> = RankWindow<'a, T, IdxSize, RankPolicyMax>;
101
pub type RankWindowDense<'a, T> = RankWindow<'a, T, IdxSize, RankPolicyDense>;
102
pub type RankWindowRandom<'a, T> = RankWindow<'a, T, IdxSize, RankPolicyRandom>;
103
104
pub fn rolling_rank<T>(
105
values: &[T],
106
window_size: usize,
107
min_periods: usize,
108
center: bool,
109
weights: Option<&[f64]>,
110
params: Option<RollingFnParams>,
111
) -> PolarsResult<ArrayRef>
112
where
113
T: NativeType + num_traits::Num,
114
{
115
assert!(weights.is_none(), "weights are not supported for rank");
116
117
let offset_fn = match center {
118
true => det_offsets_center,
119
false => det_offsets,
120
};
121
let method = if let Some(RollingFnParams::Rank { method, .. }) = params {
122
method
123
} else {
124
unreachable!("expected RollingFnParams::Rank");
125
};
126
127
match method {
128
RollingRankMethod::Average => rolling_apply_agg_window::<RankWindowAvg<T>, _, _, _>(
129
values,
130
window_size,
131
min_periods,
132
offset_fn,
133
params,
134
),
135
RollingRankMethod::Min => rolling_apply_agg_window::<RankWindowMin<T>, _, _, _>(
136
values,
137
window_size,
138
min_periods,
139
offset_fn,
140
params,
141
),
142
RollingRankMethod::Max => rolling_apply_agg_window::<RankWindowMax<T>, _, _, _>(
143
values,
144
window_size,
145
min_periods,
146
offset_fn,
147
params,
148
),
149
RollingRankMethod::Dense => rolling_apply_agg_window::<RankWindowDense<T>, _, _, _>(
150
values,
151
window_size,
152
min_periods,
153
offset_fn,
154
params,
155
),
156
RollingRankMethod::Random => rolling_apply_agg_window::<RankWindowRandom<T>, _, _, _>(
157
values,
158
window_size,
159
min_periods,
160
offset_fn,
161
params,
162
),
163
}
164
}
165
166