Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/plans/aexpr/predicates/skip_batches.rs
7889 views
1
//! This module creates predicates that can skip record batches of rows based on statistics about
2
//! that record batch.
3
4
use polars_core::prelude::{AnyValue, DataType, Scalar};
5
use polars_core::schema::Schema;
6
use polars_utils::aliases::PlIndexMap;
7
use polars_utils::arena::{Arena, Node};
8
use polars_utils::format_pl_smallstr;
9
use polars_utils::pl_str::PlSmallStr;
10
11
use super::super::evaluate::{constant_evaluate, into_column};
12
use super::super::{AExpr, IRBooleanFunction, IRFunctionExpr, Operator};
13
use crate::plans::aexpr::builder::IntoAExprBuilder;
14
use crate::plans::predicates::get_binary_expr_col_and_lv;
15
use crate::plans::{AExprBuilder, aexpr_to_leaf_names_iter, is_scalar_ae, rename_columns};
16
17
/// Return a new boolean expression determines whether a batch can be skipped based on min, max and
18
/// null count statistics.
19
///
20
/// This is conversative and may return `None` or `false` when an expression is not yet supported.
21
///
22
/// To evaluate, the expression it is given all the original column appended with `_min` and
23
/// `_max`. The `min` or `max` cannot be null and when they are null it is assumed they are not
24
/// known.
25
pub fn aexpr_to_skip_batch_predicate(
26
e: Node,
27
expr_arena: &mut Arena<AExpr>,
28
schema: &Schema,
29
) -> Option<Node> {
30
aexpr_to_skip_batch_predicate_rec(e, expr_arena, schema, 0)
31
}
32
33
fn does_dtype_have_sufficient_order(dtype: &DataType) -> bool {
34
// Rules surrounding floats are really complicated. I should get around to that.
35
!dtype.is_nested() && !dtype.is_float() && !dtype.is_null() && !dtype.is_categorical()
36
}
37
38
fn is_stat_defined(
39
expr: impl IntoAExprBuilder,
40
dtype: &DataType,
41
arena: &mut Arena<AExpr>,
42
) -> AExprBuilder {
43
let mut expr = expr.into_aexpr_builder();
44
expr = expr.is_not_null(arena);
45
if dtype.is_float() {
46
let is_not_nan = expr.is_not_nan(arena);
47
expr = expr.and(is_not_nan, arena);
48
}
49
expr
50
}
51
52
#[recursive::recursive]
53
fn aexpr_to_skip_batch_predicate_rec(
54
e: Node,
55
arena: &mut Arena<AExpr>,
56
schema: &Schema,
57
depth: usize,
58
) -> Option<Node> {
59
use Operator as O;
60
61
macro_rules! rec {
62
($node:expr) => {{ aexpr_to_skip_batch_predicate_rec($node, arena, schema, depth + 1) }};
63
}
64
macro_rules! lv_cases {
65
(
66
$lv:expr, $lv_node:expr,
67
null: $null_case:expr,
68
not_null: $non_null_case:expr $(,)?
69
) => {{
70
if let Some(lv) = $lv {
71
if lv.is_null() {
72
$null_case
73
} else {
74
$non_null_case
75
}
76
} else {
77
let lv_node = AExprBuilder::new_from_node($lv_node);
78
79
let lv_is_null = lv_node.has_nulls(arena);
80
let lv_not_null = lv_node.has_no_nulls(arena);
81
82
let null_case = lv_is_null.and($null_case, arena);
83
let non_null_case = lv_not_null.and($non_null_case, arena);
84
85
null_case.or(non_null_case, arena).node()
86
}
87
}};
88
}
89
macro_rules! col {
90
(len) => {{ col!(PlSmallStr::from_static("len")) }};
91
($name:expr) => {{ AExprBuilder::new_from_node(arena.add(AExpr::Column($name))) }};
92
(min: $name:expr) => {{ col!(format_pl_smallstr!("{}_min", $name)) }};
93
(max: $name:expr) => {{ col!(format_pl_smallstr!("{}_max", $name)) }};
94
(null_count: $name:expr) => {{ col!(format_pl_smallstr!("{}_nc", $name)) }};
95
}
96
macro_rules! lv {
97
($lv:expr) => {{ AExprBuilder::lit_scalar(Scalar::from($lv), arena) }};
98
(idx: $lv:expr) => {{ AExprBuilder::lit_scalar(Scalar::new_idxsize($lv), arena) }};
99
}
100
101
let specialized = (|| {
102
if let Some(Some(lv)) = constant_evaluate(e, arena, schema, 0) {
103
if let Some(av) = lv.to_any_value() {
104
return match av {
105
AnyValue::Null => Some(lv!(true).node()),
106
AnyValue::Boolean(b) => Some(lv!(!b).node()),
107
_ => None,
108
};
109
}
110
}
111
112
match arena.get(e) {
113
AExpr::Element => None,
114
AExpr::Explode { .. } => None,
115
AExpr::Column(_) => None,
116
AExpr::Literal(_) => None,
117
AExpr::BinaryExpr { left, op, right } => {
118
let left = *left;
119
let right = *right;
120
121
match op {
122
O::Eq | O::EqValidity => {
123
let ((col, _), (lv, lv_node)) =
124
get_binary_expr_col_and_lv(left, right, arena, schema)?;
125
let dtype = schema.get(col)?;
126
127
if !does_dtype_have_sufficient_order(dtype) {
128
return None;
129
}
130
131
let op = *op;
132
let col = col.clone();
133
134
// col(A) == B -> {
135
// null_count(A) == 0 , if B.is_null(),
136
// null_count(A) == LEN || min(A) > B || max(A) < B, if B.is_not_null(),
137
// }
138
139
Some(lv_cases!(
140
lv, lv_node,
141
null: {
142
if matches!(op, O::Eq) {
143
lv!(false).node()
144
} else {
145
let col_nc = col!(null_count: col);
146
let idx_zero = lv!(idx: 0);
147
col_nc.eq(idx_zero, arena).node()
148
}
149
},
150
not_null: {
151
let col_min = col!(min: col);
152
let col_max = col!(max: col);
153
154
let min_is_defined = is_stat_defined(col_min.node(), dtype, arena);
155
let max_is_defined = is_stat_defined(col_max.node(), dtype, arena);
156
157
let min_gt = col_min.gt(lv_node, arena);
158
let min_gt = min_gt.and(min_is_defined, arena);
159
160
let max_lt = col_max.lt(lv_node, arena);
161
let max_lt = max_lt.and(max_is_defined, arena);
162
163
let col_nc = col!(null_count: col);
164
let len = col!(len);
165
let all_nulls = col_nc.eq(len, arena);
166
167
all_nulls.or(min_gt, arena).or(max_lt, arena).node()
168
}
169
))
170
},
171
O::NotEq | O::NotEqValidity => {
172
let ((col, _), (lv, lv_node)) =
173
get_binary_expr_col_and_lv(left, right, arena, schema)?;
174
let dtype = schema.get(col)?;
175
176
if !does_dtype_have_sufficient_order(dtype) {
177
return None;
178
}
179
180
let op = *op;
181
let col = col.clone();
182
183
// col(A) != B -> {
184
// null_count(A) == LEN , if B.is_null(),
185
// null_count(A) == 0 && min(A) == B && max(A) == B, if B.is_not_null(),
186
// }
187
188
Some(lv_cases!(
189
lv, lv_node,
190
null: {
191
if matches!(op, O::NotEq) {
192
lv!(false).node()
193
} else {
194
let col_nc = col!(null_count: col);
195
let len = col!(len);
196
col_nc.eq(len, arena).node()
197
}
198
},
199
not_null: {
200
let col_min = col!(min: col);
201
let col_max = col!(max: col);
202
let min_eq = col_min.eq(lv_node, arena);
203
let max_eq = col_max.eq(lv_node, arena);
204
205
let col_nc = col!(null_count: col);
206
let idx_zero = lv!(idx: 0);
207
let no_nulls = col_nc.eq(idx_zero, arena);
208
209
no_nulls.and(min_eq, arena).and(max_eq, arena).node()
210
}
211
))
212
},
213
O::Lt | O::Gt | O::LtEq | O::GtEq => {
214
let ((col, col_node), (lv, lv_node)) =
215
get_binary_expr_col_and_lv(left, right, arena, schema)?;
216
let dtype = schema.get(col)?;
217
218
if !does_dtype_have_sufficient_order(dtype) {
219
return None;
220
}
221
222
let col_is_left = col_node == left;
223
224
let op = *op;
225
let col = col.clone();
226
let lv_may_be_null = lv.is_none_or(|lv| lv.is_null());
227
228
// If B is null, this is always true.
229
//
230
// col(A) < B ~ B > col(A) ->
231
// null_count(A) == LEN || min(A) >= B
232
//
233
// col(A) > B ~ B < col(A) ->
234
// null_count(A) == LEN || max(A) <= B
235
//
236
// col(A) <= B ~ B >= col(A) ->
237
// null_count(A) == LEN || min(A) > B
238
//
239
// col(A) >= B ~ B <= col(A) ->
240
// null_count(A) == LEN || max(A) < B
241
242
let stat = match (op, col_is_left) {
243
(O::Lt | O::LtEq, true) | (O::Gt | O::GtEq, false) => col!(min: col),
244
(O::Lt | O::LtEq, false) | (O::Gt | O::GtEq, true) => col!(max: col),
245
_ => unreachable!(),
246
};
247
let cmp_op = match (op, col_is_left) {
248
(O::Lt, true) | (O::Gt, false) => O::GtEq,
249
(O::Lt, false) | (O::Gt, true) => O::LtEq,
250
251
(O::LtEq, true) | (O::GtEq, false) => O::Gt,
252
(O::LtEq, false) | (O::GtEq, true) => O::Lt,
253
254
_ => unreachable!(),
255
};
256
257
let stat_is_defined = is_stat_defined(stat, dtype, arena);
258
let cmp_op = stat.binary_op(lv_node, cmp_op, arena);
259
let mut expr = stat_is_defined.and(cmp_op, arena);
260
261
if lv_may_be_null {
262
let has_nulls = lv_node.into_aexpr_builder().has_nulls(arena);
263
expr = has_nulls.or(expr, arena);
264
}
265
Some(expr.node())
266
},
267
268
O::And | O::LogicalAnd => match (rec!(left), rec!(right)) {
269
(Some(left), Some(right)) => {
270
Some(AExprBuilder::new_from_node(left).or(right, arena).node())
271
},
272
(Some(n), None) | (None, Some(n)) => Some(n),
273
(None, None) => None,
274
},
275
O::Or | O::LogicalOr => {
276
let left = rec!(left)?;
277
let right = rec!(right)?;
278
Some(AExprBuilder::new_from_node(left).and(right, arena).node())
279
},
280
281
O::Plus
282
| O::Minus
283
| O::Multiply
284
| O::Divide
285
| O::TrueDivide
286
| O::FloorDivide
287
| O::Modulus
288
| O::Xor => None,
289
}
290
},
291
AExpr::Cast { .. } => None,
292
AExpr::Sort { .. } => None,
293
AExpr::Gather { .. } => None,
294
AExpr::SortBy { .. } => None,
295
AExpr::Filter { .. } => None,
296
AExpr::Agg(..) | AExpr::AnonymousStreamingAgg { .. } => None,
297
AExpr::Ternary { .. } => None,
298
AExpr::AnonymousFunction { .. } => None,
299
AExpr::Eval { .. } => None,
300
AExpr::Function {
301
input, function, ..
302
} => match function {
303
IRFunctionExpr::Boolean(f) => match f {
304
#[cfg(feature = "is_in")]
305
IRBooleanFunction::IsIn { nulls_equal } => {
306
if !is_scalar_ae(input[1].node(), arena) {
307
return None;
308
}
309
310
let nulls_equal = *nulls_equal;
311
let lv_node = input[1].node();
312
match (
313
into_column(input[0].node(), arena),
314
constant_evaluate(lv_node, arena, schema, 0),
315
) {
316
(Some(col), Some(_)) => {
317
use polars_core::prelude::ExplodeOptions;
318
319
let dtype = schema.get(col)?;
320
if !does_dtype_have_sufficient_order(dtype) {
321
return None;
322
}
323
324
// col(A).is_in([B1, ..., Bn]) ->
325
// ([B1, ..., Bn].has_no_nulls() || null_count(A) == 0) &&
326
// (
327
// min(A) > max[B1, ..., Bn] ||
328
// max(A) < min[B1, ..., Bn]
329
// )
330
let col = col.clone();
331
let lv_node = lv_node.into_aexpr_builder();
332
333
let lv_node_exploded = lv_node.explode(
334
arena,
335
ExplodeOptions {
336
empty_as_null: false,
337
keep_nulls: true,
338
},
339
);
340
let lv_min = lv_node_exploded.min(arena);
341
let lv_max = lv_node_exploded.max(arena);
342
343
let col_min = col!(min: col);
344
let col_max = col!(max: col);
345
346
let min_is_defined = is_stat_defined(col_min, dtype, arena);
347
let max_is_defined = is_stat_defined(col_max, dtype, arena);
348
349
let min_gt = col_min.gt(lv_max, arena);
350
let min_gt = min_is_defined.and(min_gt, arena);
351
352
let max_lt = col_max.lt(lv_min, arena);
353
let max_lt = max_is_defined.and(max_lt, arena);
354
355
let expr = min_gt.or(max_lt, arena);
356
357
let col_nc = col!(null_count: col);
358
let col_has_no_nulls = col_nc.has_no_nulls(arena);
359
360
let lv_has_not_nulls = lv_node_exploded.has_no_nulls(arena);
361
let null_case = lv_has_not_nulls.or(col_has_no_nulls, arena);
362
363
let min_max_is_in = null_case.and(expr, arena);
364
365
let col_nc = col!(null_count: col);
366
367
let min_is_max = col_min.eq(col_max, arena); // Eq so that (None == None) == None
368
let idx_zero = lv!(idx: 0);
369
let has_no_nulls = col_nc.eq(idx_zero, arena);
370
371
// The above case does always cover the fallback path. Since there
372
// is code that relies on the `min==max` always filtering normally,
373
// we add it here.
374
let exact_not_in =
375
col_min.is_in(lv_node, nulls_equal, arena).not(arena);
376
let exact_not_in =
377
min_is_max.and(has_no_nulls, arena).and(exact_not_in, arena);
378
379
Some(exact_not_in.or(min_max_is_in, arena).node())
380
},
381
_ => None,
382
}
383
},
384
IRBooleanFunction::IsNull => {
385
let col = into_column(input[0].node(), arena)?;
386
387
// col(A).is_null() -> null_count(A) == 0
388
let col_nc = col!(null_count: col);
389
let idx_zero = lv!(idx: 0);
390
Some(col_nc.eq(idx_zero, arena).node())
391
},
392
IRBooleanFunction::IsNotNull => {
393
let col = into_column(input[0].node(), arena)?;
394
395
// col(A).is_not_null() -> null_count(A) == LEN
396
let col_nc = col!(null_count: col);
397
let len = col!(len);
398
Some(col_nc.eq(len, arena).node())
399
},
400
#[cfg(feature = "is_between")]
401
IRBooleanFunction::IsBetween { closed } => {
402
let col = into_column(input[0].node(), arena)?;
403
let dtype = schema.get(col)?;
404
405
if !does_dtype_have_sufficient_order(dtype) {
406
return None;
407
}
408
409
// col(A).is_between(X, Y) ->
410
// null_count(A) == LEN ||
411
// min(A) >(=) Y ||
412
// max(A) <(=) X
413
414
let left_node = input[1].node();
415
let right_node = input[2].node();
416
417
_ = constant_evaluate(left_node, arena, schema, 0)?;
418
_ = constant_evaluate(right_node, arena, schema, 0)?;
419
420
let col = col.clone();
421
let closed = *closed;
422
423
let lhs_no_nulls = left_node.into_aexpr_builder().has_no_nulls(arena);
424
let rhs_no_nulls = right_node.into_aexpr_builder().has_no_nulls(arena);
425
426
let col_min = col!(min: col);
427
let col_max = col!(max: col);
428
429
use polars_ops::series::ClosedInterval;
430
let (left, right) = match closed {
431
ClosedInterval::Both => (O::Lt, O::Gt),
432
ClosedInterval::Left => (O::Lt, O::GtEq),
433
ClosedInterval::Right => (O::LtEq, O::Gt),
434
ClosedInterval::None => (O::LtEq, O::GtEq),
435
};
436
437
let left = col_max.binary_op(left_node, left, arena);
438
let right = col_min.binary_op(right_node, right, arena);
439
440
let min_is_defined = is_stat_defined(col_min, dtype, arena);
441
let max_is_defined = is_stat_defined(col_max, dtype, arena);
442
443
let left = max_is_defined.and(left, arena);
444
let right = min_is_defined.and(right, arena);
445
446
let interval = left.or(right, arena);
447
Some(
448
lhs_no_nulls
449
.and(rhs_no_nulls, arena)
450
.and(interval, arena)
451
.node(),
452
)
453
},
454
_ => None,
455
},
456
_ => None,
457
},
458
#[cfg(feature = "dynamic_group_by")]
459
AExpr::Rolling { .. } => None,
460
AExpr::Over { .. } => None,
461
AExpr::Slice { .. } => None,
462
AExpr::Len => None,
463
}
464
})();
465
466
if let Some(specialized) = specialized {
467
return Some(specialized);
468
}
469
470
// If we don't have a specialized implementation we can check if the whole block is constant
471
// and fill that value in. This is especially useful when filtering hive partitions which are
472
// filtered using this expression and which set their min == max.
473
//
474
// Essentially, what this does is
475
// E -> all(col(A_min) == col(A_max) & col(A_nc) == 0 for A in LIVE(E)) & ~(E)
476
477
let live_columns = PlIndexMap::from_iter(aexpr_to_leaf_names_iter(e, arena).map(|col| {
478
let min_name = format_pl_smallstr!("{col}_min");
479
(col.clone(), min_name)
480
}));
481
482
// We cannot do proper equalities for these.
483
if live_columns
484
.iter()
485
.any(|(c, _)| schema.get(c).is_none_or(|dt| dt.is_categorical()))
486
{
487
return None;
488
}
489
490
// Rename all uses of column names with the min value.
491
let expr = rename_columns(e, arena, &live_columns);
492
let mut expr = expr.into_aexpr_builder().not(arena);
493
for col in live_columns.keys() {
494
let col_min = col!(min: col);
495
let col_max = col!(max: col);
496
let col_nc = col!(null_count: col);
497
498
let min_is_max = col_min.eq(col_max, arena); // Eq so that (None == None) == None
499
let idx_zero = lv!(idx: 0);
500
let has_no_nulls = col_nc.eq(idx_zero, arena);
501
502
expr = min_is_max.and(has_no_nulls, arena).and(expr, arena);
503
}
504
Some(expr.node())
505
}
506
507