Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/plans/conversion/dsl_to_ir/join.rs
7889 views
1
use arrow::legacy::error::PolarsResult;
2
use either::Either;
3
use polars_core::chunked_array::cast::CastOptions;
4
use polars_core::error::feature_gated;
5
use polars_core::utils::{get_numeric_upcast_supertype_lossless, try_get_supertype};
6
use polars_utils::format_pl_smallstr;
7
use polars_utils::itertools::Itertools;
8
9
use super::*;
10
use crate::constants::POLARS_TMP_PREFIX;
11
use crate::dsl::Expr;
12
#[cfg(feature = "iejoin")]
13
use crate::plans::AExpr;
14
15
fn check_join_keys(keys: &[Expr]) -> PolarsResult<()> {
16
for e in keys {
17
if has_expr(e, |e| matches!(e, Expr::Alias(_, _))) {
18
polars_bail!(
19
InvalidOperation:
20
"'alias' is not allowed in a join key, use 'with_columns' first",
21
)
22
}
23
}
24
Ok(())
25
}
26
27
/// Returns: left: join_node, right: last_node (often both the same)
28
pub fn resolve_join(
29
input_left: Either<Arc<DslPlan>, Node>,
30
input_right: Either<Arc<DslPlan>, Node>,
31
left_on: Vec<Expr>,
32
right_on: Vec<Expr>,
33
predicates: Vec<Expr>,
34
mut options: JoinOptionsIR,
35
ctxt: &mut DslConversionContext,
36
) -> PolarsResult<(Node, Node)> {
37
if !predicates.is_empty() {
38
feature_gated!("iejoin", {
39
debug_assert!(left_on.is_empty() && right_on.is_empty());
40
return resolve_join_where(
41
input_left.unwrap_left(),
42
input_right.unwrap_left(),
43
predicates,
44
options,
45
ctxt,
46
);
47
})
48
}
49
50
let owned = Arc::unwrap_or_clone;
51
let mut input_left = input_left.map_right(Ok).right_or_else(|input| {
52
to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_here!(join left)))
53
})?;
54
let mut input_right = input_right.map_right(Ok).right_or_else(|input| {
55
to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_here!(join right)))
56
})?;
57
58
let schema_left = ctxt.lp_arena.get(input_left).schema(ctxt.lp_arena);
59
let schema_right = ctxt.lp_arena.get(input_right).schema(ctxt.lp_arena);
60
61
if options.args.how.is_cross() {
62
polars_ensure!(left_on.len() + right_on.len() == 0, InvalidOperation: "a 'cross' join doesn't expect any join keys");
63
} else {
64
polars_ensure!(left_on.len() + right_on.len() > 0, InvalidOperation: "expected join keys/predicates");
65
check_join_keys(&left_on)?;
66
check_join_keys(&right_on)?;
67
68
let mut turn_off_coalesce = false;
69
for e in left_on.iter().chain(right_on.iter()) {
70
// Any expression that is not a simple column expression will turn of coalescing.
71
turn_off_coalesce |= has_expr(e, |e| !matches!(e, Expr::Column(_)));
72
}
73
if turn_off_coalesce {
74
if matches!(options.args.coalesce, JoinCoalesce::CoalesceColumns) {
75
polars_warn!(
76
"coalescing join requested but not all join keys are column references, turning off key coalescing"
77
);
78
}
79
options.args.coalesce = JoinCoalesce::KeepColumns;
80
}
81
82
options.args.validation.is_valid_join(&options.args.how)?;
83
84
#[cfg(feature = "asof_join")]
85
if let JoinType::AsOf(options) = &options.args.how {
86
match (&options.left_by, &options.right_by) {
87
(None, None) => {},
88
(Some(l), Some(r)) => {
89
polars_ensure!(l.len() == r.len(), InvalidOperation: "expected equal number of columns in 'by_left' and 'by_right' in 'asof_join'");
90
validate_columns_in_input(l, &schema_left, "asof_join")?;
91
validate_columns_in_input(r, &schema_right, "asof_join")?;
92
},
93
_ => {
94
polars_bail!(InvalidOperation: "expected both 'by_left' and 'by_right' to be set in 'asof_join'")
95
},
96
}
97
}
98
99
polars_ensure!(
100
left_on.len() == right_on.len(),
101
InvalidOperation:
102
format!(
103
"the number of columns given as join key (left: {}, right:{}) should be equal",
104
left_on.len(),
105
right_on.len()
106
)
107
);
108
}
109
110
let mut left_on = left_on
111
.into_iter()
112
.map(|e| {
113
to_expr_ir_materialized_lit(
114
e,
115
&mut ExprToIRContext::new_with_opt_eager(
116
ctxt.expr_arena,
117
&schema_left,
118
ctxt.opt_flags,
119
),
120
)
121
})
122
.collect::<PolarsResult<Vec<_>>>()?;
123
let mut right_on = right_on
124
.into_iter()
125
.map(|e| {
126
to_expr_ir_materialized_lit(
127
e,
128
&mut ExprToIRContext::new_with_opt_eager(
129
ctxt.expr_arena,
130
&schema_right,
131
ctxt.opt_flags,
132
),
133
)
134
})
135
.collect::<PolarsResult<Vec<_>>>()?;
136
let mut joined_on = PlHashSet::new();
137
138
#[cfg(feature = "iejoin")]
139
let check = !matches!(options.args.how, JoinType::IEJoin);
140
#[cfg(not(feature = "iejoin"))]
141
let check = true;
142
if check {
143
for (l, r) in left_on.iter().zip(right_on.iter()) {
144
polars_ensure!(
145
joined_on.insert((l.output_name(), r.output_name())),
146
InvalidOperation: "joining with repeated key names; already joined on {} and {}",
147
l.output_name(),
148
r.output_name()
149
)
150
}
151
}
152
drop(joined_on);
153
154
ctxt.conversion_optimizer
155
.fill_scratch(&left_on, ctxt.expr_arena);
156
ctxt.conversion_optimizer
157
.optimize_exprs(ctxt.expr_arena, ctxt.lp_arena, input_left, true)
158
.map_err(|e| e.context("'join' failed".into()))?;
159
ctxt.conversion_optimizer
160
.fill_scratch(&right_on, ctxt.expr_arena);
161
ctxt.conversion_optimizer
162
.optimize_exprs(ctxt.expr_arena, ctxt.lp_arena, input_right, true)
163
.map_err(|e| e.context("'join' failed".into()))?;
164
165
// Re-evaluate because of mutable borrows earlier.
166
let schema_left = ctxt.lp_arena.get(input_left).schema(ctxt.lp_arena);
167
let schema_right = ctxt.lp_arena.get(input_right).schema(ctxt.lp_arena);
168
169
// # Resolve scalars
170
//
171
// Scalars need to be expanded. We translate them to temporary columns added with
172
// `with_columns` and remove them later with `project`
173
// This way the backends don't have to expand the literals in the join implementation
174
175
let has_scalars = left_on
176
.iter()
177
.chain(right_on.iter())
178
.any(|e| e.is_scalar(ctxt.expr_arena));
179
180
let (schema_left, schema_right) = if has_scalars {
181
let mut as_with_columns_l = vec![];
182
let mut as_with_columns_r = vec![];
183
for (i, e) in left_on.iter().enumerate() {
184
if e.is_scalar(ctxt.expr_arena) {
185
as_with_columns_l.push((i, e.clone()));
186
}
187
}
188
for (i, e) in right_on.iter().enumerate() {
189
if e.is_scalar(ctxt.expr_arena) {
190
as_with_columns_r.push((i, e.clone()));
191
}
192
}
193
194
let mut count = 0;
195
let get_tmp_name = |i| format_pl_smallstr!("{POLARS_TMP_PREFIX}{i}");
196
197
// Early clone because of bck.
198
let mut schema_right_new = if !as_with_columns_r.is_empty() {
199
(**schema_right).clone()
200
} else {
201
Default::default()
202
};
203
if !as_with_columns_l.is_empty() {
204
let mut schema_left_new = (**schema_left).clone();
205
206
let mut exprs = Vec::with_capacity(as_with_columns_l.len());
207
for (i, mut e) in as_with_columns_l {
208
let tmp_name = get_tmp_name(count);
209
count += 1;
210
e.set_alias(tmp_name.clone());
211
let dtype = e.dtype(&schema_left_new, ctxt.expr_arena)?;
212
schema_left_new.with_column(tmp_name.clone(), dtype.clone());
213
214
let col = ctxt.expr_arena.add(AExpr::Column(tmp_name));
215
left_on[i] = ExprIR::from_node(col, ctxt.expr_arena);
216
exprs.push(e);
217
}
218
input_left = ctxt.lp_arena.add(IR::HStack {
219
input: input_left,
220
exprs,
221
schema: Arc::new(schema_left_new),
222
options: ProjectionOptions::default(),
223
})
224
}
225
if !as_with_columns_r.is_empty() {
226
let mut exprs = Vec::with_capacity(as_with_columns_r.len());
227
for (i, mut e) in as_with_columns_r {
228
let tmp_name = get_tmp_name(count);
229
count += 1;
230
e.set_alias(tmp_name.clone());
231
let dtype = e.dtype(&schema_right_new, ctxt.expr_arena)?;
232
schema_right_new.with_column(tmp_name.clone(), dtype.clone());
233
234
let col = ctxt.expr_arena.add(AExpr::Column(tmp_name));
235
right_on[i] = ExprIR::from_node(col, ctxt.expr_arena);
236
exprs.push(e);
237
}
238
input_right = ctxt.lp_arena.add(IR::HStack {
239
input: input_right,
240
exprs,
241
schema: Arc::new(schema_right_new),
242
options: ProjectionOptions::default(),
243
})
244
}
245
246
(
247
ctxt.lp_arena.get(input_left).schema(ctxt.lp_arena),
248
ctxt.lp_arena.get(input_right).schema(ctxt.lp_arena),
249
)
250
} else {
251
(schema_left, schema_right)
252
};
253
254
// Not a closure to avoid borrow issues because we mutate expr_arena as well.
255
macro_rules! get_dtype {
256
($expr:expr, $schema:expr) => {
257
ctxt.expr_arena
258
.get($expr.node())
259
.to_dtype(&ToFieldContext::new(ctxt.expr_arena, $schema))
260
};
261
}
262
263
// As an optimization, when inserting casts for coalescing joins we only insert them beforehand for full-join.
264
// This means for e.g. left-join, the LHS key preserves its dtype in the output even if it is joined
265
// with an RHS key of wider type.
266
let key_cols_coalesced =
267
options.args.should_coalesce() && matches!(&options.args.how, JoinType::Full);
268
let mut as_with_columns_l = vec![];
269
let mut as_with_columns_r = vec![];
270
for (lnode, rnode) in left_on.iter_mut().zip(right_on.iter_mut()) {
271
let ltype = get_dtype!(lnode, &schema_left)?;
272
let rtype = get_dtype!(rnode, &schema_right)?;
273
274
if let Some(dtype) = get_numeric_upcast_supertype_lossless(&ltype, &rtype) {
275
// We use overflowing cast to allow better optimization as we are casting to a known
276
// lossless supertype.
277
//
278
// We have unique references to these nodes (they are created by this function),
279
// so we can mutate in-place without causing side effects somewhere else.
280
let casted_l = ctxt.expr_arena.add(AExpr::Cast {
281
expr: lnode.node(),
282
dtype: dtype.clone(),
283
options: CastOptions::Overflowing,
284
});
285
let casted_r = ctxt.expr_arena.add(AExpr::Cast {
286
expr: rnode.node(),
287
dtype,
288
options: CastOptions::Overflowing,
289
});
290
291
if key_cols_coalesced {
292
let mut lnode = lnode.clone();
293
let mut rnode = rnode.clone();
294
295
let ae_l = ctxt.expr_arena.get(lnode.node());
296
let ae_r = ctxt.expr_arena.get(rnode.node());
297
298
polars_ensure!(
299
ae_l.is_col() && ae_r.is_col(),
300
SchemaMismatch: "can only 'coalesce' full join if join keys are column expressions",
301
);
302
303
lnode.set_node(casted_l);
304
rnode.set_node(casted_r);
305
306
as_with_columns_r.push(rnode);
307
as_with_columns_l.push(lnode);
308
} else {
309
lnode.set_node(casted_l);
310
rnode.set_node(casted_r);
311
}
312
} else {
313
polars_ensure!(
314
ltype == rtype,
315
SchemaMismatch: "datatypes of join keys don't match - `{}`: {} on left does not match `{}`: {} on right (and no other type was available to cast to)",
316
lnode.output_name(), ltype, rnode.output_name(), rtype
317
)
318
}
319
}
320
321
// Every expression must be elementwise so that we are
322
// guaranteed the keys for a join are all the same length.
323
324
polars_ensure!(
325
all_elementwise(&left_on, ctxt.expr_arena) && all_elementwise(&right_on, ctxt.expr_arena),
326
InvalidOperation: "all join key expressions must be elementwise."
327
);
328
329
#[cfg(feature = "asof_join")]
330
if let JoinType::AsOf(options) = &mut options.args.how {
331
use polars_core::utils::arrow::temporal_conversions::MILLISECONDS_IN_DAY;
332
333
// prepare the tolerance
334
// we must ensure that we use the right units
335
if let Some(tol) = &options.tolerance_str {
336
let duration = polars_time::Duration::try_parse(tol)?;
337
polars_ensure!(
338
duration.months() == 0,
339
ComputeError: "cannot use month offset in timedelta of an asof join; \
340
consider using 4 weeks"
341
);
342
use DataType::*;
343
match ctxt
344
.expr_arena
345
.get(left_on[0].node())
346
.to_dtype(&ToFieldContext::new(ctxt.expr_arena, &schema_left))?
347
{
348
Datetime(tu, _) | Duration(tu) => {
349
let tolerance = match tu {
350
TimeUnit::Nanoseconds => duration.duration_ns(),
351
TimeUnit::Microseconds => duration.duration_us(),
352
TimeUnit::Milliseconds => duration.duration_ms(),
353
};
354
options.tolerance = Some(Scalar::from(tolerance))
355
},
356
Date => {
357
let days = (duration.duration_ms() / MILLISECONDS_IN_DAY) as i32;
358
options.tolerance = Some(Scalar::from(days))
359
},
360
Time => {
361
let tolerance = duration.duration_ns();
362
options.tolerance = Some(Scalar::from(tolerance))
363
},
364
_ => {
365
panic!(
366
"can only use timedelta string language with Date/Datetime/Duration/Time dtypes"
367
)
368
},
369
}
370
}
371
}
372
373
// These are Arc<Schema>, into_owned is free.
374
let schema_left = schema_left.into_owned();
375
let schema_right = schema_right.into_owned();
376
377
let join_schema = det_join_schema(
378
&schema_left,
379
&schema_right,
380
&left_on,
381
&right_on,
382
&options,
383
ctxt.expr_arena,
384
)
385
.map_err(|e| e.context(failed_here!(join schema resolving)))?;
386
387
if key_cols_coalesced {
388
input_left = if as_with_columns_l.is_empty() {
389
input_left
390
} else {
391
ctxt.lp_arena.add(IR::HStack {
392
input: input_left,
393
exprs: as_with_columns_l,
394
schema: schema_left,
395
options: ProjectionOptions::default(),
396
})
397
};
398
399
input_right = if as_with_columns_r.is_empty() {
400
input_right
401
} else {
402
ctxt.lp_arena.add(IR::HStack {
403
input: input_right,
404
exprs: as_with_columns_r,
405
schema: schema_right,
406
options: ProjectionOptions::default(),
407
})
408
};
409
}
410
411
let ir = IR::Join {
412
input_left,
413
input_right,
414
schema: join_schema.clone(),
415
left_on,
416
right_on,
417
options: Arc::new(options),
418
};
419
let join_node = ctxt.lp_arena.add(ir);
420
421
if has_scalars {
422
let names = join_schema
423
.iter_names()
424
.filter_map(|n| {
425
if n.starts_with(POLARS_TMP_PREFIX) {
426
None
427
} else {
428
Some(n.clone())
429
}
430
})
431
.collect_vec();
432
433
let builder = IRBuilder::new(join_node, ctxt.expr_arena, ctxt.lp_arena);
434
let ir = builder.project_simple(names).map(|b| b.build())?;
435
let select_node = ctxt.lp_arena.add(ir);
436
437
Ok((select_node, join_node))
438
} else {
439
Ok((join_node, join_node))
440
}
441
}
442
443
#[cfg(feature = "iejoin")]
444
impl From<InequalityOperator> for Operator {
445
fn from(value: InequalityOperator) -> Self {
446
match value {
447
InequalityOperator::LtEq => Operator::LtEq,
448
InequalityOperator::Lt => Operator::Lt,
449
InequalityOperator::GtEq => Operator::GtEq,
450
InequalityOperator::Gt => Operator::Gt,
451
}
452
}
453
}
454
455
#[cfg(feature = "iejoin")]
456
/// Returns: left: join_node, right: last_node (often both the same)
457
fn resolve_join_where(
458
input_left: Arc<DslPlan>,
459
input_right: Arc<DslPlan>,
460
predicates: Vec<Expr>,
461
mut options: JoinOptionsIR,
462
ctxt: &mut DslConversionContext,
463
) -> PolarsResult<(Node, Node)> {
464
// If not eager, respect the flag.
465
if ctxt.opt_flags.eager() {
466
ctxt.opt_flags.set(OptFlags::PREDICATE_PUSHDOWN, true);
467
}
468
check_join_keys(&predicates)?;
469
let input_left = to_alp_impl(Arc::unwrap_or_clone(input_left), ctxt)
470
.map_err(|e| e.context(failed_here!(join left)))?;
471
let input_right = to_alp_impl(Arc::unwrap_or_clone(input_right), ctxt)
472
.map_err(|e| e.context(failed_here!(join left)))?;
473
474
let schema_left = ctxt
475
.lp_arena
476
.get(input_left)
477
.schema(ctxt.lp_arena)
478
.into_owned();
479
480
options.args.how = JoinType::Cross;
481
482
let (mut last_node, join_node) = resolve_join(
483
Either::Right(input_left),
484
Either::Right(input_right),
485
vec![],
486
vec![],
487
vec![],
488
options,
489
ctxt,
490
)?;
491
492
let schema_merged = ctxt
493
.lp_arena
494
.get(last_node)
495
.schema(ctxt.lp_arena)
496
.into_owned();
497
498
// Perform predicate validation.
499
let mut upcast_exprs = Vec::<(Node, DataType)>::new();
500
for e in predicates {
501
let arena = &mut ctxt.expr_arena;
502
let predicate = to_expr_ir_materialized_lit(
503
e,
504
&mut ExprToIRContext::new_with_opt_eager(arena, &schema_merged, ctxt.opt_flags),
505
)?;
506
let node = predicate.node();
507
508
// Ensure the predicate dtype output of the root node is Boolean
509
let ae = arena.get(node);
510
let dt_out = ae.to_dtype(&ToFieldContext::new(arena, &schema_merged))?;
511
polars_ensure!(
512
dt_out == DataType::Boolean,
513
ComputeError: "'join_where' predicates must resolve to boolean"
514
);
515
516
ensure_lossless_binary_comparisons(
517
&node,
518
&schema_left,
519
&schema_merged,
520
arena,
521
&mut upcast_exprs,
522
)?;
523
524
ctxt.conversion_optimizer
525
.push_scratch(predicate.node(), ctxt.expr_arena);
526
527
let ir = IR::Filter {
528
input: last_node,
529
predicate,
530
};
531
532
last_node = ctxt.lp_arena.add(ir);
533
}
534
535
ctxt.conversion_optimizer
536
.optimize_exprs(ctxt.expr_arena, ctxt.lp_arena, last_node, false)
537
.map_err(|e| e.context("'join_where' failed".into()))?;
538
539
Ok((last_node, join_node))
540
}
541
542
/// Locate nodes that are operands in a binary comparison involving both tables, and ensure that
543
/// these nodes are losslessly upcast to a safe dtype.
544
fn ensure_lossless_binary_comparisons(
545
node: &Node,
546
schema_left: &Schema,
547
schema_merged: &Schema,
548
expr_arena: &mut Arena<AExpr>,
549
upcast_exprs: &mut Vec<(Node, DataType)>,
550
) -> PolarsResult<()> {
551
// let mut upcast_exprs = Vec::<(Node, DataType)>::new();
552
// Ensure that all binary comparisons that use both tables are lossless.
553
build_upcast_node_list(node, schema_left, schema_merged, expr_arena, upcast_exprs)?;
554
// Replace each node with its casted counterpart
555
for (expr, dtype) in upcast_exprs.drain(..) {
556
let old_expr = expr_arena.duplicate(expr);
557
let new_aexpr = AExpr::Cast {
558
expr: old_expr,
559
dtype,
560
options: CastOptions::Overflowing,
561
};
562
expr_arena.replace(expr, new_aexpr);
563
}
564
Ok(())
565
}
566
567
/// If we are dealing with a binary comparison involving columns from exclusively the left table
568
/// on the LHS and the right table on the RHS side, ensure that the cast is lossless.
569
/// Expressions involving binaries using either table alone we leave up to the user to verify
570
/// that they are valid, as they could theoretically be pushed outside of the join.
571
#[recursive]
572
fn build_upcast_node_list(
573
node: &Node,
574
schema_left: &Schema,
575
schema_merged: &Schema,
576
expr_arena: &Arena<AExpr>,
577
to_replace: &mut Vec<(Node, DataType)>,
578
) -> PolarsResult<ExprOrigin> {
579
let expr_origin = match expr_arena.get(*node) {
580
AExpr::Column(name) => {
581
if schema_left.contains(name) {
582
ExprOrigin::Left
583
} else if schema_merged.contains(name) {
584
ExprOrigin::Right
585
} else {
586
polars_bail!(ColumnNotFound: "{}", name);
587
}
588
},
589
AExpr::Literal(..) => ExprOrigin::None,
590
AExpr::Cast { expr: node, .. } => {
591
build_upcast_node_list(node, schema_left, schema_merged, expr_arena, to_replace)?
592
},
593
AExpr::BinaryExpr {
594
left: left_node,
595
op,
596
right: right_node,
597
} => {
598
// If left and right node has both, ensure the dtypes are valid.
599
let left_origin = build_upcast_node_list(
600
left_node,
601
schema_left,
602
schema_merged,
603
expr_arena,
604
to_replace,
605
)?;
606
let right_origin = build_upcast_node_list(
607
right_node,
608
schema_left,
609
schema_merged,
610
expr_arena,
611
to_replace,
612
)?;
613
// We only update casts during comparisons if the operands are from different tables.
614
if op.is_comparison() {
615
match (left_origin, right_origin) {
616
(ExprOrigin::Left, ExprOrigin::Right)
617
| (ExprOrigin::Right, ExprOrigin::Left) => {
618
// Ensure our dtype casts are lossless
619
let left = expr_arena.get(*left_node);
620
let right = expr_arena.get(*right_node);
621
let dtype_left =
622
left.to_dtype(&ToFieldContext::new(expr_arena, schema_merged))?;
623
let dtype_right =
624
right.to_dtype(&ToFieldContext::new(expr_arena, schema_merged))?;
625
if dtype_left != dtype_right {
626
// Ensure that we have a lossless cast between the two types.
627
let dt = if dtype_left.is_primitive_numeric()
628
|| dtype_right.is_primitive_numeric()
629
{
630
get_numeric_upcast_supertype_lossless(&dtype_left, &dtype_right)
631
.ok_or(PolarsError::SchemaMismatch(
632
format!(
633
"'join_where' cannot compare {dtype_left:?} with {dtype_right:?}"
634
)
635
.into(),
636
))
637
} else {
638
try_get_supertype(&dtype_left, &dtype_right)
639
}?;
640
641
// Store the nodes and their replacements if a cast is required.
642
let replace_left = dt != dtype_left;
643
let replace_right = dt != dtype_right;
644
if replace_left && replace_right {
645
to_replace.push((*left_node, dt.clone()));
646
to_replace.push((*right_node, dt));
647
} else if replace_left {
648
to_replace.push((*left_node, dt));
649
} else if replace_right {
650
to_replace.push((*right_node, dt));
651
}
652
}
653
},
654
_ => (),
655
}
656
}
657
left_origin | right_origin
658
},
659
_ => ExprOrigin::None,
660
};
661
Ok(expr_origin)
662
}
663
664