Path: blob/main/crates/polars-plan/src/plans/optimizer/cse/cspe.rs
7889 views
use std::hash::{Hash, Hasher};12use hashbrown::hash_map::RawEntryMut;3use polars_utils::unique_id::UniqueId;45use super::*;6use crate::prelude::visitor::IRNode;78struct Blake3Hasher {9hasher: blake3::Hasher,10}1112impl Blake3Hasher {13fn new() -> Self {14Self {15hasher: blake3::Hasher::new(),16}17}1819fn finalize(self) -> [u8; 32] {20self.hasher.finalize().into()21}22}2324impl Hasher for Blake3Hasher {25fn finish(&self) -> u64 {26// Not used - we'll call finalize() instead27028}2930fn write(&mut self, bytes: &[u8]) {31self.hasher.update(bytes);32}33}3435mod identifier_impl {36use super::*;37#[derive(Clone)]38pub(super) struct Identifier {39inner: Option<[u8; 32]>,40}4142impl Identifier {43pub fn hash(&self) -> u64 {44self.inner45.map(|inner| u64::from_le_bytes(inner[0..8].try_into().unwrap()))46.unwrap_or(0)47}4849pub fn is_equal(&self, other: &Self) -> bool {50self.inner.map(blake3::Hash::from_bytes) == other.inner.map(blake3::Hash::from_bytes)51}5253pub fn new() -> Self {54Self { inner: None }55}5657pub fn is_valid(&self) -> bool {58self.inner.is_some()59}6061pub fn combine(&mut self, other: &Identifier) {62let inner = match (self.inner, other.inner) {63(Some(l), Some(r)) => {64let mut h = blake3::Hasher::new();65h.update(&l);66h.update(&r);67*h.finalize().as_bytes()68},69(None, Some(r)) => r,70(Some(l), None) => l,71_ => return,72};73self.inner = Some(inner);74}7576pub fn add_alp_node(77&self,78alp: &IRNode,79lp_arena: &Arena<IR>,80expr_arena: &Arena<AExpr>,81) -> Self {82let mut h = Blake3Hasher::new();83alp.hashable_and_cmp(lp_arena, expr_arena)84.hash_as_equality()85.hash(&mut h);86let hashed = h.finalize();8788let inner = Some(self.inner.map_or(hashed, |l| {89let mut h = blake3::Hasher::new();90h.update(&l);91h.update(&hashed);92*h.finalize().as_bytes()93}));94Self { inner }95}96}97}98use identifier_impl::*;99100struct IdentifierMap<V> {101inner: PlHashMap<Identifier, V>,102}103104impl<V> IdentifierMap<V> {105fn new() -> Self {106Self {107inner: Default::default(),108}109}110111fn get(&self, id: &Identifier) -> Option<&V> {112self.inner113.raw_entry()114.from_hash(id.hash(), |k| k.is_equal(id))115.map(|(_k, v)| v)116}117118fn entry<F: FnOnce() -> V>(&mut self, id: Identifier, v: F) -> &mut V {119let h = id.hash();120match self.inner.raw_entry_mut().from_hash(h, |k| k.is_equal(&id)) {121RawEntryMut::Occupied(entry) => entry.into_mut(),122RawEntryMut::Vacant(entry) => {123let (_, v) = entry.insert_with_hasher(h, id, v(), |id| id.hash());124v125},126}127}128}129130impl<V> Default for IdentifierMap<V> {131fn default() -> Self {132Self::new()133}134}135/// Identifier maps to Expr Node and count.136type SubPlanCount = IdentifierMap<(Node, u32)>;137/// (post_visit_idx, identifier);138type IdentifierArray = Vec<(usize, Identifier)>;139140/// See Expr based CSE for explanations.141enum VisitRecord {142/// Entered a new plan node143Entered(usize),144SubPlanId(Identifier),145}146147struct LpIdentifierVisitor<'a> {148sp_count: &'a mut SubPlanCount,149identifier_array: &'a mut IdentifierArray,150// Index in pre-visit traversal order.151pre_visit_idx: usize,152post_visit_idx: usize,153visit_stack: Vec<VisitRecord>,154has_subplan: bool,155}156157impl LpIdentifierVisitor<'_> {158fn new<'a>(159sp_count: &'a mut SubPlanCount,160identifier_array: &'a mut IdentifierArray,161) -> LpIdentifierVisitor<'a> {162LpIdentifierVisitor {163sp_count,164identifier_array,165pre_visit_idx: 0,166post_visit_idx: 0,167visit_stack: vec![],168has_subplan: false,169}170}171172fn pop_until_entered(&mut self) -> (usize, Identifier) {173let mut id = Identifier::new();174175while let Some(item) = self.visit_stack.pop() {176match item {177VisitRecord::Entered(idx) => return (idx, id),178VisitRecord::SubPlanId(s) => {179id.combine(&s);180},181}182}183unreachable!()184}185}186187fn skip_children(lp: &IR) -> bool {188match lp {189// Don't visit all the files in a `scan *` operation.190// Put an arbitrary limit to 20 files now.191IR::Union {192options, inputs, ..193} => options.from_partitioned_ds && inputs.len() > 20,194_ => false,195}196}197198impl Visitor for LpIdentifierVisitor<'_> {199type Node = IRNode;200type Arena = IRNodeArena;201202fn pre_visit(203&mut self,204node: &Self::Node,205arena: &Self::Arena,206) -> PolarsResult<VisitRecursion> {207self.visit_stack208.push(VisitRecord::Entered(self.pre_visit_idx));209self.pre_visit_idx += 1;210211self.identifier_array.push((0, Identifier::new()));212213if skip_children(node.to_alp(&arena.0)) {214Ok(VisitRecursion::Skip)215} else {216Ok(VisitRecursion::Continue)217}218}219220fn post_visit(221&mut self,222node: &Self::Node,223arena: &Self::Arena,224) -> PolarsResult<VisitRecursion> {225self.post_visit_idx += 1;226227let (pre_visit_idx, sub_plan_id) = self.pop_until_entered();228229// Create the Id of this node.230let id = sub_plan_id.add_alp_node(node, &arena.0, &arena.1);231232// Store the created id.233self.identifier_array[pre_visit_idx] = (self.post_visit_idx, id.clone());234235// We popped until entered, push this Id on the stack so the trail236// is available for the parent plan.237self.visit_stack.push(VisitRecord::SubPlanId(id.clone()));238239let (_, sp_count) = self.sp_count.entry(id, || (node.node(), 0));240*sp_count += 1;241self.has_subplan |= *sp_count > 1;242Ok(VisitRecursion::Continue)243}244}245246pub(super) type CacheId2Caches = PlHashMap<UniqueId, (u32, Vec<Node>)>;247248struct CommonSubPlanRewriter<'a> {249sp_count: &'a SubPlanCount,250identifier_array: &'a IdentifierArray,251252max_post_visit_idx: usize,253/// index in traversal order in which `identifier_array`254/// was written. This is the index in `identifier_array`.255visited_idx: usize,256/// Indicates if this expression is rewritten.257rewritten: bool,258cache_id: IdentifierMap<UniqueId>,259// Maps cache_id : (cache_count and cache_nodes)260cache_id_to_caches: CacheId2Caches,261}262263impl<'a> CommonSubPlanRewriter<'a> {264fn new(sp_count: &'a SubPlanCount, identifier_array: &'a IdentifierArray) -> Self {265Self {266sp_count,267identifier_array,268max_post_visit_idx: 0,269visited_idx: 0,270rewritten: false,271cache_id: Default::default(),272cache_id_to_caches: Default::default(),273}274}275}276277impl RewritingVisitor for CommonSubPlanRewriter<'_> {278type Node = IRNode;279type Arena = IRNodeArena;280281fn pre_visit(282&mut self,283lp_node: &Self::Node,284arena: &mut Self::Arena,285) -> PolarsResult<RewriteRecursion> {286if self.visited_idx >= self.identifier_array.len()287|| self.max_post_visit_idx > self.identifier_array[self.visited_idx].0288{289return Ok(RewriteRecursion::Stop);290}291292let id = &self.identifier_array[self.visited_idx].1;293294// Id placeholder not overwritten, so we can skip this sub-expression.295if !id.is_valid() {296self.visited_idx += 1;297return Ok(RewriteRecursion::NoMutateAndContinue);298}299300let Some((_, count)) = self.sp_count.get(id) else {301self.visited_idx += 1;302return Ok(RewriteRecursion::NoMutateAndContinue);303};304305if *count > 1 {306// Rewrite this sub-plan, don't visit its children307Ok(RewriteRecursion::MutateAndStop)308}309// Never mutate if count <= 1. The post-visit will search for the node, and not be able to find it310else {311// Don't traverse the children.312if skip_children(lp_node.to_alp(&arena.0)) {313return Ok(RewriteRecursion::Stop);314}315// This is a unique plan316// visit its children to see if they are cse317self.visited_idx += 1;318Ok(RewriteRecursion::NoMutateAndContinue)319}320}321322fn mutate(323&mut self,324mut node: Self::Node,325arena: &mut Self::Arena,326) -> PolarsResult<Self::Node> {327let (post_visit_count, id) = &self.identifier_array[self.visited_idx];328self.visited_idx += 1;329330if *post_visit_count < self.max_post_visit_idx {331return Ok(node);332}333self.max_post_visit_idx = *post_visit_count;334while self.visited_idx < self.identifier_array.len()335&& *post_visit_count > self.identifier_array[self.visited_idx].0336{337self.visited_idx += 1;338}339340let cache_id = *self.cache_id.entry(id.clone(), UniqueId::new);341let cache_count = self.sp_count.get(id).unwrap().1;342343let cache_node = IR::Cache {344input: node.node(),345id: cache_id,346};347node.assign(cache_node, &mut arena.0);348let (_count, nodes) = self349.cache_id_to_caches350.entry(cache_id)351.or_insert_with(|| (cache_count, vec![]));352nodes.push(node.node());353self.rewritten = true;354Ok(node)355}356}357358fn insert_caches(359root: Node,360lp_arena: &mut Arena<IR>,361expr_arena: &mut Arena<AExpr>,362) -> (Node, bool, CacheId2Caches) {363let mut sp_count = Default::default();364let mut id_array = Default::default();365366with_ir_arena(lp_arena, expr_arena, |arena| {367let lp_node = IRNode::new_mutate(root);368let mut visitor = LpIdentifierVisitor::new(&mut sp_count, &mut id_array);369370lp_node.visit(&mut visitor, arena).map(|_| ()).unwrap();371372if visitor.has_subplan {373let lp_node = IRNode::new_mutate(root);374let mut rewriter = CommonSubPlanRewriter::new(&sp_count, &id_array);375lp_node.rewrite(&mut rewriter, arena).unwrap();376377(root, rewriter.rewritten, rewriter.cache_id_to_caches)378} else {379(root, false, Default::default())380}381})382}383384/// Prune unused caches.385/// In the query below the query will be insert cache 0 with a count of 2 on `lf.select`386/// and cache 1 with a count of 3 on `lf`. But because cache 0 is higher in the chain cache 1387/// will never be used. So we prune caches that don't fit their count.388///389/// `conctat([lf.select(), lf.select(), lf])`390fn prune_unused_caches(lp_arena: &mut Arena<IR>, cid2c: &CacheId2Caches) {391for (count, nodes) in cid2c.values() {392if *count == nodes.len() as u32 {393continue;394}395396for node in nodes {397let IR::Cache { input, .. } = lp_arena.get(*node) else {398unreachable!()399};400lp_arena.swap(*input, *node)401}402}403}404405pub(super) fn elim_cmn_subplans(406root: Node,407lp_arena: &mut Arena<IR>,408expr_arena: &mut Arena<AExpr>,409) -> (Node, bool, CacheId2Caches) {410let (lp, changed, cid2c) = insert_caches(root, lp_arena, expr_arena);411if changed {412prune_unused_caches(lp_arena, &cid2c);413}414415(lp, changed, cid2c)416}417418419