Path: blob/main/crates/polars-plan/src/plans/optimizer/cse/csee.rs
7889 views
use std::hash::BuildHasher;12use hashbrown::hash_map::RawEntryMut;3use polars_core::CHEAP_SERIES_HASH_LIMIT;4use polars_utils::aliases::PlFixedStateQuality;5use polars_utils::format_pl_smallstr;6use polars_utils::hashing::_boost_hash_combine;7use polars_utils::vec::CapacityByFactor;89use super::*;10use crate::constants::CSE_REPLACED;11use crate::prelude::visitor::AexprNode;1213#[derive(Debug, Clone)]14struct ProjectionExprs {15expr: Vec<ExprIR>,16/// offset from the back17/// `expr[expr.len() - common_sub_offset..]`18/// are the common sub expressions19common_sub_offset: usize,20}2122impl ProjectionExprs {23fn default_exprs(&self) -> &[ExprIR] {24&self.expr[..self.expr.len() - self.common_sub_offset]25}2627fn cse_exprs(&self) -> &[ExprIR] {28&self.expr[self.expr.len() - self.common_sub_offset..]29}3031fn new_with_cse(expr: Vec<ExprIR>, common_sub_offset: usize) -> Self {32Self {33expr,34common_sub_offset,35}36}37}3839/// Identifier that shows the sub-expression path.40/// Must implement hash and equality and ideally41/// have little collisions42/// We will do a full expression comparison to check if the43/// expressions with equal identifiers are truly equal44#[derive(Clone, Debug)]45pub(super) struct Identifier {46inner: Option<u64>,47last_node: Option<AexprNode>,48hb: PlFixedStateQuality,49}5051impl Identifier {52fn new() -> Self {53Self {54inner: None,55last_node: None,56hb: PlFixedStateQuality::with_seed(0),57}58}5960fn hash(&self) -> u64 {61self.inner.unwrap_or(0)62}6364fn ae_node(&self) -> AexprNode {65self.last_node.unwrap()66}6768fn is_equal(&self, other: &Self, arena: &Arena<AExpr>) -> bool {69self.inner == other.inner70&& self.last_node.map(|v| v.hashable_and_cmp(arena))71== other.last_node.map(|v| v.hashable_and_cmp(arena))72}7374fn is_valid(&self) -> bool {75self.inner.is_some()76}7778fn materialize(&self) -> PlSmallStr {79format_pl_smallstr!("{}{:#x}", CSE_REPLACED, self.materialized_hash())80}8182fn materialized_hash(&self) -> u64 {83self.inner.unwrap_or(0)84}8586fn combine(&mut self, other: &Identifier) {87let inner = match (self.inner, other.inner) {88(Some(l), Some(r)) => _boost_hash_combine(l, r),89(None, Some(r)) => r,90(Some(l), None) => l,91_ => return,92};93self.inner = Some(inner);94}9596fn add_ae_node(&self, ae: &AexprNode, arena: &Arena<AExpr>) -> Self {97let hashed = self.hb.hash_one(ae.to_aexpr(arena));98let inner = Some(99self.inner100.map_or(hashed, |l| _boost_hash_combine(l, hashed)),101);102Self {103inner,104last_node: Some(*ae),105hb: self.hb.clone(),106}107}108}109110#[derive(Default)]111struct IdentifierMap<V> {112inner: PlHashMap<Identifier, V>,113}114115impl<V> IdentifierMap<V> {116fn get(&self, id: &Identifier, arena: &Arena<AExpr>) -> Option<&V> {117self.inner118.raw_entry()119.from_hash(id.hash(), |k| k.is_equal(id, arena))120.map(|(_k, v)| v)121}122123fn entry<'a, F: FnOnce() -> V>(124&'a mut self,125id: Identifier,126v: F,127arena: &Arena<AExpr>,128) -> &'a mut V {129let h = id.hash();130match self131.inner132.raw_entry_mut()133.from_hash(h, |k| k.is_equal(&id, arena))134{135RawEntryMut::Occupied(entry) => entry.into_mut(),136RawEntryMut::Vacant(entry) => {137let (_, v) = entry.insert_with_hasher(h, id, v(), |id| id.hash());138v139},140}141}142fn insert(&mut self, id: Identifier, v: V, arena: &Arena<AExpr>) {143self.entry(id, || v, arena);144}145146fn iter(&self) -> impl Iterator<Item = (&Identifier, &V)> {147self.inner.iter()148}149}150151/// Merges identical expressions into identical IDs.152///153/// Does no analysis whether this leads to legal substitutions.154#[derive(Default)]155pub struct NaiveExprMerger {156node_to_uniq_id: PlHashMap<Node, u32>,157uniq_id_to_node: Vec<Node>,158identifier_to_uniq_id: IdentifierMap<u32>,159arg_stack: Vec<Option<Identifier>>,160}161162impl NaiveExprMerger {163pub fn add_expr(&mut self, node: Node, arena: &Arena<AExpr>) {164let node = AexprNode::new(node);165node.visit(self, arena).unwrap();166}167168pub fn get_uniq_id(&self, node: Node) -> Option<u32> {169self.node_to_uniq_id.get(&node).copied()170}171172pub fn get_node(&self, uniq_id: u32) -> Option<Node> {173self.uniq_id_to_node.get(uniq_id as usize).copied()174}175}176177impl Visitor for NaiveExprMerger {178type Node = AexprNode;179type Arena = Arena<AExpr>;180181fn pre_visit(182&mut self,183_node: &Self::Node,184_arena: &Self::Arena,185) -> PolarsResult<VisitRecursion> {186self.arg_stack.push(None);187Ok(VisitRecursion::Continue)188}189190fn post_visit(191&mut self,192node: &Self::Node,193arena: &Self::Arena,194) -> PolarsResult<VisitRecursion> {195let mut identifier = Identifier::new();196while let Some(Some(arg)) = self.arg_stack.pop() {197identifier.combine(&arg);198}199identifier = identifier.add_ae_node(node, arena);200let uniq_id = *self.identifier_to_uniq_id.entry(201identifier,202|| {203let uniq_id = self.uniq_id_to_node.len() as u32;204self.uniq_id_to_node.push(node.node());205uniq_id206},207arena,208);209self.node_to_uniq_id.insert(node.node(), uniq_id);210Ok(VisitRecursion::Continue)211}212}213214/// Identifier maps to Expr Node and count.215type SubExprCount = IdentifierMap<(Node, u32)>;216/// (post_visit_idx, identifier);217type IdentifierArray = Vec<(usize, Identifier)>;218219#[derive(Debug)]220enum VisitRecord {221/// entered a new expression222Entered(usize),223/// Every visited sub-expression pushes their identifier to the stack.224// The `bool` indicates if this expression is valid.225// This can be `AND` accumulated by the lineage of the expression to determine226// of the whole expression can be added.227// For instance a in a group_by we only want to use elementwise operation in cse:228// - `(col("a") * 2).sum(), (col("a") * 2)` -> we want to do `col("a") * 2` on a `with_columns`229// - `col("a").sum() * col("a").sum()` -> we don't want `sum` to run on `with_columns`230// as that doesn't have groups context. If we encounter a `sum` it should be flagged as `false`231//232// This should have the following stack233// id valid234// col(a) true235// sum false236// col(a) true237// sum false238// binary true239// -------------- accumulated240// false241SubExprId(Identifier, bool),242}243244fn skip_pre_visit(ae: &AExpr, is_groupby: bool) -> bool {245match ae {246#[cfg(feature = "dynamic_group_by")]247AExpr::Rolling { .. } => true,248AExpr::Over { .. } => true,249#[cfg(feature = "dtype-struct")]250AExpr::Ternary { .. } => is_groupby,251_ => false,252}253}254255/// Goes through an expression and generates a identifier256///257/// The visitor uses a `visit_stack` to track traversal order.258///259/// # Entering a node260/// When `pre-visit` is called we enter a new (sub)-expression and261/// we add `Entered` to the stack.262/// # Leaving a node263/// On `post-visit` when we leave the node and we pop all `SubExprIds` nodes.264/// Those are considered sub-expression of the leaving node265///266/// We also record an `id_array` that followed the pre-visit order. This267/// is used to cache the `Identifiers`.268//269// # Example (this is not a docstring as clippy complains about spacing)270// Say we have the expression: `(col("f00").min() * col("bar")).sum()`271// with the following call tree:272//273// sum274//275// |276//277// binary: *278//279// | |280//281// col(bar) min282//283// |284//285// col(f00)286//287// # call order288// function-called stack stack-after(pop until E, push I) # ID289// pre-visit: sum E -290// pre-visit: binary: * EE -291// pre-visit: col(bar) EEE -292// post-visit: col(bar) EEE EEI id: col(bar)293// pre-visit: min EEIE -294// pre-visit: col(f00) EEIEE -295// post-visit: col(f00) EEIEE EEIEI id: col(f00)296// post-visit: min EEIEI EEII id: min!col(f00)297// post-visit: binary: * EEII EI id: binary: *!min!col(f00)!col(bar)298// post-visit: sum EI I id: sum!binary: *!min!col(f00)!col(bar)299struct ExprIdentifierVisitor<'a> {300se_count: &'a mut SubExprCount,301/// Materialized `CSE` materialized (name) hashes can collide. So we validate that all CSE counts302/// match name hash counts.303name_validation: &'a mut PlHashMap<u64, u32>,304identifier_array: &'a mut IdentifierArray,305// Index in pre-visit traversal order.306pre_visit_idx: usize,307post_visit_idx: usize,308visit_stack: &'a mut Vec<VisitRecord>,309/// Offset in the identifier array310/// this allows us to use a single `vec` on multiple expressions311id_array_offset: usize,312// Whether the expression replaced a subexpression.313has_sub_expr: bool,314// During aggregation we only identify element-wise operations315is_group_by: bool,316}317318impl ExprIdentifierVisitor<'_> {319fn new<'a>(320se_count: &'a mut SubExprCount,321identifier_array: &'a mut IdentifierArray,322visit_stack: &'a mut Vec<VisitRecord>,323is_group_by: bool,324name_validation: &'a mut PlHashMap<u64, u32>,325) -> ExprIdentifierVisitor<'a> {326let id_array_offset = identifier_array.len();327ExprIdentifierVisitor {328se_count,329name_validation,330identifier_array,331pre_visit_idx: 0,332post_visit_idx: 0,333visit_stack,334id_array_offset,335has_sub_expr: false,336is_group_by,337}338}339340/// pop all visit-records until an `Entered` is found. We accumulate a `SubExprId`s341/// to `id`. Finally we return the expression `idx` and `Identifier`.342/// This works due to the stack.343/// If we traverse another expression in the mean time, it will get popped of the stack first344/// so the returned identifier belongs to a single sub-expression345fn pop_until_entered(&mut self) -> (usize, Identifier, bool) {346let mut id = Identifier::new();347let mut is_valid_accumulated = true;348349while let Some(item) = self.visit_stack.pop() {350match item {351VisitRecord::Entered(idx) => return (idx, id, is_valid_accumulated),352VisitRecord::SubExprId(s, valid) => {353id.combine(&s);354is_valid_accumulated &= valid355},356}357}358unreachable!()359}360361/// return `None` -> node is accepted362/// return `Some(_)` node is not accepted and apply the given recursion operation363/// `Some(_, true)` don't accept this node, but can be a member of a cse.364/// `Some(_, false)` don't accept this node, and don't allow as a member of a cse.365fn accept_node_post_visit(&self, ae: &AExpr) -> Accepted {366match ae {367// window expressions should `evaluate_on_groups`, not `evaluate`368// so we shouldn't cache the children as they are evaluated incorrectly369#[cfg(feature = "dynamic_group_by")]370AExpr::Rolling { .. } => REFUSE_SKIP,371AExpr::Over { .. } => REFUSE_SKIP,372// Don't allow this for now, as we can get `null().cast()` in ternary expressions.373// TODO! Add a typed null374AExpr::Literal(LiteralValue::Scalar(sc)) if sc.is_null() => REFUSE_NO_MEMBER,375AExpr::Literal(s) => {376match s {377LiteralValue::Series(s) => {378let dtype = s.dtype();379380// Object and nested types are harder to hash and compare.381let allow = !(dtype.is_nested() | dtype.is_object());382383if s.len() < CHEAP_SERIES_HASH_LIMIT && allow {384REFUSE_ALLOW_MEMBER385} else {386REFUSE_NO_MEMBER387}388},389_ => REFUSE_ALLOW_MEMBER,390}391},392AExpr::Column(_) => REFUSE_ALLOW_MEMBER,393AExpr::Len => {394if self.is_group_by {395REFUSE_NO_MEMBER396} else {397REFUSE_ALLOW_MEMBER398}399},400#[cfg(feature = "random")]401AExpr::Function {402function: IRFunctionExpr::Random { .. },403..404} => REFUSE_NO_MEMBER,405#[cfg(feature = "rolling_window")]406AExpr::Function {407function: IRFunctionExpr::RollingExpr { .. },408..409} => REFUSE_NO_MEMBER,410AExpr::AnonymousFunction { .. } => REFUSE_NO_MEMBER,411_ => {412// During aggregation we only store elementwise operation in the state413// other operations we cannot add to the state as they have the output size of the414// groups, not the original dataframe415if self.is_group_by {416if !ae.is_elementwise_top_level() {417return REFUSE_NO_MEMBER;418}419match ae {420AExpr::AnonymousFunction { .. } => REFUSE_NO_MEMBER,421AExpr::Cast { .. } => REFUSE_ALLOW_MEMBER,422_ => ACCEPT,423}424} else {425ACCEPT426}427},428}429}430}431432impl Visitor for ExprIdentifierVisitor<'_> {433type Node = AexprNode;434type Arena = Arena<AExpr>;435436fn pre_visit(437&mut self,438node: &Self::Node,439arena: &Self::Arena,440) -> PolarsResult<VisitRecursion> {441if skip_pre_visit(node.to_aexpr(arena), self.is_group_by) {442// Still add to the stack so that a parent becomes invalidated.443self.visit_stack444.push(VisitRecord::SubExprId(Identifier::new(), false));445return Ok(VisitRecursion::Skip);446}447448self.visit_stack449.push(VisitRecord::Entered(self.pre_visit_idx));450self.pre_visit_idx += 1;451452// implement default placeholders453self.identifier_array454.push((self.id_array_offset, Identifier::new()));455456Ok(VisitRecursion::Continue)457}458459fn post_visit(460&mut self,461node: &Self::Node,462arena: &Self::Arena,463) -> PolarsResult<VisitRecursion> {464let ae = node.to_aexpr(arena);465self.post_visit_idx += 1;466467let (pre_visit_idx, sub_expr_id, is_valid_accumulated) = self.pop_until_entered();468// Create the Id of this node.469let id: Identifier = sub_expr_id.add_ae_node(node, arena);470471if !is_valid_accumulated {472self.identifier_array[pre_visit_idx + self.id_array_offset].0 = self.post_visit_idx;473self.visit_stack.push(VisitRecord::SubExprId(id, false));474return Ok(VisitRecursion::Continue);475}476477// If we don't store this node478// we only push the visit_stack, so the parents know the trail.479if let Some((recurse, local_is_valid)) = self.accept_node_post_visit(ae) {480self.identifier_array[pre_visit_idx + self.id_array_offset].0 = self.post_visit_idx;481482self.visit_stack483.push(VisitRecord::SubExprId(id, local_is_valid));484return Ok(recurse);485}486487// Store the created id.488self.identifier_array[pre_visit_idx + self.id_array_offset] =489(self.post_visit_idx, id.clone());490491// We popped until entered, push this Id on the stack so the trail492// is available for the parent expression.493self.visit_stack494.push(VisitRecord::SubExprId(id.clone(), true));495496let mat_h = id.materialized_hash();497let (_, se_count) = self.se_count.entry(id, || (node.node(), 0), arena);498499*se_count += 1;500*self.name_validation.entry(mat_h).or_insert(0) += 1;501self.has_sub_expr |= *se_count > 1;502503Ok(VisitRecursion::Continue)504}505}506507struct CommonSubExprRewriter<'a> {508sub_expr_map: &'a SubExprCount,509identifier_array: &'a IdentifierArray,510/// keep track of the replaced identifiers.511replaced_identifiers: &'a mut IdentifierMap<()>,512513max_post_visit_idx: usize,514/// index in traversal order in which `identifier_array`515/// was written. This is the index in `identifier_array`.516visited_idx: usize,517/// Offset in the identifier array.518/// This allows us to use a single `vec` on multiple expressions519id_array_offset: usize,520/// Indicates if this expression is rewritten.521rewritten: bool,522is_group_by: bool,523}524525impl<'a> CommonSubExprRewriter<'a> {526fn new(527sub_expr_map: &'a SubExprCount,528identifier_array: &'a IdentifierArray,529replaced_identifiers: &'a mut IdentifierMap<()>,530id_array_offset: usize,531is_group_by: bool,532) -> Self {533Self {534sub_expr_map,535identifier_array,536replaced_identifiers,537max_post_visit_idx: 0,538visited_idx: 0,539id_array_offset,540rewritten: false,541is_group_by,542}543}544}545546// # Example547// Expression tree with [pre-visit,post-visit] indices548// counted from 1549// [1,8] binary: +550//551// | |552//553// [2,2] sum [4,7] sum554//555// | |556//557// [3,1] col(foo) [5,6] binary: *558//559// | |560//561// [6,3] col(bar) [7,5] sum562//563// |564//565// [8,4] col(foo)566//567// in this tree `col(foo).sum()` should be post-visited/mutated568// so if we are at `[2,2]`569//570// call stack571// pre-visit [1,8] binary -> no_mutate_and_continue -> visits children572// pre-visit [2,2] sum -> mutate_and_stop -> does not visit children573// post-visit [2,2] sum -> skip index to [4,7] (because we didn't visit children)574// pre-visit [4,7] sum -> no_mutate_and_continue -> visits children575// pre-visit [5,6] binary -> no_mutate_and_continue -> visits children576// pre-visit [6,3] col -> stop_recursion -> does not mutate577// pre-visit [7,5] sum -> mutate_and_stop -> does not visit children578// post-visit [7,5] -> skip index to end579impl RewritingVisitor for CommonSubExprRewriter<'_> {580type Node = AexprNode;581type Arena = Arena<AExpr>;582583fn pre_visit(584&mut self,585ae_node: &Self::Node,586arena: &mut Self::Arena,587) -> PolarsResult<RewriteRecursion> {588let ae = ae_node.to_aexpr(arena);589if self.visited_idx + self.id_array_offset >= self.identifier_array.len()590|| self.max_post_visit_idx591> self.identifier_array[self.visited_idx + self.id_array_offset].0592|| skip_pre_visit(ae, self.is_group_by)593{594return Ok(RewriteRecursion::Stop);595}596597let id = &self.identifier_array[self.visited_idx + self.id_array_offset].1;598599// Id placeholder not overwritten, so we can skip this sub-expression.600if !id.is_valid() {601self.visited_idx += 1;602let recurse = if ae_node.is_leaf(arena) {603RewriteRecursion::Stop604} else {605// continue visit its children to see606// if there are cse607RewriteRecursion::NoMutateAndContinue608};609return Ok(recurse);610}611612// Because some expressions don't have hash / equality guarantee (e.g. floats)613// we can get none here. This must be changed later.614let Some((_, count)) = self.sub_expr_map.get(id, arena) else {615self.visited_idx += 1;616return Ok(RewriteRecursion::NoMutateAndContinue);617};618if *count > 1 {619self.replaced_identifiers.insert(id.clone(), (), arena);620// rewrite this sub-expression, don't visit its children621Ok(RewriteRecursion::MutateAndStop)622} else {623// This is a unique expression624// visit its children to see if they are cse625self.visited_idx += 1;626Ok(RewriteRecursion::NoMutateAndContinue)627}628}629630fn mutate(631&mut self,632mut node: Self::Node,633arena: &mut Self::Arena,634) -> PolarsResult<Self::Node> {635let (post_visit_count, id) =636&self.identifier_array[self.visited_idx + self.id_array_offset];637self.visited_idx += 1;638639// TODO!: check if we ever hit this branch640if *post_visit_count < self.max_post_visit_idx {641return Ok(node);642}643644self.max_post_visit_idx = *post_visit_count;645// DFS, so every post_visit that is smaller than `post_visit_count`646// is a subexpression of this node and we can skip that647//648// `self.visited_idx` will influence recursion strategy in `pre_visit`649// see call-stack comment above650while self.visited_idx < self.identifier_array.len() - self.id_array_offset651&& *post_visit_count > self.identifier_array[self.visited_idx + self.id_array_offset].0652{653self.visited_idx += 1;654}655// If this is not true, the traversal order in the visitor was different from the rewriter.656debug_assert_eq!(657node.hashable_and_cmp(arena),658id.ae_node().hashable_and_cmp(arena)659);660661let name = id.materialize();662node.assign(AExpr::col(name), arena);663self.rewritten = true;664665Ok(node)666}667}668669pub(crate) struct CommonSubExprOptimizer {670// amortize allocations671// these are cleared per lp node672se_count: SubExprCount,673id_array: IdentifierArray,674id_array_offsets: Vec<u32>,675replaced_identifiers: IdentifierMap<()>,676// these are cleared per expr node677visit_stack: Vec<VisitRecord>,678name_validation: PlHashMap<u64, u32>,679}680681impl CommonSubExprOptimizer {682pub(crate) fn new() -> Self {683Self {684se_count: Default::default(),685id_array: Default::default(),686visit_stack: Default::default(),687id_array_offsets: Default::default(),688replaced_identifiers: Default::default(),689name_validation: Default::default(),690}691}692693fn visit_expression(694&mut self,695ae_node: AexprNode,696is_group_by: bool,697expr_arena: &mut Arena<AExpr>,698) -> PolarsResult<(usize, bool)> {699let mut visitor = ExprIdentifierVisitor::new(700&mut self.se_count,701&mut self.id_array,702&mut self.visit_stack,703is_group_by,704&mut self.name_validation,705);706ae_node.visit(&mut visitor, expr_arena).map(|_| ())?;707Ok((visitor.id_array_offset, visitor.has_sub_expr))708}709710/// Mutate the expression.711/// Returns a new expression and a `bool` indicating if it was rewritten or not.712fn mutate_expression(713&mut self,714ae_node: AexprNode,715id_array_offset: usize,716is_group_by: bool,717expr_arena: &mut Arena<AExpr>,718) -> PolarsResult<(AexprNode, bool)> {719let mut rewriter = CommonSubExprRewriter::new(720&self.se_count,721&self.id_array,722&mut self.replaced_identifiers,723id_array_offset,724is_group_by,725);726ae_node727.rewrite(&mut rewriter, expr_arena)728.map(|out| (out, rewriter.rewritten))729}730731fn find_cse(732&mut self,733expr: &[ExprIR],734expr_arena: &mut Arena<AExpr>,735id_array_offsets: &mut Vec<u32>,736is_group_by: bool,737schema: &Schema,738) -> PolarsResult<Option<ProjectionExprs>> {739let mut has_sub_expr = false;740741// First get all cse's.742for e in expr {743// The visitor can return early thus depleted its stack744// on a previous iteration.745self.visit_stack.clear();746747// Visit expressions and collect sub-expression counts.748let ae_node = AexprNode::new(e.node());749let (id_array_offset, this_expr_has_se) =750self.visit_expression(ae_node, is_group_by, expr_arena)?;751id_array_offsets.push(id_array_offset as u32);752has_sub_expr |= this_expr_has_se;753}754755// Ensure that the `materialized hashes` count matches that of the CSE count.756// It can happen that CSE collide and in that case we fallback and skip CSE.757for (id, (_, count)) in self.se_count.iter() {758let mat_h = id.materialized_hash();759let valid = if let Some(name_count) = self.name_validation.get(&mat_h) {760*name_count == *count761} else {762false763};764765if !valid {766if verbose() {767eprintln!(768"materialized names collided in common subexpression elimination.\n backtrace and run without CSE"769)770}771return Ok(None);772}773}774775if has_sub_expr {776let mut new_expr = Vec::with_capacity_by_factor(expr.len(), 1.3);777778// Then rewrite the expressions that have a cse count > 1.779for (e, offset) in expr.iter().zip(id_array_offsets.iter()) {780let ae_node = AexprNode::new(e.node());781782let (out, rewritten) =783self.mutate_expression(ae_node, *offset as usize, is_group_by, expr_arena)?;784785let out_node = out.node();786let mut out_e = e.clone();787let new_node = if !rewritten {788out_e789} else {790out_e.set_node(out_node);791792// Ensure the function ExprIR's have the proper names.793// This is needed for structs to get the proper field794let mut scratch = vec![];795let mut stack = vec![(e.node(), out_node)];796while let Some((original, new)) = stack.pop() {797// Don't follow identical nodes.798if original == new {799continue;800}801scratch.clear();802let aes = expr_arena.get_many_mut([original, new]);803804// Only follow paths that are the same.805if std::mem::discriminant(aes[0]) != std::mem::discriminant(aes[1]) {806continue;807}808809aes[0].inputs_rev(&mut scratch);810let offset = scratch.len();811aes[1].inputs_rev(&mut scratch);812813// If they have a different number of inputs, we don't follow the nodes.814if scratch.len() != offset * 2 {815continue;816}817818for i in 0..scratch.len() / 2 {819stack.push((scratch[i], scratch[i + offset]));820}821822match expr_arena.get_many_mut([original, new]) {823[824AExpr::Function {825input: input_original,826..827},828AExpr::Function {829input: input_new, ..830},831] => {832for (new, original) in input_new.iter_mut().zip(input_original) {833new.set_alias(original.output_name().clone());834}835},836[837AExpr::AnonymousFunction {838input: input_original,839..840},841AExpr::AnonymousFunction {842input: input_new, ..843},844] => {845for (new, original) in input_new.iter_mut().zip(input_original) {846new.set_alias(original.output_name().clone());847}848},849_ => {},850}851}852853// If we don't end with an alias we add an alias. Because the normal left-hand854// rule we apply for determining the name will not work we now refer to855// intermediate temporary names starting with the `CSE_REPLACED` constant.856if !e.has_alias() {857let name = ae_node.to_field(schema, expr_arena)?.name;858out_e.set_alias(name.clone());859}860out_e861};862new_expr.push(new_node)863}864// Add the tmp columns865for id in self.replaced_identifiers.inner.keys() {866let (node, _count) = self.se_count.get(id, expr_arena).unwrap();867let name = id.materialize();868let out_e = ExprIR::new(*node, OutputName::Alias(name));869new_expr.push(out_e)870}871let expr =872ProjectionExprs::new_with_cse(new_expr, self.replaced_identifiers.inner.len());873Ok(Some(expr))874} else {875Ok(None)876}877}878}879880impl RewritingVisitor for CommonSubExprOptimizer {881type Node = IRNode;882type Arena = IRNodeArena;883884fn pre_visit(885&mut self,886node: &Self::Node,887arena: &mut Self::Arena,888) -> PolarsResult<RewriteRecursion> {889use IR::*;890Ok(match node.to_alp(&arena.0) {891Select { .. } | HStack { .. } | GroupBy { .. } => RewriteRecursion::MutateAndContinue,892_ => RewriteRecursion::NoMutateAndContinue,893})894}895896fn mutate(&mut self, node: Self::Node, arena: &mut Self::Arena) -> PolarsResult<Self::Node> {897let mut id_array_offsets = std::mem::take(&mut self.id_array_offsets);898899self.se_count.inner.clear();900self.name_validation.clear();901self.id_array.clear();902id_array_offsets.clear();903self.replaced_identifiers.inner.clear();904905let arena_idx = node.node();906let alp = arena.0.get(arena_idx);907908match alp {909IR::Select {910input,911expr,912schema,913options,914} => {915let input_schema = arena.0.get(*input).schema(&arena.0);916if let Some(expr) = self.find_cse(917expr,918&mut arena.1,919&mut id_array_offsets,920false,921input_schema.as_ref().as_ref(),922)? {923let schema = schema.clone();924let options = *options;925926let lp = IRBuilder::new(*input, &mut arena.1, &mut arena.0)927.with_columns(928expr.cse_exprs().to_vec(),929ProjectionOptions {930run_parallel: options.run_parallel,931duplicate_check: options.duplicate_check,932// These columns might have different933// lengths from the dataframe, but934// they are only temporaries that will935// be removed by the evaluation of the936// default_exprs and the subsequent937// projection.938should_broadcast: false,939},940)941.build();942let input = arena.0.add(lp);943944let lp = IR::Select {945input,946expr: expr.default_exprs().to_vec(),947schema,948options,949};950arena.0.replace(arena_idx, lp);951}952},953IR::HStack {954input,955exprs,956schema,957options,958} => {959let input_schema = arena.0.get(*input).schema(&arena.0);960if let Some(exprs) = self.find_cse(961exprs,962&mut arena.1,963&mut id_array_offsets,964false,965input_schema.as_ref().as_ref(),966)? {967let schema = schema.clone();968let options = *options;969let input = *input;970971let lp = IRBuilder::new(input, &mut arena.1, &mut arena.0)972.with_columns(973exprs.cse_exprs().to_vec(),974// These columns might have different975// lengths from the dataframe, but they976// are only temporaries that will be977// removed by the evaluation of the978// default_exprs and the subsequent979// projection.980ProjectionOptions {981run_parallel: options.run_parallel,982duplicate_check: options.duplicate_check,983should_broadcast: false,984},985)986.with_columns(exprs.default_exprs().to_vec(), options)987.build();988let input = arena.0.add(lp);989990let lp = IR::SimpleProjection {991input,992columns: schema,993};994arena.0.replace(arena_idx, lp);995}996},997IR::GroupBy {998input,999keys,1000aggs,1001options,1002maintain_order,1003apply,1004schema,1005} => {1006let input_schema = arena.0.get(*input).schema(&arena.0);1007if let Some(aggs) = self.find_cse(1008aggs,1009&mut arena.1,1010&mut id_array_offsets,1011true,1012input_schema.as_ref().as_ref(),1013)? {1014let keys = keys.clone();1015let options = options.clone();1016let schema = schema.clone();1017let apply = apply.clone();1018let maintain_order = *maintain_order;1019let input = *input;10201021let lp = IRBuilder::new(input, &mut arena.1, &mut arena.0)1022.with_columns(aggs.cse_exprs().to_vec(), Default::default())1023.build();1024let input = arena.0.add(lp);10251026let lp = IR::GroupBy {1027input,1028keys,1029aggs: aggs.default_exprs().to_vec(),1030options,1031schema,1032maintain_order,1033apply,1034};1035arena.0.replace(arena_idx, lp);1036}1037},1038_ => {},1039}10401041self.id_array_offsets = id_array_offsets;1042Ok(node)1043}1044}104510461047