Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-mem-engine/src/scan_predicate/mod.rs
7884 views
1
pub mod functions;
2
pub mod skip_files_mask;
3
use core::fmt;
4
use std::sync::Arc;
5
6
use arrow::bitmap::Bitmap;
7
pub use functions::{create_scan_predicate, initialize_scan_predicate};
8
use polars_core::frame::DataFrame;
9
use polars_core::prelude::{AnyValue, Column, Field, GroupPositions, PlHashMap, PlIndexSet};
10
use polars_core::scalar::Scalar;
11
use polars_core::schema::{Schema, SchemaRef};
12
use polars_error::PolarsResult;
13
use polars_expr::prelude::{AggregationContext, PhysicalExpr, phys_expr_to_io_expr};
14
use polars_expr::state::ExecutionState;
15
use polars_io::predicates::{
16
ColumnPredicates, ScanIOPredicate, SkipBatchPredicate, SpecializedColumnPredicate,
17
};
18
use polars_utils::pl_str::PlSmallStr;
19
use polars_utils::{IdxSize, format_pl_smallstr};
20
21
/// All the expressions and metadata used to filter out rows using predicates.
22
#[derive(Clone)]
23
pub struct ScanPredicate {
24
pub predicate: Arc<dyn PhysicalExpr>,
25
26
/// Column names that are used in the predicate.
27
pub live_columns: Arc<PlIndexSet<PlSmallStr>>,
28
29
/// A predicate expression used to skip record batches based on its statistics.
30
///
31
/// This expression will be given a batch size along with a `min`, `max` and `null count` for
32
/// each live column (set to `null` when it is not known) and the expression evaluates to
33
/// `true` if the whole batch can for sure be skipped. This may be conservative and evaluate to
34
/// `false` even when the batch could theoretically be skipped.
35
pub skip_batch_predicate: Option<Arc<dyn PhysicalExpr>>,
36
37
/// Partial predicates for each column for filter when loading columnar formats.
38
pub column_predicates: PhysicalColumnPredicates,
39
40
/// Predicate only referring to hive columns.
41
pub hive_predicate: Option<Arc<dyn PhysicalExpr>>,
42
pub hive_predicate_is_full_predicate: bool,
43
}
44
45
impl fmt::Debug for ScanPredicate {
46
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47
f.write_str("scan_predicate")
48
}
49
}
50
51
#[derive(Clone)]
52
pub struct PhysicalColumnPredicates {
53
pub predicates:
54
PlHashMap<PlSmallStr, (Arc<dyn PhysicalExpr>, Option<SpecializedColumnPredicate>)>,
55
pub is_sumwise_complete: bool,
56
}
57
58
/// Helper to implement [`SkipBatchPredicate`].
59
struct SkipBatchPredicateHelper {
60
skip_batch_predicate: Arc<dyn PhysicalExpr>,
61
schema: SchemaRef,
62
}
63
64
/// Helper for the [`PhysicalExpr`] trait to include constant columns.
65
pub struct PhysicalExprWithConstCols {
66
constants: Vec<(PlSmallStr, Scalar)>,
67
child: Arc<dyn PhysicalExpr>,
68
}
69
70
impl PhysicalExpr for PhysicalExprWithConstCols {
71
fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {
72
let mut df = df.clone();
73
for (name, scalar) in &self.constants {
74
df.with_column(Column::new_scalar(
75
name.clone(),
76
scalar.clone(),
77
df.height(),
78
))?;
79
}
80
81
self.child.evaluate(&df, state)
82
}
83
84
fn evaluate_on_groups<'a>(
85
&self,
86
df: &DataFrame,
87
groups: &'a GroupPositions,
88
state: &ExecutionState,
89
) -> PolarsResult<AggregationContext<'a>> {
90
let mut df = df.clone();
91
for (name, scalar) in &self.constants {
92
df.with_column(Column::new_scalar(
93
name.clone(),
94
scalar.clone(),
95
df.height(),
96
))?;
97
}
98
99
self.child.evaluate_on_groups(&df, groups, state)
100
}
101
102
fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {
103
self.child.to_field(input_schema)
104
}
105
fn is_scalar(&self) -> bool {
106
self.child.is_scalar()
107
}
108
}
109
110
impl ScanPredicate {
111
pub fn with_constant_columns(
112
&self,
113
constant_columns: impl IntoIterator<Item = (PlSmallStr, Scalar)>,
114
) -> Self {
115
let constant_columns = constant_columns.into_iter();
116
117
let mut live_columns = self.live_columns.as_ref().clone();
118
let mut skip_batch_predicate_constants =
119
Vec::with_capacity(if self.skip_batch_predicate.is_some() {
120
1 + constant_columns.size_hint().0 * 3
121
} else {
122
Default::default()
123
});
124
125
let predicate_constants = constant_columns
126
.filter_map(|(name, scalar): (PlSmallStr, Scalar)| {
127
if !live_columns.swap_remove(&name) {
128
return None;
129
}
130
131
if self.skip_batch_predicate.is_some() {
132
let mut null_count: Scalar = (0 as IdxSize).into();
133
134
// If the constant value is Null, we don't know how many nulls there are
135
// because the length of the batch may vary.
136
if scalar.is_null() {
137
null_count.update(AnyValue::Null);
138
}
139
140
skip_batch_predicate_constants.extend([
141
(format_pl_smallstr!("{name}_min"), scalar.clone()),
142
(format_pl_smallstr!("{name}_max"), scalar.clone()),
143
(format_pl_smallstr!("{name}_nc"), null_count),
144
]);
145
}
146
147
Some((name, scalar))
148
})
149
.collect();
150
151
let predicate = Arc::new(PhysicalExprWithConstCols {
152
constants: predicate_constants,
153
child: self.predicate.clone(),
154
});
155
let skip_batch_predicate = self.skip_batch_predicate.as_ref().map(|skp| {
156
Arc::new(PhysicalExprWithConstCols {
157
constants: skip_batch_predicate_constants,
158
child: skp.clone(),
159
}) as _
160
});
161
162
Self {
163
predicate,
164
live_columns: Arc::new(live_columns),
165
skip_batch_predicate,
166
column_predicates: self.column_predicates.clone(), // Q? Maybe this should cull
167
// predicates.
168
hive_predicate: None,
169
hive_predicate_is_full_predicate: false,
170
}
171
}
172
173
/// Create a predicate to skip batches using statistics.
174
pub(crate) fn to_dyn_skip_batch_predicate(
175
&self,
176
schema: SchemaRef,
177
) -> Option<Arc<dyn SkipBatchPredicate>> {
178
let skip_batch_predicate = self.skip_batch_predicate.as_ref()?.clone();
179
Some(Arc::new(SkipBatchPredicateHelper {
180
skip_batch_predicate,
181
schema,
182
}))
183
}
184
185
pub fn to_io(
186
&self,
187
skip_batch_predicate: Option<&Arc<dyn SkipBatchPredicate>>,
188
schema: SchemaRef,
189
) -> ScanIOPredicate {
190
ScanIOPredicate {
191
predicate: phys_expr_to_io_expr(self.predicate.clone()),
192
live_columns: self.live_columns.clone(),
193
skip_batch_predicate: skip_batch_predicate
194
.cloned()
195
.or_else(|| self.to_dyn_skip_batch_predicate(schema)),
196
column_predicates: Arc::new(ColumnPredicates {
197
predicates: self
198
.column_predicates
199
.predicates
200
.iter()
201
.map(|(n, (p, s))| (n.clone(), (phys_expr_to_io_expr(p.clone()), s.clone())))
202
.collect(),
203
is_sumwise_complete: self.column_predicates.is_sumwise_complete,
204
}),
205
hive_predicate: self.hive_predicate.clone().map(phys_expr_to_io_expr),
206
hive_predicate_is_full_predicate: self.hive_predicate_is_full_predicate,
207
}
208
}
209
}
210
211
impl SkipBatchPredicate for SkipBatchPredicateHelper {
212
fn schema(&self) -> &SchemaRef {
213
&self.schema
214
}
215
216
fn evaluate_with_stat_df(&self, df: &DataFrame) -> PolarsResult<Bitmap> {
217
let array = self
218
.skip_batch_predicate
219
.evaluate(df, &Default::default())?;
220
let array = array.bool()?.rechunk();
221
let array = array.downcast_as_array();
222
223
let array = if let Some(validity) = array.validity() {
224
array.values() & validity
225
} else {
226
array.values().clone()
227
};
228
229
// @NOTE: Certain predicates like `1 == 1` will only output 1 value. We need to broadcast
230
// the result back to the dataframe length.
231
if array.len() == 1 && df.height() != 0 {
232
return Ok(Bitmap::new_with_value(array.get_bit(0), df.height()));
233
}
234
235
assert_eq!(array.len(), df.height());
236
Ok(array)
237
}
238
}
239
240