Path: blob/main/crates/polars-stream/src/nodes/unordered_union.rs
8430 views
use std::sync::Arc;12use polars_core::schema::Schema;3use tokio::sync::mpsc;45use super::compute_node_prelude::*;67pub struct UnorderedUnionNode {8max_morsel_seq_sent: MorselSeq,9output_schema: Arc<Schema>,10}1112impl UnorderedUnionNode {13pub fn new(output_schema: Arc<Schema>) -> Self {14Self {15max_morsel_seq_sent: MorselSeq::new(0),16output_schema,17}18}19}2021impl ComputeNode for UnorderedUnionNode {22fn name(&self) -> &str {23"unordered-union"24}2526fn update_state(27&mut self,28recv: &mut [PortState],29send: &mut [PortState],30_state: &StreamingExecutionState,31) -> PolarsResult<()> {32assert_eq!(send.len(), 1);3334let done = send[0] == PortState::Done || recv.iter().all(|r| *r == PortState::Done);35if done {36send[0] = PortState::Done;37recv.fill(PortState::Done);38return Ok(());39}4041let any_ready = recv.contains(&PortState::Ready);42recv.fill(send[0]);43send[0] = if any_ready {44PortState::Ready45} else {46PortState::Blocked47};48Ok(())49}5051fn spawn<'env, 's>(52&'env mut self,53scope: &'s TaskScope<'s, 'env>,54recv_ports: &mut [Option<RecvPort<'_>>],55send_ports: &mut [Option<SendPort<'_>>],56state: &'s StreamingExecutionState,57join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,58) {59assert_eq!(send_ports.len(), 1);60let output_senders = send_ports[0].take().unwrap().parallel();61let num_pipelines = output_senders.len();62assert_eq!(num_pipelines, state.num_pipelines);6364let (mpsc_senders, mpsc_receivers): (Vec<_>, Vec<_>) = (0..num_pipelines)65.map(|_| mpsc::channel::<Morsel>(1))66.unzip();6768for recv_port in recv_ports {69if let Some(recv) = recv_port.take() {70let receivers = recv.parallel();71let mpsc_senders_clone = mpsc_senders.clone();7273for (mut receiver, sender) in receivers.into_iter().zip(mpsc_senders_clone) {74let output_schema = self.output_schema.clone();75join_handles.push(scope.spawn_task(TaskPriority::High, async move {76while let Ok(mut morsel) = receiver.recv().await {77// Ensure the morsel matches the expected output schema,78// casting nulls to the appropriate output type.79morsel.df_mut().ensure_matches_schema(&output_schema)?;8081if sender.send(morsel).await.is_err() {82break;83}84}85PolarsResult::Ok(())86}));87}88}89}9091drop(mpsc_senders);9293// Each pipeline relabels morsel sequences independently of the others.94// We first compute the `morsel_offset` as (max morsel sequence sent so far + 1), so this95// phase never reuses sequence numbers from earlier phases.96//97// Then, each pipeline assigns sequences by:98// - starting at `morsel_offset + pipeline_idx` (so pipelines start at different values),99// - advancing by `num_pipelines` each time it emits a morsel.100//101// Example with 2 pipelines (num_pipelines = 2) and morsel_offset = 1000:102// pipeline 0: 1000, 1002, 1004, ...103// pipeline 1: 1001, 1003, 1005, ...104//105// This guarantees:106// - Global uniqueness: no collisions with earlier phases, and no collisions across pipelines.107// - Per-pipeline non-decreasing: each pipeline only moves forward by a fixed positive step.108let morsel_offset = self.max_morsel_seq_sent.successor();109110let mut inner_handles = Vec::new();111for (lane_idx, (mut mpsc_receiver, mut output_sender)) in112mpsc_receivers.into_iter().zip(output_senders).enumerate()113{114inner_handles.push(scope.spawn_task(TaskPriority::High, async move {115let mut local_seq = morsel_offset.offset_by_u64(lane_idx as u64);116let seq_step = num_pipelines as u64;117let mut max_seq = MorselSeq::new(0);118119while let Some(mut morsel) = mpsc_receiver.recv().await {120morsel.set_seq(local_seq);121max_seq = max_seq.max(local_seq);122local_seq = local_seq.offset_by_u64(seq_step);123124if output_sender.send(morsel).await.is_err() {125break;126}127}128129PolarsResult::Ok(max_seq)130}));131}132133join_handles.push(scope.spawn_task(TaskPriority::High, async move {134for handle in inner_handles {135self.max_morsel_seq_sent = self.max_morsel_seq_sent.max(handle.await?);136}137Ok(())138}));139}140}141142143