Path: blob/main/crates/polars-compute/src/rolling/arg_min_max.rs
8448 views
use std::collections::VecDeque;1use std::marker::PhantomData;23use arrow::bitmap::Bitmap;4use arrow::types::NativeType;5use polars_utils::IdxSize;6use polars_utils::min_max::{MaxPropagateNan, MinMaxPolicy, MinPropagateNan};78use super::RollingFnParams;9use super::no_nulls::RollingAggWindowNoNulls;10use super::nulls::RollingAggWindowNulls;1112// Algorithm: https://cs.stackexchange.com/questions/120915/interview-question-with-arrays-and-consecutive-subintervals/120936#12093613// Modified to return the argmin/argmax instead of the value:14pub struct ArgMinMaxWindow<'a, T, P> {15pub(crate) values: &'a [T],16validity: Option<&'a Bitmap>,17// values[monotonic_idxs[i]] is better than values[monotonic_idxs[i+1]] for18// all i, as per the policy.19monotonic_idxs: VecDeque<usize>,20nonnulls_in_window: usize,21pub(super) start: usize,22pub(super) end: usize,23policy: PhantomData<P>,24}2526impl<T: NativeType, P: MinMaxPolicy> ArgMinMaxWindow<'_, T, P> {27/// # Safety28/// The index must be in-bounds.29unsafe fn insert_nonnull_value(&mut self, idx: usize) {30unsafe {31let value = self.values.get_unchecked(idx);3233// Remove values which are older and worse.34while let Some(&tail_idx) = self.monotonic_idxs.back() {35let tail_value = self.values.get_unchecked(tail_idx);36if !P::is_better(value, tail_value) {37break;38}39self.monotonic_idxs.pop_back();40}4142self.monotonic_idxs.push_back(idx);43self.nonnulls_in_window += 1;44}45}4647fn remove_old_values(&mut self, window_start: usize) {48// Remove values which have fallen outside the window start.49while let Some(&head_idx) = self.monotonic_idxs.front() {50if head_idx >= window_start {51break;52}53self.monotonic_idxs.pop_front();54}55}56}5758impl<T: NativeType, P: MinMaxPolicy> RollingAggWindowNulls<T, IdxSize>59for ArgMinMaxWindow<'_, T, P>60{61type This<'a> = ArgMinMaxWindow<'a, T, P>;6263fn new<'a>(64slice: &'a [T],65validity: &'a Bitmap,66start: usize,67end: usize,68params: Option<RollingFnParams>,69_window_size: Option<usize>,70) -> Self::This<'a> {71assert!(params.is_none());72assert!(start <= slice.len() && end <= slice.len() && start <= end);7374let mut this = ArgMinMaxWindow {75values: slice,76validity: Some(validity),77monotonic_idxs: VecDeque::new(),78nonnulls_in_window: 0,79start: 0,80end: 0,81policy: PhantomData,82};83// SAFETY: We bounds checked `start` and `end`.84unsafe { RollingAggWindowNulls::update(&mut this, start, end) };85this86}8788unsafe fn update(&mut self, new_start: usize, new_end: usize) {89unsafe {90let v = self.validity.unwrap_unchecked();91self.remove_old_values(new_start);92for i in self.start..new_start.min(self.end) {93self.nonnulls_in_window -= v.get_bit_unchecked(i) as usize;94}95for i in new_start.max(self.end)..new_end {96if v.get_bit_unchecked(i) {97self.insert_nonnull_value(i);98}99}100};101self.start = new_start;102self.end = new_end;103}104105fn get_agg(&self, _idx: usize) -> Option<IdxSize> {106self.monotonic_idxs107.front()108.map(|&best_abs| (best_abs - self.start) as IdxSize)109}110111fn is_valid(&self, min_periods: usize) -> bool {112self.nonnulls_in_window >= min_periods113}114115fn slice_len(&self) -> usize {116self.values.len()117}118}119120impl<T: NativeType, P: MinMaxPolicy> RollingAggWindowNoNulls<T, IdxSize>121for ArgMinMaxWindow<'_, T, P>122{123type This<'a> = ArgMinMaxWindow<'a, T, P>;124125fn new<'a>(126slice: &'a [T],127start: usize,128end: usize,129params: Option<RollingFnParams>,130_window_size: Option<usize>,131) -> Self::This<'a> {132assert!(params.is_none());133assert!(start <= slice.len() && end <= slice.len() && start <= end);134135let mut this = ArgMinMaxWindow {136values: slice,137validity: None,138monotonic_idxs: VecDeque::new(),139nonnulls_in_window: 0,140start: 0,141end: 0,142policy: PhantomData,143};144145// SAFETY: We bounds checked `start` and `end`.146unsafe { RollingAggWindowNoNulls::update(&mut this, start, end) };147this148}149150unsafe fn update(&mut self, new_start: usize, new_end: usize) {151unsafe {152self.remove_old_values(new_start);153154for i in new_start.max(self.end)..new_end {155self.insert_nonnull_value(i);156}157};158self.start = new_start;159self.end = new_end;160}161162fn get_agg(&self, _idx: usize) -> Option<IdxSize> {163self.monotonic_idxs164.front()165.map(|&best_abs| (best_abs - self.start) as IdxSize)166}167168fn slice_len(&self) -> usize {169self.values.len()170}171}172173pub type ArgMinWindow<'a, T> = ArgMinMaxWindow<'a, T, MinPropagateNan>;174pub type ArgMaxWindow<'a, T> = ArgMinMaxWindow<'a, T, MaxPropagateNan>;175176177