Path: blob/main/crates/polars-stream/src/nodes/cum_agg.rs
6939 views
use polars_core::prelude::{AnyValue, IntoColumn};1use polars_core::utils::last_non_null;2use polars_error::PolarsResult;3use polars_ops::series::{4cum_count_with_init, cum_max_with_init, cum_min_with_init, cum_prod_with_init,5cum_sum_with_init,6};78use super::ComputeNode;9use crate::async_executor::{JoinHandle, TaskPriority, TaskScope};10use crate::execute::StreamingExecutionState;11use crate::graph::PortState;12use crate::pipe::{RecvPort, SendPort};1314pub struct CumAggNode {15state: AnyValue<'static>,16kind: CumAggKind,17}1819#[derive(Debug, Clone, Copy)]20pub enum CumAggKind {21Min,22Max,23Sum,24Count,25Prod,26}2728impl CumAggNode {29pub fn new(kind: CumAggKind) -> Self {30Self {31state: AnyValue::Null,32kind,33}34}35}3637impl ComputeNode for CumAggNode {38fn name(&self) -> &str {39match self.kind {40CumAggKind::Min => "cum_min",41CumAggKind::Max => "cum_max",42CumAggKind::Sum => "cum_sum",43CumAggKind::Count => "cum_count",44CumAggKind::Prod => "cum_prod",45}46}4748fn update_state(49&mut self,50recv: &mut [PortState],51send: &mut [PortState],52_state: &StreamingExecutionState,53) -> PolarsResult<()> {54assert!(recv.len() == 1 && send.len() == 1);5556recv.swap_with_slice(send);57Ok(())58}5960fn spawn<'env, 's>(61&'env mut self,62scope: &'s TaskScope<'s, 'env>,63recv_ports: &mut [Option<RecvPort<'_>>],64send_ports: &mut [Option<SendPort<'_>>],65_state: &'s StreamingExecutionState,66join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,67) {68assert_eq!(recv_ports.len(), 1);69assert_eq!(send_ports.len(), 1);7071let mut recv = recv_ports[0].take().unwrap().serial();72let mut send = send_ports[0].take().unwrap().serial();7374join_handles.push(scope.spawn_task(TaskPriority::High, async move {75while let Ok(mut m) = recv.recv().await {76assert_eq!(m.df().width(), 1);77if m.df().height() == 0 {78continue;79}8081let s = m.df()[0].as_materialized_series();82let out = match self.kind {83CumAggKind::Min => cum_min_with_init(s, false, &self.state),84CumAggKind::Max => cum_max_with_init(s, false, &self.state),85CumAggKind::Sum => cum_sum_with_init(s, false, &self.state),86CumAggKind::Count => {87cum_count_with_init(s, false, self.state.extract().unwrap_or_default())88},89CumAggKind::Prod => cum_prod_with_init(s, false, &self.state),90}?;9192// Find the last non-null value and set that as the state.93let last_non_null_idx = if out.has_nulls() {94last_non_null(out.chunks().iter().map(|c| c.validity()), out.len())95} else {96Some(out.len() - 1)97};98if let Some(idx) = last_non_null_idx {99self.state = out.get(idx).unwrap().into_static();100}101*m.df_mut() = out.into_column().into_frame();102103if send.send(m).await.is_err() {104break;105}106}107108Ok(())109}));110}111}112113114