Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/plans/aexpr/equality.rs
8422 views
1
use polars_core::prelude::SortOptions;
2
use polars_utils::arena::{Arena, Node};
3
4
use super::{AExpr, IRAggExpr};
5
6
impl AExpr {
7
pub fn is_expr_equal_to(&self, other: &Self, arena: &Arena<AExpr>) -> bool {
8
let mut l_stack = Vec::new();
9
let mut r_stack = Vec::new();
10
self.is_expr_equal_to_amortized(other, arena, &mut l_stack, &mut r_stack)
11
}
12
13
pub fn is_expr_equal_to_amortized(
14
&self,
15
other: &Self,
16
arena: &Arena<AExpr>,
17
l_stack: &mut Vec<Node>,
18
r_stack: &mut Vec<Node>,
19
) -> bool {
20
l_stack.clear();
21
r_stack.clear();
22
23
// Top-Level node.
24
if !self.is_expr_equal_top_level(other) {
25
return false;
26
}
27
self.children_rev(l_stack);
28
other.children_rev(r_stack);
29
30
// Traverse node in N R L order
31
loop {
32
assert_eq!(l_stack.len(), r_stack.len());
33
34
let (Some(l_node), Some(r_node)) = (l_stack.pop(), r_stack.pop()) else {
35
break;
36
};
37
38
let l_expr = arena.get(l_node);
39
let r_expr = arena.get(r_node);
40
41
if !l_expr.is_expr_equal_top_level(r_expr) {
42
return false;
43
}
44
l_expr.children_rev(l_stack);
45
r_expr.children_rev(r_stack);
46
}
47
48
true
49
}
50
51
pub fn is_expr_equal_top_level(&self, other: &Self) -> bool {
52
if std::mem::discriminant(self) != std::mem::discriminant(other) {
53
// Fast path: different kind of expression.
54
return false;
55
}
56
57
use AExpr as E;
58
59
// @NOTE: Intentionally written as a match statement over only `self` as it forces the
60
// match to be exhaustive.
61
#[rustfmt::skip]
62
let is_equal = match self {
63
E::Explode { expr: _, options: l_options } => matches!(other, E::Explode { expr: _, options: r_options } if l_options == r_options),
64
E::Column(l_name) => matches!(other, E::Column(r_name) if l_name == r_name),
65
#[cfg(feature = "dtype-struct")]
66
E::StructField (l_name) => matches!(other, E::StructField(r_name) if l_name == r_name),
67
E::Literal(l_lit) => matches!(other, E::Literal(r_lit) if l_lit == r_lit),
68
E::BinaryExpr { left: _, op: l_op, right: _ } => matches!(other, E::BinaryExpr { left: _, op: r_op, right: _ } if l_op == r_op),
69
E::Cast { expr: _, dtype: l_dtype, options: l_options } => matches!(other, E::Cast { expr: _, dtype: r_dtype, options: r_options } if l_dtype == r_dtype && l_options == r_options),
70
E::Sort { expr: _, options: l_options } => matches!(other, E::Sort { expr: _, options: r_options } if l_options == r_options),
71
E::Gather { expr: _, idx: l_idx, returns_scalar: l_returns_scalar, null_on_oob: l_null_on_oob } => matches!(other, E::Gather { expr: _, idx: r_idx, returns_scalar: r_returns_scalar, null_on_oob: r_null_on_oob } if l_idx == r_idx && l_returns_scalar == r_returns_scalar && l_null_on_oob == r_null_on_oob),
72
E::SortBy { expr: _, by: l_by, sort_options: l_sort_options } => matches!(other, E::SortBy { expr: _, by: r_by, sort_options: r_sort_options } if l_by.len() == r_by.len() && l_sort_options == r_sort_options),
73
E::Agg(l_agg) => matches!(other, E::Agg(r_agg) if l_agg.is_agg_equal_top_level(r_agg)),
74
E::AnonymousAgg { input: input_l, fmt_str: fmt_str_l, function: function_l } => matches!(other, E::AnonymousAgg { input: input_r, fmt_str: fmt_str_r, function: function_r} if input_l == input_r && function_l == function_r && fmt_str_l == fmt_str_r),
75
E::AnonymousFunction { input: l_input, function: l_function, options: l_options, fmt_str: l_fmt_str } => matches!(other, E::AnonymousFunction { input: r_input, function: r_function, options: r_options, fmt_str: r_fmt_str } if l_input.len() == r_input.len() && l_function == r_function && l_options == r_options && l_fmt_str == r_fmt_str),
76
E::Eval { expr: _, evaluation: _, variant: l_variant } => matches!(other, E::Eval { expr: _, evaluation: _, variant: r_variant } if l_variant == r_variant),
77
E::Function { input: l_input, function: l_function, options: l_options } => matches!(other, E::Function { input: r_input, function: r_function, options: r_options } if l_input.len() == r_input.len() && l_function == r_function && l_options == r_options),
78
#[cfg(feature = "dynamic_group_by")]
79
E::Rolling { function: _, index_column: _, period: l_period, offset: l_offset, closed_window: l_closed_window } => matches!(other, E::Rolling { function: _, index_column: _, period: r_period, offset: r_offset, closed_window: r_closed_window } if l_period == r_period && l_offset == r_offset && l_closed_window == r_closed_window),
80
E::Over { function: _, partition_by: l_partition_by, order_by: l_order_by, mapping: l_mapping } => matches!(other, E::Over { function: _, partition_by: r_partition_by, order_by: r_order_by, mapping: r_mapping } if l_partition_by.len() == r_partition_by.len() && l_order_by.as_ref().map(|(_, v): &(Node, SortOptions)| v) == r_order_by.as_ref().map(|(_, v): &(Node, SortOptions)| v) && l_mapping == r_mapping),
81
82
// Discriminant check done above.
83
E::Element |
84
E::Filter { input: _, by: _ } |
85
E::Ternary { predicate: _, truthy: _, falsy: _ } |
86
E::Slice { input: _, offset: _, length: _ } |
87
E::Len => true,
88
#[cfg(feature = "dtype-struct")]
89
E::StructEval { expr: _, evaluation: _} => true
90
};
91
92
is_equal
93
}
94
}
95
96
impl IRAggExpr {
97
pub fn is_agg_equal_top_level(&self, other: &Self) -> bool {
98
if std::mem::discriminant(self) != std::mem::discriminant(other) {
99
// Fast path: different kind of expression.
100
return false;
101
}
102
103
use IRAggExpr as A;
104
105
// @NOTE: Intentionally written as a match statement over only `self` as it forces the
106
// match to be exhaustive.
107
#[rustfmt::skip]
108
let is_equal = match self {
109
A::Min { input: _, propagate_nans: l_propagate_nans } => matches!(other, A::Min { input: _, propagate_nans: r_propagate_nans } if l_propagate_nans == r_propagate_nans),
110
A::Max { input: _, propagate_nans: l_propagate_nans } => matches!(other, A::Max { input: _, propagate_nans: r_propagate_nans } if l_propagate_nans == r_propagate_nans),
111
A::Quantile { expr: _, quantile: _, method: l_method } => matches!(other, A::Quantile { expr: _, quantile: _, method: r_method } if l_method == r_method),
112
A::Count { input: _, include_nulls: l_include_nulls } => matches!(other, A::Count { input: _, include_nulls: r_include_nulls } if l_include_nulls == r_include_nulls),
113
A::Item { input: _, allow_empty: l_allow_empty } => matches!(other, A::Item { input: _, allow_empty: r_allow_empty } if l_allow_empty == r_allow_empty),
114
A::Std(_, l_ddof) => matches!(other, A::Std(_, r_ddof) if l_ddof == r_ddof),
115
A::Var(_, l_ddof) => matches!(other, A::Var(_, r_ddof) if l_ddof == r_ddof),
116
117
// Discriminant check done above.
118
A::Median(_) |
119
A::NUnique(_) |
120
A::First(_) |
121
A::FirstNonNull(_) |
122
A::Last(_) |
123
A::LastNonNull(_) |
124
A::Mean(_) |
125
A::Implode(_) |
126
A::Sum(_) |
127
A::AggGroups(_) => true,
128
};
129
130
is_equal
131
}
132
}
133
134