Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-compute/src/rolling/window.rs
6939 views
1
use ::skiplist::OrderedSkipList;
2
use polars_utils::total_ord::TotalOrd;
3
4
use super::*;
5
6
pub(super) struct SortedBuf<'a, T: NativeType> {
7
// slice over which the window slides
8
slice: &'a [T],
9
last_start: usize,
10
last_end: usize,
11
// values within the window that we keep sorted
12
buf: OrderedSkipList<T>,
13
}
14
15
impl<'a, T: NativeType + PartialOrd + Copy> SortedBuf<'a, T> {
16
pub(super) fn new(
17
slice: &'a [T],
18
start: usize,
19
end: usize,
20
max_window_size: Option<usize>,
21
) -> Self {
22
let mut buf = if let Some(max_window_size) = max_window_size {
23
OrderedSkipList::with_capacity(max_window_size)
24
} else {
25
OrderedSkipList::new()
26
};
27
unsafe { buf.sort_by(TotalOrd::tot_cmp) };
28
let mut out = Self {
29
slice,
30
last_start: start,
31
last_end: end,
32
buf,
33
};
34
let init = &slice[start..end];
35
out.reset(init);
36
out
37
}
38
39
fn reset(&mut self, slice: &[T]) {
40
self.buf.clear();
41
self.buf.extend(slice.iter().copied());
42
}
43
44
/// Update the window position by setting the `start` index and the `end` index.
45
///
46
/// # Safety
47
/// The caller must ensure that `start` and `end` are within bounds of `self.slice`
48
///
49
pub(super) unsafe fn update(&mut self, start: usize, end: usize) {
50
// swap the whole buffer
51
if start >= self.last_end {
52
self.buf.clear();
53
let new_window = unsafe { self.slice.get_unchecked(start..end) };
54
self.reset(new_window);
55
} else {
56
// remove elements that should leave the window
57
for idx in self.last_start..start {
58
// SAFETY:
59
// in bounds
60
let val = unsafe { self.slice.get_unchecked(idx) };
61
self.buf.remove(val);
62
}
63
64
// insert elements that enter the window, but insert them sorted
65
for idx in self.last_end..end {
66
// SAFETY:
67
// we are in bounds
68
let val = unsafe { *self.slice.get_unchecked(idx) };
69
self.buf.insert(val);
70
}
71
}
72
self.last_start = start;
73
self.last_end = end;
74
}
75
76
pub(super) fn get(&self, index: usize) -> T {
77
self.buf[index]
78
}
79
80
pub(super) fn len(&self) -> usize {
81
self.buf.len()
82
}
83
// Note: range is not inclusive
84
pub(super) fn index_range(
85
&self,
86
range: std::ops::Range<usize>,
87
) -> skiplist::ordered_skiplist::Iter<'_, T> {
88
self.buf.index_range(range)
89
}
90
}
91
92
pub(super) struct SortedBufNulls<'a, T: NativeType> {
93
// slice over which the window slides
94
slice: &'a [T],
95
validity: &'a Bitmap,
96
last_start: usize,
97
last_end: usize,
98
// non-null values within the window that we keep sorted
99
buf: OrderedSkipList<T>,
100
pub null_count: usize,
101
}
102
103
impl<'a, T: NativeType + PartialOrd> SortedBufNulls<'a, T> {
104
unsafe fn fill_and_sort_buf(&mut self, start: usize, end: usize) {
105
self.null_count = 0;
106
let iter = (start..end).flat_map(|idx| unsafe {
107
if self.validity.get_bit_unchecked(idx) {
108
Some(*self.slice.get_unchecked(idx))
109
} else {
110
self.null_count += 1;
111
None
112
}
113
});
114
115
self.buf.clear();
116
self.buf.extend(iter);
117
}
118
119
pub unsafe fn new(
120
slice: &'a [T],
121
validity: &'a Bitmap,
122
start: usize,
123
end: usize,
124
max_window_size: Option<usize>,
125
) -> Self {
126
let mut buf = if let Some(max_window_size) = max_window_size {
127
OrderedSkipList::with_capacity(max_window_size)
128
} else {
129
OrderedSkipList::new()
130
};
131
unsafe { buf.sort_by(TotalOrd::tot_cmp) };
132
133
// sort_opt_buf(&mut buf);
134
let mut out = Self {
135
slice,
136
validity,
137
last_start: start,
138
last_end: end,
139
buf,
140
null_count: 0,
141
};
142
unsafe { out.fill_and_sort_buf(start, end) };
143
out
144
}
145
146
/// Update the window position by setting the `start` index and the `end` index.
147
///
148
/// # Safety
149
/// The caller must ensure that `start` and `end` are within bounds of `self.slice`
150
pub unsafe fn update(&mut self, start: usize, end: usize) -> usize {
151
// Swap the whole buffer.
152
if start >= self.last_end {
153
unsafe { self.fill_and_sort_buf(start, end) };
154
} else {
155
// Vemove elements that should leave the window.
156
for idx in self.last_start..start {
157
// SAFETY: we are in bounds.
158
if unsafe { self.validity.get_bit_unchecked(idx) } {
159
self.buf.remove(unsafe { self.slice.get_unchecked(idx) });
160
} else {
161
self.null_count -= 1;
162
}
163
}
164
165
// Insert elements that enter the window, but insert them sorted.
166
for idx in self.last_end..end {
167
// SAFETY: we are in bounds.
168
if unsafe { self.validity.get_bit_unchecked(idx) } {
169
self.buf.insert(unsafe { *self.slice.get_unchecked(idx) });
170
} else {
171
self.null_count += 1;
172
}
173
}
174
}
175
176
self.last_start = start;
177
self.last_end = end;
178
self.null_count
179
}
180
181
pub fn is_valid(&self, min_periods: usize) -> bool {
182
((self.last_end - self.last_start) - self.null_count) >= min_periods
183
}
184
185
pub fn len(&self) -> usize {
186
self.null_count + self.buf.len()
187
}
188
189
pub fn get(&self, idx: usize) -> Option<T> {
190
if idx >= self.null_count {
191
Some(self.buf[idx - self.null_count])
192
} else {
193
None
194
}
195
}
196
197
// Note: range is not inclusive
198
pub fn index_range(&self, range: std::ops::Range<usize>) -> impl Iterator<Item = Option<T>> {
199
let nonnull_range =
200
range.start.saturating_sub(self.null_count)..range.end.saturating_sub(self.null_count);
201
(0..range.len() - nonnull_range.len())
202
.map(|_| None)
203
.chain(self.buf.index_range(nonnull_range).map(|x| Some(*x)))
204
}
205
}
206
207