Path: blob/main/crates/polars-stream/src/nodes/rolling_group_by.rs
7884 views
use std::sync::Arc;12use chrono_tz::Tz;3use polars_core::frame::DataFrame;4use polars_core::prelude::{Column, DataType, GroupsType, TimeUnit};5use polars_core::schema::Schema;6use polars_error::{PolarsError, PolarsResult, polars_bail, polars_ensure};7use polars_expr::state::ExecutionState;8use polars_time::prelude::{RollingWindower, ensure_duration_matches_dtype};9use polars_time::{ClosedWindow, Duration};10use polars_utils::IdxSize;11use polars_utils::pl_str::PlSmallStr;1213use super::ComputeNode;14use crate::DEFAULT_DISTRIBUTOR_BUFFER_SIZE;15use crate::async_executor::{JoinHandle, TaskPriority, TaskScope};16use crate::async_primitives::distributor_channel::distributor_channel;17use crate::async_primitives::wait_group::WaitGroup;18use crate::execute::StreamingExecutionState;19use crate::expression::StreamExpr;20use crate::graph::PortState;21use crate::morsel::{Morsel, MorselSeq, SourceToken};22use crate::pipe::{RecvPort, SendPort};2324type NextWindows = (Vec<[IdxSize; 2]>, DataFrame, Column);2526pub struct RollingGroupBy {27buf_df: DataFrame,28/// How many `buf_df` rows did we discard of already?29buf_df_offset: IdxSize,30/// Casted index column, which may need to keep around old values.31buf_index_column: Column,32/// Uncasted index column.33buf_key_column: Column,3435seq: MorselSeq,3637slice_offset: IdxSize,38slice_length: IdxSize,3940index_column: PlSmallStr,41windower: RollingWindower,42aggs: Arc<[(PlSmallStr, StreamExpr)]>,43}44impl RollingGroupBy {45pub fn new(46schema: Arc<Schema>,47index_column: PlSmallStr,48period: Duration,49offset: Duration,50closed: ClosedWindow,51slice: Option<(IdxSize, IdxSize)>,52aggs: Arc<[(PlSmallStr, StreamExpr)]>,53) -> PolarsResult<Self> {54polars_ensure!(55!period.is_zero() && !period.negative(),56ComputeError: "rolling window period should be strictly positive",57);5859let key_dtype = schema.get(&index_column).unwrap();60ensure_duration_matches_dtype(period, key_dtype, "period")?;61ensure_duration_matches_dtype(offset, key_dtype, "offset")?;6263use DataType as DT;64let (tu, tz) = match key_dtype {65DT::Datetime(tu, tz) => (*tu, tz.clone()),66DT::Date => (TimeUnit::Microseconds, None),67DT::UInt32 | DT::UInt64 | DT::Int64 | DT::Int32 => (TimeUnit::Nanoseconds, None),68dt => polars_bail!(69ComputeError:70"expected any of the following dtypes: {{ Date, Datetime, Int32, Int64, UInt32, UInt64 }}, got {}",71dt72),73};7475let buf_df = DataFrame::empty_with_arc_schema(schema.clone());76let buf_key_column = Column::new_empty(index_column.clone(), key_dtype);77let buf_index_column =78Column::new_empty(index_column.clone(), &DT::Datetime(tu, tz.clone()));7980// @NOTE: This is a bit strange since it ignores errors, but it mirrors the in-memory81// engine.82let tz = tz.and_then(|tz| tz.parse::<Tz>().ok());83let windower = RollingWindower::new(period, offset, closed, tu, tz);8485let (slice_offset, slice_length) = slice.unwrap_or((0, IdxSize::MAX));8687Ok(Self {88buf_df,89buf_df_offset: 0,90buf_index_column,91buf_key_column,92seq: MorselSeq::default(),93slice_offset,94slice_length,95index_column,96windower,97aggs,98})99}100101async fn evaluate_one(102windows: Vec<[IdxSize; 2]>,103key: Column,104aggs: &[(PlSmallStr, StreamExpr)],105state: &ExecutionState,106mut df: DataFrame,107) -> PolarsResult<DataFrame> {108assert_eq!(windows.len(), key.len());109110let groups = GroupsType::new_slice(windows, true, true).into_sliceable();111112// @NOTE:113// Rechunk so we can use specialized rolling kernels.114//115// This can be removed if / when the rolling kernels are chunking aware.116df.rechunk_mut();117118let mut columns = Vec::with_capacity(1 + aggs.len());119let height = key.len();120columns.push(key);121for (name, agg) in aggs.iter() {122let mut agg = agg.evaluate_on_groups(&df, &groups, state).await?;123let agg = agg.finalize();124columns.push(agg.with_name(name.clone()));125}126127Ok(unsafe { DataFrame::new_no_checks(height, columns) })128}129130/// Progress the state and get the next available evaluation windows, data and key.131fn next_windows(&mut self, finalize: bool) -> PolarsResult<Option<NextWindows>> {132let buf_index_col_dt = self.buf_index_column.datetime()?;133let mut time = Vec::new();134time.extend(135buf_index_col_dt136.physical()137.downcast_iter()138.map(|arr| arr.values().as_slice()),139);140141let mut windows = Vec::new();142let num_retired = if finalize {143self.windower.finalize(&time, &mut windows);144self.buf_key_column.len() as IdxSize145} else {146self.windower.insert(&time, &mut windows)?147};148149if num_retired == 0 && windows.is_empty() {150return Ok(None);151}152153let start_row_offset = self.buf_df_offset;154155self.buf_index_column = self.buf_index_column.slice(num_retired as i64, usize::MAX);156let new_buf_df = self.buf_df.slice(num_retired as i64, usize::MAX);157let data = std::mem::replace(&mut self.buf_df, new_buf_df);158self.buf_df_offset += num_retired;159160if windows.is_empty() {161return Ok(None);162}163164let key;165(key, self.buf_key_column) = self.buf_key_column.split_at(windows.len() as i64);166let key = key.slice(self.slice_offset as i64, self.slice_length as usize);167168let offset = windows[0][0];169let end = windows.last().unwrap();170let end = end[0] + end[1];171172if self.slice_offset as usize > windows.len() {173self.slice_offset -= windows.len() as IdxSize;174windows.clear();175} else if self.slice_offset > 0 {176let offset = self.slice_offset as usize;177self.slice_offset = self.slice_offset.saturating_sub(windows.len() as IdxSize);178windows.drain(..offset);179}180181windows.truncate(windows.len().min(self.slice_length as usize));182self.slice_length -= windows.len() as IdxSize;183184if windows.is_empty() {185return Ok(None);186}187188// Prune the data that is not covered by the windows and update the windows accordingly.189windows.iter_mut().for_each(|[s, _]| *s -= offset);190let data = data.slice(191(offset - start_row_offset) as i64,192(end - start_row_offset) as usize,193);194195Ok(Some((windows, data, key)))196}197}198199impl ComputeNode for RollingGroupBy {200fn name(&self) -> &str {201"rolling-group-by"202}203204fn update_state(205&mut self,206recv: &mut [PortState],207send: &mut [PortState],208_state: &StreamingExecutionState,209) -> PolarsResult<()> {210assert!(recv.len() == 1 && send.len() == 1);211212if self.slice_length == 0 {213recv[0] = PortState::Done;214send[0] = PortState::Done;215std::mem::take(&mut self.buf_df);216return Ok(());217}218219if send[0] == PortState::Done {220recv[0] = PortState::Done;221std::mem::take(&mut self.buf_df);222} else if recv[0] == PortState::Done {223if self.buf_df.is_empty() {224send[0] = PortState::Done;225} else {226send[0] = PortState::Ready;227}228} else {229recv.swap_with_slice(send);230}231232Ok(())233}234235fn spawn<'env, 's>(236&'env mut self,237scope: &'s TaskScope<'s, 'env>,238recv_ports: &mut [Option<RecvPort<'_>>],239send_ports: &mut [Option<SendPort<'_>>],240state: &'s StreamingExecutionState,241join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,242) {243assert!(recv_ports.len() == 1 && send_ports.len() == 1);244245let Some(recv) = recv_ports[0].take() else {246// We no longer have to receive data. Finalize and send all remaining data.247assert!(!self.buf_df.is_empty());248assert!(self.slice_length > 0);249let mut send = send_ports[0].take().unwrap().serial();250join_handles.push(scope.spawn_task(TaskPriority::High, async move {251if let Some((windows, df, key)) = self.next_windows(true)? {252let df = Self::evaluate_one(253windows,254key,255&self.aggs,256&state.in_memory_exec_state,257df,258)259.await?;260261_ = send262.send(Morsel::new(df, self.seq.successor(), SourceToken::new()))263.await;264}265266self.buf_df = self.buf_df.clear();267self.buf_key_column = self.buf_key_column.clear();268self.buf_index_column = self.buf_index_column.clear();269270Ok(())271}));272return;273};274275let mut recv = recv.serial();276let send = send_ports[0].take().unwrap().parallel();277278let (mut distributor, rxs) = distributor_channel::<(Morsel, Column, Vec<[IdxSize; 2]>)>(279send.len(),280*DEFAULT_DISTRIBUTOR_BUFFER_SIZE,281);282283// Worker tasks.284//285// These evaluate the aggregations.286join_handles.extend(rxs.into_iter().zip(send).map(|(mut rx, mut tx)| {287let wg = WaitGroup::default();288let aggs = self.aggs.clone();289let state = state.in_memory_exec_state.split();290scope.spawn_task(TaskPriority::High, async move {291while let Ok((mut morsel, key, windows)) = rx.recv().await {292morsel = morsel293.async_try_map::<PolarsError, _, _>(async |df| {294Self::evaluate_one(windows, key, &aggs, &state, df).await295})296.await?;297morsel.set_consume_token(wg.token());298299if tx.send(morsel).await.is_err() {300break;301}302wg.wait().await;303}304305Ok(())306})307}));308309// Distributor task.310//311// This finds boundaries to distribute to worker threads over.312join_handles.push(scope.spawn_task(TaskPriority::High, async move {313while let Ok(morsel) = recv.recv().await314&& self.slice_length > 0315{316let (df, seq, source_token, wait_token) = morsel.into_inner();317self.seq = seq;318drop(wait_token);319320if df.height() == 0 {321continue;322}323324let morsel_index_column = df.column(&self.index_column)?;325polars_ensure!(326morsel_index_column.null_count() == 0,327ComputeError: "null values in `rolling` not supported, fill nulls."328);329330self.buf_key_column.append(morsel_index_column)?;331332use DataType as DT;333let morsel_index_column = match morsel_index_column.dtype() {334DT::Datetime(_, _) => morsel_index_column.clone(),335DT::Date => {336morsel_index_column.cast(&DT::Datetime(TimeUnit::Microseconds, None))?337},338DT::UInt32 | DT::UInt64 | DT::Int32 => morsel_index_column339.cast(&DT::Int64)?340.cast(&DT::Datetime(TimeUnit::Nanoseconds, None))?,341DT::Int64 => {342morsel_index_column.cast(&DT::Datetime(TimeUnit::Nanoseconds, None))?343},344_ => unreachable!(),345};346self.buf_index_column.append(&morsel_index_column)?;347self.buf_df.vstack_mut_owned(df)?;348349if let Some((windows, df, key)) = self.next_windows(false)? {350if distributor351.send((Morsel::new(df, seq, source_token), key, windows))352.await353.is_err()354{355break;356}357}358}359360Ok(())361}));362}363}364365366