Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/plans/aexpr/equality.rs
6940 views
1
use polars_core::prelude::SortOptions;
2
use polars_utils::arena::{Arena, Node};
3
4
use super::{AExpr, IRAggExpr};
5
6
impl AExpr {
7
pub fn is_expr_equal_to(&self, other: &Self, arena: &Arena<AExpr>) -> bool {
8
let mut l_stack = Vec::new();
9
let mut r_stack = Vec::new();
10
self.is_expr_equal_to_amortized(other, arena, &mut l_stack, &mut r_stack)
11
}
12
13
pub fn is_expr_equal_to_amortized(
14
&self,
15
other: &Self,
16
arena: &Arena<AExpr>,
17
l_stack: &mut Vec<Node>,
18
r_stack: &mut Vec<Node>,
19
) -> bool {
20
l_stack.clear();
21
r_stack.clear();
22
23
// Top-Level node.
24
if !self.is_expr_equal_top_level(other) {
25
return false;
26
}
27
self.children_rev(l_stack);
28
other.children_rev(r_stack);
29
30
// Traverse node in N R L order
31
loop {
32
assert_eq!(l_stack.len(), r_stack.len());
33
34
let (Some(l_node), Some(r_node)) = (l_stack.pop(), r_stack.pop()) else {
35
break;
36
};
37
38
let l_expr = arena.get(l_node);
39
let r_expr = arena.get(r_node);
40
41
if !l_expr.is_expr_equal_top_level(r_expr) {
42
return false;
43
}
44
l_expr.children_rev(l_stack);
45
r_expr.children_rev(r_stack);
46
}
47
48
true
49
}
50
51
pub fn is_expr_equal_top_level(&self, other: &Self) -> bool {
52
if std::mem::discriminant(self) != std::mem::discriminant(other) {
53
// Fast path: different kind of expression.
54
return false;
55
}
56
57
use AExpr as E;
58
59
// @NOTE: Intentionally written as a match statement over only `self` as it forces the
60
// match to be exhaustive.
61
#[rustfmt::skip]
62
let is_equal = match self {
63
E::Explode { expr: _, skip_empty: l_skip_empty } => matches!(other, E::Explode { expr: _, skip_empty: r_skip_empty } if l_skip_empty == r_skip_empty),
64
E::Column(l_name) => matches!(other, E::Column(r_name) if l_name == r_name),
65
E::Literal(l_lit) => matches!(other, E::Literal(r_lit) if l_lit == r_lit),
66
E::BinaryExpr { left: _, op: l_op, right: _ } => matches!(other, E::BinaryExpr { left: _, op: r_op, right: _ } if l_op == r_op),
67
E::Cast { expr: _, dtype: l_dtype, options: l_options } => matches!(other, E::Cast { expr: _, dtype: r_dtype, options: r_options } if l_dtype == r_dtype && l_options == r_options),
68
E::Sort { expr: _, options: l_options } => matches!(other, E::Sort { expr: _, options: r_options } if l_options == r_options),
69
E::Gather { expr: _, idx: l_idx, returns_scalar: l_returns_scalar } => matches!(other, E::Gather { expr: _, idx: r_idx, returns_scalar: r_returns_scalar } if l_idx == r_idx && l_returns_scalar == r_returns_scalar),
70
E::SortBy { expr: _, by: l_by, sort_options: l_sort_options } => matches!(other, E::SortBy { expr: _, by: r_by, sort_options: r_sort_options } if l_by.len() == r_by.len() && l_sort_options == r_sort_options),
71
E::Agg(l_agg) => matches!(other, E::Agg(r_agg) if l_agg.is_agg_equal_top_level(r_agg)),
72
E::AnonymousFunction { input: l_input, function: l_function, options: l_options, fmt_str: l_fmt_str } => matches!(other, E::AnonymousFunction { input: r_input, function: r_function, options: r_options, fmt_str: r_fmt_str } if l_input.len() == r_input.len() && l_function == r_function && l_options == r_options && l_fmt_str == r_fmt_str),
73
E::Eval { expr: _, evaluation: _, variant: l_variant } => matches!(other, E::Eval { expr: _, evaluation: _, variant: r_variant } if l_variant == r_variant),
74
E::Function { input: l_input, function: l_function, options: l_options } => matches!(other, E::Function { input: r_input, function: r_function, options: r_options } if l_input.len() == r_input.len() && l_function == r_function && l_options == r_options),
75
E::Window { function: _, partition_by: l_partition_by, order_by: l_order_by, options: l_options } => matches!(other, E::Window { function: _, partition_by: r_partition_by, order_by: r_order_by, options: r_options } if l_partition_by.len() == r_partition_by.len() && l_order_by.as_ref().map(|(_, v): &(Node, SortOptions)| v) == r_order_by.as_ref().map(|(_, v): &(Node, SortOptions)| v) && l_options == r_options),
76
77
// Discriminant check done above.
78
E::Filter { input: _, by: _ } |
79
E::Ternary { predicate: _, truthy: _, falsy: _ } |
80
E::Slice { input: _, offset: _, length: _ } |
81
E::Len => true,
82
};
83
84
is_equal
85
}
86
}
87
88
impl IRAggExpr {
89
pub fn is_agg_equal_top_level(&self, other: &Self) -> bool {
90
if std::mem::discriminant(self) != std::mem::discriminant(other) {
91
// Fast path: different kind of expression.
92
return false;
93
}
94
95
use IRAggExpr as A;
96
97
// @NOTE: Intentionally written as a match statement over only `self` as it forces the
98
// match to be exhaustive.
99
#[rustfmt::skip]
100
let is_equal = match self {
101
A::Min { input: _, propagate_nans: l_propagate_nans } => matches!(other, A::Min { input: _, propagate_nans: r_propagate_nans } if l_propagate_nans == r_propagate_nans),
102
A::Max { input: _, propagate_nans: l_propagate_nans } => matches!(other, A::Max { input: _, propagate_nans: r_propagate_nans } if l_propagate_nans == r_propagate_nans),
103
A::Quantile { expr: _, quantile: _, method: l_method } => matches!(other, A::Quantile { expr: _, quantile: _, method: r_method } if l_method == r_method),
104
A::Count { input: _, include_nulls: l_include_nulls } => matches!(other, A::Count { input: _, include_nulls: r_include_nulls } if l_include_nulls == r_include_nulls),
105
A::Std(_, l_ddof) => matches!(other, A::Std(_, r_ddof) if l_ddof == r_ddof),
106
A::Var(_, l_ddof) => matches!(other, A::Var(_, r_ddof) if l_ddof == r_ddof),
107
108
// Discriminant check done above.
109
A::Median(_) |
110
A::NUnique(_) |
111
A::First(_) |
112
A::Last(_) |
113
A::Mean(_) |
114
A::Implode(_) |
115
A::Sum(_) |
116
A::AggGroups(_) => true,
117
};
118
119
is_equal
120
}
121
}
122
123