Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-stream/src/nodes/ordered_union.rs
8422 views
1
use std::sync::Arc;
2
3
use polars_core::schema::Schema;
4
5
use super::compute_node_prelude::*;
6
7
/// A node that first passes through all data from the first input, then the
8
/// second input, etc.
9
pub struct OrderedUnionNode {
10
cur_input_idx: usize,
11
max_morsel_seq_sent: MorselSeq,
12
morsel_offset: MorselSeq,
13
output_schema: Arc<Schema>,
14
}
15
16
impl OrderedUnionNode {
17
pub fn new(output_schema: Arc<Schema>) -> Self {
18
Self {
19
cur_input_idx: 0,
20
max_morsel_seq_sent: MorselSeq::new(0),
21
morsel_offset: MorselSeq::new(0),
22
output_schema,
23
}
24
}
25
}
26
27
impl ComputeNode for OrderedUnionNode {
28
fn name(&self) -> &str {
29
"ordered-union"
30
}
31
32
fn update_state(
33
&mut self,
34
recv: &mut [PortState],
35
send: &mut [PortState],
36
_state: &StreamingExecutionState,
37
) -> PolarsResult<()> {
38
assert!(self.cur_input_idx <= recv.len() && send.len() == 1);
39
40
// Skip inputs that are done.
41
while self.cur_input_idx < recv.len() && recv[self.cur_input_idx] == PortState::Done {
42
self.cur_input_idx += 1;
43
}
44
45
// Act like a normal pass-through node for the current input, or mark
46
// ourselves as done if all inputs are handled.
47
if self.cur_input_idx < recv.len() {
48
core::mem::swap(&mut recv[self.cur_input_idx], &mut send[0]);
49
} else {
50
send[0] = PortState::Done;
51
}
52
53
// Mark all inputs after the current one as blocked.
54
for r in recv.iter_mut().skip(self.cur_input_idx + 1) {
55
*r = PortState::Blocked;
56
}
57
58
// Set the morsel offset one higher than any sent so far.
59
self.morsel_offset = self.max_morsel_seq_sent.successor();
60
Ok(())
61
}
62
63
fn spawn<'env, 's>(
64
&'env mut self,
65
scope: &'s TaskScope<'s, 'env>,
66
recv_ports: &mut [Option<RecvPort<'_>>],
67
send_ports: &mut [Option<SendPort<'_>>],
68
_state: &'s StreamingExecutionState,
69
join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,
70
) {
71
let ready_count = recv_ports.iter().filter(|r| r.is_some()).count();
72
assert!(ready_count == 1 && send_ports.len() == 1);
73
let receivers = recv_ports[self.cur_input_idx].take().unwrap().parallel();
74
let senders = send_ports[0].take().unwrap().parallel();
75
76
let mut inner_handles = Vec::new();
77
for (mut recv, mut send) in receivers.into_iter().zip(senders) {
78
let output_schema = self.output_schema.clone();
79
let morsel_offset = self.morsel_offset;
80
inner_handles.push(scope.spawn_task(TaskPriority::High, async move {
81
let mut max_seq = MorselSeq::new(0);
82
while let Ok(mut morsel) = recv.recv().await {
83
// Ensure the morsel matches the expected output schema,
84
// casting nulls to the appropriate output type.
85
morsel.df_mut().ensure_matches_schema(&output_schema)?;
86
87
// Ensure the morsel sequence id stream is monotonic.
88
let seq = morsel.seq().offset_by(morsel_offset);
89
max_seq = max_seq.max(seq);
90
91
morsel.set_seq(seq);
92
if send.send(morsel).await.is_err() {
93
break;
94
}
95
}
96
PolarsResult::Ok(max_seq)
97
}));
98
}
99
100
join_handles.push(scope.spawn_task(TaskPriority::High, async move {
101
// Update our global maximum.
102
for handle in inner_handles {
103
self.max_morsel_seq_sent = self.max_morsel_seq_sent.max(handle.await?);
104
}
105
Ok(())
106
}));
107
}
108
}
109
110