Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-stream/src/nodes/joins/equi_join.rs
8460 views
1
use std::cmp::Reverse;
2
use std::collections::BinaryHeap;
3
use std::sync::Arc;
4
use std::sync::atomic::{AtomicU64, Ordering};
5
6
use arrow::array::builder::ShareStrategy;
7
use polars_core::frame::builder::DataFrameBuilder;
8
use polars_core::prelude::*;
9
use polars_core::schema::{Schema, SchemaExt};
10
use polars_core::{POOL, config};
11
use polars_expr::hash_keys::HashKeys;
12
use polars_expr::idx_table::{IdxTable, new_idx_table};
13
use polars_io::pl_async::get_runtime;
14
use polars_ops::frame::{JoinArgs, JoinBuildSide, JoinType, MaintainOrderJoin};
15
use polars_ops::series::coalesce_columns;
16
use polars_utils::cardinality_sketch::CardinalitySketch;
17
use polars_utils::hashing::HashPartitioner;
18
use polars_utils::itertools::Itertools;
19
use polars_utils::pl_str::PlSmallStr;
20
use polars_utils::priority::Priority;
21
use polars_utils::relaxed_cell::RelaxedCell;
22
use polars_utils::sparse_init_vec::SparseInitVec;
23
use polars_utils::{IdxSize, format_pl_smallstr};
24
use rayon::prelude::*;
25
26
use super::{BufferedStream, JOIN_SAMPLE_LIMIT, LOPSIDED_SAMPLE_FACTOR};
27
use crate::async_executor;
28
use crate::async_primitives::wait_group::WaitGroup;
29
use crate::expression::StreamExpr;
30
use crate::morsel::{SourceToken, get_ideal_morsel_size};
31
use crate::nodes::compute_node_prelude::*;
32
use crate::nodes::in_memory_source::InMemorySourceNode;
33
34
struct EquiJoinParams {
35
left_is_build: Option<bool>,
36
preserve_order_build: bool,
37
preserve_order_probe: bool,
38
left_key_schema: Arc<Schema>,
39
left_key_selectors: Vec<StreamExpr>,
40
#[allow(dead_code)]
41
right_key_schema: Arc<Schema>,
42
right_key_selectors: Vec<StreamExpr>,
43
left_payload_select: Vec<Option<PlSmallStr>>,
44
right_payload_select: Vec<Option<PlSmallStr>>,
45
left_payload_schema: Arc<Schema>,
46
right_payload_schema: Arc<Schema>,
47
args: JoinArgs,
48
random_state: PlRandomState,
49
}
50
51
impl EquiJoinParams {
52
/// Should we emit unmatched rows from the build side?
53
fn emit_unmatched_build(&self) -> bool {
54
if self.left_is_build.unwrap() {
55
self.args.how == JoinType::Left || self.args.how == JoinType::Full
56
} else {
57
self.args.how == JoinType::Right || self.args.how == JoinType::Full
58
}
59
}
60
61
/// Should we emit unmatched rows from the probe side?
62
fn emit_unmatched_probe(&self) -> bool {
63
if self.left_is_build.unwrap() {
64
self.args.how == JoinType::Right || self.args.how == JoinType::Full
65
} else {
66
self.args.how == JoinType::Left || self.args.how == JoinType::Full
67
}
68
}
69
}
70
71
/// A payload selector contains for each column whether that column should be
72
/// included in the payload, and if yes with what name.
73
fn compute_payload_selector(
74
this: &Schema,
75
other: &Schema,
76
this_key_schema: &Schema,
77
other_key_schema: &Schema,
78
is_left: bool,
79
args: &JoinArgs,
80
) -> PolarsResult<Vec<Option<PlSmallStr>>> {
81
let should_coalesce = args.should_coalesce();
82
83
this.iter_names()
84
.map(|c| {
85
#[expect(clippy::never_loop)]
86
loop {
87
let selector = if args.how == JoinType::Right {
88
if is_left {
89
if should_coalesce && this_key_schema.contains(c) {
90
// Coalesced to RHS output key.
91
None
92
} else {
93
Some(c.clone())
94
}
95
} else if !other.contains(c) || (should_coalesce && other_key_schema.contains(c)) {
96
Some(c.clone())
97
} else {
98
break;
99
}
100
} else if should_coalesce && this_key_schema.contains(c) {
101
if is_left {
102
Some(c.clone())
103
} else if args.how == JoinType::Full {
104
// We must keep the right-hand side keycols around for
105
// coalescing.
106
let key_idx = this_key_schema.index_of(c).unwrap();
107
let name = format_pl_smallstr!("__POLARS_COALESCE_KEYCOL_{key_idx}");
108
Some(name)
109
} else {
110
None
111
}
112
} else if !other.contains(c) || is_left {
113
Some(c.clone())
114
} else {
115
break;
116
};
117
118
return Ok(selector);
119
}
120
121
let suffixed = format_pl_smallstr!("{}{}", c, args.suffix());
122
if other.contains(&suffixed) {
123
polars_bail!(Duplicate: "column with name '{suffixed}' already exists\n\n\
124
You may want to try:\n\
125
- renaming the column prior to joining\n\
126
- using the `suffix` parameter to specify a suffix different to the default one ('_right')")
127
}
128
129
Ok(Some(suffixed))
130
})
131
.collect()
132
}
133
134
/// Fixes names and does coalescing of columns post-join.
135
fn postprocess_join(df: DataFrame, params: &EquiJoinParams) -> DataFrame {
136
if params.args.how == JoinType::Full && params.args.should_coalesce() {
137
// TODO: don't do string-based column lookups for each dataframe, pre-compute coalesce indices.
138
let new_cols = df
139
.columns()
140
.iter()
141
.filter_map(|c| {
142
if let Some(key_idx) = params.left_key_schema.index_of(c.name()) {
143
let other = df
144
.column(&format_pl_smallstr!("__POLARS_COALESCE_KEYCOL_{key_idx}"))
145
.unwrap();
146
return Some(coalesce_columns(&[c.clone(), other.clone()]).unwrap());
147
}
148
149
if c.name().starts_with("__POLARS_COALESCE_KEYCOL") {
150
return None;
151
}
152
153
Some(c.clone())
154
})
155
.collect();
156
157
unsafe { DataFrame::new_unchecked(df.height(), new_cols) }
158
} else {
159
df
160
}
161
}
162
163
fn select_schema(schema: &Schema, selector: &[Option<PlSmallStr>]) -> Schema {
164
schema
165
.iter_fields()
166
.zip(selector)
167
.filter_map(|(f, name)| Some(f.with_name(name.clone()?)))
168
.collect()
169
}
170
171
async fn select_keys(
172
df: &DataFrame,
173
key_selectors: &[StreamExpr],
174
params: &EquiJoinParams,
175
state: &ExecutionState,
176
) -> PolarsResult<HashKeys> {
177
let mut key_columns = Vec::new();
178
for selector in key_selectors {
179
key_columns.push(selector.evaluate(df, state).await?.into_column());
180
}
181
let keys = unsafe { DataFrame::new_unchecked_with_broadcast(df.height(), key_columns)? };
182
Ok(HashKeys::from_df(
183
&keys,
184
params.random_state.clone(),
185
params.args.nulls_equal,
186
false,
187
))
188
}
189
190
fn select_payload(df: DataFrame, selector: &[Option<PlSmallStr>]) -> DataFrame {
191
let height = df.height();
192
let new_cols = df
193
.into_columns()
194
.into_iter()
195
.zip(selector)
196
.filter_map(|(c, name)| Some(c.with_name(name.clone()?)))
197
.collect();
198
199
unsafe { DataFrame::new_unchecked(height, new_cols) }
200
}
201
202
fn estimate_cardinality(
203
morsels: &[Morsel],
204
key_selectors: &[StreamExpr],
205
params: &EquiJoinParams,
206
state: &ExecutionState,
207
) -> PolarsResult<f64> {
208
let sample_limit = *JOIN_SAMPLE_LIMIT;
209
if morsels.is_empty() || sample_limit == 0 {
210
return Ok(0.0);
211
}
212
213
let mut total_height = 0;
214
let mut to_process_end = 0;
215
while to_process_end < morsels.len() && total_height < sample_limit {
216
total_height += morsels[to_process_end].df().height();
217
to_process_end += 1;
218
}
219
let last_morsel_idx = to_process_end - 1;
220
let last_morsel_len = morsels[last_morsel_idx].df().height();
221
let last_morsel_slice = last_morsel_len - total_height.saturating_sub(sample_limit);
222
let runtime = get_runtime();
223
224
POOL.install(|| {
225
let sample_cardinality = morsels[..to_process_end]
226
.par_iter()
227
.enumerate()
228
.try_fold(
229
CardinalitySketch::new,
230
|mut sketch, (morsel_idx, morsel)| {
231
let sliced;
232
let df = if morsel_idx == last_morsel_idx {
233
sliced = morsel.df().slice(0, last_morsel_slice);
234
&sliced
235
} else {
236
morsel.df()
237
};
238
let hash_keys =
239
runtime.block_on(select_keys(df, key_selectors, params, state))?;
240
hash_keys.sketch_cardinality(&mut sketch);
241
PolarsResult::Ok(sketch)
242
},
243
)
244
.map(|sketch| PolarsResult::Ok(sketch?.estimate()))
245
.try_reduce_with(|a, b| Ok(a + b))
246
.unwrap()?;
247
Ok(sample_cardinality as f64 / total_height.min(sample_limit) as f64)
248
})
249
}
250
251
#[derive(Default)]
252
struct SampleState {
253
left: Vec<Morsel>,
254
left_len: usize,
255
right: Vec<Morsel>,
256
right_len: usize,
257
}
258
259
impl SampleState {
260
async fn sink(
261
mut recv: PortReceiver,
262
morsels: &mut Vec<Morsel>,
263
len: &mut usize,
264
this_final_len: Arc<RelaxedCell<usize>>,
265
other_final_len: Arc<RelaxedCell<usize>>,
266
) -> PolarsResult<()> {
267
while let Ok(mut morsel) = recv.recv().await {
268
*len += morsel.df().height();
269
if *len >= *JOIN_SAMPLE_LIMIT
270
|| *len
271
>= other_final_len
272
.load()
273
.saturating_mul(LOPSIDED_SAMPLE_FACTOR)
274
{
275
morsel.source_token().stop();
276
}
277
278
drop(morsel.take_consume_token());
279
morsels.push(morsel);
280
}
281
this_final_len.store(*len);
282
Ok(())
283
}
284
285
fn try_transition_to_build(
286
&mut self,
287
recv: &[PortState],
288
params: &mut EquiJoinParams,
289
state: &StreamingExecutionState,
290
) -> PolarsResult<Option<BuildState>> {
291
let left_saturated = self.left_len >= *JOIN_SAMPLE_LIMIT;
292
let right_saturated = self.right_len >= *JOIN_SAMPLE_LIMIT;
293
let left_done = recv[0] == PortState::Done || left_saturated;
294
let right_done = recv[1] == PortState::Done || right_saturated;
295
#[expect(clippy::nonminimal_bool)]
296
let stop_sampling = (left_done && right_done)
297
|| (left_done && self.right_len >= LOPSIDED_SAMPLE_FACTOR * self.left_len)
298
|| (right_done && self.left_len >= LOPSIDED_SAMPLE_FACTOR * self.right_len);
299
if !stop_sampling {
300
return Ok(None);
301
}
302
303
if config::verbose() {
304
eprintln!(
305
"choosing build side, sample lengths are: {} vs. {}",
306
self.left_len, self.right_len
307
);
308
}
309
310
let estimate_cardinalities = || {
311
let left_cardinality = estimate_cardinality(
312
&self.left,
313
&params.left_key_selectors,
314
params,
315
&state.in_memory_exec_state,
316
)?;
317
let right_cardinality = estimate_cardinality(
318
&self.right,
319
&params.right_key_selectors,
320
params,
321
&state.in_memory_exec_state,
322
)?;
323
if config::verbose() {
324
eprintln!(
325
"estimated cardinalities are: {left_cardinality} vs. {right_cardinality}"
326
);
327
}
328
PolarsResult::Ok((left_cardinality, right_cardinality))
329
};
330
331
let left_is_build = match (left_saturated, right_saturated) {
332
// Don't bother estimating cardinality, just choose smaller side as
333
// we have everything in-memory anyway.
334
(false, false) => self.left_len < self.right_len,
335
336
// Choose the unsaturated side, the saturated side could be
337
// arbitrarily big.
338
(false, true) => true,
339
(true, false) => false,
340
341
(true, true) => {
342
match params.args.build_side {
343
Some(JoinBuildSide::PreferLeft) => true,
344
Some(JoinBuildSide::PreferRight) => false,
345
Some(JoinBuildSide::ForceLeft | JoinBuildSide::ForceRight) => unreachable!(),
346
None => {
347
// Estimate cardinality and choose smaller.
348
let (lc, rc) = estimate_cardinalities()?;
349
lc < rc
350
},
351
}
352
},
353
};
354
355
if config::verbose() {
356
eprintln!(
357
"build side chosen: {}",
358
if left_is_build { "left" } else { "right" }
359
);
360
}
361
362
// Transition to building state.
363
params.left_is_build = Some(left_is_build);
364
let mut sampled_build_morsels =
365
BufferedStream::new(core::mem::take(&mut self.left), MorselSeq::default());
366
let mut sampled_probe_morsels =
367
BufferedStream::new(core::mem::take(&mut self.right), MorselSeq::default());
368
if !left_is_build {
369
core::mem::swap(&mut sampled_build_morsels, &mut sampled_probe_morsels);
370
}
371
372
let partitioner = HashPartitioner::new(state.num_pipelines, 0);
373
let mut build_state = BuildState::new(
374
state.num_pipelines,
375
state.num_pipelines,
376
sampled_probe_morsels,
377
);
378
379
// Simulate the sample build morsels flowing into the build side.
380
if !sampled_build_morsels.is_empty() {
381
crate::async_executor::task_scope(|scope| {
382
let mut join_handles = Vec::new();
383
let receivers = sampled_build_morsels
384
.reinsert(state.num_pipelines, None, scope, &mut join_handles)
385
.unwrap();
386
387
for (local_builder, recv) in build_state.local_builders.iter_mut().zip(receivers) {
388
join_handles.push(scope.spawn_task(
389
TaskPriority::High,
390
BuildState::partition_and_sink(
391
recv,
392
local_builder,
393
partitioner.clone(),
394
params,
395
state,
396
),
397
));
398
}
399
400
polars_io::pl_async::get_runtime().block_on(async move {
401
for handle in join_handles {
402
handle.await?;
403
}
404
PolarsResult::Ok(())
405
})
406
})?;
407
}
408
409
Ok(Some(build_state))
410
}
411
}
412
413
#[derive(Default)]
414
struct LocalBuilder {
415
// The complete list of morsels and their computed hashes seen by this builder.
416
morsels: Vec<(MorselSeq, DataFrame, HashKeys)>,
417
418
// A cardinality sketch per partition for the keys seen by this builder.
419
sketch_per_p: Vec<CardinalitySketch>,
420
421
// morsel_idxs_values_per_p[p][start..stop] contains the offsets into morsels[i]
422
// for partition p, where start, stop are:
423
// let start = morsel_idxs_offsets[i * num_partitions + p];
424
// let stop = morsel_idxs_offsets[(i + 1) * num_partitions + p];
425
morsel_idxs_values_per_p: Vec<Vec<IdxSize>>,
426
morsel_idxs_offsets_per_p: Vec<usize>,
427
}
428
429
struct BuildState {
430
local_builders: Vec<LocalBuilder>,
431
sampled_probe_morsels: BufferedStream,
432
}
433
434
impl BuildState {
435
fn new(
436
num_pipelines: usize,
437
num_partitions: usize,
438
sampled_probe_morsels: BufferedStream,
439
) -> Self {
440
let local_builders = (0..num_pipelines)
441
.map(|_| LocalBuilder {
442
morsels: Vec::new(),
443
sketch_per_p: vec![CardinalitySketch::default(); num_partitions],
444
morsel_idxs_values_per_p: vec![Vec::new(); num_partitions],
445
morsel_idxs_offsets_per_p: vec![0; num_partitions],
446
})
447
.collect();
448
Self {
449
local_builders,
450
sampled_probe_morsels,
451
}
452
}
453
454
async fn partition_and_sink(
455
mut recv: PortReceiver,
456
local: &mut LocalBuilder,
457
partitioner: HashPartitioner,
458
params: &EquiJoinParams,
459
state: &StreamingExecutionState,
460
) -> PolarsResult<()> {
461
let track_unmatchable = params.emit_unmatched_build();
462
let (key_selectors, payload_selector);
463
if params.left_is_build.unwrap() {
464
payload_selector = &params.left_payload_select;
465
key_selectors = &params.left_key_selectors;
466
} else {
467
payload_selector = &params.right_payload_select;
468
key_selectors = &params.right_key_selectors;
469
};
470
471
while let Ok(morsel) = recv.recv().await {
472
// Compute hashed keys and payload. We must rechunk the payload for
473
// later gathers.
474
let hash_keys = select_keys(
475
morsel.df(),
476
key_selectors,
477
params,
478
&state.in_memory_exec_state,
479
)
480
.await?;
481
let mut payload = select_payload(morsel.df().clone(), payload_selector);
482
payload.rechunk_mut();
483
484
hash_keys.gen_idxs_per_partition(
485
&partitioner,
486
&mut local.morsel_idxs_values_per_p,
487
&mut local.sketch_per_p,
488
track_unmatchable,
489
);
490
491
local
492
.morsel_idxs_offsets_per_p
493
.extend(local.morsel_idxs_values_per_p.iter().map(|vp| vp.len()));
494
local.morsels.push((morsel.seq(), payload, hash_keys));
495
}
496
Ok(())
497
}
498
499
fn finalize_ordered(&mut self, params: &EquiJoinParams, table: &dyn IdxTable) -> ProbeState {
500
let track_unmatchable = params.emit_unmatched_build();
501
let payload_schema = if params.left_is_build.unwrap() {
502
&params.left_payload_schema
503
} else {
504
&params.right_payload_schema
505
};
506
507
let num_partitions = self.local_builders[0].sketch_per_p.len();
508
let local_builders = &self.local_builders;
509
let probe_tables: SparseInitVec<ProbeTable> = SparseInitVec::with_capacity(num_partitions);
510
511
POOL.scope(|s| {
512
for p in 0..num_partitions {
513
let probe_tables = &probe_tables;
514
s.spawn(move |_| {
515
// TODO: every thread does an identical linearize, we can do a single parallel one.
516
let mut kmerge = BinaryHeap::with_capacity(local_builders.len());
517
let mut cur_idx_per_loc = vec![0; local_builders.len()];
518
519
// Compute cardinality estimate and total amount of
520
// payload for this partition, and initialize k-way merge.
521
let mut sketch = CardinalitySketch::new();
522
let mut payload_rows = 0;
523
for (l_idx, l) in local_builders.iter().enumerate() {
524
let Some((seq, _, _)) = l.morsels.first() else {
525
continue;
526
};
527
kmerge.push(Priority(Reverse(seq), l_idx));
528
529
sketch.combine(&l.sketch_per_p[p]);
530
let offsets_len = l.morsel_idxs_offsets_per_p.len();
531
payload_rows +=
532
l.morsel_idxs_offsets_per_p[offsets_len - num_partitions + p];
533
}
534
535
// Allocate hash table and payload builder.
536
let mut p_table = table.new_empty();
537
p_table.reserve(sketch.estimate() * 5 / 4);
538
let mut p_payload = DataFrameBuilder::new(payload_schema.clone());
539
p_payload.reserve(payload_rows);
540
541
let mut p_seq_ids = Vec::new();
542
if track_unmatchable {
543
p_seq_ids.reserve(payload_rows);
544
}
545
546
// Linearize and build.
547
unsafe {
548
let mut norm_seq_id = 0 as IdxSize;
549
while let Some(Priority(Reverse(_seq), l_idx)) = kmerge.pop() {
550
let l = local_builders.get_unchecked(l_idx);
551
let idx_in_l = *cur_idx_per_loc.get_unchecked(l_idx);
552
*cur_idx_per_loc.get_unchecked_mut(l_idx) += 1;
553
if let Some((next_seq, _, _)) = l.morsels.get(idx_in_l + 1) {
554
kmerge.push(Priority(Reverse(next_seq), l_idx));
555
}
556
557
let (_mseq, payload, keys) = l.morsels.get_unchecked(idx_in_l);
558
let p_morsel_idxs_start =
559
l.morsel_idxs_offsets_per_p[idx_in_l * num_partitions + p];
560
let p_morsel_idxs_stop =
561
l.morsel_idxs_offsets_per_p[(idx_in_l + 1) * num_partitions + p];
562
let p_morsel_idxs = &l.morsel_idxs_values_per_p[p]
563
[p_morsel_idxs_start..p_morsel_idxs_stop];
564
p_table.insert_keys_subset(keys, p_morsel_idxs, track_unmatchable);
565
p_payload.gather_extend(payload, p_morsel_idxs, ShareStrategy::Never);
566
567
if track_unmatchable {
568
p_seq_ids.resize(p_payload.len(), norm_seq_id);
569
norm_seq_id += 1;
570
}
571
}
572
}
573
574
probe_tables
575
.try_set(
576
p,
577
ProbeTable {
578
hash_table: p_table,
579
payload: p_payload.freeze(),
580
seq_ids: p_seq_ids,
581
},
582
)
583
.ok()
584
.unwrap();
585
});
586
}
587
});
588
589
ProbeState {
590
table_per_partition: probe_tables.try_assume_init().ok().unwrap(),
591
max_seq_sent: MorselSeq::default(),
592
sampled_probe_morsels: core::mem::take(&mut self.sampled_probe_morsels),
593
unordered_morsel_seq: AtomicU64::new(0),
594
}
595
}
596
597
fn finalize_unordered(&mut self, params: &EquiJoinParams, table: &dyn IdxTable) -> ProbeState {
598
let track_unmatchable = params.emit_unmatched_build();
599
let payload_schema = if params.left_is_build.unwrap() {
600
&params.left_payload_schema
601
} else {
602
&params.right_payload_schema
603
};
604
605
// To reduce maximum memory usage we want to drop the morsels
606
// as soon as they're processed, so we move into Arcs. The drops might
607
// also be expensive, so instead of directly dropping we put that on
608
// a work queue.
609
let morsels_per_local_builder = self
610
.local_builders
611
.iter_mut()
612
.map(|b| Arc::new(core::mem::take(&mut b.morsels)))
613
.collect_vec();
614
let (morsel_drop_q_send, morsel_drop_q_recv) =
615
async_channel::bounded(morsels_per_local_builder.len());
616
let num_partitions = self.local_builders[0].sketch_per_p.len();
617
let local_builders = &self.local_builders;
618
let probe_tables: SparseInitVec<ProbeTable> = SparseInitVec::with_capacity(num_partitions);
619
620
async_executor::task_scope(|s| {
621
// Wrap in outer Arc to move to each thread, performing the
622
// expensive clone on that thread.
623
let arc_morsels_per_local_builder = Arc::new(morsels_per_local_builder);
624
let mut join_handles = Vec::new();
625
for p in 0..num_partitions {
626
let arc_morsels_per_local_builder = Arc::clone(&arc_morsels_per_local_builder);
627
let morsel_drop_q_send = morsel_drop_q_send.clone();
628
let morsel_drop_q_recv = morsel_drop_q_recv.clone();
629
let probe_tables = &probe_tables;
630
join_handles.push(s.spawn_task(TaskPriority::High, async move {
631
// Extract from outer arc and drop outer arc.
632
let morsels_per_local_builder =
633
Arc::unwrap_or_clone(arc_morsels_per_local_builder);
634
635
// Compute cardinality estimate and total amount of
636
// payload for this partition.
637
let mut sketch = CardinalitySketch::new();
638
let mut payload_rows = 0;
639
for l in local_builders {
640
sketch.combine(&l.sketch_per_p[p]);
641
let offsets_len = l.morsel_idxs_offsets_per_p.len();
642
payload_rows +=
643
l.morsel_idxs_offsets_per_p[offsets_len - num_partitions + p];
644
}
645
646
// Allocate hash table and payload builder.
647
let mut p_table = table.new_empty();
648
p_table.reserve(sketch.estimate() * 5 / 4);
649
let mut p_payload = DataFrameBuilder::new(payload_schema.clone());
650
p_payload.reserve(payload_rows);
651
652
// Build.
653
let mut skip_drop_attempt = false;
654
for (l, l_morsels) in local_builders.iter().zip(morsels_per_local_builder) {
655
// Try to help with dropping the processed morsels.
656
if !skip_drop_attempt {
657
drop(morsel_drop_q_recv.try_recv());
658
}
659
660
for (i, morsel) in l_morsels.iter().enumerate() {
661
let (_mseq, payload, keys) = morsel;
662
unsafe {
663
let p_morsel_idxs_start =
664
l.morsel_idxs_offsets_per_p[i * num_partitions + p];
665
let p_morsel_idxs_stop =
666
l.morsel_idxs_offsets_per_p[(i + 1) * num_partitions + p];
667
let p_morsel_idxs = &l.morsel_idxs_values_per_p[p]
668
[p_morsel_idxs_start..p_morsel_idxs_stop];
669
p_table.insert_keys_subset(keys, p_morsel_idxs, track_unmatchable);
670
p_payload.gather_extend(
671
payload,
672
p_morsel_idxs,
673
ShareStrategy::Never,
674
);
675
}
676
}
677
678
if let Some(l) = Arc::into_inner(l_morsels) {
679
// If we're the last thread to process this set of morsels we're probably
680
// falling behind the rest, since the drop can be quite expensive we skip
681
// a drop attempt hoping someone else will pick up the slack.
682
drop(morsel_drop_q_send.try_send(l));
683
skip_drop_attempt = true;
684
} else {
685
skip_drop_attempt = false;
686
}
687
}
688
689
// We're done, help others out by doing drops.
690
drop(morsel_drop_q_send); // So we don't deadlock trying to receive from ourselves.
691
while let Ok(l_morsels) = morsel_drop_q_recv.recv().await {
692
drop(l_morsels);
693
}
694
695
probe_tables
696
.try_set(
697
p,
698
ProbeTable {
699
hash_table: p_table,
700
payload: p_payload.freeze(),
701
seq_ids: Vec::new(),
702
},
703
)
704
.ok()
705
.unwrap();
706
}));
707
}
708
709
// Drop outer arc after spawning each thread so the inner arcs
710
// can get dropped as soon as they're processed. We also have to
711
// drop the drop queue sender so we don't deadlock waiting for it
712
// to end.
713
drop(arc_morsels_per_local_builder);
714
drop(morsel_drop_q_send);
715
716
polars_io::pl_async::get_runtime().block_on(async move {
717
for handle in join_handles {
718
handle.await;
719
}
720
});
721
});
722
723
ProbeState {
724
table_per_partition: probe_tables.try_assume_init().ok().unwrap(),
725
max_seq_sent: MorselSeq::default(),
726
sampled_probe_morsels: core::mem::take(&mut self.sampled_probe_morsels),
727
unordered_morsel_seq: AtomicU64::new(0),
728
}
729
}
730
}
731
732
struct ProbeTable {
733
hash_table: Box<dyn IdxTable>,
734
payload: DataFrame,
735
seq_ids: Vec<IdxSize>,
736
}
737
738
struct ProbeState {
739
table_per_partition: Vec<ProbeTable>,
740
max_seq_sent: MorselSeq,
741
sampled_probe_morsels: BufferedStream,
742
743
// For unordered joins we relabel output morsels to speed up the linearizer.
744
unordered_morsel_seq: AtomicU64,
745
}
746
747
impl ProbeState {
748
/// Returns the max morsel sequence sent.
749
async fn partition_and_probe(
750
mut recv: PortReceiver,
751
mut send: PortSender,
752
partitions: &[ProbeTable],
753
unordered_morsel_seq: &AtomicU64,
754
partitioner: HashPartitioner,
755
params: &EquiJoinParams,
756
state: &StreamingExecutionState,
757
) -> PolarsResult<MorselSeq> {
758
// TODO: shuffle after partitioning and keep probe tables thread-local.
759
let mut partition_idxs = vec![Vec::new(); partitioner.num_partitions()];
760
let mut probe_partitions = Vec::new();
761
let mut materialized_idxsize_range = Vec::new();
762
let mut table_match = Vec::new();
763
let mut probe_match = Vec::new();
764
let mut max_seq = MorselSeq::default();
765
766
let probe_limit = get_ideal_morsel_size() as IdxSize;
767
let mark_matches = params.emit_unmatched_build();
768
let emit_unmatched = params.emit_unmatched_probe();
769
770
let (key_selectors, payload_selector, build_payload_schema, probe_payload_schema);
771
if params.left_is_build.unwrap() {
772
key_selectors = &params.right_key_selectors;
773
payload_selector = &params.right_payload_select;
774
build_payload_schema = &params.left_payload_schema;
775
probe_payload_schema = &params.right_payload_schema;
776
} else {
777
key_selectors = &params.left_key_selectors;
778
payload_selector = &params.left_payload_select;
779
build_payload_schema = &params.right_payload_schema;
780
probe_payload_schema = &params.left_payload_schema;
781
};
782
783
let mut build_out = DataFrameBuilder::new(build_payload_schema.clone());
784
let mut probe_out = DataFrameBuilder::new(probe_payload_schema.clone());
785
786
// A simple estimate used to size reserves.
787
let mut selectivity_estimate = 1.0;
788
let mut selectivity_estimate_confidence = 0.0;
789
790
while let Ok(morsel) = recv.recv().await {
791
// Compute hashed keys and payload.
792
let (df, in_seq, src_token, wait_token) = morsel.into_inner();
793
794
let df_height = df.height();
795
if df_height == 0 {
796
continue;
797
}
798
799
let hash_keys =
800
select_keys(&df, key_selectors, params, &state.in_memory_exec_state).await?;
801
let mut payload = select_payload(df, payload_selector);
802
let mut payload_rechunked = false; // We don't eagerly rechunk because there might be no matches.
803
let mut total_matches = 0;
804
805
// Use selectivity estimate to reserve for morsel builders.
806
let max_match_per_key_est = (selectivity_estimate * 1.2) as usize + 16;
807
let out_est_size = ((selectivity_estimate * 1.2 * df_height as f64) as usize)
808
.min(probe_limit as usize);
809
build_out.reserve(out_est_size + max_match_per_key_est);
810
811
unsafe {
812
let mut new_morsel =
813
|build: &mut DataFrameBuilder, probe: &mut DataFrameBuilder| {
814
let mut build_df = build.freeze_reset();
815
let mut probe_df = probe.freeze_reset();
816
let out_df = if params.left_is_build.unwrap() {
817
build_df.hstack_mut_unchecked(probe_df.columns());
818
build_df
819
} else {
820
probe_df.hstack_mut_unchecked(build_df.columns());
821
probe_df
822
};
823
let out_df = postprocess_join(out_df, params);
824
let out_seq = if params.preserve_order_probe {
825
in_seq
826
} else {
827
MorselSeq::new(unordered_morsel_seq.fetch_add(1, Ordering::Relaxed))
828
};
829
max_seq = out_seq;
830
Morsel::new(out_df, out_seq, src_token.clone())
831
};
832
833
if params.preserve_order_probe {
834
// To preserve the order we can't do bulk probes per partition and must follow
835
// the order of the probe morsel. We can still group probes that are
836
// consecutively on the same partition.
837
probe_partitions.clear();
838
hash_keys.gen_partitions(&partitioner, &mut probe_partitions, emit_unmatched);
839
840
let mut probe_group_start = 0;
841
while probe_group_start < probe_partitions.len() {
842
let p_idx = probe_partitions[probe_group_start];
843
let mut probe_group_end = probe_group_start + 1;
844
while probe_partitions.get(probe_group_end) == Some(&p_idx) {
845
probe_group_end += 1;
846
}
847
let Some(p) = partitions.get(p_idx as usize) else {
848
probe_group_start = probe_group_end;
849
continue;
850
};
851
852
materialized_idxsize_range.extend(
853
materialized_idxsize_range.len() as IdxSize..probe_group_end as IdxSize,
854
);
855
856
while probe_group_start < probe_group_end {
857
let matches_before_limit = probe_limit - probe_match.len() as IdxSize;
858
table_match.clear();
859
probe_group_start += p.hash_table.probe_subset(
860
&hash_keys,
861
&materialized_idxsize_range[probe_group_start..probe_group_end],
862
&mut table_match,
863
&mut probe_match,
864
mark_matches,
865
emit_unmatched,
866
matches_before_limit,
867
) as usize;
868
869
if emit_unmatched {
870
build_out.opt_gather_extend(
871
&p.payload,
872
&table_match,
873
ShareStrategy::Always,
874
);
875
} else {
876
build_out.gather_extend(
877
&p.payload,
878
&table_match,
879
ShareStrategy::Always,
880
);
881
};
882
883
if probe_match.len() >= probe_limit as usize
884
|| probe_group_start == probe_partitions.len()
885
{
886
if !payload_rechunked {
887
payload.rechunk_mut();
888
payload_rechunked = true;
889
}
890
probe_out.gather_extend(
891
&payload,
892
&probe_match,
893
ShareStrategy::Always,
894
);
895
let out_len = probe_match.len();
896
probe_match.clear();
897
let out_morsel = new_morsel(&mut build_out, &mut probe_out);
898
if send.send(out_morsel).await.is_err() {
899
return Ok(max_seq);
900
}
901
if probe_group_end != probe_partitions.len() {
902
// We had enough matches to need a mid-partition flush, let's assume there are a lot of
903
// matches and just do a large reserve.
904
let old_est = probe_limit as usize + max_match_per_key_est;
905
build_out.reserve(old_est.max(out_len + 16));
906
}
907
}
908
}
909
}
910
} else {
911
// Partition and probe the tables.
912
for p in partition_idxs.iter_mut() {
913
p.clear();
914
}
915
hash_keys.gen_idxs_per_partition(
916
&partitioner,
917
&mut partition_idxs,
918
&mut [],
919
emit_unmatched,
920
);
921
922
for (p, idxs_in_p) in partitions.iter().zip(&partition_idxs) {
923
let mut offset = 0;
924
while offset < idxs_in_p.len() {
925
let matches_before_limit = probe_limit - probe_match.len() as IdxSize;
926
table_match.clear();
927
offset += p.hash_table.probe_subset(
928
&hash_keys,
929
&idxs_in_p[offset..],
930
&mut table_match,
931
&mut probe_match,
932
mark_matches,
933
emit_unmatched,
934
matches_before_limit,
935
) as usize;
936
937
if table_match.is_empty() {
938
continue;
939
}
940
total_matches += table_match.len();
941
942
if emit_unmatched {
943
build_out.opt_gather_extend(
944
&p.payload,
945
&table_match,
946
ShareStrategy::Always,
947
);
948
} else {
949
build_out.gather_extend(
950
&p.payload,
951
&table_match,
952
ShareStrategy::Always,
953
);
954
};
955
956
if probe_match.len() >= probe_limit as usize {
957
if !payload_rechunked {
958
payload.rechunk_mut();
959
payload_rechunked = true;
960
}
961
probe_out.gather_extend(
962
&payload,
963
&probe_match,
964
ShareStrategy::Always,
965
);
966
let out_len = probe_match.len();
967
probe_match.clear();
968
let out_morsel = new_morsel(&mut build_out, &mut probe_out);
969
if send.send(out_morsel).await.is_err() {
970
return Ok(max_seq);
971
}
972
// We had enough matches to need a mid-partition flush, let's assume there are a lot of
973
// matches and just do a large reserve.
974
let old_est = probe_limit as usize + max_match_per_key_est;
975
build_out.reserve(old_est.max(out_len + 16));
976
}
977
}
978
}
979
}
980
981
if !probe_match.is_empty() {
982
if !payload_rechunked {
983
payload.rechunk_mut();
984
}
985
probe_out.gather_extend(&payload, &probe_match, ShareStrategy::Always);
986
probe_match.clear();
987
let out_morsel = new_morsel(&mut build_out, &mut probe_out);
988
if send.send(out_morsel).await.is_err() {
989
return Ok(max_seq);
990
}
991
}
992
}
993
994
drop(wait_token);
995
996
// Move selectivity estimate a bit towards latest value. Allows rapid changes at first.
997
// TODO: implement something more re-usable and robust.
998
selectivity_estimate = selectivity_estimate_confidence * selectivity_estimate
999
+ (1.0 - selectivity_estimate_confidence)
1000
* (total_matches as f64 / df_height as f64);
1001
selectivity_estimate_confidence = (selectivity_estimate_confidence + 0.1).min(0.8);
1002
}
1003
1004
Ok(max_seq)
1005
}
1006
1007
fn ordered_unmatched(&mut self, params: &EquiJoinParams) -> DataFrame {
1008
// TODO: parallelize this operator.
1009
1010
let build_payload_schema = if params.left_is_build.unwrap() {
1011
&params.left_payload_schema
1012
} else {
1013
&params.right_payload_schema
1014
};
1015
1016
let mut unmarked_idxs = Vec::new();
1017
let mut linearized_idxs = Vec::new();
1018
1019
for (p_idx, p) in self.table_per_partition.iter().enumerate_idx() {
1020
p.hash_table
1021
.unmarked_keys(&mut unmarked_idxs, 0, IdxSize::MAX);
1022
linearized_idxs.extend(
1023
unmarked_idxs
1024
.iter()
1025
.map(|i| (unsafe { *p.seq_ids.get_unchecked(*i as usize) }, p_idx, *i)),
1026
);
1027
}
1028
1029
linearized_idxs.sort_by_key(|(seq_id, _, _)| *seq_id);
1030
1031
unsafe {
1032
let mut build_out = DataFrameBuilder::new(build_payload_schema.clone());
1033
build_out.reserve(linearized_idxs.len());
1034
1035
// Group indices from the same partition.
1036
let mut group_start = 0;
1037
let mut gather_idxs = Vec::new();
1038
while group_start < linearized_idxs.len() {
1039
gather_idxs.clear();
1040
1041
let (_seq, p_idx, idx_in_p) = linearized_idxs[group_start];
1042
gather_idxs.push(idx_in_p);
1043
let mut group_end = group_start + 1;
1044
while group_end < linearized_idxs.len() && linearized_idxs[group_end].1 == p_idx {
1045
gather_idxs.push(linearized_idxs[group_end].2);
1046
group_end += 1;
1047
}
1048
1049
build_out.gather_extend(
1050
&self.table_per_partition[p_idx as usize].payload,
1051
&gather_idxs,
1052
ShareStrategy::Never, // Don't keep entire table alive for unmatched indices.
1053
);
1054
1055
group_start = group_end;
1056
}
1057
1058
let mut build_df = build_out.freeze();
1059
let out_df = if params.left_is_build.unwrap() {
1060
let probe_df =
1061
DataFrame::full_null(&params.right_payload_schema, build_df.height());
1062
build_df.hstack_mut_unchecked(probe_df.columns());
1063
build_df
1064
} else {
1065
let mut probe_df =
1066
DataFrame::full_null(&params.left_payload_schema, build_df.height());
1067
probe_df.hstack_mut_unchecked(build_df.columns());
1068
probe_df
1069
};
1070
postprocess_join(out_df, params)
1071
}
1072
}
1073
}
1074
1075
impl Drop for ProbeState {
1076
fn drop(&mut self) {
1077
POOL.install(|| {
1078
// Parallel drop as the state might be quite big.
1079
self.table_per_partition.par_drain(..).for_each(drop);
1080
})
1081
}
1082
}
1083
1084
struct EmitUnmatchedState {
1085
partitions: Vec<ProbeTable>,
1086
active_partition_idx: usize,
1087
offset_in_active_p: usize,
1088
morsel_seq: MorselSeq,
1089
}
1090
1091
impl EmitUnmatchedState {
1092
async fn emit_unmatched(
1093
&mut self,
1094
mut send: PortSender,
1095
params: &EquiJoinParams,
1096
num_pipelines: usize,
1097
) -> PolarsResult<()> {
1098
let total_len: usize = self
1099
.partitions
1100
.iter()
1101
.map(|p| p.hash_table.num_keys() as usize)
1102
.sum();
1103
let ideal_morsel_count = (total_len / get_ideal_morsel_size()).max(1);
1104
let morsel_count = ideal_morsel_count.next_multiple_of(num_pipelines);
1105
let morsel_size = total_len.div_ceil(morsel_count).max(1);
1106
1107
let wait_group = WaitGroup::default();
1108
let source_token = SourceToken::new();
1109
let mut unmarked_idxs = Vec::new();
1110
while let Some(p) = self.partitions.get(self.active_partition_idx) {
1111
loop {
1112
// Generate a chunk of unmarked key indices.
1113
self.offset_in_active_p += p.hash_table.unmarked_keys(
1114
&mut unmarked_idxs,
1115
self.offset_in_active_p as IdxSize,
1116
morsel_size as IdxSize,
1117
) as usize;
1118
if unmarked_idxs.is_empty() {
1119
break;
1120
}
1121
1122
// Gather and create full-null counterpart.
1123
let out_df = unsafe {
1124
let mut build_df = p.payload.take_slice_unchecked_impl(&unmarked_idxs, false);
1125
let len = build_df.height();
1126
if params.left_is_build.unwrap() {
1127
let probe_df = DataFrame::full_null(&params.right_payload_schema, len);
1128
build_df.hstack_mut_unchecked(probe_df.columns());
1129
build_df
1130
} else {
1131
let mut probe_df = DataFrame::full_null(&params.left_payload_schema, len);
1132
probe_df.hstack_mut_unchecked(build_df.columns());
1133
probe_df
1134
}
1135
};
1136
let out_df = postprocess_join(out_df, params);
1137
1138
// Send and wait until consume token is consumed.
1139
let mut morsel = Morsel::new(out_df, self.morsel_seq, source_token.clone());
1140
self.morsel_seq = self.morsel_seq.successor();
1141
morsel.set_consume_token(wait_group.token());
1142
if send.send(morsel).await.is_err() {
1143
return Ok(());
1144
}
1145
1146
wait_group.wait().await;
1147
if source_token.stop_requested() {
1148
return Ok(());
1149
}
1150
}
1151
1152
self.active_partition_idx += 1;
1153
self.offset_in_active_p = 0;
1154
}
1155
1156
Ok(())
1157
}
1158
}
1159
1160
enum EquiJoinState {
1161
Sample(SampleState),
1162
Build(BuildState),
1163
Probe(ProbeState),
1164
EmitUnmatchedBuild(EmitUnmatchedState),
1165
EmitUnmatchedBuildInOrder(InMemorySourceNode),
1166
Done,
1167
}
1168
1169
pub struct EquiJoinNode {
1170
state: EquiJoinState,
1171
params: EquiJoinParams,
1172
table: Box<dyn IdxTable>,
1173
}
1174
1175
impl EquiJoinNode {
1176
#[allow(clippy::too_many_arguments)]
1177
pub fn new(
1178
left_input_schema: Arc<Schema>,
1179
right_input_schema: Arc<Schema>,
1180
left_key_schema: Arc<Schema>,
1181
right_key_schema: Arc<Schema>,
1182
unique_key_schema: Arc<Schema>,
1183
left_key_selectors: Vec<StreamExpr>,
1184
right_key_selectors: Vec<StreamExpr>,
1185
args: JoinArgs,
1186
num_pipelines: usize,
1187
) -> PolarsResult<Self> {
1188
let left_is_build = match args.maintain_order {
1189
MaintainOrderJoin::None => match args.build_side {
1190
Some(JoinBuildSide::ForceLeft) => Some(true),
1191
Some(JoinBuildSide::ForceRight) => Some(false),
1192
Some(JoinBuildSide::PreferLeft) | Some(JoinBuildSide::PreferRight) | None => {
1193
if *JOIN_SAMPLE_LIMIT == 0 {
1194
Some(args.build_side != Some(JoinBuildSide::PreferRight))
1195
} else {
1196
None
1197
}
1198
},
1199
},
1200
MaintainOrderJoin::Left | MaintainOrderJoin::LeftRight => {
1201
if args.build_side == Some(JoinBuildSide::ForceLeft) {
1202
polars_warn!("can't force left build-side with left-maintaining cross-join");
1203
}
1204
Some(false)
1205
},
1206
MaintainOrderJoin::Right | MaintainOrderJoin::RightLeft => {
1207
if args.build_side == Some(JoinBuildSide::ForceRight) {
1208
polars_warn!("can't force right build-side with right-maintaining cross-join");
1209
}
1210
Some(true)
1211
},
1212
};
1213
1214
let preserve_order_probe = args.maintain_order != MaintainOrderJoin::None;
1215
let preserve_order_build = matches!(
1216
args.maintain_order,
1217
MaintainOrderJoin::LeftRight | MaintainOrderJoin::RightLeft
1218
);
1219
1220
let left_payload_select = compute_payload_selector(
1221
&left_input_schema,
1222
&right_input_schema,
1223
&left_key_schema,
1224
&right_key_schema,
1225
true,
1226
&args,
1227
)?;
1228
let right_payload_select = compute_payload_selector(
1229
&right_input_schema,
1230
&left_input_schema,
1231
&right_key_schema,
1232
&left_key_schema,
1233
false,
1234
&args,
1235
)?;
1236
1237
let state = if left_is_build.is_some() {
1238
EquiJoinState::Build(BuildState::new(
1239
num_pipelines,
1240
num_pipelines,
1241
BufferedStream::default(),
1242
))
1243
} else {
1244
EquiJoinState::Sample(SampleState::default())
1245
};
1246
1247
let left_payload_schema = Arc::new(select_schema(&left_input_schema, &left_payload_select));
1248
let right_payload_schema =
1249
Arc::new(select_schema(&right_input_schema, &right_payload_select));
1250
Ok(Self {
1251
state,
1252
params: EquiJoinParams {
1253
left_is_build,
1254
preserve_order_build,
1255
preserve_order_probe,
1256
left_key_schema,
1257
left_key_selectors,
1258
right_key_schema,
1259
right_key_selectors,
1260
left_payload_select,
1261
right_payload_select,
1262
left_payload_schema,
1263
right_payload_schema,
1264
args,
1265
random_state: PlRandomState::default(),
1266
},
1267
table: new_idx_table(unique_key_schema),
1268
})
1269
}
1270
}
1271
1272
impl ComputeNode for EquiJoinNode {
1273
fn name(&self) -> &str {
1274
"equi-join"
1275
}
1276
1277
fn update_state(
1278
&mut self,
1279
recv: &mut [PortState],
1280
send: &mut [PortState],
1281
state: &StreamingExecutionState,
1282
) -> PolarsResult<()> {
1283
assert!(recv.len() == 2 && send.len() == 1);
1284
1285
// If the output doesn't want any more data, transition to being done.
1286
if send[0] == PortState::Done {
1287
self.state = EquiJoinState::Done;
1288
}
1289
1290
// If we are sampling and both sides are done/filled, transition to building.
1291
if let EquiJoinState::Sample(sample_state) = &mut self.state {
1292
if let Some(build_state) =
1293
sample_state.try_transition_to_build(recv, &mut self.params, state)?
1294
{
1295
self.state = EquiJoinState::Build(build_state);
1296
}
1297
}
1298
1299
let build_idx = if self.params.left_is_build == Some(true) {
1300
0
1301
} else {
1302
1
1303
};
1304
let probe_idx = 1 - build_idx;
1305
1306
// If we are building and the build input is done, transition to probing.
1307
if let EquiJoinState::Build(build_state) = &mut self.state {
1308
if recv[build_idx] == PortState::Done {
1309
let probe_state = if self.params.preserve_order_build {
1310
build_state.finalize_ordered(&self.params, &*self.table)
1311
} else {
1312
build_state.finalize_unordered(&self.params, &*self.table)
1313
};
1314
self.state = EquiJoinState::Probe(probe_state);
1315
}
1316
}
1317
1318
// If we are probing and the probe input is done, emit unmatched if
1319
// necessary, otherwise we're done.
1320
if let EquiJoinState::Probe(probe_state) = &mut self.state {
1321
let samples_consumed = probe_state.sampled_probe_morsels.is_empty();
1322
if samples_consumed && recv[probe_idx] == PortState::Done {
1323
if self.params.emit_unmatched_build() {
1324
if self.params.preserve_order_build {
1325
let unmatched = probe_state.ordered_unmatched(&self.params);
1326
let src = InMemorySourceNode::new(
1327
Arc::new(unmatched),
1328
probe_state.max_seq_sent.successor(),
1329
);
1330
self.state = EquiJoinState::EmitUnmatchedBuildInOrder(src);
1331
} else {
1332
self.state = EquiJoinState::EmitUnmatchedBuild(EmitUnmatchedState {
1333
partitions: core::mem::take(&mut probe_state.table_per_partition),
1334
active_partition_idx: 0,
1335
offset_in_active_p: 0,
1336
morsel_seq: probe_state.max_seq_sent.successor(),
1337
});
1338
}
1339
} else {
1340
self.state = EquiJoinState::Done;
1341
}
1342
}
1343
}
1344
1345
// Finally, check if we are done emitting unmatched keys.
1346
if let EquiJoinState::EmitUnmatchedBuild(emit_state) = &mut self.state {
1347
if emit_state.active_partition_idx >= emit_state.partitions.len() {
1348
self.state = EquiJoinState::Done;
1349
}
1350
}
1351
1352
match &mut self.state {
1353
EquiJoinState::Sample(sample_state) => {
1354
send[0] = PortState::Blocked;
1355
if recv[0] != PortState::Done {
1356
recv[0] = if sample_state.left_len < *JOIN_SAMPLE_LIMIT {
1357
PortState::Ready
1358
} else {
1359
PortState::Blocked
1360
};
1361
}
1362
if recv[1] != PortState::Done {
1363
recv[1] = if sample_state.right_len < *JOIN_SAMPLE_LIMIT {
1364
PortState::Ready
1365
} else {
1366
PortState::Blocked
1367
};
1368
}
1369
},
1370
EquiJoinState::Build(_) => {
1371
send[0] = PortState::Blocked;
1372
if recv[build_idx] != PortState::Done {
1373
recv[build_idx] = PortState::Ready;
1374
}
1375
if recv[probe_idx] != PortState::Done {
1376
recv[probe_idx] = PortState::Blocked;
1377
}
1378
},
1379
EquiJoinState::Probe(probe_state) => {
1380
if recv[probe_idx] != PortState::Done {
1381
core::mem::swap(&mut send[0], &mut recv[probe_idx]);
1382
} else {
1383
let samples_consumed = probe_state.sampled_probe_morsels.is_empty();
1384
send[0] = if samples_consumed {
1385
PortState::Done
1386
} else {
1387
PortState::Ready
1388
};
1389
}
1390
recv[build_idx] = PortState::Done;
1391
},
1392
EquiJoinState::EmitUnmatchedBuild(_) => {
1393
send[0] = PortState::Ready;
1394
recv[build_idx] = PortState::Done;
1395
recv[probe_idx] = PortState::Done;
1396
},
1397
EquiJoinState::EmitUnmatchedBuildInOrder(src_node) => {
1398
recv[build_idx] = PortState::Done;
1399
recv[probe_idx] = PortState::Done;
1400
src_node.update_state(&mut [], &mut send[0..1], state)?;
1401
if send[0] == PortState::Done {
1402
self.state = EquiJoinState::Done;
1403
}
1404
},
1405
EquiJoinState::Done => {
1406
send[0] = PortState::Done;
1407
recv[0] = PortState::Done;
1408
recv[1] = PortState::Done;
1409
},
1410
}
1411
Ok(())
1412
}
1413
1414
fn is_memory_intensive_pipeline_blocker(&self) -> bool {
1415
matches!(
1416
self.state,
1417
EquiJoinState::Sample { .. } | EquiJoinState::Build { .. }
1418
)
1419
}
1420
1421
fn spawn<'env, 's>(
1422
&'env mut self,
1423
scope: &'s TaskScope<'s, 'env>,
1424
recv_ports: &mut [Option<RecvPort<'_>>],
1425
send_ports: &mut [Option<SendPort<'_>>],
1426
state: &'s StreamingExecutionState,
1427
join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,
1428
) {
1429
assert!(recv_ports.len() == 2);
1430
assert!(send_ports.len() == 1);
1431
1432
let build_idx = if self.params.left_is_build == Some(true) {
1433
0
1434
} else {
1435
1
1436
};
1437
let probe_idx = 1 - build_idx;
1438
1439
match &mut self.state {
1440
EquiJoinState::Sample(sample_state) => {
1441
assert!(send_ports[0].is_none());
1442
let left_final_len = Arc::new(RelaxedCell::from(if recv_ports[0].is_none() {
1443
sample_state.left_len
1444
} else {
1445
usize::MAX
1446
}));
1447
let right_final_len = Arc::new(RelaxedCell::from(if recv_ports[1].is_none() {
1448
sample_state.right_len
1449
} else {
1450
usize::MAX
1451
}));
1452
1453
if let Some(left_recv) = recv_ports[0].take() {
1454
join_handles.push(scope.spawn_task(
1455
TaskPriority::High,
1456
SampleState::sink(
1457
left_recv.serial(),
1458
&mut sample_state.left,
1459
&mut sample_state.left_len,
1460
left_final_len.clone(),
1461
right_final_len.clone(),
1462
),
1463
));
1464
}
1465
if let Some(right_recv) = recv_ports[1].take() {
1466
join_handles.push(scope.spawn_task(
1467
TaskPriority::High,
1468
SampleState::sink(
1469
right_recv.serial(),
1470
&mut sample_state.right,
1471
&mut sample_state.right_len,
1472
right_final_len,
1473
left_final_len,
1474
),
1475
));
1476
}
1477
},
1478
EquiJoinState::Build(build_state) => {
1479
assert!(send_ports[0].is_none());
1480
assert!(recv_ports[probe_idx].is_none());
1481
let receivers = recv_ports[build_idx].take().unwrap().parallel();
1482
1483
let partitioner = HashPartitioner::new(state.num_pipelines, 0);
1484
for (local_builder, recv) in build_state.local_builders.iter_mut().zip(receivers) {
1485
join_handles.push(scope.spawn_task(
1486
TaskPriority::High,
1487
BuildState::partition_and_sink(
1488
recv,
1489
local_builder,
1490
partitioner.clone(),
1491
&self.params,
1492
state,
1493
),
1494
));
1495
}
1496
},
1497
EquiJoinState::Probe(probe_state) => {
1498
assert!(recv_ports[build_idx].is_none());
1499
let senders = send_ports[0].take().unwrap().parallel();
1500
let receivers = probe_state
1501
.sampled_probe_morsels
1502
.reinsert(
1503
state.num_pipelines,
1504
recv_ports[probe_idx].take(),
1505
scope,
1506
join_handles,
1507
)
1508
.unwrap();
1509
1510
let partitioner = HashPartitioner::new(state.num_pipelines, 0);
1511
let probe_tasks = receivers
1512
.into_iter()
1513
.zip(senders)
1514
.map(|(recv, send)| {
1515
scope.spawn_task(
1516
TaskPriority::High,
1517
ProbeState::partition_and_probe(
1518
recv,
1519
send,
1520
&probe_state.table_per_partition,
1521
&probe_state.unordered_morsel_seq,
1522
partitioner.clone(),
1523
&self.params,
1524
state,
1525
),
1526
)
1527
})
1528
.collect_vec();
1529
1530
let max_seq_sent = &mut probe_state.max_seq_sent;
1531
join_handles.push(scope.spawn_task(TaskPriority::High, async move {
1532
for probe_task in probe_tasks {
1533
*max_seq_sent = (*max_seq_sent).max(probe_task.await?);
1534
}
1535
Ok(())
1536
}));
1537
},
1538
EquiJoinState::EmitUnmatchedBuild(emit_state) => {
1539
assert!(recv_ports[build_idx].is_none());
1540
assert!(recv_ports[probe_idx].is_none());
1541
let send = send_ports[0].take().unwrap().serial();
1542
join_handles.push(scope.spawn_task(
1543
TaskPriority::Low,
1544
emit_state.emit_unmatched(send, &self.params, state.num_pipelines),
1545
));
1546
},
1547
EquiJoinState::EmitUnmatchedBuildInOrder(src_node) => {
1548
assert!(recv_ports[build_idx].is_none());
1549
assert!(recv_ports[probe_idx].is_none());
1550
src_node.spawn(scope, &mut [], send_ports, state, join_handles);
1551
},
1552
EquiJoinState::Done => unreachable!(),
1553
}
1554
}
1555
}
1556
1557