Path: blob/main/crates/polars-stream/src/nodes/multiplexer.rs
6939 views
use std::collections::VecDeque;12use tokio::sync::mpsc::{UnboundedSender, unbounded_channel};34use super::compute_node_prelude::*;5use crate::async_primitives::wait_group::WaitGroup;6use crate::morsel::SourceToken;78// TODO: replace this with an out-of-core buffering solution.9enum BufferedStream {10Open(VecDeque<Morsel>),11Closed,12}1314impl BufferedStream {15fn new() -> Self {16Self::Open(VecDeque::new())17}18}1920pub struct MultiplexerNode {21buffers: Vec<BufferedStream>,22}2324impl MultiplexerNode {25pub fn new() -> Self {26Self {27buffers: Vec::default(),28}29}30}3132impl ComputeNode for MultiplexerNode {33fn name(&self) -> &str {34"multiplexer"35}3637fn update_state(38&mut self,39recv: &mut [PortState],40send: &mut [PortState],41_state: &StreamingExecutionState,42) -> PolarsResult<()> {43assert!(recv.len() == 1 && !send.is_empty());4445// Initialize buffered streams, and mark those for which the receiver46// is no longer interested as closed.47self.buffers.resize_with(send.len(), BufferedStream::new);48for (s, b) in send.iter().zip(&mut self.buffers) {49if *s == PortState::Done {50*b = BufferedStream::Closed;51}52}5354// Check if either the input is done, or all outputs are done.55let input_done = recv[0] == PortState::Done56&& self.buffers.iter().all(|b| match b {57BufferedStream::Open(v) => v.is_empty(),58BufferedStream::Closed => true,59});60let output_done = send.iter().all(|p| *p == PortState::Done);6162// If either side is done, everything is done.63if input_done || output_done {64recv[0] = PortState::Done;65for s in send {66*s = PortState::Done;67}68return Ok(());69}7071let all_blocked = send.iter().all(|p| *p == PortState::Blocked);7273// Pass along the input state to the output.74for (i, s) in send.iter_mut().enumerate() {75let buffer_empty = match &self.buffers[i] {76BufferedStream::Open(v) => v.is_empty(),77BufferedStream::Closed => true,78};79*s = if buffer_empty && recv[0] == PortState::Done {80PortState::Done81} else if !buffer_empty || recv[0] == PortState::Ready {82PortState::Ready83} else {84PortState::Blocked85};86}8788// We say we are ready to receive unless all outputs are blocked.89recv[0] = if all_blocked {90PortState::Blocked91} else {92PortState::Ready93};94Ok(())95}9697fn spawn<'env, 's>(98&'env mut self,99scope: &'s TaskScope<'s, 'env>,100recv_ports: &mut [Option<RecvPort<'_>>],101send_ports: &mut [Option<SendPort<'_>>],102_state: &'s StreamingExecutionState,103join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,104) {105assert!(recv_ports.len() == 1 && !send_ports.is_empty());106assert!(self.buffers.len() == send_ports.len());107108enum Listener<'a> {109Active(UnboundedSender<Morsel>),110Buffering(&'a mut VecDeque<Morsel>),111Inactive,112}113114let buffered_source_token = SourceToken::new();115116let (mut buf_senders, buf_receivers): (Vec<_>, Vec<_>) = self117.buffers118.iter_mut()119.enumerate()120.map(|(port_idx, buffer)| {121if let BufferedStream::Open(buf) = buffer {122if send_ports[port_idx].is_some() {123// TODO: replace with a bounded channel and store data124// out-of-core beyond a certain size.125let (rx, tx) = unbounded_channel();126(Listener::Active(rx), Some((buf, tx)))127} else {128(Listener::Buffering(buf), None)129}130} else {131(Listener::Inactive, None)132}133})134.unzip();135136// TODO: parallel multiplexing.137if let Some(mut receiver) = recv_ports[0].take().map(|r| r.serial()) {138let buffered_source_token = buffered_source_token.clone();139join_handles.push(scope.spawn_task(TaskPriority::High, async move {140loop {141let Ok(mut morsel) = receiver.recv().await else {142break;143};144drop(morsel.take_consume_token());145146let mut anyone_interested = false;147let mut active_listener_interested = false;148for buf_sender in &mut buf_senders {149match buf_sender {150Listener::Active(s) => match s.send(morsel.clone()) {151Ok(_) => {152anyone_interested = true;153active_listener_interested = true;154},155Err(_) => *buf_sender = Listener::Inactive,156},157Listener::Buffering(b) => {158b.push_front(morsel.clone());159anyone_interested = true;160},161Listener::Inactive => {},162}163}164165if !anyone_interested {166break;167}168169// If only buffering inputs are left, or we got a stop170// request from an input reading from old buffered data,171// request a stop from the source.172if !active_listener_interested || buffered_source_token.stop_requested() {173morsel.source_token().stop();174}175}176177Ok(())178}));179}180181for (send_port, opt_buf_recv) in send_ports.iter_mut().zip(buf_receivers) {182if let Some((buf, mut rx)) = opt_buf_recv {183let mut sender = send_port.take().unwrap().serial();184185let wait_group = WaitGroup::default();186let buffered_source_token = buffered_source_token.clone();187join_handles.push(scope.spawn_task(TaskPriority::High, async move {188// First we try to flush all the old buffered data.189while let Some(mut morsel) = buf.pop_back() {190morsel.replace_source_token(buffered_source_token.clone());191morsel.set_consume_token(wait_group.token());192if sender.send(morsel).await.is_err()193|| buffered_source_token.stop_requested()194{195break;196}197wait_group.wait().await;198}199200// Then send along data from the multiplexer.201while let Some(mut morsel) = rx.recv().await {202morsel.set_consume_token(wait_group.token());203if sender.send(morsel).await.is_err() {204break;205}206wait_group.wait().await;207}208Ok(())209}));210}211}212}213}214215216