Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/plans/optimizer/cse/csee.rs
8479 views
1
use std::hash::BuildHasher;
2
3
use hashbrown::hash_map::RawEntryMut;
4
use polars_core::CHEAP_SERIES_HASH_LIMIT;
5
use polars_utils::aliases::PlFixedStateQuality;
6
use polars_utils::format_pl_smallstr;
7
use polars_utils::hashing::_boost_hash_combine;
8
use polars_utils::vec::CapacityByFactor;
9
10
use super::*;
11
use crate::constants::CSE_REPLACED;
12
use crate::prelude::visitor::AexprNode;
13
14
#[derive(Debug, Clone)]
15
struct ProjectionExprs {
16
expr: Vec<ExprIR>,
17
/// offset from the back
18
/// `expr[expr.len() - common_sub_offset..]`
19
/// are the common sub expressions
20
common_sub_offset: usize,
21
}
22
23
impl ProjectionExprs {
24
fn default_exprs(&self) -> &[ExprIR] {
25
&self.expr[..self.expr.len() - self.common_sub_offset]
26
}
27
28
fn cse_exprs(&self) -> &[ExprIR] {
29
&self.expr[self.expr.len() - self.common_sub_offset..]
30
}
31
32
fn new_with_cse(expr: Vec<ExprIR>, common_sub_offset: usize) -> Self {
33
Self {
34
expr,
35
common_sub_offset,
36
}
37
}
38
}
39
40
/// Identifier that shows the sub-expression path.
41
/// Must implement hash and equality and ideally
42
/// have little collisions
43
/// We will do a full expression comparison to check if the
44
/// expressions with equal identifiers are truly equal
45
#[derive(Clone, Debug)]
46
pub(super) struct Identifier {
47
inner: Option<u64>,
48
last_node: Option<AexprNode>,
49
hb: PlFixedStateQuality,
50
}
51
52
impl Identifier {
53
fn new() -> Self {
54
Self {
55
inner: None,
56
last_node: None,
57
hb: PlFixedStateQuality::with_seed(0),
58
}
59
}
60
61
fn hash(&self) -> u64 {
62
self.inner.unwrap_or(0)
63
}
64
65
fn ae_node(&self) -> AexprNode {
66
self.last_node.unwrap()
67
}
68
69
fn is_equal(&self, other: &Self, arena: &Arena<AExpr>) -> bool {
70
self.inner == other.inner
71
&& self.last_node.map(|v| v.hashable_and_cmp(arena))
72
== other.last_node.map(|v| v.hashable_and_cmp(arena))
73
}
74
75
fn is_valid(&self) -> bool {
76
self.inner.is_some()
77
}
78
79
fn materialize(&self) -> PlSmallStr {
80
format_pl_smallstr!("{}{:#x}", CSE_REPLACED, self.materialized_hash())
81
}
82
83
fn materialized_hash(&self) -> u64 {
84
self.inner.unwrap_or(0)
85
}
86
87
fn combine(&mut self, other: &Identifier) {
88
let inner = match (self.inner, other.inner) {
89
(Some(l), Some(r)) => _boost_hash_combine(l, r),
90
(None, Some(r)) => r,
91
(Some(l), None) => l,
92
_ => return,
93
};
94
self.inner = Some(inner);
95
}
96
97
fn add_ae_node(&self, ae: &AexprNode, arena: &Arena<AExpr>) -> Self {
98
let hashed = self.hb.hash_one(ae.to_aexpr(arena));
99
let inner = Some(
100
self.inner
101
.map_or(hashed, |l| _boost_hash_combine(l, hashed)),
102
);
103
Self {
104
inner,
105
last_node: Some(*ae),
106
hb: self.hb.clone(),
107
}
108
}
109
}
110
111
#[derive(Default)]
112
struct IdentifierMap<V> {
113
inner: PlHashMap<Identifier, V>,
114
}
115
116
impl<V> IdentifierMap<V> {
117
fn get(&self, id: &Identifier, arena: &Arena<AExpr>) -> Option<&V> {
118
self.inner
119
.raw_entry()
120
.from_hash(id.hash(), |k| k.is_equal(id, arena))
121
.map(|(_k, v)| v)
122
}
123
124
fn entry<'a, F: FnOnce() -> V>(
125
&'a mut self,
126
id: Identifier,
127
v: F,
128
arena: &Arena<AExpr>,
129
) -> &'a mut V {
130
let h = id.hash();
131
match self
132
.inner
133
.raw_entry_mut()
134
.from_hash(h, |k| k.is_equal(&id, arena))
135
{
136
RawEntryMut::Occupied(entry) => entry.into_mut(),
137
RawEntryMut::Vacant(entry) => {
138
let (_, v) = entry.insert_with_hasher(h, id, v(), |id| id.hash());
139
v
140
},
141
}
142
}
143
fn insert(&mut self, id: Identifier, v: V, arena: &Arena<AExpr>) {
144
self.entry(id, || v, arena);
145
}
146
147
fn iter(&self) -> impl Iterator<Item = (&Identifier, &V)> {
148
self.inner.iter()
149
}
150
}
151
152
/// Merges identical expressions into identical IDs.
153
///
154
/// Does no analysis whether this leads to legal substitutions.
155
#[derive(Default)]
156
pub struct NaiveExprMerger {
157
node_to_uniq_id: PlHashMap<Node, u32>,
158
uniq_id_to_node: Vec<Node>,
159
identifier_to_uniq_id: IdentifierMap<u32>,
160
arg_stack: Vec<Option<Identifier>>,
161
}
162
163
impl NaiveExprMerger {
164
pub fn add_expr(&mut self, node: Node, arena: &Arena<AExpr>) {
165
let node = AexprNode::new(node);
166
node.visit(self, arena).unwrap();
167
}
168
169
pub fn get_uniq_id(&self, node: Node) -> Option<u32> {
170
self.node_to_uniq_id.get(&node).copied()
171
}
172
173
pub fn get_node(&self, uniq_id: u32) -> Option<Node> {
174
self.uniq_id_to_node.get(uniq_id as usize).copied()
175
}
176
}
177
178
impl Visitor for NaiveExprMerger {
179
type Node = AexprNode;
180
type Arena = Arena<AExpr>;
181
182
fn pre_visit(
183
&mut self,
184
_node: &Self::Node,
185
_arena: &Self::Arena,
186
) -> PolarsResult<VisitRecursion> {
187
self.arg_stack.push(None);
188
Ok(VisitRecursion::Continue)
189
}
190
191
fn post_visit(
192
&mut self,
193
node: &Self::Node,
194
arena: &Self::Arena,
195
) -> PolarsResult<VisitRecursion> {
196
let mut identifier = Identifier::new();
197
while let Some(Some(arg)) = self.arg_stack.pop() {
198
identifier.combine(&arg);
199
}
200
identifier = identifier.add_ae_node(node, arena);
201
let uniq_id = *self.identifier_to_uniq_id.entry(
202
identifier,
203
|| {
204
let uniq_id = self.uniq_id_to_node.len() as u32;
205
self.uniq_id_to_node.push(node.node());
206
uniq_id
207
},
208
arena,
209
);
210
self.node_to_uniq_id.insert(node.node(), uniq_id);
211
Ok(VisitRecursion::Continue)
212
}
213
}
214
215
/// Identifier maps to Expr Node and count.
216
type SubExprCount = IdentifierMap<(Node, u32)>;
217
/// (post_visit_idx, identifier);
218
type IdentifierArray = Vec<(usize, Identifier)>;
219
220
#[derive(Debug)]
221
enum VisitRecord {
222
/// entered a new expression
223
Entered(usize),
224
/// Every visited sub-expression pushes their identifier to the stack.
225
// The `bool` indicates if this expression is valid.
226
// This can be `AND` accumulated by the lineage of the expression to determine
227
// of the whole expression can be added.
228
// For instance a in a group_by we only want to use elementwise operation in cse:
229
// - `(col("a") * 2).sum(), (col("a") * 2)` -> we want to do `col("a") * 2` on a `with_columns`
230
// - `col("a").sum() * col("a").sum()` -> we don't want `sum` to run on `with_columns`
231
// as that doesn't have groups context. If we encounter a `sum` it should be flagged as `false`
232
//
233
// This should have the following stack
234
// id valid
235
// col(a) true
236
// sum false
237
// col(a) true
238
// sum false
239
// binary true
240
// -------------- accumulated
241
// false
242
SubExprId(Identifier, bool),
243
}
244
245
fn skip_pre_visit(ae: &AExpr, is_groupby: bool, element_wise_select_only: bool) -> bool {
246
match ae {
247
#[cfg(feature = "dynamic_group_by")]
248
AExpr::Rolling { .. } => true,
249
AExpr::Over { .. } => true,
250
#[cfg(feature = "dtype-struct")]
251
AExpr::Ternary { .. } => is_groupby,
252
ae => {
253
if element_wise_select_only {
254
if is_groupby {
255
true
256
} else {
257
!ae.is_elementwise_top_level()
258
}
259
} else {
260
false
261
}
262
},
263
}
264
}
265
266
/// Goes through an expression and generates a identifier
267
///
268
/// The visitor uses a `visit_stack` to track traversal order.
269
///
270
/// # Entering a node
271
/// When `pre-visit` is called we enter a new (sub)-expression and
272
/// we add `Entered` to the stack.
273
/// # Leaving a node
274
/// On `post-visit` when we leave the node and we pop all `SubExprIds` nodes.
275
/// Those are considered sub-expression of the leaving node
276
///
277
/// We also record an `id_array` that followed the pre-visit order. This
278
/// is used to cache the `Identifiers`.
279
//
280
// # Example (this is not a docstring as clippy complains about spacing)
281
// Say we have the expression: `(col("f00").min() * col("bar")).sum()`
282
// with the following call tree:
283
//
284
// sum
285
//
286
// |
287
//
288
// binary: *
289
//
290
// | |
291
//
292
// col(bar) min
293
//
294
// |
295
//
296
// col(f00)
297
//
298
// # call order
299
// function-called stack stack-after(pop until E, push I) # ID
300
// pre-visit: sum E -
301
// pre-visit: binary: * EE -
302
// pre-visit: col(bar) EEE -
303
// post-visit: col(bar) EEE EEI id: col(bar)
304
// pre-visit: min EEIE -
305
// pre-visit: col(f00) EEIEE -
306
// post-visit: col(f00) EEIEE EEIEI id: col(f00)
307
// post-visit: min EEIEI EEII id: min!col(f00)
308
// post-visit: binary: * EEII EI id: binary: *!min!col(f00)!col(bar)
309
// post-visit: sum EI I id: sum!binary: *!min!col(f00)!col(bar)
310
struct ExprIdentifierVisitor<'a> {
311
se_count: &'a mut SubExprCount,
312
/// Materialized `CSE` materialized (name) hashes can collide. So we validate that all CSE counts
313
/// match name hash counts.
314
name_validation: &'a mut PlHashMap<u64, u32>,
315
identifier_array: &'a mut IdentifierArray,
316
// Index in pre-visit traversal order.
317
pre_visit_idx: usize,
318
post_visit_idx: usize,
319
visit_stack: &'a mut Vec<VisitRecord>,
320
/// Offset in the identifier array
321
/// this allows us to use a single `vec` on multiple expressions
322
id_array_offset: usize,
323
// Whether the expression replaced a subexpression.
324
has_sub_expr: bool,
325
// During aggregation we only identify element-wise operations
326
is_group_by: bool,
327
//
328
element_wise_only: bool,
329
}
330
331
impl ExprIdentifierVisitor<'_> {
332
fn new<'a>(
333
se_count: &'a mut SubExprCount,
334
identifier_array: &'a mut IdentifierArray,
335
visit_stack: &'a mut Vec<VisitRecord>,
336
is_group_by: bool,
337
name_validation: &'a mut PlHashMap<u64, u32>,
338
element_wise_select_only: bool,
339
) -> ExprIdentifierVisitor<'a> {
340
let id_array_offset = identifier_array.len();
341
ExprIdentifierVisitor {
342
se_count,
343
name_validation,
344
identifier_array,
345
pre_visit_idx: 0,
346
post_visit_idx: 0,
347
visit_stack,
348
id_array_offset,
349
has_sub_expr: false,
350
is_group_by,
351
element_wise_only: element_wise_select_only,
352
}
353
}
354
355
/// pop all visit-records until an `Entered` is found. We accumulate a `SubExprId`s
356
/// to `id`. Finally we return the expression `idx` and `Identifier`.
357
/// This works due to the stack.
358
/// If we traverse another expression in the mean time, it will get popped of the stack first
359
/// so the returned identifier belongs to a single sub-expression
360
fn pop_until_entered(&mut self) -> (usize, Identifier, bool) {
361
let mut id = Identifier::new();
362
let mut is_valid_accumulated = true;
363
364
while let Some(item) = self.visit_stack.pop() {
365
match item {
366
VisitRecord::Entered(idx) => return (idx, id, is_valid_accumulated),
367
VisitRecord::SubExprId(s, valid) => {
368
id.combine(&s);
369
is_valid_accumulated &= valid
370
},
371
}
372
}
373
unreachable!()
374
}
375
376
/// return `None` -> node is accepted
377
/// return `Some(_)` node is not accepted and apply the given recursion operation
378
/// `Some(_, true)` don't accept this node, but can be a member of a cse.
379
/// `Some(_, false)` don't accept this node, and don't allow as a member of a cse.
380
fn accept_node_post_visit(&self, ae: &AExpr) -> Accepted {
381
match ae {
382
// window expressions should `evaluate_on_groups`, not `evaluate`
383
// so we shouldn't cache the children as they are evaluated incorrectly
384
#[cfg(feature = "dynamic_group_by")]
385
AExpr::Rolling { .. } => REFUSE_SKIP,
386
AExpr::Over { .. } => REFUSE_SKIP,
387
// Don't allow this for now, as we can get `null().cast()` in ternary expressions.
388
// TODO! Add a typed null
389
AExpr::Literal(LiteralValue::Scalar(sc)) if sc.is_null() => REFUSE_NO_MEMBER,
390
AExpr::Literal(s) => {
391
match s {
392
LiteralValue::Series(s) => {
393
let dtype = s.dtype();
394
395
// Object and nested types are harder to hash and compare.
396
let allow = !(dtype.is_nested() | dtype.is_object());
397
398
if s.len() < CHEAP_SERIES_HASH_LIMIT && allow {
399
REFUSE_ALLOW_MEMBER
400
} else {
401
REFUSE_NO_MEMBER
402
}
403
},
404
_ => REFUSE_ALLOW_MEMBER,
405
}
406
},
407
AExpr::Column(_) => REFUSE_ALLOW_MEMBER,
408
AExpr::Len => {
409
if self.is_group_by {
410
REFUSE_NO_MEMBER
411
} else {
412
REFUSE_ALLOW_MEMBER
413
}
414
},
415
#[cfg(feature = "random")]
416
AExpr::Function {
417
function: IRFunctionExpr::Random { .. },
418
..
419
} => REFUSE_NO_MEMBER,
420
#[cfg(feature = "rolling_window")]
421
AExpr::Function {
422
function: IRFunctionExpr::RollingExpr { .. },
423
..
424
} => REFUSE_NO_MEMBER,
425
_ => {
426
// During aggregation we only store elementwise operation in the state
427
// other operations we cannot add to the state as they have the output size of the
428
// groups, not the original dataframe
429
if self.is_group_by {
430
if !ae.is_elementwise_top_level() {
431
return REFUSE_NO_MEMBER;
432
}
433
match ae {
434
AExpr::Cast { .. } => REFUSE_ALLOW_MEMBER,
435
_ => ACCEPT,
436
}
437
} else {
438
ACCEPT
439
}
440
},
441
}
442
}
443
}
444
445
impl Visitor for ExprIdentifierVisitor<'_> {
446
type Node = AexprNode;
447
type Arena = Arena<AExpr>;
448
449
fn pre_visit(
450
&mut self,
451
node: &Self::Node,
452
arena: &Self::Arena,
453
) -> PolarsResult<VisitRecursion> {
454
if skip_pre_visit(
455
node.to_aexpr(arena),
456
self.is_group_by,
457
self.element_wise_only,
458
) {
459
// Still add to the stack so that a parent becomes invalidated.
460
self.visit_stack
461
.push(VisitRecord::SubExprId(Identifier::new(), false));
462
return Ok(VisitRecursion::Skip);
463
}
464
465
self.visit_stack
466
.push(VisitRecord::Entered(self.pre_visit_idx));
467
self.pre_visit_idx += 1;
468
469
// implement default placeholders
470
self.identifier_array
471
.push((self.id_array_offset, Identifier::new()));
472
473
Ok(VisitRecursion::Continue)
474
}
475
476
fn post_visit(
477
&mut self,
478
node: &Self::Node,
479
arena: &Self::Arena,
480
) -> PolarsResult<VisitRecursion> {
481
let ae = node.to_aexpr(arena);
482
self.post_visit_idx += 1;
483
484
let (pre_visit_idx, sub_expr_id, is_valid_accumulated) = self.pop_until_entered();
485
// Create the Id of this node.
486
let id: Identifier = sub_expr_id.add_ae_node(node, arena);
487
488
if !is_valid_accumulated {
489
self.identifier_array[pre_visit_idx + self.id_array_offset].0 = self.post_visit_idx;
490
self.visit_stack.push(VisitRecord::SubExprId(id, false));
491
return Ok(VisitRecursion::Continue);
492
}
493
494
// If we don't store this node
495
// we only push the visit_stack, so the parents know the trail.
496
if let Some((recurse, local_is_valid)) = self.accept_node_post_visit(ae) {
497
self.identifier_array[pre_visit_idx + self.id_array_offset].0 = self.post_visit_idx;
498
499
self.visit_stack
500
.push(VisitRecord::SubExprId(id, local_is_valid));
501
return Ok(recurse);
502
}
503
504
// Store the created id.
505
self.identifier_array[pre_visit_idx + self.id_array_offset] =
506
(self.post_visit_idx, id.clone());
507
508
// We popped until entered, push this Id on the stack so the trail
509
// is available for the parent expression.
510
self.visit_stack
511
.push(VisitRecord::SubExprId(id.clone(), true));
512
513
let mat_h = id.materialized_hash();
514
let (_, se_count) = self.se_count.entry(id, || (node.node(), 0), arena);
515
516
*se_count += 1;
517
*self.name_validation.entry(mat_h).or_insert(0) += 1;
518
self.has_sub_expr |= *se_count > 1;
519
520
Ok(VisitRecursion::Continue)
521
}
522
}
523
524
struct CommonSubExprRewriter<'a> {
525
sub_expr_map: &'a SubExprCount,
526
identifier_array: &'a IdentifierArray,
527
/// keep track of the replaced identifiers.
528
replaced_identifiers: &'a mut IdentifierMap<()>,
529
530
max_post_visit_idx: usize,
531
/// index in traversal order in which `identifier_array`
532
/// was written. This is the index in `identifier_array`.
533
visited_idx: usize,
534
/// Offset in the identifier array.
535
/// This allows us to use a single `vec` on multiple expressions
536
id_array_offset: usize,
537
/// Indicates if this expression is rewritten.
538
rewritten: bool,
539
is_group_by: bool,
540
is_element_wise_select_only: bool,
541
}
542
543
impl<'a> CommonSubExprRewriter<'a> {
544
fn new(
545
sub_expr_map: &'a SubExprCount,
546
identifier_array: &'a IdentifierArray,
547
replaced_identifiers: &'a mut IdentifierMap<()>,
548
id_array_offset: usize,
549
is_group_by: bool,
550
is_element_wise_select_only: bool,
551
) -> Self {
552
Self {
553
sub_expr_map,
554
identifier_array,
555
replaced_identifiers,
556
max_post_visit_idx: 0,
557
visited_idx: 0,
558
id_array_offset,
559
rewritten: false,
560
is_group_by,
561
is_element_wise_select_only,
562
}
563
}
564
}
565
566
// # Example
567
// Expression tree with [pre-visit,post-visit] indices
568
// counted from 1
569
// [1,8] binary: +
570
//
571
// | |
572
//
573
// [2,2] sum [4,7] sum
574
//
575
// | |
576
//
577
// [3,1] col(foo) [5,6] binary: *
578
//
579
// | |
580
//
581
// [6,3] col(bar) [7,5] sum
582
//
583
// |
584
//
585
// [8,4] col(foo)
586
//
587
// in this tree `col(foo).sum()` should be post-visited/mutated
588
// so if we are at `[2,2]`
589
//
590
// call stack
591
// pre-visit [1,8] binary -> no_mutate_and_continue -> visits children
592
// pre-visit [2,2] sum -> mutate_and_stop -> does not visit children
593
// post-visit [2,2] sum -> skip index to [4,7] (because we didn't visit children)
594
// pre-visit [4,7] sum -> no_mutate_and_continue -> visits children
595
// pre-visit [5,6] binary -> no_mutate_and_continue -> visits children
596
// pre-visit [6,3] col -> stop_recursion -> does not mutate
597
// pre-visit [7,5] sum -> mutate_and_stop -> does not visit children
598
// post-visit [7,5] -> skip index to end
599
impl RewritingVisitor for CommonSubExprRewriter<'_> {
600
type Node = AexprNode;
601
type Arena = Arena<AExpr>;
602
603
fn pre_visit(
604
&mut self,
605
ae_node: &Self::Node,
606
arena: &mut Self::Arena,
607
) -> PolarsResult<RewriteRecursion> {
608
let ae = ae_node.to_aexpr(arena);
609
if self.visited_idx + self.id_array_offset >= self.identifier_array.len()
610
|| self.max_post_visit_idx
611
> self.identifier_array[self.visited_idx + self.id_array_offset].0
612
|| skip_pre_visit(ae, self.is_group_by, self.is_element_wise_select_only)
613
{
614
return Ok(RewriteRecursion::Stop);
615
}
616
617
let id = &self.identifier_array[self.visited_idx + self.id_array_offset].1;
618
619
// Id placeholder not overwritten, so we can skip this sub-expression.
620
if !id.is_valid() {
621
self.visited_idx += 1;
622
let recurse = if ae_node.is_leaf(arena) {
623
RewriteRecursion::Stop
624
} else {
625
// continue visit its children to see
626
// if there are cse
627
RewriteRecursion::NoMutateAndContinue
628
};
629
return Ok(recurse);
630
}
631
632
// Because some expressions don't have hash / equality guarantee (e.g. floats)
633
// we can get none here. This must be changed later.
634
let Some((_, count)) = self.sub_expr_map.get(id, arena) else {
635
self.visited_idx += 1;
636
return Ok(RewriteRecursion::NoMutateAndContinue);
637
};
638
if *count > 1 {
639
self.replaced_identifiers.insert(id.clone(), (), arena);
640
// rewrite this sub-expression, don't visit its children
641
Ok(RewriteRecursion::MutateAndStop)
642
} else {
643
// This is a unique expression
644
// visit its children to see if they are cse
645
self.visited_idx += 1;
646
Ok(RewriteRecursion::NoMutateAndContinue)
647
}
648
}
649
650
fn mutate(
651
&mut self,
652
mut node: Self::Node,
653
arena: &mut Self::Arena,
654
) -> PolarsResult<Self::Node> {
655
let (post_visit_count, id) =
656
&self.identifier_array[self.visited_idx + self.id_array_offset];
657
self.visited_idx += 1;
658
659
// TODO!: check if we ever hit this branch
660
if *post_visit_count < self.max_post_visit_idx {
661
return Ok(node);
662
}
663
664
self.max_post_visit_idx = *post_visit_count;
665
// DFS, so every post_visit that is smaller than `post_visit_count`
666
// is a subexpression of this node and we can skip that
667
//
668
// `self.visited_idx` will influence recursion strategy in `pre_visit`
669
// see call-stack comment above
670
while self.visited_idx < self.identifier_array.len() - self.id_array_offset
671
&& *post_visit_count > self.identifier_array[self.visited_idx + self.id_array_offset].0
672
{
673
self.visited_idx += 1;
674
}
675
// If this is not true, the traversal order in the visitor was different from the rewriter.
676
debug_assert_eq!(
677
node.hashable_and_cmp(arena),
678
id.ae_node().hashable_and_cmp(arena)
679
);
680
681
let name = id.materialize();
682
node.assign(AExpr::col(name), arena);
683
self.rewritten = true;
684
685
Ok(node)
686
}
687
}
688
689
pub(crate) struct CommonSubExprOptimizer {
690
// amortize allocations
691
// these are cleared per lp node
692
se_count: SubExprCount,
693
id_array: IdentifierArray,
694
id_array_offsets: Vec<u32>,
695
replaced_identifiers: IdentifierMap<()>,
696
// these are cleared per expr node
697
visit_stack: Vec<VisitRecord>,
698
name_validation: PlHashMap<u64, u32>,
699
// Set by the streaming engine
700
// Only supports element-wise CSEE
701
// on SELECT/HSTACK
702
element_wise_select_only: bool,
703
}
704
705
impl CommonSubExprOptimizer {
706
pub(crate) fn new(element_wise_select_only: bool) -> Self {
707
Self {
708
se_count: Default::default(),
709
id_array: Default::default(),
710
visit_stack: Default::default(),
711
id_array_offsets: Default::default(),
712
replaced_identifiers: Default::default(),
713
name_validation: Default::default(),
714
element_wise_select_only,
715
}
716
}
717
718
fn visit_expression(
719
&mut self,
720
ae_node: AexprNode,
721
is_group_by: bool,
722
expr_arena: &mut Arena<AExpr>,
723
element_wise_select_only: bool,
724
) -> PolarsResult<(usize, bool)> {
725
let mut visitor = ExprIdentifierVisitor::new(
726
&mut self.se_count,
727
&mut self.id_array,
728
&mut self.visit_stack,
729
is_group_by,
730
&mut self.name_validation,
731
element_wise_select_only,
732
);
733
ae_node.visit(&mut visitor, expr_arena).map(|_| ())?;
734
Ok((visitor.id_array_offset, visitor.has_sub_expr))
735
}
736
737
/// Mutate the expression.
738
/// Returns a new expression and a `bool` indicating if it was rewritten or not.
739
fn mutate_expression(
740
&mut self,
741
ae_node: AexprNode,
742
id_array_offset: usize,
743
is_group_by: bool,
744
expr_arena: &mut Arena<AExpr>,
745
element_wise_select_only: bool,
746
) -> PolarsResult<(AexprNode, bool)> {
747
let mut rewriter = CommonSubExprRewriter::new(
748
&self.se_count,
749
&self.id_array,
750
&mut self.replaced_identifiers,
751
id_array_offset,
752
is_group_by,
753
element_wise_select_only,
754
);
755
ae_node
756
.rewrite(&mut rewriter, expr_arena)
757
.map(|out| (out, rewriter.rewritten))
758
}
759
760
fn find_cse(
761
&mut self,
762
expr: &[ExprIR],
763
expr_arena: &mut Arena<AExpr>,
764
id_array_offsets: &mut Vec<u32>,
765
is_group_by: bool,
766
schema: &Schema,
767
element_wise_select_only: bool,
768
) -> PolarsResult<Option<ProjectionExprs>> {
769
let mut has_sub_expr = false;
770
771
// First get all cse's.
772
for e in expr {
773
// The visitor can return early thus depleted its stack
774
// on a previous iteration.
775
self.visit_stack.clear();
776
777
// Visit expressions and collect sub-expression counts.
778
let ae_node = AexprNode::new(e.node());
779
let (id_array_offset, this_expr_has_se) =
780
self.visit_expression(ae_node, is_group_by, expr_arena, element_wise_select_only)?;
781
id_array_offsets.push(id_array_offset as u32);
782
has_sub_expr |= this_expr_has_se;
783
}
784
785
// Ensure that the `materialized hashes` count matches that of the CSE count.
786
// It can happen that CSE collide and in that case we fallback and skip CSE.
787
for (id, (_, count)) in self.se_count.iter() {
788
let mat_h = id.materialized_hash();
789
let valid = if let Some(name_count) = self.name_validation.get(&mat_h) {
790
*name_count == *count
791
} else {
792
false
793
};
794
795
if !valid {
796
if verbose() {
797
eprintln!(
798
"materialized names collided in common subexpression elimination.\n backtrace and run without CSE"
799
)
800
}
801
return Ok(None);
802
}
803
}
804
805
if has_sub_expr {
806
let mut new_expr = Vec::with_capacity_by_factor(expr.len(), 1.3);
807
808
// Then rewrite the expressions that have a cse count > 1.
809
for (e, offset) in expr.iter().zip(id_array_offsets.iter()) {
810
let ae_node = AexprNode::new(e.node());
811
812
let (out, rewritten) = self.mutate_expression(
813
ae_node,
814
*offset as usize,
815
is_group_by,
816
expr_arena,
817
element_wise_select_only,
818
)?;
819
820
let out_node = out.node();
821
let mut out_e = e.clone();
822
let new_node = if !rewritten {
823
out_e
824
} else {
825
out_e.set_node(out_node);
826
827
// Ensure the function ExprIR's have the proper names.
828
// This is needed for structs to get the proper field
829
let mut scratch = vec![];
830
let mut stack = vec![(e.node(), out_node)];
831
while let Some((original, new)) = stack.pop() {
832
// Don't follow identical nodes.
833
if original == new {
834
continue;
835
}
836
scratch.clear();
837
let aes = expr_arena.get_disjoint_mut([original, new]);
838
839
// Only follow paths that are the same.
840
if std::mem::discriminant(aes[0]) != std::mem::discriminant(aes[1]) {
841
continue;
842
}
843
844
aes[0].inputs_rev(&mut scratch);
845
let offset = scratch.len();
846
aes[1].inputs_rev(&mut scratch);
847
848
// If they have a different number of inputs, we don't follow the nodes.
849
if scratch.len() != offset * 2 {
850
continue;
851
}
852
853
for i in 0..scratch.len() / 2 {
854
stack.push((scratch[i], scratch[i + offset]));
855
}
856
857
match expr_arena.get_disjoint_mut([original, new]) {
858
[
859
AExpr::Function {
860
input: input_original,
861
..
862
},
863
AExpr::Function {
864
input: input_new, ..
865
},
866
] => {
867
for (new, original) in input_new.iter_mut().zip(input_original) {
868
new.set_alias(original.output_name().clone());
869
}
870
},
871
[
872
AExpr::AnonymousFunction {
873
input: input_original,
874
..
875
},
876
AExpr::AnonymousFunction {
877
input: input_new, ..
878
},
879
] => {
880
for (new, original) in input_new.iter_mut().zip(input_original) {
881
new.set_alias(original.output_name().clone());
882
}
883
},
884
_ => {},
885
}
886
}
887
888
// If we don't end with an alias we add an alias. Because the normal left-hand
889
// rule we apply for determining the name will not work we now refer to
890
// intermediate temporary names starting with the `CSE_REPLACED` constant.
891
if !e.has_alias() {
892
let name = ae_node.to_field(schema, expr_arena)?.name;
893
out_e.set_alias(name.clone());
894
}
895
out_e
896
};
897
new_expr.push(new_node)
898
}
899
// Add the tmp columns
900
for id in self.replaced_identifiers.inner.keys() {
901
let (node, _count) = self.se_count.get(id, expr_arena).unwrap();
902
let name = id.materialize();
903
let out_e = ExprIR::new(*node, OutputName::Alias(name));
904
new_expr.push(out_e)
905
}
906
let expr =
907
ProjectionExprs::new_with_cse(new_expr, self.replaced_identifiers.inner.len());
908
Ok(Some(expr))
909
} else {
910
Ok(None)
911
}
912
}
913
}
914
915
impl RewritingVisitor for CommonSubExprOptimizer {
916
type Node = IRNode;
917
type Arena = IRNodeArena;
918
919
fn pre_visit(
920
&mut self,
921
node: &Self::Node,
922
arena: &mut Self::Arena,
923
) -> PolarsResult<RewriteRecursion> {
924
use IR::*;
925
Ok(match node.to_alp(&arena.0) {
926
Select { .. } | HStack { .. } | GroupBy { .. } => RewriteRecursion::MutateAndContinue,
927
_ => RewriteRecursion::NoMutateAndContinue,
928
})
929
}
930
931
fn mutate(&mut self, node: Self::Node, arena: &mut Self::Arena) -> PolarsResult<Self::Node> {
932
let mut id_array_offsets = std::mem::take(&mut self.id_array_offsets);
933
934
self.se_count.inner.clear();
935
self.name_validation.clear();
936
self.id_array.clear();
937
id_array_offsets.clear();
938
self.replaced_identifiers.inner.clear();
939
940
let arena_idx = node.node();
941
let alp = arena.0.get(arena_idx);
942
943
match alp {
944
IR::Select {
945
input,
946
expr,
947
schema,
948
options,
949
} => {
950
let input_schema = arena.0.get(*input).schema(&arena.0);
951
if let Some(expr) = self.find_cse(
952
expr,
953
&mut arena.1,
954
&mut id_array_offsets,
955
false,
956
input_schema.as_ref().as_ref(),
957
self.element_wise_select_only,
958
)? {
959
let schema = schema.clone();
960
let options = *options;
961
962
let lp = IRBuilder::new(*input, &mut arena.1, &mut arena.0)
963
.with_columns(
964
expr.cse_exprs().to_vec(),
965
ProjectionOptions {
966
run_parallel: options.run_parallel,
967
duplicate_check: options.duplicate_check,
968
// These columns might have different
969
// lengths from the dataframe, but
970
// they are only temporaries that will
971
// be removed by the evaluation of the
972
// default_exprs and the subsequent
973
// projection.
974
should_broadcast: false,
975
},
976
)
977
.build();
978
let input = arena.0.add(lp);
979
980
let lp = IR::Select {
981
input,
982
expr: expr.default_exprs().to_vec(),
983
schema,
984
options,
985
};
986
arena.0.replace(arena_idx, lp);
987
}
988
},
989
IR::HStack {
990
input,
991
exprs,
992
schema,
993
options,
994
} => {
995
let input_schema = arena.0.get(*input).schema(&arena.0);
996
if let Some(exprs) = self.find_cse(
997
exprs,
998
&mut arena.1,
999
&mut id_array_offsets,
1000
false,
1001
input_schema.as_ref().as_ref(),
1002
self.element_wise_select_only,
1003
)? {
1004
let schema = schema.clone();
1005
let options = *options;
1006
let input = *input;
1007
1008
let lp = IRBuilder::new(input, &mut arena.1, &mut arena.0)
1009
.with_columns(
1010
exprs.cse_exprs().to_vec(),
1011
// These columns might have different
1012
// lengths from the dataframe, but they
1013
// are only temporaries that will be
1014
// removed by the evaluation of the
1015
// default_exprs and the subsequent
1016
// projection.
1017
ProjectionOptions {
1018
run_parallel: options.run_parallel,
1019
duplicate_check: options.duplicate_check,
1020
should_broadcast: false,
1021
},
1022
)
1023
.with_columns(exprs.default_exprs().to_vec(), options)
1024
.build();
1025
let input = arena.0.add(lp);
1026
1027
let lp = IR::SimpleProjection {
1028
input,
1029
columns: schema,
1030
};
1031
arena.0.replace(arena_idx, lp);
1032
}
1033
},
1034
IR::GroupBy {
1035
input,
1036
keys,
1037
aggs,
1038
options,
1039
maintain_order,
1040
apply,
1041
schema,
1042
} if !self.element_wise_select_only => {
1043
let input_schema = arena.0.get(*input).schema(&arena.0);
1044
if let Some(aggs) = self.find_cse(
1045
aggs,
1046
&mut arena.1,
1047
&mut id_array_offsets,
1048
true,
1049
input_schema.as_ref().as_ref(),
1050
self.element_wise_select_only,
1051
)? {
1052
let keys = keys.clone();
1053
let options = options.clone();
1054
let schema = schema.clone();
1055
let apply = apply.clone();
1056
let maintain_order = *maintain_order;
1057
let input = *input;
1058
1059
let lp = IRBuilder::new(input, &mut arena.1, &mut arena.0)
1060
.with_columns(aggs.cse_exprs().to_vec(), Default::default())
1061
.build();
1062
let input = arena.0.add(lp);
1063
1064
let lp = IR::GroupBy {
1065
input,
1066
keys,
1067
aggs: aggs.default_exprs().to_vec(),
1068
options,
1069
schema,
1070
maintain_order,
1071
apply,
1072
};
1073
arena.0.replace(arena_idx, lp);
1074
}
1075
},
1076
_ => {},
1077
}
1078
1079
self.id_array_offsets = id_array_offsets;
1080
Ok(node)
1081
}
1082
}
1083
1084