Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-compute/src/rolling/nulls/rank.rs
8430 views
1
use core::panic;
2
use std::marker::PhantomData;
3
4
use polars_utils::IdxSize;
5
use polars_utils::order_statistic_tree::OrderStatisticTree;
6
7
use super::super::rank::*;
8
use super::*;
9
10
pub struct RankWindow<'a, T, Out, P> {
11
slice: &'a [T],
12
validity: &'a Bitmap,
13
start: usize,
14
end: usize,
15
ost: OrderStatisticTree<&'a T>,
16
policy: P,
17
_out: PhantomData<Out>,
18
}
19
20
impl<T, Out, P> RollingAggWindowNulls<T, Out> for RankWindow<'_, T, Out, P>
21
where
22
T: NativeType,
23
Out: NativeType,
24
P: RankPolicy<T, Out>,
25
{
26
type This<'a> = RankWindow<'a, T, Out, P>;
27
28
fn new<'a>(
29
slice: &'a [T],
30
validity: &'a Bitmap,
31
start: usize,
32
end: usize,
33
params: Option<RollingFnParams>,
34
window_size: Option<usize>,
35
) -> Self::This<'a> {
36
assert!(start <= slice.len() && end <= slice.len() && start <= end);
37
38
let cmp = |a: &&T, b: &&T| T::tot_cmp(*a, *b);
39
let ost: OrderStatisticTree<&T> = match window_size {
40
Some(ws) => OrderStatisticTree::with_capacity(ws, cmp),
41
None => OrderStatisticTree::new(cmp),
42
};
43
let mut this = RankWindow {
44
slice,
45
validity,
46
start: 0,
47
end: 0,
48
ost,
49
policy: P::new(&params.unwrap()),
50
_out: PhantomData,
51
};
52
// SAFETY: We bounds checked `start` and `end`.
53
unsafe {
54
this.update(start, end);
55
}
56
this
57
}
58
59
unsafe fn update(&mut self, new_start: usize, new_end: usize) {
60
debug_assert!(self.start <= self.end);
61
debug_assert!(self.end <= self.slice.len());
62
debug_assert!(new_start <= new_end);
63
debug_assert!(new_end <= self.slice.len());
64
debug_assert!(self.start <= new_start);
65
debug_assert!(self.end <= new_end);
66
67
for i in self.end..new_end {
68
if !self.validity.get(i).unwrap() {
69
continue;
70
}
71
self.ost.insert(unsafe { self.slice.get_unchecked(i) });
72
}
73
for i in self.start..new_start {
74
if !self.validity.get(i).unwrap() {
75
continue;
76
}
77
self.ost
78
.remove(&unsafe { self.slice.get_unchecked(i) })
79
.expect("previously added value is missing");
80
}
81
self.start = new_start;
82
self.end = new_end;
83
}
84
85
fn get_agg(&self, idx: usize) -> Option<Out> {
86
if !(self.start..self.end).contains(&idx) {
87
panic!("index out of bounds");
88
}
89
self.policy.rank(&self.ost, &self.slice[idx])
90
}
91
92
fn is_valid(&self, _min_periods: usize) -> bool {
93
self.validity.get(self.end - 1).unwrap()
94
}
95
96
fn slice_len(&self) -> usize {
97
self.slice.len()
98
}
99
}
100
101
pub type RankWindowAvg<'a, T> = RankWindow<'a, T, f64, RankPolicyAverage>;
102
pub type RankWindowMin<'a, T> = RankWindow<'a, T, IdxSize, RankPolicyMin>;
103
pub type RankWindowMax<'a, T> = RankWindow<'a, T, IdxSize, RankPolicyMax>;
104
pub type RankWindowDense<'a, T> = RankWindow<'a, T, IdxSize, RankPolicyDense>;
105
pub type RankWindowRandom<'a, T> = RankWindow<'a, T, IdxSize, RankPolicyRandom>;
106
107
pub fn rolling_rank<T>(
108
arr: &PrimitiveArray<T>,
109
window_size: usize,
110
min_periods: usize,
111
center: bool,
112
weights: Option<&[f64]>,
113
params: Option<RollingFnParams>,
114
) -> ArrayRef
115
where
116
T: NativeType,
117
{
118
assert!(weights.is_none(), "weights are not supported for rank");
119
120
let offset_fn = match center {
121
true => det_offsets_center,
122
false => det_offsets,
123
};
124
let method = if let Some(RollingFnParams::Rank { method, .. }) = params {
125
method
126
} else {
127
unreachable!("expected RollingFnParams::Rank");
128
};
129
130
match method {
131
RollingRankMethod::Average => rolling_apply_agg_window::<RankWindowAvg<T>, _, _, _>(
132
arr.values().as_slice(),
133
arr.validity().as_ref().unwrap(),
134
window_size,
135
min_periods,
136
offset_fn,
137
params,
138
),
139
RollingRankMethod::Min => rolling_apply_agg_window::<RankWindowMin<T>, _, _, _>(
140
arr.values().as_slice(),
141
arr.validity().as_ref().unwrap(),
142
window_size,
143
min_periods,
144
offset_fn,
145
params,
146
),
147
RollingRankMethod::Max => rolling_apply_agg_window::<RankWindowMax<T>, _, _, _>(
148
arr.values().as_slice(),
149
arr.validity().as_ref().unwrap(),
150
window_size,
151
min_periods,
152
offset_fn,
153
params,
154
),
155
RollingRankMethod::Dense => rolling_apply_agg_window::<RankWindowDense<T>, _, _, _>(
156
arr.values().as_slice(),
157
arr.validity().as_ref().unwrap(),
158
window_size,
159
min_periods,
160
offset_fn,
161
params,
162
),
163
RollingRankMethod::Random => rolling_apply_agg_window::<RankWindowRandom<T>, _, _, _>(
164
arr.values().as_slice(),
165
arr.validity().as_ref().unwrap(),
166
window_size,
167
min_periods,
168
offset_fn,
169
params,
170
),
171
}
172
}
173
174