Path: blob/main/crates/polars-compute/src/rolling/nulls/mod.rs
8422 views
mod mean;1mod min_max;2mod moment;3mod quantile;4mod rank;5mod sum;67use arrow::legacy::utils::CustomIterTools;8pub use mean::*;9pub use min_max::*;10pub use moment::*;11pub use quantile::*;12pub use rank::*;13pub use sum::*;1415use super::*;1617pub trait RollingAggWindowNulls<T: NativeType, Out: NativeType = T> {18type This<'a>: RollingAggWindowNulls<T, Out>;1920/// # Safety21/// `start` and `end` must be in bounds for `slice` and `validity`22fn new<'a>(23slice: &'a [T],24validity: &'a Bitmap,25start: usize,26end: usize,27params: Option<RollingFnParams>,28window_size: Option<usize>,29) -> Self::This<'a>;3031/// # Safety32/// `start` and `end` must be in bounds of `slice` and `bitmap`33unsafe fn update(&mut self, new_start: usize, new_end: usize);3435/// Get the aggregate of the current window relative to the value at `idx`.36fn get_agg(&self, idx: usize) -> Option<Out>;3738/// Returns the length of the underlying input.39fn slice_len(&self) -> usize;4041fn is_valid(&self, min_periods: usize) -> bool;42}4344// Use an aggregation window that maintains the state45pub(super) fn rolling_apply_agg_window<Agg, T, Out, Fo>(46values: &[T],47validity: &Bitmap,48window_size: usize,49min_periods: usize,50det_offsets_fn: Fo,51params: Option<RollingFnParams>,52) -> ArrayRef53where54Agg: RollingAggWindowNulls<T, Out>,55T: NativeType,56Out: NativeType,57Fo: Fn(Idx, WindowSize, Len) -> (Start, End) + Copy,58{59let len = values.len();60let (start, end) = det_offsets_fn(0, window_size, len);61let mut agg_window = Agg::new(values, validity, start, end, params, Some(window_size));6263let mut validity = create_validity(min_periods, len, window_size, det_offsets_fn)64.unwrap_or_else(|| {65let mut validity = MutableBitmap::with_capacity(len);66validity.extend_constant(len, true);67validity68});6970let out = (0..len)71.map(|idx| {72let (start, end) = det_offsets_fn(idx, window_size, len);73// SAFETY:74// we are in bounds75unsafe { agg_window.update(start, end) };76match agg_window.get_agg(idx) {77Some(val) => {78if agg_window.is_valid(min_periods) {79val80} else {81// SAFETY: we are in bounds82unsafe { validity.set_unchecked(idx, false) };83Out::default()84}85},86None => {87// SAFETY: we are in bounds88unsafe { validity.set_unchecked(idx, false) };89Out::default()90},91}92})93.collect_trusted::<Vec<_>>();9495Box::new(PrimitiveArray::new(96Out::PRIMITIVE.into(),97out.into(),98Some(validity.into()),99))100}101102#[cfg(test)]103mod test {104use arrow::array::{Array, Int32Array};105use arrow::datatypes::ArrowDataType;106use polars_buffer::Buffer;107use polars_utils::min_max::MaxIgnoreNan;108109use super::*;110use crate::rolling::min_max::MinMaxWindow;111112fn get_null_arr() -> PrimitiveArray<f64> {113// 1, None, -1, 4114let buf = Buffer::from(vec![1.0, 0.0, -1.0, 4.0]);115PrimitiveArray::new(116ArrowDataType::Float64,117buf,118Some(Bitmap::from(&[true, false, true, true])),119)120}121122#[test]123fn test_rolling_sum_nulls() {124let buf = Buffer::from(vec![1.0, 2.0, 3.0, 4.0]);125let arr = &PrimitiveArray::new(126ArrowDataType::Float64,127buf,128Some(Bitmap::from(&[true, false, true, true])),129);130131let out = rolling_sum(arr, 2, 2, false, None, None);132let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();133let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();134assert_eq!(out, &[None, None, None, Some(7.0)]);135136let out = rolling_sum(arr, 2, 1, false, None, None);137let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();138let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();139assert_eq!(out, &[Some(1.0), Some(1.0), Some(3.0), Some(7.0)]);140141let out = rolling_sum(arr, 4, 1, false, None, None);142let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();143let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();144assert_eq!(out, &[Some(1.0), Some(1.0), Some(4.0), Some(8.0)]);145146let out = rolling_sum(arr, 4, 1, true, None, None);147let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();148let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();149assert_eq!(out, &[Some(1.0), Some(4.0), Some(8.0), Some(7.0)]);150151let out = rolling_sum(arr, 4, 4, true, None, None);152let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();153let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();154assert_eq!(out, &[None, None, None, None]);155}156157#[test]158fn test_rolling_mean_nulls() {159let arr = get_null_arr();160let arr = &arr;161162let out = rolling_mean(arr, 2, 2, false, None, None);163let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();164let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();165assert_eq!(out, &[None, None, None, Some(1.5)]);166167let out = rolling_mean(arr, 2, 1, false, None, None);168let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();169let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();170assert_eq!(out, &[Some(1.0), Some(1.0), Some(-1.0), Some(1.5)]);171172let out = rolling_mean(arr, 4, 1, false, None, None);173let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();174let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();175assert_eq!(out, &[Some(1.0), Some(1.0), Some(0.0), Some(4.0 / 3.0)]);176}177178#[test]179fn test_rolling_var_nulls() {180let arr = get_null_arr();181let arr = &arr;182183let out = rolling_var(arr, 3, 1, false, None, None);184let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();185let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();186187assert_eq!(out, &[None, None, Some(2.0), Some(12.5)]);188189let testpars = Some(RollingFnParams::Var(RollingVarParams { ddof: 0 }));190let out = rolling_var(arr, 3, 1, false, None, testpars);191let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();192let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();193194assert_eq!(out, &[Some(0.0), Some(0.0), Some(1.0), Some(6.25)]);195196let out = rolling_var(arr, 4, 1, false, None, None);197let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();198let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();199assert_eq!(out, &[None, None, Some(2.0), Some(6.333333333333334)]);200201let out = rolling_var(arr, 4, 1, false, None, testpars);202let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();203let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();204assert_eq!(205out,206&[Some(0.), Some(0.0), Some(1.0), Some(4.222222222222222)]207);208}209210#[test]211fn test_rolling_max_no_nulls() {212let buf = Buffer::from(vec![1.0, 2.0, 3.0, 4.0]);213let arr = &PrimitiveArray::new(214ArrowDataType::Float64,215buf,216Some(Bitmap::from(&[true, true, true, true])),217);218let out = rolling_max(arr, 4, 1, false, None, None);219let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();220let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();221assert_eq!(out, &[Some(1.0), Some(2.0), Some(3.0), Some(4.0)]);222223let out = rolling_max(arr, 2, 2, false, None, None);224let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();225let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();226assert_eq!(out, &[None, Some(2.0), Some(3.0), Some(4.0)]);227228let out = rolling_max(arr, 4, 4, false, None, None);229let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();230let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();231assert_eq!(out, &[None, None, None, Some(4.0)]);232233let buf = Buffer::from(vec![4.0, 3.0, 2.0, 1.0]);234let arr = &PrimitiveArray::new(235ArrowDataType::Float64,236buf,237Some(Bitmap::from(&[true, true, true, true])),238);239let out = rolling_max(arr, 2, 1, false, None, None);240let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();241let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();242assert_eq!(out, &[Some(4.0), Some(4.0), Some(3.0), Some(2.0)]);243244let out =245super::no_nulls::rolling_max(arr.values().as_slice(), 2, 1, false, None, None).unwrap();246let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();247let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();248assert_eq!(out, &[Some(4.0), Some(4.0), Some(3.0), Some(2.0)]);249}250251#[test]252fn test_rolling_extrema_nulls() {253let vals = vec![3, 3, 3, 10, 10, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1];254let validity = Bitmap::new_with_value(true, vals.len());255let window_size = 3;256let min_periods = 3;257258let arr = Int32Array::new(ArrowDataType::Int32, vals.into(), Some(validity));259260let out = rolling_apply_agg_window::<MinMaxWindow<i32, MaxIgnoreNan>, _, _, _>(261arr.values().as_slice(),262arr.validity().as_ref().unwrap(),263window_size,264min_periods,265det_offsets,266None,267);268let arr = out.as_any().downcast_ref::<Int32Array>().unwrap();269assert_eq!(arr.null_count(), 2);270assert_eq!(271&arr.values().as_slice()[2..],272&[3, 10, 10, 10, 10, 10, 9, 8, 7, 6, 5, 4, 3]273);274}275}276277278