Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/plans/conversion/dsl_to_ir/join.rs
8506 views
1
use arrow::legacy::error::PolarsResult;
2
use either::Either;
3
use polars_core::chunked_array::cast::CastOptions;
4
use polars_core::error::feature_gated;
5
use polars_core::utils::{get_numeric_upcast_supertype_lossless, try_get_supertype};
6
use polars_utils::format_pl_smallstr;
7
use polars_utils::itertools::Itertools;
8
9
use super::*;
10
use crate::constants::POLARS_TMP_PREFIX;
11
use crate::dsl::Expr;
12
#[cfg(feature = "iejoin")]
13
use crate::plans::AExpr;
14
15
fn check_join_keys(keys: &[Expr]) -> PolarsResult<()> {
16
for e in keys {
17
if has_expr(e, |e| matches!(e, Expr::Alias(_, _))) {
18
polars_bail!(
19
InvalidOperation:
20
"'alias' is not allowed in a join key, use 'with_columns' first",
21
)
22
}
23
}
24
Ok(())
25
}
26
27
/// Returns: left: join_node, right: last_node (often both the same)
28
pub fn resolve_join(
29
input_left: Either<Arc<DslPlan>, Node>,
30
input_right: Either<Arc<DslPlan>, Node>,
31
left_on: Vec<Expr>,
32
right_on: Vec<Expr>,
33
predicates: Vec<Expr>,
34
mut options: JoinOptionsIR,
35
ctxt: &mut DslConversionContext,
36
) -> PolarsResult<(Node, Node)> {
37
if !predicates.is_empty() {
38
feature_gated!("iejoin", {
39
debug_assert!(left_on.is_empty() && right_on.is_empty());
40
return resolve_join_where(
41
input_left.unwrap_left(),
42
input_right.unwrap_left(),
43
predicates,
44
options,
45
ctxt,
46
);
47
})
48
}
49
50
let owned = Arc::unwrap_or_clone;
51
let mut input_left = input_left.map_right(Ok).right_or_else(|input| {
52
to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_here!(join left)))
53
})?;
54
let mut input_right = input_right.map_right(Ok).right_or_else(|input| {
55
to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_here!(join right)))
56
})?;
57
58
let schema_left = ctxt.lp_arena.get(input_left).schema(ctxt.lp_arena);
59
let schema_right = ctxt.lp_arena.get(input_right).schema(ctxt.lp_arena);
60
61
if options.args.how.is_cross() {
62
polars_ensure!(left_on.len() + right_on.len() == 0, InvalidOperation: "a 'cross' join doesn't expect any join keys");
63
} else {
64
polars_ensure!(left_on.len() + right_on.len() > 0, InvalidOperation: "expected join keys/predicates");
65
check_join_keys(&left_on)?;
66
check_join_keys(&right_on)?;
67
68
let mut turn_off_coalesce = false;
69
for e in left_on.iter().chain(right_on.iter()) {
70
// Any expression that is not a simple column expression will turn of coalescing.
71
turn_off_coalesce |= has_expr(e, |e| !matches!(e, Expr::Column(_)));
72
}
73
if turn_off_coalesce {
74
if matches!(options.args.coalesce, JoinCoalesce::CoalesceColumns) {
75
polars_warn!(
76
"coalescing join requested but not all join keys are column references, turning off key coalescing"
77
);
78
}
79
options.args.coalesce = JoinCoalesce::KeepColumns;
80
}
81
82
options.args.validation.is_valid_join(&options.args.how)?;
83
84
#[cfg(feature = "asof_join")]
85
if let JoinType::AsOf(options) = &options.args.how {
86
match (&options.left_by, &options.right_by) {
87
(None, None) => {},
88
(Some(l), Some(r)) => {
89
polars_ensure!(l.len() == r.len(), InvalidOperation: "expected equal number of columns in 'by_left' and 'by_right' in 'asof_join'");
90
validate_columns_in_input(l, &schema_left, "asof_join")?;
91
validate_columns_in_input(r, &schema_right, "asof_join")?;
92
},
93
_ => {
94
polars_bail!(InvalidOperation: "expected both 'by_left' and 'by_right' to be set in 'asof_join'")
95
},
96
}
97
}
98
99
polars_ensure!(
100
left_on.len() == right_on.len(),
101
InvalidOperation:
102
"the number of columns given as join key (left: {}, right:{}) should be equal",
103
left_on.len(),
104
right_on.len()
105
);
106
}
107
108
let mut left_on = left_on
109
.into_iter()
110
.map(|e| {
111
to_expr_ir_materialized_lit(
112
e,
113
&mut ExprToIRContext::new_with_opt_eager(
114
ctxt.expr_arena,
115
&schema_left,
116
ctxt.opt_flags,
117
),
118
)
119
})
120
.collect::<PolarsResult<Vec<_>>>()?;
121
let mut right_on = right_on
122
.into_iter()
123
.map(|e| {
124
to_expr_ir_materialized_lit(
125
e,
126
&mut ExprToIRContext::new_with_opt_eager(
127
ctxt.expr_arena,
128
&schema_right,
129
ctxt.opt_flags,
130
),
131
)
132
})
133
.collect::<PolarsResult<Vec<_>>>()?;
134
let mut joined_on = PlHashSet::new();
135
136
#[cfg(feature = "iejoin")]
137
let check = !matches!(options.args.how, JoinType::IEJoin);
138
#[cfg(not(feature = "iejoin"))]
139
let check = true;
140
if check {
141
for (l, r) in left_on.iter().zip(right_on.iter()) {
142
polars_ensure!(
143
joined_on.insert((l.output_name(), r.output_name())),
144
InvalidOperation: "joining with repeated key names; already joined on {} and {}",
145
l.output_name(),
146
r.output_name()
147
)
148
}
149
}
150
drop(joined_on);
151
152
ctxt.conversion_optimizer
153
.fill_scratch(&left_on, ctxt.expr_arena);
154
ctxt.conversion_optimizer
155
.optimize_exprs(ctxt.expr_arena, ctxt.lp_arena, input_left, true)
156
.map_err(|e| e.context("'join' failed".into()))?;
157
ctxt.conversion_optimizer
158
.fill_scratch(&right_on, ctxt.expr_arena);
159
ctxt.conversion_optimizer
160
.optimize_exprs(ctxt.expr_arena, ctxt.lp_arena, input_right, true)
161
.map_err(|e| e.context("'join' failed".into()))?;
162
163
// Re-evaluate because of mutable borrows earlier.
164
let schema_left = ctxt.lp_arena.get(input_left).schema(ctxt.lp_arena);
165
let schema_right = ctxt.lp_arena.get(input_right).schema(ctxt.lp_arena);
166
167
// # Resolve scalars
168
//
169
// Scalars need to be expanded. We translate them to temporary columns added with
170
// `with_columns` and remove them later with `project`
171
// This way the backends don't have to expand the literals in the join implementation
172
173
let has_scalars = left_on
174
.iter()
175
.chain(right_on.iter())
176
.any(|e| e.is_scalar(ctxt.expr_arena));
177
178
let (schema_left, schema_right) = if has_scalars {
179
let mut as_with_columns_l = vec![];
180
let mut as_with_columns_r = vec![];
181
for (i, e) in left_on.iter().enumerate() {
182
if e.is_scalar(ctxt.expr_arena) {
183
as_with_columns_l.push((i, e.clone()));
184
}
185
}
186
for (i, e) in right_on.iter().enumerate() {
187
if e.is_scalar(ctxt.expr_arena) {
188
as_with_columns_r.push((i, e.clone()));
189
}
190
}
191
192
let mut count = 0;
193
let get_tmp_name = |i| format_pl_smallstr!("{POLARS_TMP_PREFIX}{i}");
194
195
// Early clone because of bck.
196
let mut schema_right_new = if !as_with_columns_r.is_empty() {
197
(**schema_right).clone()
198
} else {
199
Default::default()
200
};
201
if !as_with_columns_l.is_empty() {
202
let mut schema_left_new = (**schema_left).clone();
203
204
let mut exprs = Vec::with_capacity(as_with_columns_l.len());
205
for (i, mut e) in as_with_columns_l {
206
let tmp_name = get_tmp_name(count);
207
count += 1;
208
e.set_alias(tmp_name.clone());
209
let dtype = e.dtype(&schema_left_new, ctxt.expr_arena)?;
210
schema_left_new.with_column(tmp_name.clone(), dtype.clone());
211
212
let col = ctxt.expr_arena.add(AExpr::Column(tmp_name));
213
left_on[i] = ExprIR::from_node(col, ctxt.expr_arena);
214
exprs.push(e);
215
}
216
input_left = ctxt.lp_arena.add(IR::HStack {
217
input: input_left,
218
exprs,
219
schema: Arc::new(schema_left_new),
220
options: ProjectionOptions::default(),
221
})
222
}
223
if !as_with_columns_r.is_empty() {
224
let mut exprs = Vec::with_capacity(as_with_columns_r.len());
225
for (i, mut e) in as_with_columns_r {
226
let tmp_name = get_tmp_name(count);
227
count += 1;
228
e.set_alias(tmp_name.clone());
229
let dtype = e.dtype(&schema_right_new, ctxt.expr_arena)?;
230
schema_right_new.with_column(tmp_name.clone(), dtype.clone());
231
232
let col = ctxt.expr_arena.add(AExpr::Column(tmp_name));
233
right_on[i] = ExprIR::from_node(col, ctxt.expr_arena);
234
exprs.push(e);
235
}
236
input_right = ctxt.lp_arena.add(IR::HStack {
237
input: input_right,
238
exprs,
239
schema: Arc::new(schema_right_new),
240
options: ProjectionOptions::default(),
241
})
242
}
243
244
(
245
ctxt.lp_arena.get(input_left).schema(ctxt.lp_arena),
246
ctxt.lp_arena.get(input_right).schema(ctxt.lp_arena),
247
)
248
} else {
249
(schema_left, schema_right)
250
};
251
252
// Not a closure to avoid borrow issues because we mutate expr_arena as well.
253
macro_rules! get_dtype {
254
($expr:expr, $schema:expr) => {
255
ctxt.expr_arena
256
.get($expr.node())
257
.to_dtype(&ToFieldContext::new(ctxt.expr_arena, $schema))
258
};
259
}
260
261
// As an optimization, when inserting casts for coalescing joins we only insert them beforehand for full-join.
262
// This means for e.g. left-join, the LHS key preserves its dtype in the output even if it is joined
263
// with an RHS key of wider type.
264
let key_cols_coalesced =
265
options.args.should_coalesce() && matches!(&options.args.how, JoinType::Full);
266
let mut as_with_columns_l = vec![];
267
let mut as_with_columns_r = vec![];
268
for (lnode, rnode) in left_on.iter_mut().zip(right_on.iter_mut()) {
269
let ltype = get_dtype!(lnode, &schema_left)?;
270
let rtype = get_dtype!(rnode, &schema_right)?;
271
272
if let Some(dtype) = get_numeric_upcast_supertype_lossless(&ltype, &rtype) {
273
// We use overflowing cast to allow better optimization as we are casting to a known
274
// lossless supertype.
275
//
276
// We have unique references to these nodes (they are created by this function),
277
// so we can mutate in-place without causing side effects somewhere else.
278
let casted_l = ctxt.expr_arena.add(AExpr::Cast {
279
expr: lnode.node(),
280
dtype: dtype.clone(),
281
options: CastOptions::Overflowing,
282
});
283
let casted_r = ctxt.expr_arena.add(AExpr::Cast {
284
expr: rnode.node(),
285
dtype,
286
options: CastOptions::Overflowing,
287
});
288
289
if key_cols_coalesced {
290
let mut lnode = lnode.clone();
291
let mut rnode = rnode.clone();
292
293
let ae_l = ctxt.expr_arena.get(lnode.node());
294
let ae_r = ctxt.expr_arena.get(rnode.node());
295
296
polars_ensure!(
297
ae_l.is_col() && ae_r.is_col(),
298
SchemaMismatch: "can only 'coalesce' full join if join keys are column expressions",
299
);
300
301
lnode.set_node(casted_l);
302
rnode.set_node(casted_r);
303
304
as_with_columns_r.push(rnode);
305
as_with_columns_l.push(lnode);
306
} else {
307
lnode.set_node(casted_l);
308
rnode.set_node(casted_r);
309
}
310
} else {
311
polars_ensure!(
312
ltype == rtype,
313
SchemaMismatch: "datatypes of join keys don't match - `{}`: {} on left does not match `{}`: {} on right (and no other type was available to cast to)",
314
lnode.output_name(), ltype.pretty_format(), rnode.output_name(), rtype.pretty_format()
315
);
316
}
317
}
318
319
// Every expression must be elementwise so that we are
320
// guaranteed the keys for a join are all the same length.
321
322
polars_ensure!(
323
all_elementwise(&left_on, ctxt.expr_arena) && all_elementwise(&right_on, ctxt.expr_arena),
324
InvalidOperation: "all join key expressions must be elementwise."
325
);
326
327
#[cfg(feature = "asof_join")]
328
if let JoinType::AsOf(options) = &mut options.args.how {
329
use polars_core::utils::arrow::temporal_conversions::MILLISECONDS_IN_DAY;
330
331
// prepare the tolerance
332
// we must ensure that we use the right units
333
if let Some(tol) = &options.tolerance_str {
334
let duration = polars_time::Duration::try_parse(tol)?;
335
polars_ensure!(
336
duration.months() == 0,
337
ComputeError: "cannot use month offset in timedelta of an asof join; \
338
consider using 4 weeks"
339
);
340
use DataType::*;
341
match ctxt
342
.expr_arena
343
.get(left_on[0].node())
344
.to_dtype(&ToFieldContext::new(ctxt.expr_arena, &schema_left))?
345
{
346
Datetime(tu, _) | Duration(tu) => {
347
let tolerance = match tu {
348
TimeUnit::Nanoseconds => duration.duration_ns(),
349
TimeUnit::Microseconds => duration.duration_us(),
350
TimeUnit::Milliseconds => duration.duration_ms(),
351
};
352
options.tolerance = Some(Scalar::from(tolerance))
353
},
354
Date => {
355
let days = (duration.duration_ms() / MILLISECONDS_IN_DAY) as i32;
356
options.tolerance = Some(Scalar::from(days))
357
},
358
Time => {
359
let tolerance = duration.duration_ns();
360
options.tolerance = Some(Scalar::from(tolerance))
361
},
362
_ => {
363
panic!(
364
"can only use timedelta string language with Date/Datetime/Duration/Time dtypes"
365
)
366
},
367
}
368
}
369
}
370
371
// These are Arc<Schema>, into_owned is free.
372
let schema_left = schema_left.into_owned();
373
let schema_right = schema_right.into_owned();
374
375
let join_schema = det_join_schema(
376
&schema_left,
377
&schema_right,
378
&left_on,
379
&right_on,
380
&options,
381
ctxt.expr_arena,
382
)
383
.map_err(|e| e.context(failed_here!(join schema resolving)))?;
384
385
if key_cols_coalesced {
386
input_left = if as_with_columns_l.is_empty() {
387
input_left
388
} else {
389
ctxt.lp_arena.add(IR::HStack {
390
input: input_left,
391
exprs: as_with_columns_l,
392
schema: schema_left,
393
options: ProjectionOptions::default(),
394
})
395
};
396
397
input_right = if as_with_columns_r.is_empty() {
398
input_right
399
} else {
400
ctxt.lp_arena.add(IR::HStack {
401
input: input_right,
402
exprs: as_with_columns_r,
403
schema: schema_right,
404
options: ProjectionOptions::default(),
405
})
406
};
407
}
408
409
let ir = IR::Join {
410
input_left,
411
input_right,
412
schema: join_schema.clone(),
413
left_on,
414
right_on,
415
options: Arc::new(options),
416
};
417
let join_node = ctxt.lp_arena.add(ir);
418
419
if has_scalars {
420
let names = join_schema
421
.iter_names()
422
.filter_map(|n| {
423
if n.starts_with(POLARS_TMP_PREFIX) {
424
None
425
} else {
426
Some(n.clone())
427
}
428
})
429
.collect_vec();
430
431
let builder = IRBuilder::new(join_node, ctxt.expr_arena, ctxt.lp_arena);
432
let ir = builder.project_simple(names).map(|b| b.build())?;
433
let select_node = ctxt.lp_arena.add(ir);
434
435
Ok((select_node, join_node))
436
} else {
437
Ok((join_node, join_node))
438
}
439
}
440
441
#[cfg(feature = "iejoin")]
442
impl From<InequalityOperator> for Operator {
443
fn from(value: InequalityOperator) -> Self {
444
match value {
445
InequalityOperator::LtEq => Operator::LtEq,
446
InequalityOperator::Lt => Operator::Lt,
447
InequalityOperator::GtEq => Operator::GtEq,
448
InequalityOperator::Gt => Operator::Gt,
449
}
450
}
451
}
452
453
#[cfg(feature = "iejoin")]
454
/// Returns: left: join_node, right: last_node (often both the same)
455
fn resolve_join_where(
456
input_left: Arc<DslPlan>,
457
input_right: Arc<DslPlan>,
458
predicates: Vec<Expr>,
459
mut options: JoinOptionsIR,
460
ctxt: &mut DslConversionContext,
461
) -> PolarsResult<(Node, Node)> {
462
// If not eager, respect the flag.
463
if ctxt.opt_flags.eager() {
464
ctxt.opt_flags.set(OptFlags::PREDICATE_PUSHDOWN, true);
465
}
466
check_join_keys(&predicates)?;
467
let input_left = to_alp_impl(Arc::unwrap_or_clone(input_left), ctxt)
468
.map_err(|e| e.context(failed_here!(join left)))?;
469
let input_right = to_alp_impl(Arc::unwrap_or_clone(input_right), ctxt)
470
.map_err(|e| e.context(failed_here!(join left)))?;
471
472
let schema_left = ctxt
473
.lp_arena
474
.get(input_left)
475
.schema(ctxt.lp_arena)
476
.into_owned();
477
478
options.args.how = JoinType::Cross;
479
480
let (mut last_node, join_node) = resolve_join(
481
Either::Right(input_left),
482
Either::Right(input_right),
483
vec![],
484
vec![],
485
vec![],
486
options,
487
ctxt,
488
)?;
489
490
let schema_merged = ctxt
491
.lp_arena
492
.get(last_node)
493
.schema(ctxt.lp_arena)
494
.into_owned();
495
496
// Perform predicate validation.
497
let mut upcast_exprs = Vec::<(Node, DataType)>::new();
498
for e in predicates {
499
let arena = &mut ctxt.expr_arena;
500
let predicate = to_expr_ir_materialized_lit(
501
e,
502
&mut ExprToIRContext::new_with_opt_eager(arena, &schema_merged, ctxt.opt_flags),
503
)?;
504
let node = predicate.node();
505
506
// Ensure the predicate dtype output of the root node is Boolean
507
let ae = arena.get(node);
508
let dt_out = ae.to_dtype(&ToFieldContext::new(arena, &schema_merged))?;
509
polars_ensure!(
510
dt_out == DataType::Boolean,
511
ComputeError: "'join_where' predicates must resolve to boolean"
512
);
513
514
ensure_lossless_binary_comparisons(
515
&node,
516
&schema_left,
517
&schema_merged,
518
arena,
519
&mut upcast_exprs,
520
)?;
521
522
ctxt.conversion_optimizer
523
.push_scratch(predicate.node(), ctxt.expr_arena);
524
525
let ir = IR::Filter {
526
input: last_node,
527
predicate,
528
};
529
530
last_node = ctxt.lp_arena.add(ir);
531
}
532
533
ctxt.conversion_optimizer
534
.optimize_exprs(ctxt.expr_arena, ctxt.lp_arena, last_node, false)
535
.map_err(|e| e.context("'join_where' failed".into()))?;
536
537
Ok((last_node, join_node))
538
}
539
540
/// Locate nodes that are operands in a binary comparison involving both tables, and ensure that
541
/// these nodes are losslessly upcast to a safe dtype.
542
fn ensure_lossless_binary_comparisons(
543
node: &Node,
544
schema_left: &Schema,
545
schema_merged: &Schema,
546
expr_arena: &mut Arena<AExpr>,
547
upcast_exprs: &mut Vec<(Node, DataType)>,
548
) -> PolarsResult<()> {
549
// let mut upcast_exprs = Vec::<(Node, DataType)>::new();
550
// Ensure that all binary comparisons that use both tables are lossless.
551
build_upcast_node_list(node, schema_left, schema_merged, expr_arena, upcast_exprs)?;
552
// Replace each node with its casted counterpart
553
for (expr, dtype) in upcast_exprs.drain(..) {
554
let old_expr = expr_arena.duplicate(expr);
555
let new_aexpr = AExpr::Cast {
556
expr: old_expr,
557
dtype,
558
options: CastOptions::Overflowing,
559
};
560
expr_arena.replace(expr, new_aexpr);
561
}
562
Ok(())
563
}
564
565
/// If we are dealing with a binary comparison involving columns from exclusively the left table
566
/// on the LHS and the right table on the RHS side, ensure that the cast is lossless.
567
/// Expressions involving binaries using either table alone we leave up to the user to verify
568
/// that they are valid, as they could theoretically be pushed outside of the join.
569
#[recursive]
570
fn build_upcast_node_list(
571
node: &Node,
572
schema_left: &Schema,
573
schema_merged: &Schema,
574
expr_arena: &Arena<AExpr>,
575
to_replace: &mut Vec<(Node, DataType)>,
576
) -> PolarsResult<ExprOrigin> {
577
let expr_origin = match expr_arena.get(*node) {
578
AExpr::Column(name) => {
579
if schema_left.contains(name) {
580
ExprOrigin::Left
581
} else if schema_merged.contains(name) {
582
ExprOrigin::Right
583
} else {
584
polars_bail!(ColumnNotFound: "{name}");
585
}
586
},
587
AExpr::Literal(..) => ExprOrigin::None,
588
AExpr::Cast { expr: node, .. } => {
589
build_upcast_node_list(node, schema_left, schema_merged, expr_arena, to_replace)?
590
},
591
AExpr::BinaryExpr {
592
left: left_node,
593
op,
594
right: right_node,
595
} => {
596
// If left and right node has both, ensure the dtypes are valid.
597
let left_origin = build_upcast_node_list(
598
left_node,
599
schema_left,
600
schema_merged,
601
expr_arena,
602
to_replace,
603
)?;
604
let right_origin = build_upcast_node_list(
605
right_node,
606
schema_left,
607
schema_merged,
608
expr_arena,
609
to_replace,
610
)?;
611
// We only update casts during comparisons if the operands are from different tables.
612
if op.is_comparison() {
613
match (left_origin, right_origin) {
614
(ExprOrigin::Left, ExprOrigin::Right)
615
| (ExprOrigin::Right, ExprOrigin::Left) => {
616
// Ensure our dtype casts are lossless
617
let left = expr_arena.get(*left_node);
618
let right = expr_arena.get(*right_node);
619
let dtype_left =
620
left.to_dtype(&ToFieldContext::new(expr_arena, schema_merged))?;
621
let dtype_right =
622
right.to_dtype(&ToFieldContext::new(expr_arena, schema_merged))?;
623
if dtype_left != dtype_right {
624
// Ensure that we have a lossless cast between the two types.
625
let dt = if dtype_left.is_primitive_numeric()
626
|| dtype_right.is_primitive_numeric()
627
{
628
get_numeric_upcast_supertype_lossless(&dtype_left, &dtype_right)
629
.ok_or(PolarsError::SchemaMismatch(
630
format!(
631
"'join_where' cannot compare {dtype_left:?} with {dtype_right:?}"
632
)
633
.into(),
634
))
635
} else {
636
try_get_supertype(&dtype_left, &dtype_right)
637
}?;
638
639
// Store the nodes and their replacements if a cast is required.
640
let replace_left = dt != dtype_left;
641
let replace_right = dt != dtype_right;
642
if replace_left && replace_right {
643
to_replace.push((*left_node, dt.clone()));
644
to_replace.push((*right_node, dt));
645
} else if replace_left {
646
to_replace.push((*left_node, dt));
647
} else if replace_right {
648
to_replace.push((*right_node, dt));
649
}
650
}
651
},
652
_ => (),
653
}
654
}
655
left_origin | right_origin
656
},
657
_ => ExprOrigin::None,
658
};
659
Ok(expr_origin)
660
}
661
662