Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/plans/aexpr/properties/general.rs
6940 views
1
use polars_utils::idx_vec::UnitVec;
2
use polars_utils::unitvec;
3
4
use super::super::*;
5
6
impl AExpr {
7
pub(crate) fn is_leaf(&self) -> bool {
8
matches!(self, AExpr::Column(_) | AExpr::Literal(_) | AExpr::Len)
9
}
10
11
pub(crate) fn is_col(&self) -> bool {
12
matches!(self, AExpr::Column(_))
13
}
14
15
/// Checks whether this expression is elementwise. This only checks the top level expression.
16
pub(crate) fn is_elementwise_top_level(&self) -> bool {
17
use AExpr::*;
18
19
match self {
20
AnonymousFunction { options, .. } => options.is_elementwise(),
21
22
Function { options, .. } => options.is_elementwise(),
23
24
Literal(v) => v.is_scalar(),
25
26
Eval { variant, .. } => match variant {
27
EvalVariant::List => true,
28
EvalVariant::Cumulative { min_samples: _ } => false,
29
},
30
31
BinaryExpr { .. } | Column(_) | Ternary { .. } | Cast { .. } => true,
32
33
Agg { .. }
34
| Explode { .. }
35
| Filter { .. }
36
| Gather { .. }
37
| Len
38
| Slice { .. }
39
| Sort { .. }
40
| SortBy { .. }
41
| Window { .. } => false,
42
}
43
}
44
45
/// Checks whether this expression is row-separable. This only checks the top level expression.
46
pub(crate) fn is_row_separable_top_level(&self) -> bool {
47
use AExpr::*;
48
49
match self {
50
AnonymousFunction { options, .. } => options.is_row_separable(),
51
Function { options, .. } => options.is_row_separable(),
52
Literal(v) => v.is_scalar(),
53
Explode { .. } | Filter { .. } => true,
54
_ => self.is_elementwise_top_level(),
55
}
56
}
57
58
pub(crate) fn does_not_modify_top_level(&self) -> bool {
59
match self {
60
AExpr::Column(_) => true,
61
AExpr::Function { function, .. } => {
62
matches!(function, IRFunctionExpr::SetSortedFlag(_))
63
},
64
_ => false,
65
}
66
}
67
}
68
69
// Traversal utilities
70
fn property_and_traverse<F>(stack: &mut UnitVec<Node>, ae: &AExpr, property: F) -> bool
71
where
72
F: Fn(&AExpr) -> bool,
73
{
74
if !property(ae) {
75
return false;
76
}
77
ae.inputs_rev(stack);
78
true
79
}
80
81
fn property_rec<F>(node: Node, expr_arena: &Arena<AExpr>, property: F) -> bool
82
where
83
F: Fn(&mut UnitVec<Node>, &AExpr, &Arena<AExpr>) -> bool,
84
{
85
let mut stack = unitvec![];
86
let mut ae = expr_arena.get(node);
87
88
loop {
89
if !property(&mut stack, ae, expr_arena) {
90
return false;
91
}
92
93
let Some(node) = stack.pop() else {
94
break;
95
};
96
97
ae = expr_arena.get(node);
98
}
99
100
true
101
}
102
103
/// Checks if the top-level expression node does not modify. If this is the case, then `stack` will
104
/// be extended further with any nested expression nodes.
105
fn does_not_modify(stack: &mut UnitVec<Node>, ae: &AExpr, _expr_arena: &Arena<AExpr>) -> bool {
106
property_and_traverse(stack, ae, |ae| ae.does_not_modify_top_level())
107
}
108
109
pub fn does_not_modify_rec(node: Node, expr_arena: &Arena<AExpr>) -> bool {
110
property_rec(node, expr_arena, does_not_modify)
111
}
112
113
pub fn is_prop<P: Fn(&AExpr) -> bool>(
114
stack: &mut UnitVec<Node>,
115
ae: &AExpr,
116
expr_arena: &Arena<AExpr>,
117
prop_top_level: P,
118
) -> bool {
119
use AExpr::*;
120
121
if !prop_top_level(ae) {
122
return false;
123
}
124
125
match ae {
126
// Literals that aren't being projected are allowed to be non-scalar, so we don't add them
127
// for inspection. (e.g. `is_in(<literal>)`).
128
#[cfg(feature = "is_in")]
129
Function {
130
function: IRFunctionExpr::Boolean(IRBooleanFunction::IsIn { .. }),
131
input,
132
..
133
} => (|| {
134
if let Some(rhs) = input.get(1) {
135
assert_eq!(input.len(), 2); // A.is_in(B)
136
let rhs = rhs.node();
137
138
if matches!(expr_arena.get(rhs), AExpr::Literal { .. }) {
139
stack.extend([input[0].node()]);
140
return;
141
}
142
};
143
144
ae.inputs_rev(stack);
145
})(),
146
_ => ae.inputs_rev(stack),
147
}
148
149
true
150
}
151
152
/// Checks if the top-level expression node is elementwise. If this is the case, then `stack` will
153
/// be extended further with any nested expression nodes.
154
pub fn is_elementwise(stack: &mut UnitVec<Node>, ae: &AExpr, expr_arena: &Arena<AExpr>) -> bool {
155
is_prop(stack, ae, expr_arena, |ae| ae.is_elementwise_top_level())
156
}
157
158
pub fn all_elementwise<'a, N>(nodes: &'a [N], expr_arena: &Arena<AExpr>) -> bool
159
where
160
Node: From<&'a N>,
161
{
162
nodes
163
.iter()
164
.all(|n| is_elementwise_rec(n.into(), expr_arena))
165
}
166
167
/// Recursive variant of `is_elementwise`
168
pub fn is_elementwise_rec(node: Node, expr_arena: &Arena<AExpr>) -> bool {
169
property_rec(node, expr_arena, is_elementwise)
170
}
171
172
/// Checks if the top-level expression node is row-separable. If this is the case, then `stack` will
173
/// be extended further with any nested expression nodes.
174
pub fn is_row_separable(stack: &mut UnitVec<Node>, ae: &AExpr, expr_arena: &Arena<AExpr>) -> bool {
175
is_prop(stack, ae, expr_arena, |ae| ae.is_row_separable_top_level())
176
}
177
178
pub fn all_row_separable<'a, N>(nodes: &'a [N], expr_arena: &Arena<AExpr>) -> bool
179
where
180
Node: From<&'a N>,
181
{
182
nodes
183
.iter()
184
.all(|n| is_row_separable_rec(n.into(), expr_arena))
185
}
186
187
/// Recursive variant of `is_row_separable`
188
pub fn is_row_separable_rec(node: Node, expr_arena: &Arena<AExpr>) -> bool {
189
property_rec(node, expr_arena, is_row_separable)
190
}
191
192
#[derive(Debug, Clone)]
193
pub enum ExprPushdownGroup {
194
/// Can be pushed. (elementwise, infallible)
195
///
196
/// e.g. non-strict cast
197
Pushable,
198
/// Cannot be pushed, but doesn't block pushables. (elementwise, fallible)
199
///
200
/// Fallible expressions are categorized into this group rather than the Barrier group. The
201
/// effect of this means we push more predicates, but the expression may no longer error
202
/// if the problematic rows are filtered out.
203
///
204
/// e.g. strict-cast, list.get(null_on_oob=False), to_datetime(strict=True)
205
Fallible,
206
/// Cannot be pushed, and blocks all expressions at the current level. (non-elementwise)
207
///
208
/// e.g. sort()
209
Barrier,
210
}
211
212
impl ExprPushdownGroup {
213
/// Note:
214
/// * `stack` is not extended with any nodes if a barrier expression is seen.
215
/// * This function is not recursive - the caller should repeatedly
216
/// call this function with the `stack` to perform a recursive check.
217
pub fn update_with_expr(
218
&mut self,
219
stack: &mut UnitVec<Node>,
220
ae: &AExpr,
221
expr_arena: &Arena<AExpr>,
222
) -> &mut Self {
223
match self {
224
ExprPushdownGroup::Pushable | ExprPushdownGroup::Fallible => {
225
// Downgrade to unpushable if fallible
226
if match ae {
227
// Rows that go OOB on get/gather may be filtered out in earlier operations,
228
// so we don't push these down.
229
AExpr::Function {
230
function: IRFunctionExpr::ListExpr(IRListFunction::Get(false)),
231
..
232
} => true,
233
234
#[cfg(feature = "list_gather")]
235
AExpr::Function {
236
function: IRFunctionExpr::ListExpr(IRListFunction::Gather(false)),
237
..
238
} => true,
239
240
#[cfg(feature = "dtype-array")]
241
AExpr::Function {
242
function: IRFunctionExpr::ArrayExpr(IRArrayFunction::Get(false)),
243
..
244
} => true,
245
246
#[cfg(all(feature = "strings", feature = "temporal"))]
247
AExpr::Function {
248
input,
249
function:
250
IRFunctionExpr::StringExpr(IRStringFunction::Strptime(_, strptime_options)),
251
..
252
} => {
253
debug_assert!(input.len() <= 2);
254
255
let ambiguous_arg_is_infallible_scalar = input
256
.get(1)
257
.map(|x| expr_arena.get(x.node()))
258
.is_some_and(|ae| match ae {
259
AExpr::Literal(lv) => {
260
lv.extract_str().is_some_and(|ambiguous| match ambiguous {
261
"earliest" | "latest" | "null" => true,
262
"raise" => false,
263
v => {
264
if cfg!(debug_assertions) {
265
panic!("unhandled parameter to ambiguous: {v}")
266
}
267
false
268
},
269
})
270
},
271
_ => false,
272
});
273
274
let ambiguous_is_fallible = !ambiguous_arg_is_infallible_scalar;
275
276
strptime_options.strict || ambiguous_is_fallible
277
},
278
AExpr::Cast {
279
expr,
280
dtype: _,
281
options: CastOptions::Strict,
282
} => !matches!(expr_arena.get(*expr), AExpr::Literal(_)),
283
284
_ => false,
285
} {
286
*self = ExprPushdownGroup::Fallible;
287
}
288
289
// Downgrade to barrier if non-elementwise
290
if !is_elementwise(stack, ae, expr_arena) {
291
*self = ExprPushdownGroup::Barrier
292
}
293
},
294
295
ExprPushdownGroup::Barrier => {},
296
}
297
298
self
299
}
300
301
pub fn update_with_expr_rec<'a>(
302
&mut self,
303
mut ae: &'a AExpr,
304
expr_arena: &'a Arena<AExpr>,
305
scratch: Option<&mut UnitVec<Node>>,
306
) -> &mut Self {
307
let mut local_scratch = unitvec![];
308
let stack = scratch.unwrap_or(&mut local_scratch);
309
310
loop {
311
self.update_with_expr(stack, ae, expr_arena);
312
313
if let ExprPushdownGroup::Barrier = self {
314
return self;
315
}
316
317
let Some(node) = stack.pop() else {
318
break;
319
};
320
321
ae = expr_arena.get(node);
322
}
323
324
self
325
}
326
327
pub fn blocks_pushdown(&self, maintain_errors: bool) -> bool {
328
match self {
329
ExprPushdownGroup::Barrier => true,
330
ExprPushdownGroup::Fallible => maintain_errors,
331
ExprPushdownGroup::Pushable => false,
332
}
333
}
334
}
335
336
pub fn can_pre_agg_exprs(
337
exprs: &[ExprIR],
338
expr_arena: &Arena<AExpr>,
339
_input_schema: &Schema,
340
) -> bool {
341
exprs
342
.iter()
343
.all(|e| can_pre_agg(e.node(), expr_arena, _input_schema))
344
}
345
346
/// Checks whether an expression can be pre-aggregated in a group-by. Note that this also must be
347
/// implemented physically, so this isn't a complete list.
348
pub fn can_pre_agg(agg: Node, expr_arena: &Arena<AExpr>, _input_schema: &Schema) -> bool {
349
let aexpr = expr_arena.get(agg);
350
351
match aexpr {
352
AExpr::Len => true,
353
AExpr::Column(_) | AExpr::Literal(_) => false,
354
// We only allow expressions that end with an aggregation.
355
AExpr::Agg(_) => {
356
let has_aggregation =
357
|node: Node| has_aexpr(node, expr_arena, |ae| matches!(ae, AExpr::Agg(_)));
358
359
// check if the aggregation type is partitionable
360
// only simple aggregation like col().sum
361
// that can be divided in to the aggregation of their partitions are allowed
362
let can_partition = (expr_arena).iter(agg).all(|(_, ae)| {
363
use AExpr::*;
364
match ae {
365
// struct is needed to keep both states
366
#[cfg(feature = "dtype-struct")]
367
Agg(IRAggExpr::Mean(_)) => {
368
// only numeric means for now.
369
// logical types seem to break because of casts to float.
370
matches!(
371
expr_arena
372
.get(agg)
373
.get_dtype(_input_schema, expr_arena)
374
.map(|dt| { dt.is_primitive_numeric() }),
375
Ok(true)
376
)
377
},
378
// only allowed expressions
379
Agg(agg_e) => {
380
matches!(
381
agg_e,
382
IRAggExpr::Min { .. }
383
| IRAggExpr::Max { .. }
384
| IRAggExpr::Sum(_)
385
| IRAggExpr::Last(_)
386
| IRAggExpr::First(_)
387
| IRAggExpr::Count {
388
input: _,
389
include_nulls: true
390
}
391
)
392
},
393
Function { input, options, .. } => {
394
options.is_elementwise()
395
&& input.len() == 1
396
&& !has_aggregation(input[0].node())
397
},
398
BinaryExpr { left, right, .. } => {
399
!has_aggregation(*left) && !has_aggregation(*right)
400
},
401
Ternary {
402
truthy,
403
falsy,
404
predicate,
405
..
406
} => {
407
!has_aggregation(*truthy)
408
&& !has_aggregation(*falsy)
409
&& !has_aggregation(*predicate)
410
},
411
Literal(lv) => lv.is_scalar(),
412
Column(_) | Len | Cast { .. } => true,
413
_ => false,
414
}
415
});
416
417
#[cfg(feature = "object")]
418
{
419
for name in aexpr_to_leaf_names(agg, expr_arena) {
420
let dtype = _input_schema.get(&name).unwrap();
421
422
if let DataType::Object(_) = dtype {
423
return false;
424
}
425
}
426
}
427
can_partition
428
},
429
_ => false,
430
}
431
}
432
433
/// Identifies columns that are guaranteed to be non-NULL after applying this filter.
434
///
435
/// This is conservative in that it will not give false positives, but may not identify all columns.
436
///
437
/// Note, this must be called with the root node of filter expressions (the root nodes after splitting
438
/// with MintermIter is also allowed).
439
pub(crate) fn predicate_non_null_column_outputs(
440
predicate_node: Node,
441
expr_arena: &Arena<AExpr>,
442
non_null_column_callback: &mut dyn FnMut(&PlSmallStr),
443
) {
444
let mut minterm_iter = MintermIter::new(predicate_node, expr_arena);
445
let stack: &mut UnitVec<Node> = &mut unitvec![];
446
447
/// Only traverse the first input, e.g. `A.is_in(B)` we don't consider B.
448
macro_rules! traverse_first_input {
449
// &[ExprIR]
450
($inputs:expr) => {{
451
if let Some(expr_ir) = $inputs.first() {
452
stack.push(expr_ir.node())
453
}
454
455
false
456
}};
457
}
458
459
loop {
460
use AExpr::*;
461
462
let node = if let Some(node) = stack.pop() {
463
node
464
} else if let Some(minterm_node) = minterm_iter.next() {
465
// Some additional leaf exprs can be pruned.
466
match expr_arena.get(minterm_node) {
467
Function {
468
input,
469
function: IRFunctionExpr::Boolean(IRBooleanFunction::IsNotNull),
470
options: _,
471
} if !input.is_empty() => input.first().unwrap().node(),
472
473
Function {
474
input,
475
function: IRFunctionExpr::Boolean(IRBooleanFunction::Not),
476
options: _,
477
} if !input.is_empty() => match expr_arena.get(input.first().unwrap().node()) {
478
Function {
479
input,
480
function: IRFunctionExpr::Boolean(IRBooleanFunction::IsNull),
481
options: _,
482
} if !input.is_empty() => input.first().unwrap().node(),
483
484
_ => minterm_node,
485
},
486
487
_ => minterm_node,
488
}
489
} else {
490
break;
491
};
492
493
let ae = expr_arena.get(node);
494
495
// This match we traverse a subset of the operations that are guaranteed to maintain NULLs.
496
//
497
// This must not catch any operations that materialize NULLs, as otherwise e.g.
498
// `e.fill_null(False) >= False` will include NULLs
499
let traverse_all_inputs = match ae {
500
BinaryExpr {
501
left: _,
502
op,
503
right: _,
504
} => {
505
use Operator::*;
506
507
match op {
508
Eq | NotEq | Lt | LtEq | Gt | GtEq | Plus | Minus | Multiply | Divide
509
| TrueDivide | FloorDivide | Modulus | Xor => true,
510
511
// These can turn NULLs into true/false. E.g.:
512
// * (L & False) >= False becomes True
513
// * L | True becomes True
514
EqValidity | NotEqValidity | Or | LogicalOr | And | LogicalAnd => false,
515
}
516
},
517
518
Cast { dtype, .. } => {
519
// Forbid nested types, it's currently buggy:
520
// >>> pl.select(a=pl.lit(None), b=pl.lit(None).cast(pl.Struct({})))
521
// | a | b |
522
// | --- | --- |
523
// | null | struct[0] |
524
// |------|-----------|
525
// | null | {} |
526
//
527
// (issue at https://github.com/pola-rs/polars/issues/23276)
528
!dtype.is_nested()
529
},
530
531
Function {
532
input,
533
function: _,
534
options,
535
} => {
536
if options
537
.flags
538
.contains(FunctionFlags::PRESERVES_NULL_FIRST_INPUT)
539
{
540
traverse_first_input!(input)
541
} else {
542
options
543
.flags
544
.contains(FunctionFlags::PRESERVES_NULL_ALL_INPUTS)
545
}
546
},
547
548
Column(name) => {
549
non_null_column_callback(name);
550
false
551
},
552
553
_ => false,
554
};
555
556
if traverse_all_inputs {
557
ae.inputs_rev(stack);
558
}
559
}
560
}
561
562