Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-compute/src/rolling/nulls/rank.rs
7884 views
1
use std::marker::PhantomData;
2
3
use polars_utils::IdxSize;
4
use polars_utils::order_statistic_tree::OrderStatisticTree;
5
6
use super::super::rank::*;
7
use super::*;
8
9
pub struct RankWindow<'a, T, Out, P> {
10
slice: &'a [T],
11
validity: &'a Bitmap,
12
last_start: usize,
13
last_end: usize,
14
ost: OrderStatisticTree<&'a T>,
15
policy: P,
16
_out: PhantomData<Out>,
17
}
18
19
impl<'a, T, Out, P> RollingAggWindowNulls<'a, T, Out> for RankWindow<'a, T, Out, P>
20
where
21
T: NativeType,
22
Out: NativeType,
23
P: RankPolicy<T, Out>,
24
{
25
unsafe fn new(
26
slice: &'a [T],
27
validity: &'a Bitmap,
28
start: usize,
29
end: usize,
30
params: Option<RollingFnParams>,
31
window_size: Option<usize>,
32
) -> Self {
33
let cmp = |a: &&T, b: &&T| T::tot_cmp(*a, *b);
34
let ost: OrderStatisticTree<&T> = match window_size {
35
Some(ws) => OrderStatisticTree::with_capacity(ws, cmp),
36
None => OrderStatisticTree::new(cmp),
37
};
38
let mut slf = Self {
39
slice,
40
validity,
41
last_start: 0,
42
last_end: 0,
43
ost,
44
policy: P::new(&params.unwrap()),
45
_out: PhantomData,
46
};
47
unsafe {
48
slf.update(start, end);
49
}
50
slf
51
}
52
53
unsafe fn update(&mut self, new_start: usize, new_end: usize) -> Option<Out> {
54
debug_assert!(self.last_start <= self.last_end);
55
debug_assert!(self.last_end <= self.slice.len());
56
debug_assert!(new_start <= new_end);
57
debug_assert!(new_end <= self.slice.len());
58
debug_assert!(self.last_start <= new_start);
59
debug_assert!(self.last_end <= new_end);
60
61
for i in self.last_end..new_end {
62
if !self.validity.get(i).unwrap() {
63
continue;
64
}
65
self.ost.insert(unsafe { self.slice.get_unchecked(i) });
66
}
67
for i in self.last_start..new_start {
68
if !self.validity.get(i).unwrap() {
69
continue;
70
}
71
self.ost
72
.remove(&unsafe { self.slice.get_unchecked(i) })
73
.expect("previously added value is missing");
74
}
75
self.last_start = new_start;
76
self.last_end = new_end;
77
let cur = unsafe { self.slice.get_unchecked(self.last_end - 1) };
78
self.policy.rank(&self.ost, cur)
79
}
80
81
fn is_valid(&self, _min_periods: usize) -> bool {
82
self.validity.get(self.last_end - 1).unwrap()
83
}
84
}
85
86
type RankWindowAvg<'a, T> = RankWindow<'a, T, f64, RankPolicyAverage>;
87
type RankWindowMin<'a, T> = RankWindow<'a, T, IdxSize, RankPolicyMin>;
88
type RankWindowMax<'a, T> = RankWindow<'a, T, IdxSize, RankPolicyMax>;
89
type RankWindowDense<'a, T> = RankWindow<'a, T, IdxSize, RankPolicyDense>;
90
type RankWindowRandom<'a, T> = RankWindow<'a, T, IdxSize, RankPolicyRandom>;
91
92
pub fn rolling_rank<T>(
93
arr: &PrimitiveArray<T>,
94
window_size: usize,
95
min_periods: usize,
96
center: bool,
97
weights: Option<&[f64]>,
98
params: Option<RollingFnParams>,
99
) -> ArrayRef
100
where
101
T: NativeType,
102
{
103
assert!(weights.is_none(), "weights are not supported for rank");
104
105
let offset_fn = match center {
106
true => det_offsets_center,
107
false => det_offsets,
108
};
109
let method = if let Some(RollingFnParams::Rank { method, .. }) = params {
110
method
111
} else {
112
unreachable!("expected RollingFnParams::Rank");
113
};
114
115
match method {
116
RollingRankMethod::Average => rolling_apply_agg_window::<RankWindowAvg<T>, _, _, _>(
117
arr.values().as_slice(),
118
arr.validity().as_ref().unwrap(),
119
window_size,
120
min_periods,
121
offset_fn,
122
params,
123
),
124
RollingRankMethod::Min => rolling_apply_agg_window::<RankWindowMin<T>, _, _, _>(
125
arr.values().as_slice(),
126
arr.validity().as_ref().unwrap(),
127
window_size,
128
min_periods,
129
offset_fn,
130
params,
131
),
132
RollingRankMethod::Max => rolling_apply_agg_window::<RankWindowMax<T>, _, _, _>(
133
arr.values().as_slice(),
134
arr.validity().as_ref().unwrap(),
135
window_size,
136
min_periods,
137
offset_fn,
138
params,
139
),
140
RollingRankMethod::Dense => rolling_apply_agg_window::<RankWindowDense<T>, _, _, _>(
141
arr.values().as_slice(),
142
arr.validity().as_ref().unwrap(),
143
window_size,
144
min_periods,
145
offset_fn,
146
params,
147
),
148
RollingRankMethod::Random => rolling_apply_agg_window::<RankWindowRandom<T>, _, _, _>(
149
arr.values().as_slice(),
150
arr.validity().as_ref().unwrap(),
151
window_size,
152
min_periods,
153
offset_fn,
154
params,
155
),
156
}
157
}
158
159