Path: blob/main/crates/polars-plan/src/plans/optimizer/cache_states.rs
6940 views
use std::collections::BTreeMap;12use polars_utils::unique_id::UniqueId;34use super::*;56fn get_upper_projections(7parent: Node,8lp_arena: &Arena<IR>,9expr_arena: &Arena<AExpr>,10names_scratch: &mut Vec<PlSmallStr>,11found_required_columns: &mut bool,12) -> bool {13let parent = lp_arena.get(parent);1415use IR::*;16// During projection pushdown all accumulated.17match parent {18SimpleProjection { columns, .. } => {19let iter = columns.iter_names_cloned();20names_scratch.extend(iter);21*found_required_columns = true;22false23},24Filter { predicate, .. } => {25// Also add predicate, as the projection is above the filter node.26names_scratch.extend(aexpr_to_leaf_names(predicate.node(), expr_arena));2728true29},30// Only filter and projection nodes are allowed, any other node we stop.31_ => false,32}33}3435fn get_upper_predicates(36parent: Node,37lp_arena: &Arena<IR>,38expr_arena: &mut Arena<AExpr>,39predicate_scratch: &mut Vec<Expr>,40) -> bool {41let parent = lp_arena.get(parent);4243use IR::*;44match parent {45Filter { predicate, .. } => {46let expr = predicate.to_expr(expr_arena);47predicate_scratch.push(expr);48false49},50SimpleProjection { .. } => true,51// Only filter and projection nodes are allowed, any other node we stop.52_ => false,53}54}5556type TwoParents = [Option<Node>; 2];5758// 1. This will ensure that all equal caches communicate the amount of columns59// they need to project.60// 2. This will ensure we apply predicate in the subtrees below the caches.61// If the predicate above the cache is the same for all matching caches, that filter will be62// applied as well.63//64// # Example65// Consider this tree, where `SUB-TREE` is duplicate and can be cached.66//67//68// Tree69// |70// |71// |--------------------|-------------------|72// | |73// SUB-TREE SUB-TREE74//75// STEPS:76// - 1. CSE will run and will insert cache nodes77//78// Tree79// |80// |81// |--------------------|-------------------|82// | |83// | CACHE 0 | CACHE 084// | |85// SUB-TREE SUB-TREE86//87// - 2. predicate and projection pushdown will run and will insert optional FILTER and PROJECTION above the caches88//89// Tree90// |91// |92// |--------------------|-------------------|93// | FILTER (optional) | FILTER (optional)94// | PROJ (optional) | PROJ (optional)95// | |96// | CACHE 0 | CACHE 097// | |98// SUB-TREE SUB-TREE99//100// # Projection optimization101// The union of the projection is determined and the projection will be pushed down.102//103// Tree104// |105// |106// |--------------------|-------------------|107// | FILTER (optional) | FILTER (optional)108// | CACHE 0 | CACHE 0109// | |110// SUB-TREE SUB-TREE111// UNION PROJ (optional) UNION PROJ (optional)112//113// # Filter optimization114// Depending on the predicates the predicate pushdown optimization will run.115// Possible cases:116// - NO FILTERS: run predicate pd from the cache nodes -> finish117// - Above the filters the caches are the same -> run predicate pd from the filter node -> finish118// - There is a cache without predicates above the cache node -> run predicate form the cache nodes -> finish119// - The predicates above the cache nodes are all different -> remove the cache nodes -> finish120#[expect(clippy::too_many_arguments)]121pub(super) fn set_cache_states(122root: Node,123lp_arena: &mut Arena<IR>,124expr_arena: &mut Arena<AExpr>,125scratch: &mut Vec<Node>,126expr_eval: ExprEval<'_>,127verbose: bool,128pushdown_maintain_errors: bool,129new_streaming: bool,130) -> PolarsResult<()> {131let mut stack = Vec::with_capacity(4);132let mut names_scratch = vec![];133let mut predicates_scratch = vec![];134135scratch.clear();136stack.clear();137138#[derive(Default)]139struct Value {140// All the children of the cache per cache-id.141children: Vec<Node>,142parents: Vec<TwoParents>,143cache_nodes: Vec<Node>,144// Union over projected names.145names_union: PlHashSet<PlSmallStr>,146// Union over predicates.147predicate_union: PlHashMap<Expr, u32>,148}149let mut cache_schema_and_children = BTreeMap::new();150151// Stack frame152#[derive(Default, Clone)]153struct Frame {154current: Node,155cache_id: Option<UniqueId>,156parent: TwoParents,157previous_cache: Option<UniqueId>,158}159let init = Frame {160current: root,161..Default::default()162};163164stack.push(init);165166// # First traversal.167// Collect the union of columns per cache id.168// And find the cache parents.169while let Some(mut frame) = stack.pop() {170let lp = lp_arena.get(frame.current);171lp.copy_inputs(scratch);172173use IR::*;174175if let Cache { input, id, .. } = lp {176if let Some(cache_id) = frame.cache_id {177frame.previous_cache = Some(cache_id)178}179if frame.parent[0].is_some() {180// Projection pushdown has already run and blocked on cache nodes181// the pushed down columns are projected just above this cache182// if there were no pushed down column, we just take the current183// nodes schema184// we never want to naively take parents, as a join or aggregate for instance185// change the schema186187let v = cache_schema_and_children188.entry(*id)189.or_insert_with(Value::default);190v.children.push(*input);191v.parents.push(frame.parent);192v.cache_nodes.push(frame.current);193194let mut found_required_columns = false;195196for parent_node in frame.parent.into_iter().flatten() {197let keep_going = get_upper_projections(198parent_node,199lp_arena,200expr_arena,201&mut names_scratch,202&mut found_required_columns,203);204if !names_scratch.is_empty() {205v.names_union.extend(names_scratch.drain(..));206}207// We stop early as we want to find the first projection node above the cache.208if !keep_going {209break;210}211}212213for parent_node in frame.parent.into_iter().flatten() {214let keep_going = get_upper_predicates(215parent_node,216lp_arena,217expr_arena,218&mut predicates_scratch,219);220if !predicates_scratch.is_empty() {221for pred in predicates_scratch.drain(..) {222let count = v.predicate_union.entry(pred).or_insert(0);223*count += 1;224}225}226// We stop early as we want to find the first predicate node above the cache.227if !keep_going {228break;229}230}231232// There was no explicit projection and we must take233// all columns234if !found_required_columns {235let schema = lp.schema(lp_arena);236v.names_union.extend(schema.iter_names_cloned());237}238}239frame.cache_id = Some(*id);240};241242// Shift parents.243frame.parent[1] = frame.parent[0];244frame.parent[0] = Some(frame.current);245for n in scratch.iter() {246let mut new_frame = frame.clone();247new_frame.current = *n;248stack.push(new_frame);249}250scratch.clear();251}252253// # Second pass.254// we create a subtree where we project the columns255// just before the cache. Then we do another projection pushdown256// and finally remove that last projection and stitch the subplan257// back to the cache node again258if !cache_schema_and_children.is_empty() {259let mut proj_pd = ProjectionPushDown::new();260let mut pred_pd =261PredicatePushDown::new(expr_eval, pushdown_maintain_errors, new_streaming)262.block_at_cache(false);263for (_cache_id, v) in cache_schema_and_children {264// # CHECK IF WE NEED TO REMOVE CACHES265// If we encounter multiple predicates we remove the cache nodes completely as we don't266// want to loose predicate pushdown in favor of scan sharing.267if v.predicate_union.len() > 1 {268if verbose {269eprintln!("cache nodes will be removed because predicates don't match")270}271for ((&child, cache), parents) in272v.children.iter().zip(v.cache_nodes).zip(v.parents)273{274// Remove the cache and assign the child the cache location.275lp_arena.swap(child, cache);276277// Restart predicate and projection pushdown from most top parent.278// This to ensure we continue the optimization where it was blocked initially.279// We pick up the blocked filter and projection.280let mut node = cache;281for p_node in parents.into_iter().flatten() {282if matches!(283lp_arena.get(p_node),284IR::Filter { .. } | IR::SimpleProjection { .. }285) {286node = p_node287} else {288break;289}290}291292let lp = lp_arena.take(node);293let lp = proj_pd.optimize(lp, lp_arena, expr_arena)?;294let lp = pred_pd.optimize(lp, lp_arena, expr_arena)?;295lp_arena.replace(node, lp);296}297return Ok(());298}299// Below we restart projection and predicates pushdown300// on the first cache node. As it are cache nodes, the others are the same301// and we can reuse the optimized state for all inputs.302// See #21637303304// # RUN PROJECTION PUSHDOWN305if !v.names_union.is_empty() {306let first_child = *v.children.first().expect("at least on child");307308let columns = &v.names_union;309let child_lp = lp_arena.take(first_child);310311// Make sure we project in the order of the schema312// if we don't a union may fail as we would project by the313// order we discovered all values.314let child_schema = child_lp.schema(lp_arena);315let child_schema = child_schema.as_ref();316let projection = child_schema317.iter_names()318.flat_map(|name| columns.get(name.as_str()).cloned())319.collect::<Vec<_>>();320321let new_child = lp_arena.add(child_lp);322323let lp = IRBuilder::new(new_child, expr_arena, lp_arena)324.project_simple(projection)325.expect("unique names")326.build();327328let lp = proj_pd.optimize(lp, lp_arena, expr_arena)?;329// Optimization can lead to a double projection. Only take the last.330let lp = if let IR::SimpleProjection { input, columns } = lp {331let input =332if let IR::SimpleProjection { input: input2, .. } = lp_arena.get(input) {333*input2334} else {335input336};337IR::SimpleProjection { input, columns }338} else {339lp340};341lp_arena.replace(first_child, lp.clone());342343// Set the remaining children to the same node.344for &child in &v.children[1..] {345lp_arena.replace(child, lp.clone());346}347} else {348// No upper projections to include, run projection pushdown from cache node.349let first_child = *v.children.first().expect("at least on child");350let child_lp = lp_arena.take(first_child);351let lp = proj_pd.optimize(child_lp, lp_arena, expr_arena)?;352lp_arena.replace(first_child, lp.clone());353354for &child in &v.children[1..] {355lp_arena.replace(child, lp.clone());356}357}358359// # RUN PREDICATE PUSHDOWN360// Run this after projection pushdown, otherwise the predicate columns will not be projected.361362// - If all predicates of parent are the same we will restart predicate pushdown from the parent FILTER node.363// - Otherwise we will start predicate pushdown from the cache node.364let allow_parent_predicate_pushdown = v.predicate_union.len() == 1 && {365let (_pred, count) = v.predicate_union.iter().next().unwrap();366*count == v.children.len() as u32367};368369if allow_parent_predicate_pushdown {370let parents = *v.parents.first().unwrap();371let node = get_filter_node(parents, lp_arena)372.expect("expected filter; this is an optimizer bug");373let start_lp = lp_arena.take(node);374let lp = pred_pd.optimize(start_lp, lp_arena, expr_arena)?;375lp_arena.replace(node, lp.clone());376for &parents in &v.parents[1..] {377let node = get_filter_node(parents, lp_arena)378.expect("expected filter; this is an optimizer bug");379lp_arena.replace(node, lp.clone());380}381} else {382let child = *v.children.first().unwrap();383let child_lp = lp_arena.take(child);384let lp = pred_pd.optimize(child_lp, lp_arena, expr_arena)?;385lp_arena.replace(child, lp.clone());386for &child in &v.children[1..] {387lp_arena.replace(child, lp.clone());388}389}390}391}392Ok(())393}394395fn get_filter_node(parents: TwoParents, lp_arena: &Arena<IR>) -> Option<Node> {396parents397.into_iter()398.flatten()399.find(|&parent| matches!(lp_arena.get(parent), IR::Filter { .. }))400}401402403