Path: blob/main/crates/polars-expr/src/expressions/rolling.rs
8424 views
use arrow::array::PrimitiveArray;1use polars_time::prelude::RollingWindower;2use polars_time::{ClosedWindow, Duration, PolarsTemporalGroupby, RollingGroupOptions};3use polars_utils::UnitVec;45use super::*;67pub(crate) struct RollingExpr {8/// the root column that the Function will be applied on.9/// This will be used to create a smaller DataFrame to prevent taking unneeded columns by index10/// TODO! support keys?11/// The challenge is that the group_by will reorder the results and the12/// keys, and time index would need to be updated, or the result should be joined back13/// For now, don't support it.14///15/// A function Expr. i.e. Mean, Median, Max, etc.16pub(crate) phys_function: Arc<dyn PhysicalExpr>,17pub(crate) index_column: Arc<dyn PhysicalExpr>,18pub(crate) period: Duration,19pub(crate) offset: Duration,20pub(crate) closed_window: ClosedWindow,21pub(crate) expr: Expr,22pub(crate) output_field: Field,23}2425impl PhysicalExpr for RollingExpr {26fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {27let groups = if let Some(index_column_name) = self.index_column.as_column() {28let options = RollingGroupOptions {29index_column: index_column_name.clone(),30period: self.period,31offset: self.offset,32closed_window: self.closed_window,33};34let groups_key = format!("{options:?}");35let groups = {36// Groups must be set by expression runner.37state.window_cache.get_groups(&groups_key)38};3940// There can be multiple rolling expressions in a single expr.41// E.g. `min().rolling() + max().rolling()`42// So if we hit that we will compute them here.43match groups {44Some(groups) => groups,45None => {46let (_time_key, groups) = df.rolling(None, &options)?;47state.window_cache.insert_groups(groups_key, groups.clone());48groups49},50}51} else {52let index_column_name = PlSmallStr::from_static("__PL_INDEX_COL");53let options = RollingGroupOptions {54index_column: index_column_name.clone(),55period: self.period,56offset: self.offset,57closed_window: self.closed_window,58};5960let index_column = self.index_column.evaluate(df, state)?;6162let mut df = df.clone();63df.with_column(index_column.with_name(index_column_name))?;64let (_time_key, groups) = df.rolling(None, &options)?;65groups66};6768let out = self69.phys_function70.evaluate_on_groups(df, &groups, state)?71.finalize();72polars_ensure!(out.len() == groups.len(), agg_len = out.len(), groups.len());73Ok(out.into_column())74}7576fn evaluate_on_groups<'a>(77&self,78df: &DataFrame,79groups: &'a GroupPositions,80state: &ExecutionState,81) -> PolarsResult<AggregationContext<'a>> {82let mut index_column = self.index_column.evaluate_on_groups(df, groups, state)?;8384index_column.groups();8586let mut index_column_data = index_column.flat_naive();87use DataType as DT;88let (time_unit, time_zone): (TimeUnit, Option<TimeZone>) = match index_column_data.dtype() {89DT::Datetime(tu, tz) => (*tu, tz.clone()),90DT::Date => (TimeUnit::Microseconds, None),91DT::UInt32 | DT::UInt64 | DT::Int32 => {92index_column_data = Cow::Owned(index_column_data.cast(&DT::Int64)?);93(TimeUnit::Nanoseconds, None)94},95DT::Int64 => (TimeUnit::Nanoseconds, None),96dt => polars_bail!(97ComputeError:98"expected any of the following dtypes: {{ Date, Datetime, Int32, Int64, UInt32, UInt64 }}, got {}",99dt100),101};102let index_column_data =103index_column_data.cast(&DataType::Datetime(time_unit, time_zone.clone()))?;104105// @NOTE: This is a bit strange since it ignores errors, but it mirrors the in-memory106// engine.107let tz = time_zone.and_then(|tz| tz.parse::<chrono_tz::Tz>().ok());108109polars_ensure!(110index_column_data.null_count() == 0,111ComputeError: "null values in `rolling` not supported, fill nulls."112);113let index_column_data = index_column_data.rechunk_to_arrow(CompatLevel::newest());114let index_column_data = index_column_data115.as_any()116.downcast_ref::<PrimitiveArray<i64>>()117.unwrap();118let mut index_column_data = Cow::Borrowed(index_column_data.values().as_slice());119let mut rolling =120RollingWindower::new(self.period, self.offset, self.closed_window, time_unit, tz);121122let num_elements = groups.num_elements();123124// Convert the index groups to slices.125//126// This is not strictly necessary but allows us to reuse the existing `RollingWindower`127// struct.128let (slice_groups, overlapping, monotonic) = match &**index_column.groups {129GroupsType::Idx(idx) => {130let mut data = Vec::with_capacity(num_elements);131let mut slices = Vec::with_capacity(groups.len());132for i in idx.all() {133slices.push([data.len() as IdxSize, i.len() as IdxSize]);134data.extend(i.iter().map(|i| index_column_data[*i as usize]));135}136index_column_data = Cow::Owned(data);137(Cow::Owned(slices), false, true)138},139GroupsType::Slice {140groups,141overlapping,142monotonic,143} => (Cow::Borrowed(groups), *overlapping, *monotonic),144};145146// We need to make sure there are no length mismatches, otherwise we will have problems147// down the line.148assert_eq!(slice_groups.len(), groups.len());149let length_mismatch = match &**groups {150GroupsType::Idx(idx) => idx151.all()152.iter()153.zip(slice_groups.iter())154.map(|(i, [_, s])| (i.len(), *s as usize))155.find(|(l, r)| *l != *r),156GroupsType::Slice {157groups,158overlapping: _,159monotonic: _,160} => groups161.iter()162.zip(slice_groups.iter())163.map(|([_, s1], [_, s2])| (*s1 as usize, *s2 as usize))164.find(|(l, r)| *l != *r),165};166if let Some((l, r)) = length_mismatch {167polars_bail!(length_mismatch = "rolling", l, r);168}169170// Get the subslices within each group.171let mut windows = Vec::with_capacity(num_elements);172for [start, length] in slice_groups.as_ref() {173rolling.reset();174let time = &index_column_data[*start as usize..][..*length as usize];175let offset = rolling.insert(&[time], &mut windows)?;176let time = &time[offset as usize..];177rolling.finalize(&[time], &mut windows);178}179180// Create new groups as subgroups of the existing groups.181let nested_groups = match &**groups {182GroupsType::Idx(idx) => {183let mut nested_groups = Vec::with_capacity(num_elements);184let mut i = 0;185for idx in idx.all() {186nested_groups.extend(windows[i..][..idx.len()].iter().map(|[s, l]| {187(188idx[*s as usize],189UnitVec::from_iter(idx[*s as usize..][..*l as usize].iter().copied()),190)191}));192i += idx.len();193}194GroupsType::Idx(nested_groups.into())195},196GroupsType::Slice {197groups,198overlapping: _,199monotonic,200} => {201let mut nested_groups = Vec::with_capacity(num_elements);202let mut i = 0;203for [start, length] in groups {204nested_groups.extend(205windows[i..][..*length as usize]206.iter()207.map(|[s, l]| [*start + *s, *l]),208);209i += *length as usize;210}211GroupsType::new_slice(nested_groups, true, *monotonic)212},213};214215let nested_groups = nested_groups.into_sliceable();216let out = self217.phys_function218.evaluate_on_groups(df, &nested_groups, state)?219.finalize();220polars_ensure!(221out.len() == nested_groups.len(),222agg_len = out.len(),223nested_groups.len()224);225226let out = AggregationContext {227state: AggState::NotAggregated(out.into_column()),228groups: Cow::Owned(229GroupsType::new_slice(slice_groups.into_owned(), overlapping, monotonic)230.into_sliceable(),231),232update_groups: UpdateGroups::No,233original_len: false,234};235Ok(out)236}237238fn to_field(&self, _input_schema: &Schema) -> PolarsResult<Field> {239Ok(self.output_field.clone())240}241242fn as_expression(&self) -> Option<&Expr> {243Some(&self.expr)244}245246fn is_scalar(&self) -> bool {247false248}249}250251252