Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-stream/src/nodes/joins/equi_join.rs
6939 views
1
use std::cmp::Reverse;
2
use std::collections::BinaryHeap;
3
use std::sync::Arc;
4
use std::sync::atomic::{AtomicU64, Ordering};
5
6
use arrow::array::builder::ShareStrategy;
7
use polars_core::frame::builder::DataFrameBuilder;
8
use polars_core::prelude::*;
9
use polars_core::schema::{Schema, SchemaExt};
10
use polars_core::{POOL, config};
11
use polars_expr::hash_keys::HashKeys;
12
use polars_expr::idx_table::{IdxTable, new_idx_table};
13
use polars_io::pl_async::get_runtime;
14
use polars_ops::frame::{JoinArgs, JoinType, MaintainOrderJoin};
15
use polars_ops::series::coalesce_columns;
16
use polars_utils::cardinality_sketch::CardinalitySketch;
17
use polars_utils::hashing::HashPartitioner;
18
use polars_utils::itertools::Itertools;
19
use polars_utils::pl_str::PlSmallStr;
20
use polars_utils::priority::Priority;
21
use polars_utils::relaxed_cell::RelaxedCell;
22
use polars_utils::sparse_init_vec::SparseInitVec;
23
use polars_utils::{IdxSize, format_pl_smallstr};
24
use rayon::prelude::*;
25
26
use super::{BufferedStream, JOIN_SAMPLE_LIMIT, LOPSIDED_SAMPLE_FACTOR};
27
use crate::async_executor;
28
use crate::async_primitives::connector::{Receiver, Sender};
29
use crate::async_primitives::wait_group::WaitGroup;
30
use crate::expression::StreamExpr;
31
use crate::morsel::{SourceToken, get_ideal_morsel_size};
32
use crate::nodes::compute_node_prelude::*;
33
use crate::nodes::in_memory_source::InMemorySourceNode;
34
35
struct EquiJoinParams {
36
left_is_build: Option<bool>,
37
preserve_order_build: bool,
38
preserve_order_probe: bool,
39
left_key_schema: Arc<Schema>,
40
left_key_selectors: Vec<StreamExpr>,
41
#[allow(dead_code)]
42
right_key_schema: Arc<Schema>,
43
right_key_selectors: Vec<StreamExpr>,
44
left_payload_select: Vec<Option<PlSmallStr>>,
45
right_payload_select: Vec<Option<PlSmallStr>>,
46
left_payload_schema: Arc<Schema>,
47
right_payload_schema: Arc<Schema>,
48
args: JoinArgs,
49
random_state: PlRandomState,
50
}
51
52
impl EquiJoinParams {
53
/// Should we emit unmatched rows from the build side?
54
fn emit_unmatched_build(&self) -> bool {
55
if self.left_is_build.unwrap() {
56
self.args.how == JoinType::Left || self.args.how == JoinType::Full
57
} else {
58
self.args.how == JoinType::Right || self.args.how == JoinType::Full
59
}
60
}
61
62
/// Should we emit unmatched rows from the probe side?
63
fn emit_unmatched_probe(&self) -> bool {
64
if self.left_is_build.unwrap() {
65
self.args.how == JoinType::Right || self.args.how == JoinType::Full
66
} else {
67
self.args.how == JoinType::Left || self.args.how == JoinType::Full
68
}
69
}
70
}
71
72
/// A payload selector contains for each column whether that column should be
73
/// included in the payload, and if yes with what name.
74
fn compute_payload_selector(
75
this: &Schema,
76
other: &Schema,
77
this_key_schema: &Schema,
78
other_key_schema: &Schema,
79
is_left: bool,
80
args: &JoinArgs,
81
) -> PolarsResult<Vec<Option<PlSmallStr>>> {
82
let should_coalesce = args.should_coalesce();
83
84
let mut coalesce_idx = 0;
85
this.iter_names()
86
.map(|c| {
87
#[expect(clippy::never_loop)]
88
loop {
89
let selector = if args.how == JoinType::Right {
90
if is_left {
91
if should_coalesce && this_key_schema.contains(c) {
92
// Coalesced to RHS output key.
93
None
94
} else {
95
Some(c.clone())
96
}
97
} else if !other.contains(c) || (should_coalesce && other_key_schema.contains(c)) {
98
Some(c.clone())
99
} else {
100
break;
101
}
102
} else if should_coalesce && this_key_schema.contains(c) {
103
if is_left {
104
Some(c.clone())
105
} else if args.how == JoinType::Full {
106
// We must keep the right-hand side keycols around for
107
// coalescing.
108
let name = format_pl_smallstr!("__POLARS_COALESCE_KEYCOL{coalesce_idx}");
109
coalesce_idx += 1;
110
Some(name)
111
} else {
112
None
113
}
114
} else if !other.contains(c) || is_left {
115
Some(c.clone())
116
} else {
117
break;
118
};
119
120
return Ok(selector);
121
}
122
123
let suffixed = format_pl_smallstr!("{}{}", c, args.suffix());
124
if other.contains(&suffixed) {
125
polars_bail!(Duplicate: "column with name '{suffixed}' already exists\n\n\
126
You may want to try:\n\
127
- renaming the column prior to joining\n\
128
- using the `suffix` parameter to specify a suffix different to the default one ('_right')")
129
}
130
131
Ok(Some(suffixed))
132
})
133
.collect()
134
}
135
136
/// Fixes names and does coalescing of columns post-join.
137
fn postprocess_join(df: DataFrame, params: &EquiJoinParams) -> DataFrame {
138
if params.args.how == JoinType::Full && params.args.should_coalesce() {
139
// TODO: don't do string-based column lookups for each dataframe, pre-compute coalesce indices.
140
let mut coalesce_idx = 0;
141
df.get_columns()
142
.iter()
143
.filter_map(|c| {
144
if params.left_key_schema.contains(c.name()) {
145
let other = df
146
.column(&format_pl_smallstr!(
147
"__POLARS_COALESCE_KEYCOL{coalesce_idx}"
148
))
149
.unwrap();
150
coalesce_idx += 1;
151
return Some(coalesce_columns(&[c.clone(), other.clone()]).unwrap());
152
}
153
154
if c.name().starts_with("__POLARS_COALESCE_KEYCOL") {
155
return None;
156
}
157
158
Some(c.clone())
159
})
160
.collect()
161
} else {
162
df
163
}
164
}
165
166
fn select_schema(schema: &Schema, selector: &[Option<PlSmallStr>]) -> Schema {
167
schema
168
.iter_fields()
169
.zip(selector)
170
.filter_map(|(f, name)| Some(f.with_name(name.clone()?)))
171
.collect()
172
}
173
174
async fn select_keys(
175
df: &DataFrame,
176
key_selectors: &[StreamExpr],
177
params: &EquiJoinParams,
178
state: &ExecutionState,
179
) -> PolarsResult<HashKeys> {
180
let mut key_columns = Vec::new();
181
for selector in key_selectors {
182
key_columns.push(selector.evaluate(df, state).await?.into_column());
183
}
184
let keys = DataFrame::new_with_broadcast_len(key_columns, df.height())?;
185
Ok(HashKeys::from_df(
186
&keys,
187
params.random_state,
188
params.args.nulls_equal,
189
false,
190
))
191
}
192
193
fn select_payload(df: DataFrame, selector: &[Option<PlSmallStr>]) -> DataFrame {
194
// Maintain height of zero-width dataframes.
195
if df.width() == 0 {
196
return df;
197
}
198
199
df.take_columns()
200
.into_iter()
201
.zip(selector)
202
.filter_map(|(c, name)| Some(c.with_name(name.clone()?)))
203
.collect()
204
}
205
206
fn estimate_cardinality(
207
morsels: &[Morsel],
208
key_selectors: &[StreamExpr],
209
params: &EquiJoinParams,
210
state: &ExecutionState,
211
) -> PolarsResult<f64> {
212
let sample_limit = *JOIN_SAMPLE_LIMIT;
213
if morsels.is_empty() || sample_limit == 0 {
214
return Ok(0.0);
215
}
216
217
let mut total_height = 0;
218
let mut to_process_end = 0;
219
while to_process_end < morsels.len() && total_height < sample_limit {
220
total_height += morsels[to_process_end].df().height();
221
to_process_end += 1;
222
}
223
let last_morsel_idx = to_process_end - 1;
224
let last_morsel_len = morsels[last_morsel_idx].df().height();
225
let last_morsel_slice = last_morsel_len - total_height.saturating_sub(sample_limit);
226
let runtime = get_runtime();
227
228
POOL.install(|| {
229
let sample_cardinality = morsels[..to_process_end]
230
.par_iter()
231
.enumerate()
232
.try_fold(
233
CardinalitySketch::new,
234
|mut sketch, (morsel_idx, morsel)| {
235
let sliced;
236
let df = if morsel_idx == last_morsel_idx {
237
sliced = morsel.df().slice(0, last_morsel_slice);
238
&sliced
239
} else {
240
morsel.df()
241
};
242
let hash_keys =
243
runtime.block_on(select_keys(df, key_selectors, params, state))?;
244
hash_keys.sketch_cardinality(&mut sketch);
245
PolarsResult::Ok(sketch)
246
},
247
)
248
.map(|sketch| PolarsResult::Ok(sketch?.estimate()))
249
.try_reduce_with(|a, b| Ok(a + b))
250
.unwrap()?;
251
Ok(sample_cardinality as f64 / total_height.min(sample_limit) as f64)
252
})
253
}
254
255
#[derive(Default)]
256
struct SampleState {
257
left: Vec<Morsel>,
258
left_len: usize,
259
right: Vec<Morsel>,
260
right_len: usize,
261
}
262
263
impl SampleState {
264
async fn sink(
265
mut recv: Receiver<Morsel>,
266
morsels: &mut Vec<Morsel>,
267
len: &mut usize,
268
this_final_len: Arc<RelaxedCell<usize>>,
269
other_final_len: Arc<RelaxedCell<usize>>,
270
) -> PolarsResult<()> {
271
while let Ok(mut morsel) = recv.recv().await {
272
*len += morsel.df().height();
273
if *len >= *JOIN_SAMPLE_LIMIT
274
|| *len
275
>= other_final_len
276
.load()
277
.saturating_mul(LOPSIDED_SAMPLE_FACTOR)
278
{
279
morsel.source_token().stop();
280
}
281
282
drop(morsel.take_consume_token());
283
morsels.push(morsel);
284
}
285
this_final_len.store(*len);
286
Ok(())
287
}
288
289
fn try_transition_to_build(
290
&mut self,
291
recv: &[PortState],
292
params: &mut EquiJoinParams,
293
state: &StreamingExecutionState,
294
) -> PolarsResult<Option<BuildState>> {
295
let left_saturated = self.left_len >= *JOIN_SAMPLE_LIMIT;
296
let right_saturated = self.right_len >= *JOIN_SAMPLE_LIMIT;
297
let left_done = recv[0] == PortState::Done || left_saturated;
298
let right_done = recv[1] == PortState::Done || right_saturated;
299
#[expect(clippy::nonminimal_bool)]
300
let stop_sampling = (left_done && right_done)
301
|| (left_done && self.right_len >= LOPSIDED_SAMPLE_FACTOR * self.left_len)
302
|| (right_done && self.left_len >= LOPSIDED_SAMPLE_FACTOR * self.right_len);
303
if !stop_sampling {
304
return Ok(None);
305
}
306
307
if config::verbose() {
308
eprintln!(
309
"choosing build side, sample lengths are: {} vs. {}",
310
self.left_len, self.right_len
311
);
312
}
313
314
let estimate_cardinalities = || {
315
let left_cardinality = estimate_cardinality(
316
&self.left,
317
&params.left_key_selectors,
318
params,
319
&state.in_memory_exec_state,
320
)?;
321
let right_cardinality = estimate_cardinality(
322
&self.right,
323
&params.right_key_selectors,
324
params,
325
&state.in_memory_exec_state,
326
)?;
327
if config::verbose() {
328
eprintln!(
329
"estimated cardinalities are: {left_cardinality} vs. {right_cardinality}"
330
);
331
}
332
PolarsResult::Ok((left_cardinality, right_cardinality))
333
};
334
335
let left_is_build = match (left_saturated, right_saturated) {
336
// Don't bother estimating cardinality, just choose smaller side as
337
// we have everything in-memory anyway.
338
(false, false) => self.left_len < self.right_len,
339
340
// Choose the unsaturated side, the saturated side could be
341
// arbitrarily big.
342
(false, true) => true,
343
(true, false) => false,
344
345
// Estimate cardinality and choose smaller.
346
(true, true) => {
347
let (lc, rc) = estimate_cardinalities()?;
348
lc < rc
349
},
350
};
351
352
if config::verbose() {
353
eprintln!(
354
"build side chosen: {}",
355
if left_is_build { "left" } else { "right" }
356
);
357
}
358
359
// Transition to building state.
360
params.left_is_build = Some(left_is_build);
361
let mut sampled_build_morsels =
362
BufferedStream::new(core::mem::take(&mut self.left), MorselSeq::default());
363
let mut sampled_probe_morsels =
364
BufferedStream::new(core::mem::take(&mut self.right), MorselSeq::default());
365
if !left_is_build {
366
core::mem::swap(&mut sampled_build_morsels, &mut sampled_probe_morsels);
367
}
368
369
let partitioner = HashPartitioner::new(state.num_pipelines, 0);
370
let mut build_state = BuildState::new(
371
state.num_pipelines,
372
state.num_pipelines,
373
sampled_probe_morsels,
374
);
375
376
// Simulate the sample build morsels flowing into the build side.
377
if !sampled_build_morsels.is_empty() {
378
crate::async_executor::task_scope(|scope| {
379
let mut join_handles = Vec::new();
380
let receivers = sampled_build_morsels
381
.reinsert(state.num_pipelines, None, scope, &mut join_handles)
382
.unwrap();
383
384
for (local_builder, recv) in build_state.local_builders.iter_mut().zip(receivers) {
385
join_handles.push(scope.spawn_task(
386
TaskPriority::High,
387
BuildState::partition_and_sink(
388
recv,
389
local_builder,
390
partitioner.clone(),
391
params,
392
state,
393
),
394
));
395
}
396
397
polars_io::pl_async::get_runtime().block_on(async move {
398
for handle in join_handles {
399
handle.await?;
400
}
401
PolarsResult::Ok(())
402
})
403
})?;
404
}
405
406
Ok(Some(build_state))
407
}
408
}
409
410
#[derive(Default)]
411
struct LocalBuilder {
412
// The complete list of morsels and their computed hashes seen by this builder.
413
morsels: Vec<(MorselSeq, DataFrame, HashKeys)>,
414
415
// A cardinality sketch per partition for the keys seen by this builder.
416
sketch_per_p: Vec<CardinalitySketch>,
417
418
// morsel_idxs_values_per_p[p][start..stop] contains the offsets into morsels[i]
419
// for partition p, where start, stop are:
420
// let start = morsel_idxs_offsets[i * num_partitions + p];
421
// let stop = morsel_idxs_offsets[(i + 1) * num_partitions + p];
422
morsel_idxs_values_per_p: Vec<Vec<IdxSize>>,
423
morsel_idxs_offsets_per_p: Vec<usize>,
424
}
425
426
struct BuildState {
427
local_builders: Vec<LocalBuilder>,
428
sampled_probe_morsels: BufferedStream,
429
}
430
431
impl BuildState {
432
fn new(
433
num_pipelines: usize,
434
num_partitions: usize,
435
sampled_probe_morsels: BufferedStream,
436
) -> Self {
437
let local_builders = (0..num_pipelines)
438
.map(|_| LocalBuilder {
439
morsels: Vec::new(),
440
sketch_per_p: vec![CardinalitySketch::default(); num_partitions],
441
morsel_idxs_values_per_p: vec![Vec::new(); num_partitions],
442
morsel_idxs_offsets_per_p: vec![0; num_partitions],
443
})
444
.collect();
445
Self {
446
local_builders,
447
sampled_probe_morsels,
448
}
449
}
450
451
async fn partition_and_sink(
452
mut recv: Receiver<Morsel>,
453
local: &mut LocalBuilder,
454
partitioner: HashPartitioner,
455
params: &EquiJoinParams,
456
state: &StreamingExecutionState,
457
) -> PolarsResult<()> {
458
let track_unmatchable = params.emit_unmatched_build();
459
let (key_selectors, payload_selector);
460
if params.left_is_build.unwrap() {
461
payload_selector = &params.left_payload_select;
462
key_selectors = &params.left_key_selectors;
463
} else {
464
payload_selector = &params.right_payload_select;
465
key_selectors = &params.right_key_selectors;
466
};
467
468
while let Ok(morsel) = recv.recv().await {
469
// Compute hashed keys and payload. We must rechunk the payload for
470
// later gathers.
471
let hash_keys = select_keys(
472
morsel.df(),
473
key_selectors,
474
params,
475
&state.in_memory_exec_state,
476
)
477
.await?;
478
let mut payload = select_payload(morsel.df().clone(), payload_selector);
479
payload.rechunk_mut();
480
481
hash_keys.gen_idxs_per_partition(
482
&partitioner,
483
&mut local.morsel_idxs_values_per_p,
484
&mut local.sketch_per_p,
485
track_unmatchable,
486
);
487
488
local
489
.morsel_idxs_offsets_per_p
490
.extend(local.morsel_idxs_values_per_p.iter().map(|vp| vp.len()));
491
local.morsels.push((morsel.seq(), payload, hash_keys));
492
}
493
Ok(())
494
}
495
496
fn finalize_ordered(&mut self, params: &EquiJoinParams, table: &dyn IdxTable) -> ProbeState {
497
let track_unmatchable = params.emit_unmatched_build();
498
let payload_schema = if params.left_is_build.unwrap() {
499
&params.left_payload_schema
500
} else {
501
&params.right_payload_schema
502
};
503
504
let num_partitions = self.local_builders[0].sketch_per_p.len();
505
let local_builders = &self.local_builders;
506
let probe_tables: SparseInitVec<ProbeTable> = SparseInitVec::with_capacity(num_partitions);
507
508
POOL.scope(|s| {
509
for p in 0..num_partitions {
510
let probe_tables = &probe_tables;
511
s.spawn(move |_| {
512
// TODO: every thread does an identical linearize, we can do a single parallel one.
513
let mut kmerge = BinaryHeap::with_capacity(local_builders.len());
514
let mut cur_idx_per_loc = vec![0; local_builders.len()];
515
516
// Compute cardinality estimate and total amount of
517
// payload for this partition, and initialize k-way merge.
518
let mut sketch = CardinalitySketch::new();
519
let mut payload_rows = 0;
520
for (l_idx, l) in local_builders.iter().enumerate() {
521
let Some((seq, _, _)) = l.morsels.first() else {
522
continue;
523
};
524
kmerge.push(Priority(Reverse(seq), l_idx));
525
526
sketch.combine(&l.sketch_per_p[p]);
527
let offsets_len = l.morsel_idxs_offsets_per_p.len();
528
payload_rows +=
529
l.morsel_idxs_offsets_per_p[offsets_len - num_partitions + p];
530
}
531
532
// Allocate hash table and payload builder.
533
let mut p_table = table.new_empty();
534
p_table.reserve(sketch.estimate() * 5 / 4);
535
let mut p_payload = DataFrameBuilder::new(payload_schema.clone());
536
p_payload.reserve(payload_rows);
537
538
let mut p_seq_ids = Vec::new();
539
if track_unmatchable {
540
p_seq_ids.reserve(payload_rows);
541
}
542
543
// Linearize and build.
544
unsafe {
545
let mut norm_seq_id = 0 as IdxSize;
546
while let Some(Priority(Reverse(_seq), l_idx)) = kmerge.pop() {
547
let l = local_builders.get_unchecked(l_idx);
548
let idx_in_l = *cur_idx_per_loc.get_unchecked(l_idx);
549
*cur_idx_per_loc.get_unchecked_mut(l_idx) += 1;
550
if let Some((next_seq, _, _)) = l.morsels.get(idx_in_l + 1) {
551
kmerge.push(Priority(Reverse(next_seq), l_idx));
552
}
553
554
let (_mseq, payload, keys) = l.morsels.get_unchecked(idx_in_l);
555
let p_morsel_idxs_start =
556
l.morsel_idxs_offsets_per_p[idx_in_l * num_partitions + p];
557
let p_morsel_idxs_stop =
558
l.morsel_idxs_offsets_per_p[(idx_in_l + 1) * num_partitions + p];
559
let p_morsel_idxs = &l.morsel_idxs_values_per_p[p]
560
[p_morsel_idxs_start..p_morsel_idxs_stop];
561
p_table.insert_keys_subset(keys, p_morsel_idxs, track_unmatchable);
562
p_payload.gather_extend(payload, p_morsel_idxs, ShareStrategy::Never);
563
564
if track_unmatchable {
565
p_seq_ids.resize(p_payload.len(), norm_seq_id);
566
norm_seq_id += 1;
567
}
568
}
569
}
570
571
probe_tables
572
.try_set(
573
p,
574
ProbeTable {
575
hash_table: p_table,
576
payload: p_payload.freeze(),
577
seq_ids: p_seq_ids,
578
},
579
)
580
.ok()
581
.unwrap();
582
});
583
}
584
});
585
586
ProbeState {
587
table_per_partition: probe_tables.try_assume_init().ok().unwrap(),
588
max_seq_sent: MorselSeq::default(),
589
sampled_probe_morsels: core::mem::take(&mut self.sampled_probe_morsels),
590
unordered_morsel_seq: AtomicU64::new(0),
591
}
592
}
593
594
fn finalize_unordered(&mut self, params: &EquiJoinParams, table: &dyn IdxTable) -> ProbeState {
595
let track_unmatchable = params.emit_unmatched_build();
596
let payload_schema = if params.left_is_build.unwrap() {
597
&params.left_payload_schema
598
} else {
599
&params.right_payload_schema
600
};
601
602
// To reduce maximum memory usage we want to drop the morsels
603
// as soon as they're processed, so we move into Arcs. The drops might
604
// also be expensive, so instead of directly dropping we put that on
605
// a work queue.
606
let morsels_per_local_builder = self
607
.local_builders
608
.iter_mut()
609
.map(|b| Arc::new(core::mem::take(&mut b.morsels)))
610
.collect_vec();
611
let (morsel_drop_q_send, morsel_drop_q_recv) =
612
async_channel::bounded(morsels_per_local_builder.len());
613
let num_partitions = self.local_builders[0].sketch_per_p.len();
614
let local_builders = &self.local_builders;
615
let probe_tables: SparseInitVec<ProbeTable> = SparseInitVec::with_capacity(num_partitions);
616
617
async_executor::task_scope(|s| {
618
// Wrap in outer Arc to move to each thread, performing the
619
// expensive clone on that thread.
620
let arc_morsels_per_local_builder = Arc::new(morsels_per_local_builder);
621
let mut join_handles = Vec::new();
622
for p in 0..num_partitions {
623
let arc_morsels_per_local_builder = Arc::clone(&arc_morsels_per_local_builder);
624
let morsel_drop_q_send = morsel_drop_q_send.clone();
625
let morsel_drop_q_recv = morsel_drop_q_recv.clone();
626
let probe_tables = &probe_tables;
627
join_handles.push(s.spawn_task(TaskPriority::High, async move {
628
// Extract from outer arc and drop outer arc.
629
let morsels_per_local_builder =
630
Arc::unwrap_or_clone(arc_morsels_per_local_builder);
631
632
// Compute cardinality estimate and total amount of
633
// payload for this partition.
634
let mut sketch = CardinalitySketch::new();
635
let mut payload_rows = 0;
636
for l in local_builders {
637
sketch.combine(&l.sketch_per_p[p]);
638
let offsets_len = l.morsel_idxs_offsets_per_p.len();
639
payload_rows +=
640
l.morsel_idxs_offsets_per_p[offsets_len - num_partitions + p];
641
}
642
643
// Allocate hash table and payload builder.
644
let mut p_table = table.new_empty();
645
p_table.reserve(sketch.estimate() * 5 / 4);
646
let mut p_payload = DataFrameBuilder::new(payload_schema.clone());
647
p_payload.reserve(payload_rows);
648
649
// Build.
650
let mut skip_drop_attempt = false;
651
for (l, l_morsels) in local_builders.iter().zip(morsels_per_local_builder) {
652
// Try to help with dropping the processed morsels.
653
if !skip_drop_attempt {
654
drop(morsel_drop_q_recv.try_recv());
655
}
656
657
for (i, morsel) in l_morsels.iter().enumerate() {
658
let (_mseq, payload, keys) = morsel;
659
unsafe {
660
let p_morsel_idxs_start =
661
l.morsel_idxs_offsets_per_p[i * num_partitions + p];
662
let p_morsel_idxs_stop =
663
l.morsel_idxs_offsets_per_p[(i + 1) * num_partitions + p];
664
let p_morsel_idxs = &l.morsel_idxs_values_per_p[p]
665
[p_morsel_idxs_start..p_morsel_idxs_stop];
666
p_table.insert_keys_subset(keys, p_morsel_idxs, track_unmatchable);
667
p_payload.gather_extend(
668
payload,
669
p_morsel_idxs,
670
ShareStrategy::Never,
671
);
672
}
673
}
674
675
if let Some(l) = Arc::into_inner(l_morsels) {
676
// If we're the last thread to process this set of morsels we're probably
677
// falling behind the rest, since the drop can be quite expensive we skip
678
// a drop attempt hoping someone else will pick up the slack.
679
drop(morsel_drop_q_send.try_send(l));
680
skip_drop_attempt = true;
681
} else {
682
skip_drop_attempt = false;
683
}
684
}
685
686
// We're done, help others out by doing drops.
687
drop(morsel_drop_q_send); // So we don't deadlock trying to receive from ourselves.
688
while let Ok(l_morsels) = morsel_drop_q_recv.recv().await {
689
drop(l_morsels);
690
}
691
692
probe_tables
693
.try_set(
694
p,
695
ProbeTable {
696
hash_table: p_table,
697
payload: p_payload.freeze(),
698
seq_ids: Vec::new(),
699
},
700
)
701
.ok()
702
.unwrap();
703
}));
704
}
705
706
// Drop outer arc after spawning each thread so the inner arcs
707
// can get dropped as soon as they're processed. We also have to
708
// drop the drop queue sender so we don't deadlock waiting for it
709
// to end.
710
drop(arc_morsels_per_local_builder);
711
drop(morsel_drop_q_send);
712
713
polars_io::pl_async::get_runtime().block_on(async move {
714
for handle in join_handles {
715
handle.await;
716
}
717
});
718
});
719
720
ProbeState {
721
table_per_partition: probe_tables.try_assume_init().ok().unwrap(),
722
max_seq_sent: MorselSeq::default(),
723
sampled_probe_morsels: core::mem::take(&mut self.sampled_probe_morsels),
724
unordered_morsel_seq: AtomicU64::new(0),
725
}
726
}
727
}
728
729
struct ProbeTable {
730
hash_table: Box<dyn IdxTable>,
731
payload: DataFrame,
732
seq_ids: Vec<IdxSize>,
733
}
734
735
struct ProbeState {
736
table_per_partition: Vec<ProbeTable>,
737
max_seq_sent: MorselSeq,
738
sampled_probe_morsels: BufferedStream,
739
740
// For unordered joins we relabel output morsels to speed up the linearizer.
741
unordered_morsel_seq: AtomicU64,
742
}
743
744
impl ProbeState {
745
/// Returns the max morsel sequence sent.
746
async fn partition_and_probe(
747
mut recv: Receiver<Morsel>,
748
mut send: Sender<Morsel>,
749
partitions: &[ProbeTable],
750
unordered_morsel_seq: &AtomicU64,
751
partitioner: HashPartitioner,
752
params: &EquiJoinParams,
753
state: &StreamingExecutionState,
754
) -> PolarsResult<MorselSeq> {
755
// TODO: shuffle after partitioning and keep probe tables thread-local.
756
let mut partition_idxs = vec![Vec::new(); partitioner.num_partitions()];
757
let mut probe_partitions = Vec::new();
758
let mut materialized_idxsize_range = Vec::new();
759
let mut table_match = Vec::new();
760
let mut probe_match = Vec::new();
761
let mut max_seq = MorselSeq::default();
762
763
let probe_limit = get_ideal_morsel_size() as IdxSize;
764
let mark_matches = params.emit_unmatched_build();
765
let emit_unmatched = params.emit_unmatched_probe();
766
767
let (key_selectors, payload_selector, build_payload_schema, probe_payload_schema);
768
if params.left_is_build.unwrap() {
769
key_selectors = &params.right_key_selectors;
770
payload_selector = &params.right_payload_select;
771
build_payload_schema = &params.left_payload_schema;
772
probe_payload_schema = &params.right_payload_schema;
773
} else {
774
key_selectors = &params.left_key_selectors;
775
payload_selector = &params.left_payload_select;
776
build_payload_schema = &params.right_payload_schema;
777
probe_payload_schema = &params.left_payload_schema;
778
};
779
780
let mut build_out = DataFrameBuilder::new(build_payload_schema.clone());
781
let mut probe_out = DataFrameBuilder::new(probe_payload_schema.clone());
782
783
// A simple estimate used to size reserves.
784
let mut selectivity_estimate = 1.0;
785
let mut selectivity_estimate_confidence = 0.0;
786
787
while let Ok(morsel) = recv.recv().await {
788
// Compute hashed keys and payload.
789
let (df, in_seq, src_token, wait_token) = morsel.into_inner();
790
791
let df_height = df.height();
792
if df_height == 0 {
793
continue;
794
}
795
796
let hash_keys =
797
select_keys(&df, key_selectors, params, &state.in_memory_exec_state).await?;
798
let mut payload = select_payload(df, payload_selector);
799
let mut payload_rechunked = false; // We don't eagerly rechunk because there might be no matches.
800
let mut total_matches = 0;
801
802
// Use selectivity estimate to reserve for morsel builders.
803
let max_match_per_key_est = (selectivity_estimate * 1.2) as usize + 16;
804
let out_est_size = ((selectivity_estimate * 1.2 * df_height as f64) as usize)
805
.min(probe_limit as usize);
806
build_out.reserve(out_est_size + max_match_per_key_est);
807
808
unsafe {
809
let mut new_morsel =
810
|build: &mut DataFrameBuilder, probe: &mut DataFrameBuilder| {
811
let mut build_df = build.freeze_reset();
812
let mut probe_df = probe.freeze_reset();
813
let out_df = if params.left_is_build.unwrap() {
814
build_df.hstack_mut_unchecked(probe_df.get_columns());
815
build_df
816
} else {
817
probe_df.hstack_mut_unchecked(build_df.get_columns());
818
probe_df
819
};
820
let out_df = postprocess_join(out_df, params);
821
let out_seq = if params.preserve_order_probe {
822
in_seq
823
} else {
824
MorselSeq::new(unordered_morsel_seq.fetch_add(1, Ordering::Relaxed))
825
};
826
max_seq = out_seq;
827
Morsel::new(out_df, out_seq, src_token.clone())
828
};
829
830
if params.preserve_order_probe {
831
// To preserve the order we can't do bulk probes per partition and must follow
832
// the order of the probe morsel. We can still group probes that are
833
// consecutively on the same partition.
834
probe_partitions.clear();
835
hash_keys.gen_partitions(&partitioner, &mut probe_partitions, emit_unmatched);
836
837
let mut probe_group_start = 0;
838
while probe_group_start < probe_partitions.len() {
839
let p_idx = probe_partitions[probe_group_start];
840
let mut probe_group_end = probe_group_start + 1;
841
while probe_partitions.get(probe_group_end) == Some(&p_idx) {
842
probe_group_end += 1;
843
}
844
let Some(p) = partitions.get(p_idx as usize) else {
845
probe_group_start = probe_group_end;
846
continue;
847
};
848
849
materialized_idxsize_range.extend(
850
materialized_idxsize_range.len() as IdxSize..probe_group_end as IdxSize,
851
);
852
853
while probe_group_start < probe_group_end {
854
let matches_before_limit = probe_limit - probe_match.len() as IdxSize;
855
table_match.clear();
856
probe_group_start += p.hash_table.probe_subset(
857
&hash_keys,
858
&materialized_idxsize_range[probe_group_start..probe_group_end],
859
&mut table_match,
860
&mut probe_match,
861
mark_matches,
862
emit_unmatched,
863
matches_before_limit,
864
) as usize;
865
866
if emit_unmatched {
867
build_out.opt_gather_extend(
868
&p.payload,
869
&table_match,
870
ShareStrategy::Always,
871
);
872
} else {
873
build_out.gather_extend(
874
&p.payload,
875
&table_match,
876
ShareStrategy::Always,
877
);
878
};
879
880
if probe_match.len() >= probe_limit as usize
881
|| probe_group_start == probe_partitions.len()
882
{
883
if !payload_rechunked {
884
payload.rechunk_mut();
885
payload_rechunked = true;
886
}
887
probe_out.gather_extend(
888
&payload,
889
&probe_match,
890
ShareStrategy::Always,
891
);
892
let out_len = probe_match.len();
893
probe_match.clear();
894
let out_morsel = new_morsel(&mut build_out, &mut probe_out);
895
if send.send(out_morsel).await.is_err() {
896
return Ok(max_seq);
897
}
898
if probe_group_end != probe_partitions.len() {
899
// We had enough matches to need a mid-partition flush, let's assume there are a lot of
900
// matches and just do a large reserve.
901
let old_est = probe_limit as usize + max_match_per_key_est;
902
build_out.reserve(old_est.max(out_len + 16));
903
}
904
}
905
}
906
}
907
} else {
908
// Partition and probe the tables.
909
for p in partition_idxs.iter_mut() {
910
p.clear();
911
}
912
hash_keys.gen_idxs_per_partition(
913
&partitioner,
914
&mut partition_idxs,
915
&mut [],
916
emit_unmatched,
917
);
918
919
for (p, idxs_in_p) in partitions.iter().zip(&partition_idxs) {
920
let mut offset = 0;
921
while offset < idxs_in_p.len() {
922
let matches_before_limit = probe_limit - probe_match.len() as IdxSize;
923
table_match.clear();
924
offset += p.hash_table.probe_subset(
925
&hash_keys,
926
&idxs_in_p[offset..],
927
&mut table_match,
928
&mut probe_match,
929
mark_matches,
930
emit_unmatched,
931
matches_before_limit,
932
) as usize;
933
934
if table_match.is_empty() {
935
continue;
936
}
937
total_matches += table_match.len();
938
939
if emit_unmatched {
940
build_out.opt_gather_extend(
941
&p.payload,
942
&table_match,
943
ShareStrategy::Always,
944
);
945
} else {
946
build_out.gather_extend(
947
&p.payload,
948
&table_match,
949
ShareStrategy::Always,
950
);
951
};
952
953
if probe_match.len() >= probe_limit as usize {
954
if !payload_rechunked {
955
payload.rechunk_mut();
956
payload_rechunked = true;
957
}
958
probe_out.gather_extend(
959
&payload,
960
&probe_match,
961
ShareStrategy::Always,
962
);
963
let out_len = probe_match.len();
964
probe_match.clear();
965
let out_morsel = new_morsel(&mut build_out, &mut probe_out);
966
if send.send(out_morsel).await.is_err() {
967
return Ok(max_seq);
968
}
969
// We had enough matches to need a mid-partition flush, let's assume there are a lot of
970
// matches and just do a large reserve.
971
let old_est = probe_limit as usize + max_match_per_key_est;
972
build_out.reserve(old_est.max(out_len + 16));
973
}
974
}
975
}
976
}
977
978
if !probe_match.is_empty() {
979
if !payload_rechunked {
980
payload.rechunk_mut();
981
}
982
probe_out.gather_extend(&payload, &probe_match, ShareStrategy::Always);
983
probe_match.clear();
984
let out_morsel = new_morsel(&mut build_out, &mut probe_out);
985
if send.send(out_morsel).await.is_err() {
986
return Ok(max_seq);
987
}
988
}
989
}
990
991
drop(wait_token);
992
993
// Move selectivity estimate a bit towards latest value. Allows rapid changes at first.
994
// TODO: implement something more re-usable and robust.
995
selectivity_estimate = selectivity_estimate_confidence * selectivity_estimate
996
+ (1.0 - selectivity_estimate_confidence)
997
* (total_matches as f64 / df_height as f64);
998
selectivity_estimate_confidence = (selectivity_estimate_confidence + 0.1).min(0.8);
999
}
1000
1001
Ok(max_seq)
1002
}
1003
1004
fn ordered_unmatched(&mut self, params: &EquiJoinParams) -> DataFrame {
1005
// TODO: parallelize this operator.
1006
1007
let build_payload_schema = if params.left_is_build.unwrap() {
1008
&params.left_payload_schema
1009
} else {
1010
&params.right_payload_schema
1011
};
1012
1013
let mut unmarked_idxs = Vec::new();
1014
let mut linearized_idxs = Vec::new();
1015
1016
for (p_idx, p) in self.table_per_partition.iter().enumerate_idx() {
1017
p.hash_table
1018
.unmarked_keys(&mut unmarked_idxs, 0, IdxSize::MAX);
1019
linearized_idxs.extend(
1020
unmarked_idxs
1021
.iter()
1022
.map(|i| (unsafe { *p.seq_ids.get_unchecked(*i as usize) }, p_idx, *i)),
1023
);
1024
}
1025
1026
linearized_idxs.sort_by_key(|(seq_id, _, _)| *seq_id);
1027
1028
unsafe {
1029
let mut build_out = DataFrameBuilder::new(build_payload_schema.clone());
1030
build_out.reserve(linearized_idxs.len());
1031
1032
// Group indices from the same partition.
1033
let mut group_start = 0;
1034
let mut gather_idxs = Vec::new();
1035
while group_start < linearized_idxs.len() {
1036
gather_idxs.clear();
1037
1038
let (_seq, p_idx, idx_in_p) = linearized_idxs[group_start];
1039
gather_idxs.push(idx_in_p);
1040
let mut group_end = group_start + 1;
1041
while group_end < linearized_idxs.len() && linearized_idxs[group_end].1 == p_idx {
1042
gather_idxs.push(linearized_idxs[group_end].2);
1043
group_end += 1;
1044
}
1045
1046
build_out.gather_extend(
1047
&self.table_per_partition[p_idx as usize].payload,
1048
&gather_idxs,
1049
ShareStrategy::Never, // Don't keep entire table alive for unmatched indices.
1050
);
1051
1052
group_start = group_end;
1053
}
1054
1055
let mut build_df = build_out.freeze();
1056
let out_df = if params.left_is_build.unwrap() {
1057
let probe_df =
1058
DataFrame::full_null(&params.right_payload_schema, build_df.height());
1059
build_df.hstack_mut_unchecked(probe_df.get_columns());
1060
build_df
1061
} else {
1062
let mut probe_df =
1063
DataFrame::full_null(&params.left_payload_schema, build_df.height());
1064
probe_df.hstack_mut_unchecked(build_df.get_columns());
1065
probe_df
1066
};
1067
postprocess_join(out_df, params)
1068
}
1069
}
1070
}
1071
1072
impl Drop for ProbeState {
1073
fn drop(&mut self) {
1074
POOL.install(|| {
1075
// Parallel drop as the state might be quite big.
1076
self.table_per_partition.par_drain(..).for_each(drop);
1077
})
1078
}
1079
}
1080
1081
struct EmitUnmatchedState {
1082
partitions: Vec<ProbeTable>,
1083
active_partition_idx: usize,
1084
offset_in_active_p: usize,
1085
morsel_seq: MorselSeq,
1086
}
1087
1088
impl EmitUnmatchedState {
1089
async fn emit_unmatched(
1090
&mut self,
1091
mut send: Sender<Morsel>,
1092
params: &EquiJoinParams,
1093
num_pipelines: usize,
1094
) -> PolarsResult<()> {
1095
let total_len: usize = self
1096
.partitions
1097
.iter()
1098
.map(|p| p.hash_table.num_keys() as usize)
1099
.sum();
1100
let ideal_morsel_count = (total_len / get_ideal_morsel_size()).max(1);
1101
let morsel_count = ideal_morsel_count.next_multiple_of(num_pipelines);
1102
let morsel_size = total_len.div_ceil(morsel_count).max(1);
1103
1104
let wait_group = WaitGroup::default();
1105
let source_token = SourceToken::new();
1106
let mut unmarked_idxs = Vec::new();
1107
while let Some(p) = self.partitions.get(self.active_partition_idx) {
1108
loop {
1109
// Generate a chunk of unmarked key indices.
1110
self.offset_in_active_p += p.hash_table.unmarked_keys(
1111
&mut unmarked_idxs,
1112
self.offset_in_active_p as IdxSize,
1113
morsel_size as IdxSize,
1114
) as usize;
1115
if unmarked_idxs.is_empty() {
1116
break;
1117
}
1118
1119
// Gather and create full-null counterpart.
1120
let out_df = unsafe {
1121
let mut build_df = p.payload.take_slice_unchecked_impl(&unmarked_idxs, false);
1122
let len = build_df.height();
1123
if params.left_is_build.unwrap() {
1124
let probe_df = DataFrame::full_null(&params.right_payload_schema, len);
1125
build_df.hstack_mut_unchecked(probe_df.get_columns());
1126
build_df
1127
} else {
1128
let mut probe_df = DataFrame::full_null(&params.left_payload_schema, len);
1129
probe_df.hstack_mut_unchecked(build_df.get_columns());
1130
probe_df
1131
}
1132
};
1133
let out_df = postprocess_join(out_df, params);
1134
1135
// Send and wait until consume token is consumed.
1136
let mut morsel = Morsel::new(out_df, self.morsel_seq, source_token.clone());
1137
self.morsel_seq = self.morsel_seq.successor();
1138
morsel.set_consume_token(wait_group.token());
1139
if send.send(morsel).await.is_err() {
1140
return Ok(());
1141
}
1142
1143
wait_group.wait().await;
1144
if source_token.stop_requested() {
1145
return Ok(());
1146
}
1147
}
1148
1149
self.active_partition_idx += 1;
1150
self.offset_in_active_p = 0;
1151
}
1152
1153
Ok(())
1154
}
1155
}
1156
1157
enum EquiJoinState {
1158
Sample(SampleState),
1159
Build(BuildState),
1160
Probe(ProbeState),
1161
EmitUnmatchedBuild(EmitUnmatchedState),
1162
EmitUnmatchedBuildInOrder(InMemorySourceNode),
1163
Done,
1164
}
1165
1166
pub struct EquiJoinNode {
1167
state: EquiJoinState,
1168
params: EquiJoinParams,
1169
table: Box<dyn IdxTable>,
1170
}
1171
1172
impl EquiJoinNode {
1173
#[allow(clippy::too_many_arguments)]
1174
pub fn new(
1175
left_input_schema: Arc<Schema>,
1176
right_input_schema: Arc<Schema>,
1177
left_key_schema: Arc<Schema>,
1178
right_key_schema: Arc<Schema>,
1179
unique_key_schema: Arc<Schema>,
1180
left_key_selectors: Vec<StreamExpr>,
1181
right_key_selectors: Vec<StreamExpr>,
1182
args: JoinArgs,
1183
num_pipelines: usize,
1184
) -> PolarsResult<Self> {
1185
let left_is_build = match args.maintain_order {
1186
MaintainOrderJoin::None => {
1187
if *JOIN_SAMPLE_LIMIT == 0 {
1188
Some(true)
1189
} else {
1190
None
1191
}
1192
},
1193
MaintainOrderJoin::Left | MaintainOrderJoin::LeftRight => Some(false),
1194
MaintainOrderJoin::Right | MaintainOrderJoin::RightLeft => Some(true),
1195
};
1196
1197
let preserve_order_probe = args.maintain_order != MaintainOrderJoin::None;
1198
let preserve_order_build = matches!(
1199
args.maintain_order,
1200
MaintainOrderJoin::LeftRight | MaintainOrderJoin::RightLeft
1201
);
1202
1203
let left_payload_select = compute_payload_selector(
1204
&left_input_schema,
1205
&right_input_schema,
1206
&left_key_schema,
1207
&right_key_schema,
1208
true,
1209
&args,
1210
)?;
1211
let right_payload_select = compute_payload_selector(
1212
&right_input_schema,
1213
&left_input_schema,
1214
&right_key_schema,
1215
&left_key_schema,
1216
false,
1217
&args,
1218
)?;
1219
1220
let state = if left_is_build.is_some() {
1221
EquiJoinState::Build(BuildState::new(
1222
num_pipelines,
1223
num_pipelines,
1224
BufferedStream::default(),
1225
))
1226
} else {
1227
EquiJoinState::Sample(SampleState::default())
1228
};
1229
1230
let left_payload_schema = Arc::new(select_schema(&left_input_schema, &left_payload_select));
1231
let right_payload_schema =
1232
Arc::new(select_schema(&right_input_schema, &right_payload_select));
1233
Ok(Self {
1234
state,
1235
params: EquiJoinParams {
1236
left_is_build,
1237
preserve_order_build,
1238
preserve_order_probe,
1239
left_key_schema,
1240
left_key_selectors,
1241
right_key_schema,
1242
right_key_selectors,
1243
left_payload_select,
1244
right_payload_select,
1245
left_payload_schema,
1246
right_payload_schema,
1247
args,
1248
random_state: PlRandomState::default(),
1249
},
1250
table: new_idx_table(unique_key_schema),
1251
})
1252
}
1253
}
1254
1255
impl ComputeNode for EquiJoinNode {
1256
fn name(&self) -> &str {
1257
"equi-join"
1258
}
1259
1260
fn update_state(
1261
&mut self,
1262
recv: &mut [PortState],
1263
send: &mut [PortState],
1264
state: &StreamingExecutionState,
1265
) -> PolarsResult<()> {
1266
assert!(recv.len() == 2 && send.len() == 1);
1267
1268
// If the output doesn't want any more data, transition to being done.
1269
if send[0] == PortState::Done {
1270
self.state = EquiJoinState::Done;
1271
}
1272
1273
// If we are sampling and both sides are done/filled, transition to building.
1274
if let EquiJoinState::Sample(sample_state) = &mut self.state {
1275
if let Some(build_state) =
1276
sample_state.try_transition_to_build(recv, &mut self.params, state)?
1277
{
1278
self.state = EquiJoinState::Build(build_state);
1279
}
1280
}
1281
1282
let build_idx = if self.params.left_is_build == Some(true) {
1283
0
1284
} else {
1285
1
1286
};
1287
let probe_idx = 1 - build_idx;
1288
1289
// If we are building and the build input is done, transition to probing.
1290
if let EquiJoinState::Build(build_state) = &mut self.state {
1291
if recv[build_idx] == PortState::Done {
1292
let probe_state = if self.params.preserve_order_build {
1293
build_state.finalize_ordered(&self.params, &*self.table)
1294
} else {
1295
build_state.finalize_unordered(&self.params, &*self.table)
1296
};
1297
self.state = EquiJoinState::Probe(probe_state);
1298
}
1299
}
1300
1301
// If we are probing and the probe input is done, emit unmatched if
1302
// necessary, otherwise we're done.
1303
if let EquiJoinState::Probe(probe_state) = &mut self.state {
1304
let samples_consumed = probe_state.sampled_probe_morsels.is_empty();
1305
if samples_consumed && recv[probe_idx] == PortState::Done {
1306
if self.params.emit_unmatched_build() {
1307
if self.params.preserve_order_build {
1308
let unmatched = probe_state.ordered_unmatched(&self.params);
1309
let src = InMemorySourceNode::new(
1310
Arc::new(unmatched),
1311
probe_state.max_seq_sent.successor(),
1312
);
1313
self.state = EquiJoinState::EmitUnmatchedBuildInOrder(src);
1314
} else {
1315
self.state = EquiJoinState::EmitUnmatchedBuild(EmitUnmatchedState {
1316
partitions: core::mem::take(&mut probe_state.table_per_partition),
1317
active_partition_idx: 0,
1318
offset_in_active_p: 0,
1319
morsel_seq: probe_state.max_seq_sent.successor(),
1320
});
1321
}
1322
} else {
1323
self.state = EquiJoinState::Done;
1324
}
1325
}
1326
}
1327
1328
// Finally, check if we are done emitting unmatched keys.
1329
if let EquiJoinState::EmitUnmatchedBuild(emit_state) = &mut self.state {
1330
if emit_state.active_partition_idx >= emit_state.partitions.len() {
1331
self.state = EquiJoinState::Done;
1332
}
1333
}
1334
1335
match &mut self.state {
1336
EquiJoinState::Sample(sample_state) => {
1337
send[0] = PortState::Blocked;
1338
if recv[0] != PortState::Done {
1339
recv[0] = if sample_state.left_len < *JOIN_SAMPLE_LIMIT {
1340
PortState::Ready
1341
} else {
1342
PortState::Blocked
1343
};
1344
}
1345
if recv[1] != PortState::Done {
1346
recv[1] = if sample_state.right_len < *JOIN_SAMPLE_LIMIT {
1347
PortState::Ready
1348
} else {
1349
PortState::Blocked
1350
};
1351
}
1352
},
1353
EquiJoinState::Build(_) => {
1354
send[0] = PortState::Blocked;
1355
if recv[build_idx] != PortState::Done {
1356
recv[build_idx] = PortState::Ready;
1357
}
1358
if recv[probe_idx] != PortState::Done {
1359
recv[probe_idx] = PortState::Blocked;
1360
}
1361
},
1362
EquiJoinState::Probe(probe_state) => {
1363
if recv[probe_idx] != PortState::Done {
1364
core::mem::swap(&mut send[0], &mut recv[probe_idx]);
1365
} else {
1366
let samples_consumed = probe_state.sampled_probe_morsels.is_empty();
1367
send[0] = if samples_consumed {
1368
PortState::Done
1369
} else {
1370
PortState::Ready
1371
};
1372
}
1373
recv[build_idx] = PortState::Done;
1374
},
1375
EquiJoinState::EmitUnmatchedBuild(_) => {
1376
send[0] = PortState::Ready;
1377
recv[build_idx] = PortState::Done;
1378
recv[probe_idx] = PortState::Done;
1379
},
1380
EquiJoinState::EmitUnmatchedBuildInOrder(src_node) => {
1381
recv[build_idx] = PortState::Done;
1382
recv[probe_idx] = PortState::Done;
1383
src_node.update_state(&mut [], &mut send[0..1], state)?;
1384
if send[0] == PortState::Done {
1385
self.state = EquiJoinState::Done;
1386
}
1387
},
1388
EquiJoinState::Done => {
1389
send[0] = PortState::Done;
1390
recv[0] = PortState::Done;
1391
recv[1] = PortState::Done;
1392
},
1393
}
1394
Ok(())
1395
}
1396
1397
fn is_memory_intensive_pipeline_blocker(&self) -> bool {
1398
matches!(
1399
self.state,
1400
EquiJoinState::Sample { .. } | EquiJoinState::Build { .. }
1401
)
1402
}
1403
1404
fn spawn<'env, 's>(
1405
&'env mut self,
1406
scope: &'s TaskScope<'s, 'env>,
1407
recv_ports: &mut [Option<RecvPort<'_>>],
1408
send_ports: &mut [Option<SendPort<'_>>],
1409
state: &'s StreamingExecutionState,
1410
join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,
1411
) {
1412
assert!(recv_ports.len() == 2);
1413
assert!(send_ports.len() == 1);
1414
1415
let build_idx = if self.params.left_is_build == Some(true) {
1416
0
1417
} else {
1418
1
1419
};
1420
let probe_idx = 1 - build_idx;
1421
1422
match &mut self.state {
1423
EquiJoinState::Sample(sample_state) => {
1424
assert!(send_ports[0].is_none());
1425
let left_final_len = Arc::new(RelaxedCell::from(if recv_ports[0].is_none() {
1426
sample_state.left_len
1427
} else {
1428
usize::MAX
1429
}));
1430
let right_final_len = Arc::new(RelaxedCell::from(if recv_ports[1].is_none() {
1431
sample_state.right_len
1432
} else {
1433
usize::MAX
1434
}));
1435
1436
if let Some(left_recv) = recv_ports[0].take() {
1437
join_handles.push(scope.spawn_task(
1438
TaskPriority::High,
1439
SampleState::sink(
1440
left_recv.serial(),
1441
&mut sample_state.left,
1442
&mut sample_state.left_len,
1443
left_final_len.clone(),
1444
right_final_len.clone(),
1445
),
1446
));
1447
}
1448
if let Some(right_recv) = recv_ports[1].take() {
1449
join_handles.push(scope.spawn_task(
1450
TaskPriority::High,
1451
SampleState::sink(
1452
right_recv.serial(),
1453
&mut sample_state.right,
1454
&mut sample_state.right_len,
1455
right_final_len,
1456
left_final_len,
1457
),
1458
));
1459
}
1460
},
1461
EquiJoinState::Build(build_state) => {
1462
assert!(send_ports[0].is_none());
1463
assert!(recv_ports[probe_idx].is_none());
1464
let receivers = recv_ports[build_idx].take().unwrap().parallel();
1465
1466
let partitioner = HashPartitioner::new(state.num_pipelines, 0);
1467
for (local_builder, recv) in build_state.local_builders.iter_mut().zip(receivers) {
1468
join_handles.push(scope.spawn_task(
1469
TaskPriority::High,
1470
BuildState::partition_and_sink(
1471
recv,
1472
local_builder,
1473
partitioner.clone(),
1474
&self.params,
1475
state,
1476
),
1477
));
1478
}
1479
},
1480
EquiJoinState::Probe(probe_state) => {
1481
assert!(recv_ports[build_idx].is_none());
1482
let senders = send_ports[0].take().unwrap().parallel();
1483
let receivers = probe_state
1484
.sampled_probe_morsels
1485
.reinsert(
1486
state.num_pipelines,
1487
recv_ports[probe_idx].take(),
1488
scope,
1489
join_handles,
1490
)
1491
.unwrap();
1492
1493
let partitioner = HashPartitioner::new(state.num_pipelines, 0);
1494
let probe_tasks = receivers
1495
.into_iter()
1496
.zip(senders)
1497
.map(|(recv, send)| {
1498
scope.spawn_task(
1499
TaskPriority::High,
1500
ProbeState::partition_and_probe(
1501
recv,
1502
send,
1503
&probe_state.table_per_partition,
1504
&probe_state.unordered_morsel_seq,
1505
partitioner.clone(),
1506
&self.params,
1507
state,
1508
),
1509
)
1510
})
1511
.collect_vec();
1512
1513
let max_seq_sent = &mut probe_state.max_seq_sent;
1514
join_handles.push(scope.spawn_task(TaskPriority::High, async move {
1515
for probe_task in probe_tasks {
1516
*max_seq_sent = (*max_seq_sent).max(probe_task.await?);
1517
}
1518
Ok(())
1519
}));
1520
},
1521
EquiJoinState::EmitUnmatchedBuild(emit_state) => {
1522
assert!(recv_ports[build_idx].is_none());
1523
assert!(recv_ports[probe_idx].is_none());
1524
let send = send_ports[0].take().unwrap().serial();
1525
join_handles.push(scope.spawn_task(
1526
TaskPriority::Low,
1527
emit_state.emit_unmatched(send, &self.params, state.num_pipelines),
1528
));
1529
},
1530
EquiJoinState::EmitUnmatchedBuildInOrder(src_node) => {
1531
assert!(recv_ports[build_idx].is_none());
1532
assert!(recv_ports[probe_idx].is_none());
1533
src_node.spawn(scope, &mut [], send_ports, state, join_handles);
1534
},
1535
EquiJoinState::Done => unreachable!(),
1536
}
1537
}
1538
}
1539
1540