Path: blob/main/crates/polars-compute/src/rolling/window.rs
6939 views
use ::skiplist::OrderedSkipList;1use polars_utils::total_ord::TotalOrd;23use super::*;45pub(super) struct SortedBuf<'a, T: NativeType> {6// slice over which the window slides7slice: &'a [T],8last_start: usize,9last_end: usize,10// values within the window that we keep sorted11buf: OrderedSkipList<T>,12}1314impl<'a, T: NativeType + PartialOrd + Copy> SortedBuf<'a, T> {15pub(super) fn new(16slice: &'a [T],17start: usize,18end: usize,19max_window_size: Option<usize>,20) -> Self {21let mut buf = if let Some(max_window_size) = max_window_size {22OrderedSkipList::with_capacity(max_window_size)23} else {24OrderedSkipList::new()25};26unsafe { buf.sort_by(TotalOrd::tot_cmp) };27let mut out = Self {28slice,29last_start: start,30last_end: end,31buf,32};33let init = &slice[start..end];34out.reset(init);35out36}3738fn reset(&mut self, slice: &[T]) {39self.buf.clear();40self.buf.extend(slice.iter().copied());41}4243/// Update the window position by setting the `start` index and the `end` index.44///45/// # Safety46/// The caller must ensure that `start` and `end` are within bounds of `self.slice`47///48pub(super) unsafe fn update(&mut self, start: usize, end: usize) {49// swap the whole buffer50if start >= self.last_end {51self.buf.clear();52let new_window = unsafe { self.slice.get_unchecked(start..end) };53self.reset(new_window);54} else {55// remove elements that should leave the window56for idx in self.last_start..start {57// SAFETY:58// in bounds59let val = unsafe { self.slice.get_unchecked(idx) };60self.buf.remove(val);61}6263// insert elements that enter the window, but insert them sorted64for idx in self.last_end..end {65// SAFETY:66// we are in bounds67let val = unsafe { *self.slice.get_unchecked(idx) };68self.buf.insert(val);69}70}71self.last_start = start;72self.last_end = end;73}7475pub(super) fn get(&self, index: usize) -> T {76self.buf[index]77}7879pub(super) fn len(&self) -> usize {80self.buf.len()81}82// Note: range is not inclusive83pub(super) fn index_range(84&self,85range: std::ops::Range<usize>,86) -> skiplist::ordered_skiplist::Iter<'_, T> {87self.buf.index_range(range)88}89}9091pub(super) struct SortedBufNulls<'a, T: NativeType> {92// slice over which the window slides93slice: &'a [T],94validity: &'a Bitmap,95last_start: usize,96last_end: usize,97// non-null values within the window that we keep sorted98buf: OrderedSkipList<T>,99pub null_count: usize,100}101102impl<'a, T: NativeType + PartialOrd> SortedBufNulls<'a, T> {103unsafe fn fill_and_sort_buf(&mut self, start: usize, end: usize) {104self.null_count = 0;105let iter = (start..end).flat_map(|idx| unsafe {106if self.validity.get_bit_unchecked(idx) {107Some(*self.slice.get_unchecked(idx))108} else {109self.null_count += 1;110None111}112});113114self.buf.clear();115self.buf.extend(iter);116}117118pub unsafe fn new(119slice: &'a [T],120validity: &'a Bitmap,121start: usize,122end: usize,123max_window_size: Option<usize>,124) -> Self {125let mut buf = if let Some(max_window_size) = max_window_size {126OrderedSkipList::with_capacity(max_window_size)127} else {128OrderedSkipList::new()129};130unsafe { buf.sort_by(TotalOrd::tot_cmp) };131132// sort_opt_buf(&mut buf);133let mut out = Self {134slice,135validity,136last_start: start,137last_end: end,138buf,139null_count: 0,140};141unsafe { out.fill_and_sort_buf(start, end) };142out143}144145/// Update the window position by setting the `start` index and the `end` index.146///147/// # Safety148/// The caller must ensure that `start` and `end` are within bounds of `self.slice`149pub unsafe fn update(&mut self, start: usize, end: usize) -> usize {150// Swap the whole buffer.151if start >= self.last_end {152unsafe { self.fill_and_sort_buf(start, end) };153} else {154// Vemove elements that should leave the window.155for idx in self.last_start..start {156// SAFETY: we are in bounds.157if unsafe { self.validity.get_bit_unchecked(idx) } {158self.buf.remove(unsafe { self.slice.get_unchecked(idx) });159} else {160self.null_count -= 1;161}162}163164// Insert elements that enter the window, but insert them sorted.165for idx in self.last_end..end {166// SAFETY: we are in bounds.167if unsafe { self.validity.get_bit_unchecked(idx) } {168self.buf.insert(unsafe { *self.slice.get_unchecked(idx) });169} else {170self.null_count += 1;171}172}173}174175self.last_start = start;176self.last_end = end;177self.null_count178}179180pub fn is_valid(&self, min_periods: usize) -> bool {181((self.last_end - self.last_start) - self.null_count) >= min_periods182}183184pub fn len(&self) -> usize {185self.null_count + self.buf.len()186}187188pub fn get(&self, idx: usize) -> Option<T> {189if idx >= self.null_count {190Some(self.buf[idx - self.null_count])191} else {192None193}194}195196// Note: range is not inclusive197pub fn index_range(&self, range: std::ops::Range<usize>) -> impl Iterator<Item = Option<T>> {198let nonnull_range =199range.start.saturating_sub(self.null_count)..range.end.saturating_sub(self.null_count);200(0..range.len() - nonnull_range.len())201.map(|_| None)202.chain(self.buf.index_range(nonnull_range).map(|x| Some(*x)))203}204}205206207