Path: blob/main/crates/polars-plan/src/plans/aexpr/predicates/skip_batches.rs
7889 views
//! This module creates predicates that can skip record batches of rows based on statistics about1//! that record batch.23use polars_core::prelude::{AnyValue, DataType, Scalar};4use polars_core::schema::Schema;5use polars_utils::aliases::PlIndexMap;6use polars_utils::arena::{Arena, Node};7use polars_utils::format_pl_smallstr;8use polars_utils::pl_str::PlSmallStr;910use super::super::evaluate::{constant_evaluate, into_column};11use super::super::{AExpr, IRBooleanFunction, IRFunctionExpr, Operator};12use crate::plans::aexpr::builder::IntoAExprBuilder;13use crate::plans::predicates::get_binary_expr_col_and_lv;14use crate::plans::{AExprBuilder, aexpr_to_leaf_names_iter, is_scalar_ae, rename_columns};1516/// Return a new boolean expression determines whether a batch can be skipped based on min, max and17/// null count statistics.18///19/// This is conversative and may return `None` or `false` when an expression is not yet supported.20///21/// To evaluate, the expression it is given all the original column appended with `_min` and22/// `_max`. The `min` or `max` cannot be null and when they are null it is assumed they are not23/// known.24pub fn aexpr_to_skip_batch_predicate(25e: Node,26expr_arena: &mut Arena<AExpr>,27schema: &Schema,28) -> Option<Node> {29aexpr_to_skip_batch_predicate_rec(e, expr_arena, schema, 0)30}3132fn does_dtype_have_sufficient_order(dtype: &DataType) -> bool {33// Rules surrounding floats are really complicated. I should get around to that.34!dtype.is_nested() && !dtype.is_float() && !dtype.is_null() && !dtype.is_categorical()35}3637fn is_stat_defined(38expr: impl IntoAExprBuilder,39dtype: &DataType,40arena: &mut Arena<AExpr>,41) -> AExprBuilder {42let mut expr = expr.into_aexpr_builder();43expr = expr.is_not_null(arena);44if dtype.is_float() {45let is_not_nan = expr.is_not_nan(arena);46expr = expr.and(is_not_nan, arena);47}48expr49}5051#[recursive::recursive]52fn aexpr_to_skip_batch_predicate_rec(53e: Node,54arena: &mut Arena<AExpr>,55schema: &Schema,56depth: usize,57) -> Option<Node> {58use Operator as O;5960macro_rules! rec {61($node:expr) => {{ aexpr_to_skip_batch_predicate_rec($node, arena, schema, depth + 1) }};62}63macro_rules! lv_cases {64(65$lv:expr, $lv_node:expr,66null: $null_case:expr,67not_null: $non_null_case:expr $(,)?68) => {{69if let Some(lv) = $lv {70if lv.is_null() {71$null_case72} else {73$non_null_case74}75} else {76let lv_node = AExprBuilder::new_from_node($lv_node);7778let lv_is_null = lv_node.has_nulls(arena);79let lv_not_null = lv_node.has_no_nulls(arena);8081let null_case = lv_is_null.and($null_case, arena);82let non_null_case = lv_not_null.and($non_null_case, arena);8384null_case.or(non_null_case, arena).node()85}86}};87}88macro_rules! col {89(len) => {{ col!(PlSmallStr::from_static("len")) }};90($name:expr) => {{ AExprBuilder::new_from_node(arena.add(AExpr::Column($name))) }};91(min: $name:expr) => {{ col!(format_pl_smallstr!("{}_min", $name)) }};92(max: $name:expr) => {{ col!(format_pl_smallstr!("{}_max", $name)) }};93(null_count: $name:expr) => {{ col!(format_pl_smallstr!("{}_nc", $name)) }};94}95macro_rules! lv {96($lv:expr) => {{ AExprBuilder::lit_scalar(Scalar::from($lv), arena) }};97(idx: $lv:expr) => {{ AExprBuilder::lit_scalar(Scalar::new_idxsize($lv), arena) }};98}99100let specialized = (|| {101if let Some(Some(lv)) = constant_evaluate(e, arena, schema, 0) {102if let Some(av) = lv.to_any_value() {103return match av {104AnyValue::Null => Some(lv!(true).node()),105AnyValue::Boolean(b) => Some(lv!(!b).node()),106_ => None,107};108}109}110111match arena.get(e) {112AExpr::Element => None,113AExpr::Explode { .. } => None,114AExpr::Column(_) => None,115AExpr::Literal(_) => None,116AExpr::BinaryExpr { left, op, right } => {117let left = *left;118let right = *right;119120match op {121O::Eq | O::EqValidity => {122let ((col, _), (lv, lv_node)) =123get_binary_expr_col_and_lv(left, right, arena, schema)?;124let dtype = schema.get(col)?;125126if !does_dtype_have_sufficient_order(dtype) {127return None;128}129130let op = *op;131let col = col.clone();132133// col(A) == B -> {134// null_count(A) == 0 , if B.is_null(),135// null_count(A) == LEN || min(A) > B || max(A) < B, if B.is_not_null(),136// }137138Some(lv_cases!(139lv, lv_node,140null: {141if matches!(op, O::Eq) {142lv!(false).node()143} else {144let col_nc = col!(null_count: col);145let idx_zero = lv!(idx: 0);146col_nc.eq(idx_zero, arena).node()147}148},149not_null: {150let col_min = col!(min: col);151let col_max = col!(max: col);152153let min_is_defined = is_stat_defined(col_min.node(), dtype, arena);154let max_is_defined = is_stat_defined(col_max.node(), dtype, arena);155156let min_gt = col_min.gt(lv_node, arena);157let min_gt = min_gt.and(min_is_defined, arena);158159let max_lt = col_max.lt(lv_node, arena);160let max_lt = max_lt.and(max_is_defined, arena);161162let col_nc = col!(null_count: col);163let len = col!(len);164let all_nulls = col_nc.eq(len, arena);165166all_nulls.or(min_gt, arena).or(max_lt, arena).node()167}168))169},170O::NotEq | O::NotEqValidity => {171let ((col, _), (lv, lv_node)) =172get_binary_expr_col_and_lv(left, right, arena, schema)?;173let dtype = schema.get(col)?;174175if !does_dtype_have_sufficient_order(dtype) {176return None;177}178179let op = *op;180let col = col.clone();181182// col(A) != B -> {183// null_count(A) == LEN , if B.is_null(),184// null_count(A) == 0 && min(A) == B && max(A) == B, if B.is_not_null(),185// }186187Some(lv_cases!(188lv, lv_node,189null: {190if matches!(op, O::NotEq) {191lv!(false).node()192} else {193let col_nc = col!(null_count: col);194let len = col!(len);195col_nc.eq(len, arena).node()196}197},198not_null: {199let col_min = col!(min: col);200let col_max = col!(max: col);201let min_eq = col_min.eq(lv_node, arena);202let max_eq = col_max.eq(lv_node, arena);203204let col_nc = col!(null_count: col);205let idx_zero = lv!(idx: 0);206let no_nulls = col_nc.eq(idx_zero, arena);207208no_nulls.and(min_eq, arena).and(max_eq, arena).node()209}210))211},212O::Lt | O::Gt | O::LtEq | O::GtEq => {213let ((col, col_node), (lv, lv_node)) =214get_binary_expr_col_and_lv(left, right, arena, schema)?;215let dtype = schema.get(col)?;216217if !does_dtype_have_sufficient_order(dtype) {218return None;219}220221let col_is_left = col_node == left;222223let op = *op;224let col = col.clone();225let lv_may_be_null = lv.is_none_or(|lv| lv.is_null());226227// If B is null, this is always true.228//229// col(A) < B ~ B > col(A) ->230// null_count(A) == LEN || min(A) >= B231//232// col(A) > B ~ B < col(A) ->233// null_count(A) == LEN || max(A) <= B234//235// col(A) <= B ~ B >= col(A) ->236// null_count(A) == LEN || min(A) > B237//238// col(A) >= B ~ B <= col(A) ->239// null_count(A) == LEN || max(A) < B240241let stat = match (op, col_is_left) {242(O::Lt | O::LtEq, true) | (O::Gt | O::GtEq, false) => col!(min: col),243(O::Lt | O::LtEq, false) | (O::Gt | O::GtEq, true) => col!(max: col),244_ => unreachable!(),245};246let cmp_op = match (op, col_is_left) {247(O::Lt, true) | (O::Gt, false) => O::GtEq,248(O::Lt, false) | (O::Gt, true) => O::LtEq,249250(O::LtEq, true) | (O::GtEq, false) => O::Gt,251(O::LtEq, false) | (O::GtEq, true) => O::Lt,252253_ => unreachable!(),254};255256let stat_is_defined = is_stat_defined(stat, dtype, arena);257let cmp_op = stat.binary_op(lv_node, cmp_op, arena);258let mut expr = stat_is_defined.and(cmp_op, arena);259260if lv_may_be_null {261let has_nulls = lv_node.into_aexpr_builder().has_nulls(arena);262expr = has_nulls.or(expr, arena);263}264Some(expr.node())265},266267O::And | O::LogicalAnd => match (rec!(left), rec!(right)) {268(Some(left), Some(right)) => {269Some(AExprBuilder::new_from_node(left).or(right, arena).node())270},271(Some(n), None) | (None, Some(n)) => Some(n),272(None, None) => None,273},274O::Or | O::LogicalOr => {275let left = rec!(left)?;276let right = rec!(right)?;277Some(AExprBuilder::new_from_node(left).and(right, arena).node())278},279280O::Plus281| O::Minus282| O::Multiply283| O::Divide284| O::TrueDivide285| O::FloorDivide286| O::Modulus287| O::Xor => None,288}289},290AExpr::Cast { .. } => None,291AExpr::Sort { .. } => None,292AExpr::Gather { .. } => None,293AExpr::SortBy { .. } => None,294AExpr::Filter { .. } => None,295AExpr::Agg(..) | AExpr::AnonymousStreamingAgg { .. } => None,296AExpr::Ternary { .. } => None,297AExpr::AnonymousFunction { .. } => None,298AExpr::Eval { .. } => None,299AExpr::Function {300input, function, ..301} => match function {302IRFunctionExpr::Boolean(f) => match f {303#[cfg(feature = "is_in")]304IRBooleanFunction::IsIn { nulls_equal } => {305if !is_scalar_ae(input[1].node(), arena) {306return None;307}308309let nulls_equal = *nulls_equal;310let lv_node = input[1].node();311match (312into_column(input[0].node(), arena),313constant_evaluate(lv_node, arena, schema, 0),314) {315(Some(col), Some(_)) => {316use polars_core::prelude::ExplodeOptions;317318let dtype = schema.get(col)?;319if !does_dtype_have_sufficient_order(dtype) {320return None;321}322323// col(A).is_in([B1, ..., Bn]) ->324// ([B1, ..., Bn].has_no_nulls() || null_count(A) == 0) &&325// (326// min(A) > max[B1, ..., Bn] ||327// max(A) < min[B1, ..., Bn]328// )329let col = col.clone();330let lv_node = lv_node.into_aexpr_builder();331332let lv_node_exploded = lv_node.explode(333arena,334ExplodeOptions {335empty_as_null: false,336keep_nulls: true,337},338);339let lv_min = lv_node_exploded.min(arena);340let lv_max = lv_node_exploded.max(arena);341342let col_min = col!(min: col);343let col_max = col!(max: col);344345let min_is_defined = is_stat_defined(col_min, dtype, arena);346let max_is_defined = is_stat_defined(col_max, dtype, arena);347348let min_gt = col_min.gt(lv_max, arena);349let min_gt = min_is_defined.and(min_gt, arena);350351let max_lt = col_max.lt(lv_min, arena);352let max_lt = max_is_defined.and(max_lt, arena);353354let expr = min_gt.or(max_lt, arena);355356let col_nc = col!(null_count: col);357let col_has_no_nulls = col_nc.has_no_nulls(arena);358359let lv_has_not_nulls = lv_node_exploded.has_no_nulls(arena);360let null_case = lv_has_not_nulls.or(col_has_no_nulls, arena);361362let min_max_is_in = null_case.and(expr, arena);363364let col_nc = col!(null_count: col);365366let min_is_max = col_min.eq(col_max, arena); // Eq so that (None == None) == None367let idx_zero = lv!(idx: 0);368let has_no_nulls = col_nc.eq(idx_zero, arena);369370// The above case does always cover the fallback path. Since there371// is code that relies on the `min==max` always filtering normally,372// we add it here.373let exact_not_in =374col_min.is_in(lv_node, nulls_equal, arena).not(arena);375let exact_not_in =376min_is_max.and(has_no_nulls, arena).and(exact_not_in, arena);377378Some(exact_not_in.or(min_max_is_in, arena).node())379},380_ => None,381}382},383IRBooleanFunction::IsNull => {384let col = into_column(input[0].node(), arena)?;385386// col(A).is_null() -> null_count(A) == 0387let col_nc = col!(null_count: col);388let idx_zero = lv!(idx: 0);389Some(col_nc.eq(idx_zero, arena).node())390},391IRBooleanFunction::IsNotNull => {392let col = into_column(input[0].node(), arena)?;393394// col(A).is_not_null() -> null_count(A) == LEN395let col_nc = col!(null_count: col);396let len = col!(len);397Some(col_nc.eq(len, arena).node())398},399#[cfg(feature = "is_between")]400IRBooleanFunction::IsBetween { closed } => {401let col = into_column(input[0].node(), arena)?;402let dtype = schema.get(col)?;403404if !does_dtype_have_sufficient_order(dtype) {405return None;406}407408// col(A).is_between(X, Y) ->409// null_count(A) == LEN ||410// min(A) >(=) Y ||411// max(A) <(=) X412413let left_node = input[1].node();414let right_node = input[2].node();415416_ = constant_evaluate(left_node, arena, schema, 0)?;417_ = constant_evaluate(right_node, arena, schema, 0)?;418419let col = col.clone();420let closed = *closed;421422let lhs_no_nulls = left_node.into_aexpr_builder().has_no_nulls(arena);423let rhs_no_nulls = right_node.into_aexpr_builder().has_no_nulls(arena);424425let col_min = col!(min: col);426let col_max = col!(max: col);427428use polars_ops::series::ClosedInterval;429let (left, right) = match closed {430ClosedInterval::Both => (O::Lt, O::Gt),431ClosedInterval::Left => (O::Lt, O::GtEq),432ClosedInterval::Right => (O::LtEq, O::Gt),433ClosedInterval::None => (O::LtEq, O::GtEq),434};435436let left = col_max.binary_op(left_node, left, arena);437let right = col_min.binary_op(right_node, right, arena);438439let min_is_defined = is_stat_defined(col_min, dtype, arena);440let max_is_defined = is_stat_defined(col_max, dtype, arena);441442let left = max_is_defined.and(left, arena);443let right = min_is_defined.and(right, arena);444445let interval = left.or(right, arena);446Some(447lhs_no_nulls448.and(rhs_no_nulls, arena)449.and(interval, arena)450.node(),451)452},453_ => None,454},455_ => None,456},457#[cfg(feature = "dynamic_group_by")]458AExpr::Rolling { .. } => None,459AExpr::Over { .. } => None,460AExpr::Slice { .. } => None,461AExpr::Len => None,462}463})();464465if let Some(specialized) = specialized {466return Some(specialized);467}468469// If we don't have a specialized implementation we can check if the whole block is constant470// and fill that value in. This is especially useful when filtering hive partitions which are471// filtered using this expression and which set their min == max.472//473// Essentially, what this does is474// E -> all(col(A_min) == col(A_max) & col(A_nc) == 0 for A in LIVE(E)) & ~(E)475476let live_columns = PlIndexMap::from_iter(aexpr_to_leaf_names_iter(e, arena).map(|col| {477let min_name = format_pl_smallstr!("{col}_min");478(col.clone(), min_name)479}));480481// We cannot do proper equalities for these.482if live_columns483.iter()484.any(|(c, _)| schema.get(c).is_none_or(|dt| dt.is_categorical()))485{486return None;487}488489// Rename all uses of column names with the min value.490let expr = rename_columns(e, arena, &live_columns);491let mut expr = expr.into_aexpr_builder().not(arena);492for col in live_columns.keys() {493let col_min = col!(min: col);494let col_max = col!(max: col);495let col_nc = col!(null_count: col);496497let min_is_max = col_min.eq(col_max, arena); // Eq so that (None == None) == None498let idx_zero = lv!(idx: 0);499let has_no_nulls = col_nc.eq(idx_zero, arena);500501expr = min_is_max.and(has_no_nulls, arena).and(expr, arena);502}503Some(expr.node())504}505506507