Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/plans/aexpr/predicates/skip_batches.rs
8458 views
1
//! This module creates predicates that can skip record batches of rows based on statistics about
2
//! that record batch.
3
4
use polars_core::prelude::{AnyValue, DataType, Scalar};
5
use polars_core::schema::Schema;
6
use polars_utils::aliases::PlIndexMap;
7
use polars_utils::arena::{Arena, Node};
8
use polars_utils::format_pl_smallstr;
9
use polars_utils::pl_str::PlSmallStr;
10
11
use super::super::evaluate::{constant_evaluate, into_column};
12
use super::super::{AExpr, IRBooleanFunction, IRFunctionExpr, Operator};
13
use crate::plans::aexpr::builder::IntoAExprBuilder;
14
use crate::plans::predicates::get_binary_expr_col_and_lv;
15
use crate::plans::{AExprBuilder, aexpr_to_leaf_names_iter, is_scalar_ae, rename_columns};
16
17
/// Return a new boolean expression determines whether a batch can be skipped based on min, max and
18
/// null count statistics.
19
///
20
/// This is conversative and may return `None` or `false` when an expression is not yet supported.
21
///
22
/// To evaluate, the expression it is given all the original column appended with `_min` and
23
/// `_max`. The `min` or `max` cannot be null and when they are null it is assumed they are not
24
/// known.
25
pub fn aexpr_to_skip_batch_predicate(
26
e: Node,
27
expr_arena: &mut Arena<AExpr>,
28
schema: &Schema,
29
) -> Option<Node> {
30
aexpr_to_skip_batch_predicate_rec(e, expr_arena, schema, 0)
31
}
32
33
fn does_dtype_have_sufficient_order(dtype: &DataType) -> bool {
34
// Rules surrounding floats are really complicated. I should get around to that.
35
!dtype.is_nested() && !dtype.is_float() && !dtype.is_null() && !dtype.is_categorical()
36
}
37
38
fn is_stat_defined(
39
expr: impl IntoAExprBuilder,
40
dtype: &DataType,
41
arena: &mut Arena<AExpr>,
42
) -> AExprBuilder {
43
let mut expr = expr.into_aexpr_builder();
44
expr = expr.is_not_null(arena);
45
if dtype.is_float() {
46
let is_not_nan = expr.is_not_nan(arena);
47
expr = expr.and(is_not_nan, arena);
48
}
49
expr
50
}
51
52
#[recursive::recursive]
53
fn aexpr_to_skip_batch_predicate_rec(
54
e: Node,
55
arena: &mut Arena<AExpr>,
56
schema: &Schema,
57
depth: usize,
58
) -> Option<Node> {
59
use Operator as O;
60
61
macro_rules! rec {
62
($node:expr) => {{ aexpr_to_skip_batch_predicate_rec($node, arena, schema, depth + 1) }};
63
}
64
macro_rules! lv_cases {
65
(
66
$lv:expr, $lv_node:expr,
67
null: $null_case:expr,
68
not_null: $non_null_case:expr $(,)?
69
) => {{
70
if let Some(lv) = $lv {
71
if lv.is_null() {
72
$null_case
73
} else {
74
$non_null_case
75
}
76
} else {
77
let lv_node = AExprBuilder::new_from_node($lv_node);
78
79
let lv_is_null = lv_node.has_nulls(arena);
80
let lv_not_null = lv_node.has_no_nulls(arena);
81
82
let null_case = lv_is_null.and($null_case, arena);
83
let non_null_case = lv_not_null.and($non_null_case, arena);
84
85
null_case.or(non_null_case, arena).node()
86
}
87
}};
88
}
89
macro_rules! col {
90
(len) => {{ col!(PlSmallStr::from_static("len")) }};
91
($name:expr) => {{ AExprBuilder::new_from_node(arena.add(AExpr::Column($name))) }};
92
(min: $name:expr) => {{ col!(format_pl_smallstr!("{}_min", $name)) }};
93
(max: $name:expr) => {{ col!(format_pl_smallstr!("{}_max", $name)) }};
94
(null_count: $name:expr) => {{ col!(format_pl_smallstr!("{}_nc", $name)) }};
95
}
96
macro_rules! lv {
97
($lv:expr) => {{ AExprBuilder::lit_scalar(Scalar::from($lv), arena) }};
98
(idx: $lv:expr) => {{ AExprBuilder::lit_scalar(Scalar::new_idxsize($lv), arena) }};
99
}
100
101
let specialized = (|| {
102
if let Some(Some(lv)) = constant_evaluate(e, arena, schema, 0) {
103
if let Some(av) = lv.to_any_value() {
104
return match av {
105
AnyValue::Null => Some(lv!(true).node()),
106
AnyValue::Boolean(b) => Some(lv!(!b).node()),
107
_ => None,
108
};
109
}
110
}
111
112
match arena.get(e) {
113
AExpr::Element => None,
114
AExpr::Explode { .. } => None,
115
AExpr::Column(_) => None,
116
#[cfg(feature = "dtype-struct")]
117
AExpr::StructField(_) => None,
118
AExpr::Literal(_) => None,
119
AExpr::BinaryExpr { left, op, right } => {
120
let left = *left;
121
let right = *right;
122
123
match op {
124
O::Eq | O::EqValidity => {
125
let ((col, _), (lv, lv_node)) =
126
get_binary_expr_col_and_lv(left, right, arena, schema)?;
127
let dtype = schema.get(col)?;
128
129
if !does_dtype_have_sufficient_order(dtype) {
130
return None;
131
}
132
133
let op = *op;
134
let col = col.clone();
135
136
// col(A) == B -> {
137
// null_count(A) == 0 , if B.is_null(),
138
// null_count(A) == LEN || min(A) > B || max(A) < B, if B.is_not_null(),
139
// }
140
141
Some(lv_cases!(
142
lv, lv_node,
143
null: {
144
if matches!(op, O::Eq) {
145
lv!(false).node()
146
} else {
147
let col_nc = col!(null_count: col);
148
let idx_zero = lv!(idx: 0);
149
col_nc.eq(idx_zero, arena).node()
150
}
151
},
152
not_null: {
153
let col_min = col!(min: col);
154
let col_max = col!(max: col);
155
156
let min_is_defined = is_stat_defined(col_min.node(), dtype, arena);
157
let max_is_defined = is_stat_defined(col_max.node(), dtype, arena);
158
159
let min_gt = col_min.gt(lv_node, arena);
160
let min_gt = min_gt.and(min_is_defined, arena);
161
162
let max_lt = col_max.lt(lv_node, arena);
163
let max_lt = max_lt.and(max_is_defined, arena);
164
165
let col_nc = col!(null_count: col);
166
let len = col!(len);
167
let all_nulls = col_nc.eq(len, arena);
168
169
all_nulls.or(min_gt, arena).or(max_lt, arena).node()
170
}
171
))
172
},
173
O::NotEq | O::NotEqValidity => {
174
let ((col, _), (lv, lv_node)) =
175
get_binary_expr_col_and_lv(left, right, arena, schema)?;
176
let dtype = schema.get(col)?;
177
178
if !does_dtype_have_sufficient_order(dtype) {
179
return None;
180
}
181
182
let op = *op;
183
let col = col.clone();
184
185
// col(A) != B -> {
186
// null_count(A) == LEN , if B.is_null(),
187
// null_count(A) == 0 && min(A) == B && max(A) == B, if B.is_not_null(),
188
// }
189
190
Some(lv_cases!(
191
lv, lv_node,
192
null: {
193
if matches!(op, O::NotEq) {
194
lv!(false).node()
195
} else {
196
let col_nc = col!(null_count: col);
197
let len = col!(len);
198
col_nc.eq(len, arena).node()
199
}
200
},
201
not_null: {
202
let col_min = col!(min: col);
203
let col_max = col!(max: col);
204
let min_eq = col_min.eq(lv_node, arena);
205
let max_eq = col_max.eq(lv_node, arena);
206
207
let col_nc = col!(null_count: col);
208
let idx_zero = lv!(idx: 0);
209
let no_nulls = col_nc.eq(idx_zero, arena);
210
211
no_nulls.and(min_eq, arena).and(max_eq, arena).node()
212
}
213
))
214
},
215
O::Lt | O::Gt | O::LtEq | O::GtEq => {
216
let ((col, col_node), (lv, lv_node)) =
217
get_binary_expr_col_and_lv(left, right, arena, schema)?;
218
let dtype = schema.get(col)?;
219
220
if !does_dtype_have_sufficient_order(dtype) {
221
return None;
222
}
223
224
let col_is_left = col_node == left;
225
226
let op = *op;
227
let col = col.clone();
228
let lv_may_be_null = lv.is_none_or(|lv| lv.is_null());
229
230
// If B is null, this is always true.
231
//
232
// col(A) < B ~ B > col(A) ->
233
// null_count(A) == LEN || min(A) >= B
234
//
235
// col(A) > B ~ B < col(A) ->
236
// null_count(A) == LEN || max(A) <= B
237
//
238
// col(A) <= B ~ B >= col(A) ->
239
// null_count(A) == LEN || min(A) > B
240
//
241
// col(A) >= B ~ B <= col(A) ->
242
// null_count(A) == LEN || max(A) < B
243
244
let stat = match (op, col_is_left) {
245
(O::Lt | O::LtEq, true) | (O::Gt | O::GtEq, false) => col!(min: col),
246
(O::Lt | O::LtEq, false) | (O::Gt | O::GtEq, true) => col!(max: col),
247
_ => unreachable!(),
248
};
249
let cmp_op = match (op, col_is_left) {
250
(O::Lt, true) | (O::Gt, false) => O::GtEq,
251
(O::Lt, false) | (O::Gt, true) => O::LtEq,
252
253
(O::LtEq, true) | (O::GtEq, false) => O::Gt,
254
(O::LtEq, false) | (O::GtEq, true) => O::Lt,
255
256
_ => unreachable!(),
257
};
258
259
let stat_is_defined = is_stat_defined(stat, dtype, arena);
260
let cmp_op = stat.binary_op(lv_node, cmp_op, arena);
261
let mut expr = stat_is_defined.and(cmp_op, arena);
262
263
if lv_may_be_null {
264
let has_nulls = lv_node.into_aexpr_builder().has_nulls(arena);
265
expr = has_nulls.or(expr, arena);
266
}
267
Some(expr.node())
268
},
269
270
O::And | O::LogicalAnd => match (rec!(left), rec!(right)) {
271
(Some(left), Some(right)) => {
272
Some(AExprBuilder::new_from_node(left).or(right, arena).node())
273
},
274
(Some(n), None) | (None, Some(n)) => Some(n),
275
(None, None) => None,
276
},
277
O::Or | O::LogicalOr => {
278
let left = rec!(left)?;
279
let right = rec!(right)?;
280
Some(AExprBuilder::new_from_node(left).and(right, arena).node())
281
},
282
283
O::Plus
284
| O::Minus
285
| O::Multiply
286
| O::RustDivide
287
| O::TrueDivide
288
| O::FloorDivide
289
| O::Modulus
290
| O::Xor => None,
291
}
292
},
293
AExpr::Cast { .. } => None,
294
AExpr::Sort { .. } => None,
295
AExpr::Gather { .. } => None,
296
AExpr::SortBy { .. } => None,
297
AExpr::Filter { .. } => None,
298
AExpr::Agg(..) | AExpr::AnonymousAgg { .. } => None,
299
AExpr::Ternary { .. } => None,
300
AExpr::AnonymousFunction { .. } => None,
301
AExpr::Eval { .. } => None,
302
#[cfg(feature = "dtype-struct")]
303
AExpr::StructEval { .. } => None,
304
AExpr::Function {
305
input, function, ..
306
} => match function {
307
IRFunctionExpr::Boolean(f) => match f {
308
#[cfg(feature = "is_in")]
309
IRBooleanFunction::IsIn { nulls_equal } => {
310
if !is_scalar_ae(input[1].node(), arena) {
311
return None;
312
}
313
314
let nulls_equal = *nulls_equal;
315
let lv_node = input[1].node();
316
match (
317
into_column(input[0].node(), arena),
318
constant_evaluate(lv_node, arena, schema, 0),
319
) {
320
(Some(col), Some(_)) => {
321
use polars_core::prelude::ExplodeOptions;
322
323
let dtype = schema.get(col)?;
324
if !does_dtype_have_sufficient_order(dtype) {
325
return None;
326
}
327
328
// col(A).is_in([B1, ..., Bn]) ->
329
// ([B1, ..., Bn].has_no_nulls() || null_count(A) == 0) &&
330
// (
331
// min(A) > max[B1, ..., Bn] ||
332
// max(A) < min[B1, ..., Bn]
333
// )
334
let col = col.clone();
335
let lv_node = lv_node.into_aexpr_builder();
336
337
let lv_node_exploded = lv_node.explode(
338
arena,
339
ExplodeOptions {
340
empty_as_null: false,
341
keep_nulls: true,
342
},
343
);
344
let lv_min = lv_node_exploded.min(arena);
345
let lv_max = lv_node_exploded.max(arena);
346
347
let col_min = col!(min: col);
348
let col_max = col!(max: col);
349
350
let min_is_defined = is_stat_defined(col_min, dtype, arena);
351
let max_is_defined = is_stat_defined(col_max, dtype, arena);
352
353
let min_gt = col_min.gt(lv_max, arena);
354
let min_gt = min_is_defined.and(min_gt, arena);
355
356
let max_lt = col_max.lt(lv_min, arena);
357
let max_lt = max_is_defined.and(max_lt, arena);
358
359
let expr = min_gt.or(max_lt, arena);
360
361
let col_nc = col!(null_count: col);
362
let col_has_no_nulls = col_nc.has_no_nulls(arena);
363
364
let lv_has_not_nulls = lv_node_exploded.has_no_nulls(arena);
365
let null_case = lv_has_not_nulls.or(col_has_no_nulls, arena);
366
367
let min_max_is_in = null_case.and(expr, arena);
368
369
let col_nc = col!(null_count: col);
370
371
let min_is_max = col_min.eq(col_max, arena); // Eq so that (None == None) == None
372
let idx_zero = lv!(idx: 0);
373
let has_no_nulls = col_nc.eq(idx_zero, arena);
374
375
// The above case does always cover the fallback path. Since there
376
// is code that relies on the `min==max` always filtering normally,
377
// we add it here.
378
let exact_not_in =
379
col_min.is_in(lv_node, nulls_equal, arena).not(arena);
380
let exact_not_in =
381
min_is_max.and(has_no_nulls, arena).and(exact_not_in, arena);
382
383
Some(exact_not_in.or(min_max_is_in, arena).node())
384
},
385
_ => None,
386
}
387
},
388
IRBooleanFunction::IsNull => {
389
let col = into_column(input[0].node(), arena)?;
390
391
// col(A).is_null() -> null_count(A) == 0
392
let col_nc = col!(null_count: col);
393
let idx_zero = lv!(idx: 0);
394
Some(col_nc.eq(idx_zero, arena).node())
395
},
396
IRBooleanFunction::IsNotNull => {
397
let col = into_column(input[0].node(), arena)?;
398
399
// col(A).is_not_null() -> null_count(A) == LEN
400
let col_nc = col!(null_count: col);
401
let len = col!(len);
402
Some(col_nc.eq(len, arena).node())
403
},
404
#[cfg(feature = "is_between")]
405
IRBooleanFunction::IsBetween { closed } => {
406
let col = into_column(input[0].node(), arena)?;
407
let dtype = schema.get(col)?;
408
409
if !does_dtype_have_sufficient_order(dtype) {
410
return None;
411
}
412
413
// col(A).is_between(X, Y) ->
414
// null_count(A) == LEN ||
415
// min(A) >(=) Y ||
416
// max(A) <(=) X
417
418
let left_node = input[1].node();
419
let right_node = input[2].node();
420
421
_ = constant_evaluate(left_node, arena, schema, 0)?;
422
_ = constant_evaluate(right_node, arena, schema, 0)?;
423
424
let col = col.clone();
425
let closed = *closed;
426
427
let lhs_no_nulls = left_node.into_aexpr_builder().has_no_nulls(arena);
428
let rhs_no_nulls = right_node.into_aexpr_builder().has_no_nulls(arena);
429
430
let col_min = col!(min: col);
431
let col_max = col!(max: col);
432
433
use polars_ops::series::ClosedInterval;
434
let (left, right) = match closed {
435
ClosedInterval::Both => (O::Lt, O::Gt),
436
ClosedInterval::Left => (O::Lt, O::GtEq),
437
ClosedInterval::Right => (O::LtEq, O::Gt),
438
ClosedInterval::None => (O::LtEq, O::GtEq),
439
};
440
441
let left = col_max.binary_op(left_node, left, arena);
442
let right = col_min.binary_op(right_node, right, arena);
443
444
let min_is_defined = is_stat_defined(col_min, dtype, arena);
445
let max_is_defined = is_stat_defined(col_max, dtype, arena);
446
447
let left = max_is_defined.and(left, arena);
448
let right = min_is_defined.and(right, arena);
449
450
let interval = left.or(right, arena);
451
Some(
452
lhs_no_nulls
453
.and(rhs_no_nulls, arena)
454
.and(interval, arena)
455
.node(),
456
)
457
},
458
_ => None,
459
},
460
_ => None,
461
},
462
#[cfg(feature = "dynamic_group_by")]
463
AExpr::Rolling { .. } => None,
464
AExpr::Over { .. } => None,
465
AExpr::Slice { .. } => None,
466
AExpr::Len => None,
467
}
468
})();
469
470
if let Some(specialized) = specialized {
471
return Some(specialized);
472
}
473
474
// If we don't have a specialized implementation we can check if the whole block is constant
475
// and fill that value in. This is especially useful when filtering hive partitions which are
476
// filtered using this expression and which set their min == max.
477
//
478
// Essentially, what this does is
479
// E -> all(col(A_min) == col(A_max) & col(A_nc) == 0 for A in LIVE(E)) & ~(E)
480
481
let live_columns = PlIndexMap::from_iter(aexpr_to_leaf_names_iter(e, arena).map(|col| {
482
let min_name = format_pl_smallstr!("{col}_min");
483
(col.clone(), min_name)
484
}));
485
486
// We cannot do proper equalities for these.
487
if live_columns
488
.iter()
489
.any(|(c, _)| schema.get(c).is_none_or(|dt| dt.is_categorical()))
490
{
491
return None;
492
}
493
494
// Rename all uses of column names with the min value.
495
let expr = rename_columns(e, arena, &live_columns);
496
let mut expr = expr.into_aexpr_builder().not(arena);
497
for col in live_columns.keys() {
498
let col_min = col!(min: col);
499
let col_max = col!(max: col);
500
let col_nc = col!(null_count: col);
501
502
let min_is_max = col_min.eq(col_max, arena); // Eq so that (None == None) == None
503
let idx_zero = lv!(idx: 0);
504
let has_no_nulls = col_nc.eq(idx_zero, arena);
505
506
expr = min_is_max.and(has_no_nulls, arena).and(expr, arena);
507
}
508
Some(expr.node())
509
}
510
511