Path: blob/main/crates/polars-mem-engine/src/scan_predicate/mod.rs
7884 views
pub mod functions;1pub mod skip_files_mask;2use core::fmt;3use std::sync::Arc;45use arrow::bitmap::Bitmap;6pub use functions::{create_scan_predicate, initialize_scan_predicate};7use polars_core::frame::DataFrame;8use polars_core::prelude::{AnyValue, Column, Field, GroupPositions, PlHashMap, PlIndexSet};9use polars_core::scalar::Scalar;10use polars_core::schema::{Schema, SchemaRef};11use polars_error::PolarsResult;12use polars_expr::prelude::{AggregationContext, PhysicalExpr, phys_expr_to_io_expr};13use polars_expr::state::ExecutionState;14use polars_io::predicates::{15ColumnPredicates, ScanIOPredicate, SkipBatchPredicate, SpecializedColumnPredicate,16};17use polars_utils::pl_str::PlSmallStr;18use polars_utils::{IdxSize, format_pl_smallstr};1920/// All the expressions and metadata used to filter out rows using predicates.21#[derive(Clone)]22pub struct ScanPredicate {23pub predicate: Arc<dyn PhysicalExpr>,2425/// Column names that are used in the predicate.26pub live_columns: Arc<PlIndexSet<PlSmallStr>>,2728/// A predicate expression used to skip record batches based on its statistics.29///30/// This expression will be given a batch size along with a `min`, `max` and `null count` for31/// each live column (set to `null` when it is not known) and the expression evaluates to32/// `true` if the whole batch can for sure be skipped. This may be conservative and evaluate to33/// `false` even when the batch could theoretically be skipped.34pub skip_batch_predicate: Option<Arc<dyn PhysicalExpr>>,3536/// Partial predicates for each column for filter when loading columnar formats.37pub column_predicates: PhysicalColumnPredicates,3839/// Predicate only referring to hive columns.40pub hive_predicate: Option<Arc<dyn PhysicalExpr>>,41pub hive_predicate_is_full_predicate: bool,42}4344impl fmt::Debug for ScanPredicate {45fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {46f.write_str("scan_predicate")47}48}4950#[derive(Clone)]51pub struct PhysicalColumnPredicates {52pub predicates:53PlHashMap<PlSmallStr, (Arc<dyn PhysicalExpr>, Option<SpecializedColumnPredicate>)>,54pub is_sumwise_complete: bool,55}5657/// Helper to implement [`SkipBatchPredicate`].58struct SkipBatchPredicateHelper {59skip_batch_predicate: Arc<dyn PhysicalExpr>,60schema: SchemaRef,61}6263/// Helper for the [`PhysicalExpr`] trait to include constant columns.64pub struct PhysicalExprWithConstCols {65constants: Vec<(PlSmallStr, Scalar)>,66child: Arc<dyn PhysicalExpr>,67}6869impl PhysicalExpr for PhysicalExprWithConstCols {70fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {71let mut df = df.clone();72for (name, scalar) in &self.constants {73df.with_column(Column::new_scalar(74name.clone(),75scalar.clone(),76df.height(),77))?;78}7980self.child.evaluate(&df, state)81}8283fn evaluate_on_groups<'a>(84&self,85df: &DataFrame,86groups: &'a GroupPositions,87state: &ExecutionState,88) -> PolarsResult<AggregationContext<'a>> {89let mut df = df.clone();90for (name, scalar) in &self.constants {91df.with_column(Column::new_scalar(92name.clone(),93scalar.clone(),94df.height(),95))?;96}9798self.child.evaluate_on_groups(&df, groups, state)99}100101fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {102self.child.to_field(input_schema)103}104fn is_scalar(&self) -> bool {105self.child.is_scalar()106}107}108109impl ScanPredicate {110pub fn with_constant_columns(111&self,112constant_columns: impl IntoIterator<Item = (PlSmallStr, Scalar)>,113) -> Self {114let constant_columns = constant_columns.into_iter();115116let mut live_columns = self.live_columns.as_ref().clone();117let mut skip_batch_predicate_constants =118Vec::with_capacity(if self.skip_batch_predicate.is_some() {1191 + constant_columns.size_hint().0 * 3120} else {121Default::default()122});123124let predicate_constants = constant_columns125.filter_map(|(name, scalar): (PlSmallStr, Scalar)| {126if !live_columns.swap_remove(&name) {127return None;128}129130if self.skip_batch_predicate.is_some() {131let mut null_count: Scalar = (0 as IdxSize).into();132133// If the constant value is Null, we don't know how many nulls there are134// because the length of the batch may vary.135if scalar.is_null() {136null_count.update(AnyValue::Null);137}138139skip_batch_predicate_constants.extend([140(format_pl_smallstr!("{name}_min"), scalar.clone()),141(format_pl_smallstr!("{name}_max"), scalar.clone()),142(format_pl_smallstr!("{name}_nc"), null_count),143]);144}145146Some((name, scalar))147})148.collect();149150let predicate = Arc::new(PhysicalExprWithConstCols {151constants: predicate_constants,152child: self.predicate.clone(),153});154let skip_batch_predicate = self.skip_batch_predicate.as_ref().map(|skp| {155Arc::new(PhysicalExprWithConstCols {156constants: skip_batch_predicate_constants,157child: skp.clone(),158}) as _159});160161Self {162predicate,163live_columns: Arc::new(live_columns),164skip_batch_predicate,165column_predicates: self.column_predicates.clone(), // Q? Maybe this should cull166// predicates.167hive_predicate: None,168hive_predicate_is_full_predicate: false,169}170}171172/// Create a predicate to skip batches using statistics.173pub(crate) fn to_dyn_skip_batch_predicate(174&self,175schema: SchemaRef,176) -> Option<Arc<dyn SkipBatchPredicate>> {177let skip_batch_predicate = self.skip_batch_predicate.as_ref()?.clone();178Some(Arc::new(SkipBatchPredicateHelper {179skip_batch_predicate,180schema,181}))182}183184pub fn to_io(185&self,186skip_batch_predicate: Option<&Arc<dyn SkipBatchPredicate>>,187schema: SchemaRef,188) -> ScanIOPredicate {189ScanIOPredicate {190predicate: phys_expr_to_io_expr(self.predicate.clone()),191live_columns: self.live_columns.clone(),192skip_batch_predicate: skip_batch_predicate193.cloned()194.or_else(|| self.to_dyn_skip_batch_predicate(schema)),195column_predicates: Arc::new(ColumnPredicates {196predicates: self197.column_predicates198.predicates199.iter()200.map(|(n, (p, s))| (n.clone(), (phys_expr_to_io_expr(p.clone()), s.clone())))201.collect(),202is_sumwise_complete: self.column_predicates.is_sumwise_complete,203}),204hive_predicate: self.hive_predicate.clone().map(phys_expr_to_io_expr),205hive_predicate_is_full_predicate: self.hive_predicate_is_full_predicate,206}207}208}209210impl SkipBatchPredicate for SkipBatchPredicateHelper {211fn schema(&self) -> &SchemaRef {212&self.schema213}214215fn evaluate_with_stat_df(&self, df: &DataFrame) -> PolarsResult<Bitmap> {216let array = self217.skip_batch_predicate218.evaluate(df, &Default::default())?;219let array = array.bool()?.rechunk();220let array = array.downcast_as_array();221222let array = if let Some(validity) = array.validity() {223array.values() & validity224} else {225array.values().clone()226};227228// @NOTE: Certain predicates like `1 == 1` will only output 1 value. We need to broadcast229// the result back to the dataframe length.230if array.len() == 1 && df.height() != 0 {231return Ok(Bitmap::new_with_value(array.get_bit(0), df.height()));232}233234assert_eq!(array.len(), df.height());235Ok(array)236}237}238239240