Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/plans/optimizer/cache_states.rs
6940 views
1
use std::collections::BTreeMap;
2
3
use polars_utils::unique_id::UniqueId;
4
5
use super::*;
6
7
fn get_upper_projections(
8
parent: Node,
9
lp_arena: &Arena<IR>,
10
expr_arena: &Arena<AExpr>,
11
names_scratch: &mut Vec<PlSmallStr>,
12
found_required_columns: &mut bool,
13
) -> bool {
14
let parent = lp_arena.get(parent);
15
16
use IR::*;
17
// During projection pushdown all accumulated.
18
match parent {
19
SimpleProjection { columns, .. } => {
20
let iter = columns.iter_names_cloned();
21
names_scratch.extend(iter);
22
*found_required_columns = true;
23
false
24
},
25
Filter { predicate, .. } => {
26
// Also add predicate, as the projection is above the filter node.
27
names_scratch.extend(aexpr_to_leaf_names(predicate.node(), expr_arena));
28
29
true
30
},
31
// Only filter and projection nodes are allowed, any other node we stop.
32
_ => false,
33
}
34
}
35
36
fn get_upper_predicates(
37
parent: Node,
38
lp_arena: &Arena<IR>,
39
expr_arena: &mut Arena<AExpr>,
40
predicate_scratch: &mut Vec<Expr>,
41
) -> bool {
42
let parent = lp_arena.get(parent);
43
44
use IR::*;
45
match parent {
46
Filter { predicate, .. } => {
47
let expr = predicate.to_expr(expr_arena);
48
predicate_scratch.push(expr);
49
false
50
},
51
SimpleProjection { .. } => true,
52
// Only filter and projection nodes are allowed, any other node we stop.
53
_ => false,
54
}
55
}
56
57
type TwoParents = [Option<Node>; 2];
58
59
// 1. This will ensure that all equal caches communicate the amount of columns
60
// they need to project.
61
// 2. This will ensure we apply predicate in the subtrees below the caches.
62
// If the predicate above the cache is the same for all matching caches, that filter will be
63
// applied as well.
64
//
65
// # Example
66
// Consider this tree, where `SUB-TREE` is duplicate and can be cached.
67
//
68
//
69
// Tree
70
// |
71
// |
72
// |--------------------|-------------------|
73
// | |
74
// SUB-TREE SUB-TREE
75
//
76
// STEPS:
77
// - 1. CSE will run and will insert cache nodes
78
//
79
// Tree
80
// |
81
// |
82
// |--------------------|-------------------|
83
// | |
84
// | CACHE 0 | CACHE 0
85
// | |
86
// SUB-TREE SUB-TREE
87
//
88
// - 2. predicate and projection pushdown will run and will insert optional FILTER and PROJECTION above the caches
89
//
90
// Tree
91
// |
92
// |
93
// |--------------------|-------------------|
94
// | FILTER (optional) | FILTER (optional)
95
// | PROJ (optional) | PROJ (optional)
96
// | |
97
// | CACHE 0 | CACHE 0
98
// | |
99
// SUB-TREE SUB-TREE
100
//
101
// # Projection optimization
102
// The union of the projection is determined and the projection will be pushed down.
103
//
104
// Tree
105
// |
106
// |
107
// |--------------------|-------------------|
108
// | FILTER (optional) | FILTER (optional)
109
// | CACHE 0 | CACHE 0
110
// | |
111
// SUB-TREE SUB-TREE
112
// UNION PROJ (optional) UNION PROJ (optional)
113
//
114
// # Filter optimization
115
// Depending on the predicates the predicate pushdown optimization will run.
116
// Possible cases:
117
// - NO FILTERS: run predicate pd from the cache nodes -> finish
118
// - Above the filters the caches are the same -> run predicate pd from the filter node -> finish
119
// - There is a cache without predicates above the cache node -> run predicate form the cache nodes -> finish
120
// - The predicates above the cache nodes are all different -> remove the cache nodes -> finish
121
#[expect(clippy::too_many_arguments)]
122
pub(super) fn set_cache_states(
123
root: Node,
124
lp_arena: &mut Arena<IR>,
125
expr_arena: &mut Arena<AExpr>,
126
scratch: &mut Vec<Node>,
127
expr_eval: ExprEval<'_>,
128
verbose: bool,
129
pushdown_maintain_errors: bool,
130
new_streaming: bool,
131
) -> PolarsResult<()> {
132
let mut stack = Vec::with_capacity(4);
133
let mut names_scratch = vec![];
134
let mut predicates_scratch = vec![];
135
136
scratch.clear();
137
stack.clear();
138
139
#[derive(Default)]
140
struct Value {
141
// All the children of the cache per cache-id.
142
children: Vec<Node>,
143
parents: Vec<TwoParents>,
144
cache_nodes: Vec<Node>,
145
// Union over projected names.
146
names_union: PlHashSet<PlSmallStr>,
147
// Union over predicates.
148
predicate_union: PlHashMap<Expr, u32>,
149
}
150
let mut cache_schema_and_children = BTreeMap::new();
151
152
// Stack frame
153
#[derive(Default, Clone)]
154
struct Frame {
155
current: Node,
156
cache_id: Option<UniqueId>,
157
parent: TwoParents,
158
previous_cache: Option<UniqueId>,
159
}
160
let init = Frame {
161
current: root,
162
..Default::default()
163
};
164
165
stack.push(init);
166
167
// # First traversal.
168
// Collect the union of columns per cache id.
169
// And find the cache parents.
170
while let Some(mut frame) = stack.pop() {
171
let lp = lp_arena.get(frame.current);
172
lp.copy_inputs(scratch);
173
174
use IR::*;
175
176
if let Cache { input, id, .. } = lp {
177
if let Some(cache_id) = frame.cache_id {
178
frame.previous_cache = Some(cache_id)
179
}
180
if frame.parent[0].is_some() {
181
// Projection pushdown has already run and blocked on cache nodes
182
// the pushed down columns are projected just above this cache
183
// if there were no pushed down column, we just take the current
184
// nodes schema
185
// we never want to naively take parents, as a join or aggregate for instance
186
// change the schema
187
188
let v = cache_schema_and_children
189
.entry(*id)
190
.or_insert_with(Value::default);
191
v.children.push(*input);
192
v.parents.push(frame.parent);
193
v.cache_nodes.push(frame.current);
194
195
let mut found_required_columns = false;
196
197
for parent_node in frame.parent.into_iter().flatten() {
198
let keep_going = get_upper_projections(
199
parent_node,
200
lp_arena,
201
expr_arena,
202
&mut names_scratch,
203
&mut found_required_columns,
204
);
205
if !names_scratch.is_empty() {
206
v.names_union.extend(names_scratch.drain(..));
207
}
208
// We stop early as we want to find the first projection node above the cache.
209
if !keep_going {
210
break;
211
}
212
}
213
214
for parent_node in frame.parent.into_iter().flatten() {
215
let keep_going = get_upper_predicates(
216
parent_node,
217
lp_arena,
218
expr_arena,
219
&mut predicates_scratch,
220
);
221
if !predicates_scratch.is_empty() {
222
for pred in predicates_scratch.drain(..) {
223
let count = v.predicate_union.entry(pred).or_insert(0);
224
*count += 1;
225
}
226
}
227
// We stop early as we want to find the first predicate node above the cache.
228
if !keep_going {
229
break;
230
}
231
}
232
233
// There was no explicit projection and we must take
234
// all columns
235
if !found_required_columns {
236
let schema = lp.schema(lp_arena);
237
v.names_union.extend(schema.iter_names_cloned());
238
}
239
}
240
frame.cache_id = Some(*id);
241
};
242
243
// Shift parents.
244
frame.parent[1] = frame.parent[0];
245
frame.parent[0] = Some(frame.current);
246
for n in scratch.iter() {
247
let mut new_frame = frame.clone();
248
new_frame.current = *n;
249
stack.push(new_frame);
250
}
251
scratch.clear();
252
}
253
254
// # Second pass.
255
// we create a subtree where we project the columns
256
// just before the cache. Then we do another projection pushdown
257
// and finally remove that last projection and stitch the subplan
258
// back to the cache node again
259
if !cache_schema_and_children.is_empty() {
260
let mut proj_pd = ProjectionPushDown::new();
261
let mut pred_pd =
262
PredicatePushDown::new(expr_eval, pushdown_maintain_errors, new_streaming)
263
.block_at_cache(false);
264
for (_cache_id, v) in cache_schema_and_children {
265
// # CHECK IF WE NEED TO REMOVE CACHES
266
// If we encounter multiple predicates we remove the cache nodes completely as we don't
267
// want to loose predicate pushdown in favor of scan sharing.
268
if v.predicate_union.len() > 1 {
269
if verbose {
270
eprintln!("cache nodes will be removed because predicates don't match")
271
}
272
for ((&child, cache), parents) in
273
v.children.iter().zip(v.cache_nodes).zip(v.parents)
274
{
275
// Remove the cache and assign the child the cache location.
276
lp_arena.swap(child, cache);
277
278
// Restart predicate and projection pushdown from most top parent.
279
// This to ensure we continue the optimization where it was blocked initially.
280
// We pick up the blocked filter and projection.
281
let mut node = cache;
282
for p_node in parents.into_iter().flatten() {
283
if matches!(
284
lp_arena.get(p_node),
285
IR::Filter { .. } | IR::SimpleProjection { .. }
286
) {
287
node = p_node
288
} else {
289
break;
290
}
291
}
292
293
let lp = lp_arena.take(node);
294
let lp = proj_pd.optimize(lp, lp_arena, expr_arena)?;
295
let lp = pred_pd.optimize(lp, lp_arena, expr_arena)?;
296
lp_arena.replace(node, lp);
297
}
298
return Ok(());
299
}
300
// Below we restart projection and predicates pushdown
301
// on the first cache node. As it are cache nodes, the others are the same
302
// and we can reuse the optimized state for all inputs.
303
// See #21637
304
305
// # RUN PROJECTION PUSHDOWN
306
if !v.names_union.is_empty() {
307
let first_child = *v.children.first().expect("at least on child");
308
309
let columns = &v.names_union;
310
let child_lp = lp_arena.take(first_child);
311
312
// Make sure we project in the order of the schema
313
// if we don't a union may fail as we would project by the
314
// order we discovered all values.
315
let child_schema = child_lp.schema(lp_arena);
316
let child_schema = child_schema.as_ref();
317
let projection = child_schema
318
.iter_names()
319
.flat_map(|name| columns.get(name.as_str()).cloned())
320
.collect::<Vec<_>>();
321
322
let new_child = lp_arena.add(child_lp);
323
324
let lp = IRBuilder::new(new_child, expr_arena, lp_arena)
325
.project_simple(projection)
326
.expect("unique names")
327
.build();
328
329
let lp = proj_pd.optimize(lp, lp_arena, expr_arena)?;
330
// Optimization can lead to a double projection. Only take the last.
331
let lp = if let IR::SimpleProjection { input, columns } = lp {
332
let input =
333
if let IR::SimpleProjection { input: input2, .. } = lp_arena.get(input) {
334
*input2
335
} else {
336
input
337
};
338
IR::SimpleProjection { input, columns }
339
} else {
340
lp
341
};
342
lp_arena.replace(first_child, lp.clone());
343
344
// Set the remaining children to the same node.
345
for &child in &v.children[1..] {
346
lp_arena.replace(child, lp.clone());
347
}
348
} else {
349
// No upper projections to include, run projection pushdown from cache node.
350
let first_child = *v.children.first().expect("at least on child");
351
let child_lp = lp_arena.take(first_child);
352
let lp = proj_pd.optimize(child_lp, lp_arena, expr_arena)?;
353
lp_arena.replace(first_child, lp.clone());
354
355
for &child in &v.children[1..] {
356
lp_arena.replace(child, lp.clone());
357
}
358
}
359
360
// # RUN PREDICATE PUSHDOWN
361
// Run this after projection pushdown, otherwise the predicate columns will not be projected.
362
363
// - If all predicates of parent are the same we will restart predicate pushdown from the parent FILTER node.
364
// - Otherwise we will start predicate pushdown from the cache node.
365
let allow_parent_predicate_pushdown = v.predicate_union.len() == 1 && {
366
let (_pred, count) = v.predicate_union.iter().next().unwrap();
367
*count == v.children.len() as u32
368
};
369
370
if allow_parent_predicate_pushdown {
371
let parents = *v.parents.first().unwrap();
372
let node = get_filter_node(parents, lp_arena)
373
.expect("expected filter; this is an optimizer bug");
374
let start_lp = lp_arena.take(node);
375
let lp = pred_pd.optimize(start_lp, lp_arena, expr_arena)?;
376
lp_arena.replace(node, lp.clone());
377
for &parents in &v.parents[1..] {
378
let node = get_filter_node(parents, lp_arena)
379
.expect("expected filter; this is an optimizer bug");
380
lp_arena.replace(node, lp.clone());
381
}
382
} else {
383
let child = *v.children.first().unwrap();
384
let child_lp = lp_arena.take(child);
385
let lp = pred_pd.optimize(child_lp, lp_arena, expr_arena)?;
386
lp_arena.replace(child, lp.clone());
387
for &child in &v.children[1..] {
388
lp_arena.replace(child, lp.clone());
389
}
390
}
391
}
392
}
393
Ok(())
394
}
395
396
fn get_filter_node(parents: TwoParents, lp_arena: &Arena<IR>) -> Option<Node> {
397
parents
398
.into_iter()
399
.flatten()
400
.find(|&parent| matches!(lp_arena.get(parent), IR::Filter { .. }))
401
}
402
403