Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/plans/optimizer/predicate_pushdown/join.rs
8422 views
1
use polars_utils::format_pl_smallstr;
2
3
use super::*;
4
use crate::plans::optimizer::join_utils::remove_suffix;
5
6
const IEJOIN_MAX_PREDICATES: usize = 2;
7
8
#[allow(clippy::too_many_arguments)]
9
pub(super) fn process_join(
10
opt: &mut PredicatePushDown,
11
lp_arena: &mut Arena<IR>,
12
expr_arena: &mut Arena<AExpr>,
13
input_left: Node,
14
input_right: Node,
15
mut left_on: Vec<ExprIR>,
16
mut right_on: Vec<ExprIR>,
17
mut schema: SchemaRef,
18
mut options: Arc<JoinOptionsIR>,
19
mut acc_predicates: PlHashMap<PlSmallStr, ExprIR>,
20
streaming: bool,
21
) -> PolarsResult<IR> {
22
let schema_left = lp_arena.get(input_left).schema(lp_arena).into_owned();
23
let schema_right = lp_arena.get(input_right).schema(lp_arena).into_owned();
24
25
let opt_post_select = try_rewrite_join_type(
26
&schema_left,
27
&schema_right,
28
&mut schema,
29
&mut options,
30
&mut left_on,
31
&mut right_on,
32
&mut acc_predicates,
33
expr_arena,
34
streaming,
35
)?;
36
37
if match &options.args.how {
38
// Full-join with no coalesce. We can only push filters if they do not remove NULLs, but
39
// we don't have a reliable way to guarantee this.
40
JoinType::Full => !options.args.should_coalesce(),
41
42
_ => false,
43
} || acc_predicates.is_empty()
44
{
45
let lp = IR::Join {
46
input_left,
47
input_right,
48
left_on,
49
right_on,
50
schema,
51
options,
52
};
53
54
return opt.no_pushdown_restart_opt(lp, acc_predicates, lp_arena, expr_arena);
55
}
56
57
let should_coalesce = options.args.should_coalesce();
58
59
// AsOf has the equality join keys under `asof_options.left/right_by`. This code builds an
60
// iterator to address these generically without creating a `Box<dyn Iterator>`.
61
let get_lhs_column_keys_iter = || {
62
let len = match &options.args.how {
63
#[cfg(feature = "asof_join")]
64
JoinType::AsOf(asof_options) => {
65
asof_options.left_by.as_deref().unwrap_or_default().len()
66
},
67
_ => left_on.len(),
68
};
69
70
(0..len).map(|i| match &options.args.how {
71
#[cfg(feature = "asof_join")]
72
JoinType::AsOf(asof_options) => Some(
73
asof_options
74
.left_by
75
.as_deref()
76
.unwrap_or_default()
77
.get(i)
78
.unwrap(),
79
),
80
_ => {
81
let expr = left_on.get(i).unwrap();
82
83
// For non full-joins coalesce can still insert casts into the key exprs.
84
let node = match expr_arena.get(expr.node()) {
85
AExpr::Cast {
86
expr,
87
dtype: _,
88
options: _,
89
} if should_coalesce => *expr,
90
91
_ => expr.node(),
92
};
93
94
if let AExpr::Column(name) = expr_arena.get(node) {
95
Some(name)
96
} else {
97
None
98
}
99
},
100
})
101
};
102
103
let get_rhs_column_keys_iter = || {
104
let len = match &options.args.how {
105
#[cfg(feature = "asof_join")]
106
JoinType::AsOf(asof_options) => {
107
asof_options.right_by.as_deref().unwrap_or_default().len()
108
},
109
_ => right_on.len(),
110
};
111
112
(0..len).map(|i| match &options.args.how {
113
#[cfg(feature = "asof_join")]
114
JoinType::AsOf(asof_options) => Some(
115
asof_options
116
.right_by
117
.as_deref()
118
.unwrap_or_default()
119
.get(i)
120
.unwrap(),
121
),
122
_ => {
123
let expr = right_on.get(i).unwrap();
124
125
// For non full-joins coalesce can still insert casts into the key exprs.
126
let node = match expr_arena.get(expr.node()) {
127
AExpr::Cast {
128
expr,
129
dtype: _,
130
options: _,
131
} if should_coalesce => *expr,
132
133
_ => expr.node(),
134
};
135
136
if let AExpr::Column(name) = expr_arena.get(node) {
137
Some(name)
138
} else {
139
None
140
}
141
},
142
})
143
};
144
145
if cfg!(debug_assertions) && options.args.should_coalesce() {
146
match &options.args.how {
147
#[cfg(feature = "asof_join")]
148
JoinType::AsOf(_) => {},
149
150
_ => {
151
assert!(get_lhs_column_keys_iter().len() > 0);
152
assert!(get_rhs_column_keys_iter().len() > 0);
153
},
154
}
155
156
assert!(get_lhs_column_keys_iter().all(|x| x.is_some()));
157
assert!(get_rhs_column_keys_iter().all(|x| x.is_some()));
158
}
159
160
// Key columns of the left table that are coalesced into an output column of the right table.
161
let coalesced_to_right: PlHashSet<PlSmallStr> =
162
if matches!(&options.args.how, JoinType::Right) && options.args.should_coalesce() {
163
get_lhs_column_keys_iter()
164
.map(|x| x.unwrap().clone())
165
.collect()
166
} else {
167
Default::default()
168
};
169
170
let mut output_key_to_left_input_map: PlHashMap<PlSmallStr, PlSmallStr> =
171
PlHashMap::with_capacity(get_lhs_column_keys_iter().len());
172
let mut output_key_to_right_input_map: PlHashMap<PlSmallStr, PlSmallStr> =
173
PlHashMap::with_capacity(get_rhs_column_keys_iter().len());
174
175
for (lhs_input_key, rhs_input_key) in get_lhs_column_keys_iter().zip(get_rhs_column_keys_iter())
176
{
177
let (Some(lhs_input_key), Some(rhs_input_key)) = (lhs_input_key, rhs_input_key) else {
178
continue;
179
};
180
181
// lhs_input_key: Column name within the left table.
182
use JoinType::*;
183
// Map output name of an LHS join key output to an input key column of the right table.
184
// This will cause predicates referring to LHS join keys to also be pushed to the RHS table.
185
if match &options.args.how {
186
Left | Inner | Full => true,
187
188
#[cfg(feature = "asof_join")]
189
AsOf(_) => true,
190
#[cfg(feature = "semi_anti_join")]
191
Semi | Anti => true,
192
193
// NOTE: Right-join is excluded.
194
Right => false,
195
196
#[cfg(feature = "iejoin")]
197
IEJoin => false,
198
199
Cross => unreachable!(), // Cross left/right_on should be empty
200
} {
201
// Note: `lhs_input_key` maintains its name in the output column for all cases except
202
// for a coalescing right-join.
203
output_key_to_right_input_map.insert(lhs_input_key.clone(), rhs_input_key.clone());
204
}
205
206
// Map output name of an RHS join key output to a key column of the left table.
207
// This will cause predicates referring to RHS join keys to also be pushed to the LHS table.
208
if match &options.args.how {
209
JoinType::Right => true,
210
// Non-coalesced output columns of an inner join are equivalent between LHS and RHS.
211
JoinType::Inner => !options.args.should_coalesce(),
212
_ => false,
213
} {
214
let rhs_output_key: PlSmallStr = if schema_left.contains(rhs_input_key.as_str())
215
&& !coalesced_to_right.contains(rhs_input_key.as_str())
216
{
217
format_pl_smallstr!("{}{}", rhs_input_key, options.args.suffix())
218
} else {
219
rhs_input_key.clone()
220
};
221
222
assert!(schema.contains(&rhs_output_key));
223
224
output_key_to_left_input_map.insert(rhs_output_key.clone(), lhs_input_key.clone());
225
}
226
}
227
228
let mut pushdown_left: PlHashMap<PlSmallStr, ExprIR> = init_hashmap(Some(acc_predicates.len()));
229
let mut pushdown_right: PlHashMap<PlSmallStr, ExprIR> =
230
init_hashmap(Some(acc_predicates.len()));
231
let mut local_predicates = Vec::with_capacity(acc_predicates.len());
232
233
for (_, predicate) in acc_predicates {
234
let mut push_left = true;
235
let mut push_right = true;
236
237
for col_name in aexpr_to_leaf_names_iter(predicate.node(), expr_arena) {
238
let origin: ExprOrigin = ExprOrigin::get_column_origin(
239
col_name.as_str(),
240
&schema_left,
241
&schema_right,
242
options.args.suffix(),
243
Some(&|name| coalesced_to_right.contains(name)),
244
)
245
.unwrap();
246
247
push_left &= matches!(origin, ExprOrigin::Left | ExprOrigin::None)
248
|| output_key_to_left_input_map.contains_key(col_name);
249
250
push_right &= matches!(origin, ExprOrigin::Right | ExprOrigin::None)
251
|| output_key_to_right_input_map.contains_key(col_name);
252
}
253
254
// Note: If `push_left` and `push_right` are both `true`, it means the predicate refers only
255
// to the join key columns, or the predicate does not refer any columns.
256
257
let has_residual = match &options.args.how {
258
// Pushing to a single side is enough to observe the full effect of the filter.
259
JoinType::Inner => !(push_left || push_right),
260
261
// Left-join: Pushing filters to the left table is enough to observe the effect of the
262
// filter. Pushing filters to the right is optional, but can only be done if the
263
// filter is also pushed to the left (if this is the case it means the filter only
264
// references join key columns).
265
JoinType::Left => {
266
push_right &= push_left;
267
!push_left
268
},
269
270
// Same as left-join, just flipped around.
271
JoinType::Right => {
272
push_left &= push_right;
273
!push_right
274
},
275
276
// Full-join: Filters must strictly apply only to coalesced output key columns.
277
JoinType::Full => {
278
assert!(options.args.should_coalesce());
279
280
let push = push_left && push_right;
281
push_left = push;
282
push_right = push;
283
284
!push
285
},
286
287
JoinType::Cross => {
288
// Predicate should only refer to a single side.
289
assert!(output_key_to_left_input_map.is_empty());
290
assert!(output_key_to_right_input_map.is_empty());
291
!(push_left || push_right)
292
},
293
294
// Behaves similarly to left-join on "by" columns (takes a single match instead of
295
// all matches according to asof strategy).
296
#[cfg(feature = "asof_join")]
297
JoinType::AsOf(_) => {
298
push_right &= push_left;
299
!push_left
300
},
301
302
// Same as inner-join.
303
#[cfg(feature = "semi_anti_join")]
304
JoinType::Semi => !(push_left || push_right),
305
306
// Anti-join is an exclusion of key tuples that exist in the right table, meaning that
307
// filters can only be pushed to the right table if they are also pushed to the left.
308
#[cfg(feature = "semi_anti_join")]
309
JoinType::Anti => {
310
push_right &= push_left;
311
!push_left
312
},
313
314
// Same as inner-join.
315
#[cfg(feature = "iejoin")]
316
JoinType::IEJoin => !(push_left || push_right),
317
};
318
319
if has_residual {
320
local_predicates.push(predicate.clone())
321
}
322
323
if push_left {
324
let mut predicate = predicate.clone();
325
map_column_references(&mut predicate, expr_arena, &output_key_to_left_input_map);
326
insert_predicate_dedup(&mut pushdown_left, &predicate, expr_arena);
327
}
328
329
if push_right {
330
let mut predicate = predicate;
331
map_column_references(&mut predicate, expr_arena, &output_key_to_right_input_map);
332
remove_suffix(
333
&mut predicate,
334
expr_arena,
335
&schema_right,
336
options.args.suffix(),
337
);
338
insert_predicate_dedup(&mut pushdown_right, &predicate, expr_arena);
339
}
340
}
341
342
opt.pushdown_and_assign(input_left, pushdown_left, lp_arena, expr_arena)?;
343
opt.pushdown_and_assign(input_right, pushdown_right, lp_arena, expr_arena)?;
344
345
let lp = IR::Join {
346
input_left,
347
input_right,
348
left_on,
349
right_on,
350
schema,
351
options,
352
};
353
354
let lp = opt.optional_apply_predicate(lp, local_predicates, lp_arena, expr_arena);
355
356
let lp = if let Some((projections, schema)) = opt_post_select {
357
IR::Select {
358
input: lp_arena.add(lp),
359
expr: projections,
360
schema,
361
options: ProjectionOptions {
362
run_parallel: false,
363
duplicate_check: false,
364
should_broadcast: false,
365
},
366
}
367
} else {
368
lp
369
};
370
371
Ok(lp)
372
}
373
374
/// Attempts to rewrite the join-type based on NULL-removing filters.
375
///
376
/// Changing between some join types may cause the output column order to change. If this is the
377
/// case, a Vec of column selectors will be returned that restore the original column order.
378
#[expect(clippy::too_many_arguments)]
379
fn try_rewrite_join_type(
380
schema_left: &SchemaRef,
381
schema_right: &SchemaRef,
382
output_schema: &mut SchemaRef,
383
options: &mut Arc<JoinOptionsIR>,
384
left_on: &mut Vec<ExprIR>,
385
right_on: &mut Vec<ExprIR>,
386
acc_predicates: &mut PlHashMap<PlSmallStr, ExprIR>,
387
expr_arena: &mut Arena<AExpr>,
388
streaming: bool,
389
) -> PolarsResult<Option<(Vec<ExprIR>, SchemaRef)>> {
390
if acc_predicates.is_empty() {
391
return Ok(None);
392
}
393
394
let suffix = options.args.suffix().clone();
395
396
// * Cross -> Inner | IEJoin
397
// * IEJoin -> Inner
398
//
399
// Note: The join rewrites here all maintain output column ordering, hence this does not need
400
// to return any post-select (inserted inner joins will use JoinCoalesce::KeepColumns).
401
(|| {
402
match &options.args.how {
403
#[cfg(feature = "iejoin")]
404
JoinType::IEJoin => {},
405
JoinType::Cross => {},
406
407
_ => return PolarsResult::Ok(()),
408
}
409
410
match &options.options {
411
Some(JoinTypeOptionsIR::CrossAndFilter { .. }) => {
412
let Some(JoinTypeOptionsIR::CrossAndFilter { predicate }) =
413
Arc::make_mut(options).options.take()
414
else {
415
unreachable!()
416
};
417
418
insert_predicate_dedup(acc_predicates, &predicate, expr_arena);
419
},
420
421
#[cfg(feature = "iejoin")]
422
Some(JoinTypeOptionsIR::IEJoin(_)) => {},
423
None => {},
424
}
425
426
// Try converting to inner join
427
let equality_conditions = take_inner_join_compatible_filters(
428
acc_predicates,
429
expr_arena,
430
schema_left,
431
schema_right,
432
&suffix,
433
)?;
434
435
for InnerJoinKeys {
436
input_lhs,
437
input_rhs,
438
} in equality_conditions
439
{
440
let join_options = Arc::make_mut(options);
441
join_options.args.how = JoinType::Inner;
442
join_options.args.coalesce = JoinCoalesce::KeepColumns;
443
444
left_on.push(ExprIR::from_node(input_lhs, expr_arena));
445
let mut rexpr = ExprIR::from_node(input_rhs, expr_arena);
446
remove_suffix(&mut rexpr, expr_arena, schema_right, &suffix);
447
right_on.push(rexpr);
448
}
449
450
if options.args.how == JoinType::Inner {
451
return Ok(());
452
}
453
454
// Try converting cross join to IEJoin
455
#[cfg(feature = "iejoin")]
456
if matches!(options.args.maintain_order, MaintainOrderJoin::None)
457
&& left_on.len() < IEJOIN_MAX_PREDICATES
458
{
459
let ie_conditions = take_iejoin_compatible_filters(
460
acc_predicates,
461
expr_arena,
462
schema_left,
463
schema_right,
464
output_schema,
465
&suffix,
466
)?;
467
468
for IEJoinCompatiblePredicate {
469
input_lhs,
470
input_rhs,
471
ie_op,
472
source_node,
473
} in ie_conditions
474
{
475
let join_options = Arc::make_mut(options);
476
join_options.args.how = JoinType::IEJoin;
477
478
if left_on.len() >= IEJOIN_MAX_PREDICATES {
479
// Important: Place these back into acc_predicates.
480
insert_predicate_dedup(
481
acc_predicates,
482
&ExprIR::from_node(source_node, expr_arena),
483
expr_arena,
484
);
485
} else {
486
left_on.push(ExprIR::from_node(input_lhs, expr_arena));
487
let mut rexpr = ExprIR::from_node(input_rhs, expr_arena);
488
remove_suffix(&mut rexpr, expr_arena, schema_right, &suffix);
489
right_on.push(rexpr);
490
491
let JoinTypeOptionsIR::IEJoin(ie_options) = join_options
492
.options
493
.get_or_insert(JoinTypeOptionsIR::IEJoin(IEJoinOptions::default()))
494
else {
495
unreachable!()
496
};
497
498
match left_on.len() {
499
1 => ie_options.operator1 = ie_op,
500
2 => ie_options.operator2 = Some(ie_op),
501
_ => unreachable!("{}", IEJOIN_MAX_PREDICATES),
502
};
503
}
504
}
505
506
if options.args.how == JoinType::IEJoin {
507
return Ok(());
508
}
509
}
510
511
debug_assert_eq!(options.args.how, JoinType::Cross);
512
513
if options.args.how != JoinType::Cross {
514
return Ok(());
515
}
516
517
if streaming {
518
return Ok(());
519
}
520
521
let Some(nested_loop_predicates) = take_nested_loop_join_compatible_filters(
522
acc_predicates,
523
expr_arena,
524
schema_left,
525
schema_right,
526
&suffix,
527
)?
528
.reduce(|left, right| {
529
expr_arena.add(AExpr::BinaryExpr {
530
left,
531
op: Operator::And,
532
right,
533
})
534
}) else {
535
return Ok(());
536
};
537
538
let existing = Arc::make_mut(options)
539
.options
540
.replace(JoinTypeOptionsIR::CrossAndFilter {
541
predicate: ExprIR::from_node(nested_loop_predicates, expr_arena),
542
});
543
assert!(existing.is_none()); // Important
544
545
Ok(())
546
})()?;
547
548
if !matches!(
549
&options.args.how,
550
JoinType::Full | JoinType::Left | JoinType::Right
551
) {
552
return Ok(None);
553
}
554
555
let should_coalesce = options.args.should_coalesce();
556
557
/// Note: This may panic if `args.should_coalesce()` is false.
558
macro_rules! lhs_input_column_keys_iter {
559
() => {{
560
left_on.iter().map(|expr| {
561
let node = match expr_arena.get(expr.node()) {
562
AExpr::Cast {
563
expr,
564
dtype: _,
565
options: _,
566
} if should_coalesce => *expr,
567
568
_ => expr.node(),
569
};
570
571
let AExpr::Column(name) = expr_arena.get(node) else {
572
// All keys should be columns when coalesce=True
573
unreachable!()
574
};
575
576
name.clone()
577
})
578
}};
579
}
580
581
let mut coalesced_to_right: PlHashSet<PlSmallStr> = Default::default();
582
// Removing NULLs on these columns do not allow for join downgrading.
583
// We only need to track these for full-join - e.g. for left-join, removing NULLs from any left
584
// column does not cause any join rewrites.
585
let mut coalesced_full_join_key_outputs: PlHashSet<PlSmallStr> = Default::default();
586
587
if options.args.should_coalesce() {
588
match &options.args.how {
589
JoinType::Full => {
590
coalesced_full_join_key_outputs = lhs_input_column_keys_iter!().collect()
591
},
592
JoinType::Right => coalesced_to_right = lhs_input_column_keys_iter!().collect(),
593
_ => {},
594
}
595
}
596
597
let mut non_null_side = ExprOrigin::None;
598
599
for predicate in acc_predicates.values() {
600
for node in MintermIter::new(predicate.node(), expr_arena) {
601
predicate_non_null_column_outputs(node, expr_arena, &mut |non_null_column| {
602
if coalesced_full_join_key_outputs.contains(non_null_column) {
603
return;
604
}
605
606
non_null_side |= ExprOrigin::get_column_origin(
607
non_null_column.as_str(),
608
schema_left,
609
schema_right,
610
options.args.suffix(),
611
Some(&|x| coalesced_to_right.contains(x)),
612
)
613
.unwrap();
614
});
615
}
616
}
617
618
let Some(new_join_type) = (match non_null_side {
619
ExprOrigin::Both => Some(JoinType::Inner),
620
621
ExprOrigin::Left => match &options.args.how {
622
JoinType::Full => Some(JoinType::Left),
623
JoinType::Right => Some(JoinType::Inner),
624
_ => None,
625
},
626
627
ExprOrigin::Right => match &options.args.how {
628
JoinType::Full => Some(JoinType::Right),
629
JoinType::Left => Some(JoinType::Inner),
630
_ => None,
631
},
632
633
ExprOrigin::None => None,
634
}) else {
635
return Ok(None);
636
};
637
638
let options = Arc::make_mut(options);
639
// Ensure JoinSpecific is materialized to a specific config option, as we change the join type.
640
options.args.coalesce = if options.args.should_coalesce() {
641
JoinCoalesce::CoalesceColumns
642
} else {
643
JoinCoalesce::KeepColumns
644
};
645
let original_join_type = std::mem::replace(&mut options.args.how, new_join_type.clone());
646
let original_output_schema = match (&original_join_type, &new_join_type) {
647
(JoinType::Right, _) | (_, JoinType::Right) => std::mem::replace(
648
output_schema,
649
det_join_schema(
650
schema_left,
651
schema_right,
652
left_on,
653
right_on,
654
options,
655
expr_arena,
656
)
657
.unwrap(),
658
),
659
_ => {
660
debug_assert_eq!(
661
output_schema,
662
&det_join_schema(
663
schema_left,
664
schema_right,
665
left_on,
666
right_on,
667
options,
668
expr_arena,
669
)
670
.unwrap()
671
);
672
output_schema.clone()
673
},
674
};
675
676
// Maps the original join output names to the new join output names (used for mapping column
677
// references of the predicates).
678
let mut original_to_new_names_map: PlHashMap<PlSmallStr, PlSmallStr> = Default::default();
679
// Projects the new join output table back into the original join output table.
680
let mut project_to_original: Option<Vec<ExprIR>> = None;
681
682
if options.args.should_coalesce() {
683
// If we changed join types between a coalescing right-join, we need to do a select() to restore the column
684
// order of the original join type. The column references in the predicates may also need to be changed.
685
match (&original_join_type, &new_join_type) {
686
(JoinType::Right, JoinType::Right) => unreachable!(),
687
688
// Right-join rewritten to inner-join.
689
//
690
// E.g.
691
// Left: | a | b | c |
692
// Right: | a | b | c |
693
//
694
// right_join(left_on='a', right_on='b'): | b | c | a | *b_right | c_right |
695
// inner_join(left_on='a', right_on='b'): | *a | b | c | a_right | c_right |
696
// note: '*' means coalesced key output column
697
//
698
// project_to_original: | col(b) | col(c) | col(a_right).alias(a) | col(a).alias(b_right) | col(c_right) |
699
// original_to_new_names_map: {'a': 'a_right', 'b_right': 'a'}
700
//
701
(JoinType::Right, JoinType::Inner) => {
702
let mut join_output_key_selectors = PlHashMap::with_capacity(right_on.len());
703
704
for (l, r) in left_on.iter().zip(right_on) {
705
// Unwrap any Cast expressions that may have been inserted for type coercion.
706
// For non full-joins coalesce can still insert casts into the key exprs.
707
let l_node = match expr_arena.get(l.node()) {
708
AExpr::Cast {
709
expr,
710
dtype: _,
711
options: _,
712
} if should_coalesce => *expr,
713
_ => l.node(),
714
};
715
let r_node = match expr_arena.get(r.node()) {
716
AExpr::Cast {
717
expr,
718
dtype: _,
719
options: _,
720
} if should_coalesce => *expr,
721
_ => r.node(),
722
};
723
724
let (AExpr::Column(lhs_input_key), AExpr::Column(rhs_input_key)) =
725
(expr_arena.get(l_node), expr_arena.get(r_node))
726
else {
727
// `should_coalesce() == true` should guarantee all are columns.
728
unreachable!()
729
};
730
731
let original_key_output_name: PlSmallStr = if schema_left
732
.contains(rhs_input_key.as_str())
733
&& !coalesced_to_right.contains(rhs_input_key.as_str())
734
{
735
format_pl_smallstr!("{}{}", rhs_input_key, options.args.suffix())
736
} else {
737
rhs_input_key.clone()
738
};
739
740
let new_key_output_name = lhs_input_key.clone();
741
let rhs_input_key = rhs_input_key.clone();
742
743
let node = expr_arena.add(AExpr::Column(lhs_input_key.clone()));
744
let mut ae = ExprIR::from_node(node, expr_arena);
745
746
if original_key_output_name != new_key_output_name {
747
// E.g. left_on=col(a), right_on=col(b)
748
// rhs_output_key = 'b', lhs_input_key = 'a', the original right-join is supposed to output 'b'.
749
original_to_new_names_map.insert(
750
original_key_output_name.clone(),
751
new_key_output_name.clone(),
752
);
753
ae.set_alias(original_key_output_name)
754
}
755
756
join_output_key_selectors.insert(rhs_input_key, ae);
757
}
758
759
let mut column_selectors: Vec<ExprIR> = Vec::with_capacity(output_schema.len());
760
761
for lhs_input_col in schema_left.iter_names() {
762
if coalesced_to_right.contains(lhs_input_col) {
763
continue;
764
}
765
766
let node = expr_arena.add(AExpr::Column(lhs_input_col.clone()));
767
column_selectors.push(ExprIR::from_node(node, expr_arena));
768
}
769
770
for rhs_input_col in schema_right.iter_names() {
771
let expr = if let Some(expr) = join_output_key_selectors.get(rhs_input_col) {
772
expr.clone()
773
} else if schema_left.contains(rhs_input_col) {
774
let new_join_output_name =
775
format_pl_smallstr!("{}{}", rhs_input_col, options.args.suffix());
776
777
let node = expr_arena.add(AExpr::Column(new_join_output_name.clone()));
778
let mut expr = ExprIR::from_node(node, expr_arena);
779
780
// The column with the same name from the LHS is not projected in the original
781
// right-join, so we alias to remove the suffix that was added from the inner-join.
782
if coalesced_to_right.contains(rhs_input_col.as_str()) {
783
original_to_new_names_map
784
.insert(rhs_input_col.clone(), new_join_output_name);
785
expr.set_alias(rhs_input_col.clone());
786
}
787
788
expr
789
} else {
790
let node = expr_arena.add(AExpr::Column(rhs_input_col.clone()));
791
ExprIR::from_node(node, expr_arena)
792
};
793
794
column_selectors.push(expr)
795
}
796
797
assert_eq!(column_selectors.len(), output_schema.len());
798
assert_eq!(column_selectors.len(), original_output_schema.len());
799
800
if cfg!(debug_assertions) {
801
assert!(
802
column_selectors
803
.iter()
804
.zip(original_output_schema.iter_names())
805
.all(|(l, r)| l.output_name() == r)
806
)
807
}
808
809
project_to_original = Some(column_selectors)
810
},
811
812
// Full-join rewritten to right-join
813
//
814
// E.g.
815
// Left: | a | b | c |
816
// Right: | a | b | c |
817
//
818
// full_join(left_on='a', right_on='b'): | *a | b | c | a_right | c_right |
819
// right_join(left_on='a', right_on='b'): | b | c | a | *b_right | c_right |
820
// note: '*' means coalesced key output column
821
//
822
// project_to_original: | col(b_right).alias(a) | col(b) | col(c) | col(a).alias(a_right) | col(c_right) |
823
// original_to_new_names_map: {'a': 'b_right', 'a_right': 'a'}
824
//
825
(JoinType::Full, JoinType::Right) => {
826
let mut join_output_key_selectors = PlHashMap::with_capacity(left_on.len());
827
828
// The existing one is empty because the original join type was not a right-join.
829
assert!(coalesced_to_right.is_empty());
830
// LHS input key columns that are coalesced (i.e. not projected) for the right-join.
831
let coalesced_to_right: PlHashSet<PlSmallStr> =
832
lhs_input_column_keys_iter!().collect();
833
// RHS input key columns that are coalesced (i.e. not projected) for the full-join.
834
let mut coalesced_to_left: PlHashSet<PlSmallStr> =
835
PlHashSet::with_capacity(right_on.len());
836
837
for (l, r) in left_on.iter().zip(right_on) {
838
// Unwrap any Cast expressions that may have been inserted for type coercion.
839
// For non full-joins coalesce can still insert casts into the key exprs.
840
let l_node = match expr_arena.get(l.node()) {
841
AExpr::Cast {
842
expr,
843
dtype: _,
844
options: _,
845
} if should_coalesce => *expr,
846
_ => l.node(),
847
};
848
let r_node = match expr_arena.get(r.node()) {
849
AExpr::Cast {
850
expr,
851
dtype: _,
852
options: _,
853
} if should_coalesce => *expr,
854
_ => r.node(),
855
};
856
857
let (AExpr::Column(lhs_input_key), AExpr::Column(rhs_input_key)) =
858
(expr_arena.get(l_node), expr_arena.get(r_node))
859
else {
860
// `should_coalesce() == true` should guarantee all columns.
861
unreachable!()
862
};
863
864
let new_key_output_name: PlSmallStr = if schema_left
865
.contains(rhs_input_key.as_str())
866
&& !coalesced_to_right.contains(rhs_input_key.as_str())
867
{
868
format_pl_smallstr!("{}{}", rhs_input_key, options.args.suffix())
869
} else {
870
rhs_input_key.clone()
871
};
872
873
let lhs_input_key = lhs_input_key.clone();
874
let rhs_input_key = rhs_input_key.clone();
875
let original_key_output_name = &lhs_input_key;
876
877
coalesced_to_left.insert(rhs_input_key);
878
879
let node = expr_arena.add(AExpr::Column(new_key_output_name.clone()));
880
881
let mut ae = ExprIR::from_node(node, expr_arena);
882
883
// E.g. left_on=col(a), right_on=col(b)
884
// rhs_output_key = 'b', lhs_input_key = 'a'
885
if new_key_output_name != original_key_output_name {
886
original_to_new_names_map.insert(
887
original_key_output_name.clone(),
888
new_key_output_name.clone(),
889
);
890
ae.set_alias(original_key_output_name.clone())
891
}
892
893
join_output_key_selectors.insert(lhs_input_key.clone(), ae);
894
}
895
896
let mut column_selectors = Vec::with_capacity(output_schema.len());
897
898
for lhs_input_col in schema_left.iter_names() {
899
let expr = if let Some(expr) = join_output_key_selectors.get(lhs_input_col) {
900
expr.clone()
901
} else {
902
let node = expr_arena.add(AExpr::Column(lhs_input_col.clone()));
903
ExprIR::from_node(node, expr_arena)
904
};
905
906
column_selectors.push(expr)
907
}
908
909
for rhs_input_col in schema_right.iter_names() {
910
if coalesced_to_left.contains(rhs_input_col) {
911
continue;
912
}
913
914
let mut original_output_name: Option<PlSmallStr> = None;
915
916
let new_join_output_name = if schema_left.contains(rhs_input_col) {
917
let suffixed =
918
format_pl_smallstr!("{}{}", rhs_input_col, options.args.suffix());
919
920
if coalesced_to_right.contains(rhs_input_col) {
921
original_output_name = Some(suffixed);
922
rhs_input_col.clone()
923
} else {
924
suffixed
925
}
926
} else {
927
rhs_input_col.clone()
928
};
929
930
let node = expr_arena.add(AExpr::Column(new_join_output_name));
931
932
let mut expr = ExprIR::from_node(node, expr_arena);
933
934
if let Some(original_output_name) = original_output_name {
935
original_to_new_names_map
936
.insert(original_output_name.clone(), rhs_input_col.clone());
937
expr.set_alias(original_output_name);
938
}
939
940
column_selectors.push(expr);
941
}
942
943
assert_eq!(column_selectors.len(), output_schema.len());
944
assert_eq!(column_selectors.len(), original_output_schema.len());
945
946
if cfg!(debug_assertions) {
947
assert!(
948
column_selectors
949
.iter()
950
.zip(original_output_schema.iter_names())
951
.all(|(l, r)| l.output_name() == r)
952
)
953
}
954
955
project_to_original = Some(column_selectors)
956
},
957
958
(JoinType::Right, _) | (_, JoinType::Right) => unreachable!(),
959
960
_ => {},
961
}
962
}
963
964
if !original_to_new_names_map.is_empty() {
965
assert!(project_to_original.is_some());
966
967
for (_, predicate_expr) in acc_predicates.iter_mut() {
968
map_column_references(predicate_expr, expr_arena, &original_to_new_names_map);
969
}
970
}
971
972
Ok(project_to_original.map(|p| (p, original_output_schema)))
973
}
974
975
struct InnerJoinKeys {
976
input_lhs: Node,
977
input_rhs: Node,
978
}
979
980
/// Removes all equality predicates that can be used as inner-join conditions from `acc_predicates`.
981
fn take_inner_join_compatible_filters(
982
acc_predicates: &mut PlHashMap<PlSmallStr, ExprIR>,
983
expr_arena: &mut Arena<AExpr>,
984
schema_left: &Schema,
985
schema_right: &Schema,
986
suffix: &str,
987
) -> PolarsResult<hashbrown::hash_map::IntoValues<Node, InnerJoinKeys>> {
988
take_predicates_mut(acc_predicates, expr_arena, |ae, _ae_node, expr_arena| {
989
Ok(match ae {
990
AExpr::BinaryExpr {
991
left,
992
op: Operator::Eq,
993
right,
994
} => {
995
let left_origin = ExprOrigin::get_expr_origin(
996
*left,
997
expr_arena,
998
schema_left,
999
schema_right,
1000
suffix,
1001
None, // is_coalesced_to_right
1002
)?;
1003
let right_origin = ExprOrigin::get_expr_origin(
1004
*right,
1005
expr_arena,
1006
schema_left,
1007
schema_right,
1008
suffix,
1009
None,
1010
)?;
1011
1012
match (left_origin, right_origin) {
1013
(ExprOrigin::Left, ExprOrigin::Right) => Some(InnerJoinKeys {
1014
input_lhs: *left,
1015
input_rhs: *right,
1016
}),
1017
(ExprOrigin::Right, ExprOrigin::Left) => Some(InnerJoinKeys {
1018
input_lhs: *right,
1019
input_rhs: *left,
1020
}),
1021
_ => None,
1022
}
1023
},
1024
_ => None,
1025
})
1026
})
1027
}
1028
1029
#[cfg(feature = "iejoin")]
1030
struct IEJoinCompatiblePredicate {
1031
input_lhs: Node,
1032
input_rhs: Node,
1033
ie_op: InequalityOperator,
1034
/// Original input node.
1035
source_node: Node,
1036
}
1037
1038
#[cfg(feature = "iejoin")]
1039
/// Removes all inequality filters that can be used as iejoin conditions from `acc_predicates`.
1040
fn take_iejoin_compatible_filters(
1041
acc_predicates: &mut PlHashMap<PlSmallStr, ExprIR>,
1042
expr_arena: &mut Arena<AExpr>,
1043
schema_left: &Schema,
1044
schema_right: &Schema,
1045
output_schema: &Schema,
1046
suffix: &str,
1047
) -> PolarsResult<hashbrown::hash_map::IntoValues<Node, IEJoinCompatiblePredicate>> {
1048
return take_predicates_mut(acc_predicates, expr_arena, |ae, ae_node, expr_arena| {
1049
Ok(match ae {
1050
AExpr::BinaryExpr { left, op, right } => {
1051
if to_inequality_operator(op).is_none() {
1052
return Ok(None);
1053
}
1054
1055
let left_origin = ExprOrigin::get_expr_origin(
1056
*left,
1057
expr_arena,
1058
schema_left,
1059
schema_right,
1060
suffix,
1061
None, // is_coalesced_to_right
1062
)?;
1063
1064
let right_origin = ExprOrigin::get_expr_origin(
1065
*right,
1066
expr_arena,
1067
schema_left,
1068
schema_right,
1069
suffix,
1070
None,
1071
)?;
1072
1073
macro_rules! is_supported_type {
1074
($node:expr) => {{
1075
let node = $node;
1076
let field = expr_arena
1077
.get(node)
1078
.to_field(&ToFieldContext::new(expr_arena, output_schema))?;
1079
let dtype = field.dtype();
1080
1081
!dtype.is_nested() && dtype.to_physical().is_primitive_numeric()
1082
}};
1083
}
1084
1085
// IEJoin only supports numeric.
1086
if !is_supported_type!(*left) || !is_supported_type!(*right) {
1087
return Ok(None);
1088
}
1089
1090
match (left_origin, right_origin) {
1091
(ExprOrigin::Left, ExprOrigin::Right) => Some(IEJoinCompatiblePredicate {
1092
input_lhs: *left,
1093
input_rhs: *right,
1094
ie_op: to_inequality_operator(op).unwrap(),
1095
source_node: ae_node,
1096
}),
1097
(ExprOrigin::Right, ExprOrigin::Left) => {
1098
let op = op.swap_operands();
1099
1100
Some(IEJoinCompatiblePredicate {
1101
input_lhs: *right,
1102
input_rhs: *left,
1103
ie_op: to_inequality_operator(&op).unwrap(),
1104
source_node: ae_node,
1105
})
1106
},
1107
_ => None,
1108
}
1109
},
1110
_ => None,
1111
})
1112
});
1113
1114
fn to_inequality_operator(op: &Operator) -> Option<InequalityOperator> {
1115
match op {
1116
Operator::Lt => Some(InequalityOperator::Lt),
1117
Operator::LtEq => Some(InequalityOperator::LtEq),
1118
Operator::Gt => Some(InequalityOperator::Gt),
1119
Operator::GtEq => Some(InequalityOperator::GtEq),
1120
_ => None,
1121
}
1122
}
1123
}
1124
1125
/// Removes all filters that can be used as nested loop join conditions from `acc_predicates`.
1126
///
1127
/// Note that filters that refer only to a single side are not removed so that they can be pushed
1128
/// into the LHS/RHS tables.
1129
fn take_nested_loop_join_compatible_filters(
1130
acc_predicates: &mut PlHashMap<PlSmallStr, ExprIR>,
1131
expr_arena: &mut Arena<AExpr>,
1132
schema_left: &Schema,
1133
schema_right: &Schema,
1134
suffix: &str,
1135
) -> PolarsResult<hashbrown::hash_map::IntoValues<Node, Node>> {
1136
take_predicates_mut(acc_predicates, expr_arena, |_ae, ae_node, expr_arena| {
1137
Ok(
1138
match ExprOrigin::get_expr_origin(
1139
ae_node,
1140
expr_arena,
1141
schema_left,
1142
schema_right,
1143
suffix,
1144
None,
1145
)? {
1146
// Leave single-origin exprs as they get pushed to the left/right tables individually.
1147
ExprOrigin::Left | ExprOrigin::Right | ExprOrigin::None => None,
1148
_ => Some(ae_node),
1149
},
1150
)
1151
})
1152
}
1153
1154
/// Removes predicates from the map according to a function.
1155
fn take_predicates_mut<F, T>(
1156
acc_predicates: &mut PlHashMap<PlSmallStr, ExprIR>,
1157
expr_arena: &mut Arena<AExpr>,
1158
take_predicate: F,
1159
) -> PolarsResult<hashbrown::hash_map::IntoValues<Node, T>>
1160
where
1161
F: Fn(&AExpr, Node, &Arena<AExpr>) -> PolarsResult<Option<T>>,
1162
{
1163
let mut selected_predicates: PlHashMap<Node, T> = PlHashMap::new();
1164
1165
for predicate in acc_predicates.values() {
1166
for node in MintermIter::new(predicate.node(), expr_arena) {
1167
let ae = expr_arena.get(node);
1168
1169
if let Some(t) = take_predicate(ae, node, expr_arena)? {
1170
selected_predicates.insert(node, t);
1171
}
1172
}
1173
}
1174
1175
if !selected_predicates.is_empty() {
1176
remove_min_terms(acc_predicates, expr_arena, &|node| {
1177
selected_predicates.contains_key(node)
1178
});
1179
}
1180
1181
return Ok(selected_predicates.into_values());
1182
1183
#[inline(never)]
1184
fn remove_min_terms(
1185
acc_predicates: &mut PlHashMap<PlSmallStr, ExprIR>,
1186
expr_arena: &mut Arena<AExpr>,
1187
should_remove: &dyn Fn(&Node) -> bool,
1188
) {
1189
let mut remove_keys = PlHashSet::new();
1190
let mut nodes_scratch = vec![];
1191
1192
for (k, predicate) in acc_predicates.iter_mut() {
1193
let mut has_removed = false;
1194
1195
nodes_scratch.clear();
1196
nodes_scratch.extend(
1197
MintermIter::new(predicate.node(), expr_arena).filter(|node| {
1198
let remove = should_remove(node);
1199
has_removed |= remove;
1200
!remove
1201
}),
1202
);
1203
1204
if nodes_scratch.is_empty() {
1205
remove_keys.insert(k.clone());
1206
continue;
1207
};
1208
1209
if has_removed {
1210
let new_predicate_node = nodes_scratch
1211
.drain(..)
1212
.reduce(|left, right| {
1213
expr_arena.add(AExpr::BinaryExpr {
1214
left,
1215
op: Operator::And,
1216
right,
1217
})
1218
})
1219
.unwrap();
1220
1221
*predicate = ExprIR::from_node(new_predicate_node, expr_arena);
1222
}
1223
}
1224
1225
for k in remove_keys {
1226
let v = acc_predicates.remove(&k);
1227
assert!(v.is_some());
1228
}
1229
}
1230
}
1231
1232