Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/plans/optimizer/cse/cspe.rs
7889 views
1
use std::hash::{Hash, Hasher};
2
3
use hashbrown::hash_map::RawEntryMut;
4
use polars_utils::unique_id::UniqueId;
5
6
use super::*;
7
use crate::prelude::visitor::IRNode;
8
9
struct Blake3Hasher {
10
hasher: blake3::Hasher,
11
}
12
13
impl Blake3Hasher {
14
fn new() -> Self {
15
Self {
16
hasher: blake3::Hasher::new(),
17
}
18
}
19
20
fn finalize(self) -> [u8; 32] {
21
self.hasher.finalize().into()
22
}
23
}
24
25
impl Hasher for Blake3Hasher {
26
fn finish(&self) -> u64 {
27
// Not used - we'll call finalize() instead
28
0
29
}
30
31
fn write(&mut self, bytes: &[u8]) {
32
self.hasher.update(bytes);
33
}
34
}
35
36
mod identifier_impl {
37
use super::*;
38
#[derive(Clone)]
39
pub(super) struct Identifier {
40
inner: Option<[u8; 32]>,
41
}
42
43
impl Identifier {
44
pub fn hash(&self) -> u64 {
45
self.inner
46
.map(|inner| u64::from_le_bytes(inner[0..8].try_into().unwrap()))
47
.unwrap_or(0)
48
}
49
50
pub fn is_equal(&self, other: &Self) -> bool {
51
self.inner.map(blake3::Hash::from_bytes) == other.inner.map(blake3::Hash::from_bytes)
52
}
53
54
pub fn new() -> Self {
55
Self { inner: None }
56
}
57
58
pub fn is_valid(&self) -> bool {
59
self.inner.is_some()
60
}
61
62
pub fn combine(&mut self, other: &Identifier) {
63
let inner = match (self.inner, other.inner) {
64
(Some(l), Some(r)) => {
65
let mut h = blake3::Hasher::new();
66
h.update(&l);
67
h.update(&r);
68
*h.finalize().as_bytes()
69
},
70
(None, Some(r)) => r,
71
(Some(l), None) => l,
72
_ => return,
73
};
74
self.inner = Some(inner);
75
}
76
77
pub fn add_alp_node(
78
&self,
79
alp: &IRNode,
80
lp_arena: &Arena<IR>,
81
expr_arena: &Arena<AExpr>,
82
) -> Self {
83
let mut h = Blake3Hasher::new();
84
alp.hashable_and_cmp(lp_arena, expr_arena)
85
.hash_as_equality()
86
.hash(&mut h);
87
let hashed = h.finalize();
88
89
let inner = Some(self.inner.map_or(hashed, |l| {
90
let mut h = blake3::Hasher::new();
91
h.update(&l);
92
h.update(&hashed);
93
*h.finalize().as_bytes()
94
}));
95
Self { inner }
96
}
97
}
98
}
99
use identifier_impl::*;
100
101
struct IdentifierMap<V> {
102
inner: PlHashMap<Identifier, V>,
103
}
104
105
impl<V> IdentifierMap<V> {
106
fn new() -> Self {
107
Self {
108
inner: Default::default(),
109
}
110
}
111
112
fn get(&self, id: &Identifier) -> Option<&V> {
113
self.inner
114
.raw_entry()
115
.from_hash(id.hash(), |k| k.is_equal(id))
116
.map(|(_k, v)| v)
117
}
118
119
fn entry<F: FnOnce() -> V>(&mut self, id: Identifier, v: F) -> &mut V {
120
let h = id.hash();
121
match self.inner.raw_entry_mut().from_hash(h, |k| k.is_equal(&id)) {
122
RawEntryMut::Occupied(entry) => entry.into_mut(),
123
RawEntryMut::Vacant(entry) => {
124
let (_, v) = entry.insert_with_hasher(h, id, v(), |id| id.hash());
125
v
126
},
127
}
128
}
129
}
130
131
impl<V> Default for IdentifierMap<V> {
132
fn default() -> Self {
133
Self::new()
134
}
135
}
136
/// Identifier maps to Expr Node and count.
137
type SubPlanCount = IdentifierMap<(Node, u32)>;
138
/// (post_visit_idx, identifier);
139
type IdentifierArray = Vec<(usize, Identifier)>;
140
141
/// See Expr based CSE for explanations.
142
enum VisitRecord {
143
/// Entered a new plan node
144
Entered(usize),
145
SubPlanId(Identifier),
146
}
147
148
struct LpIdentifierVisitor<'a> {
149
sp_count: &'a mut SubPlanCount,
150
identifier_array: &'a mut IdentifierArray,
151
// Index in pre-visit traversal order.
152
pre_visit_idx: usize,
153
post_visit_idx: usize,
154
visit_stack: Vec<VisitRecord>,
155
has_subplan: bool,
156
}
157
158
impl LpIdentifierVisitor<'_> {
159
fn new<'a>(
160
sp_count: &'a mut SubPlanCount,
161
identifier_array: &'a mut IdentifierArray,
162
) -> LpIdentifierVisitor<'a> {
163
LpIdentifierVisitor {
164
sp_count,
165
identifier_array,
166
pre_visit_idx: 0,
167
post_visit_idx: 0,
168
visit_stack: vec![],
169
has_subplan: false,
170
}
171
}
172
173
fn pop_until_entered(&mut self) -> (usize, Identifier) {
174
let mut id = Identifier::new();
175
176
while let Some(item) = self.visit_stack.pop() {
177
match item {
178
VisitRecord::Entered(idx) => return (idx, id),
179
VisitRecord::SubPlanId(s) => {
180
id.combine(&s);
181
},
182
}
183
}
184
unreachable!()
185
}
186
}
187
188
fn skip_children(lp: &IR) -> bool {
189
match lp {
190
// Don't visit all the files in a `scan *` operation.
191
// Put an arbitrary limit to 20 files now.
192
IR::Union {
193
options, inputs, ..
194
} => options.from_partitioned_ds && inputs.len() > 20,
195
_ => false,
196
}
197
}
198
199
impl Visitor for LpIdentifierVisitor<'_> {
200
type Node = IRNode;
201
type Arena = IRNodeArena;
202
203
fn pre_visit(
204
&mut self,
205
node: &Self::Node,
206
arena: &Self::Arena,
207
) -> PolarsResult<VisitRecursion> {
208
self.visit_stack
209
.push(VisitRecord::Entered(self.pre_visit_idx));
210
self.pre_visit_idx += 1;
211
212
self.identifier_array.push((0, Identifier::new()));
213
214
if skip_children(node.to_alp(&arena.0)) {
215
Ok(VisitRecursion::Skip)
216
} else {
217
Ok(VisitRecursion::Continue)
218
}
219
}
220
221
fn post_visit(
222
&mut self,
223
node: &Self::Node,
224
arena: &Self::Arena,
225
) -> PolarsResult<VisitRecursion> {
226
self.post_visit_idx += 1;
227
228
let (pre_visit_idx, sub_plan_id) = self.pop_until_entered();
229
230
// Create the Id of this node.
231
let id = sub_plan_id.add_alp_node(node, &arena.0, &arena.1);
232
233
// Store the created id.
234
self.identifier_array[pre_visit_idx] = (self.post_visit_idx, id.clone());
235
236
// We popped until entered, push this Id on the stack so the trail
237
// is available for the parent plan.
238
self.visit_stack.push(VisitRecord::SubPlanId(id.clone()));
239
240
let (_, sp_count) = self.sp_count.entry(id, || (node.node(), 0));
241
*sp_count += 1;
242
self.has_subplan |= *sp_count > 1;
243
Ok(VisitRecursion::Continue)
244
}
245
}
246
247
pub(super) type CacheId2Caches = PlHashMap<UniqueId, (u32, Vec<Node>)>;
248
249
struct CommonSubPlanRewriter<'a> {
250
sp_count: &'a SubPlanCount,
251
identifier_array: &'a IdentifierArray,
252
253
max_post_visit_idx: usize,
254
/// index in traversal order in which `identifier_array`
255
/// was written. This is the index in `identifier_array`.
256
visited_idx: usize,
257
/// Indicates if this expression is rewritten.
258
rewritten: bool,
259
cache_id: IdentifierMap<UniqueId>,
260
// Maps cache_id : (cache_count and cache_nodes)
261
cache_id_to_caches: CacheId2Caches,
262
}
263
264
impl<'a> CommonSubPlanRewriter<'a> {
265
fn new(sp_count: &'a SubPlanCount, identifier_array: &'a IdentifierArray) -> Self {
266
Self {
267
sp_count,
268
identifier_array,
269
max_post_visit_idx: 0,
270
visited_idx: 0,
271
rewritten: false,
272
cache_id: Default::default(),
273
cache_id_to_caches: Default::default(),
274
}
275
}
276
}
277
278
impl RewritingVisitor for CommonSubPlanRewriter<'_> {
279
type Node = IRNode;
280
type Arena = IRNodeArena;
281
282
fn pre_visit(
283
&mut self,
284
lp_node: &Self::Node,
285
arena: &mut Self::Arena,
286
) -> PolarsResult<RewriteRecursion> {
287
if self.visited_idx >= self.identifier_array.len()
288
|| self.max_post_visit_idx > self.identifier_array[self.visited_idx].0
289
{
290
return Ok(RewriteRecursion::Stop);
291
}
292
293
let id = &self.identifier_array[self.visited_idx].1;
294
295
// Id placeholder not overwritten, so we can skip this sub-expression.
296
if !id.is_valid() {
297
self.visited_idx += 1;
298
return Ok(RewriteRecursion::NoMutateAndContinue);
299
}
300
301
let Some((_, count)) = self.sp_count.get(id) else {
302
self.visited_idx += 1;
303
return Ok(RewriteRecursion::NoMutateAndContinue);
304
};
305
306
if *count > 1 {
307
// Rewrite this sub-plan, don't visit its children
308
Ok(RewriteRecursion::MutateAndStop)
309
}
310
// Never mutate if count <= 1. The post-visit will search for the node, and not be able to find it
311
else {
312
// Don't traverse the children.
313
if skip_children(lp_node.to_alp(&arena.0)) {
314
return Ok(RewriteRecursion::Stop);
315
}
316
// This is a unique plan
317
// visit its children to see if they are cse
318
self.visited_idx += 1;
319
Ok(RewriteRecursion::NoMutateAndContinue)
320
}
321
}
322
323
fn mutate(
324
&mut self,
325
mut node: Self::Node,
326
arena: &mut Self::Arena,
327
) -> PolarsResult<Self::Node> {
328
let (post_visit_count, id) = &self.identifier_array[self.visited_idx];
329
self.visited_idx += 1;
330
331
if *post_visit_count < self.max_post_visit_idx {
332
return Ok(node);
333
}
334
self.max_post_visit_idx = *post_visit_count;
335
while self.visited_idx < self.identifier_array.len()
336
&& *post_visit_count > self.identifier_array[self.visited_idx].0
337
{
338
self.visited_idx += 1;
339
}
340
341
let cache_id = *self.cache_id.entry(id.clone(), UniqueId::new);
342
let cache_count = self.sp_count.get(id).unwrap().1;
343
344
let cache_node = IR::Cache {
345
input: node.node(),
346
id: cache_id,
347
};
348
node.assign(cache_node, &mut arena.0);
349
let (_count, nodes) = self
350
.cache_id_to_caches
351
.entry(cache_id)
352
.or_insert_with(|| (cache_count, vec![]));
353
nodes.push(node.node());
354
self.rewritten = true;
355
Ok(node)
356
}
357
}
358
359
fn insert_caches(
360
root: Node,
361
lp_arena: &mut Arena<IR>,
362
expr_arena: &mut Arena<AExpr>,
363
) -> (Node, bool, CacheId2Caches) {
364
let mut sp_count = Default::default();
365
let mut id_array = Default::default();
366
367
with_ir_arena(lp_arena, expr_arena, |arena| {
368
let lp_node = IRNode::new_mutate(root);
369
let mut visitor = LpIdentifierVisitor::new(&mut sp_count, &mut id_array);
370
371
lp_node.visit(&mut visitor, arena).map(|_| ()).unwrap();
372
373
if visitor.has_subplan {
374
let lp_node = IRNode::new_mutate(root);
375
let mut rewriter = CommonSubPlanRewriter::new(&sp_count, &id_array);
376
lp_node.rewrite(&mut rewriter, arena).unwrap();
377
378
(root, rewriter.rewritten, rewriter.cache_id_to_caches)
379
} else {
380
(root, false, Default::default())
381
}
382
})
383
}
384
385
/// Prune unused caches.
386
/// In the query below the query will be insert cache 0 with a count of 2 on `lf.select`
387
/// and cache 1 with a count of 3 on `lf`. But because cache 0 is higher in the chain cache 1
388
/// will never be used. So we prune caches that don't fit their count.
389
///
390
/// `conctat([lf.select(), lf.select(), lf])`
391
fn prune_unused_caches(lp_arena: &mut Arena<IR>, cid2c: &CacheId2Caches) {
392
for (count, nodes) in cid2c.values() {
393
if *count == nodes.len() as u32 {
394
continue;
395
}
396
397
for node in nodes {
398
let IR::Cache { input, .. } = lp_arena.get(*node) else {
399
unreachable!()
400
};
401
lp_arena.swap(*input, *node)
402
}
403
}
404
}
405
406
pub(super) fn elim_cmn_subplans(
407
root: Node,
408
lp_arena: &mut Arena<IR>,
409
expr_arena: &mut Arena<AExpr>,
410
) -> (Node, bool, CacheId2Caches) {
411
let (lp, changed, cid2c) = insert_caches(root, lp_arena, expr_arena);
412
if changed {
413
prune_unused_caches(lp_arena, &cid2c);
414
}
415
416
(lp, changed, cid2c)
417
}
418
419