use std::fmt;
use arrow::array::Array;
use arrow::bitmap::{Bitmap, BitmapBuilder};
use polars_core::prelude::*;
#[cfg(feature = "parquet")]
use polars_parquet::read::expr::{ParquetColumnExpr, ParquetScalar, SpecializedParquetColumnExpr};
use polars_utils::format_pl_smallstr;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
pub trait PhysicalIoExpr: Send + Sync {
fn evaluate_io(&self, df: &DataFrame) -> PolarsResult<Series>;
}
#[derive(Debug, Clone)]
pub enum SpecializedColumnPredicate {
Equal(Scalar),
Between(Scalar, Scalar),
EqualOneOf(Box<[Scalar]>),
StartsWith(Box<[u8]>),
EndsWith(Box<[u8]>),
StartEndsWith(Box<[u8]>, Box<[u8]>),
}
#[derive(Clone)]
pub struct ColumnPredicateExpr {
column_name: PlSmallStr,
dtype: DataType,
#[cfg(feature = "parquet")]
specialized: Option<SpecializedParquetColumnExpr>,
expr: Arc<dyn PhysicalIoExpr>,
}
impl ColumnPredicateExpr {
pub fn new(
column_name: PlSmallStr,
dtype: DataType,
expr: Arc<dyn PhysicalIoExpr>,
specialized: Option<SpecializedColumnPredicate>,
) -> Self {
use SpecializedColumnPredicate as S;
#[cfg(feature = "parquet")]
use SpecializedParquetColumnExpr as P;
#[cfg(feature = "parquet")]
let specialized = specialized.and_then(|s| {
Some(match s {
S::Equal(s) => P::Equal(cast_to_parquet_scalar(s)?),
S::Between(low, high) => {
P::Between(cast_to_parquet_scalar(low)?, cast_to_parquet_scalar(high)?)
},
S::EqualOneOf(scalars) => P::EqualOneOf(
scalars
.into_iter()
.map(|s| cast_to_parquet_scalar(s).ok_or(()))
.collect::<Result<Box<_>, ()>>()
.ok()?,
),
S::StartsWith(s) => P::StartsWith(s),
S::EndsWith(s) => P::EndsWith(s),
S::StartEndsWith(start, end) => P::StartEndsWith(start, end),
})
});
Self {
column_name,
dtype,
#[cfg(feature = "parquet")]
specialized,
expr,
}
}
}
#[cfg(feature = "parquet")]
impl ParquetColumnExpr for ColumnPredicateExpr {
fn evaluate_mut(&self, values: &dyn Array, bm: &mut BitmapBuilder) {
assert!(values.validity().is_none_or(|v| v.set_bits() == 0));
let series =
Series::from_chunk_and_dtype(self.column_name.clone(), values.to_boxed(), &self.dtype)
.unwrap();
let column = series.into_column();
let df = unsafe { DataFrame::new_no_checks(values.len(), vec![column]) };
let true_mask = self.expr.evaluate_io(&df).unwrap();
let true_mask = true_mask.bool().unwrap();
bm.reserve(true_mask.len());
for chunk in true_mask.downcast_iter() {
match chunk.validity() {
None => bm.extend_from_bitmap(chunk.values()),
Some(v) => bm.extend_from_bitmap(&(chunk.values() & v)),
}
}
}
fn evaluate_null(&self) -> bool {
let column = Column::full_null(self.column_name.clone(), 1, &self.dtype);
let df = unsafe { DataFrame::new_no_checks(1, vec![column]) };
let true_mask = self.expr.evaluate_io(&df).unwrap();
let true_mask = true_mask.bool().unwrap();
true_mask.get(0).unwrap_or(false)
}
fn as_specialized(&self) -> Option<&SpecializedParquetColumnExpr> {
self.specialized.as_ref()
}
}
#[cfg(feature = "parquet")]
fn cast_to_parquet_scalar(scalar: Scalar) -> Option<ParquetScalar> {
use {AnyValue as A, ParquetScalar as P};
Some(match scalar.into_value() {
A::Null => P::Null,
A::Boolean(v) => P::Boolean(v),
A::UInt8(v) => P::UInt8(v),
A::UInt16(v) => P::UInt16(v),
A::UInt32(v) => P::UInt32(v),
A::UInt64(v) => P::UInt64(v),
A::Int8(v) => P::Int8(v),
A::Int16(v) => P::Int16(v),
A::Int32(v) => P::Int32(v),
A::Int64(v) => P::Int64(v),
#[cfg(feature = "dtype-time")]
A::Date(v) => P::Int32(v),
#[cfg(feature = "dtype-datetime")]
A::Datetime(v, _, _) | A::DatetimeOwned(v, _, _) => P::Int64(v),
#[cfg(feature = "dtype-duration")]
A::Duration(v, _) => P::Int64(v),
#[cfg(feature = "dtype-time")]
A::Time(v) => P::Int64(v),
A::Float32(v) => P::Float32(v),
A::Float64(v) => P::Float64(v),
#[cfg(feature = "dtype-categorical")]
A::Categorical(_, _) | A::CategoricalOwned(_, _) | A::Enum(_, _) | A::EnumOwned(_, _) => {
return None;
},
A::String(v) => P::String(v.into()),
A::StringOwned(v) => P::String(v.as_str().into()),
A::Binary(v) => P::Binary(v.into()),
A::BinaryOwned(v) => P::Binary(v.into()),
_ => return None,
})
}
#[cfg(any(feature = "parquet", feature = "ipc"))]
pub fn apply_predicate(
df: &mut DataFrame,
predicate: Option<&dyn PhysicalIoExpr>,
parallel: bool,
) -> PolarsResult<()> {
if let (Some(predicate), false) = (&predicate, df.get_columns().is_empty()) {
let s = predicate.evaluate_io(df)?;
let mask = s.bool().expect("filter predicates was not of type boolean");
if parallel {
*df = df.filter(mask)?;
} else {
*df = df._filter_seq(mask)?;
}
}
Ok(())
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct ColumnStats {
field: Field,
null_count: Option<Series>,
min_value: Option<Series>,
max_value: Option<Series>,
}
impl ColumnStats {
pub fn new(
field: Field,
null_count: Option<Series>,
min_value: Option<Series>,
max_value: Option<Series>,
) -> Self {
Self {
field,
null_count,
min_value,
max_value,
}
}
pub fn from_field(field: Field) -> Self {
Self {
field,
null_count: None,
min_value: None,
max_value: None,
}
}
pub fn from_column_literal(s: Series) -> Self {
debug_assert_eq!(s.len(), 1);
Self {
field: s.field().into_owned(),
null_count: None,
min_value: Some(s.clone()),
max_value: Some(s),
}
}
pub fn field_name(&self) -> &PlSmallStr {
self.field.name()
}
pub fn dtype(&self) -> &DataType {
self.field.dtype()
}
pub fn get_null_count_state(&self) -> Option<&Series> {
self.null_count.as_ref()
}
pub fn get_min_state(&self) -> Option<&Series> {
self.min_value.as_ref()
}
pub fn get_max_state(&self) -> Option<&Series> {
self.max_value.as_ref()
}
pub fn null_count(&self) -> Option<usize> {
match self.dtype() {
#[cfg(feature = "dtype-struct")]
DataType::Struct(_) => None,
_ => {
let s = self.get_null_count_state()?;
if s.null_count() != s.len() {
s.sum().ok()
} else {
None
}
},
}
}
pub fn to_min_max(&self) -> Option<Series> {
let min_val = self.get_min_state()?;
let max_val = self.get_max_state()?;
let dtype = self.dtype();
if !use_min_max(dtype) {
return None;
}
let mut min_max_values = min_val.clone();
min_max_values.append(max_val).unwrap();
if min_max_values.null_count() > 0 {
None
} else {
Some(min_max_values)
}
}
pub fn to_min(&self) -> Option<&Series> {
let min_val = self.min_value.as_ref()?;
let dtype = min_val.dtype();
if !use_min_max(dtype) || min_val.len() != 1 {
return None;
}
if min_val.null_count() > 0 {
None
} else {
Some(min_val)
}
}
pub fn to_max(&self) -> Option<&Series> {
let max_val = self.max_value.as_ref()?;
let dtype = max_val.dtype();
if !use_min_max(dtype) || max_val.len() != 1 {
return None;
}
if max_val.null_count() > 0 {
None
} else {
Some(max_val)
}
}
}
fn use_min_max(dtype: &DataType) -> bool {
dtype.is_primitive_numeric()
|| dtype.is_temporal()
|| matches!(
dtype,
DataType::String | DataType::Binary | DataType::Boolean
)
}
pub struct ColumnStatistics {
pub dtype: DataType,
pub min: AnyValue<'static>,
pub max: AnyValue<'static>,
pub null_count: Option<IdxSize>,
}
pub trait SkipBatchPredicate: Send + Sync {
fn schema(&self) -> &SchemaRef;
fn can_skip_batch(
&self,
batch_size: IdxSize,
live_columns: &PlIndexSet<PlSmallStr>,
mut statistics: PlIndexMap<PlSmallStr, ColumnStatistics>,
) -> PolarsResult<bool> {
let mut columns = Vec::with_capacity(1 + live_columns.len() * 3);
columns.push(Column::new_scalar(
PlSmallStr::from_static("len"),
Scalar::new(IDX_DTYPE, batch_size.into()),
1,
));
for col in live_columns.iter() {
let dtype = self.schema().get(col).unwrap();
let (min, max, nc) = match statistics.swap_remove(col) {
None => (
Scalar::null(dtype.clone()),
Scalar::null(dtype.clone()),
Scalar::null(IDX_DTYPE),
),
Some(stat) => (
Scalar::new(dtype.clone(), stat.min),
Scalar::new(dtype.clone(), stat.max),
Scalar::new(
IDX_DTYPE,
stat.null_count.map_or(AnyValue::Null, |nc| nc.into()),
),
),
};
columns.extend([
Column::new_scalar(format_pl_smallstr!("{col}_min"), min, 1),
Column::new_scalar(format_pl_smallstr!("{col}_max"), max, 1),
Column::new_scalar(format_pl_smallstr!("{col}_nc"), nc, 1),
]);
}
let df = unsafe { DataFrame::new_no_checks(1, columns) };
Ok(self.evaluate_with_stat_df(&df)?.get_bit(0))
}
fn evaluate_with_stat_df(&self, df: &DataFrame) -> PolarsResult<Bitmap>;
}
#[derive(Clone)]
pub struct ColumnPredicates {
pub predicates:
PlHashMap<PlSmallStr, (Arc<dyn PhysicalIoExpr>, Option<SpecializedColumnPredicate>)>,
pub is_sumwise_complete: bool,
}
#[allow(clippy::derivable_impls)]
impl Default for ColumnPredicates {
fn default() -> Self {
Self {
predicates: PlHashMap::default(),
is_sumwise_complete: false,
}
}
}
pub struct PhysicalExprWithConstCols<T> {
constants: Vec<(PlSmallStr, Scalar)>,
child: T,
}
impl SkipBatchPredicate for PhysicalExprWithConstCols<Arc<dyn SkipBatchPredicate>> {
fn schema(&self) -> &SchemaRef {
self.child.schema()
}
fn evaluate_with_stat_df(&self, df: &DataFrame) -> PolarsResult<Bitmap> {
let mut df = df.clone();
for (name, scalar) in self.constants.iter() {
df.with_column(Column::new_scalar(
name.clone(),
scalar.clone(),
df.height(),
))?;
}
self.child.evaluate_with_stat_df(&df)
}
}
impl PhysicalIoExpr for PhysicalExprWithConstCols<Arc<dyn PhysicalIoExpr>> {
fn evaluate_io(&self, df: &DataFrame) -> PolarsResult<Series> {
let mut df = df.clone();
for (name, scalar) in self.constants.iter() {
df.with_column(Column::new_scalar(
name.clone(),
scalar.clone(),
df.height(),
))?;
}
self.child.evaluate_io(&df)
}
}
#[derive(Clone)]
pub struct ScanIOPredicate {
pub predicate: Arc<dyn PhysicalIoExpr>,
pub live_columns: Arc<PlIndexSet<PlSmallStr>>,
pub skip_batch_predicate: Option<Arc<dyn SkipBatchPredicate>>,
pub column_predicates: Arc<ColumnPredicates>,
pub hive_predicate: Option<Arc<dyn PhysicalIoExpr>>,
pub hive_predicate_is_full_predicate: bool,
}
impl ScanIOPredicate {
pub fn set_external_constant_columns(&mut self, constant_columns: Vec<(PlSmallStr, Scalar)>) {
if constant_columns.is_empty() {
return;
}
let mut live_columns = self.live_columns.as_ref().clone();
for (c, _) in constant_columns.iter() {
live_columns.swap_remove(c);
}
self.live_columns = Arc::new(live_columns);
if let Some(skip_batch_predicate) = self.skip_batch_predicate.take() {
let mut sbp_constant_columns = Vec::with_capacity(constant_columns.len() * 3);
for (c, v) in constant_columns.iter() {
sbp_constant_columns.push((format_pl_smallstr!("{c}_min"), v.clone()));
sbp_constant_columns.push((format_pl_smallstr!("{c}_max"), v.clone()));
let nc = if v.is_null() {
AnyValue::Null
} else {
(0 as IdxSize).into()
};
sbp_constant_columns
.push((format_pl_smallstr!("{c}_nc"), Scalar::new(IDX_DTYPE, nc)));
}
self.skip_batch_predicate = Some(Arc::new(PhysicalExprWithConstCols {
constants: sbp_constant_columns,
child: skip_batch_predicate,
}));
}
let mut column_predicates = self.column_predicates.as_ref().clone();
for (c, _) in constant_columns.iter() {
column_predicates.predicates.remove(c);
}
self.column_predicates = Arc::new(column_predicates);
self.predicate = Arc::new(PhysicalExprWithConstCols {
constants: constant_columns,
child: self.predicate.clone(),
});
}
}
impl fmt::Debug for ScanIOPredicate {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("scan_io_predicate")
}
}