Path: blob/main/crates/polars-compute/src/rolling/no_nulls/min_max.rs
6939 views
use polars_utils::min_max::{MaxPropagateNan, MinMaxPolicy, MinPropagateNan};12use super::super::min_max::MinMaxWindow;3use super::*;45pub type MinWindow<'a, T> = MinMaxWindow<'a, T, MinPropagateNan>;6pub type MaxWindow<'a, T> = MinMaxWindow<'a, T, MaxPropagateNan>;78fn weighted_min_max<T, P>(values: &[T], weights: &[T]) -> T9where10T: NativeType + std::ops::Mul<Output = T>,11P: MinMaxPolicy,12{13values14.iter()15.zip(weights)16.map(|(v, w)| *v * *w)17.reduce(P::best)18.unwrap()19}2021macro_rules! rolling_minmax_func {22($rolling_m:ident, $policy:ident) => {23pub fn $rolling_m<T>(24values: &[T],25window_size: usize,26min_periods: usize,27center: bool,28weights: Option<&[f64]>,29_params: Option<RollingFnParams>,30) -> PolarsResult<ArrayRef>31where32T: NativeType + PartialOrd + IsFloat + Bounded + NumCast + Mul<Output = T> + Num,33{34let offset_fn = match center {35true => det_offsets_center,36false => det_offsets,37};38match weights {39None => rolling_apply_agg_window::<MinMaxWindow<T, $policy>, _, _>(40values,41window_size,42min_periods,43offset_fn,44None,45),46Some(weights) => {47assert!(48T::is_float(),49"implementation error, should only be reachable by float types"50);51let weights = weights52.iter()53.map(|v| NumCast::from(*v).unwrap())54.collect::<Vec<_>>();55no_nulls::rolling_apply_weights(56values,57window_size,58min_periods,59offset_fn,60weighted_min_max::<T, $policy>,61&weights,62center,63)64},65}66}67};68}6970rolling_minmax_func!(rolling_min, MinPropagateNan);71rolling_minmax_func!(rolling_max, MaxPropagateNan);7273#[cfg(test)]74mod test {75use super::*;7677#[test]78fn test_rolling_min_max() {79let values = &[1.0f64, 5.0, 3.0, 4.0];8081let out = rolling_min(values, 2, 2, false, None, None).unwrap();82let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();83let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();84assert_eq!(out, &[None, Some(1.0), Some(3.0), Some(3.0)]);85let out = rolling_max(values, 2, 2, false, None, None).unwrap();86let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();87let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();88assert_eq!(out, &[None, Some(5.0), Some(5.0), Some(4.0)]);8990let out = rolling_min(values, 2, 1, false, None, None).unwrap();91let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();92let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();93assert_eq!(out, &[Some(1.0), Some(1.0), Some(3.0), Some(3.0)]);94let out = rolling_max(values, 2, 1, false, None, None).unwrap();95let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();96let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();97assert_eq!(out, &[Some(1.0), Some(5.0), Some(5.0), Some(4.0)]);9899let out = rolling_max(values, 3, 1, false, None, None).unwrap();100let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();101let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();102assert_eq!(out, &[Some(1.0), Some(5.0), Some(5.0), Some(5.0)]);103104// test nan handling.105let values = &[1.0, 2.0, 3.0, f64::nan(), 5.0, 6.0, 7.0];106let out = rolling_min(values, 3, 3, false, None, None).unwrap();107let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();108let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();109// we cannot compare nans, so we compare the string values110assert_eq!(111format!("{:?}", out.as_slice()),112format!(113"{:?}",114&[115None,116None,117Some(1.0),118Some(f64::nan()),119Some(f64::nan()),120Some(f64::nan()),121Some(5.0)122]123)124);125126let out = rolling_max(values, 3, 3, false, None, None).unwrap();127let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();128let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();129assert_eq!(130format!("{:?}", out.as_slice()),131format!(132"{:?}",133&[134None,135None,136Some(3.0),137Some(f64::nan()),138Some(f64::nan()),139Some(f64::nan()),140Some(7.0)141]142)143);144}145}146147148