Path: blob/main/crates/polars-time/src/chunkedarray/rolling_window/dispatch.rs
6940 views
use polars_core::{with_match_physical_float_polars_type, with_match_physical_numeric_polars_type};1use polars_ops::series::SeriesMethods;23use super::*;4use crate::prelude::*;5use crate::series::AsSeries;67#[cfg(feature = "rolling_window")]8#[allow(clippy::type_complexity)]9fn rolling_agg<T>(10ca: &ChunkedArray<T>,11options: RollingOptionsFixedWindow,12rolling_agg_fn: &dyn Fn(13&[T::Native],14usize,15usize,16bool,17Option<&[f64]>,18Option<RollingFnParams>,19) -> PolarsResult<ArrayRef>,20rolling_agg_fn_nulls: &dyn Fn(21&PrimitiveArray<T::Native>,22usize,23usize,24bool,25Option<&[f64]>,26Option<RollingFnParams>,27) -> ArrayRef,28) -> PolarsResult<Series>29where30T: PolarsNumericType,31{32polars_ensure!(options.min_periods <= options.window_size, InvalidOperation: "`min_periods` should be <= `window_size`");33if ca.is_empty() {34return Ok(Series::new_empty(ca.name().clone(), ca.dtype()));35}36let ca = ca.rechunk();3738let arr = ca.downcast_iter().next().unwrap();39let arr = match ca.null_count() {400 => rolling_agg_fn(41arr.values().as_slice(),42options.window_size,43options.min_periods,44options.center,45options.weights.as_deref(),46options.fn_params,47)?,48_ => rolling_agg_fn_nulls(49arr,50options.window_size,51options.min_periods,52options.center,53options.weights.as_deref(),54options.fn_params,55),56};57Series::try_from((ca.name().clone(), arr))58}5960#[cfg(feature = "rolling_window_by")]61#[allow(clippy::type_complexity)]62fn rolling_agg_by<T>(63ca: &ChunkedArray<T>,64by: &Series,65options: RollingOptionsDynamicWindow,66rolling_agg_fn_dynamic: &dyn Fn(67&[T::Native],68Duration,69&[i64],70ClosedWindow,71usize,72TimeUnit,73Option<&TimeZone>,74Option<RollingFnParams>,75Option<&[IdxSize]>,76) -> PolarsResult<ArrayRef>,77) -> PolarsResult<Series>78where79T: PolarsNumericType,80{81if ca.is_empty() {82return Ok(Series::new_empty(ca.name().clone(), ca.dtype()));83}84polars_ensure!(by.null_count() == 0 && ca.null_count() == 0, InvalidOperation: "'Expr.rolling_*_by(...)' not yet supported for series with null values, consider using 'DataFrame.rolling' or 'Expr.rolling'");85polars_ensure!(ca.len() == by.len(), InvalidOperation: "`by` column in `rolling_*_by` must be the same length as values column");86ensure_duration_matches_dtype(options.window_size, by.dtype(), "window_size")?;87polars_ensure!(!options.window_size.is_zero() && !options.window_size.negative, InvalidOperation: "`window_size` must be strictly positive");88let (by, tz) = match by.dtype() {89DataType::Datetime(tu, tz) => (by.cast(&DataType::Datetime(*tu, None))?, tz),90DataType::Date => (91by.cast(&DataType::Datetime(TimeUnit::Microseconds, None))?,92&None,93),94DataType::Int64 => (95by.cast(&DataType::Datetime(TimeUnit::Nanoseconds, None))?,96&None,97),98DataType::Int32 | DataType::UInt64 | DataType::UInt32 => (99by.cast(&DataType::Int64)?100.cast(&DataType::Datetime(TimeUnit::Nanoseconds, None))?,101&None,102),103dt => polars_bail!(InvalidOperation:104"in `rolling_*_by` operation, `by` argument of dtype `{}` is not supported (expected `{}`)",105dt,106"Date/Datetime/Int64/Int32/UInt64/UInt32"),107};108let ca = ca.rechunk();109let by = by.rechunk();110let by_is_sorted = by.is_sorted(SortOptions {111descending: false,112..Default::default()113})?;114let by = by.datetime().unwrap();115let tu = by.time_unit();116117let func = rolling_agg_fn_dynamic;118let out: ArrayRef = if by_is_sorted {119let arr = ca.downcast_iter().next().unwrap();120let by_values = by.physical().cont_slice().unwrap();121let values = arr.values().as_slice();122func(123values,124options.window_size,125by_values,126options.closed_window,127options.min_periods,128tu,129tz.as_ref(),130options.fn_params,131None,132)?133} else {134let sorting_indices = by.physical().arg_sort(Default::default());135let ca = unsafe { ca.take_unchecked(&sorting_indices) };136let by = unsafe { by.physical().take_unchecked(&sorting_indices) };137let arr = ca.downcast_iter().next().unwrap();138let by_values = by.cont_slice().unwrap();139let values = arr.values().as_slice();140func(141values,142options.window_size,143by_values,144options.closed_window,145options.min_periods,146tu,147tz.as_ref(),148options.fn_params,149Some(sorting_indices.cont_slice().unwrap()),150)?151};152Series::try_from((ca.name().clone(), out))153}154155pub trait SeriesOpsTime: AsSeries {156/// Apply a rolling mean to a Series based on another Series.157#[cfg(feature = "rolling_window_by")]158fn rolling_mean_by(159&self,160by: &Series,161options: RollingOptionsDynamicWindow,162) -> PolarsResult<Series> {163let s = self.as_series().to_float()?;164with_match_physical_float_polars_type!(s.dtype(), |$T| {165let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();166rolling_agg_by(167ca,168by,169options,170&super::rolling_kernels::no_nulls::rolling_mean,171)172})173}174/// Apply a rolling mean to a Series.175///176/// See: [`RollingAgg::rolling_mean`]177#[cfg(feature = "rolling_window")]178fn rolling_mean(&self, options: RollingOptionsFixedWindow) -> PolarsResult<Series> {179let s = self.as_series().to_float()?;180with_match_physical_float_polars_type!(s.dtype(), |$T| {181let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();182rolling_agg(183ca,184options,185&rolling::no_nulls::rolling_mean,186&rolling::nulls::rolling_mean,187)188})189}190/// Apply a rolling sum to a Series based on another Series.191#[cfg(feature = "rolling_window_by")]192fn rolling_sum_by(193&self,194by: &Series,195options: RollingOptionsDynamicWindow,196) -> PolarsResult<Series> {197let mut s = self.as_series().clone();198if s.dtype() == &DataType::Boolean {199s = s.cast(&DataType::IDX_DTYPE).unwrap();200}201if matches!(202s.dtype(),203DataType::Int8 | DataType::UInt8 | DataType::Int16 | DataType::UInt16204) {205s = s.cast(&DataType::Int64).unwrap();206}207208polars_ensure!(209s.dtype().is_primitive_numeric() && !s.dtype().is_unknown(),210op = "rolling_sum_by",211s.dtype()212);213214with_match_physical_numeric_polars_type!(s.dtype(), |$T| {215let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();216rolling_agg_by(217ca,218by,219options,220&super::rolling_kernels::no_nulls::rolling_sum,221)222})223}224225/// Apply a rolling sum to a Series.226#[cfg(feature = "rolling_window")]227fn rolling_sum(&self, options: RollingOptionsFixedWindow) -> PolarsResult<Series> {228let mut s = self.as_series().clone();229if options.weights.is_some() {230s = s.to_float()?;231} else if s.dtype() == &DataType::Boolean {232s = s.cast(&DataType::IDX_DTYPE).unwrap();233} else if matches!(234s.dtype(),235DataType::Int8 | DataType::UInt8 | DataType::Int16 | DataType::UInt16236) {237s = s.cast(&DataType::Int64).unwrap();238}239240polars_ensure!(241s.dtype().is_primitive_numeric() && !s.dtype().is_unknown(),242op = "rolling_sum",243s.dtype()244);245246with_match_physical_numeric_polars_type!(s.dtype(), |$T| {247let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();248rolling_agg(249ca,250options,251&rolling::no_nulls::rolling_sum,252&rolling::nulls::rolling_sum,253)254})255}256257/// Apply a rolling quantile to a Series based on another Series.258#[cfg(feature = "rolling_window_by")]259fn rolling_quantile_by(260&self,261by: &Series,262options: RollingOptionsDynamicWindow,263) -> PolarsResult<Series> {264let s = self.as_series().to_float()?;265with_match_physical_float_polars_type!(s.dtype(), |$T| {266let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();267rolling_agg_by(268ca,269by,270options,271&super::rolling_kernels::no_nulls::rolling_quantile,272)273})274}275276/// Apply a rolling quantile to a Series.277#[cfg(feature = "rolling_window")]278fn rolling_quantile(&self, options: RollingOptionsFixedWindow) -> PolarsResult<Series> {279let s = self.as_series().to_float()?;280with_match_physical_float_polars_type!(s.dtype(), |$T| {281let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();282rolling_agg(283ca,284options,285&rolling::no_nulls::rolling_quantile,286&rolling::nulls::rolling_quantile,287)288})289}290291/// Apply a rolling min to a Series based on another Series.292#[cfg(feature = "rolling_window_by")]293fn rolling_min_by(294&self,295by: &Series,296options: RollingOptionsDynamicWindow,297) -> PolarsResult<Series> {298let s = self.as_series().clone();299300let dt = s.dtype();301match dt {302// Our rolling kernels don't yet support boolean, use UInt8 as a workaround for now.303&DataType::Boolean => {304return s305.cast(&DataType::UInt8)?306.rolling_min_by(by, options)?307.cast(&DataType::Boolean);308},309dt if dt.is_temporal() => {310return s.to_physical_repr().rolling_min_by(by, options)?.cast(dt);311},312dt => {313polars_ensure!(314dt.is_primitive_numeric() && !dt.is_unknown(),315op = "rolling_min_by",316dt317);318},319}320321with_match_physical_numeric_polars_type!(s.dtype(), |$T| {322let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();323rolling_agg_by(324ca,325by,326options,327&super::rolling_kernels::no_nulls::rolling_min,328)329})330}331332/// Apply a rolling min to a Series.333#[cfg(feature = "rolling_window")]334fn rolling_min(&self, options: RollingOptionsFixedWindow) -> PolarsResult<Series> {335let mut s = self.as_series().clone();336if options.weights.is_some() {337s = s.to_float()?;338}339340let dt = s.dtype();341match dt {342// Our rolling kernels don't yet support boolean, use UInt8 as a workaround for now.343&DataType::Boolean => {344return s345.cast(&DataType::UInt8)?346.rolling_min(options)?347.cast(&DataType::Boolean);348},349dt if dt.is_temporal() => {350return s.to_physical_repr().rolling_min(options)?.cast(dt);351},352dt => {353polars_ensure!(354dt.is_primitive_numeric() && !dt.is_unknown(),355op = "rolling_min",356dt357);358},359}360361with_match_physical_numeric_polars_type!(dt, |$T| {362let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();363rolling_agg(364ca,365options,366&rolling::no_nulls::rolling_min,367&rolling::nulls::rolling_min,368)369})370}371372/// Apply a rolling max to a Series based on another Series.373#[cfg(feature = "rolling_window_by")]374fn rolling_max_by(375&self,376by: &Series,377options: RollingOptionsDynamicWindow,378) -> PolarsResult<Series> {379let s = self.as_series().clone();380381let dt = s.dtype();382match dt {383// Our rolling kernels don't yet support boolean, use UInt8 as a workaround for now.384&DataType::Boolean => {385return s386.cast(&DataType::UInt8)?387.rolling_max_by(by, options)?388.cast(&DataType::Boolean);389},390dt if dt.is_temporal() => {391return s.to_physical_repr().rolling_max_by(by, options)?.cast(dt);392},393dt => {394polars_ensure!(395dt.is_primitive_numeric() && !dt.is_unknown(),396op = "rolling_max_by",397dt398);399},400}401402with_match_physical_numeric_polars_type!(s.dtype(), |$T| {403let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();404rolling_agg_by(405ca,406by,407options,408&super::rolling_kernels::no_nulls::rolling_max,409)410})411}412413/// Apply a rolling max to a Series.414#[cfg(feature = "rolling_window")]415fn rolling_max(&self, options: RollingOptionsFixedWindow) -> PolarsResult<Series> {416let mut s = self.as_series().clone();417if options.weights.is_some() {418s = s.to_float()?;419}420421let dt = s.dtype();422match dt {423// Our rolling kernels don't yet support boolean, use UInt8 as a workaround for now.424&DataType::Boolean => {425return s426.cast(&DataType::UInt8)?427.rolling_max(options)?428.cast(&DataType::Boolean);429},430dt if dt.is_temporal() => {431return s.to_physical_repr().rolling_max(options)?.cast(dt);432},433dt => {434polars_ensure!(435dt.is_primitive_numeric() && !dt.is_unknown(),436op = "rolling_max",437dt438);439},440}441442with_match_physical_numeric_polars_type!(s.dtype(), |$T| {443let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();444rolling_agg(445ca,446options,447&rolling::no_nulls::rolling_max,448&rolling::nulls::rolling_max,449)450})451}452453/// Apply a rolling variance to a Series based on another Series.454#[cfg(feature = "rolling_window_by")]455fn rolling_var_by(456&self,457by: &Series,458options: RollingOptionsDynamicWindow,459) -> PolarsResult<Series> {460let s = self.as_series().to_float()?;461462with_match_physical_float_polars_type!(s.dtype(), |$T| {463let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();464let mut ca = ca.clone();465466rolling_agg_by(467&ca,468by,469options,470&super::rolling_kernels::no_nulls::rolling_var,471)472})473}474475/// Apply a rolling variance to a Series.476#[cfg(feature = "rolling_window")]477fn rolling_var(&self, options: RollingOptionsFixedWindow) -> PolarsResult<Series> {478let s = self.as_series().to_float()?;479480with_match_physical_float_polars_type!(s.dtype(), |$T| {481let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();482let mut ca = ca.clone();483484rolling_agg(485&ca,486options,487&rolling::no_nulls::rolling_var,488&rolling::nulls::rolling_var,489)490})491}492493/// Apply a rolling std_dev to a Series based on another Series.494#[cfg(feature = "rolling_window_by")]495fn rolling_std_by(496&self,497by: &Series,498options: RollingOptionsDynamicWindow,499) -> PolarsResult<Series> {500self.rolling_var_by(by, options).map(|mut s| {501match s.dtype().clone() {502DataType::Float32 => {503let ca: &mut ChunkedArray<Float32Type> = s._get_inner_mut().as_mut();504ca.apply_mut(|v| v.powf(0.5))505},506DataType::Float64 => {507let ca: &mut ChunkedArray<Float64Type> = s._get_inner_mut().as_mut();508ca.apply_mut(|v| v.powf(0.5))509},510_ => unreachable!(),511}512s513})514}515516/// Apply a rolling std_dev to a Series.517#[cfg(feature = "rolling_window")]518fn rolling_std(&self, options: RollingOptionsFixedWindow) -> PolarsResult<Series> {519self.rolling_var(options).map(|mut s| {520match s.dtype().clone() {521DataType::Float32 => {522let ca: &mut ChunkedArray<Float32Type> = s._get_inner_mut().as_mut();523ca.apply_mut(|v| v.powf(0.5))524},525DataType::Float64 => {526let ca: &mut ChunkedArray<Float64Type> = s._get_inner_mut().as_mut();527ca.apply_mut(|v| v.powf(0.5))528},529_ => unreachable!(),530}531s532})533}534}535536impl SeriesOpsTime for Series {}537538539