Path: blob/main/crates/polars-plan/src/plans/aexpr/predicates/column_expr.rs
7889 views
//! This module creates predicates splits predicates into partial per-column predicates.12use polars_core::datatypes::DataType;3use polars_core::prelude::AnyValue;4use polars_core::scalar::Scalar;5use polars_core::schema::Schema;6use polars_io::predicates::SpecializedColumnPredicate;7use polars_ops::series::ClosedInterval;8use polars_utils::aliases::PlHashMap;9use polars_utils::arena::{Arena, Node};10use polars_utils::pl_str::PlSmallStr;1112use super::get_binary_expr_col_and_lv;13use crate::dsl::Operator;14use crate::plans::aexpr::evaluate::{constant_evaluate, into_column};15use crate::plans::{16AExpr, IRBooleanFunction, IRFunctionExpr, MintermIter, aexpr_to_leaf_names_iter,17};1819pub struct ColumnPredicates {20pub predicates: PlHashMap<PlSmallStr, (Node, Option<SpecializedColumnPredicate>)>,2122/// Are all column predicates AND-ed together the original predicate.23pub is_sumwise_complete: bool,24}2526pub fn aexpr_to_column_predicates(27root: Node,28expr_arena: &mut Arena<AExpr>,29schema: &Schema,30) -> ColumnPredicates {31let mut predicates =32PlHashMap::<PlSmallStr, (Node, Option<SpecializedColumnPredicate>)>::default();33let mut is_sumwise_complete = true;3435let minterms = MintermIter::new(root, expr_arena).collect::<Vec<_>>();3637let mut leaf_names = Vec::with_capacity(2);38for minterm in minterms {39leaf_names.clear();40leaf_names.extend(aexpr_to_leaf_names_iter(minterm, expr_arena).cloned());4142if leaf_names.len() != 1 {43is_sumwise_complete = false;44continue;45}4647let column = leaf_names.pop().unwrap();48let Some(dtype) = schema.get(&column) else {49is_sumwise_complete = false;50continue;51};5253// We really don't want to deal with these types.54use DataType as D;55match dtype {56#[cfg(feature = "dtype-categorical")]57D::Enum(_, _) | D::Categorical(_, _) => {58is_sumwise_complete = false;59continue;60},61#[cfg(feature = "dtype-decimal")]62D::Decimal(_, _) => {63is_sumwise_complete = false;64continue;65},66#[cfg(feature = "object")]67D::Object(_) => {68is_sumwise_complete = false;69continue;70},71#[cfg(feature = "dtype-f16")]72D::Float16 => {73is_sumwise_complete = false;74continue;75},76D::Float32 | D::Float64 => {77is_sumwise_complete = false;78continue;79},80_ if dtype.is_nested() => {81is_sumwise_complete = false;82continue;83},84_ => {},85}8687let dtype = dtype.clone();88let entry = predicates.entry(column);8990entry91.and_modify(|n| {92let left = n.0;93n.0 = expr_arena.add(AExpr::BinaryExpr {94left,95op: Operator::LogicalAnd,96right: minterm,97});98n.1 = None;99})100.or_insert_with(|| {101(102minterm,103Some(()).and_then(|_| {104let aexpr = expr_arena.get(minterm);105106match aexpr {107#[cfg(all(feature = "regex", feature = "strings"))]108AExpr::Function {109input,110function: IRFunctionExpr::StringExpr(str_function),111options: _,112} if matches!(113str_function,114crate::plans::IRStringFunction::Contains { literal: _, strict: true } |115crate::plans::IRStringFunction::EndsWith |116crate::plans::IRStringFunction::StartsWith117) => {118use crate::plans::IRStringFunction;119120assert_eq!(input.len(), 2);121into_column(input[0].node(), expr_arena)?;122let lv = constant_evaluate(123input[1].node(),124expr_arena,125schema,1260,127)??;128129if !lv.is_scalar() {130return None;131}132let lv = lv.extract_str()?;133134match str_function {135IRStringFunction::Contains { literal, strict: _ } => {136let pattern = if *literal {137regex::escape(lv)138} else {139lv.to_string()140};141let pattern = regex::bytes::Regex::new(&pattern).ok()?;142Some(SpecializedColumnPredicate::RegexMatch(pattern))143},144IRStringFunction::StartsWith => Some(SpecializedColumnPredicate::StartsWith(lv.as_bytes().into())),145IRStringFunction::EndsWith => Some(SpecializedColumnPredicate::EndsWith(lv.as_bytes().into())),146_ => unreachable!(),147}148},149AExpr::Function {150input,151function: IRFunctionExpr::Boolean(IRBooleanFunction::IsNull),152options: _,153} => {154assert_eq!(input.len(), 1);155if into_column(input[0].node(), expr_arena)156.is_some()157{158Some(SpecializedColumnPredicate::Equal(Scalar::null(159dtype,160)))161} else {162None163}164},165#[cfg(feature = "is_between")]166AExpr::Function {167input,168function: IRFunctionExpr::Boolean(IRBooleanFunction::IsBetween { closed }),169options: _,170} => {171let (Some(l), Some(r)) = (172constant_evaluate(173input[1].node(),174expr_arena,175schema,1760,177)?,178constant_evaluate(179input[2].node(),180expr_arena,181schema,1820,183)?,184) else {185return None;186};187let l = l.to_any_value()?;188let r = r.to_any_value()?;189if l.dtype() != dtype || r.dtype() != dtype {190return None;191}192193let (low_closed, high_closed) = match closed {194ClosedInterval::Both => (true, true),195ClosedInterval::Left => (true, false),196ClosedInterval::Right => (false, true),197ClosedInterval::None => (false, false),198};199is_between(200&dtype,201Some(Scalar::new(dtype.clone(), l.into_static())),202Some(Scalar::new(dtype.clone(), r.into_static())),203low_closed,204high_closed,205)206},207#[cfg(feature = "is_in")]208AExpr::Function {209input,210function: IRFunctionExpr::Boolean(IRBooleanFunction::IsIn { nulls_equal }),211options: _,212} => {213into_column(input[0].node(), expr_arena)?;214215let values = constant_evaluate(216input[1].node(),217expr_arena,218schema,2190,220)??;221let values = values.to_any_value()?;222223let values = match values {224AnyValue::List(v) => v,225#[cfg(feature = "dtype-array")]226AnyValue::Array(v, _) => v,227_ => return None,228};229230if values.dtype() != &dtype {231return None;232}233if !nulls_equal && values.has_nulls() {234return None;235}236237let values = values.iter()238.map(|av| {239Scalar::new(dtype.clone(), av.into_static())240})241.collect();242243Some(SpecializedColumnPredicate::EqualOneOf(values))244},245AExpr::Function {246input,247function: IRFunctionExpr::Boolean(IRBooleanFunction::Not),248options: _,249} => {250if !dtype.is_bool() {251return None;252}253254assert_eq!(input.len(), 1);255if into_column(input[0].node(), expr_arena)256.is_some()257{258Some(SpecializedColumnPredicate::Equal(false.into()))259} else {260None261}262},263AExpr::BinaryExpr { left, op, right } => {264let ((_, _), (lv, lv_node)) =265get_binary_expr_col_and_lv(*left, *right, expr_arena, schema)?;266let lv = lv?;267let av = lv.to_any_value()?;268if av.dtype() != dtype {269return None;270}271let scalar = Scalar::new(dtype.clone(), av.into_static());272use Operator as O;273match (op, lv_node == *right) {274(O::Eq, _) if scalar.is_null() => None,275(O::Eq | O::EqValidity, _) => {276Some(SpecializedColumnPredicate::Equal(scalar))277},278(O::Lt, true) | (O::Gt, false) => {279is_between(&dtype, None, Some(scalar), false, false)280},281(O::Lt, false) | (O::Gt, true) => {282is_between(&dtype, Some(scalar), None, false, false)283},284(O::LtEq, true) | (O::GtEq, false) => {285is_between(&dtype, None, Some(scalar), false, true)286},287(O::LtEq, false) | (O::GtEq, true) => {288is_between(&dtype, Some(scalar), None, true, false)289},290_ => None,291}292},293_ => None,294}295}),296)297});298}299300ColumnPredicates {301predicates,302is_sumwise_complete,303}304}305306fn is_between(307dtype: &DataType,308low: Option<Scalar>,309high: Option<Scalar>,310mut low_closed: bool,311mut high_closed: bool,312) -> Option<SpecializedColumnPredicate> {313let dtype = dtype.to_physical();314315if !dtype.is_integer() {316return None;317}318assert!(low.is_some() || high.is_some());319320low_closed |= low.is_none();321high_closed |= high.is_none();322323let mut low = low.map_or_else(|| dtype.min().unwrap(), |sc| sc.to_physical());324let mut high = high.map_or_else(|| dtype.max().unwrap(), |sc| sc.to_physical());325326macro_rules! ints {327($($t:ident),+) => {328match (low.any_value_mut(), high.any_value_mut()) {329$(330(AV::$t(l), AV::$t(h)) => {331if !low_closed {332*l = l.checked_add(1)?;333}334if !high_closed {335*h = h.checked_sub(1)?;336}337if *l > *h {338// Really this ought to indicate that nothing should be339// loaded since the condition is impossible, but unclear340// how to do that at this abstraction layer. Could add341// SpecializedColumnPredicate::Impossible or something,342// maybe.343return None;344}345},346)+347_ => return None,348}349};350}351352use AnyValue as AV;353ints!(354Int8, Int16, Int32, Int64, Int128, UInt8, UInt16, UInt32, UInt64355);356357Some(SpecializedColumnPredicate::Between(low, high))358}359360#[cfg(test)]361mod tests {362use polars_error::PolarsResult;363364use super::*;365use crate::dsl::{Expr, col};366use crate::plans::{ExprToIRContext, to_expr_ir, typed_lit};367368/// Given a single-column `Expr`, call `aexpr_to_column_predicates()` and369/// return the corresponding column's `Option<SpecializedColumnPredicate>`.370fn column_predicate_for_expr(371col_dtype: DataType,372col_name: &str,373expr: Expr,374) -> PolarsResult<Option<SpecializedColumnPredicate>> {375let mut arena = Arena::new();376let schema = Schema::from_iter_check_duplicates([(col_name.into(), col_dtype)])?;377let mut ctx = ExprToIRContext::new(&mut arena, &schema);378let expr_ir = to_expr_ir(expr, &mut ctx)?;379let column_predicates = aexpr_to_column_predicates(expr_ir.node(), &mut arena, &schema);380assert_eq!(column_predicates.predicates.len(), 1);381let Some((col_name2, (_, predicate))) =382column_predicates.predicates.clone().into_iter().next()383else {384panic!(385"Unexpected column predicates: {:?}",386column_predicates.predicates387);388};389assert_eq!(col_name, col_name2);390Ok(predicate)391}392393#[test]394fn column_predicate_for_inequality_operators() -> PolarsResult<()> {395let col_name = "testcol";396// Array of (expr, expected minimum, expected maximum):397let test_values: [(Expr, i8, i8); _] = [398(col(col_name).lt(typed_lit(10i8)), -128, 9),399(col(col_name).lt(typed_lit(-11i8)), -128, -12),400(col(col_name).gt(typed_lit(17i8)), 18, 127),401(col(col_name).gt(typed_lit(-10i8)), -9, 127),402(col(col_name).lt_eq(typed_lit(10i8)), -128, 10),403(col(col_name).lt_eq(typed_lit(-11i8)), -128, -11),404(col(col_name).gt_eq(typed_lit(17i8)), 17, 127),405(col(col_name).gt_eq(typed_lit(-10i8)), -10, 127),406];407for (expr, expected_min, expected_max) in test_values {408let predicate = column_predicate_for_expr(DataType::Int8, col_name, expr.clone())?;409if let Some(SpecializedColumnPredicate::Between(actual_min, actual_max)) = predicate {410assert_eq!(411(expected_min.into(), expected_max.into()),412(actual_min, actual_max)413);414} else {415panic!("{predicate:?} is unexpected for {expr:?}");416}417}418Ok(())419}420421#[test]422fn column_predicate_is_between() -> PolarsResult<()> {423let col_name = "testcol";424// ClosedInterval, expected min, expected max:425let test_values: [(_, i8, i8); _] = [426(ClosedInterval::Both, 1, 10),427(ClosedInterval::Left, 1, 9),428(ClosedInterval::Right, 2, 10),429(ClosedInterval::None, 2, 9),430];431for (interval, expected_min, expected_max) in test_values {432let expr = col(col_name).is_between(typed_lit(1i8), typed_lit(10i8), interval);433let predicate = column_predicate_for_expr(DataType::Int8, col_name, expr.clone())?;434if let Some(SpecializedColumnPredicate::Between(actual_min, actual_max)) = predicate {435assert_eq!(436(expected_min.into(), expected_max.into()),437(actual_min, actual_max)438);439} else {440panic!("{predicate:?} is unexpected for {expr:?}");441}442}443Ok(())444}445}446447448