Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/plans/optimizer/cse/csee.rs
7889 views
1
use std::hash::BuildHasher;
2
3
use hashbrown::hash_map::RawEntryMut;
4
use polars_core::CHEAP_SERIES_HASH_LIMIT;
5
use polars_utils::aliases::PlFixedStateQuality;
6
use polars_utils::format_pl_smallstr;
7
use polars_utils::hashing::_boost_hash_combine;
8
use polars_utils::vec::CapacityByFactor;
9
10
use super::*;
11
use crate::constants::CSE_REPLACED;
12
use crate::prelude::visitor::AexprNode;
13
14
#[derive(Debug, Clone)]
15
struct ProjectionExprs {
16
expr: Vec<ExprIR>,
17
/// offset from the back
18
/// `expr[expr.len() - common_sub_offset..]`
19
/// are the common sub expressions
20
common_sub_offset: usize,
21
}
22
23
impl ProjectionExprs {
24
fn default_exprs(&self) -> &[ExprIR] {
25
&self.expr[..self.expr.len() - self.common_sub_offset]
26
}
27
28
fn cse_exprs(&self) -> &[ExprIR] {
29
&self.expr[self.expr.len() - self.common_sub_offset..]
30
}
31
32
fn new_with_cse(expr: Vec<ExprIR>, common_sub_offset: usize) -> Self {
33
Self {
34
expr,
35
common_sub_offset,
36
}
37
}
38
}
39
40
/// Identifier that shows the sub-expression path.
41
/// Must implement hash and equality and ideally
42
/// have little collisions
43
/// We will do a full expression comparison to check if the
44
/// expressions with equal identifiers are truly equal
45
#[derive(Clone, Debug)]
46
pub(super) struct Identifier {
47
inner: Option<u64>,
48
last_node: Option<AexprNode>,
49
hb: PlFixedStateQuality,
50
}
51
52
impl Identifier {
53
fn new() -> Self {
54
Self {
55
inner: None,
56
last_node: None,
57
hb: PlFixedStateQuality::with_seed(0),
58
}
59
}
60
61
fn hash(&self) -> u64 {
62
self.inner.unwrap_or(0)
63
}
64
65
fn ae_node(&self) -> AexprNode {
66
self.last_node.unwrap()
67
}
68
69
fn is_equal(&self, other: &Self, arena: &Arena<AExpr>) -> bool {
70
self.inner == other.inner
71
&& self.last_node.map(|v| v.hashable_and_cmp(arena))
72
== other.last_node.map(|v| v.hashable_and_cmp(arena))
73
}
74
75
fn is_valid(&self) -> bool {
76
self.inner.is_some()
77
}
78
79
fn materialize(&self) -> PlSmallStr {
80
format_pl_smallstr!("{}{:#x}", CSE_REPLACED, self.materialized_hash())
81
}
82
83
fn materialized_hash(&self) -> u64 {
84
self.inner.unwrap_or(0)
85
}
86
87
fn combine(&mut self, other: &Identifier) {
88
let inner = match (self.inner, other.inner) {
89
(Some(l), Some(r)) => _boost_hash_combine(l, r),
90
(None, Some(r)) => r,
91
(Some(l), None) => l,
92
_ => return,
93
};
94
self.inner = Some(inner);
95
}
96
97
fn add_ae_node(&self, ae: &AexprNode, arena: &Arena<AExpr>) -> Self {
98
let hashed = self.hb.hash_one(ae.to_aexpr(arena));
99
let inner = Some(
100
self.inner
101
.map_or(hashed, |l| _boost_hash_combine(l, hashed)),
102
);
103
Self {
104
inner,
105
last_node: Some(*ae),
106
hb: self.hb.clone(),
107
}
108
}
109
}
110
111
#[derive(Default)]
112
struct IdentifierMap<V> {
113
inner: PlHashMap<Identifier, V>,
114
}
115
116
impl<V> IdentifierMap<V> {
117
fn get(&self, id: &Identifier, arena: &Arena<AExpr>) -> Option<&V> {
118
self.inner
119
.raw_entry()
120
.from_hash(id.hash(), |k| k.is_equal(id, arena))
121
.map(|(_k, v)| v)
122
}
123
124
fn entry<'a, F: FnOnce() -> V>(
125
&'a mut self,
126
id: Identifier,
127
v: F,
128
arena: &Arena<AExpr>,
129
) -> &'a mut V {
130
let h = id.hash();
131
match self
132
.inner
133
.raw_entry_mut()
134
.from_hash(h, |k| k.is_equal(&id, arena))
135
{
136
RawEntryMut::Occupied(entry) => entry.into_mut(),
137
RawEntryMut::Vacant(entry) => {
138
let (_, v) = entry.insert_with_hasher(h, id, v(), |id| id.hash());
139
v
140
},
141
}
142
}
143
fn insert(&mut self, id: Identifier, v: V, arena: &Arena<AExpr>) {
144
self.entry(id, || v, arena);
145
}
146
147
fn iter(&self) -> impl Iterator<Item = (&Identifier, &V)> {
148
self.inner.iter()
149
}
150
}
151
152
/// Merges identical expressions into identical IDs.
153
///
154
/// Does no analysis whether this leads to legal substitutions.
155
#[derive(Default)]
156
pub struct NaiveExprMerger {
157
node_to_uniq_id: PlHashMap<Node, u32>,
158
uniq_id_to_node: Vec<Node>,
159
identifier_to_uniq_id: IdentifierMap<u32>,
160
arg_stack: Vec<Option<Identifier>>,
161
}
162
163
impl NaiveExprMerger {
164
pub fn add_expr(&mut self, node: Node, arena: &Arena<AExpr>) {
165
let node = AexprNode::new(node);
166
node.visit(self, arena).unwrap();
167
}
168
169
pub fn get_uniq_id(&self, node: Node) -> Option<u32> {
170
self.node_to_uniq_id.get(&node).copied()
171
}
172
173
pub fn get_node(&self, uniq_id: u32) -> Option<Node> {
174
self.uniq_id_to_node.get(uniq_id as usize).copied()
175
}
176
}
177
178
impl Visitor for NaiveExprMerger {
179
type Node = AexprNode;
180
type Arena = Arena<AExpr>;
181
182
fn pre_visit(
183
&mut self,
184
_node: &Self::Node,
185
_arena: &Self::Arena,
186
) -> PolarsResult<VisitRecursion> {
187
self.arg_stack.push(None);
188
Ok(VisitRecursion::Continue)
189
}
190
191
fn post_visit(
192
&mut self,
193
node: &Self::Node,
194
arena: &Self::Arena,
195
) -> PolarsResult<VisitRecursion> {
196
let mut identifier = Identifier::new();
197
while let Some(Some(arg)) = self.arg_stack.pop() {
198
identifier.combine(&arg);
199
}
200
identifier = identifier.add_ae_node(node, arena);
201
let uniq_id = *self.identifier_to_uniq_id.entry(
202
identifier,
203
|| {
204
let uniq_id = self.uniq_id_to_node.len() as u32;
205
self.uniq_id_to_node.push(node.node());
206
uniq_id
207
},
208
arena,
209
);
210
self.node_to_uniq_id.insert(node.node(), uniq_id);
211
Ok(VisitRecursion::Continue)
212
}
213
}
214
215
/// Identifier maps to Expr Node and count.
216
type SubExprCount = IdentifierMap<(Node, u32)>;
217
/// (post_visit_idx, identifier);
218
type IdentifierArray = Vec<(usize, Identifier)>;
219
220
#[derive(Debug)]
221
enum VisitRecord {
222
/// entered a new expression
223
Entered(usize),
224
/// Every visited sub-expression pushes their identifier to the stack.
225
// The `bool` indicates if this expression is valid.
226
// This can be `AND` accumulated by the lineage of the expression to determine
227
// of the whole expression can be added.
228
// For instance a in a group_by we only want to use elementwise operation in cse:
229
// - `(col("a") * 2).sum(), (col("a") * 2)` -> we want to do `col("a") * 2` on a `with_columns`
230
// - `col("a").sum() * col("a").sum()` -> we don't want `sum` to run on `with_columns`
231
// as that doesn't have groups context. If we encounter a `sum` it should be flagged as `false`
232
//
233
// This should have the following stack
234
// id valid
235
// col(a) true
236
// sum false
237
// col(a) true
238
// sum false
239
// binary true
240
// -------------- accumulated
241
// false
242
SubExprId(Identifier, bool),
243
}
244
245
fn skip_pre_visit(ae: &AExpr, is_groupby: bool) -> bool {
246
match ae {
247
#[cfg(feature = "dynamic_group_by")]
248
AExpr::Rolling { .. } => true,
249
AExpr::Over { .. } => true,
250
#[cfg(feature = "dtype-struct")]
251
AExpr::Ternary { .. } => is_groupby,
252
_ => false,
253
}
254
}
255
256
/// Goes through an expression and generates a identifier
257
///
258
/// The visitor uses a `visit_stack` to track traversal order.
259
///
260
/// # Entering a node
261
/// When `pre-visit` is called we enter a new (sub)-expression and
262
/// we add `Entered` to the stack.
263
/// # Leaving a node
264
/// On `post-visit` when we leave the node and we pop all `SubExprIds` nodes.
265
/// Those are considered sub-expression of the leaving node
266
///
267
/// We also record an `id_array` that followed the pre-visit order. This
268
/// is used to cache the `Identifiers`.
269
//
270
// # Example (this is not a docstring as clippy complains about spacing)
271
// Say we have the expression: `(col("f00").min() * col("bar")).sum()`
272
// with the following call tree:
273
//
274
// sum
275
//
276
// |
277
//
278
// binary: *
279
//
280
// | |
281
//
282
// col(bar) min
283
//
284
// |
285
//
286
// col(f00)
287
//
288
// # call order
289
// function-called stack stack-after(pop until E, push I) # ID
290
// pre-visit: sum E -
291
// pre-visit: binary: * EE -
292
// pre-visit: col(bar) EEE -
293
// post-visit: col(bar) EEE EEI id: col(bar)
294
// pre-visit: min EEIE -
295
// pre-visit: col(f00) EEIEE -
296
// post-visit: col(f00) EEIEE EEIEI id: col(f00)
297
// post-visit: min EEIEI EEII id: min!col(f00)
298
// post-visit: binary: * EEII EI id: binary: *!min!col(f00)!col(bar)
299
// post-visit: sum EI I id: sum!binary: *!min!col(f00)!col(bar)
300
struct ExprIdentifierVisitor<'a> {
301
se_count: &'a mut SubExprCount,
302
/// Materialized `CSE` materialized (name) hashes can collide. So we validate that all CSE counts
303
/// match name hash counts.
304
name_validation: &'a mut PlHashMap<u64, u32>,
305
identifier_array: &'a mut IdentifierArray,
306
// Index in pre-visit traversal order.
307
pre_visit_idx: usize,
308
post_visit_idx: usize,
309
visit_stack: &'a mut Vec<VisitRecord>,
310
/// Offset in the identifier array
311
/// this allows us to use a single `vec` on multiple expressions
312
id_array_offset: usize,
313
// Whether the expression replaced a subexpression.
314
has_sub_expr: bool,
315
// During aggregation we only identify element-wise operations
316
is_group_by: bool,
317
}
318
319
impl ExprIdentifierVisitor<'_> {
320
fn new<'a>(
321
se_count: &'a mut SubExprCount,
322
identifier_array: &'a mut IdentifierArray,
323
visit_stack: &'a mut Vec<VisitRecord>,
324
is_group_by: bool,
325
name_validation: &'a mut PlHashMap<u64, u32>,
326
) -> ExprIdentifierVisitor<'a> {
327
let id_array_offset = identifier_array.len();
328
ExprIdentifierVisitor {
329
se_count,
330
name_validation,
331
identifier_array,
332
pre_visit_idx: 0,
333
post_visit_idx: 0,
334
visit_stack,
335
id_array_offset,
336
has_sub_expr: false,
337
is_group_by,
338
}
339
}
340
341
/// pop all visit-records until an `Entered` is found. We accumulate a `SubExprId`s
342
/// to `id`. Finally we return the expression `idx` and `Identifier`.
343
/// This works due to the stack.
344
/// If we traverse another expression in the mean time, it will get popped of the stack first
345
/// so the returned identifier belongs to a single sub-expression
346
fn pop_until_entered(&mut self) -> (usize, Identifier, bool) {
347
let mut id = Identifier::new();
348
let mut is_valid_accumulated = true;
349
350
while let Some(item) = self.visit_stack.pop() {
351
match item {
352
VisitRecord::Entered(idx) => return (idx, id, is_valid_accumulated),
353
VisitRecord::SubExprId(s, valid) => {
354
id.combine(&s);
355
is_valid_accumulated &= valid
356
},
357
}
358
}
359
unreachable!()
360
}
361
362
/// return `None` -> node is accepted
363
/// return `Some(_)` node is not accepted and apply the given recursion operation
364
/// `Some(_, true)` don't accept this node, but can be a member of a cse.
365
/// `Some(_, false)` don't accept this node, and don't allow as a member of a cse.
366
fn accept_node_post_visit(&self, ae: &AExpr) -> Accepted {
367
match ae {
368
// window expressions should `evaluate_on_groups`, not `evaluate`
369
// so we shouldn't cache the children as they are evaluated incorrectly
370
#[cfg(feature = "dynamic_group_by")]
371
AExpr::Rolling { .. } => REFUSE_SKIP,
372
AExpr::Over { .. } => REFUSE_SKIP,
373
// Don't allow this for now, as we can get `null().cast()` in ternary expressions.
374
// TODO! Add a typed null
375
AExpr::Literal(LiteralValue::Scalar(sc)) if sc.is_null() => REFUSE_NO_MEMBER,
376
AExpr::Literal(s) => {
377
match s {
378
LiteralValue::Series(s) => {
379
let dtype = s.dtype();
380
381
// Object and nested types are harder to hash and compare.
382
let allow = !(dtype.is_nested() | dtype.is_object());
383
384
if s.len() < CHEAP_SERIES_HASH_LIMIT && allow {
385
REFUSE_ALLOW_MEMBER
386
} else {
387
REFUSE_NO_MEMBER
388
}
389
},
390
_ => REFUSE_ALLOW_MEMBER,
391
}
392
},
393
AExpr::Column(_) => REFUSE_ALLOW_MEMBER,
394
AExpr::Len => {
395
if self.is_group_by {
396
REFUSE_NO_MEMBER
397
} else {
398
REFUSE_ALLOW_MEMBER
399
}
400
},
401
#[cfg(feature = "random")]
402
AExpr::Function {
403
function: IRFunctionExpr::Random { .. },
404
..
405
} => REFUSE_NO_MEMBER,
406
#[cfg(feature = "rolling_window")]
407
AExpr::Function {
408
function: IRFunctionExpr::RollingExpr { .. },
409
..
410
} => REFUSE_NO_MEMBER,
411
AExpr::AnonymousFunction { .. } => REFUSE_NO_MEMBER,
412
_ => {
413
// During aggregation we only store elementwise operation in the state
414
// other operations we cannot add to the state as they have the output size of the
415
// groups, not the original dataframe
416
if self.is_group_by {
417
if !ae.is_elementwise_top_level() {
418
return REFUSE_NO_MEMBER;
419
}
420
match ae {
421
AExpr::AnonymousFunction { .. } => REFUSE_NO_MEMBER,
422
AExpr::Cast { .. } => REFUSE_ALLOW_MEMBER,
423
_ => ACCEPT,
424
}
425
} else {
426
ACCEPT
427
}
428
},
429
}
430
}
431
}
432
433
impl Visitor for ExprIdentifierVisitor<'_> {
434
type Node = AexprNode;
435
type Arena = Arena<AExpr>;
436
437
fn pre_visit(
438
&mut self,
439
node: &Self::Node,
440
arena: &Self::Arena,
441
) -> PolarsResult<VisitRecursion> {
442
if skip_pre_visit(node.to_aexpr(arena), self.is_group_by) {
443
// Still add to the stack so that a parent becomes invalidated.
444
self.visit_stack
445
.push(VisitRecord::SubExprId(Identifier::new(), false));
446
return Ok(VisitRecursion::Skip);
447
}
448
449
self.visit_stack
450
.push(VisitRecord::Entered(self.pre_visit_idx));
451
self.pre_visit_idx += 1;
452
453
// implement default placeholders
454
self.identifier_array
455
.push((self.id_array_offset, Identifier::new()));
456
457
Ok(VisitRecursion::Continue)
458
}
459
460
fn post_visit(
461
&mut self,
462
node: &Self::Node,
463
arena: &Self::Arena,
464
) -> PolarsResult<VisitRecursion> {
465
let ae = node.to_aexpr(arena);
466
self.post_visit_idx += 1;
467
468
let (pre_visit_idx, sub_expr_id, is_valid_accumulated) = self.pop_until_entered();
469
// Create the Id of this node.
470
let id: Identifier = sub_expr_id.add_ae_node(node, arena);
471
472
if !is_valid_accumulated {
473
self.identifier_array[pre_visit_idx + self.id_array_offset].0 = self.post_visit_idx;
474
self.visit_stack.push(VisitRecord::SubExprId(id, false));
475
return Ok(VisitRecursion::Continue);
476
}
477
478
// If we don't store this node
479
// we only push the visit_stack, so the parents know the trail.
480
if let Some((recurse, local_is_valid)) = self.accept_node_post_visit(ae) {
481
self.identifier_array[pre_visit_idx + self.id_array_offset].0 = self.post_visit_idx;
482
483
self.visit_stack
484
.push(VisitRecord::SubExprId(id, local_is_valid));
485
return Ok(recurse);
486
}
487
488
// Store the created id.
489
self.identifier_array[pre_visit_idx + self.id_array_offset] =
490
(self.post_visit_idx, id.clone());
491
492
// We popped until entered, push this Id on the stack so the trail
493
// is available for the parent expression.
494
self.visit_stack
495
.push(VisitRecord::SubExprId(id.clone(), true));
496
497
let mat_h = id.materialized_hash();
498
let (_, se_count) = self.se_count.entry(id, || (node.node(), 0), arena);
499
500
*se_count += 1;
501
*self.name_validation.entry(mat_h).or_insert(0) += 1;
502
self.has_sub_expr |= *se_count > 1;
503
504
Ok(VisitRecursion::Continue)
505
}
506
}
507
508
struct CommonSubExprRewriter<'a> {
509
sub_expr_map: &'a SubExprCount,
510
identifier_array: &'a IdentifierArray,
511
/// keep track of the replaced identifiers.
512
replaced_identifiers: &'a mut IdentifierMap<()>,
513
514
max_post_visit_idx: usize,
515
/// index in traversal order in which `identifier_array`
516
/// was written. This is the index in `identifier_array`.
517
visited_idx: usize,
518
/// Offset in the identifier array.
519
/// This allows us to use a single `vec` on multiple expressions
520
id_array_offset: usize,
521
/// Indicates if this expression is rewritten.
522
rewritten: bool,
523
is_group_by: bool,
524
}
525
526
impl<'a> CommonSubExprRewriter<'a> {
527
fn new(
528
sub_expr_map: &'a SubExprCount,
529
identifier_array: &'a IdentifierArray,
530
replaced_identifiers: &'a mut IdentifierMap<()>,
531
id_array_offset: usize,
532
is_group_by: bool,
533
) -> Self {
534
Self {
535
sub_expr_map,
536
identifier_array,
537
replaced_identifiers,
538
max_post_visit_idx: 0,
539
visited_idx: 0,
540
id_array_offset,
541
rewritten: false,
542
is_group_by,
543
}
544
}
545
}
546
547
// # Example
548
// Expression tree with [pre-visit,post-visit] indices
549
// counted from 1
550
// [1,8] binary: +
551
//
552
// | |
553
//
554
// [2,2] sum [4,7] sum
555
//
556
// | |
557
//
558
// [3,1] col(foo) [5,6] binary: *
559
//
560
// | |
561
//
562
// [6,3] col(bar) [7,5] sum
563
//
564
// |
565
//
566
// [8,4] col(foo)
567
//
568
// in this tree `col(foo).sum()` should be post-visited/mutated
569
// so if we are at `[2,2]`
570
//
571
// call stack
572
// pre-visit [1,8] binary -> no_mutate_and_continue -> visits children
573
// pre-visit [2,2] sum -> mutate_and_stop -> does not visit children
574
// post-visit [2,2] sum -> skip index to [4,7] (because we didn't visit children)
575
// pre-visit [4,7] sum -> no_mutate_and_continue -> visits children
576
// pre-visit [5,6] binary -> no_mutate_and_continue -> visits children
577
// pre-visit [6,3] col -> stop_recursion -> does not mutate
578
// pre-visit [7,5] sum -> mutate_and_stop -> does not visit children
579
// post-visit [7,5] -> skip index to end
580
impl RewritingVisitor for CommonSubExprRewriter<'_> {
581
type Node = AexprNode;
582
type Arena = Arena<AExpr>;
583
584
fn pre_visit(
585
&mut self,
586
ae_node: &Self::Node,
587
arena: &mut Self::Arena,
588
) -> PolarsResult<RewriteRecursion> {
589
let ae = ae_node.to_aexpr(arena);
590
if self.visited_idx + self.id_array_offset >= self.identifier_array.len()
591
|| self.max_post_visit_idx
592
> self.identifier_array[self.visited_idx + self.id_array_offset].0
593
|| skip_pre_visit(ae, self.is_group_by)
594
{
595
return Ok(RewriteRecursion::Stop);
596
}
597
598
let id = &self.identifier_array[self.visited_idx + self.id_array_offset].1;
599
600
// Id placeholder not overwritten, so we can skip this sub-expression.
601
if !id.is_valid() {
602
self.visited_idx += 1;
603
let recurse = if ae_node.is_leaf(arena) {
604
RewriteRecursion::Stop
605
} else {
606
// continue visit its children to see
607
// if there are cse
608
RewriteRecursion::NoMutateAndContinue
609
};
610
return Ok(recurse);
611
}
612
613
// Because some expressions don't have hash / equality guarantee (e.g. floats)
614
// we can get none here. This must be changed later.
615
let Some((_, count)) = self.sub_expr_map.get(id, arena) else {
616
self.visited_idx += 1;
617
return Ok(RewriteRecursion::NoMutateAndContinue);
618
};
619
if *count > 1 {
620
self.replaced_identifiers.insert(id.clone(), (), arena);
621
// rewrite this sub-expression, don't visit its children
622
Ok(RewriteRecursion::MutateAndStop)
623
} else {
624
// This is a unique expression
625
// visit its children to see if they are cse
626
self.visited_idx += 1;
627
Ok(RewriteRecursion::NoMutateAndContinue)
628
}
629
}
630
631
fn mutate(
632
&mut self,
633
mut node: Self::Node,
634
arena: &mut Self::Arena,
635
) -> PolarsResult<Self::Node> {
636
let (post_visit_count, id) =
637
&self.identifier_array[self.visited_idx + self.id_array_offset];
638
self.visited_idx += 1;
639
640
// TODO!: check if we ever hit this branch
641
if *post_visit_count < self.max_post_visit_idx {
642
return Ok(node);
643
}
644
645
self.max_post_visit_idx = *post_visit_count;
646
// DFS, so every post_visit that is smaller than `post_visit_count`
647
// is a subexpression of this node and we can skip that
648
//
649
// `self.visited_idx` will influence recursion strategy in `pre_visit`
650
// see call-stack comment above
651
while self.visited_idx < self.identifier_array.len() - self.id_array_offset
652
&& *post_visit_count > self.identifier_array[self.visited_idx + self.id_array_offset].0
653
{
654
self.visited_idx += 1;
655
}
656
// If this is not true, the traversal order in the visitor was different from the rewriter.
657
debug_assert_eq!(
658
node.hashable_and_cmp(arena),
659
id.ae_node().hashable_and_cmp(arena)
660
);
661
662
let name = id.materialize();
663
node.assign(AExpr::col(name), arena);
664
self.rewritten = true;
665
666
Ok(node)
667
}
668
}
669
670
pub(crate) struct CommonSubExprOptimizer {
671
// amortize allocations
672
// these are cleared per lp node
673
se_count: SubExprCount,
674
id_array: IdentifierArray,
675
id_array_offsets: Vec<u32>,
676
replaced_identifiers: IdentifierMap<()>,
677
// these are cleared per expr node
678
visit_stack: Vec<VisitRecord>,
679
name_validation: PlHashMap<u64, u32>,
680
}
681
682
impl CommonSubExprOptimizer {
683
pub(crate) fn new() -> Self {
684
Self {
685
se_count: Default::default(),
686
id_array: Default::default(),
687
visit_stack: Default::default(),
688
id_array_offsets: Default::default(),
689
replaced_identifiers: Default::default(),
690
name_validation: Default::default(),
691
}
692
}
693
694
fn visit_expression(
695
&mut self,
696
ae_node: AexprNode,
697
is_group_by: bool,
698
expr_arena: &mut Arena<AExpr>,
699
) -> PolarsResult<(usize, bool)> {
700
let mut visitor = ExprIdentifierVisitor::new(
701
&mut self.se_count,
702
&mut self.id_array,
703
&mut self.visit_stack,
704
is_group_by,
705
&mut self.name_validation,
706
);
707
ae_node.visit(&mut visitor, expr_arena).map(|_| ())?;
708
Ok((visitor.id_array_offset, visitor.has_sub_expr))
709
}
710
711
/// Mutate the expression.
712
/// Returns a new expression and a `bool` indicating if it was rewritten or not.
713
fn mutate_expression(
714
&mut self,
715
ae_node: AexprNode,
716
id_array_offset: usize,
717
is_group_by: bool,
718
expr_arena: &mut Arena<AExpr>,
719
) -> PolarsResult<(AexprNode, bool)> {
720
let mut rewriter = CommonSubExprRewriter::new(
721
&self.se_count,
722
&self.id_array,
723
&mut self.replaced_identifiers,
724
id_array_offset,
725
is_group_by,
726
);
727
ae_node
728
.rewrite(&mut rewriter, expr_arena)
729
.map(|out| (out, rewriter.rewritten))
730
}
731
732
fn find_cse(
733
&mut self,
734
expr: &[ExprIR],
735
expr_arena: &mut Arena<AExpr>,
736
id_array_offsets: &mut Vec<u32>,
737
is_group_by: bool,
738
schema: &Schema,
739
) -> PolarsResult<Option<ProjectionExprs>> {
740
let mut has_sub_expr = false;
741
742
// First get all cse's.
743
for e in expr {
744
// The visitor can return early thus depleted its stack
745
// on a previous iteration.
746
self.visit_stack.clear();
747
748
// Visit expressions and collect sub-expression counts.
749
let ae_node = AexprNode::new(e.node());
750
let (id_array_offset, this_expr_has_se) =
751
self.visit_expression(ae_node, is_group_by, expr_arena)?;
752
id_array_offsets.push(id_array_offset as u32);
753
has_sub_expr |= this_expr_has_se;
754
}
755
756
// Ensure that the `materialized hashes` count matches that of the CSE count.
757
// It can happen that CSE collide and in that case we fallback and skip CSE.
758
for (id, (_, count)) in self.se_count.iter() {
759
let mat_h = id.materialized_hash();
760
let valid = if let Some(name_count) = self.name_validation.get(&mat_h) {
761
*name_count == *count
762
} else {
763
false
764
};
765
766
if !valid {
767
if verbose() {
768
eprintln!(
769
"materialized names collided in common subexpression elimination.\n backtrace and run without CSE"
770
)
771
}
772
return Ok(None);
773
}
774
}
775
776
if has_sub_expr {
777
let mut new_expr = Vec::with_capacity_by_factor(expr.len(), 1.3);
778
779
// Then rewrite the expressions that have a cse count > 1.
780
for (e, offset) in expr.iter().zip(id_array_offsets.iter()) {
781
let ae_node = AexprNode::new(e.node());
782
783
let (out, rewritten) =
784
self.mutate_expression(ae_node, *offset as usize, is_group_by, expr_arena)?;
785
786
let out_node = out.node();
787
let mut out_e = e.clone();
788
let new_node = if !rewritten {
789
out_e
790
} else {
791
out_e.set_node(out_node);
792
793
// Ensure the function ExprIR's have the proper names.
794
// This is needed for structs to get the proper field
795
let mut scratch = vec![];
796
let mut stack = vec![(e.node(), out_node)];
797
while let Some((original, new)) = stack.pop() {
798
// Don't follow identical nodes.
799
if original == new {
800
continue;
801
}
802
scratch.clear();
803
let aes = expr_arena.get_many_mut([original, new]);
804
805
// Only follow paths that are the same.
806
if std::mem::discriminant(aes[0]) != std::mem::discriminant(aes[1]) {
807
continue;
808
}
809
810
aes[0].inputs_rev(&mut scratch);
811
let offset = scratch.len();
812
aes[1].inputs_rev(&mut scratch);
813
814
// If they have a different number of inputs, we don't follow the nodes.
815
if scratch.len() != offset * 2 {
816
continue;
817
}
818
819
for i in 0..scratch.len() / 2 {
820
stack.push((scratch[i], scratch[i + offset]));
821
}
822
823
match expr_arena.get_many_mut([original, new]) {
824
[
825
AExpr::Function {
826
input: input_original,
827
..
828
},
829
AExpr::Function {
830
input: input_new, ..
831
},
832
] => {
833
for (new, original) in input_new.iter_mut().zip(input_original) {
834
new.set_alias(original.output_name().clone());
835
}
836
},
837
[
838
AExpr::AnonymousFunction {
839
input: input_original,
840
..
841
},
842
AExpr::AnonymousFunction {
843
input: input_new, ..
844
},
845
] => {
846
for (new, original) in input_new.iter_mut().zip(input_original) {
847
new.set_alias(original.output_name().clone());
848
}
849
},
850
_ => {},
851
}
852
}
853
854
// If we don't end with an alias we add an alias. Because the normal left-hand
855
// rule we apply for determining the name will not work we now refer to
856
// intermediate temporary names starting with the `CSE_REPLACED` constant.
857
if !e.has_alias() {
858
let name = ae_node.to_field(schema, expr_arena)?.name;
859
out_e.set_alias(name.clone());
860
}
861
out_e
862
};
863
new_expr.push(new_node)
864
}
865
// Add the tmp columns
866
for id in self.replaced_identifiers.inner.keys() {
867
let (node, _count) = self.se_count.get(id, expr_arena).unwrap();
868
let name = id.materialize();
869
let out_e = ExprIR::new(*node, OutputName::Alias(name));
870
new_expr.push(out_e)
871
}
872
let expr =
873
ProjectionExprs::new_with_cse(new_expr, self.replaced_identifiers.inner.len());
874
Ok(Some(expr))
875
} else {
876
Ok(None)
877
}
878
}
879
}
880
881
impl RewritingVisitor for CommonSubExprOptimizer {
882
type Node = IRNode;
883
type Arena = IRNodeArena;
884
885
fn pre_visit(
886
&mut self,
887
node: &Self::Node,
888
arena: &mut Self::Arena,
889
) -> PolarsResult<RewriteRecursion> {
890
use IR::*;
891
Ok(match node.to_alp(&arena.0) {
892
Select { .. } | HStack { .. } | GroupBy { .. } => RewriteRecursion::MutateAndContinue,
893
_ => RewriteRecursion::NoMutateAndContinue,
894
})
895
}
896
897
fn mutate(&mut self, node: Self::Node, arena: &mut Self::Arena) -> PolarsResult<Self::Node> {
898
let mut id_array_offsets = std::mem::take(&mut self.id_array_offsets);
899
900
self.se_count.inner.clear();
901
self.name_validation.clear();
902
self.id_array.clear();
903
id_array_offsets.clear();
904
self.replaced_identifiers.inner.clear();
905
906
let arena_idx = node.node();
907
let alp = arena.0.get(arena_idx);
908
909
match alp {
910
IR::Select {
911
input,
912
expr,
913
schema,
914
options,
915
} => {
916
let input_schema = arena.0.get(*input).schema(&arena.0);
917
if let Some(expr) = self.find_cse(
918
expr,
919
&mut arena.1,
920
&mut id_array_offsets,
921
false,
922
input_schema.as_ref().as_ref(),
923
)? {
924
let schema = schema.clone();
925
let options = *options;
926
927
let lp = IRBuilder::new(*input, &mut arena.1, &mut arena.0)
928
.with_columns(
929
expr.cse_exprs().to_vec(),
930
ProjectionOptions {
931
run_parallel: options.run_parallel,
932
duplicate_check: options.duplicate_check,
933
// These columns might have different
934
// lengths from the dataframe, but
935
// they are only temporaries that will
936
// be removed by the evaluation of the
937
// default_exprs and the subsequent
938
// projection.
939
should_broadcast: false,
940
},
941
)
942
.build();
943
let input = arena.0.add(lp);
944
945
let lp = IR::Select {
946
input,
947
expr: expr.default_exprs().to_vec(),
948
schema,
949
options,
950
};
951
arena.0.replace(arena_idx, lp);
952
}
953
},
954
IR::HStack {
955
input,
956
exprs,
957
schema,
958
options,
959
} => {
960
let input_schema = arena.0.get(*input).schema(&arena.0);
961
if let Some(exprs) = self.find_cse(
962
exprs,
963
&mut arena.1,
964
&mut id_array_offsets,
965
false,
966
input_schema.as_ref().as_ref(),
967
)? {
968
let schema = schema.clone();
969
let options = *options;
970
let input = *input;
971
972
let lp = IRBuilder::new(input, &mut arena.1, &mut arena.0)
973
.with_columns(
974
exprs.cse_exprs().to_vec(),
975
// These columns might have different
976
// lengths from the dataframe, but they
977
// are only temporaries that will be
978
// removed by the evaluation of the
979
// default_exprs and the subsequent
980
// projection.
981
ProjectionOptions {
982
run_parallel: options.run_parallel,
983
duplicate_check: options.duplicate_check,
984
should_broadcast: false,
985
},
986
)
987
.with_columns(exprs.default_exprs().to_vec(), options)
988
.build();
989
let input = arena.0.add(lp);
990
991
let lp = IR::SimpleProjection {
992
input,
993
columns: schema,
994
};
995
arena.0.replace(arena_idx, lp);
996
}
997
},
998
IR::GroupBy {
999
input,
1000
keys,
1001
aggs,
1002
options,
1003
maintain_order,
1004
apply,
1005
schema,
1006
} => {
1007
let input_schema = arena.0.get(*input).schema(&arena.0);
1008
if let Some(aggs) = self.find_cse(
1009
aggs,
1010
&mut arena.1,
1011
&mut id_array_offsets,
1012
true,
1013
input_schema.as_ref().as_ref(),
1014
)? {
1015
let keys = keys.clone();
1016
let options = options.clone();
1017
let schema = schema.clone();
1018
let apply = apply.clone();
1019
let maintain_order = *maintain_order;
1020
let input = *input;
1021
1022
let lp = IRBuilder::new(input, &mut arena.1, &mut arena.0)
1023
.with_columns(aggs.cse_exprs().to_vec(), Default::default())
1024
.build();
1025
let input = arena.0.add(lp);
1026
1027
let lp = IR::GroupBy {
1028
input,
1029
keys,
1030
aggs: aggs.default_exprs().to_vec(),
1031
options,
1032
schema,
1033
maintain_order,
1034
apply,
1035
};
1036
arena.0.replace(arena_idx, lp);
1037
}
1038
},
1039
_ => {},
1040
}
1041
1042
self.id_array_offsets = id_array_offsets;
1043
Ok(node)
1044
}
1045
}
1046
1047