Path: blob/main/crates/polars-compute/src/rolling/mod.rs
8421 views
mod mean;1mod min_max;2mod moment;3pub mod no_nulls;4pub mod nulls;5pub mod quantile_filter;6mod rank;7mod sum;89mod arg_min_max;10pub(super) mod window;11use std::hash::Hash;12use std::ops::{Add, AddAssign, Div, Mul, Sub, SubAssign};1314pub use arg_min_max::{ArgMaxWindow, ArgMinMaxWindow, ArgMinWindow};15use arrow::array::{ArrayRef, PrimitiveArray};16use arrow::bitmap::{Bitmap, MutableBitmap};17use arrow::types::NativeType;18pub use mean::MeanWindow;19use num_traits::{Bounded, Float, NumCast, One, Zero};20use polars_utils::float::IsFloat;21#[cfg(feature = "serde")]22use serde::{Deserialize, Serialize};23use strum_macros::IntoStaticStr;24pub use sum::SumWindow;25use window::*;2627type Start = usize;28type End = usize;29type Idx = usize;30type WindowSize = usize;31type Len = usize;3233#[derive(Clone, Copy, PartialEq, Eq, Debug, Default, Hash, IntoStaticStr)]34#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]35#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]36#[strum(serialize_all = "snake_case")]37pub enum QuantileMethod {38#[default]39Nearest,40Lower,41Higher,42Midpoint,43Linear,44Equiprobable,45}4647#[deprecated(note = "use QuantileMethod instead")]48pub type QuantileInterpolOptions = QuantileMethod;4950#[derive(Clone, Copy, Debug, PartialEq, Hash)]51#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]52#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]53pub enum RollingFnParams {54Quantile(RollingQuantileParams),55Var(RollingVarParams),56Rank {57method: RollingRankMethod,58seed: Option<u64>,59},60Skew {61bias: bool,62},63Kurtosis {64fisher: bool,65bias: bool,66},67}6869fn det_offsets(i: Idx, window_size: WindowSize, _len: Len) -> (usize, usize) {70(i.saturating_sub(window_size - 1), i + 1)71}72fn det_offsets_center(i: Idx, window_size: WindowSize, len: Len) -> (usize, usize) {73let right_window = window_size.div_ceil(2);74(75i.saturating_sub(window_size - right_window),76std::cmp::min(len, i + right_window),77)78}7980fn create_validity<Fo>(81min_periods: usize,82len: usize,83window_size: usize,84det_offsets_fn: Fo,85) -> Option<MutableBitmap>86where87Fo: Fn(Idx, WindowSize, Len) -> (Start, End),88{89if min_periods > 1 {90let mut validity = MutableBitmap::with_capacity(len);91validity.extend_constant(len, true);9293// Set the null values at the boundaries9495// Head.96for i in 0..len {97let (start, end) = det_offsets_fn(i, window_size, len);98if (end - start) < min_periods {99validity.set(i, false)100} else {101break;102}103}104// Tail.105for i in (0..len).rev() {106let (start, end) = det_offsets_fn(i, window_size, len);107if (end - start) < min_periods {108validity.set(i, false)109} else {110break;111}112}113114Some(validity)115} else {116None117}118}119120// Parameters allowed for rolling operations.121#[derive(Clone, Copy, Debug, PartialEq, Hash)]122#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]123#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]124pub struct RollingVarParams {125pub ddof: u8,126}127128#[derive(Clone, Copy, Debug, PartialEq)]129#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]130#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]131pub struct RollingQuantileParams {132pub prob: f64,133pub method: QuantileMethod,134}135136impl Hash for RollingQuantileParams {137fn hash<H: std::hash::Hasher>(&self, state: &mut H) {138// Will not be NaN, so hash + eq symmetry will hold.139self.prob.to_bits().hash(state);140self.method.hash(state);141}142}143144#[derive(Clone, Copy, PartialEq, Eq, Debug, Default, Hash, IntoStaticStr)]145#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]146#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]147#[strum(serialize_all = "snake_case")]148pub enum RollingRankMethod {149#[default]150Average,151Min,152Max,153Dense,154Random,155}156157158