Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/plans/aexpr/traverse.rs
8427 views
1
use super::*;
2
3
impl AExpr {
4
/// Push the inputs of this node to the given container, in reverse order.
5
/// This ensures the primary node responsible for the name is pushed last.
6
///
7
/// This is subtly different from `children_rev` as this only includes the input expressions,
8
/// not expressions used during evaluation.
9
pub fn inputs_rev<E>(&self, container: &mut E)
10
where
11
E: Extend<Node>,
12
{
13
use AExpr::*;
14
15
match self {
16
Element | Column(_) | Literal(_) | Len => {},
17
#[cfg(feature = "dtype-struct")]
18
StructField(_) => {},
19
BinaryExpr { left, op: _, right } => {
20
container.extend([*right, *left]);
21
},
22
Cast { expr, .. } => container.extend([*expr]),
23
Sort { expr, .. } => container.extend([*expr]),
24
Gather { expr, idx, .. } => {
25
container.extend([*idx, *expr]);
26
},
27
SortBy { expr, by, .. } => {
28
container.extend(by.iter().cloned().rev());
29
container.extend([*expr]);
30
},
31
Filter { input, by } => {
32
container.extend([*by, *input]);
33
},
34
Agg(agg_e) => match agg_e.get_input() {
35
NodeInputs::Single(node) => container.extend([node]),
36
NodeInputs::Many(nodes) => container.extend(nodes.into_iter().rev()),
37
NodeInputs::Leaf => {},
38
},
39
Ternary {
40
truthy,
41
falsy,
42
predicate,
43
} => {
44
container.extend([*predicate, *falsy, *truthy]);
45
},
46
AnonymousFunction { input, .. }
47
| Function { input, .. }
48
| AnonymousAgg { input, .. } => container.extend(input.iter().rev().map(|e| e.node())),
49
Explode { expr: e, .. } => container.extend([*e]),
50
#[cfg(feature = "dynamic_group_by")]
51
Rolling {
52
function,
53
index_column,
54
period: _,
55
offset: _,
56
closed_window: _,
57
} => {
58
container.extend([*index_column, *function]);
59
},
60
Over {
61
function,
62
partition_by,
63
order_by,
64
mapping: _,
65
} => {
66
if let Some((n, _)) = order_by {
67
container.extend([*n]);
68
}
69
container.extend(partition_by.iter().rev().cloned());
70
container.extend([*function]);
71
},
72
Eval {
73
expr,
74
evaluation,
75
variant: _,
76
} => {
77
// We don't use the evaluation here because it does not contain inputs.
78
_ = evaluation;
79
container.extend([*expr]);
80
},
81
#[cfg(feature = "dtype-struct")]
82
StructEval { expr, evaluation } => {
83
// Evaluation is included. In case this is not allowed, use `inputs_rev_strict()`.
84
container.extend(evaluation.iter().rev().map(ExprIR::node));
85
container.extend([*expr]);
86
},
87
Slice {
88
input,
89
offset,
90
length,
91
} => {
92
container.extend([*length, *offset, *input]);
93
},
94
}
95
}
96
97
/// Push the inputs of this node to the given container, in reverse order.
98
/// This ensures the primary node responsible for the name is pushed last.
99
///
100
/// Unlike `inputs_rev`, this excludes Eval expressions. These use an extended schema,
101
/// determined by their input, which implies a different traversal order.
102
///
103
/// This is subtly different from `children_rev` as this only includes the input expressions,
104
/// not expressions used during evaluation.
105
pub fn inputs_rev_strict<E>(&self, container: &mut E)
106
where
107
E: Extend<Node>,
108
{
109
use AExpr::*;
110
111
match self {
112
#[cfg(feature = "dtype-struct")]
113
StructEval { expr, evaluation } => {
114
// Evaluation is explicitly excluded. It is up to the caller to handle
115
// any tree traversal if required.
116
_ = evaluation;
117
container.extend([*expr]);
118
},
119
expr => expr.inputs_rev(container),
120
}
121
}
122
123
/// Push the children of this node to the given container, in reverse order.
124
/// This ensures the primary node responsible for the name is pushed last.
125
///
126
/// This is subtly different from `input_rev` as this only all expressions included in the
127
/// expression not only the input expressions,
128
pub fn children_rev<E: Extend<Node>>(&self, container: &mut E) {
129
use AExpr::*;
130
131
match self {
132
Element | Column(_) | Literal(_) | Len => {},
133
#[cfg(feature = "dtype-struct")]
134
StructField(_) => {},
135
BinaryExpr { left, op: _, right } => {
136
container.extend([*right, *left]);
137
},
138
Cast { expr, .. } => container.extend([*expr]),
139
Sort { expr, .. } => container.extend([*expr]),
140
Gather { expr, idx, .. } => {
141
container.extend([*idx, *expr]);
142
},
143
SortBy { expr, by, .. } => {
144
container.extend(by.iter().cloned().rev());
145
container.extend([*expr]);
146
},
147
Filter { input, by } => {
148
container.extend([*by, *input]);
149
},
150
Agg(agg_e) => match agg_e.get_input() {
151
NodeInputs::Single(node) => container.extend([node]),
152
NodeInputs::Many(nodes) => container.extend(nodes.into_iter().rev()),
153
NodeInputs::Leaf => {},
154
},
155
Ternary {
156
truthy,
157
falsy,
158
predicate,
159
} => {
160
container.extend([*predicate, *falsy, *truthy]);
161
},
162
AnonymousFunction { input, .. }
163
| Function { input, .. }
164
| AnonymousAgg { input, .. } => container.extend(input.iter().rev().map(|e| e.node())),
165
Explode { expr: e, .. } => container.extend([*e]),
166
#[cfg(feature = "dynamic_group_by")]
167
Rolling {
168
function,
169
index_column,
170
period: _,
171
offset: _,
172
closed_window: _,
173
} => {
174
container.extend([*index_column, *function]);
175
},
176
Over {
177
function,
178
partition_by,
179
order_by,
180
mapping: _,
181
} => {
182
if let Some((n, _)) = order_by {
183
container.extend([*n]);
184
}
185
container.extend(partition_by.iter().rev().cloned());
186
container.extend([*function]);
187
},
188
Eval {
189
expr,
190
evaluation,
191
variant: _,
192
} => container.extend([*evaluation, *expr]),
193
#[cfg(feature = "dtype-struct")]
194
StructEval { expr, evaluation } => {
195
container.extend(evaluation.iter().rev().map(ExprIR::node));
196
container.extend([*expr]);
197
},
198
Slice {
199
input,
200
offset,
201
length,
202
} => {
203
container.extend([*length, *offset, *input]);
204
},
205
}
206
}
207
208
pub fn replace_inputs(mut self, inputs: &[Node]) -> Self {
209
use AExpr::*;
210
let input = match &mut self {
211
Element | Column(_) | Literal(_) | Len => return self,
212
#[cfg(feature = "dtype-struct")]
213
StructField(_) => return self,
214
Cast { expr, .. } => expr,
215
Explode { expr, .. } => expr,
216
BinaryExpr { left, right, .. } => {
217
*left = inputs[0];
218
*right = inputs[1];
219
return self;
220
},
221
Gather { expr, idx, .. } => {
222
*expr = inputs[0];
223
*idx = inputs[1];
224
return self;
225
},
226
Sort { expr, .. } => expr,
227
SortBy { expr, by, .. } => {
228
*expr = inputs[0];
229
by.clear();
230
by.extend_from_slice(&inputs[1..]);
231
return self;
232
},
233
Filter { input, by, .. } => {
234
*input = inputs[0];
235
*by = inputs[1];
236
return self;
237
},
238
Agg(a) => {
239
match a {
240
IRAggExpr::Quantile {
241
expr,
242
quantile,
243
method: _,
244
} => {
245
*expr = inputs[0];
246
*quantile = inputs[1];
247
},
248
_ => {
249
a.set_input(inputs[0]);
250
},
251
}
252
return self;
253
},
254
Ternary {
255
truthy,
256
falsy,
257
predicate,
258
} => {
259
*truthy = inputs[0];
260
*falsy = inputs[1];
261
*predicate = inputs[2];
262
return self;
263
},
264
AnonymousFunction { input, .. }
265
| Function { input, .. }
266
| AnonymousAgg { input, .. } => {
267
assert_eq!(input.len(), inputs.len());
268
for (e, node) in input.iter_mut().zip(inputs.iter()) {
269
e.set_node(*node);
270
}
271
return self;
272
},
273
Eval {
274
expr,
275
evaluation,
276
variant: _,
277
} => {
278
*expr = inputs[0];
279
_ = evaluation; // Intentional.
280
return self;
281
},
282
#[cfg(feature = "dtype-struct")]
283
StructEval { expr, evaluation } => {
284
*expr = inputs[0];
285
_ = evaluation; // Intentional.
286
return self;
287
},
288
Slice {
289
input,
290
offset,
291
length,
292
} => {
293
*input = inputs[0];
294
*offset = inputs[1];
295
*length = inputs[2];
296
return self;
297
},
298
#[cfg(feature = "dynamic_group_by")]
299
Rolling {
300
function,
301
index_column,
302
period: _,
303
offset: _,
304
closed_window: _,
305
} => {
306
*function = inputs[0];
307
*index_column = inputs[1];
308
return self;
309
},
310
Over {
311
function,
312
partition_by,
313
order_by,
314
..
315
} => {
316
let offset = order_by.is_some() as usize;
317
*function = inputs[0];
318
partition_by.clear();
319
partition_by.extend_from_slice(&inputs[1..inputs.len() - offset]);
320
if let Some((_, options)) = order_by {
321
*order_by = Some((*inputs.last().unwrap(), *options));
322
}
323
return self;
324
},
325
};
326
*input = inputs[0];
327
self
328
}
329
330
pub fn replace_children(mut self, inputs: &[Node]) -> Self {
331
use AExpr::*;
332
let input = match &mut self {
333
Element | Column(_) | Literal(_) | Len => return self,
334
#[cfg(feature = "dtype-struct")]
335
StructField(_) => return self,
336
Cast { expr, .. } => expr,
337
Explode { expr, .. } => expr,
338
BinaryExpr { left, right, .. } => {
339
*left = inputs[0];
340
*right = inputs[1];
341
return self;
342
},
343
Gather { expr, idx, .. } => {
344
*expr = inputs[0];
345
*idx = inputs[1];
346
return self;
347
},
348
Sort { expr, .. } => expr,
349
SortBy { expr, by, .. } => {
350
*expr = inputs[0];
351
by.clear();
352
by.extend_from_slice(&inputs[1..]);
353
return self;
354
},
355
Filter { input, by, .. } => {
356
*input = inputs[0];
357
*by = inputs[1];
358
return self;
359
},
360
Agg(a) => {
361
if let IRAggExpr::Quantile {
362
expr,
363
quantile,
364
method: _,
365
} = a
366
{
367
*expr = inputs[0];
368
*quantile = inputs[1];
369
} else {
370
a.set_input(inputs[0]);
371
}
372
return self;
373
},
374
Ternary {
375
truthy,
376
falsy,
377
predicate,
378
} => {
379
*truthy = inputs[0];
380
*falsy = inputs[1];
381
*predicate = inputs[2];
382
return self;
383
},
384
AnonymousAgg { input, .. }
385
| AnonymousFunction { input, .. }
386
| Function { input, .. } => {
387
assert_eq!(input.len(), inputs.len());
388
for (e, node) in input.iter_mut().zip(inputs.iter()) {
389
e.set_node(*node);
390
}
391
return self;
392
},
393
Eval {
394
expr,
395
evaluation,
396
variant: _,
397
} => {
398
*expr = inputs[0];
399
*evaluation = inputs[1];
400
return self;
401
},
402
#[cfg(feature = "dtype-struct")]
403
StructEval { expr, evaluation } => {
404
assert_eq!(inputs.len(), evaluation.len() + 1);
405
*expr = inputs[0];
406
for (e, node) in evaluation.iter_mut().zip(inputs[1..].iter()) {
407
e.set_node(*node);
408
}
409
return self;
410
},
411
Slice {
412
input,
413
offset,
414
length,
415
} => {
416
*input = inputs[0];
417
*offset = inputs[1];
418
*length = inputs[2];
419
return self;
420
},
421
#[cfg(feature = "dynamic_group_by")]
422
Rolling {
423
function,
424
index_column,
425
period: _,
426
offset: _,
427
closed_window: _,
428
} => {
429
*function = inputs[0];
430
*index_column = inputs[1];
431
return self;
432
},
433
Over {
434
function,
435
partition_by,
436
order_by,
437
..
438
} => {
439
let offset = order_by.is_some() as usize;
440
*function = inputs[0];
441
partition_by.clear();
442
partition_by.extend_from_slice(&inputs[1..inputs.len() - offset]);
443
if let Some((_, options)) = order_by {
444
*order_by = Some((*inputs.last().unwrap(), *options));
445
}
446
return self;
447
},
448
};
449
*input = inputs[0];
450
self
451
}
452
}
453
454
impl IRAggExpr {
455
pub fn get_input(&self) -> NodeInputs {
456
use IRAggExpr::*;
457
use NodeInputs::*;
458
459
match self {
460
Min { input, .. } => Single(*input),
461
Max { input, .. } => Single(*input),
462
Median(input) => Single(*input),
463
NUnique(input) => Single(*input),
464
First(input) => Single(*input),
465
FirstNonNull(input) => Single(*input),
466
Last(input) => Single(*input),
467
LastNonNull(input) => Single(*input),
468
Item { input, .. } => Single(*input),
469
Mean(input) => Single(*input),
470
Implode(input) => Single(*input),
471
Quantile { expr, quantile, .. } => Many(vec![*expr, *quantile]),
472
Sum(input) => Single(*input),
473
Count { input, .. } => Single(*input),
474
Std(input, _) => Single(*input),
475
Var(input, _) => Single(*input),
476
AggGroups(input) => Single(*input),
477
}
478
}
479
pub fn set_input(&mut self, input: Node) {
480
use IRAggExpr::*;
481
let node = match self {
482
Min { input, .. } => input,
483
Max { input, .. } => input,
484
Median(input) => input,
485
NUnique(input) => input,
486
First(input) => input,
487
FirstNonNull(input) => input,
488
Last(input) => input,
489
LastNonNull(input) => input,
490
Item { input, .. } => input,
491
Mean(input) => input,
492
Implode(input) => input,
493
Quantile { expr, .. } => expr,
494
Sum(input) => input,
495
Count { input, .. } => input,
496
Std(input, _) => input,
497
Var(input, _) => input,
498
AggGroups(input) => input,
499
};
500
*node = input;
501
}
502
}
503
504
pub enum NodeInputs {
505
Leaf,
506
Single(Node),
507
Many(Vec<Node>),
508
}
509
510
impl NodeInputs {
511
pub fn first(&self) -> Node {
512
match self {
513
NodeInputs::Single(node) => *node,
514
NodeInputs::Many(nodes) => nodes[0],
515
NodeInputs::Leaf => panic!(),
516
}
517
}
518
}
519
520