Path: blob/main/crates/polars-compute/src/rolling/nulls/mod.rs
6939 views
mod mean;1mod min_max;2mod moment;3mod quantile;4mod sum;56use arrow::legacy::utils::CustomIterTools;7pub use mean::*;8pub use min_max::*;9pub use moment::*;10pub use quantile::*;11pub use sum::*;1213use super::*;1415pub trait RollingAggWindowNulls<'a, T: NativeType> {16/// # Safety17/// `start` and `end` must be in bounds for `slice` and `validity`18unsafe fn new(19slice: &'a [T],20validity: &'a Bitmap,21start: usize,22end: usize,23params: Option<RollingFnParams>,24window_size: Option<usize>,25) -> Self;2627/// # Safety28/// `start` and `end` must be in bounds of `slice` and `bitmap`29unsafe fn update(&mut self, start: usize, end: usize) -> Option<T>;3031fn is_valid(&self, min_periods: usize) -> bool;32}3334// Use an aggregation window that maintains the state35pub(super) fn rolling_apply_agg_window<'a, Agg, T, Fo>(36values: &'a [T],37validity: &'a Bitmap,38window_size: usize,39min_periods: usize,40det_offsets_fn: Fo,41params: Option<RollingFnParams>,42) -> ArrayRef43where44Fo: Fn(Idx, WindowSize, Len) -> (Start, End) + Copy,45Agg: RollingAggWindowNulls<'a, T>,46T: IsFloat + NativeType,47{48let len = values.len();49let (start, end) = det_offsets_fn(0, window_size, len);50// SAFETY; we are in bounds51let mut agg_window =52unsafe { Agg::new(values, validity, start, end, params, Some(window_size)) };5354let mut validity = create_validity(min_periods, len, window_size, det_offsets_fn)55.unwrap_or_else(|| {56let mut validity = MutableBitmap::with_capacity(len);57validity.extend_constant(len, true);58validity59});6061let out = (0..len)62.map(|idx| {63let (start, end) = det_offsets_fn(idx, window_size, len);64// SAFETY:65// we are in bounds66let agg = unsafe { agg_window.update(start, end) };67match agg {68Some(val) => {69if agg_window.is_valid(min_periods) {70val71} else {72// SAFETY: we are in bounds73unsafe { validity.set_unchecked(idx, false) };74T::default()75}76},77None => {78// SAFETY: we are in bounds79unsafe { validity.set_unchecked(idx, false) };80T::default()81},82}83})84.collect_trusted::<Vec<_>>();8586Box::new(PrimitiveArray::new(87T::PRIMITIVE.into(),88out.into(),89Some(validity.into()),90))91}9293#[cfg(test)]94mod test {95use arrow::array::{Array, Int32Array};96use arrow::buffer::Buffer;97use arrow::datatypes::ArrowDataType;98use polars_utils::min_max::MaxIgnoreNan;99100use super::*;101use crate::rolling::min_max::MinMaxWindow;102103fn get_null_arr() -> PrimitiveArray<f64> {104// 1, None, -1, 4105let buf = Buffer::from(vec![1.0, 0.0, -1.0, 4.0]);106PrimitiveArray::new(107ArrowDataType::Float64,108buf,109Some(Bitmap::from(&[true, false, true, true])),110)111}112113#[test]114fn test_rolling_sum_nulls() {115let buf = Buffer::from(vec![1.0, 2.0, 3.0, 4.0]);116let arr = &PrimitiveArray::new(117ArrowDataType::Float64,118buf,119Some(Bitmap::from(&[true, false, true, true])),120);121122let out = rolling_sum(arr, 2, 2, false, None, None);123let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();124let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();125assert_eq!(out, &[None, None, None, Some(7.0)]);126127let out = rolling_sum(arr, 2, 1, false, None, None);128let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();129let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();130assert_eq!(out, &[Some(1.0), Some(1.0), Some(3.0), Some(7.0)]);131132let out = rolling_sum(arr, 4, 1, false, None, None);133let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();134let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();135assert_eq!(out, &[Some(1.0), Some(1.0), Some(4.0), Some(8.0)]);136137let out = rolling_sum(arr, 4, 1, true, None, None);138let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();139let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();140assert_eq!(out, &[Some(1.0), Some(4.0), Some(8.0), Some(7.0)]);141142let out = rolling_sum(arr, 4, 4, true, None, None);143let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();144let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();145assert_eq!(out, &[None, None, None, None]);146}147148#[test]149fn test_rolling_mean_nulls() {150let arr = get_null_arr();151let arr = &arr;152153let out = rolling_mean(arr, 2, 2, false, None, None);154let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();155let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();156assert_eq!(out, &[None, None, None, Some(1.5)]);157158let out = rolling_mean(arr, 2, 1, false, None, None);159let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();160let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();161assert_eq!(out, &[Some(1.0), Some(1.0), Some(-1.0), Some(1.5)]);162163let out = rolling_mean(arr, 4, 1, false, None, None);164let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();165let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();166assert_eq!(out, &[Some(1.0), Some(1.0), Some(0.0), Some(4.0 / 3.0)]);167}168169#[test]170fn test_rolling_var_nulls() {171let arr = get_null_arr();172let arr = &arr;173174let out = rolling_var(arr, 3, 1, false, None, None);175let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();176let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();177178assert_eq!(out, &[None, None, Some(2.0), Some(12.5)]);179180let testpars = Some(RollingFnParams::Var(RollingVarParams { ddof: 0 }));181let out = rolling_var(arr, 3, 1, false, None, testpars);182let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();183let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();184185assert_eq!(out, &[Some(0.0), Some(0.0), Some(1.0), Some(6.25)]);186187let out = rolling_var(arr, 4, 1, false, None, None);188let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();189let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();190assert_eq!(out, &[None, None, Some(2.0), Some(6.333333333333334)]);191192let out = rolling_var(arr, 4, 1, false, None, testpars);193let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();194let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();195assert_eq!(196out,197&[Some(0.), Some(0.0), Some(1.0), Some(4.222222222222222)]198);199}200201#[test]202fn test_rolling_max_no_nulls() {203let buf = Buffer::from(vec![1.0, 2.0, 3.0, 4.0]);204let arr = &PrimitiveArray::new(205ArrowDataType::Float64,206buf,207Some(Bitmap::from(&[true, true, true, true])),208);209let out = rolling_max(arr, 4, 1, false, None, None);210let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();211let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();212assert_eq!(out, &[Some(1.0), Some(2.0), Some(3.0), Some(4.0)]);213214let out = rolling_max(arr, 2, 2, false, None, None);215let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();216let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();217assert_eq!(out, &[None, Some(2.0), Some(3.0), Some(4.0)]);218219let out = rolling_max(arr, 4, 4, false, None, None);220let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();221let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();222assert_eq!(out, &[None, None, None, Some(4.0)]);223224let buf = Buffer::from(vec![4.0, 3.0, 2.0, 1.0]);225let arr = &PrimitiveArray::new(226ArrowDataType::Float64,227buf,228Some(Bitmap::from(&[true, true, true, true])),229);230let out = rolling_max(arr, 2, 1, false, None, None);231let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();232let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();233assert_eq!(out, &[Some(4.0), Some(4.0), Some(3.0), Some(2.0)]);234235let out =236super::no_nulls::rolling_max(arr.values().as_slice(), 2, 1, false, None, None).unwrap();237let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();238let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();239assert_eq!(out, &[Some(4.0), Some(4.0), Some(3.0), Some(2.0)]);240}241242#[test]243fn test_rolling_extrema_nulls() {244let vals = vec![3, 3, 3, 10, 10, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1];245let validity = Bitmap::new_with_value(true, vals.len());246let window_size = 3;247let min_periods = 3;248249let arr = Int32Array::new(ArrowDataType::Int32, vals.into(), Some(validity));250251let out = rolling_apply_agg_window::<MinMaxWindow<i32, MaxIgnoreNan>, _, _>(252arr.values().as_slice(),253arr.validity().as_ref().unwrap(),254window_size,255min_periods,256det_offsets,257None,258);259let arr = out.as_any().downcast_ref::<Int32Array>().unwrap();260assert_eq!(arr.null_count(), 2);261assert_eq!(262&arr.values().as_slice()[2..],263&[3, 10, 10, 10, 10, 10, 9, 8, 7, 6, 5, 4, 3]264);265}266}267268269