Path: blob/main/crates/polars-plan/src/plans/aexpr/predicates/column_expr.rs
8416 views
//! This module creates predicates splits predicates into partial per-column predicates.12use polars_core::datatypes::DataType;3use polars_core::prelude::AnyValue;4use polars_core::scalar::Scalar;5use polars_core::schema::Schema;6use polars_io::predicates::SpecializedColumnPredicate;7use polars_ops::series::ClosedInterval;8use polars_utils::aliases::PlHashMap;9use polars_utils::arena::{Arena, Node};10use polars_utils::pl_str::PlSmallStr;1112use super::get_binary_expr_col_and_lv;13use crate::dsl::Operator;14use crate::plans::aexpr::evaluate::{constant_evaluate, into_column};15use crate::plans::{16AExpr, IRBooleanFunction, IRFunctionExpr, MintermIter, aexpr_to_leaf_names_iter,17};1819pub struct ColumnPredicates {20pub predicates: PlHashMap<PlSmallStr, (Node, Option<SpecializedColumnPredicate>)>,2122/// Are all column predicates AND-ed together the original predicate.23pub is_sumwise_complete: bool,24}2526pub fn aexpr_to_column_predicates(27root: Node,28expr_arena: &mut Arena<AExpr>,29schema: &Schema,30) -> ColumnPredicates {31let mut predicates =32PlHashMap::<PlSmallStr, (Node, Option<SpecializedColumnPredicate>)>::default();33let mut is_sumwise_complete = true;3435let minterms = MintermIter::new(root, expr_arena).collect::<Vec<_>>();3637let mut leaf_names = Vec::with_capacity(2);38for minterm in minterms {39leaf_names.clear();40leaf_names.extend(aexpr_to_leaf_names_iter(minterm, expr_arena).cloned());4142if leaf_names.len() != 1 {43is_sumwise_complete = false;44continue;45}4647let column = leaf_names.pop().unwrap();48let Some(dtype) = schema.get(&column) else {49is_sumwise_complete = false;50continue;51};5253// We really don't want to deal with these types.54use DataType as D;55match dtype {56#[cfg(feature = "dtype-categorical")]57D::Enum(_, _) | D::Categorical(_, _) => {58is_sumwise_complete = false;59continue;60},61#[cfg(feature = "dtype-decimal")]62D::Decimal(_, _) => {63is_sumwise_complete = false;64continue;65},66#[cfg(feature = "object")]67D::Object(_) => {68is_sumwise_complete = false;69continue;70},71#[cfg(feature = "dtype-f16")]72D::Float16 => {73is_sumwise_complete = false;74continue;75},76D::Float32 | D::Float64 => {77is_sumwise_complete = false;78continue;79},80_ if dtype.is_nested() => {81is_sumwise_complete = false;82continue;83},84_ => {},85}8687let dtype = dtype.clone();88let entry = predicates.entry(column);8990entry91.and_modify(|n| {92let left = n.0;93n.0 = expr_arena.add(AExpr::BinaryExpr {94left,95op: Operator::LogicalAnd,96right: minterm,97});98n.1 = None;99})100.or_insert_with(|| {101(102minterm,103Some(()).and_then(|_| {104let aexpr = expr_arena.get(minterm);105106match aexpr {107#[cfg(all(feature = "regex", feature = "strings"))]108AExpr::Function {109input,110function: IRFunctionExpr::StringExpr(str_function),111options: _,112} if matches!(113str_function,114crate::plans::IRStringFunction::Contains { literal: _, strict: true } |115crate::plans::IRStringFunction::EndsWith |116crate::plans::IRStringFunction::StartsWith117) => {118use crate::plans::IRStringFunction;119120assert_eq!(input.len(), 2);121into_column(input[0].node(), expr_arena)?;122let lv = constant_evaluate(123input[1].node(),124expr_arena,125schema,1260,127)??;128129if !lv.is_scalar() {130return None;131}132let lv = lv.extract_str()?;133134match str_function {135IRStringFunction::Contains { literal, strict: _ } => {136let pattern = if *literal {137regex::escape(lv)138} else {139lv.to_string()140};141let pattern = regex::bytes::Regex::new(&pattern).ok()?;142Some(SpecializedColumnPredicate::RegexMatch(pattern))143},144IRStringFunction::StartsWith => Some(SpecializedColumnPredicate::StartsWith(lv.as_bytes().into())),145IRStringFunction::EndsWith => Some(SpecializedColumnPredicate::EndsWith(lv.as_bytes().into())),146_ => unreachable!(),147}148},149AExpr::Function {150input,151function: IRFunctionExpr::Boolean(IRBooleanFunction::IsNull),152options: _,153} => {154assert_eq!(input.len(), 1);155if into_column(input[0].node(), expr_arena)156.is_some()157{158Some(SpecializedColumnPredicate::Equal(Scalar::null(159dtype,160)))161} else {162None163}164},165#[cfg(feature = "is_between")]166AExpr::Function {167input,168function: IRFunctionExpr::Boolean(IRBooleanFunction::IsBetween { closed }),169options: _,170} => {171into_column(input[0].node(), expr_arena)?;172173let (Some(l), Some(r)) = (174constant_evaluate(175input[1].node(),176expr_arena,177schema,1780,179)?,180constant_evaluate(181input[2].node(),182expr_arena,183schema,1840,185)?,186) else {187return None;188};189let l = l.to_any_value()?;190let r = r.to_any_value()?;191if l.dtype() != dtype || r.dtype() != dtype {192return None;193}194195let (low_closed, high_closed) = match closed {196ClosedInterval::Both => (true, true),197ClosedInterval::Left => (true, false),198ClosedInterval::Right => (false, true),199ClosedInterval::None => (false, false),200};201is_between(202&dtype,203Some(Scalar::new(dtype.clone(), l.into_static())),204Some(Scalar::new(dtype.clone(), r.into_static())),205low_closed,206high_closed,207)208},209#[cfg(feature = "is_in")]210AExpr::Function {211input,212function: IRFunctionExpr::Boolean(IRBooleanFunction::IsIn { nulls_equal }),213options: _,214} => {215into_column(input[0].node(), expr_arena)?;216217let values = constant_evaluate(218input[1].node(),219expr_arena,220schema,2210,222)??;223let values = values.to_any_value()?;224225let values = match values {226AnyValue::List(v) => v,227#[cfg(feature = "dtype-array")]228AnyValue::Array(v, _) => v,229_ => return None,230};231232if values.dtype() != &dtype {233return None;234}235if !nulls_equal && values.has_nulls() {236return None;237}238239let values = values.iter()240.map(|av| {241Scalar::new(dtype.clone(), av.into_static())242})243.collect();244245Some(SpecializedColumnPredicate::EqualOneOf(values))246},247AExpr::Function {248input,249function: IRFunctionExpr::Boolean(IRBooleanFunction::Not),250options: _,251} => {252if !dtype.is_bool() {253return None;254}255256assert_eq!(input.len(), 1);257if into_column(input[0].node(), expr_arena)258.is_some()259{260Some(SpecializedColumnPredicate::Equal(false.into()))261} else {262None263}264},265AExpr::BinaryExpr { left, op, right } => {266let ((_, _), (lv, lv_node)) =267get_binary_expr_col_and_lv(*left, *right, expr_arena, schema)?;268let lv = lv?;269let av = lv.to_any_value()?;270if av.dtype() != dtype {271return None;272}273let scalar = Scalar::new(dtype.clone(), av.into_static());274use Operator as O;275match (op, lv_node == *right) {276(O::Eq, _) if scalar.is_null() => None,277(O::Eq | O::EqValidity, _) => {278Some(SpecializedColumnPredicate::Equal(scalar))279},280(O::Lt, true) | (O::Gt, false) => {281is_between(&dtype, None, Some(scalar), false, false)282},283(O::Lt, false) | (O::Gt, true) => {284is_between(&dtype, Some(scalar), None, false, false)285},286(O::LtEq, true) | (O::GtEq, false) => {287is_between(&dtype, None, Some(scalar), false, true)288},289(O::LtEq, false) | (O::GtEq, true) => {290is_between(&dtype, Some(scalar), None, true, false)291},292_ => None,293}294},295_ => None,296}297}),298)299});300}301302ColumnPredicates {303predicates,304is_sumwise_complete,305}306}307308fn is_between(309dtype: &DataType,310low: Option<Scalar>,311high: Option<Scalar>,312mut low_closed: bool,313mut high_closed: bool,314) -> Option<SpecializedColumnPredicate> {315let dtype = dtype.to_physical();316317if !dtype.is_integer() {318return None;319}320assert!(low.is_some() || high.is_some());321322low_closed |= low.is_none();323high_closed |= high.is_none();324325let mut low = low.map_or_else(|| dtype.min().unwrap(), |sc| sc.to_physical());326let mut high = high.map_or_else(|| dtype.max().unwrap(), |sc| sc.to_physical());327328macro_rules! ints {329($($t:ident),+) => {330match (low.any_value_mut(), high.any_value_mut()) {331$(332(AV::$t(l), AV::$t(h)) => {333if !low_closed {334*l = l.checked_add(1)?;335}336if !high_closed {337*h = h.checked_sub(1)?;338}339if *l > *h {340// Really this ought to indicate that nothing should be341// loaded since the condition is impossible, but unclear342// how to do that at this abstraction layer. Could add343// SpecializedColumnPredicate::Impossible or something,344// maybe.345return None;346}347},348)+349_ => return None,350}351};352}353354use AnyValue as AV;355ints!(356Int8, Int16, Int32, Int64, Int128, UInt8, UInt16, UInt32, UInt64357);358359Some(SpecializedColumnPredicate::Between(low, high))360}361362#[cfg(test)]363mod tests {364use polars_error::PolarsResult;365366use super::*;367use crate::dsl::Expr;368use crate::dsl::functions::col;369use crate::plans::{ExprToIRContext, to_expr_ir, typed_lit};370371/// Given a single-column `Expr`, call `aexpr_to_column_predicates()` and372/// return the corresponding column's `Option<SpecializedColumnPredicate>`.373fn column_predicate_for_expr(374col_dtype: DataType,375col_name: &str,376expr: Expr,377) -> PolarsResult<Option<SpecializedColumnPredicate>> {378let mut arena = Arena::new();379let schema = Schema::from_iter_check_duplicates([(col_name.into(), col_dtype)])?;380let mut ctx = ExprToIRContext::new(&mut arena, &schema);381let expr_ir = to_expr_ir(expr, &mut ctx)?;382let column_predicates = aexpr_to_column_predicates(expr_ir.node(), &mut arena, &schema);383assert_eq!(column_predicates.predicates.len(), 1);384let Some((col_name2, (_, predicate))) =385column_predicates.predicates.clone().into_iter().next()386else {387panic!(388"Unexpected column predicates: {:?}",389column_predicates.predicates390);391};392assert_eq!(col_name, col_name2);393Ok(predicate)394}395396#[test]397fn column_predicate_for_inequality_operators() -> PolarsResult<()> {398let col_name = "testcol";399// Array of (expr, expected minimum, expected maximum):400let test_values: [(Expr, i8, i8); _] = [401(col(col_name).lt(typed_lit(10i8)), -128, 9),402(col(col_name).lt(typed_lit(-11i8)), -128, -12),403(col(col_name).gt(typed_lit(17i8)), 18, 127),404(col(col_name).gt(typed_lit(-10i8)), -9, 127),405(col(col_name).lt_eq(typed_lit(10i8)), -128, 10),406(col(col_name).lt_eq(typed_lit(-11i8)), -128, -11),407(col(col_name).gt_eq(typed_lit(17i8)), 17, 127),408(col(col_name).gt_eq(typed_lit(-10i8)), -10, 127),409];410for (expr, expected_min, expected_max) in test_values {411let predicate = column_predicate_for_expr(DataType::Int8, col_name, expr.clone())?;412if let Some(SpecializedColumnPredicate::Between(actual_min, actual_max)) = predicate {413assert_eq!(414(expected_min.into(), expected_max.into()),415(actual_min, actual_max)416);417} else {418panic!("{predicate:?} is unexpected for {expr:?}");419}420}421Ok(())422}423424#[test]425fn column_predicate_is_between() -> PolarsResult<()> {426let col_name = "testcol";427// ClosedInterval, expected min, expected max:428let test_values: [(_, i8, i8); _] = [429(ClosedInterval::Both, 1, 10),430(ClosedInterval::Left, 1, 9),431(ClosedInterval::Right, 2, 10),432(ClosedInterval::None, 2, 9),433];434for (interval, expected_min, expected_max) in test_values {435let expr = col(col_name).is_between(typed_lit(1i8), typed_lit(10i8), interval);436let predicate = column_predicate_for_expr(DataType::Int8, col_name, expr.clone())?;437if let Some(SpecializedColumnPredicate::Between(actual_min, actual_max)) = predicate {438assert_eq!(439(expected_min.into(), expected_max.into()),440(actual_min, actual_max)441);442} else {443panic!("{predicate:?} is unexpected for {expr:?}");444}445}446Ok(())447}448}449450451