Path: blob/main/crates/polars-plan/src/plans/optimizer/cse/cache_states.rs
7889 views
use std::collections::BTreeMap;12use polars_utils::unique_id::UniqueId;34use super::*;56fn get_upper_projections(7parent: Node,8lp_arena: &Arena<IR>,9expr_arena: &Arena<AExpr>,10names_scratch: &mut Vec<PlSmallStr>,11found_required_columns: &mut bool,12) -> bool {13let parent = lp_arena.get(parent);1415use IR::*;16// During projection pushdown all accumulated.17match parent {18SimpleProjection { columns, .. } => {19let iter = columns.iter_names_cloned();20names_scratch.extend(iter);21*found_required_columns = true;22false23},24Filter { predicate, .. } => {25// Also add predicate, as the projection is above the filter node.26names_scratch.extend(aexpr_to_leaf_names(predicate.node(), expr_arena));2728true29},30// Only filter and projection nodes are allowed, any other node we stop.31_ => false,32}33}3435fn get_upper_predicates(36parent: Node,37lp_arena: &Arena<IR>,38expr_arena: &mut Arena<AExpr>,39predicate_scratch: &mut Vec<Expr>,40) -> bool {41let parent = lp_arena.get(parent);4243use IR::*;44match parent {45Filter { predicate, .. } => {46let expr = predicate.to_expr(expr_arena);47predicate_scratch.push(expr);48false49},50SimpleProjection { .. } => true,51// Only filter and projection nodes are allowed, any other node we stop.52_ => false,53}54}5556type TwoParents = [Option<Node>; 2];5758// 1. This will ensure that all equal caches communicate the amount of columns59// they need to project.60// 2. This will ensure we apply predicate in the subtrees below the caches.61// If the predicate above the cache is the same for all matching caches, that filter will be62// applied as well.63//64// # Example65// Consider this tree, where `SUB-TREE` is duplicate and can be cached.66//67//68// Tree69// |70// |71// |--------------------|-------------------|72// | |73// SUB-TREE SUB-TREE74//75// STEPS:76// - 1. CSE will run and will insert cache nodes77//78// Tree79// |80// |81// |--------------------|-------------------|82// | |83// | CACHE 0 | CACHE 084// | |85// SUB-TREE SUB-TREE86//87// - 2. predicate and projection pushdown will run and will insert optional FILTER and PROJECTION above the caches88//89// Tree90// |91// |92// |--------------------|-------------------|93// | FILTER (optional) | FILTER (optional)94// | PROJ (optional) | PROJ (optional)95// | |96// | CACHE 0 | CACHE 097// | |98// SUB-TREE SUB-TREE99//100// # Projection optimization101// The union of the projection is determined and the projection will be pushed down.102//103// Tree104// |105// |106// |--------------------|-------------------|107// | FILTER (optional) | FILTER (optional)108// | CACHE 0 | CACHE 0109// | |110// SUB-TREE SUB-TREE111// UNION PROJ (optional) UNION PROJ (optional)112//113// # Filter optimization114// Depending on the predicates the predicate pushdown optimization will run.115// Possible cases:116// - NO FILTERS: run predicate pd from the cache nodes -> finish117// - Above the filters the caches are the same -> run predicate pd from the filter node -> finish118// - There is a cache without predicates above the cache node -> run predicate form the cache nodes -> finish119// - The predicates above the cache nodes are all different -> remove the cache nodes -> finish120pub(super) fn set_cache_states(121root: Node,122lp_arena: &mut Arena<IR>,123expr_arena: &mut Arena<AExpr>,124scratch: &mut Vec<Node>,125verbose: bool,126pushdown_maintain_errors: bool,127new_streaming: bool,128) -> PolarsResult<()> {129let mut stack = Vec::with_capacity(4);130let mut names_scratch = vec![];131let mut predicates_scratch = vec![];132133scratch.clear();134stack.clear();135136#[derive(Default)]137struct Value {138// All the children of the cache per cache-id.139children: Vec<Node>,140parents: Vec<TwoParents>,141cache_nodes: Vec<Node>,142// Union over projected names.143names_union: PlHashSet<PlSmallStr>,144// Union over predicates.145predicate_union: PlHashMap<Expr, u32>,146}147let mut cache_schema_and_children = BTreeMap::new();148149// Stack frame150#[derive(Default, Clone)]151struct Frame {152current: Node,153cache_id: Option<UniqueId>,154parent: TwoParents,155previous_cache: Option<UniqueId>,156}157let init = Frame {158current: root,159..Default::default()160};161162stack.push(init);163164// # First traversal.165// Collect the union of columns per cache id.166// And find the cache parents.167while let Some(mut frame) = stack.pop() {168let lp = lp_arena.get(frame.current);169lp.copy_inputs(scratch);170171use IR::*;172173if let Cache { input, id, .. } = lp {174if let Some(cache_id) = frame.cache_id {175frame.previous_cache = Some(cache_id)176}177if frame.parent[0].is_some() {178// Projection pushdown has already run and blocked on cache nodes179// the pushed down columns are projected just above this cache180// if there were no pushed down column, we just take the current181// nodes schema182// we never want to naively take parents, as a join or aggregate for instance183// change the schema184185let v = cache_schema_and_children186.entry(*id)187.or_insert_with(Value::default);188v.children.push(*input);189v.parents.push(frame.parent);190v.cache_nodes.push(frame.current);191192let mut found_required_columns = false;193194for parent_node in frame.parent.into_iter().flatten() {195let keep_going = get_upper_projections(196parent_node,197lp_arena,198expr_arena,199&mut names_scratch,200&mut found_required_columns,201);202if !names_scratch.is_empty() {203v.names_union.extend(names_scratch.drain(..));204}205// We stop early as we want to find the first projection node above the cache.206if !keep_going {207break;208}209}210211for parent_node in frame.parent.into_iter().flatten() {212let keep_going = get_upper_predicates(213parent_node,214lp_arena,215expr_arena,216&mut predicates_scratch,217);218if !predicates_scratch.is_empty() {219for pred in predicates_scratch.drain(..) {220let count = v.predicate_union.entry(pred).or_insert(0);221*count += 1;222}223}224// We stop early as we want to find the first predicate node above the cache.225if !keep_going {226break;227}228}229230// There was no explicit projection and we must take231// all columns232if !found_required_columns {233let schema = lp.schema(lp_arena);234v.names_union.extend(schema.iter_names_cloned());235}236}237frame.cache_id = Some(*id);238};239240// Shift parents.241frame.parent[1] = frame.parent[0];242frame.parent[0] = Some(frame.current);243for n in scratch.iter() {244let mut new_frame = frame.clone();245new_frame.current = *n;246stack.push(new_frame);247}248scratch.clear();249}250251// # Second pass.252// we create a subtree where we project the columns253// just before the cache. Then we do another projection pushdown254// and finally remove that last projection and stitch the subplan255// back to the cache node again256if !cache_schema_and_children.is_empty() {257let mut proj_pd = ProjectionPushDown::new();258let mut pred_pd = PredicatePushDown::new(pushdown_maintain_errors, new_streaming);259for (_cache_id, v) in cache_schema_and_children {260// # CHECK IF WE NEED TO REMOVE CACHES261// If we encounter multiple predicates we remove the cache nodes completely as we don't262// want to loose predicate pushdown in favor of scan sharing.263if v.predicate_union.len() > 1 {264if verbose {265eprintln!("cache nodes will be removed because predicates don't match")266}267for ((&child, cache), parents) in268v.children.iter().zip(v.cache_nodes).zip(v.parents)269{270// Remove the cache and assign the child the cache location.271lp_arena.swap(child, cache);272273// Restart predicate and projection pushdown from most top parent.274// This to ensure we continue the optimization where it was blocked initially.275// We pick up the blocked filter and projection.276let mut node = cache;277for p_node in parents.into_iter().flatten() {278if matches!(279lp_arena.get(p_node),280IR::Filter { .. } | IR::SimpleProjection { .. }281) {282node = p_node283} else {284break;285}286}287288let lp = lp_arena.take(node);289let lp = proj_pd.optimize(lp, lp_arena, expr_arena)?;290let lp = pred_pd.optimize(lp, lp_arena, expr_arena)?;291lp_arena.replace(node, lp);292}293return Ok(());294}295// Below we restart projection and predicates pushdown296// on the first cache node. As it are cache nodes, the others are the same297// and we can reuse the optimized state for all inputs.298// See #21637299300// # RUN PROJECTION PUSHDOWN301if !v.names_union.is_empty() {302let first_child = *v.children.first().expect("at least on child");303304let columns = &v.names_union;305let child_lp = lp_arena.take(first_child);306307// Make sure we project in the order of the schema308// if we don't a union may fail as we would project by the309// order we discovered all values.310let child_schema = child_lp.schema(lp_arena);311let child_schema = child_schema.as_ref();312let projection = child_schema313.iter_names()314.flat_map(|name| columns.get(name.as_str()).cloned())315.collect::<Vec<_>>();316317let new_child = lp_arena.add(child_lp);318319let lp = IRBuilder::new(new_child, expr_arena, lp_arena)320.project_simple(projection)321.expect("unique names")322.build();323324let lp = proj_pd.optimize(lp, lp_arena, expr_arena)?;325// Optimization can lead to a double projection. Only take the last.326let lp = if let IR::SimpleProjection { input, columns } = lp {327let input =328if let IR::SimpleProjection { input: input2, .. } = lp_arena.get(input) {329*input2330} else {331input332};333IR::SimpleProjection { input, columns }334} else {335lp336};337lp_arena.replace(first_child, lp.clone());338339// Set the remaining children to the same node.340for &child in &v.children[1..] {341lp_arena.replace(child, lp.clone());342}343} else {344// No upper projections to include, run projection pushdown from cache node.345let first_child = *v.children.first().expect("at least on child");346let child_lp = lp_arena.take(first_child);347let lp = proj_pd.optimize(child_lp, lp_arena, expr_arena)?;348lp_arena.replace(first_child, lp.clone());349350for &child in &v.children[1..] {351lp_arena.replace(child, lp.clone());352}353}354355// # RUN PREDICATE PUSHDOWN356// Run this after projection pushdown, otherwise the predicate columns will not be projected.357358// - If all predicates of parent are the same we will restart predicate pushdown from the parent FILTER node.359// - Otherwise we will start predicate pushdown from the cache node.360let allow_parent_predicate_pushdown = v.predicate_union.len() == 1 && {361let (_pred, count) = v.predicate_union.iter().next().unwrap();362*count == v.children.len() as u32363};364365if allow_parent_predicate_pushdown {366let parents = *v.parents.first().unwrap();367let node = get_filter_node(parents, lp_arena)368.expect("expected filter; this is an optimizer bug");369let start_lp = lp_arena.take(node);370371let mut pred_pd = PredicatePushDown::new(pushdown_maintain_errors, new_streaming)372.block_at_cache(1);373let lp = pred_pd.optimize(start_lp, lp_arena, expr_arena)?;374lp_arena.replace(node, lp.clone());375for &parents in &v.parents[1..] {376let node = get_filter_node(parents, lp_arena)377.expect("expected filter; this is an optimizer bug");378lp_arena.replace(node, lp.clone());379}380} else {381let child = *v.children.first().unwrap();382let child_lp = lp_arena.take(child);383let lp = pred_pd.optimize(child_lp, lp_arena, expr_arena)?;384lp_arena.replace(child, lp.clone());385for &child in &v.children[1..] {386lp_arena.replace(child, lp.clone());387}388}389}390}391Ok(())392}393394fn get_filter_node(parents: TwoParents, lp_arena: &Arena<IR>) -> Option<Node> {395parents396.into_iter()397.flatten()398.find(|&parent| matches!(lp_arena.get(parent), IR::Filter { .. }))399}400401402