Path: blob/main/crates/polars-compute/src/rolling/no_nulls/mod.rs
6939 views
mod mean;1mod min_max;2mod moment;3mod quantile;4mod sum;5use std::fmt::Debug;67use arrow::array::PrimitiveArray;8use arrow::datatypes::ArrowDataType;9use arrow::legacy::error::PolarsResult;10use arrow::legacy::utils::CustomIterTools;11use arrow::types::NativeType;12pub use mean::*;13pub use min_max::*;14pub use moment::*;15use num_traits::{Float, Num, NumCast};16pub use quantile::*;17pub use sum::*;1819use super::*;2021pub trait RollingAggWindowNoNulls<'a, T: NativeType> {22fn new(23slice: &'a [T],24start: usize,25end: usize,26params: Option<RollingFnParams>,27window_size: Option<usize>,28) -> Self;2930/// Update and recompute the window31///32/// # Safety33/// `start` and `end` must be within the windows bounds34unsafe fn update(&mut self, start: usize, end: usize) -> Option<T>;35}3637// Use an aggregation window that maintains the state38pub(super) fn rolling_apply_agg_window<'a, Agg, T, Fo>(39values: &'a [T],40window_size: usize,41min_periods: usize,42det_offsets_fn: Fo,43params: Option<RollingFnParams>,44) -> PolarsResult<ArrayRef>45where46Fo: Fn(Idx, WindowSize, Len) -> (Start, End),47Agg: RollingAggWindowNoNulls<'a, T>,48T: Debug + NativeType + Num,49{50let len = values.len();51let (start, end) = det_offsets_fn(0, window_size, len);52let mut agg_window = Agg::new(values, start, end, params, Some(window_size));53if let Some(validity) = create_validity(min_periods, len, window_size, &det_offsets_fn) {54if validity.iter().all(|x| !x) {55return Ok(Box::new(PrimitiveArray::<T>::new_null(56T::PRIMITIVE.into(),57len,58)));59}60}6162let out = (0..len).map(|idx| {63let (start, end) = det_offsets_fn(idx, window_size, len);64if end - start < min_periods {65None66} else {67// SAFETY:68// we are in bounds69unsafe { agg_window.update(start, end) }70}71});72let arr = PrimitiveArray::from_trusted_len_iter(out);73Ok(Box::new(arr))74}7576pub(super) fn rolling_apply_weights<T, Fo, Fa>(77values: &[T],78window_size: usize,79min_periods: usize,80det_offsets_fn: Fo,81aggregator: Fa,82weights: &[T],83centered: bool,84) -> PolarsResult<ArrayRef>85where86T: NativeType + num_traits::Zero + std::ops::Div<Output = T> + Copy,87Fo: Fn(Idx, WindowSize, Len) -> (Start, End),88Fa: Fn(&[T], &[T]) -> T,89{90assert_eq!(weights.len(), window_size);91let len = values.len();92let out = (0..len)93.map(|idx| {94let (start, end) = det_offsets_fn(idx, window_size, len);95let vals = unsafe { values.get_unchecked(start..end) };96let win_len = end - start;97let weights_start = if centered {98// When using centered weights, we need to find the right location99// in the weights array specifically by aligning the center of the100// window with idx, to handle cases where the window is smaller than101// weights array.102let center = (window_size / 2) as isize;103let offset = center - (idx as isize - start as isize);104offset.max(0) as usize105} else if start == 0 {106// When start is 0, we need to work backwards from the end of the107// weights array to ensure we are lined up correctly (since the108// start of the values array is implicitly cut off)109weights.len() - win_len110} else {1110112};113let weights_slice = &weights[weights_start..weights_start + win_len];114aggregator(vals, weights_slice)115})116.collect_trusted::<Vec<T>>();117118let validity = create_validity(min_periods, len, window_size, det_offsets_fn);119Ok(Box::new(PrimitiveArray::new(120ArrowDataType::from(T::PRIMITIVE),121out.into(),122validity.map(|b| b.into()),123)))124}125126fn compute_var_weights<T>(vals: &[T], weights: &[T]) -> T127where128T: Float + std::ops::AddAssign,129{130// Compute weighted mean and weighted sum of squares in a single pass131let (wssq, wmean, total_weight) = vals.iter().zip(weights).fold(132(T::zero(), T::zero(), T::zero()),133|(wssq, wsum, wtot), (&v, &w)| (wssq + v * v * w, wsum + v * w, wtot + w),134);135if total_weight.is_zero() {136panic!("Weighted variance is undefined if weights sum to 0");137}138let mean = wmean / total_weight;139(wssq / total_weight) - (mean * mean)140}141142pub(crate) fn compute_sum_weights<T>(values: &[T], weights: &[T]) -> T143where144T: std::iter::Sum<T> + Copy + std::ops::Mul<Output = T>,145{146values.iter().zip(weights).map(|(v, w)| *v * *w).sum()147}148149/// Compute the weighted mean of values, given weights (not necessarily normalized).150/// Returns sum_i(values[i] * weights[i]) / sum_i(weights[i])151pub(crate) fn compute_mean_weights<T>(values: &[T], weights: &[T]) -> T152where153T: std::iter::Sum<T>154+ Copy155+ std::ops::Mul<Output = T>156+ std::ops::Div<Output = T>157+ num_traits::Zero,158{159let (weighted_sum, total_weight) = values160.iter()161.zip(weights)162.fold((T::zero(), T::zero()), |(wsum, wtot), (&v, &w)| {163(wsum + v * w, wtot + w)164});165if total_weight.is_zero() {166panic!("Weighted mean is undefined if weights sum to 0");167}168weighted_sum / total_weight169}170171pub(super) fn coerce_weights<T: NumCast>(weights: &[f64]) -> Vec<T>172where173{174weights175.iter()176.map(|v| NumCast::from(*v).unwrap())177.collect::<Vec<_>>()178}179180181