Path: blob/main/crates/polars-time/src/chunkedarray/rolling_window/dispatch.rs
8406 views
use std::borrow::Cow;12use arrow::types::NativeType;3#[cfg(feature = "dtype-f16")]4use num_traits::real::Real;5use polars_compute::rolling::no_nulls::RollingAggWindowNoNulls;6use polars_compute::rolling::nulls::RollingAggWindowNulls;7use polars_compute::rolling::{MeanWindow, SumWindow, no_nulls, nulls};8use polars_core::{with_match_physical_float_polars_type, with_match_physical_numeric_polars_type};9use polars_ops::series::SeriesMethods;10use polars_utils::float::IsFloat;1112use super::*;13use crate::prelude::*;14use crate::series::AsSeries;1516#[cfg(feature = "rolling_window")]17#[allow(clippy::type_complexity)]18fn rolling_agg<T>(19ca: &ChunkedArray<T>,20options: RollingOptionsFixedWindow,21rolling_agg_fn: &dyn Fn(22&[T::Native],23usize,24usize,25bool,26Option<&[f64]>,27Option<RollingFnParams>,28) -> PolarsResult<ArrayRef>,29rolling_agg_fn_nulls: &dyn Fn(30&PrimitiveArray<T::Native>,31usize,32usize,33bool,34Option<&[f64]>,35Option<RollingFnParams>,36) -> ArrayRef,37) -> PolarsResult<Series>38where39T: PolarsNumericType,40{41polars_ensure!(options.min_periods <= options.window_size, InvalidOperation: "`min_periods` should be <= `window_size`");42if ca.is_empty() {43return Ok(Series::new_empty(ca.name().clone(), ca.dtype()));44}45let ca = ca.rechunk();4647let arr = ca.downcast_iter().next().unwrap();48let arr = match ca.null_count() {490 => rolling_agg_fn(50arr.values().as_slice(),51options.window_size,52options.min_periods,53options.center,54options.weights.as_deref(),55options.fn_params,56)?,57_ => rolling_agg_fn_nulls(58arr,59options.window_size,60options.min_periods,61options.center,62options.weights.as_deref(),63options.fn_params,64),65};66Series::try_from((ca.name().clone(), arr))67}6869#[cfg(feature = "rolling_window_by")]70fn rolling_agg_by<T, Out, NoNullsAgg, NullsAgg>(71ca: &ChunkedArray<T>,72by: &Series,73options: RollingOptionsDynamicWindow,74) -> PolarsResult<Series>75where76T: PolarsNumericType,77T::Native: NativeType + IsFloat,78Out: NativeType,79NoNullsAgg: RollingAggWindowNoNulls<T::Native, Out>,80NullsAgg: RollingAggWindowNulls<T::Native, Out>,81{82use crate::chunkedarray::rolling_window::rolling_kernels::shared::{83RollingAggWindowNoNullsWrapper, RollingAggWindowNullsWrapper, rolling_apply_agg,84};8586if ca.is_empty() {87return Ok(Series::new_empty(ca.name().clone(), ca.dtype()));88}8990polars_ensure!(91ca.len() == by.len(),92InvalidOperation: "`by` column in `rolling_*_by` must be the same length as values column"93);94ensure_duration_matches_dtype(options.window_size, by.dtype(), "window_size")?;95polars_ensure!(96!options.window_size.is_zero() && !options.window_size.negative,97InvalidOperation: "`window_size` must be strictly positive"98);99100let (by, tz) = match by.dtype() {101DataType::Datetime(tu, tz) => (by.cast(&DataType::Datetime(*tu, None))?, tz),102DataType::Date => (103by.cast(&DataType::Datetime(TimeUnit::Microseconds, None))?,104&None,105),106DataType::Int64 => (107by.cast(&DataType::Datetime(TimeUnit::Nanoseconds, None))?,108&None,109),110DataType::Int32 | DataType::UInt64 | DataType::UInt32 => (111by.cast(&DataType::Int64)?112.cast(&DataType::Datetime(TimeUnit::Nanoseconds, None))?,113&None,114),115dt => polars_bail!(InvalidOperation:116"in `rolling_*_by` operation, `by` argument of dtype `{}` is not supported (expected `{}`)",117dt,118"Date/Datetime/Int64/Int32/UInt64/UInt32"),119};120let mut ca_rechunked = ca.rechunk();121let by = by.rechunk();122let by_is_sorted = by.is_sorted(SortOptions {123descending: false,124..Default::default()125})?;126let by_logical = by.datetime().unwrap();127let tu = by_logical.time_unit();128let mut by_physical = Cow::Borrowed(by_logical.physical());129let sorting_indices_opt = (!by_is_sorted).then(|| by_physical.arg_sort(Default::default()));130131if let Some(sorting_indices) = &sorting_indices_opt {132// SAFETY: `sorting_indices` is in-bounds because we checked that `ca.len() == by.len()` and133// they are derived from `by`.134ca_rechunked = Cow::Owned(unsafe { ca_rechunked.take_unchecked(sorting_indices) });135// SAFETY: `sorting_indices` is in-bounds because they are derived from `by`.136by_physical = Cow::Owned(unsafe { by_physical.take_unchecked(sorting_indices) });137}138139let by_values = by_physical.cont_slice().unwrap();140let arr = ca_rechunked.downcast_iter().next().unwrap();141let values = arr.values().as_slice();142143// We explicitly branch here because we want to compile different versions based on the no_nulls144// or nulls kernel.145let out: ArrayRef = if ca.null_count() == 0 {146let mut agg_window =147RollingAggWindowNoNullsWrapper(NoNullsAgg::new(values, 0, 0, options.fn_params, None));148149rolling_apply_agg(150&mut agg_window,151options.window_size,152by_values,153options.closed_window,154options.min_periods,155tu,156tz.as_ref(),157sorting_indices_opt158.as_ref()159.map(|s| s.cont_slice().unwrap()),160)?161} else {162let validity = arr.validity().unwrap();163let mut agg_window = RollingAggWindowNullsWrapper(NullsAgg::new(164values,165validity,1660,1670,168options.fn_params,169None,170));171172rolling_apply_agg(173&mut agg_window,174options.window_size,175by_values,176options.closed_window,177options.min_periods,178tu,179tz.as_ref(),180sorting_indices_opt181.as_ref()182.map(|s| s.cont_slice().unwrap()),183)?184};185186Series::try_from((ca.name().clone(), out))187}188189pub trait SeriesOpsTime: AsSeries {190/// Apply a rolling mean to a Series based on another Series.191#[cfg(feature = "rolling_window_by")]192fn rolling_mean_by(193&self,194by: &Series,195options: RollingOptionsDynamicWindow,196) -> PolarsResult<Series> {197let s = self.as_series().to_float()?;198with_match_physical_float_polars_type!(s.dtype(), |$T| {199let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();200rolling_agg_by::<$T, _, MeanWindow<_>, MeanWindow<_>>(ca, by, options)201})202}203/// Apply a rolling mean to a Series.204///205/// See: [`RollingAgg::rolling_mean`]206#[cfg(feature = "rolling_window")]207fn rolling_mean(&self, options: RollingOptionsFixedWindow) -> PolarsResult<Series> {208let s = self.as_series().to_float()?;209with_match_physical_float_polars_type!(s.dtype(), |$T| {210let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();211rolling_agg(212ca,213options,214&rolling::no_nulls::rolling_mean,215&rolling::nulls::rolling_mean,216)217})218}219/// Apply a rolling sum to a Series based on another Series.220#[cfg(feature = "rolling_window_by")]221fn rolling_sum_by(222&self,223by: &Series,224options: RollingOptionsDynamicWindow,225) -> PolarsResult<Series> {226let mut s = self.as_series().clone();227if s.dtype() == &DataType::Boolean {228s = s.cast(&DataType::IDX_DTYPE).unwrap();229}230if matches!(231s.dtype(),232DataType::Int8 | DataType::UInt8 | DataType::Int16 | DataType::UInt16233) {234s = s.cast(&DataType::Int64).unwrap();235}236237polars_ensure!(238s.dtype().is_primitive_numeric() && !s.dtype().is_unknown(),239op = "rolling_sum_by",240s.dtype()241);242243with_match_physical_numeric_polars_type!(s.dtype(), |$T| {244let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();245type Native = <$T as PolarsNumericType>::Native;246type SM<'a> = SumWindow<'a, Native, Native>;247rolling_agg_by::<$T, _, SM, SM>(ca, by, options)248})249}250251/// Apply a rolling sum to a Series.252#[cfg(feature = "rolling_window")]253fn rolling_sum(&self, options: RollingOptionsFixedWindow) -> PolarsResult<Series> {254let mut s = self.as_series().clone();255if options.weights.is_some() {256s = s.to_float()?;257} else if s.dtype() == &DataType::Boolean {258s = s.cast(&DataType::IDX_DTYPE).unwrap();259} else if matches!(260s.dtype(),261DataType::Int8 | DataType::UInt8 | DataType::Int16 | DataType::UInt16262) {263s = s.cast(&DataType::Int64).unwrap();264}265266polars_ensure!(267s.dtype().is_primitive_numeric() && !s.dtype().is_unknown(),268op = "rolling_sum",269s.dtype()270);271272with_match_physical_numeric_polars_type!(s.dtype(), |$T| {273let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();274rolling_agg(275ca,276options,277&rolling::no_nulls::rolling_sum,278&rolling::nulls::rolling_sum,279)280})281}282283/// Apply a rolling quantile to a Series based on another Series.284#[cfg(feature = "rolling_window_by")]285fn rolling_quantile_by(286&self,287by: &Series,288options: RollingOptionsDynamicWindow,289) -> PolarsResult<Series> {290let s = self.as_series().to_float()?;291with_match_physical_float_polars_type!(s.dtype(), |$T| {292let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();293rolling_agg_by::<294$T,295_,296no_nulls::QuantileWindow<_>,297nulls::QuantileWindow<_>298>(ca, by, options)299})300}301302/// Apply a rolling quantile to a Series.303#[cfg(feature = "rolling_window")]304fn rolling_quantile(&self, options: RollingOptionsFixedWindow) -> PolarsResult<Series> {305let s = self.as_series().to_float()?;306with_match_physical_float_polars_type!(s.dtype(), |$T| {307let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();308rolling_agg(309ca,310options,311&rolling::no_nulls::rolling_quantile,312&rolling::nulls::rolling_quantile,313)314})315}316317/// Apply a rolling min to a Series based on another Series.318#[cfg(feature = "rolling_window_by")]319fn rolling_min_by(320&self,321by: &Series,322options: RollingOptionsDynamicWindow,323) -> PolarsResult<Series> {324let s = self.as_series().clone();325326let dt = s.dtype();327match dt {328// Our rolling kernels don't yet support boolean, use UInt8 as a workaround for now.329&DataType::Boolean => {330return s331.cast(&DataType::UInt8)?332.rolling_min_by(by, options)?333.cast(&DataType::Boolean);334},335dt if dt.is_temporal() => {336return s.to_physical_repr().rolling_min_by(by, options)?.cast(dt);337},338dt => {339polars_ensure!(340dt.is_primitive_numeric() && !dt.is_unknown(),341op = "rolling_min_by",342dt343);344},345}346347with_match_physical_numeric_polars_type!(s.dtype(), |$T| {348let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();349rolling_agg_by::<350$T,351_,352no_nulls::MinWindow<_>,353nulls::MinWindow<_>354>(ca, by, options)355})356}357358/// Apply a rolling min to a Series.359#[cfg(feature = "rolling_window")]360fn rolling_min(&self, options: RollingOptionsFixedWindow) -> PolarsResult<Series> {361let mut s = self.as_series().clone();362if options.weights.is_some() {363s = s.to_float()?;364}365366let dt = s.dtype();367match dt {368// Our rolling kernels don't yet support boolean, use UInt8 as a workaround for now.369&DataType::Boolean => {370return s371.cast(&DataType::UInt8)?372.rolling_min(options)?373.cast(&DataType::Boolean);374},375dt if dt.is_temporal() => {376return s.to_physical_repr().rolling_min(options)?.cast(dt);377},378dt => {379polars_ensure!(380dt.is_primitive_numeric() && !dt.is_unknown(),381op = "rolling_min",382dt383);384},385}386387with_match_physical_numeric_polars_type!(dt, |$T| {388let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();389rolling_agg(390ca,391options,392&rolling::no_nulls::rolling_min,393&rolling::nulls::rolling_min,394)395})396}397398/// Apply a rolling max to a Series based on another Series.399#[cfg(feature = "rolling_window_by")]400fn rolling_max_by(401&self,402by: &Series,403options: RollingOptionsDynamicWindow,404) -> PolarsResult<Series> {405let s = self.as_series().clone();406407let dt = s.dtype();408match dt {409// Our rolling kernels don't yet support boolean, use UInt8 as a workaround for now.410&DataType::Boolean => {411return s412.cast(&DataType::UInt8)?413.rolling_max_by(by, options)?414.cast(&DataType::Boolean);415},416dt if dt.is_temporal() => {417return s.to_physical_repr().rolling_max_by(by, options)?.cast(dt);418},419dt => {420polars_ensure!(421dt.is_primitive_numeric() && !dt.is_unknown(),422op = "rolling_max_by",423dt424);425},426}427428with_match_physical_numeric_polars_type!(s.dtype(), |$T| {429let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();430rolling_agg_by::<431$T,432_,433no_nulls::MaxWindow<_>,434nulls::MaxWindow<_>435>(ca, by, options)436})437}438439/// Apply a rolling max to a Series.440#[cfg(feature = "rolling_window")]441fn rolling_max(&self, options: RollingOptionsFixedWindow) -> PolarsResult<Series> {442let mut s = self.as_series().clone();443if options.weights.is_some() {444s = s.to_float()?;445}446447let dt = s.dtype();448match dt {449// Our rolling kernels don't yet support boolean, use UInt8 as a workaround for now.450&DataType::Boolean => {451return s452.cast(&DataType::UInt8)?453.rolling_max(options)?454.cast(&DataType::Boolean);455},456dt if dt.is_temporal() => {457return s.to_physical_repr().rolling_max(options)?.cast(dt);458},459dt => {460polars_ensure!(461dt.is_primitive_numeric() && !dt.is_unknown(),462op = "rolling_max",463dt464);465},466}467468with_match_physical_numeric_polars_type!(s.dtype(), |$T| {469let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();470rolling_agg(471ca,472options,473&rolling::no_nulls::rolling_max,474&rolling::nulls::rolling_max,475)476})477}478479/// Apply a rolling variance to a Series based on another Series.480#[cfg(feature = "rolling_window_by")]481fn rolling_var_by(482&self,483by: &Series,484options: RollingOptionsDynamicWindow,485) -> PolarsResult<Series> {486let s = self.as_series().to_float()?;487488with_match_physical_float_polars_type!(s.dtype(), |$T| {489let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();490491rolling_agg_by::<492$T,493_,494no_nulls::MomentWindow<_, no_nulls::VarianceMoment>,495nulls::MomentWindow<_, nulls::VarianceMoment>496>(ca, by, options)497})498}499500/// Apply a rolling variance to a Series.501#[cfg(feature = "rolling_window")]502fn rolling_var(&self, options: RollingOptionsFixedWindow) -> PolarsResult<Series> {503let s = self.as_series().to_float()?;504505with_match_physical_float_polars_type!(s.dtype(), |$T| {506let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();507508rolling_agg(509ca,510options,511&rolling::no_nulls::rolling_var,512&rolling::nulls::rolling_var,513)514})515}516517/// Apply a rolling std_dev to a Series based on another Series.518#[cfg(feature = "rolling_window_by")]519fn rolling_std_by(520&self,521by: &Series,522options: RollingOptionsDynamicWindow,523) -> PolarsResult<Series> {524self.rolling_var_by(by, options).map(|mut s| {525with_match_physical_float_polars_type!(s.dtype(), |$T| {526let ca: &mut ChunkedArray<$T> = s._get_inner_mut().as_mut();527ca.apply_mut(|v| v.sqrt());528});529530s531})532}533534/// Apply a rolling std_dev to a Series.535#[cfg(feature = "rolling_window")]536fn rolling_std(&self, options: RollingOptionsFixedWindow) -> PolarsResult<Series> {537self.rolling_var(options).map(|mut s| {538with_match_physical_float_polars_type!(s.dtype(), |$T| {539let ca: &mut ChunkedArray<$T> = s._get_inner_mut().as_mut();540ca.apply_mut(|v| v.sqrt());541});542543s544})545}546547/// Apply a rolling rank to a Series based on another Series.548#[cfg(feature = "rolling_window_by")]549fn rolling_rank_by(550&self,551by: &Series,552options: RollingOptionsDynamicWindow,553) -> PolarsResult<Series> {554if !matches!(555options.closed_window,556ClosedWindow::Right | ClosedWindow::Both557) {558polars_bail!(InvalidOperation: "`rolling_rank_by` window needs to be closed on the right side (i.e., `closed` must be `right` or `both`)");559}560561let s = self.as_series().clone();562563match s.dtype() {564DataType::Boolean => return s.cast(&DataType::UInt8)?.rolling_rank_by(by, options),565dt if dt.is_temporal() => return s.to_physical_repr().rolling_rank_by(by, options),566dt => {567polars_ensure!(568dt.is_primitive_numeric() && !dt.is_unknown(),569op = "rolling_rank_by",570dt571);572},573}574575let method = if let Some(RollingFnParams::Rank { method, .. }) = options.fn_params {576method577} else {578unreachable!("expected RollingFnParams::Rank");579};580581with_match_physical_numeric_polars_type!(s.dtype(), |$T| {582let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();583584match method {585RollingRankMethod::Average => rolling_agg_by::<586$T,587_,588no_nulls::RankWindowAvg<_>,589nulls::RankWindowAvg<_>590>(ca, by, options),591RollingRankMethod::Min => rolling_agg_by::<592$T,593_,594no_nulls::RankWindowMin<_>,595nulls::RankWindowMin<_>596>(ca, by, options),597RollingRankMethod::Max => rolling_agg_by::<598$T,599_,600no_nulls::RankWindowMax<_>,601nulls::RankWindowMax<_>602>(ca, by, options),603RollingRankMethod::Dense => rolling_agg_by::<604$T,605_,606no_nulls::RankWindowDense<_>,607nulls::RankWindowDense<_>608>(ca, by, options),609RollingRankMethod::Random => rolling_agg_by::<610$T,611_,612no_nulls::RankWindowRandom<_>,613nulls::RankWindowRandom<_>614>(ca, by, options),615_ => todo!()616}617})618}619620/// Apply a rolling rank to a Series.621#[cfg(feature = "rolling_window")]622fn rolling_rank(&self, options: RollingOptionsFixedWindow) -> PolarsResult<Series> {623let s = self.as_series();624625match s.dtype() {626DataType::Boolean => return s.cast(&DataType::UInt8)?.rolling_rank(options),627dt if dt.is_temporal() => return s.to_physical_repr().rolling_rank(options),628dt => {629polars_ensure!(630dt.is_primitive_numeric() && !dt.is_unknown(),631op = "rolling_rank",632dt633);634},635}636637with_match_physical_numeric_polars_type!(s.dtype(), |$T| {638let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();639let mut ca = ca.clone();640641rolling_agg(642&ca,643options,644&rolling::no_nulls::rolling_rank,645&rolling::nulls::rolling_rank,646)647})648}649}650651impl SeriesOpsTime for Series {}652653654