Path: blob/main/crates/polars-plan/src/plans/optimizer/cluster_with_columns.rs
8458 views
use std::sync::Arc;12use polars_core::prelude::{PlHashSet, PlIndexMap};3use polars_utils::aliases::InitHashMaps;4use polars_utils::arena::{Arena, Node};56use super::aexpr::AExpr;7use super::ir::IR;8use super::{PlSmallStr, aexpr_to_leaf_names_iter};9use crate::plans::ExprIR;1011pub fn optimize(root: Node, lp_arena: &mut Arena<IR>, expr_arena: &Arena<AExpr>) {12let mut ir_stack = Vec::with_capacity(16);13ir_stack.push(root);1415// key: output_name, value: (expr, is_original)16let mut input_name_to_expr_map: PlIndexMap<PlSmallStr, (ExprIR, bool)> = PlIndexMap::new();17let mut input_names_accessed_by_non_candidates: PlHashSet<PlSmallStr> = PlHashSet::new();18let mut push_candidate_idxs: Vec<usize> = vec![];19let mut new_current_exprs: Vec<ExprIR> = vec![];20let mut visited_caches = PlHashSet::new();2122while let Some(current_node) = ir_stack.pop() {23let current_ir = lp_arena.get(current_node);2425if let IR::Cache { id, .. } = current_ir {26if !visited_caches.insert(*id) {27continue;28}29}3031current_ir.copy_inputs(&mut ir_stack);3233let IR::HStack { input, .. } = current_ir else {34continue;35};3637let input_node = *input;3839let [current_ir, input_ir] = lp_arena.get_disjoint_mut([current_node, input_node]);4041let IR::HStack {42input: _,43exprs: current_exprs,44schema: current_schema,45options: _,46} = current_ir47else {48unreachable!();49};5051let IR::HStack {52input: _,53exprs: input_exprs,54schema: input_schema,55options: _,56} = input_ir57else {58continue;59};6061input_name_to_expr_map.clear();62input_names_accessed_by_non_candidates.clear();63push_candidate_idxs.clear();64new_current_exprs.clear();6566input_name_to_expr_map.extend(67input_exprs68.iter()69.map(|e| (e.output_name().clone(), (e.clone(), true))),70);7172if input_name_to_expr_map.len() != input_exprs.len() {73if cfg!(debug_assertions) {74panic!()75};7677continue;78}7980for (i, e) in current_exprs.iter().enumerate() {81// Ignore col()82if let AExpr::Column(name) = expr_arena.get(e.node())83&& name == e.output_name()84{85continue;86}8788if aexpr_to_leaf_names_iter(e.node(), expr_arena)89.all(|name| !input_name_to_expr_map.contains_key(name))90{91push_candidate_idxs.push(i);92}93}9495let mut candidate_idx: usize = 0;9697for (i, e) in current_exprs.iter().enumerate() {98if push_candidate_idxs.get(candidate_idx) == Some(&i) {99candidate_idx += 1;100continue;101}102103for name in aexpr_to_leaf_names_iter(e.node(), expr_arena) {104input_names_accessed_by_non_candidates.insert(name.clone());105}106}107108push_candidate_idxs.retain(|&i| {109let e = ¤t_exprs[i];110!input_names_accessed_by_non_candidates.contains(e.output_name())111});112113let mut candidate_idx: usize = 0;114115for (i, e) in current_exprs.iter().enumerate() {116// Prune col()117if let AExpr::Column(name) = expr_arena.get(e.node())118&& name == e.output_name()119{120continue;121}122123if push_candidate_idxs.get(candidate_idx) == Some(&i) {124candidate_idx += 1;125input_name_to_expr_map.insert(e.output_name().clone(), (e.clone(), false));126continue;127}128129new_current_exprs.push(e.clone());130}131132if new_current_exprs.len() == current_exprs.len() {133continue;134}135136input_exprs.clear();137138for (output_name, (e, is_original)) in input_name_to_expr_map139.iter()140.map(|x| (x.0.clone(), x.1.clone()))141{142input_exprs.push(e);143144if !is_original {145let dtype = current_schema.get(&output_name).unwrap().clone();146Arc::make_mut(input_schema).insert(output_name, dtype);147}148}149150if new_current_exprs.is_empty() {151let input_ir = input_ir.clone();152lp_arena.replace(current_node, input_ir);153*ir_stack.last_mut().unwrap() = current_node;154continue;155}156157let fix_output_order = current_exprs.iter().any(|e| {158input_schema159.index_of(e.output_name())160.is_some_and(|i| i != current_schema.index_of(e.output_name()).unwrap())161});162163current_exprs.clear();164std::mem::swap(current_exprs, &mut new_current_exprs);165166if fix_output_order {167let projection = current_schema.clone();168169Arc::make_mut(current_schema)170.sort_by_key(|name, _| input_schema.index_of(name).unwrap_or(usize::MAX));171172let current_ir = lp_arena.replace(current_node, IR::Invalid);173let moved_current_node = lp_arena.add(current_ir);174lp_arena.replace(175current_node,176IR::SimpleProjection {177input: moved_current_node,178columns: projection,179},180);181}182}183}184185186