Path: blob/main/crates/polars-compute/src/rolling/moment.rs
8327 views
use num_traits::{FromPrimitive, ToPrimitive};12use super::no_nulls::RollingAggWindowNoNulls;3use super::nulls::RollingAggWindowNulls;4use super::*;5use crate::moment::{KurtosisState, SkewState, VarState};67pub trait StateUpdate {8fn new(params: Option<RollingFnParams>) -> Self;9fn reset(&mut self);10fn insert_one(&mut self, x: f64);11fn remove_one(&mut self, x: f64);12fn finalize(&self) -> Option<f64>;13}1415pub struct VarianceMoment {16state: VarState,17ddof: u8,18}1920impl StateUpdate for VarianceMoment {21fn new(params: Option<RollingFnParams>) -> Self {22let ddof = if let Some(RollingFnParams::Var(params)) = params {23params.ddof24} else {25126};2728Self {29state: VarState::default(),30ddof,31}32}3334#[inline(always)]35fn reset(&mut self) {36self.state = VarState::default();37}3839#[inline(always)]40fn insert_one(&mut self, x: f64) {41self.state.insert_one(x);42}4344#[inline(always)]45fn remove_one(&mut self, x: f64) {46self.state.remove_one(x);47}4849#[inline(always)]50fn finalize(&self) -> Option<f64> {51self.state.finalize(self.ddof)52}53}5455pub struct KurtosisMoment {56state: KurtosisState,57fisher: bool,58bias: bool,59}6061impl StateUpdate for KurtosisMoment {62fn new(params: Option<RollingFnParams>) -> Self {63let (fisher, bias) = if let Some(RollingFnParams::Kurtosis { fisher, bias }) = params {64(fisher, bias)65} else {66(false, false)67};6869Self {70state: KurtosisState::default(),71fisher,72bias,73}74}7576#[inline(always)]77fn reset(&mut self) {78self.state = KurtosisState::default();79}8081#[inline(always)]82fn insert_one(&mut self, x: f64) {83self.state.insert_one(x);84}8586#[inline(always)]87fn remove_one(&mut self, x: f64) {88self.state.remove_one(x);89}9091#[inline(always)]92fn finalize(&self) -> Option<f64> {93self.state.finalize(self.fisher, self.bias)94}95}9697pub struct SkewMoment {98state: SkewState,99bias: bool,100}101102impl StateUpdate for SkewMoment {103fn new(params: Option<RollingFnParams>) -> Self {104let bias = if let Some(RollingFnParams::Skew { bias }) = params {105bias106} else {107false108};109110Self {111state: SkewState::default(),112bias,113}114}115116#[inline(always)]117fn reset(&mut self) {118self.state = SkewState::default();119}120121#[inline(always)]122fn insert_one(&mut self, x: f64) {123self.state.insert_one(x);124}125126#[inline(always)]127fn remove_one(&mut self, x: f64) {128self.state.remove_one(x);129}130131#[inline(always)]132fn finalize(&self) -> Option<f64> {133self.state.finalize(self.bias)134}135}136137pub struct MomentWindow<'a, T, M: StateUpdate> {138slice: &'a [T],139validity: Option<&'a Bitmap>,140moment: M,141non_finite_count: usize, // NaN or infinity.142null_count: usize,143start: usize,144end: usize,145}146147impl<'a, T, M> MomentWindow<'a, T, M>148where149T: NativeType + ToPrimitive + IsFloat + FromPrimitive,150M: StateUpdate,151{152fn new_impl(153slice: &'a [T],154validity: Option<&'a Bitmap>,155params: Option<RollingFnParams>,156) -> Self {157Self {158slice,159validity,160moment: M::new(params),161non_finite_count: 0,162null_count: 0,163start: 0,164end: 0,165}166}167168#[inline(always)]169fn reset(&mut self) {170self.moment.reset();171self.non_finite_count = 0;172self.null_count = 0;173}174175#[inline(always)]176fn insert(&mut self, val: T) {177if val.is_finite() {178self.moment.insert_one(NumCast::from(val).unwrap());179} else {180self.moment.insert_one(0.0); // A hack to replicate ddof null behavior.181self.non_finite_count += 1;182}183}184185#[inline(always)]186fn remove(&mut self, val: T) {187if val.is_finite() {188self.moment.remove_one(NumCast::from(val).unwrap());189} else {190self.moment.remove_one(0.0); // A hack to replicate ddof null behavior.191self.non_finite_count -= 1;192}193}194195#[inline(always)]196fn get_moment(&self) -> Option<T> {197if self.non_finite_count > 0 {198self.moment199.finalize()200.map(|_v| T::from_f64(f64::NAN).unwrap())201} else {202self.moment.finalize().map(|v| T::from_f64(v).unwrap())203}204}205}206207impl<T, M> RollingAggWindowNoNulls<T> for MomentWindow<'_, T, M>208where209T: NativeType + ToPrimitive + IsFloat + FromPrimitive,210M: StateUpdate,211{212type This<'a> = MomentWindow<'a, T, M>;213214fn new<'a>(215slice: &'a [T],216start: usize,217end: usize,218params: Option<RollingFnParams>,219_window_size: Option<usize>,220) -> Self::This<'a> {221let mut out = MomentWindow::new_impl(slice, None, params);222unsafe { RollingAggWindowNoNulls::update(&mut out, start, end) };223out224}225226// # Safety227// The start, end range must be in-bounds.228#[inline]229unsafe fn update(&mut self, new_start: usize, new_end: usize) {230if new_start >= self.end {231self.reset();232self.start = new_start;233self.end = new_start;234}235236for val in &self.slice[self.start..new_start] {237self.remove(*val);238}239240for val in &self.slice[self.end..new_end] {241self.insert(*val);242}243244self.start = new_start;245self.end = new_end;246}247248fn get_agg(&self, _idx: usize) -> Option<T> {249self.get_moment()250}251252fn slice_len(&self) -> usize {253self.slice.len()254}255}256257impl<T, M> RollingAggWindowNulls<T> for MomentWindow<'_, T, M>258where259T: NativeType + ToPrimitive + IsFloat + FromPrimitive,260M: StateUpdate,261{262type This<'a> = MomentWindow<'a, T, M>;263264fn new<'a>(265slice: &'a [T],266validity: &'a Bitmap,267start: usize,268end: usize,269params: Option<RollingFnParams>,270_window_size: Option<usize>,271) -> Self::This<'a> {272assert!(start <= slice.len() && end <= slice.len() && start <= end);273let mut out = MomentWindow::new_impl(slice, Some(validity), params);274// SAFETY: We bounds checked `start` and `end`.275unsafe { RollingAggWindowNulls::update(&mut out, start, end) };276out277}278279// # Safety280// The start, end range must be in-bounds.281#[inline]282unsafe fn update(&mut self, new_start: usize, new_end: usize) {283let validity = unsafe { self.validity.unwrap_unchecked() };284285if new_start >= self.end {286self.reset();287self.start = new_start;288self.end = new_start;289}290291for idx in self.start..new_start {292let valid = unsafe { validity.get_bit_unchecked(idx) };293if valid {294self.remove(unsafe { *self.slice.get_unchecked(idx) });295} else {296self.null_count -= 1;297}298}299300for idx in self.end..new_end {301let valid = unsafe { validity.get_bit_unchecked(idx) };302if valid {303self.insert(unsafe { *self.slice.get_unchecked(idx) });304} else {305self.null_count += 1;306}307}308309self.start = new_start;310self.end = new_end;311}312313fn get_agg(&self, _idx: usize) -> Option<T> {314self.get_moment()315}316317#[inline(always)]318fn is_valid(&self, min_periods: usize) -> bool {319((self.end - self.start) - self.null_count) >= min_periods320}321322fn slice_len(&self) -> usize {323self.slice.len()324}325}326327328