Path: blob/main/crates/polars-mem-engine/src/predicate.rs
6939 views
use core::fmt;1use std::sync::Arc;23use arrow::bitmap::Bitmap;4use polars_core::frame::DataFrame;5use polars_core::prelude::{AnyValue, Column, Field, GroupPositions, PlHashMap, PlIndexSet};6use polars_core::scalar::Scalar;7use polars_core::schema::{Schema, SchemaRef};8use polars_error::PolarsResult;9use polars_expr::prelude::{AggregationContext, PhysicalExpr, phys_expr_to_io_expr};10use polars_expr::state::ExecutionState;11use polars_io::predicates::{12ColumnPredicates, ScanIOPredicate, SkipBatchPredicate, SpecializedColumnPredicate,13};14use polars_utils::pl_str::PlSmallStr;15use polars_utils::{IdxSize, format_pl_smallstr};1617/// All the expressions and metadata used to filter out rows using predicates.18#[derive(Clone)]19pub struct ScanPredicate {20pub predicate: Arc<dyn PhysicalExpr>,2122/// Column names that are used in the predicate.23pub live_columns: Arc<PlIndexSet<PlSmallStr>>,2425/// A predicate expression used to skip record batches based on its statistics.26///27/// This expression will be given a batch size along with a `min`, `max` and `null count` for28/// each live column (set to `null` when it is not known) and the expression evaluates to29/// `true` if the whole batch can for sure be skipped. This may be conservative and evaluate to30/// `false` even when the batch could theoretically be skipped.31pub skip_batch_predicate: Option<Arc<dyn PhysicalExpr>>,3233/// Partial predicates for each column for filter when loading columnar formats.34pub column_predicates: PhysicalColumnPredicates,3536/// Predicate only referring to hive columns.37pub hive_predicate: Option<Arc<dyn PhysicalExpr>>,38pub hive_predicate_is_full_predicate: bool,39}4041impl fmt::Debug for ScanPredicate {42fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {43f.write_str("scan_predicate")44}45}4647#[derive(Clone)]48pub struct PhysicalColumnPredicates {49pub predicates:50PlHashMap<PlSmallStr, (Arc<dyn PhysicalExpr>, Option<SpecializedColumnPredicate>)>,51pub is_sumwise_complete: bool,52}5354/// Helper to implement [`SkipBatchPredicate`].55struct SkipBatchPredicateHelper {56skip_batch_predicate: Arc<dyn PhysicalExpr>,57schema: SchemaRef,58}5960/// Helper for the [`PhysicalExpr`] trait to include constant columns.61pub struct PhysicalExprWithConstCols {62constants: Vec<(PlSmallStr, Scalar)>,63child: Arc<dyn PhysicalExpr>,64}6566impl PhysicalExpr for PhysicalExprWithConstCols {67fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {68let mut df = df.clone();69for (name, scalar) in &self.constants {70df.with_column(Column::new_scalar(71name.clone(),72scalar.clone(),73df.height(),74))?;75}7677self.child.evaluate(&df, state)78}7980fn evaluate_on_groups<'a>(81&self,82df: &DataFrame,83groups: &'a GroupPositions,84state: &ExecutionState,85) -> PolarsResult<AggregationContext<'a>> {86let mut df = df.clone();87for (name, scalar) in &self.constants {88df.with_column(Column::new_scalar(89name.clone(),90scalar.clone(),91df.height(),92))?;93}9495self.child.evaluate_on_groups(&df, groups, state)96}9798fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {99self.child.to_field(input_schema)100}101fn is_scalar(&self) -> bool {102self.child.is_scalar()103}104}105106impl ScanPredicate {107pub fn with_constant_columns(108&self,109constant_columns: impl IntoIterator<Item = (PlSmallStr, Scalar)>,110) -> Self {111let constant_columns = constant_columns.into_iter();112113let mut live_columns = self.live_columns.as_ref().clone();114let mut skip_batch_predicate_constants =115Vec::with_capacity(if self.skip_batch_predicate.is_some() {1161 + constant_columns.size_hint().0 * 3117} else {118Default::default()119});120121let predicate_constants = constant_columns122.filter_map(|(name, scalar): (PlSmallStr, Scalar)| {123if !live_columns.swap_remove(&name) {124return None;125}126127if self.skip_batch_predicate.is_some() {128let mut null_count: Scalar = (0 as IdxSize).into();129130// If the constant value is Null, we don't know how many nulls there are131// because the length of the batch may vary.132if scalar.is_null() {133null_count.update(AnyValue::Null);134}135136skip_batch_predicate_constants.extend([137(format_pl_smallstr!("{name}_min"), scalar.clone()),138(format_pl_smallstr!("{name}_max"), scalar.clone()),139(format_pl_smallstr!("{name}_nc"), null_count),140]);141}142143Some((name, scalar))144})145.collect();146147let predicate = Arc::new(PhysicalExprWithConstCols {148constants: predicate_constants,149child: self.predicate.clone(),150});151let skip_batch_predicate = self.skip_batch_predicate.as_ref().map(|skp| {152Arc::new(PhysicalExprWithConstCols {153constants: skip_batch_predicate_constants,154child: skp.clone(),155}) as _156});157158Self {159predicate,160live_columns: Arc::new(live_columns),161skip_batch_predicate,162column_predicates: self.column_predicates.clone(), // Q? Maybe this should cull163// predicates.164hive_predicate: None,165hive_predicate_is_full_predicate: false,166}167}168169/// Create a predicate to skip batches using statistics.170pub(crate) fn to_dyn_skip_batch_predicate(171&self,172schema: SchemaRef,173) -> Option<Arc<dyn SkipBatchPredicate>> {174let skip_batch_predicate = self.skip_batch_predicate.as_ref()?.clone();175Some(Arc::new(SkipBatchPredicateHelper {176skip_batch_predicate,177schema,178}))179}180181pub fn to_io(182&self,183skip_batch_predicate: Option<&Arc<dyn SkipBatchPredicate>>,184schema: SchemaRef,185) -> ScanIOPredicate {186ScanIOPredicate {187predicate: phys_expr_to_io_expr(self.predicate.clone()),188live_columns: self.live_columns.clone(),189skip_batch_predicate: skip_batch_predicate190.cloned()191.or_else(|| self.to_dyn_skip_batch_predicate(schema)),192column_predicates: Arc::new(ColumnPredicates {193predicates: self194.column_predicates195.predicates196.iter()197.map(|(n, (p, s))| (n.clone(), (phys_expr_to_io_expr(p.clone()), s.clone())))198.collect(),199is_sumwise_complete: self.column_predicates.is_sumwise_complete,200}),201hive_predicate: self.hive_predicate.clone().map(phys_expr_to_io_expr),202hive_predicate_is_full_predicate: self.hive_predicate_is_full_predicate,203}204}205}206207impl SkipBatchPredicate for SkipBatchPredicateHelper {208fn schema(&self) -> &SchemaRef {209&self.schema210}211212fn evaluate_with_stat_df(&self, df: &DataFrame) -> PolarsResult<Bitmap> {213let array = self214.skip_batch_predicate215.evaluate(df, &Default::default())?;216let array = array.bool()?;217let array = array.downcast_as_array();218219let array = if let Some(validity) = array.validity() {220array.values() & validity221} else {222array.values().clone()223};224225// @NOTE: Certain predicates like `1 == 1` will only output 1 value. We need to broadcast226// the result back to the dataframe length.227if array.len() == 1 && df.height() != 0 {228return Ok(Bitmap::new_with_value(array.get_bit(0), df.height()));229}230231assert_eq!(array.len(), df.height());232Ok(array)233}234}235236237