Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/plans/optimizer/cse/cache_states.rs
7889 views
1
use std::collections::BTreeMap;
2
3
use polars_utils::unique_id::UniqueId;
4
5
use super::*;
6
7
fn get_upper_projections(
8
parent: Node,
9
lp_arena: &Arena<IR>,
10
expr_arena: &Arena<AExpr>,
11
names_scratch: &mut Vec<PlSmallStr>,
12
found_required_columns: &mut bool,
13
) -> bool {
14
let parent = lp_arena.get(parent);
15
16
use IR::*;
17
// During projection pushdown all accumulated.
18
match parent {
19
SimpleProjection { columns, .. } => {
20
let iter = columns.iter_names_cloned();
21
names_scratch.extend(iter);
22
*found_required_columns = true;
23
false
24
},
25
Filter { predicate, .. } => {
26
// Also add predicate, as the projection is above the filter node.
27
names_scratch.extend(aexpr_to_leaf_names(predicate.node(), expr_arena));
28
29
true
30
},
31
// Only filter and projection nodes are allowed, any other node we stop.
32
_ => false,
33
}
34
}
35
36
fn get_upper_predicates(
37
parent: Node,
38
lp_arena: &Arena<IR>,
39
expr_arena: &mut Arena<AExpr>,
40
predicate_scratch: &mut Vec<Expr>,
41
) -> bool {
42
let parent = lp_arena.get(parent);
43
44
use IR::*;
45
match parent {
46
Filter { predicate, .. } => {
47
let expr = predicate.to_expr(expr_arena);
48
predicate_scratch.push(expr);
49
false
50
},
51
SimpleProjection { .. } => true,
52
// Only filter and projection nodes are allowed, any other node we stop.
53
_ => false,
54
}
55
}
56
57
type TwoParents = [Option<Node>; 2];
58
59
// 1. This will ensure that all equal caches communicate the amount of columns
60
// they need to project.
61
// 2. This will ensure we apply predicate in the subtrees below the caches.
62
// If the predicate above the cache is the same for all matching caches, that filter will be
63
// applied as well.
64
//
65
// # Example
66
// Consider this tree, where `SUB-TREE` is duplicate and can be cached.
67
//
68
//
69
// Tree
70
// |
71
// |
72
// |--------------------|-------------------|
73
// | |
74
// SUB-TREE SUB-TREE
75
//
76
// STEPS:
77
// - 1. CSE will run and will insert cache nodes
78
//
79
// Tree
80
// |
81
// |
82
// |--------------------|-------------------|
83
// | |
84
// | CACHE 0 | CACHE 0
85
// | |
86
// SUB-TREE SUB-TREE
87
//
88
// - 2. predicate and projection pushdown will run and will insert optional FILTER and PROJECTION above the caches
89
//
90
// Tree
91
// |
92
// |
93
// |--------------------|-------------------|
94
// | FILTER (optional) | FILTER (optional)
95
// | PROJ (optional) | PROJ (optional)
96
// | |
97
// | CACHE 0 | CACHE 0
98
// | |
99
// SUB-TREE SUB-TREE
100
//
101
// # Projection optimization
102
// The union of the projection is determined and the projection will be pushed down.
103
//
104
// Tree
105
// |
106
// |
107
// |--------------------|-------------------|
108
// | FILTER (optional) | FILTER (optional)
109
// | CACHE 0 | CACHE 0
110
// | |
111
// SUB-TREE SUB-TREE
112
// UNION PROJ (optional) UNION PROJ (optional)
113
//
114
// # Filter optimization
115
// Depending on the predicates the predicate pushdown optimization will run.
116
// Possible cases:
117
// - NO FILTERS: run predicate pd from the cache nodes -> finish
118
// - Above the filters the caches are the same -> run predicate pd from the filter node -> finish
119
// - There is a cache without predicates above the cache node -> run predicate form the cache nodes -> finish
120
// - The predicates above the cache nodes are all different -> remove the cache nodes -> finish
121
pub(super) fn set_cache_states(
122
root: Node,
123
lp_arena: &mut Arena<IR>,
124
expr_arena: &mut Arena<AExpr>,
125
scratch: &mut Vec<Node>,
126
verbose: bool,
127
pushdown_maintain_errors: bool,
128
new_streaming: bool,
129
) -> PolarsResult<()> {
130
let mut stack = Vec::with_capacity(4);
131
let mut names_scratch = vec![];
132
let mut predicates_scratch = vec![];
133
134
scratch.clear();
135
stack.clear();
136
137
#[derive(Default)]
138
struct Value {
139
// All the children of the cache per cache-id.
140
children: Vec<Node>,
141
parents: Vec<TwoParents>,
142
cache_nodes: Vec<Node>,
143
// Union over projected names.
144
names_union: PlHashSet<PlSmallStr>,
145
// Union over predicates.
146
predicate_union: PlHashMap<Expr, u32>,
147
}
148
let mut cache_schema_and_children = BTreeMap::new();
149
150
// Stack frame
151
#[derive(Default, Clone)]
152
struct Frame {
153
current: Node,
154
cache_id: Option<UniqueId>,
155
parent: TwoParents,
156
previous_cache: Option<UniqueId>,
157
}
158
let init = Frame {
159
current: root,
160
..Default::default()
161
};
162
163
stack.push(init);
164
165
// # First traversal.
166
// Collect the union of columns per cache id.
167
// And find the cache parents.
168
while let Some(mut frame) = stack.pop() {
169
let lp = lp_arena.get(frame.current);
170
lp.copy_inputs(scratch);
171
172
use IR::*;
173
174
if let Cache { input, id, .. } = lp {
175
if let Some(cache_id) = frame.cache_id {
176
frame.previous_cache = Some(cache_id)
177
}
178
if frame.parent[0].is_some() {
179
// Projection pushdown has already run and blocked on cache nodes
180
// the pushed down columns are projected just above this cache
181
// if there were no pushed down column, we just take the current
182
// nodes schema
183
// we never want to naively take parents, as a join or aggregate for instance
184
// change the schema
185
186
let v = cache_schema_and_children
187
.entry(*id)
188
.or_insert_with(Value::default);
189
v.children.push(*input);
190
v.parents.push(frame.parent);
191
v.cache_nodes.push(frame.current);
192
193
let mut found_required_columns = false;
194
195
for parent_node in frame.parent.into_iter().flatten() {
196
let keep_going = get_upper_projections(
197
parent_node,
198
lp_arena,
199
expr_arena,
200
&mut names_scratch,
201
&mut found_required_columns,
202
);
203
if !names_scratch.is_empty() {
204
v.names_union.extend(names_scratch.drain(..));
205
}
206
// We stop early as we want to find the first projection node above the cache.
207
if !keep_going {
208
break;
209
}
210
}
211
212
for parent_node in frame.parent.into_iter().flatten() {
213
let keep_going = get_upper_predicates(
214
parent_node,
215
lp_arena,
216
expr_arena,
217
&mut predicates_scratch,
218
);
219
if !predicates_scratch.is_empty() {
220
for pred in predicates_scratch.drain(..) {
221
let count = v.predicate_union.entry(pred).or_insert(0);
222
*count += 1;
223
}
224
}
225
// We stop early as we want to find the first predicate node above the cache.
226
if !keep_going {
227
break;
228
}
229
}
230
231
// There was no explicit projection and we must take
232
// all columns
233
if !found_required_columns {
234
let schema = lp.schema(lp_arena);
235
v.names_union.extend(schema.iter_names_cloned());
236
}
237
}
238
frame.cache_id = Some(*id);
239
};
240
241
// Shift parents.
242
frame.parent[1] = frame.parent[0];
243
frame.parent[0] = Some(frame.current);
244
for n in scratch.iter() {
245
let mut new_frame = frame.clone();
246
new_frame.current = *n;
247
stack.push(new_frame);
248
}
249
scratch.clear();
250
}
251
252
// # Second pass.
253
// we create a subtree where we project the columns
254
// just before the cache. Then we do another projection pushdown
255
// and finally remove that last projection and stitch the subplan
256
// back to the cache node again
257
if !cache_schema_and_children.is_empty() {
258
let mut proj_pd = ProjectionPushDown::new();
259
let mut pred_pd = PredicatePushDown::new(pushdown_maintain_errors, new_streaming);
260
for (_cache_id, v) in cache_schema_and_children {
261
// # CHECK IF WE NEED TO REMOVE CACHES
262
// If we encounter multiple predicates we remove the cache nodes completely as we don't
263
// want to loose predicate pushdown in favor of scan sharing.
264
if v.predicate_union.len() > 1 {
265
if verbose {
266
eprintln!("cache nodes will be removed because predicates don't match")
267
}
268
for ((&child, cache), parents) in
269
v.children.iter().zip(v.cache_nodes).zip(v.parents)
270
{
271
// Remove the cache and assign the child the cache location.
272
lp_arena.swap(child, cache);
273
274
// Restart predicate and projection pushdown from most top parent.
275
// This to ensure we continue the optimization where it was blocked initially.
276
// We pick up the blocked filter and projection.
277
let mut node = cache;
278
for p_node in parents.into_iter().flatten() {
279
if matches!(
280
lp_arena.get(p_node),
281
IR::Filter { .. } | IR::SimpleProjection { .. }
282
) {
283
node = p_node
284
} else {
285
break;
286
}
287
}
288
289
let lp = lp_arena.take(node);
290
let lp = proj_pd.optimize(lp, lp_arena, expr_arena)?;
291
let lp = pred_pd.optimize(lp, lp_arena, expr_arena)?;
292
lp_arena.replace(node, lp);
293
}
294
return Ok(());
295
}
296
// Below we restart projection and predicates pushdown
297
// on the first cache node. As it are cache nodes, the others are the same
298
// and we can reuse the optimized state for all inputs.
299
// See #21637
300
301
// # RUN PROJECTION PUSHDOWN
302
if !v.names_union.is_empty() {
303
let first_child = *v.children.first().expect("at least on child");
304
305
let columns = &v.names_union;
306
let child_lp = lp_arena.take(first_child);
307
308
// Make sure we project in the order of the schema
309
// if we don't a union may fail as we would project by the
310
// order we discovered all values.
311
let child_schema = child_lp.schema(lp_arena);
312
let child_schema = child_schema.as_ref();
313
let projection = child_schema
314
.iter_names()
315
.flat_map(|name| columns.get(name.as_str()).cloned())
316
.collect::<Vec<_>>();
317
318
let new_child = lp_arena.add(child_lp);
319
320
let lp = IRBuilder::new(new_child, expr_arena, lp_arena)
321
.project_simple(projection)
322
.expect("unique names")
323
.build();
324
325
let lp = proj_pd.optimize(lp, lp_arena, expr_arena)?;
326
// Optimization can lead to a double projection. Only take the last.
327
let lp = if let IR::SimpleProjection { input, columns } = lp {
328
let input =
329
if let IR::SimpleProjection { input: input2, .. } = lp_arena.get(input) {
330
*input2
331
} else {
332
input
333
};
334
IR::SimpleProjection { input, columns }
335
} else {
336
lp
337
};
338
lp_arena.replace(first_child, lp.clone());
339
340
// Set the remaining children to the same node.
341
for &child in &v.children[1..] {
342
lp_arena.replace(child, lp.clone());
343
}
344
} else {
345
// No upper projections to include, run projection pushdown from cache node.
346
let first_child = *v.children.first().expect("at least on child");
347
let child_lp = lp_arena.take(first_child);
348
let lp = proj_pd.optimize(child_lp, lp_arena, expr_arena)?;
349
lp_arena.replace(first_child, lp.clone());
350
351
for &child in &v.children[1..] {
352
lp_arena.replace(child, lp.clone());
353
}
354
}
355
356
// # RUN PREDICATE PUSHDOWN
357
// Run this after projection pushdown, otherwise the predicate columns will not be projected.
358
359
// - If all predicates of parent are the same we will restart predicate pushdown from the parent FILTER node.
360
// - Otherwise we will start predicate pushdown from the cache node.
361
let allow_parent_predicate_pushdown = v.predicate_union.len() == 1 && {
362
let (_pred, count) = v.predicate_union.iter().next().unwrap();
363
*count == v.children.len() as u32
364
};
365
366
if allow_parent_predicate_pushdown {
367
let parents = *v.parents.first().unwrap();
368
let node = get_filter_node(parents, lp_arena)
369
.expect("expected filter; this is an optimizer bug");
370
let start_lp = lp_arena.take(node);
371
372
let mut pred_pd = PredicatePushDown::new(pushdown_maintain_errors, new_streaming)
373
.block_at_cache(1);
374
let lp = pred_pd.optimize(start_lp, lp_arena, expr_arena)?;
375
lp_arena.replace(node, lp.clone());
376
for &parents in &v.parents[1..] {
377
let node = get_filter_node(parents, lp_arena)
378
.expect("expected filter; this is an optimizer bug");
379
lp_arena.replace(node, lp.clone());
380
}
381
} else {
382
let child = *v.children.first().unwrap();
383
let child_lp = lp_arena.take(child);
384
let lp = pred_pd.optimize(child_lp, lp_arena, expr_arena)?;
385
lp_arena.replace(child, lp.clone());
386
for &child in &v.children[1..] {
387
lp_arena.replace(child, lp.clone());
388
}
389
}
390
}
391
}
392
Ok(())
393
}
394
395
fn get_filter_node(parents: TwoParents, lp_arena: &Arena<IR>) -> Option<Node> {
396
parents
397
.into_iter()
398
.flatten()
399
.find(|&parent| matches!(lp_arena.get(parent), IR::Filter { .. }))
400
}
401
402