Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-stream/src/nodes/unordered_union.rs
8430 views
1
use std::sync::Arc;
2
3
use polars_core::schema::Schema;
4
use tokio::sync::mpsc;
5
6
use super::compute_node_prelude::*;
7
8
pub struct UnorderedUnionNode {
9
max_morsel_seq_sent: MorselSeq,
10
output_schema: Arc<Schema>,
11
}
12
13
impl UnorderedUnionNode {
14
pub fn new(output_schema: Arc<Schema>) -> Self {
15
Self {
16
max_morsel_seq_sent: MorselSeq::new(0),
17
output_schema,
18
}
19
}
20
}
21
22
impl ComputeNode for UnorderedUnionNode {
23
fn name(&self) -> &str {
24
"unordered-union"
25
}
26
27
fn update_state(
28
&mut self,
29
recv: &mut [PortState],
30
send: &mut [PortState],
31
_state: &StreamingExecutionState,
32
) -> PolarsResult<()> {
33
assert_eq!(send.len(), 1);
34
35
let done = send[0] == PortState::Done || recv.iter().all(|r| *r == PortState::Done);
36
if done {
37
send[0] = PortState::Done;
38
recv.fill(PortState::Done);
39
return Ok(());
40
}
41
42
let any_ready = recv.contains(&PortState::Ready);
43
recv.fill(send[0]);
44
send[0] = if any_ready {
45
PortState::Ready
46
} else {
47
PortState::Blocked
48
};
49
Ok(())
50
}
51
52
fn spawn<'env, 's>(
53
&'env mut self,
54
scope: &'s TaskScope<'s, 'env>,
55
recv_ports: &mut [Option<RecvPort<'_>>],
56
send_ports: &mut [Option<SendPort<'_>>],
57
state: &'s StreamingExecutionState,
58
join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,
59
) {
60
assert_eq!(send_ports.len(), 1);
61
let output_senders = send_ports[0].take().unwrap().parallel();
62
let num_pipelines = output_senders.len();
63
assert_eq!(num_pipelines, state.num_pipelines);
64
65
let (mpsc_senders, mpsc_receivers): (Vec<_>, Vec<_>) = (0..num_pipelines)
66
.map(|_| mpsc::channel::<Morsel>(1))
67
.unzip();
68
69
for recv_port in recv_ports {
70
if let Some(recv) = recv_port.take() {
71
let receivers = recv.parallel();
72
let mpsc_senders_clone = mpsc_senders.clone();
73
74
for (mut receiver, sender) in receivers.into_iter().zip(mpsc_senders_clone) {
75
let output_schema = self.output_schema.clone();
76
join_handles.push(scope.spawn_task(TaskPriority::High, async move {
77
while let Ok(mut morsel) = receiver.recv().await {
78
// Ensure the morsel matches the expected output schema,
79
// casting nulls to the appropriate output type.
80
morsel.df_mut().ensure_matches_schema(&output_schema)?;
81
82
if sender.send(morsel).await.is_err() {
83
break;
84
}
85
}
86
PolarsResult::Ok(())
87
}));
88
}
89
}
90
}
91
92
drop(mpsc_senders);
93
94
// Each pipeline relabels morsel sequences independently of the others.
95
// We first compute the `morsel_offset` as (max morsel sequence sent so far + 1), so this
96
// phase never reuses sequence numbers from earlier phases.
97
//
98
// Then, each pipeline assigns sequences by:
99
// - starting at `morsel_offset + pipeline_idx` (so pipelines start at different values),
100
// - advancing by `num_pipelines` each time it emits a morsel.
101
//
102
// Example with 2 pipelines (num_pipelines = 2) and morsel_offset = 1000:
103
// pipeline 0: 1000, 1002, 1004, ...
104
// pipeline 1: 1001, 1003, 1005, ...
105
//
106
// This guarantees:
107
// - Global uniqueness: no collisions with earlier phases, and no collisions across pipelines.
108
// - Per-pipeline non-decreasing: each pipeline only moves forward by a fixed positive step.
109
let morsel_offset = self.max_morsel_seq_sent.successor();
110
111
let mut inner_handles = Vec::new();
112
for (lane_idx, (mut mpsc_receiver, mut output_sender)) in
113
mpsc_receivers.into_iter().zip(output_senders).enumerate()
114
{
115
inner_handles.push(scope.spawn_task(TaskPriority::High, async move {
116
let mut local_seq = morsel_offset.offset_by_u64(lane_idx as u64);
117
let seq_step = num_pipelines as u64;
118
let mut max_seq = MorselSeq::new(0);
119
120
while let Some(mut morsel) = mpsc_receiver.recv().await {
121
morsel.set_seq(local_seq);
122
max_seq = max_seq.max(local_seq);
123
local_seq = local_seq.offset_by_u64(seq_step);
124
125
if output_sender.send(morsel).await.is_err() {
126
break;
127
}
128
}
129
130
PolarsResult::Ok(max_seq)
131
}));
132
}
133
134
join_handles.push(scope.spawn_task(TaskPriority::High, async move {
135
for handle in inner_handles {
136
self.max_morsel_seq_sent = self.max_morsel_seq_sent.max(handle.await?);
137
}
138
Ok(())
139
}));
140
}
141
}
142
143