Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-stream/src/nodes/joins/cross_join.rs
8479 views
1
use std::sync::Arc;
2
3
use arrow::array::builder::ShareStrategy;
4
use polars_core::frame::builder::DataFrameBuilder;
5
use polars_core::schema::Schema;
6
use polars_error::polars_warn;
7
use polars_ops::frame::{JoinArgs, JoinBuildSide, MaintainOrderJoin};
8
use polars_utils::format_pl_smallstr;
9
use polars_utils::pl_str::PlSmallStr;
10
11
use crate::morsel::get_ideal_morsel_size;
12
use crate::nodes::compute_node_prelude::*;
13
use crate::nodes::in_memory_sink::InMemorySinkNode;
14
15
pub struct CrossJoinNode {
16
left_is_build: bool,
17
left_input_schema: Arc<Schema>,
18
right_input_schema: Arc<Schema>,
19
right_rename: Vec<Option<PlSmallStr>>,
20
state: CrossJoinState,
21
}
22
23
impl CrossJoinNode {
24
pub fn new(
25
left_input_schema: Arc<Schema>,
26
right_input_schema: Arc<Schema>,
27
args: &JoinArgs,
28
) -> Self {
29
let left_is_build = match args.maintain_order {
30
MaintainOrderJoin::None => match args.build_side {
31
// TODO: size estimation.
32
None | Some(JoinBuildSide::PreferLeft) | Some(JoinBuildSide::ForceLeft) => true,
33
Some(JoinBuildSide::PreferRight) | Some(JoinBuildSide::ForceRight) => false,
34
},
35
MaintainOrderJoin::Left | MaintainOrderJoin::LeftRight => {
36
if args.build_side == Some(JoinBuildSide::ForceLeft) {
37
polars_warn!("can't force left build-side with left-maintaining cross-join");
38
}
39
false
40
},
41
MaintainOrderJoin::Right | MaintainOrderJoin::RightLeft => {
42
if args.build_side == Some(JoinBuildSide::ForceRight) {
43
polars_warn!("can't force right build-side with right-maintaining cross-join");
44
}
45
true
46
},
47
};
48
let build_input_schema = if left_is_build {
49
&left_input_schema
50
} else {
51
&right_input_schema
52
};
53
let sink_node = InMemorySinkNode::new(build_input_schema.clone());
54
let right_rename = right_input_schema
55
.iter_names()
56
.map(|rname| {
57
if left_input_schema.contains(rname) {
58
Some(format_pl_smallstr!("{}{}", rname, args.suffix()))
59
} else {
60
None
61
}
62
})
63
.collect();
64
65
Self {
66
left_is_build,
67
left_input_schema,
68
right_input_schema,
69
right_rename,
70
state: CrossJoinState::Build(sink_node),
71
}
72
}
73
}
74
75
enum CrossJoinState {
76
Build(InMemorySinkNode),
77
Probe(DataFrame),
78
Done,
79
}
80
81
impl ComputeNode for CrossJoinNode {
82
fn name(&self) -> &str {
83
"cross-join"
84
}
85
86
fn is_memory_intensive_pipeline_blocker(&self) -> bool {
87
true
88
}
89
90
fn update_state(
91
&mut self,
92
recv: &mut [PortState],
93
send: &mut [PortState],
94
_state: &StreamingExecutionState,
95
) -> PolarsResult<()> {
96
assert!(recv.len() == 2 && send.len() == 1);
97
98
let build_idx = if self.left_is_build { 0 } else { 1 };
99
let probe_idx = 1 - build_idx;
100
101
// Are we done?
102
if send[0] == PortState::Done || recv[probe_idx] == PortState::Done {
103
self.state = CrossJoinState::Done;
104
}
105
106
// Transition to build?
107
if recv[build_idx] == PortState::Done {
108
if let CrossJoinState::Build(sink_node) = &mut self.state {
109
let df = sink_node.get_output()?.unwrap();
110
if df.height() > 0 {
111
self.state = CrossJoinState::Probe(df);
112
} else {
113
self.state = CrossJoinState::Done;
114
}
115
}
116
}
117
118
match &self.state {
119
CrossJoinState::Build(_) => {
120
recv[build_idx] = PortState::Ready;
121
recv[probe_idx] = PortState::Blocked;
122
send[0] = PortState::Blocked;
123
},
124
CrossJoinState::Probe(_) => {
125
recv[build_idx] = PortState::Done;
126
core::mem::swap(&mut recv[probe_idx], &mut send[0]);
127
},
128
CrossJoinState::Done => {
129
recv[0] = PortState::Done;
130
recv[1] = PortState::Done;
131
send[0] = PortState::Done;
132
},
133
}
134
Ok(())
135
}
136
137
fn spawn<'env, 's>(
138
&'env mut self,
139
scope: &'s TaskScope<'s, 'env>,
140
recv_ports: &mut [Option<RecvPort<'_>>],
141
send_ports: &mut [Option<SendPort<'_>>],
142
state: &'s StreamingExecutionState,
143
join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,
144
) {
145
assert!(recv_ports.len() == 2 && send_ports.len() == 1);
146
let build_idx = if self.left_is_build { 0 } else { 1 };
147
let probe_idx = 1 - build_idx;
148
match &mut self.state {
149
CrossJoinState::Build(sink_node) => {
150
assert!(send_ports[0].is_none());
151
assert!(recv_ports[probe_idx].is_none());
152
sink_node.spawn(
153
scope,
154
&mut recv_ports[build_idx..build_idx + 1],
155
&mut [],
156
state,
157
join_handles,
158
);
159
},
160
CrossJoinState::Probe(build_df) => {
161
assert!(recv_ports[build_idx].is_none());
162
let receivers = recv_ports[probe_idx].take().unwrap().parallel();
163
let senders = send_ports[0].take().unwrap().parallel();
164
let ideal_morsel_size = get_ideal_morsel_size();
165
166
for (mut recv, mut send) in receivers.into_iter().zip(senders) {
167
let left_is_build = self.left_is_build;
168
let left_input_schema = self.left_input_schema.clone();
169
let right_input_schema = self.right_input_schema.clone();
170
let right_rename = &self.right_rename;
171
let build_df = &*build_df;
172
join_handles.push(
173
scope.spawn_task(TaskPriority::High, async move {
174
let mut build_repeater = DataFrameBuilder::new(left_input_schema);
175
let mut probe_repeater = DataFrameBuilder::new(right_input_schema);
176
if !left_is_build {
177
core::mem::swap(&mut build_repeater, &mut probe_repeater);
178
}
179
let mut cached_build_df_repeated = DataFrame::empty();
180
181
while let Ok(morsel) = recv.recv().await {
182
let combine =
183
|build_join_df: DataFrame, probe_join_df: DataFrame| unsafe {
184
let (mut left_join_df, mut right_join_df);
185
left_join_df = build_join_df;
186
right_join_df = probe_join_df;
187
if !left_is_build {
188
core::mem::swap(&mut left_join_df, &mut right_join_df);
189
}
190
191
for (col, opt_rename) in
192
right_join_df.columns_mut().iter_mut().zip(right_rename)
193
{
194
if let Some(rename) = opt_rename {
195
col.rename(rename.clone());
196
}
197
}
198
199
left_join_df.hstack_mut_unchecked(right_join_df.columns());
200
Morsel::new(
201
left_join_df,
202
morsel.seq(),
203
morsel.source_token().clone(),
204
)
205
};
206
207
let probe_df = morsel.df();
208
if build_df.height() >= ideal_morsel_size {
209
for probe_offset in 0..probe_df.height() {
210
let mut build_offset = 0;
211
while build_offset < build_df.height() {
212
let height = (build_df.height() - build_offset)
213
.min(ideal_morsel_size);
214
let build_join_df =
215
build_df.slice(build_offset as i64, height);
216
let probe_join_df =
217
probe_df.new_from_index(probe_offset, height);
218
let combined = combine(build_join_df, probe_join_df);
219
if send.send(combined).await.is_err() {
220
return Ok(());
221
}
222
build_offset += height;
223
}
224
}
225
} else {
226
let max_build_repeats = ideal_morsel_size / build_df.height();
227
let mut probe_offset = 0;
228
while probe_offset < probe_df.height() {
229
let build_repeats = (probe_df.height() - probe_offset)
230
.min(max_build_repeats);
231
let build_height = build_repeats * build_df.height();
232
if build_height > cached_build_df_repeated.height() {
233
build_repeater.subslice_extend_repeated(
234
build_df,
235
0,
236
build_df.height(),
237
build_repeats,
238
ShareStrategy::Never,
239
);
240
cached_build_df_repeated =
241
build_repeater.freeze_reset();
242
}
243
let build_join_df =
244
cached_build_df_repeated.slice(0, build_height);
245
246
probe_repeater.subslice_extend_each_repeated(
247
probe_df,
248
probe_offset,
249
build_repeats,
250
build_df.height(),
251
ShareStrategy::Always,
252
);
253
let probe_join_df = probe_repeater.freeze_reset();
254
255
let combined = combine(build_join_df, probe_join_df);
256
if send.send(combined).await.is_err() {
257
return Ok(());
258
}
259
260
probe_offset += build_repeats;
261
}
262
}
263
}
264
Ok(())
265
}),
266
);
267
}
268
},
269
CrossJoinState::Done => unreachable!(),
270
}
271
}
272
}
273
274