Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-stream/src/nodes/multiplexer.rs
6939 views
1
use std::collections::VecDeque;
2
3
use tokio::sync::mpsc::{UnboundedSender, unbounded_channel};
4
5
use super::compute_node_prelude::*;
6
use crate::async_primitives::wait_group::WaitGroup;
7
use crate::morsel::SourceToken;
8
9
// TODO: replace this with an out-of-core buffering solution.
10
enum BufferedStream {
11
Open(VecDeque<Morsel>),
12
Closed,
13
}
14
15
impl BufferedStream {
16
fn new() -> Self {
17
Self::Open(VecDeque::new())
18
}
19
}
20
21
pub struct MultiplexerNode {
22
buffers: Vec<BufferedStream>,
23
}
24
25
impl MultiplexerNode {
26
pub fn new() -> Self {
27
Self {
28
buffers: Vec::default(),
29
}
30
}
31
}
32
33
impl ComputeNode for MultiplexerNode {
34
fn name(&self) -> &str {
35
"multiplexer"
36
}
37
38
fn update_state(
39
&mut self,
40
recv: &mut [PortState],
41
send: &mut [PortState],
42
_state: &StreamingExecutionState,
43
) -> PolarsResult<()> {
44
assert!(recv.len() == 1 && !send.is_empty());
45
46
// Initialize buffered streams, and mark those for which the receiver
47
// is no longer interested as closed.
48
self.buffers.resize_with(send.len(), BufferedStream::new);
49
for (s, b) in send.iter().zip(&mut self.buffers) {
50
if *s == PortState::Done {
51
*b = BufferedStream::Closed;
52
}
53
}
54
55
// Check if either the input is done, or all outputs are done.
56
let input_done = recv[0] == PortState::Done
57
&& self.buffers.iter().all(|b| match b {
58
BufferedStream::Open(v) => v.is_empty(),
59
BufferedStream::Closed => true,
60
});
61
let output_done = send.iter().all(|p| *p == PortState::Done);
62
63
// If either side is done, everything is done.
64
if input_done || output_done {
65
recv[0] = PortState::Done;
66
for s in send {
67
*s = PortState::Done;
68
}
69
return Ok(());
70
}
71
72
let all_blocked = send.iter().all(|p| *p == PortState::Blocked);
73
74
// Pass along the input state to the output.
75
for (i, s) in send.iter_mut().enumerate() {
76
let buffer_empty = match &self.buffers[i] {
77
BufferedStream::Open(v) => v.is_empty(),
78
BufferedStream::Closed => true,
79
};
80
*s = if buffer_empty && recv[0] == PortState::Done {
81
PortState::Done
82
} else if !buffer_empty || recv[0] == PortState::Ready {
83
PortState::Ready
84
} else {
85
PortState::Blocked
86
};
87
}
88
89
// We say we are ready to receive unless all outputs are blocked.
90
recv[0] = if all_blocked {
91
PortState::Blocked
92
} else {
93
PortState::Ready
94
};
95
Ok(())
96
}
97
98
fn spawn<'env, 's>(
99
&'env mut self,
100
scope: &'s TaskScope<'s, 'env>,
101
recv_ports: &mut [Option<RecvPort<'_>>],
102
send_ports: &mut [Option<SendPort<'_>>],
103
_state: &'s StreamingExecutionState,
104
join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,
105
) {
106
assert!(recv_ports.len() == 1 && !send_ports.is_empty());
107
assert!(self.buffers.len() == send_ports.len());
108
109
enum Listener<'a> {
110
Active(UnboundedSender<Morsel>),
111
Buffering(&'a mut VecDeque<Morsel>),
112
Inactive,
113
}
114
115
let buffered_source_token = SourceToken::new();
116
117
let (mut buf_senders, buf_receivers): (Vec<_>, Vec<_>) = self
118
.buffers
119
.iter_mut()
120
.enumerate()
121
.map(|(port_idx, buffer)| {
122
if let BufferedStream::Open(buf) = buffer {
123
if send_ports[port_idx].is_some() {
124
// TODO: replace with a bounded channel and store data
125
// out-of-core beyond a certain size.
126
let (rx, tx) = unbounded_channel();
127
(Listener::Active(rx), Some((buf, tx)))
128
} else {
129
(Listener::Buffering(buf), None)
130
}
131
} else {
132
(Listener::Inactive, None)
133
}
134
})
135
.unzip();
136
137
// TODO: parallel multiplexing.
138
if let Some(mut receiver) = recv_ports[0].take().map(|r| r.serial()) {
139
let buffered_source_token = buffered_source_token.clone();
140
join_handles.push(scope.spawn_task(TaskPriority::High, async move {
141
loop {
142
let Ok(mut morsel) = receiver.recv().await else {
143
break;
144
};
145
drop(morsel.take_consume_token());
146
147
let mut anyone_interested = false;
148
let mut active_listener_interested = false;
149
for buf_sender in &mut buf_senders {
150
match buf_sender {
151
Listener::Active(s) => match s.send(morsel.clone()) {
152
Ok(_) => {
153
anyone_interested = true;
154
active_listener_interested = true;
155
},
156
Err(_) => *buf_sender = Listener::Inactive,
157
},
158
Listener::Buffering(b) => {
159
b.push_front(morsel.clone());
160
anyone_interested = true;
161
},
162
Listener::Inactive => {},
163
}
164
}
165
166
if !anyone_interested {
167
break;
168
}
169
170
// If only buffering inputs are left, or we got a stop
171
// request from an input reading from old buffered data,
172
// request a stop from the source.
173
if !active_listener_interested || buffered_source_token.stop_requested() {
174
morsel.source_token().stop();
175
}
176
}
177
178
Ok(())
179
}));
180
}
181
182
for (send_port, opt_buf_recv) in send_ports.iter_mut().zip(buf_receivers) {
183
if let Some((buf, mut rx)) = opt_buf_recv {
184
let mut sender = send_port.take().unwrap().serial();
185
186
let wait_group = WaitGroup::default();
187
let buffered_source_token = buffered_source_token.clone();
188
join_handles.push(scope.spawn_task(TaskPriority::High, async move {
189
// First we try to flush all the old buffered data.
190
while let Some(mut morsel) = buf.pop_back() {
191
morsel.replace_source_token(buffered_source_token.clone());
192
morsel.set_consume_token(wait_group.token());
193
if sender.send(morsel).await.is_err()
194
|| buffered_source_token.stop_requested()
195
{
196
break;
197
}
198
wait_group.wait().await;
199
}
200
201
// Then send along data from the multiplexer.
202
while let Some(mut morsel) = rx.recv().await {
203
morsel.set_consume_token(wait_group.token());
204
if sender.send(morsel).await.is_err() {
205
break;
206
}
207
wait_group.wait().await;
208
}
209
Ok(())
210
}));
211
}
212
}
213
}
214
}
215
216