Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-stream/src/nodes/zip.rs
8424 views
1
use std::collections::VecDeque;
2
use std::sync::Arc;
3
4
use polars_core::functions::concat_df_horizontal;
5
use polars_core::prelude::{Column, IntoColumn};
6
use polars_core::schema::Schema;
7
use polars_core::series::Series;
8
use polars_error::polars_ensure;
9
use polars_utils::itertools::Itertools;
10
11
use super::compute_node_prelude::*;
12
use crate::DEFAULT_ZIP_HEAD_BUFFER_SIZE;
13
use crate::morsel::SourceToken;
14
use crate::physical_plan::ZipBehavior;
15
16
/// The head of an input stream.
17
#[derive(Debug)]
18
struct InputHead {
19
/// The schema of the input, needed when creating full-null dataframes.
20
schema: Arc<Schema>,
21
22
// None when it is unknown whether this input stream is a broadcasting input or not.
23
is_broadcast: Option<bool>,
24
25
// True when there are no more morsels after the ones in the head.
26
stream_exhausted: bool,
27
28
// A FIFO queue of morsels belonging to this input stream.
29
morsels: VecDeque<Morsel>,
30
31
// The total length of the morsels in the input head.
32
total_len: usize,
33
}
34
35
impl InputHead {
36
fn new(schema: Arc<Schema>, may_broadcast: bool) -> Self {
37
Self {
38
schema,
39
morsels: VecDeque::new(),
40
is_broadcast: if may_broadcast { None } else { Some(false) },
41
total_len: 0,
42
stream_exhausted: false,
43
}
44
}
45
46
fn add_morsel(&mut self, mut morsel: Morsel) {
47
self.total_len += morsel.df().height();
48
49
if self.is_broadcast.is_none() {
50
if self.total_len > 1 {
51
self.is_broadcast = Some(false);
52
} else {
53
// Make sure we don't deadlock trying to wait to clear our ambiguous
54
// broadcast status.
55
drop(morsel.take_consume_token());
56
}
57
}
58
59
if morsel.df().height() > 0 {
60
self.morsels.push_back(morsel);
61
}
62
}
63
64
fn notify_no_more_morsels(&mut self) {
65
if self.is_broadcast.is_none() {
66
self.is_broadcast = Some(self.total_len == 1);
67
}
68
self.stream_exhausted = true;
69
}
70
71
fn ready_to_send(&self) -> bool {
72
self.is_broadcast.is_some() && (self.total_len > 0 || self.stream_exhausted)
73
}
74
75
fn min_len(&self) -> Option<usize> {
76
if self.is_broadcast == Some(false) {
77
self.morsels.front().map(|m| m.df().height())
78
} else {
79
None
80
}
81
}
82
83
fn take(&mut self, len: usize) -> DataFrame {
84
let columns: Vec<Column> = if self.is_broadcast.unwrap() {
85
self.morsels[0]
86
.df()
87
.columns()
88
.iter()
89
.map(|s| s.new_from_index(0, len))
90
.collect()
91
} else if self.total_len > 0 {
92
self.total_len -= len;
93
94
return if self.morsels[0].df().height() == len {
95
self.morsels.pop_front().unwrap().into_df()
96
} else {
97
let (head, tail) = self.morsels[0].df().split_at(len as i64);
98
*self.morsels[0].df_mut() = tail;
99
head
100
};
101
} else {
102
self.schema
103
.iter()
104
.map(|(name, dtype)| Series::full_null(name.clone(), len, dtype).into_column())
105
.collect()
106
};
107
108
unsafe { DataFrame::new_unchecked(len, columns) }
109
}
110
111
fn consume_broadcast(&mut self) -> DataFrame {
112
assert!(self.is_broadcast == Some(true) && self.total_len == 1);
113
let out = self.morsels.pop_front().unwrap().into_df();
114
self.clear();
115
out
116
}
117
118
fn clear(&mut self) {
119
self.total_len = 0;
120
self.is_broadcast = Some(false);
121
self.morsels.clear();
122
}
123
}
124
125
pub struct ZipNode {
126
zip_behavior: ZipBehavior,
127
out_seq: MorselSeq,
128
input_heads: Vec<InputHead>,
129
}
130
131
impl ZipNode {
132
pub fn new(zip_behavior: ZipBehavior, schemas: Vec<Arc<Schema>>) -> Self {
133
let input_heads = schemas
134
.into_iter()
135
.map(|s| InputHead::new(s, matches!(zip_behavior, ZipBehavior::Broadcast)))
136
.collect();
137
Self {
138
zip_behavior,
139
out_seq: MorselSeq::new(0),
140
input_heads,
141
}
142
}
143
}
144
145
impl ComputeNode for ZipNode {
146
fn name(&self) -> &str {
147
match self.zip_behavior {
148
ZipBehavior::NullExtend => "zip-null-extend",
149
ZipBehavior::Broadcast => "zip-broadcast",
150
ZipBehavior::Strict => "zip-strict",
151
}
152
}
153
154
fn update_state(
155
&mut self,
156
recv: &mut [PortState],
157
send: &mut [PortState],
158
_state: &StreamingExecutionState,
159
) -> PolarsResult<()> {
160
assert!(send.len() == 1);
161
assert!(recv.len() == self.input_heads.len());
162
163
let mut all_broadcast = true;
164
let mut all_done_or_broadcast = true;
165
let mut at_least_one_non_broadcast_done = false;
166
let mut at_least_one_non_broadcast_nonempty = false;
167
for (recv_idx, recv_state) in recv.iter().enumerate() {
168
let input_head = &mut self.input_heads[recv_idx];
169
if *recv_state == PortState::Done {
170
input_head.notify_no_more_morsels();
171
172
all_done_or_broadcast &=
173
input_head.is_broadcast == Some(true) || input_head.total_len == 0;
174
at_least_one_non_broadcast_done |=
175
input_head.is_broadcast == Some(false) && input_head.total_len == 0;
176
} else {
177
all_done_or_broadcast = false;
178
}
179
180
all_broadcast &= input_head.is_broadcast == Some(true);
181
at_least_one_non_broadcast_nonempty |=
182
input_head.is_broadcast == Some(false) && input_head.total_len > 0;
183
}
184
185
match self.zip_behavior {
186
ZipBehavior::Broadcast => {
187
polars_ensure!(
188
!(at_least_one_non_broadcast_done && at_least_one_non_broadcast_nonempty),
189
ShapeMismatch: "zip node received non-equal length inputs"
190
);
191
},
192
ZipBehavior::Strict => {
193
if let Some(first_len) = self.input_heads.first().map(|h| h.total_len) {
194
let all_len_equal = self.input_heads.iter().all(|h| h.total_len == first_len);
195
polars_ensure!(
196
all_len_equal,
197
ShapeMismatch: "zip node received non-equal length inputs"
198
);
199
}
200
},
201
ZipBehavior::NullExtend => {},
202
}
203
204
let all_output_sent = all_done_or_broadcast && !all_broadcast;
205
206
// Are we completely done?
207
if send[0] == PortState::Done || all_output_sent {
208
for input_head in &mut self.input_heads {
209
input_head.clear();
210
}
211
send[0] = PortState::Done;
212
recv.fill(PortState::Done);
213
return Ok(());
214
}
215
216
let num_inputs_blocked = recv.iter().filter(|r| **r == PortState::Blocked).count();
217
send[0] = if num_inputs_blocked > 0 {
218
PortState::Blocked
219
} else {
220
PortState::Ready
221
};
222
223
let num_total_blocked = num_inputs_blocked + (send[0] == PortState::Blocked) as usize;
224
for r in recv {
225
let num_others_blocked = num_total_blocked - (*r == PortState::Blocked) as usize;
226
*r = if num_others_blocked > 0 {
227
PortState::Blocked
228
} else {
229
PortState::Ready
230
};
231
}
232
Ok(())
233
}
234
235
fn spawn<'env, 's>(
236
&'env mut self,
237
scope: &'s TaskScope<'s, 'env>,
238
recv_ports: &mut [Option<RecvPort<'_>>],
239
send_ports: &mut [Option<SendPort<'_>>],
240
_state: &'s StreamingExecutionState,
241
join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,
242
) {
243
assert!(send_ports.len() == 1);
244
assert!(!recv_ports.is_empty());
245
let mut sender = send_ports[0].take().unwrap().serial();
246
247
let mut receivers = recv_ports
248
.iter_mut()
249
.map(|recv_port| {
250
// Add buffering to each receiver to reduce contention between input heads.
251
let mut serial_recv = recv_port.take()?.serial();
252
let (buf_send, buf_recv) =
253
tokio::sync::mpsc::channel(*DEFAULT_ZIP_HEAD_BUFFER_SIZE);
254
join_handles.push(scope.spawn_task(TaskPriority::High, async move {
255
while let Ok(morsel) = serial_recv.recv().await {
256
if buf_send.send(morsel).await.is_err() {
257
break;
258
}
259
}
260
Ok(())
261
}));
262
Some(buf_recv)
263
})
264
.collect_vec();
265
266
join_handles.push(scope.spawn_task(TaskPriority::High, async move {
267
let mut out = Vec::new();
268
let source_token = SourceToken::new();
269
loop {
270
if source_token.stop_requested() {
271
break;
272
}
273
274
// Fill input heads until they are ready to send or the input is
275
// exhausted (in this phase).
276
let mut all_ready = true;
277
for (recv_idx, opt_recv) in receivers.iter_mut().enumerate() {
278
if let Some(recv) = opt_recv {
279
while !self.input_heads[recv_idx].ready_to_send() {
280
if let Some(morsel) = recv.recv().await {
281
self.input_heads[recv_idx].add_morsel(morsel);
282
} else {
283
break;
284
}
285
}
286
}
287
all_ready &= self.input_heads[recv_idx].ready_to_send();
288
}
289
290
if !all_ready {
291
// One or more of the input heads is exhausted (this phase).
292
break;
293
}
294
295
// TODO: recombine morsels to make sure the concatenation is
296
// close to the ideal morsel size.
297
298
// Compute common size and send a combined morsel.
299
let Some(common_size) = self.input_heads.iter().flat_map(|h| h.min_len()).min()
300
else {
301
// If all input heads are broadcasts we don't get a common size,
302
// we handle this below.
303
break;
304
};
305
306
for input_head in &mut self.input_heads {
307
out.push(input_head.take(common_size));
308
}
309
let out_df = concat_df_horizontal(&out, false, true, false)?;
310
out.clear();
311
312
let morsel = Morsel::new(out_df, self.out_seq, source_token.clone());
313
self.out_seq = self.out_seq.successor();
314
if sender.send(morsel).await.is_err() {
315
// Our receiver is no longer interested in any data, no
316
// need store the rest of the incoming stream, can directly
317
// return.
318
return Ok(());
319
}
320
}
321
322
// We can't continue because one or more input heads is empty or all
323
// inputs are broadcasts. We must tell everyone to stop, unblock all
324
// pipes by consuming all ConsumeTokens, and then store all data
325
// that was still flowing through the pipelines into input_heads for
326
// the next phase.
327
for input_head in &mut self.input_heads {
328
for morsel in &mut input_head.morsels {
329
morsel.source_token().stop();
330
drop(morsel.take_consume_token());
331
}
332
}
333
334
for (recv_idx, opt_recv) in receivers.iter_mut().enumerate() {
335
if let Some(recv) = opt_recv {
336
while let Some(mut morsel) = recv.recv().await {
337
morsel.source_token().stop();
338
drop(morsel.take_consume_token());
339
self.input_heads[recv_idx].add_morsel(morsel);
340
}
341
}
342
}
343
344
// If all our input heads are broadcasts we need to send a morsel
345
// once with their output, consuming all broadcast inputs.
346
let all_broadcast = self
347
.input_heads
348
.iter()
349
.all(|h| h.is_broadcast == Some(true));
350
if all_broadcast {
351
for input_head in &mut self.input_heads {
352
out.push(input_head.consume_broadcast());
353
}
354
let out_df = concat_df_horizontal(&out, false, true, false)?;
355
out.clear();
356
357
let morsel = Morsel::new(out_df, self.out_seq, source_token.clone());
358
self.out_seq = self.out_seq.successor();
359
let _ = sender.send(morsel).await;
360
}
361
362
Ok(())
363
}));
364
}
365
}
366
367