Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-stream/src/nodes/zip.rs
6939 views
1
use std::collections::VecDeque;
2
use std::sync::Arc;
3
4
use polars_core::functions::concat_df_horizontal;
5
use polars_core::schema::Schema;
6
use polars_core::series::Series;
7
use polars_error::polars_ensure;
8
use polars_utils::itertools::Itertools;
9
10
use super::compute_node_prelude::*;
11
use crate::DEFAULT_ZIP_HEAD_BUFFER_SIZE;
12
use crate::morsel::SourceToken;
13
14
/// The head of an input stream.
15
#[derive(Debug)]
16
struct InputHead {
17
/// The schema of the input, needed when creating full-null dataframes.
18
schema: Arc<Schema>,
19
20
// None when it is unknown whether this input stream is a broadcasting input or not.
21
is_broadcast: Option<bool>,
22
23
// True when there are no more morsels after the ones in the head.
24
stream_exhausted: bool,
25
26
// A FIFO queue of morsels belonging to this input stream.
27
morsels: VecDeque<Morsel>,
28
29
// The total length of the morsels in the input head.
30
total_len: usize,
31
}
32
33
impl InputHead {
34
fn new(schema: Arc<Schema>, may_broadcast: bool) -> Self {
35
Self {
36
schema,
37
morsels: VecDeque::new(),
38
is_broadcast: if may_broadcast { None } else { Some(false) },
39
total_len: 0,
40
stream_exhausted: false,
41
}
42
}
43
44
fn add_morsel(&mut self, mut morsel: Morsel) {
45
self.total_len += morsel.df().height();
46
47
if self.is_broadcast.is_none() {
48
if self.total_len > 1 {
49
self.is_broadcast = Some(false);
50
} else {
51
// Make sure we don't deadlock trying to wait to clear our ambiguous
52
// broadcast status.
53
drop(morsel.take_consume_token());
54
}
55
}
56
57
if morsel.df().height() > 0 {
58
self.morsels.push_back(morsel);
59
}
60
}
61
62
fn notify_no_more_morsels(&mut self) {
63
if self.is_broadcast.is_none() {
64
self.is_broadcast = Some(self.total_len == 1);
65
}
66
self.stream_exhausted = true;
67
}
68
69
fn ready_to_send(&self) -> bool {
70
self.is_broadcast.is_some() && (self.total_len > 0 || self.stream_exhausted)
71
}
72
73
fn min_len(&self) -> Option<usize> {
74
if self.is_broadcast == Some(false) {
75
self.morsels.front().map(|m| m.df().height())
76
} else {
77
None
78
}
79
}
80
81
fn take(&mut self, len: usize) -> DataFrame {
82
if self.is_broadcast.unwrap() {
83
self.morsels[0]
84
.df()
85
.iter()
86
.map(|s| s.new_from_index(0, len))
87
.collect()
88
} else if self.total_len > 0 {
89
self.total_len -= len;
90
if self.morsels[0].df().height() == len {
91
self.morsels.pop_front().unwrap().into_df()
92
} else {
93
let (head, tail) = self.morsels[0].df().split_at(len as i64);
94
*self.morsels[0].df_mut() = tail;
95
head
96
}
97
} else {
98
self.schema
99
.iter()
100
.map(|(name, dtype)| Series::full_null(name.clone(), len, dtype))
101
.collect()
102
}
103
}
104
105
fn consume_broadcast(&mut self) -> DataFrame {
106
assert!(self.is_broadcast == Some(true) && self.total_len == 1);
107
let out = self.morsels.pop_front().unwrap().into_df();
108
self.clear();
109
out
110
}
111
112
fn clear(&mut self) {
113
self.total_len = 0;
114
self.is_broadcast = Some(false);
115
self.morsels.clear();
116
}
117
}
118
119
pub struct ZipNode {
120
null_extend: bool,
121
out_seq: MorselSeq,
122
input_heads: Vec<InputHead>,
123
}
124
125
impl ZipNode {
126
pub fn new(null_extend: bool, schemas: Vec<Arc<Schema>>) -> Self {
127
let input_heads = schemas
128
.into_iter()
129
.map(|s| InputHead::new(s, !null_extend))
130
.collect();
131
Self {
132
null_extend,
133
out_seq: MorselSeq::new(0),
134
input_heads,
135
}
136
}
137
}
138
139
impl ComputeNode for ZipNode {
140
fn name(&self) -> &str {
141
if self.null_extend {
142
"zip-null-extend"
143
} else {
144
"zip"
145
}
146
}
147
148
fn update_state(
149
&mut self,
150
recv: &mut [PortState],
151
send: &mut [PortState],
152
_state: &StreamingExecutionState,
153
) -> PolarsResult<()> {
154
assert!(send.len() == 1);
155
assert!(recv.len() == self.input_heads.len());
156
157
let mut all_broadcast = true;
158
let mut all_done_or_broadcast = true;
159
let mut at_least_one_non_broadcast_done = false;
160
let mut at_least_one_non_broadcast_nonempty = false;
161
for (recv_idx, recv_state) in recv.iter().enumerate() {
162
let input_head = &mut self.input_heads[recv_idx];
163
if *recv_state == PortState::Done {
164
input_head.notify_no_more_morsels();
165
166
all_done_or_broadcast &=
167
input_head.is_broadcast == Some(true) || input_head.total_len == 0;
168
at_least_one_non_broadcast_done |=
169
input_head.is_broadcast == Some(false) && input_head.total_len == 0;
170
} else {
171
all_done_or_broadcast = false;
172
}
173
174
all_broadcast &= input_head.is_broadcast == Some(true);
175
at_least_one_non_broadcast_nonempty |=
176
input_head.is_broadcast == Some(false) && input_head.total_len > 0;
177
}
178
179
if !self.null_extend {
180
polars_ensure!(
181
!(at_least_one_non_broadcast_done && at_least_one_non_broadcast_nonempty),
182
ShapeMismatch: "zip node received non-equal length inputs"
183
);
184
}
185
186
let all_output_sent = all_done_or_broadcast && !all_broadcast;
187
188
// Are we completely done?
189
if send[0] == PortState::Done || all_output_sent {
190
for input_head in &mut self.input_heads {
191
input_head.clear();
192
}
193
send[0] = PortState::Done;
194
recv.fill(PortState::Done);
195
return Ok(());
196
}
197
198
let num_inputs_blocked = recv.iter().filter(|r| **r == PortState::Blocked).count();
199
send[0] = if num_inputs_blocked > 0 {
200
PortState::Blocked
201
} else {
202
PortState::Ready
203
};
204
205
let num_total_blocked = num_inputs_blocked + (send[0] == PortState::Blocked) as usize;
206
for r in recv {
207
let num_others_blocked = num_total_blocked - (*r == PortState::Blocked) as usize;
208
*r = if num_others_blocked > 0 {
209
PortState::Blocked
210
} else {
211
PortState::Ready
212
};
213
}
214
Ok(())
215
}
216
217
fn spawn<'env, 's>(
218
&'env mut self,
219
scope: &'s TaskScope<'s, 'env>,
220
recv_ports: &mut [Option<RecvPort<'_>>],
221
send_ports: &mut [Option<SendPort<'_>>],
222
_state: &'s StreamingExecutionState,
223
join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,
224
) {
225
assert!(send_ports.len() == 1);
226
assert!(!recv_ports.is_empty());
227
let mut sender = send_ports[0].take().unwrap().serial();
228
229
let mut receivers = recv_ports
230
.iter_mut()
231
.map(|recv_port| {
232
// Add buffering to each receiver to reduce contention between input heads.
233
let mut serial_recv = recv_port.take()?.serial();
234
let (buf_send, buf_recv) =
235
tokio::sync::mpsc::channel(*DEFAULT_ZIP_HEAD_BUFFER_SIZE);
236
join_handles.push(scope.spawn_task(TaskPriority::High, async move {
237
while let Ok(morsel) = serial_recv.recv().await {
238
if buf_send.send(morsel).await.is_err() {
239
break;
240
}
241
}
242
Ok(())
243
}));
244
Some(buf_recv)
245
})
246
.collect_vec();
247
248
join_handles.push(scope.spawn_task(TaskPriority::High, async move {
249
let mut out = Vec::new();
250
let source_token = SourceToken::new();
251
loop {
252
if source_token.stop_requested() {
253
break;
254
}
255
256
// Fill input heads until they are ready to send or the input is
257
// exhausted (in this phase).
258
let mut all_ready = true;
259
for (recv_idx, opt_recv) in receivers.iter_mut().enumerate() {
260
if let Some(recv) = opt_recv {
261
while !self.input_heads[recv_idx].ready_to_send() {
262
if let Some(morsel) = recv.recv().await {
263
self.input_heads[recv_idx].add_morsel(morsel);
264
} else {
265
break;
266
}
267
}
268
}
269
all_ready &= self.input_heads[recv_idx].ready_to_send();
270
}
271
272
if !all_ready {
273
// One or more of the input heads is exhausted (this phase).
274
break;
275
}
276
277
// TODO: recombine morsels to make sure the concatenation is
278
// close to the ideal morsel size.
279
280
// Compute common size and send a combined morsel.
281
let Some(common_size) = self.input_heads.iter().flat_map(|h| h.min_len()).min()
282
else {
283
// If all input heads are broadcasts we don't get a common size,
284
// we handle this below.
285
break;
286
};
287
288
for input_head in &mut self.input_heads {
289
out.push(input_head.take(common_size));
290
}
291
let out_df = concat_df_horizontal(&out, false)?;
292
out.clear();
293
294
let morsel = Morsel::new(out_df, self.out_seq, source_token.clone());
295
self.out_seq = self.out_seq.successor();
296
if sender.send(morsel).await.is_err() {
297
// Our receiver is no longer interested in any data, no
298
// need store the rest of the incoming stream, can directly
299
// return.
300
return Ok(());
301
}
302
}
303
304
// We can't continue because one or more input heads is empty or all
305
// inputs are broadcasts. We must tell everyone to stop, unblock all
306
// pipes by consuming all ConsumeTokens, and then store all data
307
// that was still flowing through the pipelines into input_heads for
308
// the next phase.
309
for input_head in &mut self.input_heads {
310
for morsel in &mut input_head.morsels {
311
morsel.source_token().stop();
312
drop(morsel.take_consume_token());
313
}
314
}
315
316
for (recv_idx, opt_recv) in receivers.iter_mut().enumerate() {
317
if let Some(recv) = opt_recv {
318
while let Some(mut morsel) = recv.recv().await {
319
morsel.source_token().stop();
320
drop(morsel.take_consume_token());
321
self.input_heads[recv_idx].add_morsel(morsel);
322
}
323
}
324
}
325
326
// If all our input heads are broadcasts we need to send a morsel
327
// once with their output, consuming all broadcast inputs.
328
let all_broadcast = self
329
.input_heads
330
.iter()
331
.all(|h| h.is_broadcast == Some(true));
332
if all_broadcast {
333
for input_head in &mut self.input_heads {
334
out.push(input_head.consume_broadcast());
335
}
336
let out_df = concat_df_horizontal(&out, false)?;
337
out.clear();
338
339
let morsel = Morsel::new(out_df, self.out_seq, source_token.clone());
340
self.out_seq = self.out_seq.successor();
341
let _ = sender.send(morsel).await;
342
}
343
344
Ok(())
345
}));
346
}
347
}
348
349