Path: blob/main/crates/polars-plan/src/plans/aexpr/equality.rs
6940 views
use polars_core::prelude::SortOptions;1use polars_utils::arena::{Arena, Node};23use super::{AExpr, IRAggExpr};45impl AExpr {6pub fn is_expr_equal_to(&self, other: &Self, arena: &Arena<AExpr>) -> bool {7let mut l_stack = Vec::new();8let mut r_stack = Vec::new();9self.is_expr_equal_to_amortized(other, arena, &mut l_stack, &mut r_stack)10}1112pub fn is_expr_equal_to_amortized(13&self,14other: &Self,15arena: &Arena<AExpr>,16l_stack: &mut Vec<Node>,17r_stack: &mut Vec<Node>,18) -> bool {19l_stack.clear();20r_stack.clear();2122// Top-Level node.23if !self.is_expr_equal_top_level(other) {24return false;25}26self.children_rev(l_stack);27other.children_rev(r_stack);2829// Traverse node in N R L order30loop {31assert_eq!(l_stack.len(), r_stack.len());3233let (Some(l_node), Some(r_node)) = (l_stack.pop(), r_stack.pop()) else {34break;35};3637let l_expr = arena.get(l_node);38let r_expr = arena.get(r_node);3940if !l_expr.is_expr_equal_top_level(r_expr) {41return false;42}43l_expr.children_rev(l_stack);44r_expr.children_rev(r_stack);45}4647true48}4950pub fn is_expr_equal_top_level(&self, other: &Self) -> bool {51if std::mem::discriminant(self) != std::mem::discriminant(other) {52// Fast path: different kind of expression.53return false;54}5556use AExpr as E;5758// @NOTE: Intentionally written as a match statement over only `self` as it forces the59// match to be exhaustive.60#[rustfmt::skip]61let is_equal = match self {62E::Explode { expr: _, skip_empty: l_skip_empty } => matches!(other, E::Explode { expr: _, skip_empty: r_skip_empty } if l_skip_empty == r_skip_empty),63E::Column(l_name) => matches!(other, E::Column(r_name) if l_name == r_name),64E::Literal(l_lit) => matches!(other, E::Literal(r_lit) if l_lit == r_lit),65E::BinaryExpr { left: _, op: l_op, right: _ } => matches!(other, E::BinaryExpr { left: _, op: r_op, right: _ } if l_op == r_op),66E::Cast { expr: _, dtype: l_dtype, options: l_options } => matches!(other, E::Cast { expr: _, dtype: r_dtype, options: r_options } if l_dtype == r_dtype && l_options == r_options),67E::Sort { expr: _, options: l_options } => matches!(other, E::Sort { expr: _, options: r_options } if l_options == r_options),68E::Gather { expr: _, idx: l_idx, returns_scalar: l_returns_scalar } => matches!(other, E::Gather { expr: _, idx: r_idx, returns_scalar: r_returns_scalar } if l_idx == r_idx && l_returns_scalar == r_returns_scalar),69E::SortBy { expr: _, by: l_by, sort_options: l_sort_options } => matches!(other, E::SortBy { expr: _, by: r_by, sort_options: r_sort_options } if l_by.len() == r_by.len() && l_sort_options == r_sort_options),70E::Agg(l_agg) => matches!(other, E::Agg(r_agg) if l_agg.is_agg_equal_top_level(r_agg)),71E::AnonymousFunction { input: l_input, function: l_function, options: l_options, fmt_str: l_fmt_str } => matches!(other, E::AnonymousFunction { input: r_input, function: r_function, options: r_options, fmt_str: r_fmt_str } if l_input.len() == r_input.len() && l_function == r_function && l_options == r_options && l_fmt_str == r_fmt_str),72E::Eval { expr: _, evaluation: _, variant: l_variant } => matches!(other, E::Eval { expr: _, evaluation: _, variant: r_variant } if l_variant == r_variant),73E::Function { input: l_input, function: l_function, options: l_options } => matches!(other, E::Function { input: r_input, function: r_function, options: r_options } if l_input.len() == r_input.len() && l_function == r_function && l_options == r_options),74E::Window { function: _, partition_by: l_partition_by, order_by: l_order_by, options: l_options } => matches!(other, E::Window { function: _, partition_by: r_partition_by, order_by: r_order_by, options: r_options } if l_partition_by.len() == r_partition_by.len() && l_order_by.as_ref().map(|(_, v): &(Node, SortOptions)| v) == r_order_by.as_ref().map(|(_, v): &(Node, SortOptions)| v) && l_options == r_options),7576// Discriminant check done above.77E::Filter { input: _, by: _ } |78E::Ternary { predicate: _, truthy: _, falsy: _ } |79E::Slice { input: _, offset: _, length: _ } |80E::Len => true,81};8283is_equal84}85}8687impl IRAggExpr {88pub fn is_agg_equal_top_level(&self, other: &Self) -> bool {89if std::mem::discriminant(self) != std::mem::discriminant(other) {90// Fast path: different kind of expression.91return false;92}9394use IRAggExpr as A;9596// @NOTE: Intentionally written as a match statement over only `self` as it forces the97// match to be exhaustive.98#[rustfmt::skip]99let is_equal = match self {100A::Min { input: _, propagate_nans: l_propagate_nans } => matches!(other, A::Min { input: _, propagate_nans: r_propagate_nans } if l_propagate_nans == r_propagate_nans),101A::Max { input: _, propagate_nans: l_propagate_nans } => matches!(other, A::Max { input: _, propagate_nans: r_propagate_nans } if l_propagate_nans == r_propagate_nans),102A::Quantile { expr: _, quantile: _, method: l_method } => matches!(other, A::Quantile { expr: _, quantile: _, method: r_method } if l_method == r_method),103A::Count { input: _, include_nulls: l_include_nulls } => matches!(other, A::Count { input: _, include_nulls: r_include_nulls } if l_include_nulls == r_include_nulls),104A::Std(_, l_ddof) => matches!(other, A::Std(_, r_ddof) if l_ddof == r_ddof),105A::Var(_, l_ddof) => matches!(other, A::Var(_, r_ddof) if l_ddof == r_ddof),106107// Discriminant check done above.108A::Median(_) |109A::NUnique(_) |110A::First(_) |111A::Last(_) |112A::Mean(_) |113A::Implode(_) |114A::Sum(_) |115A::AggGroups(_) => true,116};117118is_equal119}120}121122123