Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/plans/visitor/expr.rs
8446 views
1
use std::fmt::{Debug, Formatter};
2
3
use polars_core::prelude::{Field, Schema};
4
use polars_utils::unitvec;
5
6
use super::*;
7
use crate::prelude::*;
8
9
impl TreeWalker for Expr {
10
type Arena = ();
11
12
fn apply_children<F: FnMut(&Self, &Self::Arena) -> PolarsResult<VisitRecursion>>(
13
&self,
14
op: &mut F,
15
arena: &Self::Arena,
16
) -> PolarsResult<VisitRecursion> {
17
let mut scratch = unitvec![];
18
19
self.nodes(&mut scratch);
20
21
for &child in scratch.as_slice() {
22
match op(child, arena)? {
23
// let the recursion continue
24
VisitRecursion::Continue | VisitRecursion::Skip => {},
25
// early stop
26
VisitRecursion::Stop => return Ok(VisitRecursion::Stop),
27
}
28
}
29
Ok(VisitRecursion::Continue)
30
}
31
32
fn map_children<F: FnMut(Self, &mut Self::Arena) -> PolarsResult<Self>>(
33
self,
34
f: &mut F,
35
_arena: &mut Self::Arena,
36
) -> PolarsResult<Self> {
37
use polars_utils::functions::try_arc_map as am;
38
let mut f = |expr| f(expr, &mut ());
39
use AggExpr::*;
40
use Expr::*;
41
#[rustfmt::skip]
42
let ret = match self {
43
Alias(l, r) => Alias(am(l, f)?, r),
44
Column(_) => self,
45
Literal(_) => self,
46
DataTypeFunction(_) => self,
47
#[cfg(feature = "dtype-struct")]
48
Field(_) => self,
49
BinaryExpr { left, op, right } => {
50
BinaryExpr { left: am(left, &mut f)? , op, right: am(right, f)?}
51
},
52
Cast { expr, dtype, options: strict } => Cast { expr: am(expr, f)?, dtype, options: strict },
53
Sort { expr, options } => Sort { expr: am(expr, f)?, options },
54
Gather { expr, idx, returns_scalar, null_on_oob } => Gather {
55
expr: am(expr, &mut f)?,
56
idx: am(idx, f)?,
57
returns_scalar,
58
null_on_oob,
59
},
60
SortBy { expr, by, sort_options } => SortBy { expr: am(expr, &mut f)?, by: by.into_iter().map(f).collect::<Result<_, _>>()?, sort_options },
61
Agg(agg_expr) => Agg(match agg_expr {
62
Min { input, propagate_nans } => Min { input: am(input, f)?, propagate_nans },
63
Max { input, propagate_nans } => Max { input: am(input, f)?, propagate_nans },
64
Median(x) => Median(am(x, f)?),
65
NUnique(x) => NUnique(am(x, f)?),
66
First(x) => First(am(x, f)?),
67
FirstNonNull(x) => FirstNonNull(am(x, f)?),
68
Last(x) => Last(am(x, f)?),
69
LastNonNull(x) => LastNonNull(am(x, f)?),
70
Item { input, allow_empty } => Item { input: am(input, f)?, allow_empty },
71
Mean(x) => Mean(am(x, f)?),
72
Implode(x) => Implode(am(x, f)?),
73
Count { input, include_nulls } => Count { input: am(input, f)?, include_nulls },
74
Quantile { expr, quantile, method: interpol } => Quantile { expr: am(expr, &mut f)?, quantile: am(quantile, f)?, method: interpol },
75
Sum(x) => Sum(am(x, f)?),
76
AggGroups(x) => AggGroups(am(x, f)?),
77
Std(x, ddf) => Std(am(x, f)?, ddf),
78
Var(x, ddf) => Var(am(x, f)?, ddf),
79
80
}),
81
Ternary { predicate, truthy, falsy } => Ternary { predicate: am(predicate, &mut f)?, truthy: am(truthy, &mut f)?, falsy: am(falsy, f)? },
82
Function { input, function } => Function { input: input.into_iter().map(f).collect::<Result<_, _>>()?, function },
83
Explode { input, options } => Explode { input: am(input, f)?, options },
84
Filter { input, by } => Filter { input: am(input, &mut f)?, by: am(by, f)? },
85
#[cfg(feature = "dynamic_group_by")]
86
Rolling { function, index_column, period, offset, closed_window } => Rolling { function: am(function, &mut f)?, index_column: am(index_column, &mut f)?, period, offset, closed_window },
87
Over { function, partition_by, order_by, mapping } => {
88
let partition_by = partition_by.into_iter().map(&mut f).collect::<Result<_, _>>()?;
89
Over { function: am(function, f)?, partition_by, order_by, mapping }
90
},
91
Slice { input, offset, length } => Slice { input: am(input, &mut f)?, offset: am(offset, &mut f)?, length: am(length, f)? },
92
KeepName(expr) => KeepName(am(expr, f)?),
93
Element => Element,
94
Len => Len,
95
RenameAlias { function, expr } => RenameAlias { function, expr: am(expr, f)? },
96
Display { inputs, fmt_str } => {
97
Display { inputs: inputs.into_iter().map(f).collect::<Result<_, _>>()?, fmt_str }
98
},
99
AnonymousFunction { input, function, options, fmt_str } => {
100
AnonymousFunction { input: input.into_iter().map(f).collect::<Result<_, _>>()?, function, options, fmt_str }
101
},
102
Eval { expr: input, evaluation, variant } => Eval { expr: am(input, &mut f)?, evaluation: am(evaluation, f)?, variant },
103
#[cfg(feature = "dtype-struct")]
104
StructEval { expr: input, evaluation } => {
105
StructEval { expr: am(input, &mut f)?, evaluation: evaluation.into_iter().map(f).collect::<Result<_, _>>()? }
106
},
107
SubPlan(_, _) => self,
108
Selector(_) => self,
109
};
110
Ok(ret)
111
}
112
}
113
114
#[derive(Copy, Clone, Debug)]
115
pub struct AexprNode {
116
node: Node,
117
}
118
119
impl AexprNode {
120
pub fn new(node: Node) -> Self {
121
Self { node }
122
}
123
124
/// Get the `Node`.
125
pub fn node(&self) -> Node {
126
self.node
127
}
128
129
pub fn to_aexpr<'a>(&self, arena: &'a Arena<AExpr>) -> &'a AExpr {
130
arena.get(self.node)
131
}
132
133
pub fn to_expr(&self, arena: &Arena<AExpr>) -> Expr {
134
node_to_expr(self.node, arena)
135
}
136
137
pub fn to_field(&self, schema: &Schema, arena: &Arena<AExpr>) -> PolarsResult<Field> {
138
let aexpr = arena.get(self.node);
139
aexpr.to_field(&ToFieldContext::new(arena, schema))
140
}
141
142
pub fn assign(&mut self, ae: AExpr, arena: &mut Arena<AExpr>) {
143
let node = arena.add(ae);
144
self.node = node;
145
}
146
147
pub(crate) fn is_leaf(&self, arena: &Arena<AExpr>) -> bool {
148
matches!(self.to_aexpr(arena), AExpr::Column(_) | AExpr::Literal(_))
149
}
150
151
pub(crate) fn hashable_and_cmp<'a>(&self, arena: &'a Arena<AExpr>) -> AExprArena<'a> {
152
AExprArena {
153
node: self.node,
154
arena,
155
}
156
}
157
}
158
159
pub struct AExprArena<'a> {
160
node: Node,
161
arena: &'a Arena<AExpr>,
162
}
163
164
impl Debug for AExprArena<'_> {
165
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
166
write!(f, "AexprArena: {}", self.node.0)
167
}
168
}
169
170
impl AExpr {
171
fn is_equal_node(&self, other: &Self) -> bool {
172
use AExpr::*;
173
match (self, other) {
174
(Column(l), Column(r)) => l == r,
175
(Literal(l), Literal(r)) => l == r,
176
#[cfg(feature = "dynamic_group_by")]
177
(
178
Rolling {
179
function: _,
180
index_column: _,
181
period: l_period,
182
offset: l_offset,
183
closed_window: l_closed_window,
184
},
185
Rolling {
186
function: _,
187
index_column: _,
188
period: r_period,
189
offset: r_offset,
190
closed_window: r_closed_window,
191
},
192
) => l_period == r_period && l_offset == r_offset && l_closed_window == r_closed_window,
193
(Over { mapping: l, .. }, Over { mapping: r, .. }) => l == r,
194
(
195
Cast {
196
options: strict_l,
197
dtype: dtl,
198
..
199
},
200
Cast {
201
options: strict_r,
202
dtype: dtr,
203
..
204
},
205
) => strict_l == strict_r && dtl == dtr,
206
(Sort { options: l, .. }, Sort { options: r, .. }) => l == r,
207
(Gather { .. }, Gather { .. })
208
| (Filter { .. }, Filter { .. })
209
| (Ternary { .. }, Ternary { .. })
210
| (Len, Len)
211
| (Slice { .. }, Slice { .. }) => true,
212
(
213
Explode {
214
expr: _,
215
options: l_options,
216
},
217
Explode {
218
expr: _,
219
options: r_options,
220
},
221
) => l_options == r_options,
222
(
223
SortBy {
224
sort_options: l_sort_options,
225
..
226
},
227
SortBy {
228
sort_options: r_sort_options,
229
..
230
},
231
) => l_sort_options == r_sort_options,
232
(Agg(l), Agg(r)) => l.equal_nodes(r),
233
(
234
Function {
235
input: il,
236
function: fl,
237
options: ol,
238
},
239
Function {
240
input: ir,
241
function: fr,
242
options: or,
243
},
244
) => {
245
fl == fr && ol == or && {
246
let mut all_same_name = true;
247
for (l, r) in il.iter().zip(ir) {
248
all_same_name &= l.output_name() == r.output_name()
249
}
250
251
all_same_name
252
}
253
},
254
(
255
AnonymousFunction {
256
function: l1,
257
options: l2,
258
fmt_str: l3,
259
input: _,
260
},
261
AnonymousFunction {
262
function: r1,
263
options: r2,
264
fmt_str: r3,
265
input: _,
266
},
267
) => {
268
l2 == r2 && l3 == r3 && {
269
use LazySerde as L;
270
match (l1, r1) {
271
// We only check the pointers, so this works for python
272
// functions that are on the same address.
273
(L::Deserialized(l0), L::Deserialized(r0)) => l0 == r0,
274
(L::Bytes(l0), L::Bytes(r0)) => l0 == r0,
275
(
276
L::Named {
277
name: l_name,
278
payload: l_payload,
279
value: l_value,
280
},
281
L::Named {
282
name: r_name,
283
payload: r_payload,
284
value: r_value,
285
},
286
) => l_name == r_name && l_payload == r_payload && l_value == r_value,
287
_ => false,
288
}
289
}
290
},
291
(BinaryExpr { op: l, .. }, BinaryExpr { op: r, .. }) => l == r,
292
_ => false,
293
}
294
}
295
}
296
297
impl<'a> AExprArena<'a> {
298
pub fn new(node: Node, arena: &'a Arena<AExpr>) -> Self {
299
Self { node, arena }
300
}
301
pub fn to_aexpr(&self) -> &'a AExpr {
302
self.arena.get(self.node)
303
}
304
305
// Check single node on equality
306
pub fn is_equal_single(&self, other: &Self) -> bool {
307
let self_ae = self.to_aexpr();
308
let other_ae = other.to_aexpr();
309
self_ae.is_equal_node(other_ae)
310
}
311
}
312
313
impl PartialEq for AExprArena<'_> {
314
fn eq(&self, other: &Self) -> bool {
315
let mut scratch1 = unitvec![];
316
let mut scratch2 = unitvec![];
317
318
scratch1.push(self.node);
319
scratch2.push(other.node);
320
321
loop {
322
match (scratch1.pop(), scratch2.pop()) {
323
(Some(l), Some(r)) => {
324
let l = Self::new(l, self.arena);
325
let r = Self::new(r, other.arena);
326
327
if !l.is_equal_single(&r) {
328
return false;
329
}
330
331
l.to_aexpr().inputs_rev(&mut scratch1);
332
r.to_aexpr().inputs_rev(&mut scratch2);
333
},
334
(None, None) => return true,
335
_ => return false,
336
}
337
}
338
}
339
}
340
341
impl TreeWalker for AexprNode {
342
type Arena = Arena<AExpr>;
343
fn apply_children<F: FnMut(&Self, &Self::Arena) -> PolarsResult<VisitRecursion>>(
344
&self,
345
op: &mut F,
346
arena: &Self::Arena,
347
) -> PolarsResult<VisitRecursion> {
348
let mut scratch = unitvec![];
349
350
self.to_aexpr(arena).inputs_rev(&mut scratch);
351
for node in scratch.as_slice() {
352
let aenode = AexprNode::new(*node);
353
match op(&aenode, arena)? {
354
// let the recursion continue
355
VisitRecursion::Continue | VisitRecursion::Skip => {},
356
// early stop
357
VisitRecursion::Stop => return Ok(VisitRecursion::Stop),
358
}
359
}
360
Ok(VisitRecursion::Continue)
361
}
362
363
fn map_children<F: FnMut(Self, &mut Self::Arena) -> PolarsResult<Self>>(
364
mut self,
365
op: &mut F,
366
arena: &mut Self::Arena,
367
) -> PolarsResult<Self> {
368
let mut scratch = unitvec![];
369
370
let ae = arena.get(self.node).clone();
371
ae.inputs_rev(&mut scratch);
372
373
// rewrite the nodes
374
for node in scratch.as_mut_slice() {
375
let aenode = AexprNode::new(*node);
376
*node = op(aenode, arena)?.node;
377
}
378
379
scratch.as_mut_slice().reverse();
380
let ae = ae.replace_inputs(&scratch);
381
self.node = arena.add(ae);
382
Ok(self)
383
}
384
}
385
386