Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/plans/optimizer/collapse_joins.rs
6940 views
1
//! Optimization that collapses several a join with several filters into faster join.
2
//!
3
//! For example, `join(how='cross').filter(pl.col.l == pl.col.r)` can be collapsed to
4
//! `join(how='inner', left_on=pl.col.l, right_on=pl.col.r)`.
5
6
use std::sync::Arc;
7
8
use polars_core::schema::*;
9
#[cfg(feature = "iejoin")]
10
use polars_ops::frame::{IEJoinOptions, InequalityOperator};
11
use polars_ops::frame::{JoinCoalesce, JoinType, MaintainOrderJoin};
12
use polars_utils::arena::{Arena, Node};
13
14
use super::{AExpr, ExprOrigin, IR, JoinOptionsIR, aexpr_to_leaf_names_iter};
15
use crate::dsl::{JoinTypeOptionsIR, Operator};
16
use crate::plans::optimizer::join_utils::remove_suffix;
17
use crate::plans::{ExprIR, MintermIter, is_elementwise_rec};
18
19
fn and_expr(left: Node, right: Node, expr_arena: &mut Arena<AExpr>) -> Node {
20
expr_arena.add(AExpr::BinaryExpr {
21
left,
22
op: Operator::And,
23
right,
24
})
25
}
26
27
pub fn optimize(
28
root: Node,
29
lp_arena: &mut Arena<IR>,
30
expr_arena: &mut Arena<AExpr>,
31
streaming: bool,
32
) {
33
let mut predicates = Vec::with_capacity(4);
34
35
// Partition to:
36
// - equality predicates
37
// - IEjoin supported inequality predicates
38
// - remaining predicates
39
#[cfg(feature = "iejoin")]
40
let mut ie_op = Vec::new();
41
let mut remaining_predicates = Vec::new();
42
43
let mut ir_stack = Vec::with_capacity(16);
44
ir_stack.push(root);
45
46
while let Some(current) = ir_stack.pop() {
47
let current_ir = lp_arena.get(current);
48
current_ir.copy_inputs(&mut ir_stack);
49
50
match current_ir {
51
IR::Filter {
52
input: _,
53
predicate,
54
} => {
55
predicates.push((current, predicate.node()));
56
},
57
IR::Join {
58
input_left,
59
input_right,
60
schema,
61
left_on,
62
right_on,
63
options,
64
} if options.args.how.is_cross() => {
65
if predicates.is_empty() {
66
continue;
67
}
68
69
let suffix = options.args.suffix();
70
71
debug_assert!(left_on.is_empty());
72
debug_assert!(right_on.is_empty());
73
74
let mut eq_left_on = Vec::new();
75
let mut eq_right_on = Vec::new();
76
77
#[cfg(feature = "iejoin")]
78
let mut ie_left_on = Vec::new();
79
#[cfg(feature = "iejoin")]
80
let mut ie_right_on = Vec::new();
81
82
#[cfg(feature = "iejoin")]
83
{
84
ie_op.clear();
85
}
86
87
remaining_predicates.clear();
88
89
#[cfg(feature = "iejoin")]
90
fn to_inequality_operator(op: &Operator) -> Option<InequalityOperator> {
91
match op {
92
Operator::Lt => Some(InequalityOperator::Lt),
93
Operator::LtEq => Some(InequalityOperator::LtEq),
94
Operator::Gt => Some(InequalityOperator::Gt),
95
Operator::GtEq => Some(InequalityOperator::GtEq),
96
_ => None,
97
}
98
}
99
100
let left_schema = lp_arena.get(*input_left).schema(lp_arena);
101
let right_schema = lp_arena.get(*input_right).schema(lp_arena);
102
103
let left_schema = left_schema.as_ref();
104
let right_schema = right_schema.as_ref();
105
106
for (_, predicate_node) in &predicates {
107
for node in MintermIter::new(*predicate_node, expr_arena) {
108
let AExpr::BinaryExpr { left, op, right } = expr_arena.get(node) else {
109
remaining_predicates.push(node);
110
continue;
111
};
112
113
if !op.is_comparison_or_bitwise() {
114
// @NOTE: This is not a valid predicate, but we should not handle that
115
// here.
116
remaining_predicates.push(node);
117
continue;
118
}
119
120
let mut left = *left;
121
let mut op = *op;
122
let mut right = *right;
123
124
let left_origin = ExprOrigin::get_expr_origin(
125
left,
126
expr_arena,
127
left_schema,
128
right_schema,
129
suffix.as_str(),
130
None,
131
)
132
.unwrap();
133
let right_origin = ExprOrigin::get_expr_origin(
134
right,
135
expr_arena,
136
left_schema,
137
right_schema,
138
suffix.as_str(),
139
None,
140
)
141
.unwrap();
142
143
use ExprOrigin as EO;
144
145
// We can only join if both sides of the binary expression stem from
146
// different sides of the join.
147
match (left_origin, right_origin) {
148
(EO::Both, _) | (_, EO::Both) => {
149
// If either expression originates from the both sides, we need to
150
// filter it afterwards.
151
remaining_predicates.push(node);
152
continue;
153
},
154
(EO::None, _) | (_, EO::None) => {
155
// @TODO: This should probably be pushed down
156
remaining_predicates.push(node);
157
continue;
158
},
159
(EO::Left, EO::Left) | (EO::Right, EO::Right) => {
160
// @TODO: This can probably be pushed down in the predicate
161
// pushdown, but for now just take it as is.
162
remaining_predicates.push(node);
163
continue;
164
},
165
(EO::Right, EO::Left) => {
166
// Swap around the expressions so they match with the left_on and
167
// right_on.
168
std::mem::swap(&mut left, &mut right);
169
op = op.swap_operands();
170
},
171
(EO::Left, EO::Right) => {},
172
}
173
174
if matches!(op, Operator::Eq) {
175
eq_left_on.push(ExprIR::from_node(left, expr_arena));
176
eq_right_on.push(ExprIR::from_node(right, expr_arena));
177
} else {
178
#[cfg(feature = "iejoin")]
179
if let Some(ie_op_) = to_inequality_operator(&op) {
180
fn is_numeric(
181
node: Node,
182
expr_arena: &Arena<AExpr>,
183
schema: &Schema,
184
) -> bool {
185
aexpr_to_leaf_names_iter(node, expr_arena).any(|name| {
186
if let Some(dt) = schema.get(name.as_str()) {
187
dt.to_physical().is_primitive_numeric()
188
} else {
189
false
190
}
191
})
192
}
193
194
// We fallback to remaining if:
195
// - we already have an IEjoin or Inner join
196
// - we already have an Inner join
197
// - data is not numeric (our iejoin doesn't yet implement that)
198
if ie_op.len() >= 2
199
|| !eq_left_on.is_empty()
200
|| !is_numeric(left, expr_arena, left_schema)
201
{
202
remaining_predicates.push(node);
203
} else {
204
ie_left_on.push(ExprIR::from_node(left, expr_arena));
205
ie_right_on.push(ExprIR::from_node(right, expr_arena));
206
ie_op.push(ie_op_);
207
}
208
} else {
209
remaining_predicates.push(node);
210
}
211
212
#[cfg(not(feature = "iejoin"))]
213
remaining_predicates.push(node);
214
}
215
}
216
}
217
218
let mut can_simplify_join = false;
219
220
if !eq_left_on.is_empty() {
221
for expr in eq_right_on.iter_mut() {
222
remove_suffix(expr, expr_arena, right_schema, suffix.as_str());
223
}
224
can_simplify_join = true;
225
} else {
226
#[cfg(feature = "iejoin")]
227
if !ie_op.is_empty() {
228
for expr in ie_right_on.iter_mut() {
229
remove_suffix(expr, expr_arena, right_schema, suffix.as_str());
230
}
231
can_simplify_join = true;
232
}
233
can_simplify_join |= options.args.how.is_cross();
234
}
235
236
if can_simplify_join {
237
let new_join = insert_fitting_join(
238
eq_left_on,
239
eq_right_on,
240
#[cfg(feature = "iejoin")]
241
ie_left_on,
242
#[cfg(feature = "iejoin")]
243
ie_right_on,
244
#[cfg(feature = "iejoin")]
245
&ie_op,
246
&remaining_predicates,
247
lp_arena,
248
expr_arena,
249
options.as_ref().clone(),
250
*input_left,
251
*input_right,
252
schema.clone(),
253
streaming,
254
);
255
256
lp_arena.swap(predicates[0].0, new_join);
257
}
258
259
predicates.clear();
260
},
261
_ => {
262
predicates.clear();
263
},
264
}
265
}
266
}
267
268
#[allow(clippy::too_many_arguments)]
269
fn insert_fitting_join(
270
eq_left_on: Vec<ExprIR>,
271
eq_right_on: Vec<ExprIR>,
272
#[cfg(feature = "iejoin")] ie_left_on: Vec<ExprIR>,
273
#[cfg(feature = "iejoin")] ie_right_on: Vec<ExprIR>,
274
#[cfg(feature = "iejoin")] ie_op: &[InequalityOperator],
275
remaining_predicates: &[Node],
276
lp_arena: &mut Arena<IR>,
277
expr_arena: &mut Arena<AExpr>,
278
mut options: JoinOptionsIR,
279
input_left: Node,
280
input_right: Node,
281
schema: SchemaRef,
282
streaming: bool,
283
) -> Node {
284
debug_assert_eq!(eq_left_on.len(), eq_right_on.len());
285
#[cfg(feature = "iejoin")]
286
{
287
debug_assert_eq!(ie_op.len(), ie_left_on.len());
288
debug_assert_eq!(ie_left_on.len(), ie_right_on.len());
289
debug_assert!(ie_op.len() <= 2);
290
}
291
debug_assert!(matches!(options.args.how, JoinType::Cross));
292
293
let remaining_predicates = remaining_predicates
294
.iter()
295
.copied()
296
.reduce(|left, right| and_expr(left, right, expr_arena));
297
298
let (left_on, right_on, remaining_predicates) = match () {
299
_ if !eq_left_on.is_empty() => {
300
options.args.how = JoinType::Inner;
301
// We need to make sure not to delete any columns
302
options.args.coalesce = JoinCoalesce::KeepColumns;
303
304
#[cfg(feature = "iejoin")]
305
let remaining_predicates = ie_left_on.into_iter().zip(ie_op).zip(ie_right_on).fold(
306
remaining_predicates,
307
|acc, ((left, op), right)| {
308
let e = expr_arena.add(AExpr::BinaryExpr {
309
left: left.node(),
310
op: (*op).into(),
311
right: right.node(),
312
});
313
Some(acc.map_or(e, |acc| and_expr(acc, e, expr_arena)))
314
},
315
);
316
317
(eq_left_on, eq_right_on, remaining_predicates)
318
},
319
#[cfg(feature = "iejoin")]
320
_ if !ie_op.is_empty() => {
321
// We can only IE join up to 2 operators
322
323
let operator1 = ie_op[0];
324
let operator2 = ie_op.get(1).copied();
325
326
// Do an IEjoin.
327
options.args.how = JoinType::IEJoin;
328
options.options = Some(JoinTypeOptionsIR::IEJoin(IEJoinOptions {
329
operator1,
330
operator2,
331
}));
332
// We need to make sure not to delete any columns
333
options.args.coalesce = JoinCoalesce::KeepColumns;
334
335
(ie_left_on, ie_right_on, remaining_predicates)
336
},
337
// If anything just fall back to a cross join.
338
_ => {
339
options.args.how = JoinType::Cross;
340
// We need to make sure not to delete any columns
341
options.args.coalesce = JoinCoalesce::KeepColumns;
342
343
#[cfg(feature = "iejoin")]
344
let remaining_predicates = ie_left_on.into_iter().zip(ie_op).zip(ie_right_on).fold(
345
remaining_predicates,
346
|acc, ((left, op), right)| {
347
let e = expr_arena.add(AExpr::BinaryExpr {
348
left: left.node(),
349
op: (*op).into(),
350
right: right.node(),
351
});
352
Some(acc.map_or(e, |acc| and_expr(acc, e, expr_arena)))
353
},
354
);
355
356
let mut remaining_predicates = remaining_predicates;
357
if let Some(pred) = remaining_predicates.take_if(|pred| {
358
matches!(options.args.maintain_order, MaintainOrderJoin::None)
359
&& !streaming
360
&& is_elementwise_rec(*pred, expr_arena)
361
}) {
362
options.options = Some(JoinTypeOptionsIR::CrossAndFilter {
363
predicate: ExprIR::from_node(pred, expr_arena),
364
})
365
}
366
367
(Vec::new(), Vec::new(), remaining_predicates)
368
},
369
};
370
371
// Note: We expect key type upcasting / expression optimizations have already been done during
372
// DSL->IR conversion.
373
374
let join_ir = IR::Join {
375
input_left,
376
input_right,
377
schema,
378
left_on,
379
right_on,
380
options: Arc::new(options),
381
};
382
383
let join_node = lp_arena.add(join_ir);
384
385
if let Some(predicate) = remaining_predicates {
386
lp_arena.add(IR::Filter {
387
input: join_node,
388
predicate: ExprIR::from_node(predicate, &*expr_arena),
389
})
390
} else {
391
join_node
392
}
393
}
394
395