Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/plans/prune.rs
8458 views
1
//! IR pruning. Pruning copies the reachable IR and expressions into a set of destination arenas.
2
3
use polars_core::prelude::{InitHashMaps as _, PlHashMap};
4
use polars_utils::arena::{Arena, Node};
5
use polars_utils::unique_id::UniqueId;
6
use recursive::recursive;
7
8
use crate::plans::{AExpr, ExprIR, IR, IRPlan, IRPlanRef};
9
10
/// Returns a pruned copy of this plan with new arenas (without unreachable nodes).
11
///
12
/// The cache hit count is updated based on the number of consumers in the pruned plan.
13
///
14
/// The original plan and arenas are not modified.
15
pub fn prune_plan(ir_plan: IRPlanRef<'_>) -> IRPlan {
16
let mut ir_arena = Arena::new();
17
let mut expr_arena = Arena::new();
18
let [root] = prune(
19
&[ir_plan.lp_top],
20
ir_plan.lp_arena,
21
ir_plan.expr_arena,
22
&mut ir_arena,
23
&mut expr_arena,
24
)
25
.try_into()
26
.unwrap();
27
IRPlan {
28
lp_top: root,
29
lp_arena: ir_arena,
30
expr_arena,
31
}
32
}
33
34
/// Prunes a subgraph reachable from the supplied roots into the supplied arenas.
35
///
36
/// The returned nodes point to the pruned copies of the supplied roots in the same order.
37
///
38
/// The cache hit count is updated based on the number of consumers in the pruned subgraph.
39
///
40
/// The original plan and arenas are not modified.
41
pub fn prune(
42
roots: &[Node],
43
src_ir: &Arena<IR>,
44
src_expr: &Arena<AExpr>,
45
dst_ir: &mut Arena<IR>,
46
dst_expr: &mut Arena<AExpr>,
47
) -> Vec<Node> {
48
let mut ctx = CopyContext {
49
src_ir,
50
src_expr,
51
dst_ir,
52
dst_expr,
53
dst_caches: PlHashMap::new(),
54
roots: PlHashMap::from_iter(roots.iter().map(|node| (*node, None))),
55
};
56
57
let dst_roots: Vec<Node> = roots.iter().map(|&root| ctx.copy_ir(root)).collect();
58
59
assert!(ctx.roots.values().all(|v| v.is_some()));
60
61
dst_roots
62
}
63
64
struct CopyContext<'a> {
65
src_ir: &'a Arena<IR>,
66
src_expr: &'a Arena<AExpr>,
67
dst_ir: &'a mut Arena<IR>,
68
dst_expr: &'a mut Arena<AExpr>,
69
// Caches and the matching dst nodes.
70
dst_caches: PlHashMap<UniqueId, Node>,
71
// Root nodes and the matching dst nodes. Needed to ensure they are visited only once,
72
// in case they are reachable from other root nodes.
73
roots: PlHashMap<Node, Option<Node>>,
74
}
75
76
impl<'a> CopyContext<'a> {
77
// Copies the IR subgraph from src to dst.
78
#[recursive]
79
fn copy_ir(&mut self, src_node: Node) -> Node {
80
// If this cache was already visited, bump the cache hits and don't traverse it.
81
// This is before the root node check, so that the hit count gets bumped for every visit.
82
if let IR::Cache { id, .. } = self.src_ir.get(src_node) {
83
if let Some(cache) = self.dst_caches.get(id) {
84
return *cache;
85
}
86
}
87
88
// If this is one of the root nodes and was already visited, don't visit again, just return
89
// the matching dst node.
90
if let Some(&Some(root_node)) = self.roots.get(&src_node) {
91
return root_node;
92
}
93
94
let src_ir = self.src_ir.get(src_node);
95
96
let mut dst_ir = src_ir.clone();
97
98
// Recurse into inputs
99
dst_ir = dst_ir.with_inputs(src_ir.inputs().map(|input| self.copy_ir(input)));
100
101
// Recurse into expressions
102
dst_ir = dst_ir.with_exprs(src_ir.exprs().map(|expr| {
103
let mut expr = expr.clone();
104
expr.set_node(self.copy_expr(expr.node()));
105
expr
106
}));
107
108
// Add this node
109
let dst_node = self.dst_ir.add(dst_ir);
110
111
// If this is a cache, reset the hit count and store the dst node.
112
if let IR::Cache { id, .. } = self.dst_ir.get_mut(dst_node) {
113
let prev = self.dst_caches.insert(*id, dst_node);
114
assert!(prev.is_none(), "cache {id} was traversed twice");
115
}
116
117
// If this is one of the root nodes, store the dst node.
118
self.roots.entry(src_node).and_modify(|e| {
119
assert!(
120
e.replace(dst_node).is_none(),
121
"root node was traversed twice"
122
)
123
});
124
125
dst_node
126
}
127
128
/// Copies the expression subgraph from src to dst.
129
#[recursive]
130
fn copy_expr(&mut self, node: Node) -> Node {
131
let expr = self.src_expr.get(node);
132
133
let mut inputs = vec![];
134
expr.inputs_rev(&mut inputs);
135
136
for input in &mut inputs {
137
*input = self.copy_expr(*input);
138
}
139
inputs.reverse();
140
141
let mut dst_expr = expr.clone().replace_inputs(&inputs);
142
143
// Fix up eval, the evaluation subtree is not treated as an input,
144
// so it needs to be copied manually.
145
if let AExpr::Eval { evaluation, .. } = &mut dst_expr {
146
*evaluation = self.copy_expr(*evaluation);
147
}
148
#[cfg(feature = "dtype-struct")]
149
if let AExpr::StructEval { evaluation, .. } = &mut dst_expr {
150
for e in evaluation.iter_mut() {
151
*e = ExprIR::new(
152
self.copy_expr(e.node()),
153
crate::plans::OutputName::Alias(e.output_name().clone()),
154
);
155
}
156
}
157
158
self.dst_expr.add(dst_expr)
159
}
160
}
161
162
#[cfg(test)]
163
mod tests {
164
use polars_core::prelude::*;
165
166
use super::*;
167
use crate::dsl::SinkTypeIR;
168
use crate::dsl::functions::{col, lit};
169
use crate::plans::{ArenaLpIter as _, ExprToIRContext, to_expr_ir};
170
171
// SINK[right]
172
// |
173
// SINK[left] SORT SINK[extra]
174
// | / /
175
// CACHE ----+--------+
176
// |
177
// FILTER
178
// |
179
// SCAN
180
struct BranchedPlan {
181
ir_arena: Arena<IR>,
182
expr_arena: Arena<AExpr>,
183
scan: Node,
184
filter: Node,
185
cache: Node,
186
left_sink: Node,
187
sort: Node,
188
right_sink: Node,
189
extra_sink: Node,
190
}
191
192
#[test]
193
fn test_pruned_subgraph_matches() {
194
let p = BranchedPlan::new();
195
196
#[rustfmt::skip]
197
let cases: &[&[Node]] = &[
198
// Single
199
&[p.scan],
200
&[p.cache],
201
&[p.left_sink],
202
&[p.right_sink],
203
// Multiple
204
&[p.left_sink, p.right_sink],
205
&[p.left_sink, p.right_sink, p.extra_sink],
206
// Duplicate
207
&[p.left_sink, p.left_sink],
208
&[p.cache, p.cache],
209
// A mess
210
&[p.filter, p.scan, p.left_sink, p.cache, p.right_sink, p.sort, p.cache, p.right_sink],
211
];
212
213
for &case in cases.iter() {
214
let (pruned, arenas) = p.prune(case);
215
for (&orig, pruned) in case.iter().zip(pruned) {
216
let orig_plan = p.plan(orig);
217
let pruned_plan = arenas.plan(pruned);
218
assert!(
219
plans_equal(orig_plan, pruned_plan),
220
"orig: {}, pruned: {}",
221
orig_plan.display(),
222
pruned_plan.display()
223
);
224
}
225
}
226
}
227
228
#[test]
229
fn test_pruned_arena_size() {
230
let p = BranchedPlan::new();
231
232
#[rustfmt::skip]
233
let cases: &[(&[Node], usize)] = &[
234
(&[p.scan], 1),
235
(&[p.cache], 3),
236
(&[p.cache, p.cache], 3),
237
(&[p.left_sink], 4),
238
(&[p.left_sink, p.left_sink], 4),
239
(&[p.right_sink], 5),
240
(&[p.left_sink, p.right_sink], 6),
241
(&[p.filter, p.scan, p.left_sink, p.cache, p.right_sink, p.sort, p.cache, p.right_sink], 6),
242
(&[p.left_sink, p.right_sink, p.extra_sink], 7),
243
];
244
245
for (i, &(case, expected_arena_size)) in cases.iter().enumerate() {
246
let (_, arenas) = p.prune(case);
247
assert_eq!(
248
arenas.ir.len(),
249
expected_arena_size,
250
"case: {i}, pruned_ir: {:?}",
251
arenas.ir
252
);
253
}
254
}
255
256
fn plans_equal(a: IRPlanRef<'_>, b: IRPlanRef<'_>) -> bool {
257
let iter_a = a.lp_arena.iter(a.lp_top);
258
let iter_b = b.lp_arena.iter(b.lp_top);
259
for ((_, ir_a), (_, ir_b)) in iter_a.zip(iter_b) {
260
if std::mem::discriminant(ir_a) != std::mem::discriminant(ir_b)
261
|| !exprs_equal(ir_a, a.expr_arena, ir_b, b.expr_arena)
262
{
263
return false;
264
}
265
}
266
true
267
}
268
269
fn exprs_equal(ir_a: &IR, arena_a: &Arena<AExpr>, ir_b: &IR, arena_b: &Arena<AExpr>) -> bool {
270
let [a, b] = [(ir_a, arena_a), (ir_b, arena_b)].map(|(ir, arena)| {
271
ir.exprs()
272
.map(|e| (e.output_name_inner().clone(), e.to_expr(arena)))
273
});
274
a.eq(b)
275
}
276
277
impl BranchedPlan {
278
pub fn new() -> Self {
279
let mut ir_arena = Arena::new();
280
let mut expr_arena = Arena::new();
281
let schema = Schema::from_iter([Field::new("a".into(), DataType::UInt8)]);
282
283
let scan = ir_arena.add(IR::DataFrameScan {
284
df: Arc::new(DataFrame::empty_with_schema(&schema)),
285
schema: Arc::new(schema.clone()),
286
output_schema: None,
287
});
288
289
let mut ctx = ExprToIRContext::new(&mut expr_arena, &schema);
290
ctx.allow_unknown = true;
291
let filter = ir_arena.add(IR::Filter {
292
input: scan,
293
predicate: to_expr_ir(col("a").gt_eq(lit(10)), &mut ctx).unwrap(),
294
});
295
296
// Throw in an unreachable node
297
ir_arena.add(IR::Invalid);
298
299
let cache = ir_arena.add(IR::Cache {
300
input: filter,
301
id: UniqueId::new(),
302
});
303
304
let left_sink = ir_arena.add(IR::Sink {
305
input: cache,
306
payload: SinkTypeIR::Memory,
307
});
308
309
// Throw in an unreachable node
310
ir_arena.add(IR::Invalid);
311
312
let mut ctx = ExprToIRContext::new(&mut expr_arena, &schema);
313
ctx.allow_unknown = true;
314
let sort = ir_arena.add(IR::Sort {
315
input: cache,
316
by_column: vec![to_expr_ir(col("a"), &mut ctx).unwrap()],
317
slice: None,
318
sort_options: Default::default(),
319
});
320
321
let right_sink = ir_arena.add(IR::Sink {
322
input: sort,
323
payload: SinkTypeIR::Memory,
324
});
325
326
// Throw in an unused sink
327
let extra_sink = ir_arena.add(IR::Sink {
328
input: cache,
329
payload: SinkTypeIR::Memory,
330
});
331
332
Self {
333
ir_arena,
334
expr_arena,
335
scan,
336
filter,
337
cache,
338
left_sink,
339
sort,
340
right_sink,
341
extra_sink,
342
}
343
}
344
345
pub fn prune(&self, roots: &[Node]) -> (Vec<Node>, Arenas) {
346
let mut arenas = Arenas {
347
ir: Arena::new(),
348
expr: Arena::new(),
349
};
350
let pruned = prune(
351
roots,
352
&self.ir_arena,
353
&self.expr_arena,
354
&mut arenas.ir,
355
&mut arenas.expr,
356
);
357
(pruned, arenas)
358
}
359
360
pub fn plan(&'_ self, node: Node) -> IRPlanRef<'_> {
361
IRPlanRef {
362
lp_top: node,
363
lp_arena: &self.ir_arena,
364
expr_arena: &self.expr_arena,
365
}
366
}
367
}
368
369
struct Arenas {
370
ir: Arena<IR>,
371
expr: Arena<AExpr>,
372
}
373
374
impl Arenas {
375
pub fn plan(&'_ self, root: Node) -> IRPlanRef<'_> {
376
IRPlanRef {
377
lp_top: root,
378
lp_arena: &self.ir,
379
expr_arena: &self.expr,
380
}
381
}
382
}
383
}
384
385