Path: blob/main/crates/polars-compute/src/rolling/no_nulls/mod.rs
8424 views
use std::fmt::Debug;12use arrow::array::PrimitiveArray;3use arrow::datatypes::ArrowDataType;4use arrow::legacy::error::PolarsResult;5use arrow::legacy::utils::CustomIterTools;6use arrow::types::NativeType;7use num_traits::{Float, Num, NumCast};89mod mean;10mod min_max;11mod moment;12mod quantile;13pub mod rank;14mod sum;1516pub use mean::*;17pub use min_max::*;18pub use moment::*;19pub use quantile::*;20pub use rank::*;21pub use sum::*;2223use super::*;2425pub trait RollingAggWindowNoNulls<T: NativeType, Out: NativeType = T> {26type This<'a>: RollingAggWindowNoNulls<T, Out>;2728fn new(29slice: &[T],30start: usize,31end: usize,32params: Option<RollingFnParams>,33window_size: Option<usize>,34) -> Self::This<'_>;3536/// Update and recompute the window37///38/// # Safety39/// `start` and `end` must be within the windows bounds40unsafe fn update(&mut self, new_start: usize, new_end: usize);4142/// Get the aggregate of the current window relative to the value at `idx`.43fn get_agg(&self, idx: usize) -> Option<Out>;4445/// Returns the length of the underlying input.46fn slice_len(&self) -> usize;47}4849// Use an aggregation window that maintains the state50pub(super) fn rolling_apply_agg_window<Agg, T, O, Fo>(51values: &[T],52window_size: usize,53min_periods: usize,54det_offsets_fn: Fo,55params: Option<RollingFnParams>,56) -> PolarsResult<ArrayRef>57where58Fo: Fn(Idx, WindowSize, Len) -> (Start, End),59Agg: RollingAggWindowNoNulls<T, O>,60T: Debug + NativeType + Num,61O: Debug + NativeType + Num,62{63let len = values.len();64let (start, end) = det_offsets_fn(0, window_size, len);65let mut agg_window = Agg::new(values, start, end, params, Some(window_size));66let out = (0..len).map(|idx| {67let (start, end) = det_offsets_fn(idx, window_size, len);68if end - start < min_periods {69None70} else {71// SAFETY:72// we are in bounds73unsafe { agg_window.update(start, end) }74agg_window.get_agg(idx)75}76});77let arr = PrimitiveArray::from_trusted_len_iter(out);78Ok(Box::new(arr))79}8081pub(super) fn rolling_apply_weights<T, Fo, Fa>(82values: &[T],83window_size: usize,84min_periods: usize,85det_offsets_fn: Fo,86aggregator: Fa,87weights: &[T],88centered: bool,89) -> PolarsResult<ArrayRef>90where91T: NativeType + num_traits::Zero + std::ops::Div<Output = T> + Copy,92Fo: Fn(Idx, WindowSize, Len) -> (Start, End),93Fa: Fn(&[T], &[T]) -> T,94{95assert_eq!(weights.len(), window_size);96let len = values.len();97let out = (0..len)98.map(|idx| {99let (start, end) = det_offsets_fn(idx, window_size, len);100let vals = unsafe { values.get_unchecked(start..end) };101let win_len = end - start;102let weights_start = if centered {103// When using centered weights, we need to find the right location104// in the weights array specifically by aligning the center of the105// window with idx, to handle cases where the window is smaller than106// weights array.107let center = (window_size / 2) as isize;108let offset = center - (idx as isize - start as isize);109offset.max(0) as usize110} else if start == 0 {111// When start is 0, we need to work backwards from the end of the112// weights array to ensure we are lined up correctly (since the113// start of the values array is implicitly cut off)114weights.len() - win_len115} else {1160117};118let weights_slice = &weights[weights_start..weights_start + win_len];119aggregator(vals, weights_slice)120})121.collect_trusted::<Vec<T>>();122123let validity = create_validity(min_periods, len, window_size, det_offsets_fn);124Ok(Box::new(PrimitiveArray::new(125ArrowDataType::from(T::PRIMITIVE),126out.into(),127validity.map(|b| b.into()),128)))129}130131fn compute_var_weights<T>(vals: &[T], weights: &[T]) -> T132where133T: Float + std::ops::AddAssign,134{135// Compute weighted mean and weighted sum of squares in a single pass136let (wssq, wmean, total_weight) = vals.iter().zip(weights).fold(137(T::zero(), T::zero(), T::zero()),138|(wssq, wsum, wtot), (&v, &w)| (wssq + v * v * w, wsum + v * w, wtot + w),139);140if total_weight.is_zero() {141T::zero() // Will get masked to null.142} else {143let mean = wmean / total_weight;144(wssq / total_weight) - (mean * mean)145}146}147148pub(crate) fn compute_sum_weights<T>(values: &[T], weights: &[T]) -> T149where150T: std::iter::Sum<T> + Copy + std::ops::Mul<Output = T>,151{152values.iter().zip(weights).map(|(v, w)| *v * *w).sum()153}154155/// Compute the weighted mean of values, given weights (not necessarily normalized).156/// Returns sum_i(values[i] * weights[i]) / sum_i(weights[i])157pub(crate) fn compute_mean_weights<T>(values: &[T], weights: &[T]) -> T158where159T: std::iter::Sum<T>160+ Copy161+ std::ops::Mul<Output = T>162+ std::ops::Div<Output = T>163+ num_traits::Zero,164{165let (weighted_sum, total_weight) = values166.iter()167.zip(weights)168.fold((T::zero(), T::zero()), |(wsum, wtot), (&v, &w)| {169(wsum + v * w, wtot + w)170});171if total_weight.is_zero() {172T::zero() // Will get masked to null.173} else {174weighted_sum / total_weight175}176}177178pub(super) fn coerce_weights<T: NumCast>(weights: &[f64]) -> Vec<T>179where180{181weights182.iter()183.map(|v| NumCast::from(*v).unwrap())184.collect::<Vec<_>>()185}186187188