Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-compute/src/rolling/arg_min_max.rs
8448 views
1
use std::collections::VecDeque;
2
use std::marker::PhantomData;
3
4
use arrow::bitmap::Bitmap;
5
use arrow::types::NativeType;
6
use polars_utils::IdxSize;
7
use polars_utils::min_max::{MaxPropagateNan, MinMaxPolicy, MinPropagateNan};
8
9
use super::RollingFnParams;
10
use super::no_nulls::RollingAggWindowNoNulls;
11
use super::nulls::RollingAggWindowNulls;
12
13
// Algorithm: https://cs.stackexchange.com/questions/120915/interview-question-with-arrays-and-consecutive-subintervals/120936#120936
14
// Modified to return the argmin/argmax instead of the value:
15
pub struct ArgMinMaxWindow<'a, T, P> {
16
pub(crate) values: &'a [T],
17
validity: Option<&'a Bitmap>,
18
// values[monotonic_idxs[i]] is better than values[monotonic_idxs[i+1]] for
19
// all i, as per the policy.
20
monotonic_idxs: VecDeque<usize>,
21
nonnulls_in_window: usize,
22
pub(super) start: usize,
23
pub(super) end: usize,
24
policy: PhantomData<P>,
25
}
26
27
impl<T: NativeType, P: MinMaxPolicy> ArgMinMaxWindow<'_, T, P> {
28
/// # Safety
29
/// The index must be in-bounds.
30
unsafe fn insert_nonnull_value(&mut self, idx: usize) {
31
unsafe {
32
let value = self.values.get_unchecked(idx);
33
34
// Remove values which are older and worse.
35
while let Some(&tail_idx) = self.monotonic_idxs.back() {
36
let tail_value = self.values.get_unchecked(tail_idx);
37
if !P::is_better(value, tail_value) {
38
break;
39
}
40
self.monotonic_idxs.pop_back();
41
}
42
43
self.monotonic_idxs.push_back(idx);
44
self.nonnulls_in_window += 1;
45
}
46
}
47
48
fn remove_old_values(&mut self, window_start: usize) {
49
// Remove values which have fallen outside the window start.
50
while let Some(&head_idx) = self.monotonic_idxs.front() {
51
if head_idx >= window_start {
52
break;
53
}
54
self.monotonic_idxs.pop_front();
55
}
56
}
57
}
58
59
impl<T: NativeType, P: MinMaxPolicy> RollingAggWindowNulls<T, IdxSize>
60
for ArgMinMaxWindow<'_, T, P>
61
{
62
type This<'a> = ArgMinMaxWindow<'a, T, P>;
63
64
fn new<'a>(
65
slice: &'a [T],
66
validity: &'a Bitmap,
67
start: usize,
68
end: usize,
69
params: Option<RollingFnParams>,
70
_window_size: Option<usize>,
71
) -> Self::This<'a> {
72
assert!(params.is_none());
73
assert!(start <= slice.len() && end <= slice.len() && start <= end);
74
75
let mut this = ArgMinMaxWindow {
76
values: slice,
77
validity: Some(validity),
78
monotonic_idxs: VecDeque::new(),
79
nonnulls_in_window: 0,
80
start: 0,
81
end: 0,
82
policy: PhantomData,
83
};
84
// SAFETY: We bounds checked `start` and `end`.
85
unsafe { RollingAggWindowNulls::update(&mut this, start, end) };
86
this
87
}
88
89
unsafe fn update(&mut self, new_start: usize, new_end: usize) {
90
unsafe {
91
let v = self.validity.unwrap_unchecked();
92
self.remove_old_values(new_start);
93
for i in self.start..new_start.min(self.end) {
94
self.nonnulls_in_window -= v.get_bit_unchecked(i) as usize;
95
}
96
for i in new_start.max(self.end)..new_end {
97
if v.get_bit_unchecked(i) {
98
self.insert_nonnull_value(i);
99
}
100
}
101
};
102
self.start = new_start;
103
self.end = new_end;
104
}
105
106
fn get_agg(&self, _idx: usize) -> Option<IdxSize> {
107
self.monotonic_idxs
108
.front()
109
.map(|&best_abs| (best_abs - self.start) as IdxSize)
110
}
111
112
fn is_valid(&self, min_periods: usize) -> bool {
113
self.nonnulls_in_window >= min_periods
114
}
115
116
fn slice_len(&self) -> usize {
117
self.values.len()
118
}
119
}
120
121
impl<T: NativeType, P: MinMaxPolicy> RollingAggWindowNoNulls<T, IdxSize>
122
for ArgMinMaxWindow<'_, T, P>
123
{
124
type This<'a> = ArgMinMaxWindow<'a, T, P>;
125
126
fn new<'a>(
127
slice: &'a [T],
128
start: usize,
129
end: usize,
130
params: Option<RollingFnParams>,
131
_window_size: Option<usize>,
132
) -> Self::This<'a> {
133
assert!(params.is_none());
134
assert!(start <= slice.len() && end <= slice.len() && start <= end);
135
136
let mut this = ArgMinMaxWindow {
137
values: slice,
138
validity: None,
139
monotonic_idxs: VecDeque::new(),
140
nonnulls_in_window: 0,
141
start: 0,
142
end: 0,
143
policy: PhantomData,
144
};
145
146
// SAFETY: We bounds checked `start` and `end`.
147
unsafe { RollingAggWindowNoNulls::update(&mut this, start, end) };
148
this
149
}
150
151
unsafe fn update(&mut self, new_start: usize, new_end: usize) {
152
unsafe {
153
self.remove_old_values(new_start);
154
155
for i in new_start.max(self.end)..new_end {
156
self.insert_nonnull_value(i);
157
}
158
};
159
self.start = new_start;
160
self.end = new_end;
161
}
162
163
fn get_agg(&self, _idx: usize) -> Option<IdxSize> {
164
self.monotonic_idxs
165
.front()
166
.map(|&best_abs| (best_abs - self.start) as IdxSize)
167
}
168
169
fn slice_len(&self) -> usize {
170
self.values.len()
171
}
172
}
173
174
pub type ArgMinWindow<'a, T> = ArgMinMaxWindow<'a, T, MinPropagateNan>;
175
pub type ArgMaxWindow<'a, T> = ArgMinMaxWindow<'a, T, MaxPropagateNan>;
176
177