Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-stream/src/nodes/joins/cross_join.rs
6939 views
1
use std::sync::Arc;
2
3
use arrow::array::builder::ShareStrategy;
4
use polars_core::frame::builder::DataFrameBuilder;
5
use polars_core::schema::Schema;
6
use polars_ops::frame::{JoinArgs, MaintainOrderJoin};
7
use polars_utils::format_pl_smallstr;
8
use polars_utils::pl_str::PlSmallStr;
9
10
use crate::morsel::get_ideal_morsel_size;
11
use crate::nodes::compute_node_prelude::*;
12
use crate::nodes::in_memory_sink::InMemorySinkNode;
13
14
pub struct CrossJoinNode {
15
left_is_build: bool,
16
left_input_schema: Arc<Schema>,
17
right_input_schema: Arc<Schema>,
18
right_rename: Vec<Option<PlSmallStr>>,
19
state: CrossJoinState,
20
}
21
22
impl CrossJoinNode {
23
pub fn new(
24
left_input_schema: Arc<Schema>,
25
right_input_schema: Arc<Schema>,
26
args: &JoinArgs,
27
) -> Self {
28
let left_is_build = match args.maintain_order {
29
MaintainOrderJoin::None => true, // TODO: size estimation.
30
MaintainOrderJoin::Left | MaintainOrderJoin::LeftRight => false,
31
MaintainOrderJoin::Right | MaintainOrderJoin::RightLeft => true,
32
};
33
let build_input_schema = if left_is_build {
34
&left_input_schema
35
} else {
36
&right_input_schema
37
};
38
let sink_node = InMemorySinkNode::new(build_input_schema.clone());
39
let right_rename = right_input_schema
40
.iter_names()
41
.map(|rname| {
42
if left_input_schema.contains(rname) {
43
Some(format_pl_smallstr!("{}{}", rname, args.suffix()))
44
} else {
45
None
46
}
47
})
48
.collect();
49
50
Self {
51
left_is_build,
52
left_input_schema,
53
right_input_schema,
54
right_rename,
55
state: CrossJoinState::Build(sink_node),
56
}
57
}
58
}
59
60
enum CrossJoinState {
61
Build(InMemorySinkNode),
62
Probe(DataFrame),
63
Done,
64
}
65
66
impl ComputeNode for CrossJoinNode {
67
fn name(&self) -> &str {
68
"cross-join"
69
}
70
71
fn is_memory_intensive_pipeline_blocker(&self) -> bool {
72
true
73
}
74
75
fn update_state(
76
&mut self,
77
recv: &mut [PortState],
78
send: &mut [PortState],
79
_state: &StreamingExecutionState,
80
) -> PolarsResult<()> {
81
assert!(recv.len() == 2 && send.len() == 1);
82
83
let build_idx = if self.left_is_build { 0 } else { 1 };
84
let probe_idx = 1 - build_idx;
85
86
// Are we done?
87
if send[0] == PortState::Done || recv[probe_idx] == PortState::Done {
88
self.state = CrossJoinState::Done;
89
}
90
91
// Transition to build?
92
if recv[build_idx] == PortState::Done {
93
if let CrossJoinState::Build(sink_node) = &mut self.state {
94
let df = sink_node.get_output()?.unwrap();
95
if df.height() > 0 {
96
self.state = CrossJoinState::Probe(df);
97
} else {
98
self.state = CrossJoinState::Done;
99
}
100
}
101
}
102
103
match &self.state {
104
CrossJoinState::Build(_) => {
105
recv[build_idx] = PortState::Ready;
106
recv[probe_idx] = PortState::Blocked;
107
send[0] = PortState::Blocked;
108
},
109
CrossJoinState::Probe(_) => {
110
recv[build_idx] = PortState::Done;
111
core::mem::swap(&mut recv[probe_idx], &mut send[0]);
112
},
113
CrossJoinState::Done => {
114
recv[0] = PortState::Done;
115
recv[1] = PortState::Done;
116
send[0] = PortState::Done;
117
},
118
}
119
Ok(())
120
}
121
122
fn spawn<'env, 's>(
123
&'env mut self,
124
scope: &'s TaskScope<'s, 'env>,
125
recv_ports: &mut [Option<RecvPort<'_>>],
126
send_ports: &mut [Option<SendPort<'_>>],
127
state: &'s StreamingExecutionState,
128
join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,
129
) {
130
assert!(recv_ports.len() == 2 && send_ports.len() == 1);
131
let build_idx = if self.left_is_build { 0 } else { 1 };
132
let probe_idx = 1 - build_idx;
133
match &mut self.state {
134
CrossJoinState::Build(sink_node) => {
135
assert!(send_ports[0].is_none());
136
assert!(recv_ports[probe_idx].is_none());
137
sink_node.spawn(
138
scope,
139
&mut recv_ports[build_idx..build_idx + 1],
140
&mut [],
141
state,
142
join_handles,
143
);
144
},
145
CrossJoinState::Probe(build_df) => {
146
assert!(recv_ports[build_idx].is_none());
147
let receivers = recv_ports[probe_idx].take().unwrap().parallel();
148
let senders = send_ports[0].take().unwrap().parallel();
149
let ideal_morsel_size = get_ideal_morsel_size();
150
151
for (mut recv, mut send) in receivers.into_iter().zip(senders) {
152
let left_is_build = self.left_is_build;
153
let left_input_schema = self.left_input_schema.clone();
154
let right_input_schema = self.right_input_schema.clone();
155
let right_rename = &self.right_rename;
156
let build_df = &*build_df;
157
join_handles.push(
158
scope.spawn_task(TaskPriority::High, async move {
159
let mut build_repeater = DataFrameBuilder::new(left_input_schema);
160
let mut probe_repeater = DataFrameBuilder::new(right_input_schema);
161
if !left_is_build {
162
core::mem::swap(&mut build_repeater, &mut probe_repeater);
163
}
164
let mut cached_build_df_repeated = DataFrame::empty();
165
166
while let Ok(morsel) = recv.recv().await {
167
let combine =
168
|build_join_df: DataFrame, probe_join_df: DataFrame| unsafe {
169
let (mut left_join_df, mut right_join_df);
170
left_join_df = build_join_df;
171
right_join_df = probe_join_df;
172
if !left_is_build {
173
core::mem::swap(&mut left_join_df, &mut right_join_df);
174
}
175
176
for (col, opt_rename) in right_join_df
177
.get_columns_mut()
178
.iter_mut()
179
.zip(right_rename)
180
{
181
if let Some(rename) = opt_rename {
182
col.rename(rename.clone());
183
}
184
}
185
186
left_join_df
187
.hstack_mut_unchecked(right_join_df.get_columns());
188
Morsel::new(
189
left_join_df,
190
morsel.seq(),
191
morsel.source_token().clone(),
192
)
193
};
194
195
let probe_df = morsel.df();
196
if build_df.height() >= ideal_morsel_size {
197
for probe_offset in 0..probe_df.height() {
198
let mut build_offset = 0;
199
while build_offset < build_df.height() {
200
let height = (build_df.height() - build_offset)
201
.min(ideal_morsel_size);
202
let build_join_df =
203
build_df.slice(build_offset as i64, height);
204
let probe_join_df =
205
probe_df.new_from_index(probe_offset, height);
206
let combined = combine(build_join_df, probe_join_df);
207
if send.send(combined).await.is_err() {
208
return Ok(());
209
}
210
build_offset += height;
211
}
212
}
213
} else {
214
let max_build_repeats = ideal_morsel_size / build_df.height();
215
let mut probe_offset = 0;
216
while probe_offset < probe_df.height() {
217
let build_repeats = (probe_df.height() - probe_offset)
218
.min(max_build_repeats);
219
let build_height = build_repeats * build_df.height();
220
if build_height > cached_build_df_repeated.height() {
221
build_repeater.subslice_extend_repeated(
222
build_df,
223
0,
224
build_df.height(),
225
build_repeats,
226
ShareStrategy::Never,
227
);
228
cached_build_df_repeated =
229
build_repeater.freeze_reset();
230
}
231
let build_join_df =
232
cached_build_df_repeated.slice(0, build_height);
233
234
probe_repeater.subslice_extend_each_repeated(
235
probe_df,
236
probe_offset,
237
build_repeats,
238
build_df.height(),
239
ShareStrategy::Always,
240
);
241
let probe_join_df = probe_repeater.freeze_reset();
242
243
let combined = combine(build_join_df, probe_join_df);
244
if send.send(combined).await.is_err() {
245
return Ok(());
246
}
247
248
probe_offset += build_repeats;
249
}
250
}
251
}
252
Ok(())
253
}),
254
);
255
}
256
},
257
CrossJoinState::Done => unreachable!(),
258
}
259
}
260
}
261
262