Path: blob/main/crates/polars-stream/src/nodes/reduce.rs
6939 views
use std::sync::Arc;12use polars_core::frame::column::ScalarColumn;3use polars_core::prelude::Column;4use polars_core::schema::{Schema, SchemaExt};5use polars_expr::reduce::GroupedReduction;6use polars_utils::itertools::Itertools;78use super::compute_node_prelude::*;9use crate::expression::StreamExpr;10use crate::morsel::SourceToken;1112enum ReduceState {13Sink {14selectors: Vec<StreamExpr>,15reductions: Vec<Box<dyn GroupedReduction>>,16},17Source(Option<DataFrame>),18Done,19}2021pub struct ReduceNode {22state: ReduceState,23output_schema: Arc<Schema>,24}2526impl ReduceNode {27pub fn new(28selectors: Vec<StreamExpr>,29reductions: Vec<Box<dyn GroupedReduction>>,30output_schema: Arc<Schema>,31) -> Self {32Self {33state: ReduceState::Sink {34selectors,35reductions,36},37output_schema,38}39}4041fn spawn_sink<'env, 's>(42selectors: &'env [StreamExpr],43reductions: &'env mut [Box<dyn GroupedReduction>],44scope: &'s TaskScope<'s, 'env>,45recv: RecvPort<'_>,46state: &'s StreamingExecutionState,47join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,48) {49let parallel_tasks: Vec<_> = recv50.parallel()51.into_iter()52.map(|mut recv| {53let mut local_reducers: Vec<_> = reductions54.iter()55.map(|d| {56let mut r = d.new_empty();57r.resize(1);58r59})60.collect();6162scope.spawn_task(TaskPriority::High, async move {63while let Ok(morsel) = recv.recv().await {64for (reducer, selector) in local_reducers.iter_mut().zip(selectors) {65let input = selector66.evaluate(morsel.df(), &state.in_memory_exec_state)67.await?;68reducer.update_group(&input, 0, morsel.seq().to_u64())?;69}70}7172PolarsResult::Ok(local_reducers)73})74})75.collect();7677join_handles.push(scope.spawn_task(TaskPriority::High, async move {78for task in parallel_tasks {79let local_reducers = task.await?;80for (r1, r2) in reductions.iter_mut().zip(local_reducers) {81r1.resize(1);82unsafe {83r1.combine_subset(&*r2, &[0], &[0])?;84}85}86}8788Ok(())89}));90}9192fn spawn_source<'env, 's>(93df: &'env mut Option<DataFrame>,94scope: &'s TaskScope<'s, 'env>,95send: SendPort<'_>,96join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,97) {98let mut send = send.serial();99join_handles.push(scope.spawn_task(TaskPriority::High, async move {100let morsel = Morsel::new(df.take().unwrap(), MorselSeq::new(0), SourceToken::new());101let _ = send.send(morsel).await;102Ok(())103}));104}105}106107impl ComputeNode for ReduceNode {108fn name(&self) -> &str {109"reduce"110}111112fn update_state(113&mut self,114recv: &mut [PortState],115send: &mut [PortState],116_state: &StreamingExecutionState,117) -> PolarsResult<()> {118assert!(recv.len() == 1 && send.len() == 1);119120// State transitions.121match &mut self.state {122// If the output doesn't want any more data, transition to being done.123_ if send[0] == PortState::Done => {124self.state = ReduceState::Done;125},126// Input is done, transition to being a source.127ReduceState::Sink { reductions, .. } if matches!(recv[0], PortState::Done) => {128let columns = reductions129.iter_mut()130.zip(self.output_schema.iter_fields())131.map(|(r, field)| {132r.resize(1);133r.finalize().map(|s| {134let s = s.with_name(field.name.clone()).cast(&field.dtype).unwrap();135Column::Scalar(ScalarColumn::unit_scalar_from_series(s))136})137})138.try_collect_vec()?;139let out = DataFrame::new(columns).unwrap();140141self.state = ReduceState::Source(Some(out));142},143// We have sent the reduced dataframe, we are done.144ReduceState::Source(df) if df.is_none() => {145self.state = ReduceState::Done;146},147// Nothing to change.148ReduceState::Done | ReduceState::Sink { .. } | ReduceState::Source(_) => {},149}150151// Communicate our state.152match &self.state {153ReduceState::Sink { .. } => {154send[0] = PortState::Blocked;155recv[0] = PortState::Ready;156},157ReduceState::Source(..) => {158recv[0] = PortState::Done;159send[0] = PortState::Ready;160},161ReduceState::Done => {162recv[0] = PortState::Done;163send[0] = PortState::Done;164},165}166Ok(())167}168169fn spawn<'env, 's>(170&'env mut self,171scope: &'s TaskScope<'s, 'env>,172recv_ports: &mut [Option<RecvPort<'_>>],173send_ports: &mut [Option<SendPort<'_>>],174state: &'s StreamingExecutionState,175join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,176) {177assert!(send_ports.len() == 1 && recv_ports.len() == 1);178match &mut self.state {179ReduceState::Sink {180selectors,181reductions,182} => {183assert!(send_ports[0].is_none());184let recv_port = recv_ports[0].take().unwrap();185Self::spawn_sink(selectors, reductions, scope, recv_port, state, join_handles)186},187ReduceState::Source(df) => {188assert!(recv_ports[0].is_none());189let send_port = send_ports[0].take().unwrap();190Self::spawn_source(df, scope, send_port, join_handles)191},192ReduceState::Done => unreachable!(),193}194}195}196197198