Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/plans/optimizer/predicate_pushdown/join.rs
7889 views
1
use polars_utils::format_pl_smallstr;
2
3
use super::*;
4
use crate::plans::optimizer::join_utils::remove_suffix;
5
6
const IEJOIN_MAX_PREDICATES: usize = 2;
7
8
#[allow(clippy::too_many_arguments)]
9
pub(super) fn process_join(
10
opt: &mut PredicatePushDown,
11
lp_arena: &mut Arena<IR>,
12
expr_arena: &mut Arena<AExpr>,
13
input_left: Node,
14
input_right: Node,
15
mut left_on: Vec<ExprIR>,
16
mut right_on: Vec<ExprIR>,
17
mut schema: SchemaRef,
18
mut options: Arc<JoinOptionsIR>,
19
mut acc_predicates: PlHashMap<PlSmallStr, ExprIR>,
20
streaming: bool,
21
) -> PolarsResult<IR> {
22
let schema_left = lp_arena.get(input_left).schema(lp_arena).into_owned();
23
let schema_right = lp_arena.get(input_right).schema(lp_arena).into_owned();
24
25
let opt_post_select = try_rewrite_join_type(
26
&schema_left,
27
&schema_right,
28
&mut schema,
29
&mut options,
30
&mut left_on,
31
&mut right_on,
32
&mut acc_predicates,
33
expr_arena,
34
streaming,
35
)?;
36
37
if match &options.args.how {
38
// Full-join with no coalesce. We can only push filters if they do not remove NULLs, but
39
// we don't have a reliable way to guarantee this.
40
JoinType::Full => !options.args.should_coalesce(),
41
42
_ => false,
43
} || acc_predicates.is_empty()
44
{
45
let lp = IR::Join {
46
input_left,
47
input_right,
48
left_on,
49
right_on,
50
schema,
51
options,
52
};
53
54
return opt.no_pushdown_restart_opt(lp, acc_predicates, lp_arena, expr_arena);
55
}
56
57
let should_coalesce = options.args.should_coalesce();
58
59
// AsOf has the equality join keys under `asof_options.left/right_by`. This code builds an
60
// iterator to address these generically without creating a `Box<dyn Iterator>`.
61
let get_lhs_column_keys_iter = || {
62
let len = match &options.args.how {
63
#[cfg(feature = "asof_join")]
64
JoinType::AsOf(asof_options) => {
65
asof_options.left_by.as_deref().unwrap_or_default().len()
66
},
67
_ => left_on.len(),
68
};
69
70
(0..len).map(|i| match &options.args.how {
71
#[cfg(feature = "asof_join")]
72
JoinType::AsOf(asof_options) => Some(
73
asof_options
74
.left_by
75
.as_deref()
76
.unwrap_or_default()
77
.get(i)
78
.unwrap(),
79
),
80
_ => {
81
let expr = left_on.get(i).unwrap();
82
83
// For non full-joins coalesce can still insert casts into the key exprs.
84
let node = match expr_arena.get(expr.node()) {
85
AExpr::Cast {
86
expr,
87
dtype: _,
88
options: _,
89
} if should_coalesce => *expr,
90
91
_ => expr.node(),
92
};
93
94
if let AExpr::Column(name) = expr_arena.get(node) {
95
Some(name)
96
} else {
97
None
98
}
99
},
100
})
101
};
102
103
let get_rhs_column_keys_iter = || {
104
let len = match &options.args.how {
105
#[cfg(feature = "asof_join")]
106
JoinType::AsOf(asof_options) => {
107
asof_options.right_by.as_deref().unwrap_or_default().len()
108
},
109
_ => right_on.len(),
110
};
111
112
(0..len).map(|i| match &options.args.how {
113
#[cfg(feature = "asof_join")]
114
JoinType::AsOf(asof_options) => Some(
115
asof_options
116
.right_by
117
.as_deref()
118
.unwrap_or_default()
119
.get(i)
120
.unwrap(),
121
),
122
_ => {
123
let expr = right_on.get(i).unwrap();
124
125
// For non full-joins coalesce can still insert casts into the key exprs.
126
let node = match expr_arena.get(expr.node()) {
127
AExpr::Cast {
128
expr,
129
dtype: _,
130
options: _,
131
} if should_coalesce => *expr,
132
133
_ => expr.node(),
134
};
135
136
if let AExpr::Column(name) = expr_arena.get(node) {
137
Some(name)
138
} else {
139
None
140
}
141
},
142
})
143
};
144
145
if cfg!(debug_assertions) && options.args.should_coalesce() {
146
match &options.args.how {
147
#[cfg(feature = "asof_join")]
148
JoinType::AsOf(_) => {},
149
150
_ => {
151
assert!(get_lhs_column_keys_iter().len() > 0);
152
assert!(get_rhs_column_keys_iter().len() > 0);
153
},
154
}
155
156
assert!(get_lhs_column_keys_iter().all(|x| x.is_some()));
157
assert!(get_rhs_column_keys_iter().all(|x| x.is_some()));
158
}
159
160
// Key columns of the left table that are coalesced into an output column of the right table.
161
let coalesced_to_right: PlHashSet<PlSmallStr> =
162
if matches!(&options.args.how, JoinType::Right) && options.args.should_coalesce() {
163
get_lhs_column_keys_iter()
164
.map(|x| x.unwrap().clone())
165
.collect()
166
} else {
167
Default::default()
168
};
169
170
let mut output_key_to_left_input_map: PlHashMap<PlSmallStr, PlSmallStr> =
171
PlHashMap::with_capacity(get_lhs_column_keys_iter().len());
172
let mut output_key_to_right_input_map: PlHashMap<PlSmallStr, PlSmallStr> =
173
PlHashMap::with_capacity(get_rhs_column_keys_iter().len());
174
175
for (lhs_input_key, rhs_input_key) in get_lhs_column_keys_iter().zip(get_rhs_column_keys_iter())
176
{
177
let (Some(lhs_input_key), Some(rhs_input_key)) = (lhs_input_key, rhs_input_key) else {
178
continue;
179
};
180
181
// lhs_input_key: Column name within the left table.
182
use JoinType::*;
183
// Map output name of an LHS join key output to an input key column of the right table.
184
// This will cause predicates referring to LHS join keys to also be pushed to the RHS table.
185
if match &options.args.how {
186
Left | Inner | Full => true,
187
188
#[cfg(feature = "asof_join")]
189
AsOf(_) => true,
190
#[cfg(feature = "semi_anti_join")]
191
Semi | Anti => true,
192
193
// NOTE: Right-join is excluded.
194
Right => false,
195
196
#[cfg(feature = "iejoin")]
197
IEJoin => false,
198
199
Cross => unreachable!(), // Cross left/right_on should be empty
200
} {
201
// Note: `lhs_input_key` maintains its name in the output column for all cases except
202
// for a coalescing right-join.
203
output_key_to_right_input_map.insert(lhs_input_key.clone(), rhs_input_key.clone());
204
}
205
206
// Map output name of an RHS join key output to a key column of the left table.
207
// This will cause predicates referring to RHS join keys to also be pushed to the LHS table.
208
if match &options.args.how {
209
JoinType::Right => true,
210
// Non-coalesced output columns of an inner join are equivalent between LHS and RHS.
211
JoinType::Inner => !options.args.should_coalesce(),
212
_ => false,
213
} {
214
let rhs_output_key: PlSmallStr = if schema_left.contains(rhs_input_key.as_str())
215
&& !coalesced_to_right.contains(rhs_input_key.as_str())
216
{
217
format_pl_smallstr!("{}{}", rhs_input_key, options.args.suffix())
218
} else {
219
rhs_input_key.clone()
220
};
221
222
assert!(schema.contains(&rhs_output_key));
223
224
output_key_to_left_input_map.insert(rhs_output_key.clone(), lhs_input_key.clone());
225
}
226
}
227
228
let mut pushdown_left: PlHashMap<PlSmallStr, ExprIR> = init_hashmap(Some(acc_predicates.len()));
229
let mut pushdown_right: PlHashMap<PlSmallStr, ExprIR> =
230
init_hashmap(Some(acc_predicates.len()));
231
let mut local_predicates = Vec::with_capacity(acc_predicates.len());
232
233
for (_, predicate) in acc_predicates {
234
let mut push_left = true;
235
let mut push_right = true;
236
237
for col_name in aexpr_to_leaf_names_iter(predicate.node(), expr_arena) {
238
let origin: ExprOrigin = ExprOrigin::get_column_origin(
239
col_name.as_str(),
240
&schema_left,
241
&schema_right,
242
options.args.suffix(),
243
Some(&|name| coalesced_to_right.contains(name)),
244
)
245
.unwrap();
246
247
push_left &= matches!(origin, ExprOrigin::Left | ExprOrigin::None)
248
|| output_key_to_left_input_map.contains_key(col_name);
249
250
push_right &= matches!(origin, ExprOrigin::Right | ExprOrigin::None)
251
|| output_key_to_right_input_map.contains_key(col_name);
252
}
253
254
// Note: If `push_left` and `push_right` are both `true`, it means the predicate refers only
255
// to the join key columns, or the predicate does not refer any columns.
256
257
let has_residual = match &options.args.how {
258
// Pushing to a single side is enough to observe the full effect of the filter.
259
JoinType::Inner => !(push_left || push_right),
260
261
// Left-join: Pushing filters to the left table is enough to observe the effect of the
262
// filter. Pushing filters to the right is optional, but can only be done if the
263
// filter is also pushed to the left (if this is the case it means the filter only
264
// references join key columns).
265
JoinType::Left => {
266
push_right &= push_left;
267
!push_left
268
},
269
270
// Same as left-join, just flipped around.
271
JoinType::Right => {
272
push_left &= push_right;
273
!push_right
274
},
275
276
// Full-join: Filters must strictly apply only to coalesced output key columns.
277
JoinType::Full => {
278
assert!(options.args.should_coalesce());
279
280
let push = push_left && push_right;
281
push_left = push;
282
push_right = push;
283
284
!push
285
},
286
287
JoinType::Cross => {
288
// Predicate should only refer to a single side.
289
assert!(output_key_to_left_input_map.is_empty());
290
assert!(output_key_to_right_input_map.is_empty());
291
!(push_left || push_right)
292
},
293
294
// Behaves similarly to left-join on "by" columns (takes a single match instead of
295
// all matches according to asof strategy).
296
#[cfg(feature = "asof_join")]
297
JoinType::AsOf(_) => {
298
push_right &= push_left;
299
!push_left
300
},
301
302
// Same as inner-join.
303
#[cfg(feature = "semi_anti_join")]
304
JoinType::Semi => !(push_left || push_right),
305
306
// Anti-join is an exclusion of key tuples that exist in the right table, meaning that
307
// filters can only be pushed to the right table if they are also pushed to the left.
308
#[cfg(feature = "semi_anti_join")]
309
JoinType::Anti => {
310
push_right &= push_left;
311
!push_left
312
},
313
314
// Same as inner-join.
315
#[cfg(feature = "iejoin")]
316
JoinType::IEJoin => !(push_left || push_right),
317
};
318
319
if has_residual {
320
local_predicates.push(predicate.clone())
321
}
322
323
if push_left {
324
let mut predicate = predicate.clone();
325
map_column_references(&mut predicate, expr_arena, &output_key_to_left_input_map);
326
insert_predicate_dedup(&mut pushdown_left, &predicate, expr_arena);
327
}
328
329
if push_right {
330
let mut predicate = predicate;
331
map_column_references(&mut predicate, expr_arena, &output_key_to_right_input_map);
332
remove_suffix(
333
&mut predicate,
334
expr_arena,
335
&schema_right,
336
options.args.suffix(),
337
);
338
insert_predicate_dedup(&mut pushdown_right, &predicate, expr_arena);
339
}
340
}
341
342
opt.pushdown_and_assign(input_left, pushdown_left, lp_arena, expr_arena)?;
343
opt.pushdown_and_assign(input_right, pushdown_right, lp_arena, expr_arena)?;
344
345
let lp = IR::Join {
346
input_left,
347
input_right,
348
left_on,
349
right_on,
350
schema,
351
options,
352
};
353
354
let lp = opt.optional_apply_predicate(lp, local_predicates, lp_arena, expr_arena);
355
356
let lp = if let Some((projections, schema)) = opt_post_select {
357
IR::Select {
358
input: lp_arena.add(lp),
359
expr: projections,
360
schema,
361
options: ProjectionOptions {
362
run_parallel: false,
363
duplicate_check: false,
364
should_broadcast: false,
365
},
366
}
367
} else {
368
lp
369
};
370
371
Ok(lp)
372
}
373
374
/// Attempts to rewrite the join-type based on NULL-removing filters.
375
///
376
/// Changing between some join types may cause the output column order to change. If this is the
377
/// case, a Vec of column selectors will be returned that restore the original column order.
378
#[expect(clippy::too_many_arguments)]
379
fn try_rewrite_join_type(
380
schema_left: &SchemaRef,
381
schema_right: &SchemaRef,
382
output_schema: &mut SchemaRef,
383
options: &mut Arc<JoinOptionsIR>,
384
left_on: &mut Vec<ExprIR>,
385
right_on: &mut Vec<ExprIR>,
386
acc_predicates: &mut PlHashMap<PlSmallStr, ExprIR>,
387
expr_arena: &mut Arena<AExpr>,
388
streaming: bool,
389
) -> PolarsResult<Option<(Vec<ExprIR>, SchemaRef)>> {
390
if acc_predicates.is_empty() {
391
return Ok(None);
392
}
393
394
let suffix = options.args.suffix().clone();
395
396
// * Cross -> Inner | IEJoin
397
// * IEJoin -> Inner
398
//
399
// Note: The join rewrites here all maintain output column ordering, hence this does not need
400
// to return any post-select (inserted inner joins will use JoinCoalesce::KeepColumns).
401
(|| {
402
match &options.args.how {
403
#[cfg(feature = "iejoin")]
404
JoinType::IEJoin => {},
405
JoinType::Cross => {},
406
407
_ => return PolarsResult::Ok(()),
408
}
409
410
match &options.options {
411
Some(JoinTypeOptionsIR::CrossAndFilter { .. }) => {
412
let Some(JoinTypeOptionsIR::CrossAndFilter { predicate }) =
413
Arc::make_mut(options).options.take()
414
else {
415
unreachable!()
416
};
417
418
insert_predicate_dedup(acc_predicates, &predicate, expr_arena);
419
},
420
421
#[cfg(feature = "iejoin")]
422
Some(JoinTypeOptionsIR::IEJoin(_)) => {},
423
None => {},
424
}
425
426
// Try converting to inner join
427
let equality_conditions = take_inner_join_compatible_filters(
428
acc_predicates,
429
expr_arena,
430
schema_left,
431
schema_right,
432
&suffix,
433
)?;
434
435
for InnerJoinKeys {
436
input_lhs,
437
input_rhs,
438
} in equality_conditions
439
{
440
let join_options = Arc::make_mut(options);
441
join_options.args.how = JoinType::Inner;
442
join_options.args.coalesce = JoinCoalesce::KeepColumns;
443
444
left_on.push(ExprIR::from_node(input_lhs, expr_arena));
445
let mut rexpr = ExprIR::from_node(input_rhs, expr_arena);
446
remove_suffix(&mut rexpr, expr_arena, schema_right, &suffix);
447
right_on.push(rexpr);
448
}
449
450
if options.args.how == JoinType::Inner {
451
return Ok(());
452
}
453
454
// Try converting cross join to IEJoin
455
#[cfg(feature = "iejoin")]
456
if matches!(options.args.maintain_order, MaintainOrderJoin::None)
457
&& left_on.len() < IEJOIN_MAX_PREDICATES
458
{
459
let ie_conditions = take_iejoin_compatible_filters(
460
acc_predicates,
461
expr_arena,
462
schema_left,
463
schema_right,
464
output_schema,
465
&suffix,
466
)?;
467
468
for IEJoinCompatiblePredicate {
469
input_lhs,
470
input_rhs,
471
ie_op,
472
source_node,
473
} in ie_conditions
474
{
475
let join_options = Arc::make_mut(options);
476
join_options.args.how = JoinType::IEJoin;
477
478
if left_on.len() >= IEJOIN_MAX_PREDICATES {
479
// Important: Place these back into acc_predicates.
480
insert_predicate_dedup(
481
acc_predicates,
482
&ExprIR::from_node(source_node, expr_arena),
483
expr_arena,
484
);
485
} else {
486
left_on.push(ExprIR::from_node(input_lhs, expr_arena));
487
let mut rexpr = ExprIR::from_node(input_rhs, expr_arena);
488
remove_suffix(&mut rexpr, expr_arena, schema_right, &suffix);
489
right_on.push(rexpr);
490
491
let JoinTypeOptionsIR::IEJoin(ie_options) = join_options
492
.options
493
.get_or_insert(JoinTypeOptionsIR::IEJoin(IEJoinOptions::default()))
494
else {
495
unreachable!()
496
};
497
498
match left_on.len() {
499
1 => ie_options.operator1 = ie_op,
500
2 => ie_options.operator2 = Some(ie_op),
501
_ => unreachable!("{}", IEJOIN_MAX_PREDICATES),
502
};
503
}
504
}
505
506
if options.args.how == JoinType::IEJoin {
507
return Ok(());
508
}
509
}
510
511
debug_assert_eq!(options.args.how, JoinType::Cross);
512
513
if options.args.how != JoinType::Cross {
514
return Ok(());
515
}
516
517
if streaming {
518
return Ok(());
519
}
520
521
let Some(nested_loop_predicates) = take_nested_loop_join_compatible_filters(
522
acc_predicates,
523
expr_arena,
524
schema_left,
525
schema_right,
526
&suffix,
527
)?
528
.reduce(|left, right| {
529
expr_arena.add(AExpr::BinaryExpr {
530
left,
531
op: Operator::And,
532
right,
533
})
534
}) else {
535
return Ok(());
536
};
537
538
let existing = Arc::make_mut(options)
539
.options
540
.replace(JoinTypeOptionsIR::CrossAndFilter {
541
predicate: ExprIR::from_node(nested_loop_predicates, expr_arena),
542
});
543
assert!(existing.is_none()); // Important
544
545
Ok(())
546
})()?;
547
548
if !matches!(
549
&options.args.how,
550
JoinType::Full | JoinType::Left | JoinType::Right
551
) {
552
return Ok(None);
553
}
554
555
let should_coalesce = options.args.should_coalesce();
556
557
/// Note: This may panic if `args.should_coalesce()` is false.
558
macro_rules! lhs_input_column_keys_iter {
559
() => {{
560
left_on.iter().map(|expr| {
561
let node = match expr_arena.get(expr.node()) {
562
AExpr::Cast {
563
expr,
564
dtype: _,
565
options: _,
566
} if should_coalesce => *expr,
567
568
_ => expr.node(),
569
};
570
571
let AExpr::Column(name) = expr_arena.get(node) else {
572
// All keys should be columns when coalesce=True
573
unreachable!()
574
};
575
576
name.clone()
577
})
578
}};
579
}
580
581
let mut coalesced_to_right: PlHashSet<PlSmallStr> = Default::default();
582
// Removing NULLs on these columns do not allow for join downgrading.
583
// We only need to track these for full-join - e.g. for left-join, removing NULLs from any left
584
// column does not cause any join rewrites.
585
let mut coalesced_full_join_key_outputs: PlHashSet<PlSmallStr> = Default::default();
586
587
if options.args.should_coalesce() {
588
match &options.args.how {
589
JoinType::Full => {
590
coalesced_full_join_key_outputs = lhs_input_column_keys_iter!().collect()
591
},
592
JoinType::Right => coalesced_to_right = lhs_input_column_keys_iter!().collect(),
593
_ => {},
594
}
595
}
596
597
let mut non_null_side = ExprOrigin::None;
598
599
for predicate in acc_predicates.values() {
600
for node in MintermIter::new(predicate.node(), expr_arena) {
601
predicate_non_null_column_outputs(node, expr_arena, &mut |non_null_column| {
602
if coalesced_full_join_key_outputs.contains(non_null_column) {
603
return;
604
}
605
606
non_null_side |= ExprOrigin::get_column_origin(
607
non_null_column.as_str(),
608
schema_left,
609
schema_right,
610
options.args.suffix(),
611
Some(&|x| coalesced_to_right.contains(x)),
612
)
613
.unwrap();
614
});
615
}
616
}
617
618
let Some(new_join_type) = (match non_null_side {
619
ExprOrigin::Both => Some(JoinType::Inner),
620
621
ExprOrigin::Left => match &options.args.how {
622
JoinType::Full => Some(JoinType::Left),
623
JoinType::Right => Some(JoinType::Inner),
624
_ => None,
625
},
626
627
ExprOrigin::Right => match &options.args.how {
628
JoinType::Full => Some(JoinType::Right),
629
JoinType::Left => Some(JoinType::Inner),
630
_ => None,
631
},
632
633
ExprOrigin::None => None,
634
}) else {
635
return Ok(None);
636
};
637
638
let options = Arc::make_mut(options);
639
// Ensure JoinSpecific is materialized to a specific config option, as we change the join type.
640
options.args.coalesce = if options.args.should_coalesce() {
641
JoinCoalesce::CoalesceColumns
642
} else {
643
JoinCoalesce::KeepColumns
644
};
645
let original_join_type = std::mem::replace(&mut options.args.how, new_join_type.clone());
646
let original_output_schema = match (&original_join_type, &new_join_type) {
647
(JoinType::Right, _) | (_, JoinType::Right) => std::mem::replace(
648
output_schema,
649
det_join_schema(
650
schema_left,
651
schema_right,
652
left_on,
653
right_on,
654
options,
655
expr_arena,
656
)
657
.unwrap(),
658
),
659
_ => {
660
debug_assert_eq!(
661
output_schema,
662
&det_join_schema(
663
schema_left,
664
schema_right,
665
left_on,
666
right_on,
667
options,
668
expr_arena,
669
)
670
.unwrap()
671
);
672
output_schema.clone()
673
},
674
};
675
676
// Maps the original join output names to the new join output names (used for mapping column
677
// references of the predicates).
678
let mut original_to_new_names_map: PlHashMap<PlSmallStr, PlSmallStr> = Default::default();
679
// Projects the new join output table back into the original join output table.
680
let mut project_to_original: Option<Vec<ExprIR>> = None;
681
682
if options.args.should_coalesce() {
683
// If we changed join types between a coalescing right-join, we need to do a select() to restore the column
684
// order of the original join type. The column references in the predicates may also need to be changed.
685
match (&original_join_type, &new_join_type) {
686
(JoinType::Right, JoinType::Right) => unreachable!(),
687
688
// Right-join rewritten to inner-join.
689
//
690
// E.g.
691
// Left: | a | b | c |
692
// Right: | a | b | c |
693
//
694
// right_join(left_on='a', right_on='b'): | b | c | a | *b_right | c_right |
695
// inner_join(left_on='a', right_on='b'): | *a | b | c | a_right | c_right |
696
// note: '*' means coalesced key output column
697
//
698
// project_to_original: | col(b) | col(c) | col(a_right).alias(a) | col(a).alias(b_right) | col(c_right) |
699
// original_to_new_names_map: {'a': 'a_right', 'b_right': 'a'}
700
//
701
(JoinType::Right, JoinType::Inner) => {
702
let mut join_output_key_selectors = PlHashMap::with_capacity(right_on.len());
703
704
for (l, r) in left_on.iter().zip(right_on) {
705
let (AExpr::Column(lhs_input_key), AExpr::Column(rhs_input_key)) =
706
(expr_arena.get(l.node()), expr_arena.get(r.node()))
707
else {
708
// `should_coalesce() == true` should guarantee all are columns.
709
unreachable!()
710
};
711
712
let original_key_output_name: PlSmallStr = if schema_left
713
.contains(rhs_input_key.as_str())
714
&& !coalesced_to_right.contains(rhs_input_key.as_str())
715
{
716
format_pl_smallstr!("{}{}", rhs_input_key, options.args.suffix())
717
} else {
718
rhs_input_key.clone()
719
};
720
721
let new_key_output_name = lhs_input_key.clone();
722
let rhs_input_key = rhs_input_key.clone();
723
724
let node = expr_arena.add(AExpr::Column(lhs_input_key.clone()));
725
let mut ae = ExprIR::from_node(node, expr_arena);
726
727
if original_key_output_name != new_key_output_name {
728
// E.g. left_on=col(a), right_on=col(b)
729
// rhs_output_key = 'b', lhs_input_key = 'a', the original right-join is supposed to output 'b'.
730
original_to_new_names_map.insert(
731
original_key_output_name.clone(),
732
new_key_output_name.clone(),
733
);
734
ae.set_alias(original_key_output_name)
735
}
736
737
join_output_key_selectors.insert(rhs_input_key, ae);
738
}
739
740
let mut column_selectors: Vec<ExprIR> = Vec::with_capacity(output_schema.len());
741
742
for lhs_input_col in schema_left.iter_names() {
743
if coalesced_to_right.contains(lhs_input_col) {
744
continue;
745
}
746
747
let node = expr_arena.add(AExpr::Column(lhs_input_col.clone()));
748
column_selectors.push(ExprIR::from_node(node, expr_arena));
749
}
750
751
for rhs_input_col in schema_right.iter_names() {
752
let expr = if let Some(expr) = join_output_key_selectors.get(rhs_input_col) {
753
expr.clone()
754
} else if schema_left.contains(rhs_input_col) {
755
let new_join_output_name =
756
format_pl_smallstr!("{}{}", rhs_input_col, options.args.suffix());
757
758
let node = expr_arena.add(AExpr::Column(new_join_output_name.clone()));
759
let mut expr = ExprIR::from_node(node, expr_arena);
760
761
// The column with the same name from the LHS is not projected in the original
762
// right-join, so we alias to remove the suffix that was added from the inner-join.
763
if coalesced_to_right.contains(rhs_input_col.as_str()) {
764
original_to_new_names_map
765
.insert(rhs_input_col.clone(), new_join_output_name);
766
expr.set_alias(rhs_input_col.clone());
767
}
768
769
expr
770
} else {
771
let node = expr_arena.add(AExpr::Column(rhs_input_col.clone()));
772
ExprIR::from_node(node, expr_arena)
773
};
774
775
column_selectors.push(expr)
776
}
777
778
assert_eq!(column_selectors.len(), output_schema.len());
779
assert_eq!(column_selectors.len(), original_output_schema.len());
780
781
if cfg!(debug_assertions) {
782
assert!(
783
column_selectors
784
.iter()
785
.zip(original_output_schema.iter_names())
786
.all(|(l, r)| l.output_name() == r)
787
)
788
}
789
790
project_to_original = Some(column_selectors)
791
},
792
793
// Full-join rewritten to right-join
794
//
795
// E.g.
796
// Left: | a | b | c |
797
// Right: | a | b | c |
798
//
799
// full_join(left_on='a', right_on='b'): | *a | b | c | a_right | c_right |
800
// right_join(left_on='a', right_on='b'): | b | c | a | *b_right | c_right |
801
// note: '*' means coalesced key output column
802
//
803
// project_to_original: | col(b_right).alias(a) | col(b) | col(c) | col(a).alias(a_right) | col(c_right) |
804
// original_to_new_names_map: {'a': 'b_right', 'a_right': 'a'}
805
//
806
(JoinType::Full, JoinType::Right) => {
807
let mut join_output_key_selectors = PlHashMap::with_capacity(left_on.len());
808
809
// The existing one is empty because the original join type was not a right-join.
810
assert!(coalesced_to_right.is_empty());
811
// LHS input key columns that are coalesced (i.e. not projected) for the right-join.
812
let coalesced_to_right: PlHashSet<PlSmallStr> =
813
lhs_input_column_keys_iter!().collect();
814
// RHS input key columns that are coalesced (i.e. not projected) for the full-join.
815
let mut coalesced_to_left: PlHashSet<PlSmallStr> =
816
PlHashSet::with_capacity(right_on.len());
817
818
for (l, r) in left_on.iter().zip(right_on) {
819
let (AExpr::Column(lhs_input_key), AExpr::Column(rhs_input_key)) =
820
(expr_arena.get(l.node()), expr_arena.get(r.node()))
821
else {
822
// `should_coalesce() == true` should guarantee all columns.
823
unreachable!()
824
};
825
826
let new_key_output_name: PlSmallStr = if schema_left
827
.contains(rhs_input_key.as_str())
828
&& !coalesced_to_right.contains(rhs_input_key.as_str())
829
{
830
format_pl_smallstr!("{}{}", rhs_input_key, options.args.suffix())
831
} else {
832
rhs_input_key.clone()
833
};
834
835
let lhs_input_key = lhs_input_key.clone();
836
let rhs_input_key = rhs_input_key.clone();
837
let original_key_output_name = &lhs_input_key;
838
839
coalesced_to_left.insert(rhs_input_key);
840
841
let node = expr_arena.add(AExpr::Column(new_key_output_name.clone()));
842
843
let mut ae = ExprIR::from_node(node, expr_arena);
844
845
// E.g. left_on=col(a), right_on=col(b)
846
// rhs_output_key = 'b', lhs_input_key = 'a'
847
if new_key_output_name != original_key_output_name {
848
original_to_new_names_map.insert(
849
original_key_output_name.clone(),
850
new_key_output_name.clone(),
851
);
852
ae.set_alias(original_key_output_name.clone())
853
}
854
855
join_output_key_selectors.insert(lhs_input_key.clone(), ae);
856
}
857
858
let mut column_selectors = Vec::with_capacity(output_schema.len());
859
860
for lhs_input_col in schema_left.iter_names() {
861
let expr = if let Some(expr) = join_output_key_selectors.get(lhs_input_col) {
862
expr.clone()
863
} else {
864
let node = expr_arena.add(AExpr::Column(lhs_input_col.clone()));
865
ExprIR::from_node(node, expr_arena)
866
};
867
868
column_selectors.push(expr)
869
}
870
871
for rhs_input_col in schema_right.iter_names() {
872
if coalesced_to_left.contains(rhs_input_col) {
873
continue;
874
}
875
876
let mut original_output_name: Option<PlSmallStr> = None;
877
878
let new_join_output_name = if schema_left.contains(rhs_input_col) {
879
let suffixed =
880
format_pl_smallstr!("{}{}", rhs_input_col, options.args.suffix());
881
882
if coalesced_to_right.contains(rhs_input_col) {
883
original_output_name = Some(suffixed);
884
rhs_input_col.clone()
885
} else {
886
suffixed
887
}
888
} else {
889
rhs_input_col.clone()
890
};
891
892
let node = expr_arena.add(AExpr::Column(new_join_output_name));
893
894
let mut expr = ExprIR::from_node(node, expr_arena);
895
896
if let Some(original_output_name) = original_output_name {
897
original_to_new_names_map
898
.insert(original_output_name.clone(), rhs_input_col.clone());
899
expr.set_alias(original_output_name);
900
}
901
902
column_selectors.push(expr);
903
}
904
905
assert_eq!(column_selectors.len(), output_schema.len());
906
assert_eq!(column_selectors.len(), original_output_schema.len());
907
908
if cfg!(debug_assertions) {
909
assert!(
910
column_selectors
911
.iter()
912
.zip(original_output_schema.iter_names())
913
.all(|(l, r)| l.output_name() == r)
914
)
915
}
916
917
project_to_original = Some(column_selectors)
918
},
919
920
(JoinType::Right, _) | (_, JoinType::Right) => unreachable!(),
921
922
_ => {},
923
}
924
}
925
926
if !original_to_new_names_map.is_empty() {
927
assert!(project_to_original.is_some());
928
929
for (_, predicate_expr) in acc_predicates.iter_mut() {
930
map_column_references(predicate_expr, expr_arena, &original_to_new_names_map);
931
}
932
}
933
934
Ok(project_to_original.map(|p| (p, original_output_schema)))
935
}
936
937
struct InnerJoinKeys {
938
input_lhs: Node,
939
input_rhs: Node,
940
}
941
942
/// Removes all equality predicates that can be used as inner-join conditions from `acc_predicates`.
943
fn take_inner_join_compatible_filters(
944
acc_predicates: &mut PlHashMap<PlSmallStr, ExprIR>,
945
expr_arena: &mut Arena<AExpr>,
946
schema_left: &Schema,
947
schema_right: &Schema,
948
suffix: &str,
949
) -> PolarsResult<hashbrown::hash_map::IntoValues<Node, InnerJoinKeys>> {
950
take_predicates_mut(acc_predicates, expr_arena, |ae, _ae_node, expr_arena| {
951
Ok(match ae {
952
AExpr::BinaryExpr {
953
left,
954
op: Operator::Eq,
955
right,
956
} => {
957
let left_origin = ExprOrigin::get_expr_origin(
958
*left,
959
expr_arena,
960
schema_left,
961
schema_right,
962
suffix,
963
None, // is_coalesced_to_right
964
)?;
965
let right_origin = ExprOrigin::get_expr_origin(
966
*right,
967
expr_arena,
968
schema_left,
969
schema_right,
970
suffix,
971
None,
972
)?;
973
974
match (left_origin, right_origin) {
975
(ExprOrigin::Left, ExprOrigin::Right) => Some(InnerJoinKeys {
976
input_lhs: *left,
977
input_rhs: *right,
978
}),
979
(ExprOrigin::Right, ExprOrigin::Left) => Some(InnerJoinKeys {
980
input_lhs: *right,
981
input_rhs: *left,
982
}),
983
_ => None,
984
}
985
},
986
_ => None,
987
})
988
})
989
}
990
991
#[cfg(feature = "iejoin")]
992
struct IEJoinCompatiblePredicate {
993
input_lhs: Node,
994
input_rhs: Node,
995
ie_op: InequalityOperator,
996
/// Original input node.
997
source_node: Node,
998
}
999
1000
#[cfg(feature = "iejoin")]
1001
/// Removes all inequality filters that can be used as iejoin conditions from `acc_predicates`.
1002
fn take_iejoin_compatible_filters(
1003
acc_predicates: &mut PlHashMap<PlSmallStr, ExprIR>,
1004
expr_arena: &mut Arena<AExpr>,
1005
schema_left: &Schema,
1006
schema_right: &Schema,
1007
output_schema: &Schema,
1008
suffix: &str,
1009
) -> PolarsResult<hashbrown::hash_map::IntoValues<Node, IEJoinCompatiblePredicate>> {
1010
return take_predicates_mut(acc_predicates, expr_arena, |ae, ae_node, expr_arena| {
1011
Ok(match ae {
1012
AExpr::BinaryExpr { left, op, right } => {
1013
if to_inequality_operator(op).is_none() {
1014
return Ok(None);
1015
}
1016
1017
let left_origin = ExprOrigin::get_expr_origin(
1018
*left,
1019
expr_arena,
1020
schema_left,
1021
schema_right,
1022
suffix,
1023
None, // is_coalesced_to_right
1024
)?;
1025
1026
let right_origin = ExprOrigin::get_expr_origin(
1027
*right,
1028
expr_arena,
1029
schema_left,
1030
schema_right,
1031
suffix,
1032
None,
1033
)?;
1034
1035
macro_rules! is_supported_type {
1036
($node:expr) => {{
1037
let node = $node;
1038
let field = expr_arena
1039
.get(node)
1040
.to_field(&ToFieldContext::new(expr_arena, output_schema))?;
1041
let dtype = field.dtype();
1042
1043
!dtype.is_nested() && dtype.to_physical().is_primitive_numeric()
1044
}};
1045
}
1046
1047
// IEJoin only supports numeric.
1048
if !is_supported_type!(*left) || !is_supported_type!(*right) {
1049
return Ok(None);
1050
}
1051
1052
match (left_origin, right_origin) {
1053
(ExprOrigin::Left, ExprOrigin::Right) => Some(IEJoinCompatiblePredicate {
1054
input_lhs: *left,
1055
input_rhs: *right,
1056
ie_op: to_inequality_operator(op).unwrap(),
1057
source_node: ae_node,
1058
}),
1059
(ExprOrigin::Right, ExprOrigin::Left) => {
1060
let op = op.swap_operands();
1061
1062
Some(IEJoinCompatiblePredicate {
1063
input_lhs: *right,
1064
input_rhs: *left,
1065
ie_op: to_inequality_operator(&op).unwrap(),
1066
source_node: ae_node,
1067
})
1068
},
1069
_ => None,
1070
}
1071
},
1072
_ => None,
1073
})
1074
});
1075
1076
fn to_inequality_operator(op: &Operator) -> Option<InequalityOperator> {
1077
match op {
1078
Operator::Lt => Some(InequalityOperator::Lt),
1079
Operator::LtEq => Some(InequalityOperator::LtEq),
1080
Operator::Gt => Some(InequalityOperator::Gt),
1081
Operator::GtEq => Some(InequalityOperator::GtEq),
1082
_ => None,
1083
}
1084
}
1085
}
1086
1087
/// Removes all filters that can be used as nested loop join conditions from `acc_predicates`.
1088
///
1089
/// Note that filters that refer only to a single side are not removed so that they can be pushed
1090
/// into the LHS/RHS tables.
1091
fn take_nested_loop_join_compatible_filters(
1092
acc_predicates: &mut PlHashMap<PlSmallStr, ExprIR>,
1093
expr_arena: &mut Arena<AExpr>,
1094
schema_left: &Schema,
1095
schema_right: &Schema,
1096
suffix: &str,
1097
) -> PolarsResult<hashbrown::hash_map::IntoValues<Node, Node>> {
1098
take_predicates_mut(acc_predicates, expr_arena, |_ae, ae_node, expr_arena| {
1099
Ok(
1100
match ExprOrigin::get_expr_origin(
1101
ae_node,
1102
expr_arena,
1103
schema_left,
1104
schema_right,
1105
suffix,
1106
None,
1107
)? {
1108
// Leave single-origin exprs as they get pushed to the left/right tables individually.
1109
ExprOrigin::Left | ExprOrigin::Right | ExprOrigin::None => None,
1110
_ => Some(ae_node),
1111
},
1112
)
1113
})
1114
}
1115
1116
/// Removes predicates from the map according to a function.
1117
fn take_predicates_mut<F, T>(
1118
acc_predicates: &mut PlHashMap<PlSmallStr, ExprIR>,
1119
expr_arena: &mut Arena<AExpr>,
1120
take_predicate: F,
1121
) -> PolarsResult<hashbrown::hash_map::IntoValues<Node, T>>
1122
where
1123
F: Fn(&AExpr, Node, &Arena<AExpr>) -> PolarsResult<Option<T>>,
1124
{
1125
let mut selected_predicates: PlHashMap<Node, T> = PlHashMap::new();
1126
1127
for predicate in acc_predicates.values() {
1128
for node in MintermIter::new(predicate.node(), expr_arena) {
1129
let ae = expr_arena.get(node);
1130
1131
if let Some(t) = take_predicate(ae, node, expr_arena)? {
1132
selected_predicates.insert(node, t);
1133
}
1134
}
1135
}
1136
1137
if !selected_predicates.is_empty() {
1138
remove_min_terms(acc_predicates, expr_arena, &|node| {
1139
selected_predicates.contains_key(node)
1140
});
1141
}
1142
1143
return Ok(selected_predicates.into_values());
1144
1145
#[inline(never)]
1146
fn remove_min_terms(
1147
acc_predicates: &mut PlHashMap<PlSmallStr, ExprIR>,
1148
expr_arena: &mut Arena<AExpr>,
1149
should_remove: &dyn Fn(&Node) -> bool,
1150
) {
1151
let mut remove_keys = PlHashSet::new();
1152
let mut nodes_scratch = vec![];
1153
1154
for (k, predicate) in acc_predicates.iter_mut() {
1155
let mut has_removed = false;
1156
1157
nodes_scratch.clear();
1158
nodes_scratch.extend(
1159
MintermIter::new(predicate.node(), expr_arena).filter(|node| {
1160
let remove = should_remove(node);
1161
has_removed |= remove;
1162
!remove
1163
}),
1164
);
1165
1166
if nodes_scratch.is_empty() {
1167
remove_keys.insert(k.clone());
1168
continue;
1169
};
1170
1171
if has_removed {
1172
let new_predicate_node = nodes_scratch
1173
.drain(..)
1174
.reduce(|left, right| {
1175
expr_arena.add(AExpr::BinaryExpr {
1176
left,
1177
op: Operator::And,
1178
right,
1179
})
1180
})
1181
.unwrap();
1182
1183
*predicate = ExprIR::from_node(new_predicate_node, expr_arena);
1184
}
1185
}
1186
1187
for k in remove_keys {
1188
let v = acc_predicates.remove(&k);
1189
assert!(v.is_some());
1190
}
1191
}
1192
}
1193
1194