Path: blob/main/crates/polars-time/src/group_by/dynamic.rs
6939 views
use arrow::legacy::time_zone::Tz;1use polars_core::POOL;2use polars_core::prelude::*;3use polars_core::series::IsSorted;4use polars_core::utils::flatten::flatten_par;5use polars_ops::series::SeriesMethods;6use polars_utils::itertools::Itertools;7use polars_utils::pl_str::PlSmallStr;8use polars_utils::slice::SortedSlice;9use rayon::prelude::*;10#[cfg(feature = "serde")]11use serde::{Deserialize, Serialize};1213use crate::prelude::*;1415#[repr(transparent)]16struct Wrap<T>(pub T);1718#[derive(Clone, Debug, PartialEq, Eq, Hash)]19#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]20#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]21pub struct DynamicGroupOptions {22/// Time or index column.23pub index_column: PlSmallStr,24/// Start a window at this interval.25pub every: Duration,26/// Window duration.27pub period: Duration,28/// Offset window boundaries.29pub offset: Duration,30/// Truncate the time column values to the window.31pub label: Label,32/// Add the boundaries to the DataFrame.33pub include_boundaries: bool,34pub closed_window: ClosedWindow,35pub start_by: StartBy,36}3738impl Default for DynamicGroupOptions {39fn default() -> Self {40Self {41index_column: "".into(),42every: Duration::new(1),43period: Duration::new(1),44offset: Duration::new(1),45label: Label::Left,46include_boundaries: false,47closed_window: ClosedWindow::Left,48start_by: Default::default(),49}50}51}5253#[derive(Clone, Debug, PartialEq, Eq, Hash)]54#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]55#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]56pub struct RollingGroupOptions {57/// Time or index column.58pub index_column: PlSmallStr,59/// Window duration.60pub period: Duration,61pub offset: Duration,62pub closed_window: ClosedWindow,63}6465impl Default for RollingGroupOptions {66fn default() -> Self {67Self {68index_column: "".into(),69period: Duration::new(1),70offset: Duration::new(1),71closed_window: ClosedWindow::Left,72}73}74}7576fn check_sortedness_slice(v: &[i64]) -> PolarsResult<()> {77polars_ensure!(v.is_sorted_ascending(), ComputeError: "input data is not sorted");78Ok(())79}8081const LB_NAME: &str = "_lower_boundary";82const UP_NAME: &str = "_upper_boundary";8384pub trait PolarsTemporalGroupby {85fn rolling(86&self,87group_by: Option<GroupsSlice>,88options: &RollingGroupOptions,89) -> PolarsResult<(Column, GroupPositions)>;9091fn group_by_dynamic(92&self,93group_by: Option<GroupsSlice>,94options: &DynamicGroupOptions,95) -> PolarsResult<(Column, Vec<Column>, GroupPositions)>;96}9798impl PolarsTemporalGroupby for DataFrame {99fn rolling(100&self,101group_by: Option<GroupsSlice>,102options: &RollingGroupOptions,103) -> PolarsResult<(Column, GroupPositions)> {104Wrap(self).rolling(group_by, options)105}106107fn group_by_dynamic(108&self,109group_by: Option<GroupsSlice>,110options: &DynamicGroupOptions,111) -> PolarsResult<(Column, Vec<Column>, GroupPositions)> {112Wrap(self).group_by_dynamic(group_by, options)113}114}115116impl Wrap<&DataFrame> {117fn rolling(118&self,119group_by: Option<GroupsSlice>,120options: &RollingGroupOptions,121) -> PolarsResult<(Column, GroupPositions)> {122polars_ensure!(123!options.period.is_zero() && !options.period.negative,124ComputeError:125"rolling window period should be strictly positive",126);127let time = self.0.column(&options.index_column)?.clone();128if group_by.is_none() {129// If by is given, the column must be sorted in the 'by' arg, which we can not check now130// this will be checked when the groups are materialized.131time.as_materialized_series().ensure_sorted_arg("rolling")?;132}133let time_type = time.dtype();134135polars_ensure!(time.null_count() == 0, ComputeError: "null values in `rolling` not supported, fill nulls.");136ensure_duration_matches_dtype(options.period, time_type, "period")?;137ensure_duration_matches_dtype(options.offset, time_type, "offset")?;138139use DataType::*;140let (dt, tu, tz): (Column, TimeUnit, Option<TimeZone>) = match time_type {141Datetime(tu, tz) => (time.clone(), *tu, tz.clone()),142Date => (143time.cast(&Datetime(TimeUnit::Microseconds, None))?,144TimeUnit::Microseconds,145None,146),147UInt32 | UInt64 | Int32 => {148let time_type_dt = Datetime(TimeUnit::Nanoseconds, None);149let dt = time.cast(&Int64).unwrap().cast(&time_type_dt).unwrap();150let (out, gt) = self.impl_rolling(151dt,152group_by,153options,154TimeUnit::Nanoseconds,155None,156&time_type_dt,157)?;158let out = out.cast(&Int64).unwrap().cast(time_type).unwrap();159return Ok((out, gt));160},161Int64 => {162let time_type = Datetime(TimeUnit::Nanoseconds, None);163let dt = time.cast(&time_type).unwrap();164let (out, gt) = self.impl_rolling(165dt,166group_by,167options,168TimeUnit::Nanoseconds,169None,170&time_type,171)?;172let out = out.cast(&Int64).unwrap();173return Ok((out, gt));174},175dt => polars_bail!(176ComputeError:177"expected any of the following dtypes: {{ Date, Datetime, Int32, Int64, UInt32, UInt64 }}, got {}",178dt179),180};181match tz {182#[cfg(feature = "timezones")]183Some(tz) => {184self.impl_rolling(dt, group_by, options, tu, tz.parse::<Tz>().ok(), time_type)185},186_ => self.impl_rolling(dt, group_by, options, tu, None, time_type),187}188}189190/// Returns: time_keys, keys, groupsproxy.191fn group_by_dynamic(192&self,193group_by: Option<GroupsSlice>,194options: &DynamicGroupOptions,195) -> PolarsResult<(Column, Vec<Column>, GroupPositions)> {196let time = self.0.column(&options.index_column)?.rechunk();197if group_by.is_none() {198// If by is given, the column must be sorted in the 'by' arg, which we can not check now199// this will be checked when the groups are materialized.200time.as_materialized_series()201.ensure_sorted_arg("group_by_dynamic")?;202}203let time_type = time.dtype();204205polars_ensure!(time.null_count() == 0, ComputeError: "null values in dynamic group_by not supported, fill nulls.");206ensure_duration_matches_dtype(options.every, time_type, "every")?;207ensure_duration_matches_dtype(options.offset, time_type, "offset")?;208ensure_duration_matches_dtype(options.period, time_type, "period")?;209210use DataType::*;211let (dt, tu) = match time_type {212Datetime(tu, _) => (time.clone(), *tu),213Date => (214time.cast(&Datetime(TimeUnit::Microseconds, None))?,215TimeUnit::Microseconds,216),217Int32 => {218let time_type = Datetime(TimeUnit::Nanoseconds, None);219let dt = time.cast(&Int64).unwrap().cast(&time_type).unwrap();220let (out, mut keys, gt) = self.impl_group_by_dynamic(221dt,222group_by,223options,224TimeUnit::Nanoseconds,225&time_type,226)?;227let out = out.cast(&Int64).unwrap().cast(&Int32).unwrap();228for k in &mut keys {229if k.name().as_str() == UP_NAME || k.name().as_str() == LB_NAME {230*k = k.cast(&Int64).unwrap().cast(&Int32).unwrap()231}232}233return Ok((out, keys, gt));234},235Int64 => {236let time_type = Datetime(TimeUnit::Nanoseconds, None);237let dt = time.cast(&time_type).unwrap();238let (out, mut keys, gt) = self.impl_group_by_dynamic(239dt,240group_by,241options,242TimeUnit::Nanoseconds,243&time_type,244)?;245let out = out.cast(&Int64).unwrap();246for k in &mut keys {247if k.name().as_str() == UP_NAME || k.name().as_str() == LB_NAME {248*k = k.cast(&Int64).unwrap()249}250}251return Ok((out, keys, gt));252},253dt => polars_bail!(254ComputeError:255"expected any of the following dtypes: {{ Date, Datetime, Int32, Int64 }}, got {}",256dt257),258};259self.impl_group_by_dynamic(dt, group_by, options, tu, time_type)260}261262fn impl_group_by_dynamic(263&self,264mut dt: Column,265group_by: Option<GroupsSlice>,266options: &DynamicGroupOptions,267tu: TimeUnit,268time_type: &DataType,269) -> PolarsResult<(Column, Vec<Column>, GroupPositions)> {270polars_ensure!(!options.every.negative, ComputeError: "'every' argument must be positive");271if dt.is_empty() {272return dt.cast(time_type).map(|s| (s, vec![], Default::default()));273}274275// A requirement for the index so we can set this such that downstream code has this info.276dt.set_sorted_flag(IsSorted::Ascending);277278let w = Window::new(options.every, options.period, options.offset);279let dt = dt.datetime().unwrap();280let tz = dt.time_zone();281282let mut lower_bound = None;283let mut upper_bound = None;284285let mut include_lower_bound = false;286let mut include_upper_bound = false;287288if options.include_boundaries {289include_lower_bound = true;290include_upper_bound = true;291}292if options.label == Label::Left {293include_lower_bound = true;294} else if options.label == Label::Right {295include_upper_bound = true;296}297298let mut update_bounds =299|lower: Vec<i64>, upper: Vec<i64>| match (&mut lower_bound, &mut upper_bound) {300(None, None) => {301lower_bound = Some(lower);302upper_bound = Some(upper);303},304(Some(lower_bound), Some(upper_bound)) => {305lower_bound.extend_from_slice(&lower);306upper_bound.extend_from_slice(&upper);307},308_ => unreachable!(),309};310311let groups = if group_by.is_none() {312let vals = dt.physical().downcast_iter().next().unwrap();313let ts = vals.values().as_slice();314let (groups, lower, upper) = group_by_windows(315w,316ts,317options.closed_window,318tu,319tz,320include_lower_bound,321include_upper_bound,322options.start_by,323)?;324update_bounds(lower, upper);325PolarsResult::Ok(GroupsType::Slice {326groups,327rolling: false,328})329} else {330let vals = dt.physical().downcast_iter().next().unwrap();331let ts = vals.values().as_slice();332333let groups = group_by.as_ref().unwrap();334335let iter = groups.par_iter().map(|[start, len]| {336let group_offset = *start;337let start = *start as usize;338let end = start + *len as usize;339let values = &ts[start..end];340check_sortedness_slice(values)?;341342let (groups, lower, upper) = group_by_windows(343w,344values,345options.closed_window,346tu,347tz,348include_lower_bound,349include_upper_bound,350options.start_by,351)?;352353PolarsResult::Ok((354groups355.iter()356.map(|[start, len]| [*start + group_offset, *len])357.collect_vec(),358lower,359upper,360))361});362363let res = POOL.install(|| iter.collect::<PolarsResult<Vec<_>>>())?;364let groups = res.iter().map(|g| &g.0).collect_vec();365let lower = res.iter().map(|g| &g.1).collect_vec();366let upper = res.iter().map(|g| &g.2).collect_vec();367368let ((groups, upper), lower) = POOL.install(|| {369rayon::join(370|| rayon::join(|| flatten_par(&groups), || flatten_par(&upper)),371|| flatten_par(&lower),372)373});374375update_bounds(lower, upper);376PolarsResult::Ok(GroupsType::Slice {377groups,378rolling: false,379})380}?;381// note that if 'group_by' is none we can be sure that the index column, the lower column and the382// upper column remain/are sorted383384let dt = unsafe { dt.clone().into_series().agg_first(&groups) };385let mut dt = dt.datetime().unwrap().physical().clone();386387let lower =388lower_bound.map(|lower| Int64Chunked::new_vec(PlSmallStr::from_static(LB_NAME), lower));389let upper =390upper_bound.map(|upper| Int64Chunked::new_vec(PlSmallStr::from_static(UP_NAME), upper));391392if options.label == Label::Left {393let mut lower = lower.clone().unwrap();394if group_by.is_none() {395lower.set_sorted_flag(IsSorted::Ascending)396}397dt = lower.with_name(dt.name().clone());398} else if options.label == Label::Right {399let mut upper = upper.clone().unwrap();400if group_by.is_none() {401upper.set_sorted_flag(IsSorted::Ascending)402}403dt = upper.with_name(dt.name().clone());404}405406let mut bounds = vec![];407if let (true, Some(mut lower), Some(mut upper)) = (options.include_boundaries, lower, upper)408{409if group_by.is_none() {410lower.set_sorted_flag(IsSorted::Ascending);411upper.set_sorted_flag(IsSorted::Ascending);412}413bounds.push(lower.into_datetime(tu, tz.clone()).into_column());414bounds.push(upper.into_datetime(tu, tz.clone()).into_column());415}416417dt.into_datetime(tu, None)418.into_column()419.cast(time_type)420.map(|s| (s, bounds, groups.into_sliceable()))421}422423/// Returns: time_keys, keys, groupsproxy424fn impl_rolling(425&self,426dt: Column,427group_by: Option<GroupsSlice>,428options: &RollingGroupOptions,429tu: TimeUnit,430tz: Option<Tz>,431time_type: &DataType,432) -> PolarsResult<(Column, GroupPositions)> {433let mut dt = dt.rechunk();434435let groups = if group_by.is_none() {436// a requirement for the index437// so we can set this such that downstream code has this info438dt.set_sorted_flag(IsSorted::Ascending);439let dt = dt.datetime().unwrap();440let vals = dt.physical().downcast_iter().next().unwrap();441let ts = vals.values().as_slice();442PolarsResult::Ok(GroupsType::Slice {443groups: group_by_values(444options.period,445options.offset,446ts,447options.closed_window,448tu,449tz,450)?,451rolling: true,452})453} else {454let dt = dt.datetime().unwrap();455let vals = dt.physical().downcast_iter().next().unwrap();456let ts = vals.values().as_slice();457458let groups = group_by.unwrap();459460let iter = groups.into_par_iter().map(|[start, len]| {461let group_offset = start;462let start = start as usize;463let end = start + len as usize;464let values = &ts[start..end];465check_sortedness_slice(values)?;466467let group = group_by_values(468options.period,469options.offset,470values,471options.closed_window,472tu,473tz,474)?;475476PolarsResult::Ok(477group478.iter()479.map(|[start, len]| [*start + group_offset, *len])480.collect_vec(),481)482});483484let groups = POOL.install(|| iter.collect::<PolarsResult<Vec<_>>>())?;485let groups = POOL.install(|| flatten_par(&groups));486PolarsResult::Ok(GroupsType::Slice {487groups,488rolling: true,489})490}?;491492let dt = dt.cast(time_type).unwrap();493494Ok((dt, groups.into_sliceable()))495}496}497498#[cfg(test)]499mod test {500use polars_compute::rolling::QuantileMethod;501use polars_ops::prelude::*;502503use super::*;504505#[test]506fn test_rolling_group_by_tu() -> PolarsResult<()> {507// test multiple time units508for tu in [509TimeUnit::Nanoseconds,510TimeUnit::Microseconds,511TimeUnit::Milliseconds,512] {513let mut date = StringChunked::new(514"dt".into(),515[516"2020-01-01 13:45:48",517"2020-01-01 16:42:13",518"2020-01-01 16:45:09",519"2020-01-02 18:12:48",520"2020-01-03 19:45:32",521"2020-01-08 23:16:43",522],523)524.as_datetime(525None,526tu,527false,528false,529None,530&StringChunked::from_iter(std::iter::once("raise")),531)?532.into_column();533date.set_sorted_flag(IsSorted::Ascending);534let a = Column::new("a".into(), [3, 7, 5, 9, 2, 1]);535let df = DataFrame::new(vec![date, a.clone()])?;536537let (_, groups) = df538.rolling(539None,540&RollingGroupOptions {541index_column: "dt".into(),542period: Duration::parse("2d"),543offset: Duration::parse("-2d"),544closed_window: ClosedWindow::Right,545},546)547.unwrap();548549let sum = unsafe { a.agg_sum(&groups) };550let expected = Column::new("".into(), [3, 10, 15, 24, 11, 1]);551assert_eq!(sum, expected);552}553554Ok(())555}556557#[test]558fn test_rolling_group_by_aggs() -> PolarsResult<()> {559let mut date = StringChunked::new(560"dt".into(),561[562"2020-01-01 13:45:48",563"2020-01-01 16:42:13",564"2020-01-01 16:45:09",565"2020-01-02 18:12:48",566"2020-01-03 19:45:32",567"2020-01-08 23:16:43",568],569)570.as_datetime(571None,572TimeUnit::Milliseconds,573false,574false,575None,576&StringChunked::from_iter(std::iter::once("raise")),577)?578.into_column();579date.set_sorted_flag(IsSorted::Ascending);580581let a = Column::new("a".into(), [3, 7, 5, 9, 2, 1]);582let df = DataFrame::new(vec![date, a.clone()])?;583584let (_, groups) = df585.rolling(586None,587&RollingGroupOptions {588index_column: "dt".into(),589period: Duration::parse("2d"),590offset: Duration::parse("-2d"),591closed_window: ClosedWindow::Right,592},593)594.unwrap();595596let nulls = Series::new(597"".into(),598[Some(3), Some(7), None, Some(9), Some(2), Some(1)],599);600601let min = unsafe { a.as_materialized_series().agg_min(&groups) };602let expected = Series::new("".into(), [3, 3, 3, 3, 2, 1]);603assert_eq!(min, expected);604605// Expected for nulls is equality.606let min = unsafe { nulls.agg_min(&groups) };607assert_eq!(min, expected);608609let max = unsafe { a.as_materialized_series().agg_max(&groups) };610let expected = Series::new("".into(), [3, 7, 7, 9, 9, 1]);611assert_eq!(max, expected);612613let max = unsafe { nulls.agg_max(&groups) };614assert_eq!(max, expected);615616let var = unsafe { a.as_materialized_series().agg_var(&groups, 1) };617let expected = Series::new(618"".into(),619[0.0, 8.0, 4.000000000000002, 6.666666666666667, 24.5, 0.0],620);621assert!(abs(&(var - expected)?).unwrap().lt(1e-12).unwrap().all());622623let var = unsafe { nulls.agg_var(&groups, 1) };624let expected = Series::new("".into(), [0.0, 8.0, 8.0, 9.333333333333343, 24.5, 0.0]);625assert!(abs(&(var - expected)?).unwrap().lt(1e-12).unwrap().all());626627let quantile = unsafe {628a.as_materialized_series()629.agg_quantile(&groups, 0.5, QuantileMethod::Linear)630};631let expected = Series::new("".into(), [3.0, 5.0, 5.0, 6.0, 5.5, 1.0]);632assert_eq!(quantile, expected);633634let quantile = unsafe { nulls.agg_quantile(&groups, 0.5, QuantileMethod::Linear) };635let expected = Series::new("".into(), [3.0, 5.0, 5.0, 7.0, 5.5, 1.0]);636assert_eq!(quantile, expected);637638Ok(())639}640}641642643