Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/plans/visitor/expr.rs
6940 views
1
use std::fmt::{Debug, Formatter};
2
3
use polars_core::prelude::{Field, Schema};
4
use polars_utils::unitvec;
5
6
use super::*;
7
use crate::prelude::*;
8
9
impl TreeWalker for Expr {
10
type Arena = ();
11
12
fn apply_children<F: FnMut(&Self, &Self::Arena) -> PolarsResult<VisitRecursion>>(
13
&self,
14
op: &mut F,
15
arena: &Self::Arena,
16
) -> PolarsResult<VisitRecursion> {
17
let mut scratch = unitvec![];
18
19
self.nodes(&mut scratch);
20
21
for &child in scratch.as_slice() {
22
match op(child, arena)? {
23
// let the recursion continue
24
VisitRecursion::Continue | VisitRecursion::Skip => {},
25
// early stop
26
VisitRecursion::Stop => return Ok(VisitRecursion::Stop),
27
}
28
}
29
Ok(VisitRecursion::Continue)
30
}
31
32
fn map_children<F: FnMut(Self, &mut Self::Arena) -> PolarsResult<Self>>(
33
self,
34
f: &mut F,
35
_arena: &mut Self::Arena,
36
) -> PolarsResult<Self> {
37
use polars_utils::functions::try_arc_map as am;
38
let mut f = |expr| f(expr, &mut ());
39
use AggExpr::*;
40
use Expr::*;
41
#[rustfmt::skip]
42
let ret = match self {
43
Alias(l, r) => Alias(am(l, f)?, r),
44
Column(_) => self,
45
Literal(_) => self,
46
DataTypeFunction(_) => self,
47
#[cfg(feature = "dtype-struct")]
48
Field(_) => self,
49
BinaryExpr { left, op, right } => {
50
BinaryExpr { left: am(left, &mut f)? , op, right: am(right, f)?}
51
},
52
Cast { expr, dtype, options: strict } => Cast { expr: am(expr, f)?, dtype, options: strict },
53
Sort { expr, options } => Sort { expr: am(expr, f)?, options },
54
Gather { expr, idx, returns_scalar } => Gather { expr: am(expr, &mut f)?, idx: am(idx, f)?, returns_scalar },
55
SortBy { expr, by, sort_options } => SortBy { expr: am(expr, &mut f)?, by: by.into_iter().map(f).collect::<Result<_, _>>()?, sort_options },
56
Agg(agg_expr) => Agg(match agg_expr {
57
Min { input, propagate_nans } => Min { input: am(input, f)?, propagate_nans },
58
Max { input, propagate_nans } => Max { input: am(input, f)?, propagate_nans },
59
Median(x) => Median(am(x, f)?),
60
NUnique(x) => NUnique(am(x, f)?),
61
First(x) => First(am(x, f)?),
62
Last(x) => Last(am(x, f)?),
63
Mean(x) => Mean(am(x, f)?),
64
Implode(x) => Implode(am(x, f)?),
65
Count { input, include_nulls } => Count { input: am(input, f)?, include_nulls },
66
Quantile { expr, quantile, method: interpol } => Quantile { expr: am(expr, &mut f)?, quantile: am(quantile, f)?, method: interpol },
67
Sum(x) => Sum(am(x, f)?),
68
AggGroups(x) => AggGroups(am(x, f)?),
69
Std(x, ddf) => Std(am(x, f)?, ddf),
70
Var(x, ddf) => Var(am(x, f)?, ddf),
71
}),
72
Ternary { predicate, truthy, falsy } => Ternary { predicate: am(predicate, &mut f)?, truthy: am(truthy, &mut f)?, falsy: am(falsy, f)? },
73
Function { input, function } => Function { input: input.into_iter().map(f).collect::<Result<_, _>>()?, function },
74
Explode { input, skip_empty } => Explode { input: am(input, f)?, skip_empty },
75
Filter { input, by } => Filter { input: am(input, &mut f)?, by: am(by, f)? },
76
Window { function, partition_by, order_by, options } => {
77
let partition_by = partition_by.into_iter().map(&mut f).collect::<Result<_, _>>()?;
78
Window { function: am(function, f)?, partition_by, order_by, options }
79
},
80
Slice { input, offset, length } => Slice { input: am(input, &mut f)?, offset: am(offset, &mut f)?, length: am(length, f)? },
81
KeepName(expr) => KeepName(am(expr, f)?),
82
Len => Len,
83
RenameAlias { function, expr } => RenameAlias { function, expr: am(expr, f)? },
84
AnonymousFunction { input, function, options, fmt_str } => {
85
AnonymousFunction { input: input.into_iter().map(f).collect::<Result<_, _>>()?, function, options, fmt_str }
86
},
87
Eval { expr: input, evaluation, variant } => Eval { expr: am(input, &mut f)?, evaluation: am(evaluation, f)?, variant },
88
SubPlan(_, _) => self,
89
Selector(_) => self,
90
};
91
Ok(ret)
92
}
93
}
94
95
#[derive(Copy, Clone, Debug)]
96
pub struct AexprNode {
97
node: Node,
98
}
99
100
impl AexprNode {
101
pub fn new(node: Node) -> Self {
102
Self { node }
103
}
104
105
/// Get the `Node`.
106
pub fn node(&self) -> Node {
107
self.node
108
}
109
110
pub fn to_aexpr<'a>(&self, arena: &'a Arena<AExpr>) -> &'a AExpr {
111
arena.get(self.node)
112
}
113
114
pub fn to_expr(&self, arena: &Arena<AExpr>) -> Expr {
115
node_to_expr(self.node, arena)
116
}
117
118
pub fn to_field(&self, schema: &Schema, arena: &Arena<AExpr>) -> PolarsResult<Field> {
119
let aexpr = arena.get(self.node);
120
aexpr.to_field(schema, arena)
121
}
122
123
pub fn assign(&mut self, ae: AExpr, arena: &mut Arena<AExpr>) {
124
let node = arena.add(ae);
125
self.node = node;
126
}
127
128
pub(crate) fn is_leaf(&self, arena: &Arena<AExpr>) -> bool {
129
matches!(self.to_aexpr(arena), AExpr::Column(_) | AExpr::Literal(_))
130
}
131
132
pub(crate) fn hashable_and_cmp<'a>(&self, arena: &'a Arena<AExpr>) -> AExprArena<'a> {
133
AExprArena {
134
node: self.node,
135
arena,
136
}
137
}
138
}
139
140
pub struct AExprArena<'a> {
141
node: Node,
142
arena: &'a Arena<AExpr>,
143
}
144
145
impl Debug for AExprArena<'_> {
146
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
147
write!(f, "AexprArena: {}", self.node.0)
148
}
149
}
150
151
impl AExpr {
152
fn is_equal_node(&self, other: &Self) -> bool {
153
use AExpr::*;
154
match (self, other) {
155
(Column(l), Column(r)) => l == r,
156
(Literal(l), Literal(r)) => l == r,
157
(Window { options: l, .. }, Window { options: r, .. }) => l == r,
158
(
159
Cast {
160
options: strict_l,
161
dtype: dtl,
162
..
163
},
164
Cast {
165
options: strict_r,
166
dtype: dtr,
167
..
168
},
169
) => strict_l == strict_r && dtl == dtr,
170
(Sort { options: l, .. }, Sort { options: r, .. }) => l == r,
171
(Gather { .. }, Gather { .. })
172
| (Filter { .. }, Filter { .. })
173
| (Ternary { .. }, Ternary { .. })
174
| (Len, Len)
175
| (Slice { .. }, Slice { .. }) => true,
176
(
177
Explode {
178
expr: _,
179
skip_empty: l_skip_empty,
180
},
181
Explode {
182
expr: _,
183
skip_empty: r_skip_empty,
184
},
185
) => l_skip_empty == r_skip_empty,
186
(
187
SortBy {
188
sort_options: l_sort_options,
189
..
190
},
191
SortBy {
192
sort_options: r_sort_options,
193
..
194
},
195
) => l_sort_options == r_sort_options,
196
(Agg(l), Agg(r)) => l.equal_nodes(r),
197
(
198
Function {
199
input: il,
200
function: fl,
201
options: ol,
202
},
203
Function {
204
input: ir,
205
function: fr,
206
options: or,
207
},
208
) => {
209
fl == fr && ol == or && {
210
let mut all_same_name = true;
211
for (l, r) in il.iter().zip(ir) {
212
all_same_name &= l.output_name() == r.output_name()
213
}
214
215
all_same_name
216
}
217
},
218
(AnonymousFunction { .. }, AnonymousFunction { .. }) => false,
219
(BinaryExpr { op: l, .. }, BinaryExpr { op: r, .. }) => l == r,
220
_ => false,
221
}
222
}
223
}
224
225
impl<'a> AExprArena<'a> {
226
pub fn new(node: Node, arena: &'a Arena<AExpr>) -> Self {
227
Self { node, arena }
228
}
229
pub fn to_aexpr(&self) -> &'a AExpr {
230
self.arena.get(self.node)
231
}
232
233
// Check single node on equality
234
pub fn is_equal_single(&self, other: &Self) -> bool {
235
let self_ae = self.to_aexpr();
236
let other_ae = other.to_aexpr();
237
self_ae.is_equal_node(other_ae)
238
}
239
}
240
241
impl PartialEq for AExprArena<'_> {
242
fn eq(&self, other: &Self) -> bool {
243
let mut scratch1 = unitvec![];
244
let mut scratch2 = unitvec![];
245
246
scratch1.push(self.node);
247
scratch2.push(other.node);
248
249
loop {
250
match (scratch1.pop(), scratch2.pop()) {
251
(Some(l), Some(r)) => {
252
let l = Self::new(l, self.arena);
253
let r = Self::new(r, self.arena);
254
255
if !l.is_equal_single(&r) {
256
return false;
257
}
258
259
l.to_aexpr().inputs_rev(&mut scratch1);
260
r.to_aexpr().inputs_rev(&mut scratch2);
261
},
262
(None, None) => return true,
263
_ => return false,
264
}
265
}
266
}
267
}
268
269
impl TreeWalker for AexprNode {
270
type Arena = Arena<AExpr>;
271
fn apply_children<F: FnMut(&Self, &Self::Arena) -> PolarsResult<VisitRecursion>>(
272
&self,
273
op: &mut F,
274
arena: &Self::Arena,
275
) -> PolarsResult<VisitRecursion> {
276
let mut scratch = unitvec![];
277
278
self.to_aexpr(arena).inputs_rev(&mut scratch);
279
for node in scratch.as_slice() {
280
let aenode = AexprNode::new(*node);
281
match op(&aenode, arena)? {
282
// let the recursion continue
283
VisitRecursion::Continue | VisitRecursion::Skip => {},
284
// early stop
285
VisitRecursion::Stop => return Ok(VisitRecursion::Stop),
286
}
287
}
288
Ok(VisitRecursion::Continue)
289
}
290
291
fn map_children<F: FnMut(Self, &mut Self::Arena) -> PolarsResult<Self>>(
292
mut self,
293
op: &mut F,
294
arena: &mut Self::Arena,
295
) -> PolarsResult<Self> {
296
let mut scratch = unitvec![];
297
298
let ae = arena.get(self.node).clone();
299
ae.inputs_rev(&mut scratch);
300
301
// rewrite the nodes
302
for node in scratch.as_mut_slice() {
303
let aenode = AexprNode::new(*node);
304
*node = op(aenode, arena)?.node;
305
}
306
307
scratch.as_mut_slice().reverse();
308
let ae = ae.replace_inputs(&scratch);
309
self.node = arena.add(ae);
310
Ok(self)
311
}
312
}
313
314