Path: blob/main/crates/polars-compute/src/rolling/nulls/moment.rs
6939 views
#![allow(unsafe_op_in_unsafe_fn)]12use num_traits::{FromPrimitive, ToPrimitive};34pub use super::super::moment::*;5use super::*;67pub struct MomentWindow<'a, T, M: StateUpdate> {8slice: &'a [T],9validity: &'a Bitmap,10moment: Option<M>,11last_start: usize,12last_end: usize,13null_count: usize,14params: Option<RollingFnParams>,15}1617impl<T: NativeType + ToPrimitive, M: StateUpdate> MomentWindow<'_, T, M> {18// compute sum from the entire window19unsafe fn compute_moment_and_null_count(&mut self, start: usize, end: usize) {20self.moment = None;21let mut idx = start;22self.null_count = 0;23for value in &self.slice[start..end] {24let valid = self.validity.get_bit_unchecked(idx);25if valid {26let value: f64 = NumCast::from(*value).unwrap();27self.moment28.get_or_insert_with(|| M::new(self.params))29.insert_one(value);30} else {31self.null_count += 1;32}33idx += 1;34}35}36}3738impl<'a, T: NativeType + ToPrimitive + IsFloat + FromPrimitive, M: StateUpdate>39RollingAggWindowNulls<'a, T> for MomentWindow<'a, T, M>40{41unsafe fn new(42slice: &'a [T],43validity: &'a Bitmap,44start: usize,45end: usize,46params: Option<RollingFnParams>,47_window_size: Option<usize>,48) -> Self {49let mut out = Self {50slice,51validity,52moment: None,53last_start: start,54last_end: end,55null_count: 0,56params,57};58out.compute_moment_and_null_count(start, end);59out60}6162unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {63let recompute_var = if start >= self.last_end {64true65} else {66// remove elements that should leave the window67let mut recompute_var = false;68for idx in self.last_start..start {69// SAFETY:70// we are in bounds71let valid = self.validity.get_bit_unchecked(idx);72if valid {73let leaving_value = *self.slice.get_unchecked(idx);7475// if the leaving value is nan we need to recompute the window76if T::is_float() && !leaving_value.is_finite() {77recompute_var = true;78break;79}80let leaving_value: f64 = NumCast::from(leaving_value).unwrap();81if let Some(v) = self.moment.as_mut() {82v.remove_one(leaving_value)83}84} else {85// null value leaving the window86self.null_count -= 1;8788// self.sum is None and the leaving value is None89// if the entering value is valid, we might get a new sum.90if self.moment.is_none() {91recompute_var = true;92break;93}94}95}96recompute_var97};9899self.last_start = start;100101// we traverse all values and compute102if recompute_var {103self.compute_moment_and_null_count(start, end);104} else {105for idx in self.last_end..end {106let valid = self.validity.get_bit_unchecked(idx);107108if valid {109let entering_value = *self.slice.get_unchecked(idx);110let entering_value: f64 = NumCast::from(entering_value).unwrap();111self.moment112.get_or_insert_with(|| M::new(self.params))113.insert_one(entering_value);114} else {115// null value entering the window116self.null_count += 1;117}118}119}120self.last_end = end;121self.moment.as_ref().and_then(|v| {122let out = v.finalize();123out.map(|v| T::from_f64(v).unwrap())124})125}126127fn is_valid(&self, min_periods: usize) -> bool {128((self.last_end - self.last_start) - self.null_count) >= min_periods129}130}131132pub fn rolling_var<T>(133arr: &PrimitiveArray<T>,134window_size: usize,135min_periods: usize,136center: bool,137weights: Option<&[f64]>,138params: Option<RollingFnParams>,139) -> ArrayRef140where141T: NativeType + ToPrimitive + FromPrimitive + IsFloat + Float,142{143if weights.is_some() {144panic!("weights not yet supported on array with null values")145}146let offsets_fn = if center {147det_offsets_center148} else {149det_offsets150};151rolling_apply_agg_window::<MomentWindow<_, VarianceMoment>, _, _>(152arr.values().as_slice(),153arr.validity().as_ref().unwrap(),154window_size,155min_periods,156offsets_fn,157params,158)159}160161pub fn rolling_skew<T>(162arr: &PrimitiveArray<T>,163window_size: usize,164min_periods: usize,165center: bool,166params: Option<RollingFnParams>,167) -> ArrayRef168where169T: NativeType + ToPrimitive + FromPrimitive + IsFloat + Float,170{171let offsets_fn = if center {172det_offsets_center173} else {174det_offsets175};176rolling_apply_agg_window::<MomentWindow<_, SkewMoment>, _, _>(177arr.values().as_slice(),178arr.validity().as_ref().unwrap(),179window_size,180min_periods,181offsets_fn,182params,183)184}185186pub fn rolling_kurtosis<T>(187arr: &PrimitiveArray<T>,188window_size: usize,189min_periods: usize,190center: bool,191params: Option<RollingFnParams>,192) -> ArrayRef193where194T: NativeType + ToPrimitive + FromPrimitive + IsFloat + Float,195{196let offsets_fn = if center {197det_offsets_center198} else {199det_offsets200};201rolling_apply_agg_window::<MomentWindow<_, KurtosisMoment>, _, _>(202arr.values().as_slice(),203arr.validity().as_ref().unwrap(),204window_size,205min_periods,206offsets_fn,207params,208)209}210211212