Path: blob/main/crates/polars-stream/src/nodes/ordered_union.rs
8422 views
use std::sync::Arc;12use polars_core::schema::Schema;34use super::compute_node_prelude::*;56/// A node that first passes through all data from the first input, then the7/// second input, etc.8pub struct OrderedUnionNode {9cur_input_idx: usize,10max_morsel_seq_sent: MorselSeq,11morsel_offset: MorselSeq,12output_schema: Arc<Schema>,13}1415impl OrderedUnionNode {16pub fn new(output_schema: Arc<Schema>) -> Self {17Self {18cur_input_idx: 0,19max_morsel_seq_sent: MorselSeq::new(0),20morsel_offset: MorselSeq::new(0),21output_schema,22}23}24}2526impl ComputeNode for OrderedUnionNode {27fn name(&self) -> &str {28"ordered-union"29}3031fn update_state(32&mut self,33recv: &mut [PortState],34send: &mut [PortState],35_state: &StreamingExecutionState,36) -> PolarsResult<()> {37assert!(self.cur_input_idx <= recv.len() && send.len() == 1);3839// Skip inputs that are done.40while self.cur_input_idx < recv.len() && recv[self.cur_input_idx] == PortState::Done {41self.cur_input_idx += 1;42}4344// Act like a normal pass-through node for the current input, or mark45// ourselves as done if all inputs are handled.46if self.cur_input_idx < recv.len() {47core::mem::swap(&mut recv[self.cur_input_idx], &mut send[0]);48} else {49send[0] = PortState::Done;50}5152// Mark all inputs after the current one as blocked.53for r in recv.iter_mut().skip(self.cur_input_idx + 1) {54*r = PortState::Blocked;55}5657// Set the morsel offset one higher than any sent so far.58self.morsel_offset = self.max_morsel_seq_sent.successor();59Ok(())60}6162fn spawn<'env, 's>(63&'env mut self,64scope: &'s TaskScope<'s, 'env>,65recv_ports: &mut [Option<RecvPort<'_>>],66send_ports: &mut [Option<SendPort<'_>>],67_state: &'s StreamingExecutionState,68join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,69) {70let ready_count = recv_ports.iter().filter(|r| r.is_some()).count();71assert!(ready_count == 1 && send_ports.len() == 1);72let receivers = recv_ports[self.cur_input_idx].take().unwrap().parallel();73let senders = send_ports[0].take().unwrap().parallel();7475let mut inner_handles = Vec::new();76for (mut recv, mut send) in receivers.into_iter().zip(senders) {77let output_schema = self.output_schema.clone();78let morsel_offset = self.morsel_offset;79inner_handles.push(scope.spawn_task(TaskPriority::High, async move {80let mut max_seq = MorselSeq::new(0);81while let Ok(mut morsel) = recv.recv().await {82// Ensure the morsel matches the expected output schema,83// casting nulls to the appropriate output type.84morsel.df_mut().ensure_matches_schema(&output_schema)?;8586// Ensure the morsel sequence id stream is monotonic.87let seq = morsel.seq().offset_by(morsel_offset);88max_seq = max_seq.max(seq);8990morsel.set_seq(seq);91if send.send(morsel).await.is_err() {92break;93}94}95PolarsResult::Ok(max_seq)96}));97}9899join_handles.push(scope.spawn_task(TaskPriority::High, async move {100// Update our global maximum.101for handle in inner_handles {102self.max_morsel_seq_sent = self.max_morsel_seq_sent.max(handle.await?);103}104Ok(())105}));106}107}108109110