Path: blob/main/crates/polars-stream/src/nodes/dynamic_group_by.rs
7884 views
use std::sync::Arc;12use arrow::legacy::time_zone::Tz;3use polars_core::frame::DataFrame;4use polars_core::prelude::{Column, DataType, GroupsType, Int64Chunked, IntoColumn, TimeUnit};5use polars_core::schema::Schema;6use polars_core::series::IsSorted;7use polars_error::{PolarsError, PolarsResult, polars_bail, polars_ensure};8use polars_expr::state::ExecutionState;9use polars_time::prelude::{GroupByDynamicWindower, Label, ensure_duration_matches_dtype};10use polars_time::{DynamicGroupOptions, LB_NAME, UB_NAME};11use polars_utils::IdxSize;12use polars_utils::pl_str::PlSmallStr;1314use super::ComputeNode;15use crate::DEFAULT_DISTRIBUTOR_BUFFER_SIZE;16use crate::async_executor::{JoinHandle, TaskPriority, TaskScope};17use crate::async_primitives::distributor_channel::distributor_channel;18use crate::async_primitives::wait_group::WaitGroup;19use crate::execute::StreamingExecutionState;20use crate::expression::StreamExpr;21use crate::graph::PortState;22use crate::morsel::{Morsel, MorselSeq, SourceToken};23use crate::pipe::{RecvPort, SendPort};2425type NextWindows = (Vec<[IdxSize; 2]>, Vec<i64>, Vec<i64>, DataFrame);2627pub struct DynamicGroupBy {28buf_df: DataFrame,29/// How many `buf_df` rows did we discard of already?30buf_df_offset: IdxSize,31buf_index_column: Column,3233seq: MorselSeq,3435slice_offset: IdxSize,36slice_length: IdxSize,3738group_by: Option<PlSmallStr>,39index_column: PlSmallStr,40index_column_idx: usize,41label: Label,42include_boundaries: bool,43windower: GroupByDynamicWindower,44aggs: Arc<[(PlSmallStr, StreamExpr)]>,45}46impl DynamicGroupBy {47pub fn new(48schema: Arc<Schema>,49options: DynamicGroupOptions,50aggs: Arc<[(PlSmallStr, StreamExpr)]>,51slice: Option<(IdxSize, IdxSize)>,52) -> PolarsResult<Self> {53let DynamicGroupOptions {54index_column,55every,56period,57offset,58label,59include_boundaries,60closed_window,61start_by,62} = options;6364polars_ensure!(!every.negative(), ComputeError: "'every' argument must be positive");6566let (index_column_idx, _, index_dtype) = schema.get_full(&index_column).unwrap();67ensure_duration_matches_dtype(every, index_dtype, "every")?;68ensure_duration_matches_dtype(period, index_dtype, "period")?;69ensure_duration_matches_dtype(offset, index_dtype, "offset")?;7071use DataType as DT;72let (tu, tz) = match index_dtype {73DT::Datetime(tu, tz) => (*tu, tz.clone()),74DT::Date => (TimeUnit::Microseconds, None),75DT::Int64 | DT::Int32 => (TimeUnit::Nanoseconds, None),76dt => polars_bail!(77ComputeError:78"expected any of the following dtypes: {{ Date, Datetime, Int32, Int64 }}, got {}",79dt80),81};8283let buf_df = DataFrame::empty_with_arc_schema(schema.clone());84let buf_index_column =85Column::new_empty(index_column.clone(), &DT::Datetime(tu, tz.clone()));8687// @NOTE: This is a bit strange since it ignores errors, but it mirrors the in-memory88// engine.89let tz = tz.and_then(|tz| tz.parse::<Tz>().ok());90let windower = GroupByDynamicWindower::new(91period,92offset,93every,94start_by,95closed_window,96tu,97tz,98include_boundaries || matches!(label, Label::Left),99include_boundaries || matches!(label, Label::Right),100);101102let (slice_offset, slice_length) = slice.unwrap_or((0, IdxSize::MAX));103104Ok(Self {105buf_df,106107buf_df_offset: 0,108buf_index_column,109seq: MorselSeq::default(),110111slice_offset,112slice_length,113114group_by: None,115index_column,116index_column_idx,117label,118include_boundaries,119windower,120aggs,121})122}123124#[expect(clippy::too_many_arguments)]125async fn evaluate_one(126windows: Vec<[IdxSize; 2]>,127lower_bound: Vec<i64>,128upper_bound: Vec<i64>,129aggs: &[(PlSmallStr, StreamExpr)],130state: &ExecutionState,131mut df: DataFrame,132133group_by: Option<&str>,134index_column_name: &str,135index_column_idx: usize,136label: Label,137include_boundaries: bool,138) -> PolarsResult<DataFrame> {139let height = windows.len();140let groups = GroupsType::new_slice(windows, true, true).into_sliceable();141142// @NOTE:143// Rechunk so we can use specialized rolling/dynamic kernels.144df.rechunk_mut();145146let mut columns =147Vec::with_capacity(if include_boundaries { 2 } else { 0 } + 1 + aggs.len());148149// Construct `lower_bound`, `upper_bound` and `key` columns that might be included in the150// output dataframe.151{152let mut lower = Int64Chunked::new_vec(PlSmallStr::from_static(LB_NAME), lower_bound);153let mut upper = Int64Chunked::new_vec(PlSmallStr::from_static(UB_NAME), upper_bound);154if group_by.is_none() {155lower.set_sorted_flag(IsSorted::Ascending);156upper.set_sorted_flag(IsSorted::Ascending);157}158let mut lower = lower.into_column();159let mut upper = upper.into_column();160161let index_column = &df.get_columns()[index_column_idx];162let index_dtype = index_column.dtype();163let mut bound_dtype_physical = index_dtype.to_physical();164let mut bound_dtype = index_dtype;165if index_dtype.is_date() {166bound_dtype = &DataType::Datetime(TimeUnit::Microseconds, None);167bound_dtype_physical = DataType::Int64;168}169lower = lower.cast(&bound_dtype_physical).unwrap();170upper = upper.cast(&bound_dtype_physical).unwrap();171(lower, upper) = unsafe {172(173lower.from_physical_unchecked(bound_dtype)?,174upper.from_physical_unchecked(bound_dtype)?,175)176};177178let key = match label {179Label::DataPoint => unsafe { index_column.agg_first(&groups) },180Label::Left => lower181.cast(index_dtype)182.unwrap()183.with_name(index_column_name.into()),184Label::Right => upper185.cast(index_dtype)186.unwrap()187.with_name(index_column_name.into()),188};189190if include_boundaries {191columns.extend([lower, upper]);192}193columns.push(key);194}195196for (name, agg) in aggs.iter() {197let mut agg = agg.evaluate_on_groups(&df, &groups, state).await?;198let agg = agg.finalize();199columns.push(agg.with_name(name.clone()));200}201202Ok(unsafe { DataFrame::new_no_checks(height, columns) })203}204205/// Progress the state and get the next available evaluation windows, data and key.206fn next_windows(&mut self, finalize: bool) -> PolarsResult<Option<NextWindows>> {207let mut windows = Vec::new();208let mut lower_bound = Vec::new();209let mut upper_bound = Vec::new();210211let num_retired = if finalize {212self.windower213.finalize(&mut windows, &mut lower_bound, &mut upper_bound);214self.buf_df.height() as IdxSize215} else {216let mut offset = self.windower.num_seen() - self.buf_df_offset;217let ca = self.buf_index_column.datetime()?;218for arr in ca.physical().downcast_iter() {219let arr_len = arr.len() as IdxSize;220if offset >= arr_len {221offset -= arr_len;222continue;223}224225self.windower.insert(226&arr.values().as_slice()[offset as usize..],227&mut windows,228&mut lower_bound,229&mut upper_bound,230)?;231offset = offset.saturating_sub(arr_len);232}233self.windower.lowest_needed_index() - self.buf_df_offset234};235236if windows.is_empty() {237if num_retired > 0 {238self.buf_df = self.buf_df.slice(num_retired as i64, usize::MAX);239self.buf_index_column = self.buf_index_column.slice(num_retired as i64, usize::MAX);240self.buf_df_offset += num_retired;241}242243return Ok(None);244}245246// Prune the data that is not covered by the windows and update the windows accordingly.247let offset = windows[0][0];248let end = windows.last().unwrap();249let end = end[0] + end[1];250251if self.slice_offset as usize > windows.len() {252self.slice_offset -= windows.len() as IdxSize;253windows.clear();254lower_bound.clear();255upper_bound.clear();256} else if self.slice_offset > 0 {257let offset = self.slice_offset as usize;258self.slice_offset = self.slice_offset.saturating_sub(windows.len() as IdxSize);259windows.drain(..offset);260lower_bound.drain(..offset.min(lower_bound.len()));261upper_bound.drain(..offset.min(upper_bound.len()));262}263264let trunc_length = windows.len().min(self.slice_length as usize);265windows.truncate(trunc_length);266lower_bound.truncate(trunc_length);267upper_bound.truncate(trunc_length);268self.slice_length -= windows.len() as IdxSize;269270windows.iter_mut().for_each(|[s, _]| *s -= offset);271let data = self.buf_df.slice(272(offset - self.buf_df_offset) as i64,273(end - self.buf_df_offset) as usize,274);275276self.buf_df = self.buf_df.slice(num_retired as i64, usize::MAX);277self.buf_index_column = self.buf_index_column.slice(num_retired as i64, usize::MAX);278self.buf_df_offset += num_retired;279280if windows.is_empty() {281return Ok(None);282}283284Ok(Some((windows, lower_bound, upper_bound, data)))285}286}287288impl ComputeNode for DynamicGroupBy {289fn name(&self) -> &str {290"dynamic-group-by"291}292293fn update_state(294&mut self,295recv: &mut [PortState],296send: &mut [PortState],297_state: &StreamingExecutionState,298) -> PolarsResult<()> {299assert!(recv.len() == 1 && send.len() == 1);300301if self.slice_length == 0 {302recv[0] = PortState::Done;303send[0] = PortState::Done;304std::mem::take(&mut self.buf_df);305return Ok(());306}307308if send[0] == PortState::Done {309recv[0] = PortState::Done;310std::mem::take(&mut self.buf_df);311} else if recv[0] == PortState::Done {312if self.buf_df.is_empty() {313send[0] = PortState::Done;314} else {315send[0] = PortState::Ready;316}317} else {318recv.swap_with_slice(send);319}320321Ok(())322}323324fn spawn<'env, 's>(325&'env mut self,326scope: &'s TaskScope<'s, 'env>,327recv_ports: &mut [Option<RecvPort<'_>>],328send_ports: &mut [Option<SendPort<'_>>],329state: &'s StreamingExecutionState,330join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,331) {332assert!(recv_ports.len() == 1 && send_ports.len() == 1);333334let Some(recv) = recv_ports[0].take() else {335// We no longer have to receive data. Finalize and send all remaining data.336assert!(!self.buf_df.is_empty());337assert!(self.slice_length > 0);338let mut send = send_ports[0].take().unwrap().serial();339join_handles.push(scope.spawn_task(TaskPriority::High, async move {340if let Some((windows, lower_bound, upper_bound, df)) = self.next_windows(true)? {341let df = Self::evaluate_one(342windows,343lower_bound,344upper_bound,345&self.aggs,346&state.in_memory_exec_state,347df,348self.group_by.as_deref(),349self.index_column.as_str(),350self.index_column_idx,351self.label,352self.include_boundaries,353)354.await?;355356_ = send357.send(Morsel::new(df, self.seq.successor(), SourceToken::new()))358.await;359}360361self.buf_df = self.buf_df.clear();362Ok(())363}));364return;365};366367let mut recv = recv.serial();368let send = send_ports[0].take().unwrap().parallel();369370let (mut distributor, rxs) =371distributor_channel::<(Morsel, Vec<[IdxSize; 2]>, Vec<i64>, Vec<i64>)>(372send.len(),373*DEFAULT_DISTRIBUTOR_BUFFER_SIZE,374);375376// Worker tasks.377//378// These evaluate the aggregations.379join_handles.extend(rxs.into_iter().zip(send).map(|(mut rx, mut tx)| {380let wg = WaitGroup::default();381let aggs = self.aggs.clone();382let state = state.in_memory_exec_state.split();383384let group_by = self.group_by.clone();385let index_column = self.index_column.clone();386let index_column_idx = self.index_column_idx;387let label = self.label;388let include_boundaries = self.include_boundaries;389390scope.spawn_task(TaskPriority::High, async move {391while let Ok((mut morsel, windows, lower_bound, upper_bound)) = rx.recv().await {392morsel = morsel393.async_try_map::<PolarsError, _, _>(async |df| {394Self::evaluate_one(395windows,396lower_bound,397upper_bound,398&aggs,399&state,400df,401group_by.as_deref(),402index_column.as_str(),403index_column_idx,404label,405include_boundaries,406)407.await408})409.await?;410morsel.set_consume_token(wg.token());411412if tx.send(morsel).await.is_err() {413break;414}415wg.wait().await;416}417418Ok(())419})420}));421422// Distributor task.423//424// This finds boundaries to distribute to worker threads over.425join_handles.push(scope.spawn_task(TaskPriority::High, async move {426while let Ok(morsel) = recv.recv().await427&& self.slice_length > 0428{429let (df, seq, source_token, wait_token) = morsel.into_inner();430self.seq = seq;431drop(wait_token);432433if df.height() == 0 {434continue;435}436437let morsel_index_column = df.column(&self.index_column)?;438polars_ensure!(439morsel_index_column.null_count() == 0,440ComputeError: "null values in `group_by_dynamic` not supported, fill nulls."441);442443use DataType as DT;444let morsel_index_column = match morsel_index_column.dtype() {445DT::Datetime(_, _) => morsel_index_column.clone(),446DT::Date => {447morsel_index_column.cast(&DT::Datetime(TimeUnit::Microseconds, None))?448},449DT::Int32 => morsel_index_column450.cast(&DT::Int64)?451.cast(&DT::Datetime(TimeUnit::Nanoseconds, None))?,452DT::Int64 => {453morsel_index_column.cast(&DT::Datetime(TimeUnit::Nanoseconds, None))?454},455_ => unreachable!(),456};457458self.buf_df.vstack_mut_owned(df)?;459self.buf_index_column.append_owned(morsel_index_column)?;460461if let Some((windows, lower_bound, upper_bound, df)) = self.next_windows(false)? {462if distributor463.send((464Morsel::new(df, seq, source_token),465windows,466lower_bound,467upper_bound,468))469.await470.is_err()471{472break;473}474}475}476477Ok(())478}));479}480}481482483