Path: blob/main/crates/polars-plan/src/plans/optimizer/cse/csee.rs
8479 views
use std::hash::BuildHasher;12use hashbrown::hash_map::RawEntryMut;3use polars_core::CHEAP_SERIES_HASH_LIMIT;4use polars_utils::aliases::PlFixedStateQuality;5use polars_utils::format_pl_smallstr;6use polars_utils::hashing::_boost_hash_combine;7use polars_utils::vec::CapacityByFactor;89use super::*;10use crate::constants::CSE_REPLACED;11use crate::prelude::visitor::AexprNode;1213#[derive(Debug, Clone)]14struct ProjectionExprs {15expr: Vec<ExprIR>,16/// offset from the back17/// `expr[expr.len() - common_sub_offset..]`18/// are the common sub expressions19common_sub_offset: usize,20}2122impl ProjectionExprs {23fn default_exprs(&self) -> &[ExprIR] {24&self.expr[..self.expr.len() - self.common_sub_offset]25}2627fn cse_exprs(&self) -> &[ExprIR] {28&self.expr[self.expr.len() - self.common_sub_offset..]29}3031fn new_with_cse(expr: Vec<ExprIR>, common_sub_offset: usize) -> Self {32Self {33expr,34common_sub_offset,35}36}37}3839/// Identifier that shows the sub-expression path.40/// Must implement hash and equality and ideally41/// have little collisions42/// We will do a full expression comparison to check if the43/// expressions with equal identifiers are truly equal44#[derive(Clone, Debug)]45pub(super) struct Identifier {46inner: Option<u64>,47last_node: Option<AexprNode>,48hb: PlFixedStateQuality,49}5051impl Identifier {52fn new() -> Self {53Self {54inner: None,55last_node: None,56hb: PlFixedStateQuality::with_seed(0),57}58}5960fn hash(&self) -> u64 {61self.inner.unwrap_or(0)62}6364fn ae_node(&self) -> AexprNode {65self.last_node.unwrap()66}6768fn is_equal(&self, other: &Self, arena: &Arena<AExpr>) -> bool {69self.inner == other.inner70&& self.last_node.map(|v| v.hashable_and_cmp(arena))71== other.last_node.map(|v| v.hashable_and_cmp(arena))72}7374fn is_valid(&self) -> bool {75self.inner.is_some()76}7778fn materialize(&self) -> PlSmallStr {79format_pl_smallstr!("{}{:#x}", CSE_REPLACED, self.materialized_hash())80}8182fn materialized_hash(&self) -> u64 {83self.inner.unwrap_or(0)84}8586fn combine(&mut self, other: &Identifier) {87let inner = match (self.inner, other.inner) {88(Some(l), Some(r)) => _boost_hash_combine(l, r),89(None, Some(r)) => r,90(Some(l), None) => l,91_ => return,92};93self.inner = Some(inner);94}9596fn add_ae_node(&self, ae: &AexprNode, arena: &Arena<AExpr>) -> Self {97let hashed = self.hb.hash_one(ae.to_aexpr(arena));98let inner = Some(99self.inner100.map_or(hashed, |l| _boost_hash_combine(l, hashed)),101);102Self {103inner,104last_node: Some(*ae),105hb: self.hb.clone(),106}107}108}109110#[derive(Default)]111struct IdentifierMap<V> {112inner: PlHashMap<Identifier, V>,113}114115impl<V> IdentifierMap<V> {116fn get(&self, id: &Identifier, arena: &Arena<AExpr>) -> Option<&V> {117self.inner118.raw_entry()119.from_hash(id.hash(), |k| k.is_equal(id, arena))120.map(|(_k, v)| v)121}122123fn entry<'a, F: FnOnce() -> V>(124&'a mut self,125id: Identifier,126v: F,127arena: &Arena<AExpr>,128) -> &'a mut V {129let h = id.hash();130match self131.inner132.raw_entry_mut()133.from_hash(h, |k| k.is_equal(&id, arena))134{135RawEntryMut::Occupied(entry) => entry.into_mut(),136RawEntryMut::Vacant(entry) => {137let (_, v) = entry.insert_with_hasher(h, id, v(), |id| id.hash());138v139},140}141}142fn insert(&mut self, id: Identifier, v: V, arena: &Arena<AExpr>) {143self.entry(id, || v, arena);144}145146fn iter(&self) -> impl Iterator<Item = (&Identifier, &V)> {147self.inner.iter()148}149}150151/// Merges identical expressions into identical IDs.152///153/// Does no analysis whether this leads to legal substitutions.154#[derive(Default)]155pub struct NaiveExprMerger {156node_to_uniq_id: PlHashMap<Node, u32>,157uniq_id_to_node: Vec<Node>,158identifier_to_uniq_id: IdentifierMap<u32>,159arg_stack: Vec<Option<Identifier>>,160}161162impl NaiveExprMerger {163pub fn add_expr(&mut self, node: Node, arena: &Arena<AExpr>) {164let node = AexprNode::new(node);165node.visit(self, arena).unwrap();166}167168pub fn get_uniq_id(&self, node: Node) -> Option<u32> {169self.node_to_uniq_id.get(&node).copied()170}171172pub fn get_node(&self, uniq_id: u32) -> Option<Node> {173self.uniq_id_to_node.get(uniq_id as usize).copied()174}175}176177impl Visitor for NaiveExprMerger {178type Node = AexprNode;179type Arena = Arena<AExpr>;180181fn pre_visit(182&mut self,183_node: &Self::Node,184_arena: &Self::Arena,185) -> PolarsResult<VisitRecursion> {186self.arg_stack.push(None);187Ok(VisitRecursion::Continue)188}189190fn post_visit(191&mut self,192node: &Self::Node,193arena: &Self::Arena,194) -> PolarsResult<VisitRecursion> {195let mut identifier = Identifier::new();196while let Some(Some(arg)) = self.arg_stack.pop() {197identifier.combine(&arg);198}199identifier = identifier.add_ae_node(node, arena);200let uniq_id = *self.identifier_to_uniq_id.entry(201identifier,202|| {203let uniq_id = self.uniq_id_to_node.len() as u32;204self.uniq_id_to_node.push(node.node());205uniq_id206},207arena,208);209self.node_to_uniq_id.insert(node.node(), uniq_id);210Ok(VisitRecursion::Continue)211}212}213214/// Identifier maps to Expr Node and count.215type SubExprCount = IdentifierMap<(Node, u32)>;216/// (post_visit_idx, identifier);217type IdentifierArray = Vec<(usize, Identifier)>;218219#[derive(Debug)]220enum VisitRecord {221/// entered a new expression222Entered(usize),223/// Every visited sub-expression pushes their identifier to the stack.224// The `bool` indicates if this expression is valid.225// This can be `AND` accumulated by the lineage of the expression to determine226// of the whole expression can be added.227// For instance a in a group_by we only want to use elementwise operation in cse:228// - `(col("a") * 2).sum(), (col("a") * 2)` -> we want to do `col("a") * 2` on a `with_columns`229// - `col("a").sum() * col("a").sum()` -> we don't want `sum` to run on `with_columns`230// as that doesn't have groups context. If we encounter a `sum` it should be flagged as `false`231//232// This should have the following stack233// id valid234// col(a) true235// sum false236// col(a) true237// sum false238// binary true239// -------------- accumulated240// false241SubExprId(Identifier, bool),242}243244fn skip_pre_visit(ae: &AExpr, is_groupby: bool, element_wise_select_only: bool) -> bool {245match ae {246#[cfg(feature = "dynamic_group_by")]247AExpr::Rolling { .. } => true,248AExpr::Over { .. } => true,249#[cfg(feature = "dtype-struct")]250AExpr::Ternary { .. } => is_groupby,251ae => {252if element_wise_select_only {253if is_groupby {254true255} else {256!ae.is_elementwise_top_level()257}258} else {259false260}261},262}263}264265/// Goes through an expression and generates a identifier266///267/// The visitor uses a `visit_stack` to track traversal order.268///269/// # Entering a node270/// When `pre-visit` is called we enter a new (sub)-expression and271/// we add `Entered` to the stack.272/// # Leaving a node273/// On `post-visit` when we leave the node and we pop all `SubExprIds` nodes.274/// Those are considered sub-expression of the leaving node275///276/// We also record an `id_array` that followed the pre-visit order. This277/// is used to cache the `Identifiers`.278//279// # Example (this is not a docstring as clippy complains about spacing)280// Say we have the expression: `(col("f00").min() * col("bar")).sum()`281// with the following call tree:282//283// sum284//285// |286//287// binary: *288//289// | |290//291// col(bar) min292//293// |294//295// col(f00)296//297// # call order298// function-called stack stack-after(pop until E, push I) # ID299// pre-visit: sum E -300// pre-visit: binary: * EE -301// pre-visit: col(bar) EEE -302// post-visit: col(bar) EEE EEI id: col(bar)303// pre-visit: min EEIE -304// pre-visit: col(f00) EEIEE -305// post-visit: col(f00) EEIEE EEIEI id: col(f00)306// post-visit: min EEIEI EEII id: min!col(f00)307// post-visit: binary: * EEII EI id: binary: *!min!col(f00)!col(bar)308// post-visit: sum EI I id: sum!binary: *!min!col(f00)!col(bar)309struct ExprIdentifierVisitor<'a> {310se_count: &'a mut SubExprCount,311/// Materialized `CSE` materialized (name) hashes can collide. So we validate that all CSE counts312/// match name hash counts.313name_validation: &'a mut PlHashMap<u64, u32>,314identifier_array: &'a mut IdentifierArray,315// Index in pre-visit traversal order.316pre_visit_idx: usize,317post_visit_idx: usize,318visit_stack: &'a mut Vec<VisitRecord>,319/// Offset in the identifier array320/// this allows us to use a single `vec` on multiple expressions321id_array_offset: usize,322// Whether the expression replaced a subexpression.323has_sub_expr: bool,324// During aggregation we only identify element-wise operations325is_group_by: bool,326//327element_wise_only: bool,328}329330impl ExprIdentifierVisitor<'_> {331fn new<'a>(332se_count: &'a mut SubExprCount,333identifier_array: &'a mut IdentifierArray,334visit_stack: &'a mut Vec<VisitRecord>,335is_group_by: bool,336name_validation: &'a mut PlHashMap<u64, u32>,337element_wise_select_only: bool,338) -> ExprIdentifierVisitor<'a> {339let id_array_offset = identifier_array.len();340ExprIdentifierVisitor {341se_count,342name_validation,343identifier_array,344pre_visit_idx: 0,345post_visit_idx: 0,346visit_stack,347id_array_offset,348has_sub_expr: false,349is_group_by,350element_wise_only: element_wise_select_only,351}352}353354/// pop all visit-records until an `Entered` is found. We accumulate a `SubExprId`s355/// to `id`. Finally we return the expression `idx` and `Identifier`.356/// This works due to the stack.357/// If we traverse another expression in the mean time, it will get popped of the stack first358/// so the returned identifier belongs to a single sub-expression359fn pop_until_entered(&mut self) -> (usize, Identifier, bool) {360let mut id = Identifier::new();361let mut is_valid_accumulated = true;362363while let Some(item) = self.visit_stack.pop() {364match item {365VisitRecord::Entered(idx) => return (idx, id, is_valid_accumulated),366VisitRecord::SubExprId(s, valid) => {367id.combine(&s);368is_valid_accumulated &= valid369},370}371}372unreachable!()373}374375/// return `None` -> node is accepted376/// return `Some(_)` node is not accepted and apply the given recursion operation377/// `Some(_, true)` don't accept this node, but can be a member of a cse.378/// `Some(_, false)` don't accept this node, and don't allow as a member of a cse.379fn accept_node_post_visit(&self, ae: &AExpr) -> Accepted {380match ae {381// window expressions should `evaluate_on_groups`, not `evaluate`382// so we shouldn't cache the children as they are evaluated incorrectly383#[cfg(feature = "dynamic_group_by")]384AExpr::Rolling { .. } => REFUSE_SKIP,385AExpr::Over { .. } => REFUSE_SKIP,386// Don't allow this for now, as we can get `null().cast()` in ternary expressions.387// TODO! Add a typed null388AExpr::Literal(LiteralValue::Scalar(sc)) if sc.is_null() => REFUSE_NO_MEMBER,389AExpr::Literal(s) => {390match s {391LiteralValue::Series(s) => {392let dtype = s.dtype();393394// Object and nested types are harder to hash and compare.395let allow = !(dtype.is_nested() | dtype.is_object());396397if s.len() < CHEAP_SERIES_HASH_LIMIT && allow {398REFUSE_ALLOW_MEMBER399} else {400REFUSE_NO_MEMBER401}402},403_ => REFUSE_ALLOW_MEMBER,404}405},406AExpr::Column(_) => REFUSE_ALLOW_MEMBER,407AExpr::Len => {408if self.is_group_by {409REFUSE_NO_MEMBER410} else {411REFUSE_ALLOW_MEMBER412}413},414#[cfg(feature = "random")]415AExpr::Function {416function: IRFunctionExpr::Random { .. },417..418} => REFUSE_NO_MEMBER,419#[cfg(feature = "rolling_window")]420AExpr::Function {421function: IRFunctionExpr::RollingExpr { .. },422..423} => REFUSE_NO_MEMBER,424_ => {425// During aggregation we only store elementwise operation in the state426// other operations we cannot add to the state as they have the output size of the427// groups, not the original dataframe428if self.is_group_by {429if !ae.is_elementwise_top_level() {430return REFUSE_NO_MEMBER;431}432match ae {433AExpr::Cast { .. } => REFUSE_ALLOW_MEMBER,434_ => ACCEPT,435}436} else {437ACCEPT438}439},440}441}442}443444impl Visitor for ExprIdentifierVisitor<'_> {445type Node = AexprNode;446type Arena = Arena<AExpr>;447448fn pre_visit(449&mut self,450node: &Self::Node,451arena: &Self::Arena,452) -> PolarsResult<VisitRecursion> {453if skip_pre_visit(454node.to_aexpr(arena),455self.is_group_by,456self.element_wise_only,457) {458// Still add to the stack so that a parent becomes invalidated.459self.visit_stack460.push(VisitRecord::SubExprId(Identifier::new(), false));461return Ok(VisitRecursion::Skip);462}463464self.visit_stack465.push(VisitRecord::Entered(self.pre_visit_idx));466self.pre_visit_idx += 1;467468// implement default placeholders469self.identifier_array470.push((self.id_array_offset, Identifier::new()));471472Ok(VisitRecursion::Continue)473}474475fn post_visit(476&mut self,477node: &Self::Node,478arena: &Self::Arena,479) -> PolarsResult<VisitRecursion> {480let ae = node.to_aexpr(arena);481self.post_visit_idx += 1;482483let (pre_visit_idx, sub_expr_id, is_valid_accumulated) = self.pop_until_entered();484// Create the Id of this node.485let id: Identifier = sub_expr_id.add_ae_node(node, arena);486487if !is_valid_accumulated {488self.identifier_array[pre_visit_idx + self.id_array_offset].0 = self.post_visit_idx;489self.visit_stack.push(VisitRecord::SubExprId(id, false));490return Ok(VisitRecursion::Continue);491}492493// If we don't store this node494// we only push the visit_stack, so the parents know the trail.495if let Some((recurse, local_is_valid)) = self.accept_node_post_visit(ae) {496self.identifier_array[pre_visit_idx + self.id_array_offset].0 = self.post_visit_idx;497498self.visit_stack499.push(VisitRecord::SubExprId(id, local_is_valid));500return Ok(recurse);501}502503// Store the created id.504self.identifier_array[pre_visit_idx + self.id_array_offset] =505(self.post_visit_idx, id.clone());506507// We popped until entered, push this Id on the stack so the trail508// is available for the parent expression.509self.visit_stack510.push(VisitRecord::SubExprId(id.clone(), true));511512let mat_h = id.materialized_hash();513let (_, se_count) = self.se_count.entry(id, || (node.node(), 0), arena);514515*se_count += 1;516*self.name_validation.entry(mat_h).or_insert(0) += 1;517self.has_sub_expr |= *se_count > 1;518519Ok(VisitRecursion::Continue)520}521}522523struct CommonSubExprRewriter<'a> {524sub_expr_map: &'a SubExprCount,525identifier_array: &'a IdentifierArray,526/// keep track of the replaced identifiers.527replaced_identifiers: &'a mut IdentifierMap<()>,528529max_post_visit_idx: usize,530/// index in traversal order in which `identifier_array`531/// was written. This is the index in `identifier_array`.532visited_idx: usize,533/// Offset in the identifier array.534/// This allows us to use a single `vec` on multiple expressions535id_array_offset: usize,536/// Indicates if this expression is rewritten.537rewritten: bool,538is_group_by: bool,539is_element_wise_select_only: bool,540}541542impl<'a> CommonSubExprRewriter<'a> {543fn new(544sub_expr_map: &'a SubExprCount,545identifier_array: &'a IdentifierArray,546replaced_identifiers: &'a mut IdentifierMap<()>,547id_array_offset: usize,548is_group_by: bool,549is_element_wise_select_only: bool,550) -> Self {551Self {552sub_expr_map,553identifier_array,554replaced_identifiers,555max_post_visit_idx: 0,556visited_idx: 0,557id_array_offset,558rewritten: false,559is_group_by,560is_element_wise_select_only,561}562}563}564565// # Example566// Expression tree with [pre-visit,post-visit] indices567// counted from 1568// [1,8] binary: +569//570// | |571//572// [2,2] sum [4,7] sum573//574// | |575//576// [3,1] col(foo) [5,6] binary: *577//578// | |579//580// [6,3] col(bar) [7,5] sum581//582// |583//584// [8,4] col(foo)585//586// in this tree `col(foo).sum()` should be post-visited/mutated587// so if we are at `[2,2]`588//589// call stack590// pre-visit [1,8] binary -> no_mutate_and_continue -> visits children591// pre-visit [2,2] sum -> mutate_and_stop -> does not visit children592// post-visit [2,2] sum -> skip index to [4,7] (because we didn't visit children)593// pre-visit [4,7] sum -> no_mutate_and_continue -> visits children594// pre-visit [5,6] binary -> no_mutate_and_continue -> visits children595// pre-visit [6,3] col -> stop_recursion -> does not mutate596// pre-visit [7,5] sum -> mutate_and_stop -> does not visit children597// post-visit [7,5] -> skip index to end598impl RewritingVisitor for CommonSubExprRewriter<'_> {599type Node = AexprNode;600type Arena = Arena<AExpr>;601602fn pre_visit(603&mut self,604ae_node: &Self::Node,605arena: &mut Self::Arena,606) -> PolarsResult<RewriteRecursion> {607let ae = ae_node.to_aexpr(arena);608if self.visited_idx + self.id_array_offset >= self.identifier_array.len()609|| self.max_post_visit_idx610> self.identifier_array[self.visited_idx + self.id_array_offset].0611|| skip_pre_visit(ae, self.is_group_by, self.is_element_wise_select_only)612{613return Ok(RewriteRecursion::Stop);614}615616let id = &self.identifier_array[self.visited_idx + self.id_array_offset].1;617618// Id placeholder not overwritten, so we can skip this sub-expression.619if !id.is_valid() {620self.visited_idx += 1;621let recurse = if ae_node.is_leaf(arena) {622RewriteRecursion::Stop623} else {624// continue visit its children to see625// if there are cse626RewriteRecursion::NoMutateAndContinue627};628return Ok(recurse);629}630631// Because some expressions don't have hash / equality guarantee (e.g. floats)632// we can get none here. This must be changed later.633let Some((_, count)) = self.sub_expr_map.get(id, arena) else {634self.visited_idx += 1;635return Ok(RewriteRecursion::NoMutateAndContinue);636};637if *count > 1 {638self.replaced_identifiers.insert(id.clone(), (), arena);639// rewrite this sub-expression, don't visit its children640Ok(RewriteRecursion::MutateAndStop)641} else {642// This is a unique expression643// visit its children to see if they are cse644self.visited_idx += 1;645Ok(RewriteRecursion::NoMutateAndContinue)646}647}648649fn mutate(650&mut self,651mut node: Self::Node,652arena: &mut Self::Arena,653) -> PolarsResult<Self::Node> {654let (post_visit_count, id) =655&self.identifier_array[self.visited_idx + self.id_array_offset];656self.visited_idx += 1;657658// TODO!: check if we ever hit this branch659if *post_visit_count < self.max_post_visit_idx {660return Ok(node);661}662663self.max_post_visit_idx = *post_visit_count;664// DFS, so every post_visit that is smaller than `post_visit_count`665// is a subexpression of this node and we can skip that666//667// `self.visited_idx` will influence recursion strategy in `pre_visit`668// see call-stack comment above669while self.visited_idx < self.identifier_array.len() - self.id_array_offset670&& *post_visit_count > self.identifier_array[self.visited_idx + self.id_array_offset].0671{672self.visited_idx += 1;673}674// If this is not true, the traversal order in the visitor was different from the rewriter.675debug_assert_eq!(676node.hashable_and_cmp(arena),677id.ae_node().hashable_and_cmp(arena)678);679680let name = id.materialize();681node.assign(AExpr::col(name), arena);682self.rewritten = true;683684Ok(node)685}686}687688pub(crate) struct CommonSubExprOptimizer {689// amortize allocations690// these are cleared per lp node691se_count: SubExprCount,692id_array: IdentifierArray,693id_array_offsets: Vec<u32>,694replaced_identifiers: IdentifierMap<()>,695// these are cleared per expr node696visit_stack: Vec<VisitRecord>,697name_validation: PlHashMap<u64, u32>,698// Set by the streaming engine699// Only supports element-wise CSEE700// on SELECT/HSTACK701element_wise_select_only: bool,702}703704impl CommonSubExprOptimizer {705pub(crate) fn new(element_wise_select_only: bool) -> Self {706Self {707se_count: Default::default(),708id_array: Default::default(),709visit_stack: Default::default(),710id_array_offsets: Default::default(),711replaced_identifiers: Default::default(),712name_validation: Default::default(),713element_wise_select_only,714}715}716717fn visit_expression(718&mut self,719ae_node: AexprNode,720is_group_by: bool,721expr_arena: &mut Arena<AExpr>,722element_wise_select_only: bool,723) -> PolarsResult<(usize, bool)> {724let mut visitor = ExprIdentifierVisitor::new(725&mut self.se_count,726&mut self.id_array,727&mut self.visit_stack,728is_group_by,729&mut self.name_validation,730element_wise_select_only,731);732ae_node.visit(&mut visitor, expr_arena).map(|_| ())?;733Ok((visitor.id_array_offset, visitor.has_sub_expr))734}735736/// Mutate the expression.737/// Returns a new expression and a `bool` indicating if it was rewritten or not.738fn mutate_expression(739&mut self,740ae_node: AexprNode,741id_array_offset: usize,742is_group_by: bool,743expr_arena: &mut Arena<AExpr>,744element_wise_select_only: bool,745) -> PolarsResult<(AexprNode, bool)> {746let mut rewriter = CommonSubExprRewriter::new(747&self.se_count,748&self.id_array,749&mut self.replaced_identifiers,750id_array_offset,751is_group_by,752element_wise_select_only,753);754ae_node755.rewrite(&mut rewriter, expr_arena)756.map(|out| (out, rewriter.rewritten))757}758759fn find_cse(760&mut self,761expr: &[ExprIR],762expr_arena: &mut Arena<AExpr>,763id_array_offsets: &mut Vec<u32>,764is_group_by: bool,765schema: &Schema,766element_wise_select_only: bool,767) -> PolarsResult<Option<ProjectionExprs>> {768let mut has_sub_expr = false;769770// First get all cse's.771for e in expr {772// The visitor can return early thus depleted its stack773// on a previous iteration.774self.visit_stack.clear();775776// Visit expressions and collect sub-expression counts.777let ae_node = AexprNode::new(e.node());778let (id_array_offset, this_expr_has_se) =779self.visit_expression(ae_node, is_group_by, expr_arena, element_wise_select_only)?;780id_array_offsets.push(id_array_offset as u32);781has_sub_expr |= this_expr_has_se;782}783784// Ensure that the `materialized hashes` count matches that of the CSE count.785// It can happen that CSE collide and in that case we fallback and skip CSE.786for (id, (_, count)) in self.se_count.iter() {787let mat_h = id.materialized_hash();788let valid = if let Some(name_count) = self.name_validation.get(&mat_h) {789*name_count == *count790} else {791false792};793794if !valid {795if verbose() {796eprintln!(797"materialized names collided in common subexpression elimination.\n backtrace and run without CSE"798)799}800return Ok(None);801}802}803804if has_sub_expr {805let mut new_expr = Vec::with_capacity_by_factor(expr.len(), 1.3);806807// Then rewrite the expressions that have a cse count > 1.808for (e, offset) in expr.iter().zip(id_array_offsets.iter()) {809let ae_node = AexprNode::new(e.node());810811let (out, rewritten) = self.mutate_expression(812ae_node,813*offset as usize,814is_group_by,815expr_arena,816element_wise_select_only,817)?;818819let out_node = out.node();820let mut out_e = e.clone();821let new_node = if !rewritten {822out_e823} else {824out_e.set_node(out_node);825826// Ensure the function ExprIR's have the proper names.827// This is needed for structs to get the proper field828let mut scratch = vec![];829let mut stack = vec![(e.node(), out_node)];830while let Some((original, new)) = stack.pop() {831// Don't follow identical nodes.832if original == new {833continue;834}835scratch.clear();836let aes = expr_arena.get_disjoint_mut([original, new]);837838// Only follow paths that are the same.839if std::mem::discriminant(aes[0]) != std::mem::discriminant(aes[1]) {840continue;841}842843aes[0].inputs_rev(&mut scratch);844let offset = scratch.len();845aes[1].inputs_rev(&mut scratch);846847// If they have a different number of inputs, we don't follow the nodes.848if scratch.len() != offset * 2 {849continue;850}851852for i in 0..scratch.len() / 2 {853stack.push((scratch[i], scratch[i + offset]));854}855856match expr_arena.get_disjoint_mut([original, new]) {857[858AExpr::Function {859input: input_original,860..861},862AExpr::Function {863input: input_new, ..864},865] => {866for (new, original) in input_new.iter_mut().zip(input_original) {867new.set_alias(original.output_name().clone());868}869},870[871AExpr::AnonymousFunction {872input: input_original,873..874},875AExpr::AnonymousFunction {876input: input_new, ..877},878] => {879for (new, original) in input_new.iter_mut().zip(input_original) {880new.set_alias(original.output_name().clone());881}882},883_ => {},884}885}886887// If we don't end with an alias we add an alias. Because the normal left-hand888// rule we apply for determining the name will not work we now refer to889// intermediate temporary names starting with the `CSE_REPLACED` constant.890if !e.has_alias() {891let name = ae_node.to_field(schema, expr_arena)?.name;892out_e.set_alias(name.clone());893}894out_e895};896new_expr.push(new_node)897}898// Add the tmp columns899for id in self.replaced_identifiers.inner.keys() {900let (node, _count) = self.se_count.get(id, expr_arena).unwrap();901let name = id.materialize();902let out_e = ExprIR::new(*node, OutputName::Alias(name));903new_expr.push(out_e)904}905let expr =906ProjectionExprs::new_with_cse(new_expr, self.replaced_identifiers.inner.len());907Ok(Some(expr))908} else {909Ok(None)910}911}912}913914impl RewritingVisitor for CommonSubExprOptimizer {915type Node = IRNode;916type Arena = IRNodeArena;917918fn pre_visit(919&mut self,920node: &Self::Node,921arena: &mut Self::Arena,922) -> PolarsResult<RewriteRecursion> {923use IR::*;924Ok(match node.to_alp(&arena.0) {925Select { .. } | HStack { .. } | GroupBy { .. } => RewriteRecursion::MutateAndContinue,926_ => RewriteRecursion::NoMutateAndContinue,927})928}929930fn mutate(&mut self, node: Self::Node, arena: &mut Self::Arena) -> PolarsResult<Self::Node> {931let mut id_array_offsets = std::mem::take(&mut self.id_array_offsets);932933self.se_count.inner.clear();934self.name_validation.clear();935self.id_array.clear();936id_array_offsets.clear();937self.replaced_identifiers.inner.clear();938939let arena_idx = node.node();940let alp = arena.0.get(arena_idx);941942match alp {943IR::Select {944input,945expr,946schema,947options,948} => {949let input_schema = arena.0.get(*input).schema(&arena.0);950if let Some(expr) = self.find_cse(951expr,952&mut arena.1,953&mut id_array_offsets,954false,955input_schema.as_ref().as_ref(),956self.element_wise_select_only,957)? {958let schema = schema.clone();959let options = *options;960961let lp = IRBuilder::new(*input, &mut arena.1, &mut arena.0)962.with_columns(963expr.cse_exprs().to_vec(),964ProjectionOptions {965run_parallel: options.run_parallel,966duplicate_check: options.duplicate_check,967// These columns might have different968// lengths from the dataframe, but969// they are only temporaries that will970// be removed by the evaluation of the971// default_exprs and the subsequent972// projection.973should_broadcast: false,974},975)976.build();977let input = arena.0.add(lp);978979let lp = IR::Select {980input,981expr: expr.default_exprs().to_vec(),982schema,983options,984};985arena.0.replace(arena_idx, lp);986}987},988IR::HStack {989input,990exprs,991schema,992options,993} => {994let input_schema = arena.0.get(*input).schema(&arena.0);995if let Some(exprs) = self.find_cse(996exprs,997&mut arena.1,998&mut id_array_offsets,999false,1000input_schema.as_ref().as_ref(),1001self.element_wise_select_only,1002)? {1003let schema = schema.clone();1004let options = *options;1005let input = *input;10061007let lp = IRBuilder::new(input, &mut arena.1, &mut arena.0)1008.with_columns(1009exprs.cse_exprs().to_vec(),1010// These columns might have different1011// lengths from the dataframe, but they1012// are only temporaries that will be1013// removed by the evaluation of the1014// default_exprs and the subsequent1015// projection.1016ProjectionOptions {1017run_parallel: options.run_parallel,1018duplicate_check: options.duplicate_check,1019should_broadcast: false,1020},1021)1022.with_columns(exprs.default_exprs().to_vec(), options)1023.build();1024let input = arena.0.add(lp);10251026let lp = IR::SimpleProjection {1027input,1028columns: schema,1029};1030arena.0.replace(arena_idx, lp);1031}1032},1033IR::GroupBy {1034input,1035keys,1036aggs,1037options,1038maintain_order,1039apply,1040schema,1041} if !self.element_wise_select_only => {1042let input_schema = arena.0.get(*input).schema(&arena.0);1043if let Some(aggs) = self.find_cse(1044aggs,1045&mut arena.1,1046&mut id_array_offsets,1047true,1048input_schema.as_ref().as_ref(),1049self.element_wise_select_only,1050)? {1051let keys = keys.clone();1052let options = options.clone();1053let schema = schema.clone();1054let apply = apply.clone();1055let maintain_order = *maintain_order;1056let input = *input;10571058let lp = IRBuilder::new(input, &mut arena.1, &mut arena.0)1059.with_columns(aggs.cse_exprs().to_vec(), Default::default())1060.build();1061let input = arena.0.add(lp);10621063let lp = IR::GroupBy {1064input,1065keys,1066aggs: aggs.default_exprs().to_vec(),1067options,1068schema,1069maintain_order,1070apply,1071};1072arena.0.replace(arena_idx, lp);1073}1074},1075_ => {},1076}10771078self.id_array_offsets = id_array_offsets;1079Ok(node)1080}1081}108210831084