Path: blob/main/crates/polars-plan/src/plans/aexpr/predicates/skip_batches.rs
8458 views
//! This module creates predicates that can skip record batches of rows based on statistics about1//! that record batch.23use polars_core::prelude::{AnyValue, DataType, Scalar};4use polars_core::schema::Schema;5use polars_utils::aliases::PlIndexMap;6use polars_utils::arena::{Arena, Node};7use polars_utils::format_pl_smallstr;8use polars_utils::pl_str::PlSmallStr;910use super::super::evaluate::{constant_evaluate, into_column};11use super::super::{AExpr, IRBooleanFunction, IRFunctionExpr, Operator};12use crate::plans::aexpr::builder::IntoAExprBuilder;13use crate::plans::predicates::get_binary_expr_col_and_lv;14use crate::plans::{AExprBuilder, aexpr_to_leaf_names_iter, is_scalar_ae, rename_columns};1516/// Return a new boolean expression determines whether a batch can be skipped based on min, max and17/// null count statistics.18///19/// This is conversative and may return `None` or `false` when an expression is not yet supported.20///21/// To evaluate, the expression it is given all the original column appended with `_min` and22/// `_max`. The `min` or `max` cannot be null and when they are null it is assumed they are not23/// known.24pub fn aexpr_to_skip_batch_predicate(25e: Node,26expr_arena: &mut Arena<AExpr>,27schema: &Schema,28) -> Option<Node> {29aexpr_to_skip_batch_predicate_rec(e, expr_arena, schema, 0)30}3132fn does_dtype_have_sufficient_order(dtype: &DataType) -> bool {33// Rules surrounding floats are really complicated. I should get around to that.34!dtype.is_nested() && !dtype.is_float() && !dtype.is_null() && !dtype.is_categorical()35}3637fn is_stat_defined(38expr: impl IntoAExprBuilder,39dtype: &DataType,40arena: &mut Arena<AExpr>,41) -> AExprBuilder {42let mut expr = expr.into_aexpr_builder();43expr = expr.is_not_null(arena);44if dtype.is_float() {45let is_not_nan = expr.is_not_nan(arena);46expr = expr.and(is_not_nan, arena);47}48expr49}5051#[recursive::recursive]52fn aexpr_to_skip_batch_predicate_rec(53e: Node,54arena: &mut Arena<AExpr>,55schema: &Schema,56depth: usize,57) -> Option<Node> {58use Operator as O;5960macro_rules! rec {61($node:expr) => {{ aexpr_to_skip_batch_predicate_rec($node, arena, schema, depth + 1) }};62}63macro_rules! lv_cases {64(65$lv:expr, $lv_node:expr,66null: $null_case:expr,67not_null: $non_null_case:expr $(,)?68) => {{69if let Some(lv) = $lv {70if lv.is_null() {71$null_case72} else {73$non_null_case74}75} else {76let lv_node = AExprBuilder::new_from_node($lv_node);7778let lv_is_null = lv_node.has_nulls(arena);79let lv_not_null = lv_node.has_no_nulls(arena);8081let null_case = lv_is_null.and($null_case, arena);82let non_null_case = lv_not_null.and($non_null_case, arena);8384null_case.or(non_null_case, arena).node()85}86}};87}88macro_rules! col {89(len) => {{ col!(PlSmallStr::from_static("len")) }};90($name:expr) => {{ AExprBuilder::new_from_node(arena.add(AExpr::Column($name))) }};91(min: $name:expr) => {{ col!(format_pl_smallstr!("{}_min", $name)) }};92(max: $name:expr) => {{ col!(format_pl_smallstr!("{}_max", $name)) }};93(null_count: $name:expr) => {{ col!(format_pl_smallstr!("{}_nc", $name)) }};94}95macro_rules! lv {96($lv:expr) => {{ AExprBuilder::lit_scalar(Scalar::from($lv), arena) }};97(idx: $lv:expr) => {{ AExprBuilder::lit_scalar(Scalar::new_idxsize($lv), arena) }};98}99100let specialized = (|| {101if let Some(Some(lv)) = constant_evaluate(e, arena, schema, 0) {102if let Some(av) = lv.to_any_value() {103return match av {104AnyValue::Null => Some(lv!(true).node()),105AnyValue::Boolean(b) => Some(lv!(!b).node()),106_ => None,107};108}109}110111match arena.get(e) {112AExpr::Element => None,113AExpr::Explode { .. } => None,114AExpr::Column(_) => None,115#[cfg(feature = "dtype-struct")]116AExpr::StructField(_) => None,117AExpr::Literal(_) => None,118AExpr::BinaryExpr { left, op, right } => {119let left = *left;120let right = *right;121122match op {123O::Eq | O::EqValidity => {124let ((col, _), (lv, lv_node)) =125get_binary_expr_col_and_lv(left, right, arena, schema)?;126let dtype = schema.get(col)?;127128if !does_dtype_have_sufficient_order(dtype) {129return None;130}131132let op = *op;133let col = col.clone();134135// col(A) == B -> {136// null_count(A) == 0 , if B.is_null(),137// null_count(A) == LEN || min(A) > B || max(A) < B, if B.is_not_null(),138// }139140Some(lv_cases!(141lv, lv_node,142null: {143if matches!(op, O::Eq) {144lv!(false).node()145} else {146let col_nc = col!(null_count: col);147let idx_zero = lv!(idx: 0);148col_nc.eq(idx_zero, arena).node()149}150},151not_null: {152let col_min = col!(min: col);153let col_max = col!(max: col);154155let min_is_defined = is_stat_defined(col_min.node(), dtype, arena);156let max_is_defined = is_stat_defined(col_max.node(), dtype, arena);157158let min_gt = col_min.gt(lv_node, arena);159let min_gt = min_gt.and(min_is_defined, arena);160161let max_lt = col_max.lt(lv_node, arena);162let max_lt = max_lt.and(max_is_defined, arena);163164let col_nc = col!(null_count: col);165let len = col!(len);166let all_nulls = col_nc.eq(len, arena);167168all_nulls.or(min_gt, arena).or(max_lt, arena).node()169}170))171},172O::NotEq | O::NotEqValidity => {173let ((col, _), (lv, lv_node)) =174get_binary_expr_col_and_lv(left, right, arena, schema)?;175let dtype = schema.get(col)?;176177if !does_dtype_have_sufficient_order(dtype) {178return None;179}180181let op = *op;182let col = col.clone();183184// col(A) != B -> {185// null_count(A) == LEN , if B.is_null(),186// null_count(A) == 0 && min(A) == B && max(A) == B, if B.is_not_null(),187// }188189Some(lv_cases!(190lv, lv_node,191null: {192if matches!(op, O::NotEq) {193lv!(false).node()194} else {195let col_nc = col!(null_count: col);196let len = col!(len);197col_nc.eq(len, arena).node()198}199},200not_null: {201let col_min = col!(min: col);202let col_max = col!(max: col);203let min_eq = col_min.eq(lv_node, arena);204let max_eq = col_max.eq(lv_node, arena);205206let col_nc = col!(null_count: col);207let idx_zero = lv!(idx: 0);208let no_nulls = col_nc.eq(idx_zero, arena);209210no_nulls.and(min_eq, arena).and(max_eq, arena).node()211}212))213},214O::Lt | O::Gt | O::LtEq | O::GtEq => {215let ((col, col_node), (lv, lv_node)) =216get_binary_expr_col_and_lv(left, right, arena, schema)?;217let dtype = schema.get(col)?;218219if !does_dtype_have_sufficient_order(dtype) {220return None;221}222223let col_is_left = col_node == left;224225let op = *op;226let col = col.clone();227let lv_may_be_null = lv.is_none_or(|lv| lv.is_null());228229// If B is null, this is always true.230//231// col(A) < B ~ B > col(A) ->232// null_count(A) == LEN || min(A) >= B233//234// col(A) > B ~ B < col(A) ->235// null_count(A) == LEN || max(A) <= B236//237// col(A) <= B ~ B >= col(A) ->238// null_count(A) == LEN || min(A) > B239//240// col(A) >= B ~ B <= col(A) ->241// null_count(A) == LEN || max(A) < B242243let stat = match (op, col_is_left) {244(O::Lt | O::LtEq, true) | (O::Gt | O::GtEq, false) => col!(min: col),245(O::Lt | O::LtEq, false) | (O::Gt | O::GtEq, true) => col!(max: col),246_ => unreachable!(),247};248let cmp_op = match (op, col_is_left) {249(O::Lt, true) | (O::Gt, false) => O::GtEq,250(O::Lt, false) | (O::Gt, true) => O::LtEq,251252(O::LtEq, true) | (O::GtEq, false) => O::Gt,253(O::LtEq, false) | (O::GtEq, true) => O::Lt,254255_ => unreachable!(),256};257258let stat_is_defined = is_stat_defined(stat, dtype, arena);259let cmp_op = stat.binary_op(lv_node, cmp_op, arena);260let mut expr = stat_is_defined.and(cmp_op, arena);261262if lv_may_be_null {263let has_nulls = lv_node.into_aexpr_builder().has_nulls(arena);264expr = has_nulls.or(expr, arena);265}266Some(expr.node())267},268269O::And | O::LogicalAnd => match (rec!(left), rec!(right)) {270(Some(left), Some(right)) => {271Some(AExprBuilder::new_from_node(left).or(right, arena).node())272},273(Some(n), None) | (None, Some(n)) => Some(n),274(None, None) => None,275},276O::Or | O::LogicalOr => {277let left = rec!(left)?;278let right = rec!(right)?;279Some(AExprBuilder::new_from_node(left).and(right, arena).node())280},281282O::Plus283| O::Minus284| O::Multiply285| O::RustDivide286| O::TrueDivide287| O::FloorDivide288| O::Modulus289| O::Xor => None,290}291},292AExpr::Cast { .. } => None,293AExpr::Sort { .. } => None,294AExpr::Gather { .. } => None,295AExpr::SortBy { .. } => None,296AExpr::Filter { .. } => None,297AExpr::Agg(..) | AExpr::AnonymousAgg { .. } => None,298AExpr::Ternary { .. } => None,299AExpr::AnonymousFunction { .. } => None,300AExpr::Eval { .. } => None,301#[cfg(feature = "dtype-struct")]302AExpr::StructEval { .. } => None,303AExpr::Function {304input, function, ..305} => match function {306IRFunctionExpr::Boolean(f) => match f {307#[cfg(feature = "is_in")]308IRBooleanFunction::IsIn { nulls_equal } => {309if !is_scalar_ae(input[1].node(), arena) {310return None;311}312313let nulls_equal = *nulls_equal;314let lv_node = input[1].node();315match (316into_column(input[0].node(), arena),317constant_evaluate(lv_node, arena, schema, 0),318) {319(Some(col), Some(_)) => {320use polars_core::prelude::ExplodeOptions;321322let dtype = schema.get(col)?;323if !does_dtype_have_sufficient_order(dtype) {324return None;325}326327// col(A).is_in([B1, ..., Bn]) ->328// ([B1, ..., Bn].has_no_nulls() || null_count(A) == 0) &&329// (330// min(A) > max[B1, ..., Bn] ||331// max(A) < min[B1, ..., Bn]332// )333let col = col.clone();334let lv_node = lv_node.into_aexpr_builder();335336let lv_node_exploded = lv_node.explode(337arena,338ExplodeOptions {339empty_as_null: false,340keep_nulls: true,341},342);343let lv_min = lv_node_exploded.min(arena);344let lv_max = lv_node_exploded.max(arena);345346let col_min = col!(min: col);347let col_max = col!(max: col);348349let min_is_defined = is_stat_defined(col_min, dtype, arena);350let max_is_defined = is_stat_defined(col_max, dtype, arena);351352let min_gt = col_min.gt(lv_max, arena);353let min_gt = min_is_defined.and(min_gt, arena);354355let max_lt = col_max.lt(lv_min, arena);356let max_lt = max_is_defined.and(max_lt, arena);357358let expr = min_gt.or(max_lt, arena);359360let col_nc = col!(null_count: col);361let col_has_no_nulls = col_nc.has_no_nulls(arena);362363let lv_has_not_nulls = lv_node_exploded.has_no_nulls(arena);364let null_case = lv_has_not_nulls.or(col_has_no_nulls, arena);365366let min_max_is_in = null_case.and(expr, arena);367368let col_nc = col!(null_count: col);369370let min_is_max = col_min.eq(col_max, arena); // Eq so that (None == None) == None371let idx_zero = lv!(idx: 0);372let has_no_nulls = col_nc.eq(idx_zero, arena);373374// The above case does always cover the fallback path. Since there375// is code that relies on the `min==max` always filtering normally,376// we add it here.377let exact_not_in =378col_min.is_in(lv_node, nulls_equal, arena).not(arena);379let exact_not_in =380min_is_max.and(has_no_nulls, arena).and(exact_not_in, arena);381382Some(exact_not_in.or(min_max_is_in, arena).node())383},384_ => None,385}386},387IRBooleanFunction::IsNull => {388let col = into_column(input[0].node(), arena)?;389390// col(A).is_null() -> null_count(A) == 0391let col_nc = col!(null_count: col);392let idx_zero = lv!(idx: 0);393Some(col_nc.eq(idx_zero, arena).node())394},395IRBooleanFunction::IsNotNull => {396let col = into_column(input[0].node(), arena)?;397398// col(A).is_not_null() -> null_count(A) == LEN399let col_nc = col!(null_count: col);400let len = col!(len);401Some(col_nc.eq(len, arena).node())402},403#[cfg(feature = "is_between")]404IRBooleanFunction::IsBetween { closed } => {405let col = into_column(input[0].node(), arena)?;406let dtype = schema.get(col)?;407408if !does_dtype_have_sufficient_order(dtype) {409return None;410}411412// col(A).is_between(X, Y) ->413// null_count(A) == LEN ||414// min(A) >(=) Y ||415// max(A) <(=) X416417let left_node = input[1].node();418let right_node = input[2].node();419420_ = constant_evaluate(left_node, arena, schema, 0)?;421_ = constant_evaluate(right_node, arena, schema, 0)?;422423let col = col.clone();424let closed = *closed;425426let lhs_no_nulls = left_node.into_aexpr_builder().has_no_nulls(arena);427let rhs_no_nulls = right_node.into_aexpr_builder().has_no_nulls(arena);428429let col_min = col!(min: col);430let col_max = col!(max: col);431432use polars_ops::series::ClosedInterval;433let (left, right) = match closed {434ClosedInterval::Both => (O::Lt, O::Gt),435ClosedInterval::Left => (O::Lt, O::GtEq),436ClosedInterval::Right => (O::LtEq, O::Gt),437ClosedInterval::None => (O::LtEq, O::GtEq),438};439440let left = col_max.binary_op(left_node, left, arena);441let right = col_min.binary_op(right_node, right, arena);442443let min_is_defined = is_stat_defined(col_min, dtype, arena);444let max_is_defined = is_stat_defined(col_max, dtype, arena);445446let left = max_is_defined.and(left, arena);447let right = min_is_defined.and(right, arena);448449let interval = left.or(right, arena);450Some(451lhs_no_nulls452.and(rhs_no_nulls, arena)453.and(interval, arena)454.node(),455)456},457_ => None,458},459_ => None,460},461#[cfg(feature = "dynamic_group_by")]462AExpr::Rolling { .. } => None,463AExpr::Over { .. } => None,464AExpr::Slice { .. } => None,465AExpr::Len => None,466}467})();468469if let Some(specialized) = specialized {470return Some(specialized);471}472473// If we don't have a specialized implementation we can check if the whole block is constant474// and fill that value in. This is especially useful when filtering hive partitions which are475// filtered using this expression and which set their min == max.476//477// Essentially, what this does is478// E -> all(col(A_min) == col(A_max) & col(A_nc) == 0 for A in LIVE(E)) & ~(E)479480let live_columns = PlIndexMap::from_iter(aexpr_to_leaf_names_iter(e, arena).map(|col| {481let min_name = format_pl_smallstr!("{col}_min");482(col.clone(), min_name)483}));484485// We cannot do proper equalities for these.486if live_columns487.iter()488.any(|(c, _)| schema.get(c).is_none_or(|dt| dt.is_categorical()))489{490return None;491}492493// Rename all uses of column names with the min value.494let expr = rename_columns(e, arena, &live_columns);495let mut expr = expr.into_aexpr_builder().not(arena);496for col in live_columns.keys() {497let col_min = col!(min: col);498let col_max = col!(max: col);499let col_nc = col!(null_count: col);500501let min_is_max = col_min.eq(col_max, arena); // Eq so that (None == None) == None502let idx_zero = lv!(idx: 0);503let has_no_nulls = col_nc.eq(idx_zero, arena);504505expr = min_is_max.and(has_no_nulls, arena).and(expr, arena);506}507Some(expr.node())508}509510511