Path: blob/main/crates/polars-plan/src/plans/visitor/hash.rs
6940 views
use std::hash::{Hash, Hasher};1use std::sync::Arc;23use polars_utils::arena::Arena;45use super::*;6#[cfg(feature = "python")]7use crate::plans::PythonOptions;8use crate::plans::{AExpr, IR};9use crate::prelude::ExprIR;10use crate::prelude::aexpr::traverse_and_hash_aexpr;1112impl IRNode {13pub(crate) fn hashable_and_cmp<'a>(14&'a self,15lp_arena: &'a Arena<IR>,16expr_arena: &'a Arena<AExpr>,17) -> HashableEqLP<'a> {18HashableEqLP {19node: *self,20lp_arena,21expr_arena,22ignore_cache: false,23}24}25}2627pub(crate) struct HashableEqLP<'a> {28node: IRNode,29lp_arena: &'a Arena<IR>,30expr_arena: &'a Arena<AExpr>,31ignore_cache: bool,32}3334impl HashableEqLP<'_> {35/// When encountering a Cache node, ignore it and take the input.36#[cfg(feature = "cse")]37pub(crate) fn ignore_caches(mut self) -> Self {38self.ignore_cache = true;39self40}41}4243fn hash_option_expr<H: Hasher>(expr: &Option<ExprIR>, expr_arena: &Arena<AExpr>, state: &mut H) {44if let Some(e) = expr {45e.traverse_and_hash(expr_arena, state)46}47}4849fn hash_exprs<H: Hasher>(exprs: &[ExprIR], expr_arena: &Arena<AExpr>, state: &mut H) {50for e in exprs {51e.traverse_and_hash(expr_arena, state);52}53}5455#[cfg(feature = "python")]56fn hash_python_predicate<H: Hasher>(57pred: &crate::prelude::PythonPredicate,58expr_arena: &Arena<AExpr>,59state: &mut H,60) {61use crate::prelude::PythonPredicate;62std::mem::discriminant(pred).hash(state);63match pred {64PythonPredicate::None => {},65PythonPredicate::PyArrow(s) => s.hash(state),66PythonPredicate::Polars(e) => e.traverse_and_hash(expr_arena, state),67}68}6970#[cfg(feature = "python")]71fn pred_eq(72l: &crate::prelude::PythonPredicate,73r: &crate::prelude::PythonPredicate,74expr_arena: &Arena<AExpr>,75) -> bool {76use crate::prelude::PythonPredicate;77match (l, r) {78(PythonPredicate::None, PythonPredicate::None) => true,79(PythonPredicate::PyArrow(a), PythonPredicate::PyArrow(b)) => a == b,80(PythonPredicate::Polars(a), PythonPredicate::Polars(b)) => expr_ir_eq(a, b, expr_arena),81_ => false,82}83}8485impl Hash for HashableEqLP<'_> {86// This hashes the variant, not the whole plan87fn hash<H: Hasher>(&self, state: &mut H) {88let alp = self.node.to_alp(self.lp_arena);89std::mem::discriminant(alp).hash(state);90match alp {91#[cfg(feature = "python")]92IR::PythonScan {93options:94PythonOptions {95scan_fn,96schema,97output_schema,98with_columns,99python_source,100n_rows,101predicate,102validate_schema,103is_pure,104},105} => {106// Hash the Python function object using the pointer to the object107// This should be the same as calling id() in python, but we don't need the GIL108if let Some(scan_fn) = scan_fn {109let ptr_addr = scan_fn.0.as_ptr() as usize;110ptr_addr.hash(state);111}112// Hash the stable fields113// We include the schema since it can be set by the user114schema.hash(state);115output_schema.hash(state);116with_columns.hash(state);117python_source.hash(state);118n_rows.hash(state);119hash_python_predicate(predicate, self.expr_arena, state);120validate_schema.hash(state);121is_pure.hash(state);122},123IR::Slice {124offset,125len,126input: _,127} => {128len.hash(state);129offset.hash(state);130},131IR::Filter {132input: _,133predicate,134} => {135predicate.traverse_and_hash(self.expr_arena, state);136},137IR::Scan {138sources,139file_info: _,140hive_parts: _,141predicate,142output_schema: _,143scan_type,144unified_scan_args,145} => {146// We don't have to traverse the schema, hive partitions etc. as they are derivative from the paths.147scan_type.hash(state);148sources.hash(state);149hash_option_expr(predicate, self.expr_arena, state);150unified_scan_args.hash(state);151},152IR::DataFrameScan {153df,154schema: _,155output_schema,156..157} => {158(Arc::as_ptr(df) as usize).hash(state);159output_schema.hash(state);160},161IR::SimpleProjection { columns, input: _ } => {162columns.hash(state);163},164IR::Select {165input: _,166expr,167schema: _,168options,169} => {170hash_exprs(expr, self.expr_arena, state);171options.hash(state);172},173IR::Sort {174input: _,175by_column,176slice,177sort_options,178} => {179hash_exprs(by_column, self.expr_arena, state);180slice.hash(state);181sort_options.hash(state);182},183IR::GroupBy {184input: _,185keys,186aggs,187schema: _,188apply,189maintain_order,190options,191} => {192hash_exprs(keys, self.expr_arena, state);193hash_exprs(aggs, self.expr_arena, state);194apply.is_none().hash(state);195maintain_order.hash(state);196options.hash(state);197},198IR::Join {199input_left: _,200input_right: _,201schema: _,202left_on,203right_on,204options,205} => {206hash_exprs(left_on, self.expr_arena, state);207hash_exprs(right_on, self.expr_arena, state);208options.hash(state);209},210IR::HStack {211input: _,212exprs,213schema: _,214options,215} => {216hash_exprs(exprs, self.expr_arena, state);217options.hash(state);218},219IR::Distinct { input: _, options } => {220options.hash(state);221},222IR::MapFunction { input: _, function } => {223function.hash(state);224},225IR::Union { inputs: _, options } => options.hash(state),226IR::HConcat {227inputs: _,228schema: _,229options,230} => {231options.hash(state);232},233IR::ExtContext {234input: _,235contexts,236schema: _,237} => {238for node in contexts {239traverse_and_hash_aexpr(*node, self.expr_arena, state);240}241},242IR::Sink { input: _, payload } => {243payload.traverse_and_hash(self.expr_arena, state);244},245IR::SinkMultiple { .. } => {},246IR::Cache { input: _, id } => {247id.hash(state);248},249#[cfg(feature = "merge_sorted")]250IR::MergeSorted {251input_left: _,252input_right: _,253key,254} => {255key.hash(state);256},257IR::Invalid => unreachable!(),258}259}260}261262fn expr_irs_eq(l: &[ExprIR], r: &[ExprIR], expr_arena: &Arena<AExpr>) -> bool {263l.len() == r.len() && l.iter().zip(r).all(|(l, r)| expr_ir_eq(l, r, expr_arena))264}265266fn expr_ir_eq(l: &ExprIR, r: &ExprIR, expr_arena: &Arena<AExpr>) -> bool {267l.get_alias() == r.get_alias() && {268let l = AexprNode::new(l.node());269let r = AexprNode::new(r.node());270l.hashable_and_cmp(expr_arena) == r.hashable_and_cmp(expr_arena)271}272}273274fn opt_expr_ir_eq(l: &Option<ExprIR>, r: &Option<ExprIR>, expr_arena: &Arena<AExpr>) -> bool {275match (l, r) {276(None, None) => true,277(Some(l), Some(r)) => expr_ir_eq(l, r, expr_arena),278_ => false,279}280}281282impl HashableEqLP<'_> {283fn is_equal(&self, other: &Self) -> bool {284let alp_l = self.node.to_alp(self.lp_arena);285let alp_r = other.node.to_alp(self.lp_arena);286if std::mem::discriminant(alp_l) != std::mem::discriminant(alp_r) {287return false;288}289match (alp_l, alp_r) {290#[cfg(feature = "python")]291(292IR::PythonScan {293options:294PythonOptions {295scan_fn: scan_fn_l,296schema: schema_l,297output_schema: output_schema_l,298with_columns: with_columns_l,299python_source: python_source_l,300n_rows: n_rows_l,301predicate: predicate_l,302validate_schema: validate_schema_l,303is_pure: is_pure_l,304},305},306IR::PythonScan {307options:308PythonOptions {309scan_fn: scan_fn_r,310schema: schema_r,311output_schema: output_schema_r,312with_columns: with_columns_r,313python_source: python_source_r,314n_rows: n_rows_r,315predicate: predicate_r,316validate_schema: validate_schema_r,317is_pure: is_pure_r,318},319},320) => {321// Require both to be pure to compare equal for CSE.322if !(*is_pure_l && *is_pure_r) {323return false;324}325326let scan_fn_eq = match (scan_fn_l, scan_fn_r) {327(None, None) => true,328(Some(a), Some(b)) => a.0.as_ptr() == b.0.as_ptr(),329_ => false,330};331332scan_fn_eq333&& schema_l == schema_r334&& output_schema_l == output_schema_r335&& with_columns_l == with_columns_r336&& python_source_l == python_source_r337&& n_rows_l == n_rows_r338&& validate_schema_l == validate_schema_r339&& pred_eq(predicate_l, predicate_r, self.expr_arena)340},341(342IR::Slice {343input: _,344offset: ol,345len: ll,346},347IR::Slice {348input: _,349offset: or,350len: lr,351},352) => ol == or && ll == lr,353(354IR::Filter {355input: _,356predicate: l,357},358IR::Filter {359input: _,360predicate: r,361},362) => expr_ir_eq(l, r, self.expr_arena),363(364IR::Scan {365sources: pl,366file_info: _,367hive_parts: _,368predicate: pred_l,369output_schema: _,370scan_type: stl,371unified_scan_args: ol,372},373IR::Scan {374sources: pr,375file_info: _,376hive_parts: _,377predicate: pred_r,378output_schema: _,379scan_type: str,380unified_scan_args: or,381},382) => {383pl == pr384&& stl == str385&& ol == or386&& opt_expr_ir_eq(pred_l, pred_r, self.expr_arena)387},388(389IR::DataFrameScan {390df: dfl,391schema: _,392output_schema: s_l,393},394IR::DataFrameScan {395df: dfr,396schema: _,397output_schema: s_r,398},399) => std::ptr::eq(Arc::as_ptr(dfl), Arc::as_ptr(dfr)) && s_l == s_r,400(401IR::SimpleProjection {402input: _,403columns: cl,404},405IR::SimpleProjection {406input: _,407columns: cr,408},409) => cl == cr,410(411IR::Select {412input: _,413expr: el,414options: ol,415schema: _,416},417IR::Select {418input: _,419expr: er,420options: or,421schema: _,422},423) => ol == or && expr_irs_eq(el, er, self.expr_arena),424(425IR::Sort {426input: _,427by_column: cl,428slice: l_slice,429sort_options: l_options,430},431IR::Sort {432input: _,433by_column: cr,434slice: r_slice,435sort_options: r_options,436},437) => {438(l_slice == r_slice && l_options == r_options)439&& expr_irs_eq(cl, cr, self.expr_arena)440},441(442IR::GroupBy {443input: _,444keys: keys_l,445aggs: aggs_l,446schema: _,447apply: apply_l,448maintain_order: maintain_l,449options: ol,450},451IR::GroupBy {452input: _,453keys: keys_r,454aggs: aggs_r,455schema: _,456apply: apply_r,457maintain_order: maintain_r,458options: or,459},460) => {461apply_l.is_none()462&& apply_r.is_none()463&& ol == or464&& maintain_l == maintain_r465&& expr_irs_eq(keys_l, keys_r, self.expr_arena)466&& expr_irs_eq(aggs_l, aggs_r, self.expr_arena)467},468(469IR::Join {470input_left: _,471input_right: _,472schema: _,473left_on: ll,474right_on: rl,475options: ol,476},477IR::Join {478input_left: _,479input_right: _,480schema: _,481left_on: lr,482right_on: rr,483options: or,484},485) => {486ol == or487&& expr_irs_eq(ll, lr, self.expr_arena)488&& expr_irs_eq(rl, rr, self.expr_arena)489},490(491IR::HStack {492input: _,493exprs: el,494schema: _,495options: ol,496},497IR::HStack {498input: _,499exprs: er,500schema: _,501options: or,502},503) => ol == or && expr_irs_eq(el, er, self.expr_arena),504(505IR::Distinct {506input: _,507options: ol,508},509IR::Distinct {510input: _,511options: or,512},513) => ol == or,514(515IR::MapFunction {516input: _,517function: l,518},519IR::MapFunction {520input: _,521function: r,522},523) => l == r,524(525IR::Union {526inputs: _,527options: l,528},529IR::Union {530inputs: _,531options: r,532},533) => l == r,534(535IR::HConcat {536inputs: _,537schema: _,538options: l,539},540IR::HConcat {541inputs: _,542schema: _,543options: r,544},545) => l == r,546(547IR::ExtContext {548input: _,549contexts: l,550schema: _,551},552IR::ExtContext {553input: _,554contexts: r,555schema: _,556},557) => {558l.len() == r.len()559&& l.iter().zip(r.iter()).all(|(l, r)| {560let l = AexprNode::new(*l).hashable_and_cmp(self.expr_arena);561let r = AexprNode::new(*r).hashable_and_cmp(self.expr_arena);562l == r563})564},565_ => false,566}567}568}569570impl PartialEq for HashableEqLP<'_> {571fn eq(&self, other: &Self) -> bool {572let mut scratch_1 = vec![];573let mut scratch_2 = vec![];574575scratch_1.push(self.node.node());576scratch_2.push(other.node.node());577578loop {579match (scratch_1.pop(), scratch_2.pop()) {580(Some(l), Some(r)) => {581let l = IRNode::new(l);582let r = IRNode::new(r);583let l_alp = l.to_alp(self.lp_arena);584let r_alp = r.to_alp(self.lp_arena);585586if self.ignore_cache {587match (l_alp, r_alp) {588(IR::Cache { input: l, .. }, IR::Cache { input: r, .. }) => {589scratch_1.push(*l);590scratch_2.push(*r);591continue;592},593(IR::Cache { input: l, .. }, _) => {594scratch_1.push(*l);595scratch_2.push(r.node());596continue;597},598(_, IR::Cache { input: r, .. }) => {599scratch_1.push(l.node());600scratch_2.push(*r);601continue;602},603_ => {},604}605}606607if !l608.hashable_and_cmp(self.lp_arena, self.expr_arena)609.is_equal(&r.hashable_and_cmp(self.lp_arena, self.expr_arena))610{611return false;612}613614l_alp.copy_inputs(&mut scratch_1);615r_alp.copy_inputs(&mut scratch_2);616},617(None, None) => return true,618_ => return false,619}620}621}622}623624625