Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/plans/aexpr/properties/general.rs
8446 views
1
use polars_utils::idx_vec::UnitVec;
2
use polars_utils::unitvec;
3
4
use super::super::*;
5
6
impl AExpr {
7
pub(crate) fn is_leaf(&self) -> bool {
8
matches!(self, AExpr::Column(_) | AExpr::Literal(_) | AExpr::Len)
9
}
10
11
pub(crate) fn is_col(&self) -> bool {
12
matches!(self, AExpr::Column(_))
13
}
14
15
/// Checks whether this expression is elementwise. This only checks the top level expression.
16
pub(crate) fn is_elementwise_top_level(&self) -> bool {
17
use AExpr::*;
18
19
match self {
20
AnonymousFunction { options, .. } => options.is_elementwise(),
21
22
Function { options, .. } => options.is_elementwise(),
23
24
Literal(v) => v.is_scalar(),
25
26
Eval { variant, .. } => variant.is_elementwise(),
27
28
Element | BinaryExpr { .. } | Column(_) | Ternary { .. } | Cast { .. } => true,
29
30
#[cfg(feature = "dtype-struct")]
31
StructEval { .. } | StructField(_) => true,
32
33
#[cfg(feature = "dynamic_group_by")]
34
Rolling { .. } => false,
35
36
Agg { .. }
37
| AnonymousAgg { .. }
38
| Explode { .. }
39
| Filter { .. }
40
| Gather { .. }
41
| Len
42
| Slice { .. }
43
| Sort { .. }
44
| SortBy { .. }
45
| Over { .. } => false,
46
}
47
}
48
49
/// Checks whether this expression is row-separable. This only checks the top level expression.
50
pub(crate) fn is_row_separable_top_level(&self) -> bool {
51
use AExpr::*;
52
53
match self {
54
AnonymousFunction { options, .. } => options.is_row_separable(),
55
Function { options, .. } => options.is_row_separable(),
56
Literal(v) => v.is_scalar(),
57
Explode { .. } | Filter { .. } => true,
58
_ => self.is_elementwise_top_level(),
59
}
60
}
61
62
pub(crate) fn does_not_modify_top_level(&self) -> bool {
63
match self {
64
AExpr::Column(_) => true,
65
AExpr::Function { function, .. } => {
66
matches!(function, IRFunctionExpr::SetSortedFlag(_))
67
},
68
_ => false,
69
}
70
}
71
}
72
73
// Traversal utilities
74
fn property_and_traverse<F>(stack: &mut UnitVec<Node>, ae: &AExpr, property: F) -> bool
75
where
76
F: Fn(&AExpr) -> bool,
77
{
78
if !property(ae) {
79
return false;
80
}
81
ae.inputs_rev(stack);
82
true
83
}
84
85
fn property_rec<F>(node: Node, expr_arena: &Arena<AExpr>, property: F) -> bool
86
where
87
F: Fn(&mut UnitVec<Node>, &AExpr, &Arena<AExpr>) -> bool,
88
{
89
let mut stack = unitvec![];
90
let mut ae = expr_arena.get(node);
91
92
loop {
93
if !property(&mut stack, ae, expr_arena) {
94
return false;
95
}
96
97
let Some(node) = stack.pop() else {
98
break;
99
};
100
101
ae = expr_arena.get(node);
102
}
103
104
true
105
}
106
107
/// Checks if the top-level expression node does not modify. If this is the case, then `stack` will
108
/// be extended further with any nested expression nodes.
109
fn does_not_modify(stack: &mut UnitVec<Node>, ae: &AExpr, _expr_arena: &Arena<AExpr>) -> bool {
110
property_and_traverse(stack, ae, |ae| ae.does_not_modify_top_level())
111
}
112
113
pub fn does_not_modify_rec(node: Node, expr_arena: &Arena<AExpr>) -> bool {
114
property_rec(node, expr_arena, does_not_modify)
115
}
116
117
pub fn is_prop<P: Fn(&AExpr) -> bool>(
118
stack: &mut UnitVec<Node>,
119
ae: &AExpr,
120
expr_arena: &Arena<AExpr>,
121
prop_top_level: P,
122
) -> bool {
123
use AExpr::*;
124
125
if !prop_top_level(ae) {
126
return false;
127
}
128
129
match ae {
130
// Literals that aren't being projected are allowed to be non-scalar, so we don't add them
131
// for inspection. (e.g. `is_in(<literal>)`).
132
#[cfg(feature = "is_in")]
133
Function {
134
function: IRFunctionExpr::Boolean(IRBooleanFunction::IsIn { .. }),
135
input,
136
..
137
} => (|| {
138
if let Some(rhs) = input.get(1) {
139
assert_eq!(input.len(), 2); // A.is_in(B)
140
let rhs = rhs.node();
141
142
if matches!(expr_arena.get(rhs), AExpr::Literal { .. }) {
143
stack.extend([input[0].node()]);
144
return;
145
}
146
};
147
ae.inputs_rev(stack);
148
})(),
149
_ => {
150
ae.inputs_rev(stack);
151
},
152
}
153
154
true
155
}
156
157
/// Checks if the top-level expression node is elementwise. If this is the case, then `stack` will
158
/// be extended further with any nested expression nodes.
159
pub fn is_elementwise(stack: &mut UnitVec<Node>, ae: &AExpr, expr_arena: &Arena<AExpr>) -> bool {
160
is_prop(stack, ae, expr_arena, |ae| ae.is_elementwise_top_level())
161
}
162
163
pub fn all_elementwise<'a, N>(nodes: &'a [N], expr_arena: &Arena<AExpr>) -> bool
164
where
165
Node: From<&'a N>,
166
{
167
nodes
168
.iter()
169
.all(|n| is_elementwise_rec(n.into(), expr_arena))
170
}
171
172
/// Recursive variant of `is_elementwise`
173
pub fn is_elementwise_rec(node: Node, expr_arena: &Arena<AExpr>) -> bool {
174
property_rec(node, expr_arena, is_elementwise)
175
}
176
177
/// Checks if the top-level expression node is row-separable. If this is the case, then `stack` will
178
/// be extended further with any nested expression nodes.
179
pub fn is_row_separable(stack: &mut UnitVec<Node>, ae: &AExpr, expr_arena: &Arena<AExpr>) -> bool {
180
is_prop(stack, ae, expr_arena, |ae| ae.is_row_separable_top_level())
181
}
182
183
pub fn all_row_separable<'a, N>(nodes: &'a [N], expr_arena: &Arena<AExpr>) -> bool
184
where
185
Node: From<&'a N>,
186
{
187
nodes
188
.iter()
189
.all(|n| is_row_separable_rec(n.into(), expr_arena))
190
}
191
192
/// Recursive variant of `is_row_separable`
193
pub fn is_row_separable_rec(node: Node, expr_arena: &Arena<AExpr>) -> bool {
194
property_rec(node, expr_arena, is_row_separable)
195
}
196
197
#[derive(Debug, Clone)]
198
pub enum ExprPushdownGroup {
199
/// Can be pushed. (elementwise, infallible)
200
///
201
/// e.g. non-strict cast
202
Pushable,
203
/// Cannot be pushed, but doesn't block pushables. (elementwise, fallible)
204
///
205
/// Fallible expressions are categorized into this group rather than the Barrier group. The
206
/// effect of this means we push more predicates, but the expression may no longer error
207
/// if the problematic rows are filtered out.
208
///
209
/// e.g. strict-cast, list.get(null_on_oob=False), to_datetime(strict=True)
210
Fallible,
211
/// Cannot be pushed, and blocks all expressions at the current level. (non-elementwise)
212
///
213
/// e.g. sort()
214
Barrier,
215
}
216
217
impl ExprPushdownGroup {
218
/// Note:
219
/// * `stack` is not extended with any nodes if a barrier expression is seen.
220
/// * This function is not recursive - the caller should repeatedly
221
/// call this function with the `stack` to perform a recursive check.
222
pub fn update_with_expr(
223
&mut self,
224
stack: &mut UnitVec<Node>,
225
ae: &AExpr,
226
expr_arena: &Arena<AExpr>,
227
) -> &mut Self {
228
match self {
229
ExprPushdownGroup::Pushable | ExprPushdownGroup::Fallible => {
230
// Downgrade to unpushable if fallible
231
if ae.is_fallible_top_level(expr_arena) {
232
*self = ExprPushdownGroup::Fallible;
233
}
234
235
// Downgrade to barrier if non-elementwise
236
if !is_elementwise(stack, ae, expr_arena) {
237
*self = ExprPushdownGroup::Barrier
238
}
239
},
240
241
ExprPushdownGroup::Barrier => {},
242
}
243
244
self
245
}
246
247
pub fn update_with_expr_rec<'a>(
248
&mut self,
249
mut ae: &'a AExpr,
250
expr_arena: &'a Arena<AExpr>,
251
scratch: Option<&mut UnitVec<Node>>,
252
) -> &mut Self {
253
let mut local_scratch = unitvec![];
254
let stack = scratch.unwrap_or(&mut local_scratch);
255
256
loop {
257
self.update_with_expr(stack, ae, expr_arena);
258
259
if let ExprPushdownGroup::Barrier = self {
260
return self;
261
}
262
263
let Some(node) = stack.pop() else {
264
break;
265
};
266
267
ae = expr_arena.get(node);
268
}
269
270
self
271
}
272
273
pub fn blocks_pushdown(&self, maintain_errors: bool) -> bool {
274
match self {
275
ExprPushdownGroup::Barrier => true,
276
ExprPushdownGroup::Fallible => maintain_errors,
277
ExprPushdownGroup::Pushable => false,
278
}
279
}
280
}
281
282
pub fn can_pre_agg_exprs(
283
exprs: &[ExprIR],
284
expr_arena: &Arena<AExpr>,
285
_input_schema: &Schema,
286
) -> bool {
287
exprs
288
.iter()
289
.all(|e| can_pre_agg(e.node(), expr_arena, _input_schema))
290
}
291
292
/// Checks whether an expression can be pre-aggregated in a group-by. Note that this also must be
293
/// implemented physically, so this isn't a complete list.
294
pub fn can_pre_agg(agg: Node, expr_arena: &Arena<AExpr>, _input_schema: &Schema) -> bool {
295
let aexpr = expr_arena.get(agg);
296
297
match aexpr {
298
AExpr::Len => true,
299
AExpr::Column(_) | AExpr::Literal(_) => false,
300
// We only allow expressions that end with an aggregation.
301
AExpr::Agg(_) => {
302
let has_aggregation =
303
|node: Node| has_aexpr(node, expr_arena, |ae| matches!(ae, AExpr::Agg(_)));
304
305
// check if the aggregation type is partitionable
306
// only simple aggregation like col().sum
307
// that can be divided in to the aggregation of their partitions are allowed
308
let can_partition = (expr_arena).iter(agg).all(|(_, ae)| {
309
use AExpr::*;
310
match ae {
311
// struct is needed to keep both states
312
#[cfg(feature = "dtype-struct")]
313
Agg(IRAggExpr::Mean(_)) => {
314
// only numeric means for now.
315
// logical types seem to break because of casts to float.
316
matches!(
317
expr_arena
318
.get(agg)
319
.to_dtype(&ToFieldContext::new(expr_arena, _input_schema))
320
.map(|dt| { dt.is_primitive_numeric() }),
321
Ok(true)
322
)
323
},
324
// only allowed expressions
325
Agg(agg_e) => {
326
matches!(
327
agg_e,
328
IRAggExpr::Min { .. }
329
| IRAggExpr::Max { .. }
330
| IRAggExpr::Sum(_)
331
| IRAggExpr::Last(_)
332
| IRAggExpr::First(_)
333
| IRAggExpr::Count {
334
input: _,
335
include_nulls: true
336
}
337
)
338
},
339
Function { input, options, .. } => {
340
options.is_elementwise()
341
&& input.len() == 1
342
&& !has_aggregation(input[0].node())
343
},
344
BinaryExpr { left, right, .. } => {
345
!has_aggregation(*left) && !has_aggregation(*right)
346
},
347
Ternary {
348
truthy,
349
falsy,
350
predicate,
351
..
352
} => {
353
!has_aggregation(*truthy)
354
&& !has_aggregation(*falsy)
355
&& !has_aggregation(*predicate)
356
},
357
Literal(lv) => lv.is_scalar(),
358
Column(_) | Len | Cast { .. } => true,
359
_ => false,
360
}
361
});
362
363
#[cfg(feature = "object")]
364
{
365
for name in aexpr_to_leaf_names(agg, expr_arena) {
366
let dtype = _input_schema.get(&name).unwrap();
367
368
if let DataType::Object(_) = dtype {
369
return false;
370
}
371
}
372
}
373
can_partition
374
},
375
_ => false,
376
}
377
}
378
379
/// Identifies columns that are guaranteed to be non-NULL after applying this filter.
380
///
381
/// This is conservative in that it will not give false positives, but may not identify all columns.
382
///
383
/// Note, this must be called with the root node of filter expressions (the root nodes after splitting
384
/// with MintermIter is also allowed).
385
pub(crate) fn predicate_non_null_column_outputs(
386
predicate_node: Node,
387
expr_arena: &Arena<AExpr>,
388
non_null_column_callback: &mut dyn FnMut(&PlSmallStr),
389
) {
390
let mut minterm_iter = MintermIter::new(predicate_node, expr_arena);
391
let stack: &mut UnitVec<Node> = &mut unitvec![];
392
393
/// Only traverse the first input, e.g. `A.is_in(B)` we don't consider B.
394
macro_rules! traverse_first_input {
395
// &[ExprIR]
396
($inputs:expr) => {{
397
if let Some(expr_ir) = $inputs.first() {
398
stack.push(expr_ir.node())
399
}
400
401
false
402
}};
403
}
404
405
loop {
406
use AExpr::*;
407
408
let node = if let Some(node) = stack.pop() {
409
node
410
} else if let Some(minterm_node) = minterm_iter.next() {
411
// Some additional leaf exprs can be pruned.
412
match expr_arena.get(minterm_node) {
413
Function {
414
input,
415
function: IRFunctionExpr::Boolean(IRBooleanFunction::IsNotNull),
416
options: _,
417
} if !input.is_empty() => input.first().unwrap().node(),
418
419
Function {
420
input,
421
function: IRFunctionExpr::Boolean(IRBooleanFunction::Not),
422
options: _,
423
} if !input.is_empty() => match expr_arena.get(input.first().unwrap().node()) {
424
Function {
425
input,
426
function: IRFunctionExpr::Boolean(IRBooleanFunction::IsNull),
427
options: _,
428
} if !input.is_empty() => input.first().unwrap().node(),
429
430
_ => minterm_node,
431
},
432
433
_ => minterm_node,
434
}
435
} else {
436
break;
437
};
438
439
let ae = expr_arena.get(node);
440
441
// This match we traverse a subset of the operations that are guaranteed to maintain NULLs.
442
//
443
// This must not catch any operations that materialize NULLs, as otherwise e.g.
444
// `e.fill_null(False) >= False` will include NULLs
445
let traverse_all_inputs = match ae {
446
BinaryExpr {
447
left: _,
448
op,
449
right: _,
450
} => {
451
use Operator::*;
452
453
match op {
454
Eq | NotEq | Lt | LtEq | Gt | GtEq | Plus | Minus | Multiply | RustDivide
455
| TrueDivide | FloorDivide | Modulus | Xor => true,
456
457
// These can turn NULLs into true/false. E.g.:
458
// * (L & False) >= False becomes True
459
// * L | True becomes True
460
EqValidity | NotEqValidity | Or | LogicalOr | And | LogicalAnd => false,
461
}
462
},
463
464
Cast { dtype, .. } => {
465
// Forbid nested types, it's currently buggy:
466
// >>> pl.select(a=pl.lit(None), b=pl.lit(None).cast(pl.Struct({})))
467
// | a | b |
468
// | --- | --- |
469
// | null | struct[0] |
470
// |------|-----------|
471
// | null | {} |
472
//
473
// (issue at https://github.com/pola-rs/polars/issues/23276)
474
!dtype.is_nested()
475
},
476
477
Function {
478
input,
479
function: _,
480
options,
481
} => {
482
if options
483
.flags
484
.contains(FunctionFlags::PRESERVES_NULL_FIRST_INPUT)
485
{
486
traverse_first_input!(input)
487
} else {
488
options
489
.flags
490
.contains(FunctionFlags::PRESERVES_NULL_ALL_INPUTS)
491
}
492
},
493
494
Column(name) => {
495
non_null_column_callback(name);
496
false
497
},
498
499
_ => false,
500
};
501
502
if traverse_all_inputs {
503
ae.inputs_rev(stack);
504
}
505
}
506
}
507
508