Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/plans/prune.rs
6940 views
1
//! IR pruning. Pruning copies the reachable IR and expressions into a set of destination arenas.
2
3
use polars_core::prelude::{InitHashMaps as _, PlHashMap};
4
use polars_utils::arena::{Arena, Node};
5
use polars_utils::unique_id::UniqueId;
6
use recursive::recursive;
7
8
use crate::plans::{AExpr, IR, IRPlan, IRPlanRef};
9
10
/// Returns a pruned copy of this plan with new arenas (without unreachable nodes).
11
///
12
/// The cache hit count is updated based on the number of consumers in the pruned plan.
13
///
14
/// The original plan and arenas are not modified.
15
pub fn prune_plan(ir_plan: IRPlanRef<'_>) -> IRPlan {
16
let mut ir_arena = Arena::new();
17
let mut expr_arena = Arena::new();
18
let [root] = prune(
19
&[ir_plan.lp_top],
20
ir_plan.lp_arena,
21
ir_plan.expr_arena,
22
&mut ir_arena,
23
&mut expr_arena,
24
)
25
.try_into()
26
.unwrap();
27
IRPlan {
28
lp_top: root,
29
lp_arena: ir_arena,
30
expr_arena,
31
}
32
}
33
34
/// Prunes a subgraph reachable from the supplied roots into the supplied arenas.
35
///
36
/// The returned nodes point to the pruned copies of the supplied roots in the same order.
37
///
38
/// The cache hit count is updated based on the number of consumers in the pruned subgraph.
39
///
40
/// The original plan and arenas are not modified.
41
pub fn prune(
42
roots: &[Node],
43
src_ir: &Arena<IR>,
44
src_expr: &Arena<AExpr>,
45
dst_ir: &mut Arena<IR>,
46
dst_expr: &mut Arena<AExpr>,
47
) -> Vec<Node> {
48
let mut ctx = CopyContext {
49
src_ir,
50
src_expr,
51
dst_ir,
52
dst_expr,
53
dst_caches: PlHashMap::new(),
54
roots: PlHashMap::from_iter(roots.iter().map(|node| (*node, None))),
55
};
56
57
let dst_roots: Vec<Node> = roots.iter().map(|&root| ctx.copy_ir(root)).collect();
58
59
assert!(ctx.roots.values().all(|v| v.is_some()));
60
61
dst_roots
62
}
63
64
struct CopyContext<'a> {
65
src_ir: &'a Arena<IR>,
66
src_expr: &'a Arena<AExpr>,
67
dst_ir: &'a mut Arena<IR>,
68
dst_expr: &'a mut Arena<AExpr>,
69
// Caches and the matching dst nodes.
70
dst_caches: PlHashMap<UniqueId, Node>,
71
// Root nodes and the matching dst nodes. Needed to ensure they are visited only once,
72
// in case they are reachable from other root nodes.
73
roots: PlHashMap<Node, Option<Node>>,
74
}
75
76
impl<'a> CopyContext<'a> {
77
// Copies the IR subgraph from src to dst.
78
#[recursive]
79
fn copy_ir(&mut self, src_node: Node) -> Node {
80
// If this cache was already visited, bump the cache hits and don't traverse it.
81
// This is before the root node check, so that the hit count gets bumped for every visit.
82
if let IR::Cache { id, .. } = self.src_ir.get(src_node) {
83
if let Some(cache) = self.dst_caches.get(id) {
84
return *cache;
85
}
86
}
87
88
// If this is one of the root nodes and was already visited, don't visit again, just return
89
// the matching dst node.
90
if let Some(&Some(root_node)) = self.roots.get(&src_node) {
91
return root_node;
92
}
93
94
let src_ir = self.src_ir.get(src_node);
95
96
let mut dst_ir = src_ir.clone();
97
98
// Recurse into inputs
99
dst_ir = dst_ir.with_inputs(src_ir.inputs().map(|input| self.copy_ir(input)));
100
101
// Recurse into expressions
102
dst_ir = dst_ir.with_exprs(src_ir.exprs().map(|expr| {
103
let mut expr = expr.clone();
104
expr.set_node(self.copy_expr(expr.node()));
105
expr
106
}));
107
108
// Add this node
109
let dst_node = self.dst_ir.add(dst_ir);
110
111
// If this is a cache, reset the hit count and store the dst node.
112
if let IR::Cache { id, .. } = self.dst_ir.get_mut(dst_node) {
113
let prev = self.dst_caches.insert(*id, dst_node);
114
assert!(prev.is_none(), "cache {id} was traversed twice");
115
}
116
117
// If this is one of the root nodes, store the dst node.
118
self.roots.entry(src_node).and_modify(|e| {
119
assert!(
120
e.replace(dst_node).is_none(),
121
"root node was traversed twice"
122
)
123
});
124
125
dst_node
126
}
127
128
/// Copies the expression subgraph from src to dst.
129
#[recursive]
130
fn copy_expr(&mut self, node: Node) -> Node {
131
let expr = self.src_expr.get(node);
132
133
let mut inputs = vec![];
134
expr.inputs_rev(&mut inputs);
135
136
for input in &mut inputs {
137
*input = self.copy_expr(*input);
138
}
139
inputs.reverse();
140
141
let mut dst_expr = expr.clone().replace_inputs(&inputs);
142
143
// Fix up eval, the evaluation subtree is not treated as an input,
144
// so it needs to be copied manually.
145
if let AExpr::Eval { evaluation, .. } = &mut dst_expr {
146
*evaluation = self.copy_expr(*evaluation);
147
}
148
149
self.dst_expr.add(dst_expr)
150
}
151
}
152
153
#[cfg(test)]
154
mod tests {
155
use polars_core::prelude::*;
156
157
use super::*;
158
use crate::dsl::{SinkTypeIR, col, lit};
159
use crate::plans::{ArenaLpIter as _, ExprToIRContext, to_expr_ir};
160
161
// SINK[right]
162
// |
163
// SINK[left] SORT SINK[extra]
164
// | / /
165
// CACHE ----+--------+
166
// |
167
// FILTER
168
// |
169
// SCAN
170
struct BranchedPlan {
171
ir_arena: Arena<IR>,
172
expr_arena: Arena<AExpr>,
173
scan: Node,
174
filter: Node,
175
cache: Node,
176
left_sink: Node,
177
sort: Node,
178
right_sink: Node,
179
extra_sink: Node,
180
}
181
182
#[test]
183
fn test_pruned_subgraph_matches() {
184
let p = BranchedPlan::new();
185
186
#[rustfmt::skip]
187
let cases: &[&[Node]] = &[
188
// Single
189
&[p.scan],
190
&[p.cache],
191
&[p.left_sink],
192
&[p.right_sink],
193
// Multiple
194
&[p.left_sink, p.right_sink],
195
&[p.left_sink, p.right_sink, p.extra_sink],
196
// Duplicate
197
&[p.left_sink, p.left_sink],
198
&[p.cache, p.cache],
199
// A mess
200
&[p.filter, p.scan, p.left_sink, p.cache, p.right_sink, p.sort, p.cache, p.right_sink],
201
];
202
203
for &case in cases.iter() {
204
let (pruned, arenas) = p.prune(case);
205
for (&orig, pruned) in case.iter().zip(pruned) {
206
let orig_plan = p.plan(orig);
207
let pruned_plan = arenas.plan(pruned);
208
assert!(
209
plans_equal(orig_plan, pruned_plan),
210
"orig: {}, pruned: {}",
211
orig_plan.display(),
212
pruned_plan.display()
213
);
214
}
215
}
216
}
217
218
#[test]
219
fn test_pruned_arena_size() {
220
let p = BranchedPlan::new();
221
222
#[rustfmt::skip]
223
let cases: &[(&[Node], usize)] = &[
224
(&[p.scan], 1),
225
(&[p.cache], 3),
226
(&[p.cache, p.cache], 3),
227
(&[p.left_sink], 4),
228
(&[p.left_sink, p.left_sink], 4),
229
(&[p.right_sink], 5),
230
(&[p.left_sink, p.right_sink], 6),
231
(&[p.filter, p.scan, p.left_sink, p.cache, p.right_sink, p.sort, p.cache, p.right_sink], 6),
232
(&[p.left_sink, p.right_sink, p.extra_sink], 7),
233
];
234
235
for (i, &(case, expected_arena_size)) in cases.iter().enumerate() {
236
let (_, arenas) = p.prune(case);
237
assert_eq!(
238
arenas.ir.len(),
239
expected_arena_size,
240
"case: {i}, pruned_ir: {:?}",
241
arenas.ir
242
);
243
}
244
}
245
246
fn plans_equal(a: IRPlanRef<'_>, b: IRPlanRef<'_>) -> bool {
247
let iter_a = a.lp_arena.iter(a.lp_top);
248
let iter_b = b.lp_arena.iter(b.lp_top);
249
for ((_, ir_a), (_, ir_b)) in iter_a.zip(iter_b) {
250
if std::mem::discriminant(ir_a) != std::mem::discriminant(ir_b)
251
|| !exprs_equal(ir_a, a.expr_arena, ir_b, b.expr_arena)
252
{
253
return false;
254
}
255
}
256
true
257
}
258
259
fn exprs_equal(ir_a: &IR, arena_a: &Arena<AExpr>, ir_b: &IR, arena_b: &Arena<AExpr>) -> bool {
260
let [a, b] = [(ir_a, arena_a), (ir_b, arena_b)].map(|(ir, arena)| {
261
ir.exprs()
262
.map(|e| (e.output_name_inner().clone(), e.to_expr(arena)))
263
});
264
a.eq(b)
265
}
266
267
impl BranchedPlan {
268
pub fn new() -> Self {
269
let mut ir_arena = Arena::new();
270
let mut expr_arena = Arena::new();
271
let schema = Schema::from_iter([Field::new("a".into(), DataType::UInt8)]);
272
273
let scan = ir_arena.add(IR::DataFrameScan {
274
df: Arc::new(DataFrame::empty_with_schema(&schema)),
275
schema: Arc::new(schema.clone()),
276
output_schema: None,
277
});
278
279
let mut ctx = ExprToIRContext::new(&mut expr_arena, &schema);
280
ctx.allow_unknown = true;
281
let filter = ir_arena.add(IR::Filter {
282
input: scan,
283
predicate: to_expr_ir(col("a").gt_eq(lit(10)), &mut ctx).unwrap(),
284
});
285
286
// Throw in an unreachable node
287
ir_arena.add(IR::Invalid);
288
289
let cache = ir_arena.add(IR::Cache {
290
input: filter,
291
id: UniqueId::new(),
292
});
293
294
let left_sink = ir_arena.add(IR::Sink {
295
input: cache,
296
payload: SinkTypeIR::Memory,
297
});
298
299
// Throw in an unreachable node
300
ir_arena.add(IR::Invalid);
301
302
let mut ctx = ExprToIRContext::new(&mut expr_arena, &schema);
303
ctx.allow_unknown = true;
304
let sort = ir_arena.add(IR::Sort {
305
input: cache,
306
by_column: vec![to_expr_ir(col("a"), &mut ctx).unwrap()],
307
slice: None,
308
sort_options: Default::default(),
309
});
310
311
let right_sink = ir_arena.add(IR::Sink {
312
input: sort,
313
payload: SinkTypeIR::Memory,
314
});
315
316
// Throw in an unused sink
317
let extra_sink = ir_arena.add(IR::Sink {
318
input: cache,
319
payload: SinkTypeIR::Memory,
320
});
321
322
Self {
323
ir_arena,
324
expr_arena,
325
scan,
326
filter,
327
cache,
328
left_sink,
329
sort,
330
right_sink,
331
extra_sink,
332
}
333
}
334
335
pub fn prune(&self, roots: &[Node]) -> (Vec<Node>, Arenas) {
336
let mut arenas = Arenas {
337
ir: Arena::new(),
338
expr: Arena::new(),
339
};
340
let pruned = prune(
341
roots,
342
&self.ir_arena,
343
&self.expr_arena,
344
&mut arenas.ir,
345
&mut arenas.expr,
346
);
347
(pruned, arenas)
348
}
349
350
pub fn plan(&'_ self, node: Node) -> IRPlanRef<'_> {
351
IRPlanRef {
352
lp_top: node,
353
lp_arena: &self.ir_arena,
354
expr_arena: &self.expr_arena,
355
}
356
}
357
}
358
359
struct Arenas {
360
ir: Arena<IR>,
361
expr: Arena<AExpr>,
362
}
363
364
impl Arenas {
365
pub fn plan(&'_ self, root: Node) -> IRPlanRef<'_> {
366
IRPlanRef {
367
lp_top: root,
368
lp_arena: &self.ir,
369
expr_arena: &self.expr,
370
}
371
}
372
}
373
}
374
375