Path: blob/main/crates/polars-time/src/chunkedarray/rolling_window/rolling_kernels/shared.rs
8395 views
//! This module implements logic shared between nulls and no_nulls.12use arrow::array::{ArrayRef, PrimitiveArray};3use arrow::bitmap::MutableBitmap;4use arrow::trusted_len::TrustedLen;5use arrow::types::NativeType;6use bytemuck::allocation::zeroed_vec;7#[cfg(feature = "timezones")]8use chrono_tz::Tz;9use polars_compute::rolling::no_nulls::RollingAggWindowNoNulls;10use polars_compute::rolling::nulls::RollingAggWindowNulls;11use polars_core::prelude::*;1213use crate::windows::duration::Duration;14use crate::windows::group_by::{ClosedWindow, group_by_values_iter};1516pub(crate) trait RollingAggWindow<T: NativeType, Out: NativeType> {17/// # Safety18/// `start` and `end` must be in bounds of `slice` and associated structures.19unsafe fn update(&mut self, start: usize, end: usize);2021/// Get the aggregate of the current window relative to the value at `idx`.22fn get_agg(&self, idx: usize) -> Option<Out>;2324/// Returns the length of the underlying input.25fn slice_len(&self) -> usize;26}2728#[repr(transparent)]29pub(crate) struct RollingAggWindowNoNullsWrapper<T>(pub T);30#[repr(transparent)]31pub(crate) struct RollingAggWindowNullsWrapper<T>(pub T);3233impl<T: NativeType, Out: NativeType, Agg: RollingAggWindowNoNulls<T, Out>> RollingAggWindow<T, Out>34for RollingAggWindowNoNullsWrapper<Agg>35{36unsafe fn update(&mut self, start: usize, end: usize) {37// SAFETY: Caller MUST uphold function safety contract.38unsafe { self.0.update(start, end) }39}4041fn get_agg(&self, idx: usize) -> Option<Out> {42self.0.get_agg(idx)43}4445fn slice_len(&self) -> usize {46self.0.slice_len()47}48}4950impl<T: NativeType, Out: NativeType, Agg: RollingAggWindowNulls<T, Out>> RollingAggWindow<T, Out>51for RollingAggWindowNullsWrapper<Agg>52{53unsafe fn update(&mut self, start: usize, end: usize) {54// SAFETY: Caller MUST uphold function safety contract.55unsafe { self.0.update(start, end) }56}5758fn get_agg(&self, idx: usize) -> Option<Out> {59self.0.get_agg(idx)60}6162fn slice_len(&self) -> usize {63self.0.slice_len()64}65}6667#[expect(clippy::too_many_arguments)]68pub(crate) fn rolling_apply_agg<T, Out, Agg>(69agg_window: &mut Agg,70period: Duration,71time: &[i64],72closed_window: ClosedWindow,73min_periods: usize,74tu: TimeUnit,75tz: Option<&TimeZone>,76sorting_indices: Option<&[IdxSize]>,77) -> PolarsResult<ArrayRef>78where79T: NativeType,80Out: NativeType,81Agg: RollingAggWindow<T, Out>,82{83let offset_iter = match tz {84#[cfg(feature = "timezones")]85Some(tz) => group_by_values_iter(period, time, closed_window, tu, tz.parse::<Tz>().ok()),86_ => group_by_values_iter(period, time, closed_window, tu, None),87}?;8889if let Some(indices) = sorting_indices {90rolling_apply_agg_window(agg_window, offset_iter, min_periods, indices)91} else {92rolling_apply_agg_window_sorted(agg_window, offset_iter, min_periods)93}94}9596// Use an aggregation window that maintains the state.97// Fastpath if values were known to already be sorted by time.98fn rolling_apply_agg_window_sorted<Agg, O, T, Out>(99agg_window: &mut Agg,100offsets: O,101min_periods: usize,102) -> PolarsResult<ArrayRef>103where104Agg: RollingAggWindow<T, Out>,105O: Iterator<Item = PolarsResult<(IdxSize, IdxSize)>> + TrustedLen,106T: NativeType,107Out: NativeType,108{109let out = offsets110.enumerate()111.map(|(idx, result)| {112result.map(|(start, len)| {113let end = start + len;114115// On the Python side, if `min_periods` wasn't specified, it is set to116// `1`. In that case, this condition is the same as checking117// `if start == end`.118if len < (min_periods as IdxSize) {119None120} else {121// SAFETY: we are in bounds122unsafe { agg_window.update(start as usize, end as usize) }123agg_window.get_agg(idx)124}125})126})127.collect::<PolarsResult<PrimitiveArray<Out>>>()?;128129Ok(Box::new(out))130}131132// Use an aggregation window that maintains the state133fn rolling_apply_agg_window<Agg, O, T, Out>(134agg_window: &mut Agg,135offsets: O,136min_periods: usize,137sorting_indices: &[IdxSize],138) -> PolarsResult<ArrayRef>139where140Agg: RollingAggWindow<T, Out>,141O: Iterator<Item = PolarsResult<(IdxSize, IdxSize)>> + TrustedLen,142T: NativeType,143Out: NativeType,144{145let mut out = zeroed_vec(agg_window.slice_len());146let mut validity: Option<MutableBitmap> = None;147offsets.enumerate().try_for_each(|(idx, result)| {148let (start, len) = result?;149let end = start + len;150let out_idx = unsafe { sorting_indices.get_unchecked(idx) };151152// On the Python side, if `min_periods` wasn't specified, it is set to153// `1`. In that case, this condition is the same as checking154// `if start == end`.155if len >= (min_periods as IdxSize) {156// SAFETY:157// we are in bound158unsafe { agg_window.update(start as usize, end as usize) };159let res = agg_window.get_agg(*out_idx as usize);160161if let Some(res) = res {162// SAFETY: `idx` is in bounds because `sorting_indices` was just taken from163// `by`, which has already been checked to be the same length as the values.164unsafe { *out.get_unchecked_mut(*out_idx as usize) = res };165} else {166instantiate_bitmap_if_null_and_set_false_at_idx(167&mut validity,168agg_window.slice_len(),169*out_idx as usize,170)171}172} else {173instantiate_bitmap_if_null_and_set_false_at_idx(174&mut validity,175agg_window.slice_len(),176*out_idx as usize,177)178}179Ok::<(), PolarsError>(())180})?;181182let out = PrimitiveArray::<Out>::from_vec(out).with_validity(validity.map(|x| x.into()));183184Ok(Box::new(out))185}186187// Instantiate a bitmap when the first null value is encountered.188// Set the validity at index `idx` to `false`.189fn instantiate_bitmap_if_null_and_set_false_at_idx(190validity: &mut Option<MutableBitmap>,191len: usize,192idx: usize,193) {194let bitmap = validity.get_or_insert_with(|| {195let mut bitmap = MutableBitmap::with_capacity(len);196bitmap.extend_constant(len, true);197bitmap198});199bitmap.set(idx, false);200}201202203