Path: blob/main/crates/polars-plan/src/plans/aexpr/equality.rs
8422 views
use polars_core::prelude::SortOptions;1use polars_utils::arena::{Arena, Node};23use super::{AExpr, IRAggExpr};45impl AExpr {6pub fn is_expr_equal_to(&self, other: &Self, arena: &Arena<AExpr>) -> bool {7let mut l_stack = Vec::new();8let mut r_stack = Vec::new();9self.is_expr_equal_to_amortized(other, arena, &mut l_stack, &mut r_stack)10}1112pub fn is_expr_equal_to_amortized(13&self,14other: &Self,15arena: &Arena<AExpr>,16l_stack: &mut Vec<Node>,17r_stack: &mut Vec<Node>,18) -> bool {19l_stack.clear();20r_stack.clear();2122// Top-Level node.23if !self.is_expr_equal_top_level(other) {24return false;25}26self.children_rev(l_stack);27other.children_rev(r_stack);2829// Traverse node in N R L order30loop {31assert_eq!(l_stack.len(), r_stack.len());3233let (Some(l_node), Some(r_node)) = (l_stack.pop(), r_stack.pop()) else {34break;35};3637let l_expr = arena.get(l_node);38let r_expr = arena.get(r_node);3940if !l_expr.is_expr_equal_top_level(r_expr) {41return false;42}43l_expr.children_rev(l_stack);44r_expr.children_rev(r_stack);45}4647true48}4950pub fn is_expr_equal_top_level(&self, other: &Self) -> bool {51if std::mem::discriminant(self) != std::mem::discriminant(other) {52// Fast path: different kind of expression.53return false;54}5556use AExpr as E;5758// @NOTE: Intentionally written as a match statement over only `self` as it forces the59// match to be exhaustive.60#[rustfmt::skip]61let is_equal = match self {62E::Explode { expr: _, options: l_options } => matches!(other, E::Explode { expr: _, options: r_options } if l_options == r_options),63E::Column(l_name) => matches!(other, E::Column(r_name) if l_name == r_name),64#[cfg(feature = "dtype-struct")]65E::StructField (l_name) => matches!(other, E::StructField(r_name) if l_name == r_name),66E::Literal(l_lit) => matches!(other, E::Literal(r_lit) if l_lit == r_lit),67E::BinaryExpr { left: _, op: l_op, right: _ } => matches!(other, E::BinaryExpr { left: _, op: r_op, right: _ } if l_op == r_op),68E::Cast { expr: _, dtype: l_dtype, options: l_options } => matches!(other, E::Cast { expr: _, dtype: r_dtype, options: r_options } if l_dtype == r_dtype && l_options == r_options),69E::Sort { expr: _, options: l_options } => matches!(other, E::Sort { expr: _, options: r_options } if l_options == r_options),70E::Gather { expr: _, idx: l_idx, returns_scalar: l_returns_scalar, null_on_oob: l_null_on_oob } => matches!(other, E::Gather { expr: _, idx: r_idx, returns_scalar: r_returns_scalar, null_on_oob: r_null_on_oob } if l_idx == r_idx && l_returns_scalar == r_returns_scalar && l_null_on_oob == r_null_on_oob),71E::SortBy { expr: _, by: l_by, sort_options: l_sort_options } => matches!(other, E::SortBy { expr: _, by: r_by, sort_options: r_sort_options } if l_by.len() == r_by.len() && l_sort_options == r_sort_options),72E::Agg(l_agg) => matches!(other, E::Agg(r_agg) if l_agg.is_agg_equal_top_level(r_agg)),73E::AnonymousAgg { input: input_l, fmt_str: fmt_str_l, function: function_l } => matches!(other, E::AnonymousAgg { input: input_r, fmt_str: fmt_str_r, function: function_r} if input_l == input_r && function_l == function_r && fmt_str_l == fmt_str_r),74E::AnonymousFunction { input: l_input, function: l_function, options: l_options, fmt_str: l_fmt_str } => matches!(other, E::AnonymousFunction { input: r_input, function: r_function, options: r_options, fmt_str: r_fmt_str } if l_input.len() == r_input.len() && l_function == r_function && l_options == r_options && l_fmt_str == r_fmt_str),75E::Eval { expr: _, evaluation: _, variant: l_variant } => matches!(other, E::Eval { expr: _, evaluation: _, variant: r_variant } if l_variant == r_variant),76E::Function { input: l_input, function: l_function, options: l_options } => matches!(other, E::Function { input: r_input, function: r_function, options: r_options } if l_input.len() == r_input.len() && l_function == r_function && l_options == r_options),77#[cfg(feature = "dynamic_group_by")]78E::Rolling { function: _, index_column: _, period: l_period, offset: l_offset, closed_window: l_closed_window } => matches!(other, E::Rolling { function: _, index_column: _, period: r_period, offset: r_offset, closed_window: r_closed_window } if l_period == r_period && l_offset == r_offset && l_closed_window == r_closed_window),79E::Over { function: _, partition_by: l_partition_by, order_by: l_order_by, mapping: l_mapping } => matches!(other, E::Over { function: _, partition_by: r_partition_by, order_by: r_order_by, mapping: r_mapping } if l_partition_by.len() == r_partition_by.len() && l_order_by.as_ref().map(|(_, v): &(Node, SortOptions)| v) == r_order_by.as_ref().map(|(_, v): &(Node, SortOptions)| v) && l_mapping == r_mapping),8081// Discriminant check done above.82E::Element |83E::Filter { input: _, by: _ } |84E::Ternary { predicate: _, truthy: _, falsy: _ } |85E::Slice { input: _, offset: _, length: _ } |86E::Len => true,87#[cfg(feature = "dtype-struct")]88E::StructEval { expr: _, evaluation: _} => true89};9091is_equal92}93}9495impl IRAggExpr {96pub fn is_agg_equal_top_level(&self, other: &Self) -> bool {97if std::mem::discriminant(self) != std::mem::discriminant(other) {98// Fast path: different kind of expression.99return false;100}101102use IRAggExpr as A;103104// @NOTE: Intentionally written as a match statement over only `self` as it forces the105// match to be exhaustive.106#[rustfmt::skip]107let is_equal = match self {108A::Min { input: _, propagate_nans: l_propagate_nans } => matches!(other, A::Min { input: _, propagate_nans: r_propagate_nans } if l_propagate_nans == r_propagate_nans),109A::Max { input: _, propagate_nans: l_propagate_nans } => matches!(other, A::Max { input: _, propagate_nans: r_propagate_nans } if l_propagate_nans == r_propagate_nans),110A::Quantile { expr: _, quantile: _, method: l_method } => matches!(other, A::Quantile { expr: _, quantile: _, method: r_method } if l_method == r_method),111A::Count { input: _, include_nulls: l_include_nulls } => matches!(other, A::Count { input: _, include_nulls: r_include_nulls } if l_include_nulls == r_include_nulls),112A::Item { input: _, allow_empty: l_allow_empty } => matches!(other, A::Item { input: _, allow_empty: r_allow_empty } if l_allow_empty == r_allow_empty),113A::Std(_, l_ddof) => matches!(other, A::Std(_, r_ddof) if l_ddof == r_ddof),114A::Var(_, l_ddof) => matches!(other, A::Var(_, r_ddof) if l_ddof == r_ddof),115116// Discriminant check done above.117A::Median(_) |118A::NUnique(_) |119A::First(_) |120A::FirstNonNull(_) |121A::Last(_) |122A::LastNonNull(_) |123A::Mean(_) |124A::Implode(_) |125A::Sum(_) |126A::AggGroups(_) => true,127};128129is_equal130}131}132133134