Path: blob/main/crates/polars-stream/src/nodes/reduce.rs
8430 views
use std::sync::Arc;12use polars_core::frame::column::ScalarColumn;3use polars_core::prelude::Column;4use polars_core::schema::{Schema, SchemaExt};5use polars_expr::reduce::GroupedReduction;6use polars_utils::itertools::Itertools;78use super::compute_node_prelude::*;9use crate::expression::StreamExpr;10use crate::morsel::SourceToken;1112enum ReduceState {13Sink {14selectors: Vec<Vec<StreamExpr>>,15reductions: Vec<Box<dyn GroupedReduction>>,16},17Source(Option<DataFrame>),18Done,19}2021pub struct ReduceNode {22state: ReduceState,23output_schema: Arc<Schema>,24}2526impl ReduceNode {27pub fn new(28selectors: Vec<Vec<StreamExpr>>,29reductions: Vec<Box<dyn GroupedReduction>>,30output_schema: Arc<Schema>,31) -> Self {32Self {33state: ReduceState::Sink {34selectors,35reductions,36},37output_schema,38}39}4041fn spawn_sink<'env, 's>(42selectors: &'env [Vec<StreamExpr>],43reductions: &'env mut [Box<dyn GroupedReduction>],44scope: &'s TaskScope<'s, 'env>,45recv: RecvPort<'_>,46state: &'s StreamingExecutionState,47join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,48) {49let parallel_tasks: Vec<_> = recv50.parallel()51.into_iter()52.map(|mut recv| {53let mut local_reducers: Vec<_> = reductions54.iter()55.map(|d| {56let mut r = d.new_empty();57r.resize(1);58r59})60.collect();6162scope.spawn_task(TaskPriority::High, async move {63let mut in_columns = Vec::new();64let mut in_column_refs = Vec::new();65while let Ok(morsel) = recv.recv().await {66for (reducer, selector_set) in local_reducers.iter_mut().zip(selectors) {67for selector in selector_set {68let col = selector69.evaluate(morsel.df(), &state.in_memory_exec_state)70.await?;71in_columns.push(col);72}73for c in in_columns.iter() {74in_column_refs.push(c);75}76reducer.update_group(&in_column_refs, 0, morsel.seq().to_u64())?;77in_column_refs.clear();78in_column_refs =79in_column_refs.into_iter().map(|_| unreachable!()).collect(); // Clear lifetimes.80in_columns.clear();81}82}8384PolarsResult::Ok(local_reducers)85})86})87.collect();8889join_handles.push(scope.spawn_task(TaskPriority::High, async move {90for task in parallel_tasks {91let local_reducers = task.await?;92for (r1, r2) in reductions.iter_mut().zip(local_reducers) {93r1.resize(1);94unsafe {95r1.combine_subset(&*r2, &[0], &[0])?;96}97}98}99100Ok(())101}));102}103104fn spawn_source<'env, 's>(105df: &'env mut Option<DataFrame>,106scope: &'s TaskScope<'s, 'env>,107send: SendPort<'_>,108join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,109) {110let mut send = send.serial();111join_handles.push(scope.spawn_task(TaskPriority::High, async move {112let morsel = Morsel::new(df.take().unwrap(), MorselSeq::new(0), SourceToken::new());113let _ = send.send(morsel).await;114Ok(())115}));116}117}118119impl ComputeNode for ReduceNode {120fn name(&self) -> &str {121"reduce"122}123124fn update_state(125&mut self,126recv: &mut [PortState],127send: &mut [PortState],128_state: &StreamingExecutionState,129) -> PolarsResult<()> {130assert!(recv.len() == 1 && send.len() == 1);131132// State transitions.133match &mut self.state {134// If the output doesn't want any more data, transition to being done.135_ if send[0] == PortState::Done => {136self.state = ReduceState::Done;137},138// Input is done, transition to being a source.139ReduceState::Sink { reductions, .. } if matches!(recv[0], PortState::Done) => {140let columns = reductions141.iter_mut()142.zip(self.output_schema.iter_fields())143.map(|(r, field)| {144r.resize(1);145r.finalize().map(|s| {146let s = s.with_name(field.name.clone());147Column::Scalar(ScalarColumn::unit_scalar_from_series(s))148})149})150.try_collect_vec()?;151let out = unsafe { DataFrame::new_unchecked(1, columns) };152153self.state = ReduceState::Source(Some(out));154},155// We have sent the reduced dataframe, we are done.156ReduceState::Source(df) if df.is_none() => {157self.state = ReduceState::Done;158},159// Nothing to change.160ReduceState::Done | ReduceState::Sink { .. } | ReduceState::Source(_) => {},161}162163// Communicate our state.164match &self.state {165ReduceState::Sink { .. } => {166send[0] = PortState::Blocked;167recv[0] = PortState::Ready;168},169ReduceState::Source(..) => {170recv[0] = PortState::Done;171send[0] = PortState::Ready;172},173ReduceState::Done => {174recv[0] = PortState::Done;175send[0] = PortState::Done;176},177}178Ok(())179}180181fn spawn<'env, 's>(182&'env mut self,183scope: &'s TaskScope<'s, 'env>,184recv_ports: &mut [Option<RecvPort<'_>>],185send_ports: &mut [Option<SendPort<'_>>],186state: &'s StreamingExecutionState,187join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,188) {189assert!(send_ports.len() == 1 && recv_ports.len() == 1);190match &mut self.state {191ReduceState::Sink {192selectors,193reductions,194} => {195assert!(send_ports[0].is_none());196let recv_port = recv_ports[0].take().unwrap();197Self::spawn_sink(selectors, reductions, scope, recv_port, state, join_handles)198},199ReduceState::Source(df) => {200assert!(recv_ports[0].is_none());201let send_port = send_ports[0].take().unwrap();202Self::spawn_source(df, scope, send_port, join_handles)203},204ReduceState::Done => unreachable!(),205}206}207}208209210