Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-mem-engine/src/predicate.rs
6939 views
1
use core::fmt;
2
use std::sync::Arc;
3
4
use arrow::bitmap::Bitmap;
5
use polars_core::frame::DataFrame;
6
use polars_core::prelude::{AnyValue, Column, Field, GroupPositions, PlHashMap, PlIndexSet};
7
use polars_core::scalar::Scalar;
8
use polars_core::schema::{Schema, SchemaRef};
9
use polars_error::PolarsResult;
10
use polars_expr::prelude::{AggregationContext, PhysicalExpr, phys_expr_to_io_expr};
11
use polars_expr::state::ExecutionState;
12
use polars_io::predicates::{
13
ColumnPredicates, ScanIOPredicate, SkipBatchPredicate, SpecializedColumnPredicate,
14
};
15
use polars_utils::pl_str::PlSmallStr;
16
use polars_utils::{IdxSize, format_pl_smallstr};
17
18
/// All the expressions and metadata used to filter out rows using predicates.
19
#[derive(Clone)]
20
pub struct ScanPredicate {
21
pub predicate: Arc<dyn PhysicalExpr>,
22
23
/// Column names that are used in the predicate.
24
pub live_columns: Arc<PlIndexSet<PlSmallStr>>,
25
26
/// A predicate expression used to skip record batches based on its statistics.
27
///
28
/// This expression will be given a batch size along with a `min`, `max` and `null count` for
29
/// each live column (set to `null` when it is not known) and the expression evaluates to
30
/// `true` if the whole batch can for sure be skipped. This may be conservative and evaluate to
31
/// `false` even when the batch could theoretically be skipped.
32
pub skip_batch_predicate: Option<Arc<dyn PhysicalExpr>>,
33
34
/// Partial predicates for each column for filter when loading columnar formats.
35
pub column_predicates: PhysicalColumnPredicates,
36
37
/// Predicate only referring to hive columns.
38
pub hive_predicate: Option<Arc<dyn PhysicalExpr>>,
39
pub hive_predicate_is_full_predicate: bool,
40
}
41
42
impl fmt::Debug for ScanPredicate {
43
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44
f.write_str("scan_predicate")
45
}
46
}
47
48
#[derive(Clone)]
49
pub struct PhysicalColumnPredicates {
50
pub predicates:
51
PlHashMap<PlSmallStr, (Arc<dyn PhysicalExpr>, Option<SpecializedColumnPredicate>)>,
52
pub is_sumwise_complete: bool,
53
}
54
55
/// Helper to implement [`SkipBatchPredicate`].
56
struct SkipBatchPredicateHelper {
57
skip_batch_predicate: Arc<dyn PhysicalExpr>,
58
schema: SchemaRef,
59
}
60
61
/// Helper for the [`PhysicalExpr`] trait to include constant columns.
62
pub struct PhysicalExprWithConstCols {
63
constants: Vec<(PlSmallStr, Scalar)>,
64
child: Arc<dyn PhysicalExpr>,
65
}
66
67
impl PhysicalExpr for PhysicalExprWithConstCols {
68
fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {
69
let mut df = df.clone();
70
for (name, scalar) in &self.constants {
71
df.with_column(Column::new_scalar(
72
name.clone(),
73
scalar.clone(),
74
df.height(),
75
))?;
76
}
77
78
self.child.evaluate(&df, state)
79
}
80
81
fn evaluate_on_groups<'a>(
82
&self,
83
df: &DataFrame,
84
groups: &'a GroupPositions,
85
state: &ExecutionState,
86
) -> PolarsResult<AggregationContext<'a>> {
87
let mut df = df.clone();
88
for (name, scalar) in &self.constants {
89
df.with_column(Column::new_scalar(
90
name.clone(),
91
scalar.clone(),
92
df.height(),
93
))?;
94
}
95
96
self.child.evaluate_on_groups(&df, groups, state)
97
}
98
99
fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {
100
self.child.to_field(input_schema)
101
}
102
fn is_scalar(&self) -> bool {
103
self.child.is_scalar()
104
}
105
}
106
107
impl ScanPredicate {
108
pub fn with_constant_columns(
109
&self,
110
constant_columns: impl IntoIterator<Item = (PlSmallStr, Scalar)>,
111
) -> Self {
112
let constant_columns = constant_columns.into_iter();
113
114
let mut live_columns = self.live_columns.as_ref().clone();
115
let mut skip_batch_predicate_constants =
116
Vec::with_capacity(if self.skip_batch_predicate.is_some() {
117
1 + constant_columns.size_hint().0 * 3
118
} else {
119
Default::default()
120
});
121
122
let predicate_constants = constant_columns
123
.filter_map(|(name, scalar): (PlSmallStr, Scalar)| {
124
if !live_columns.swap_remove(&name) {
125
return None;
126
}
127
128
if self.skip_batch_predicate.is_some() {
129
let mut null_count: Scalar = (0 as IdxSize).into();
130
131
// If the constant value is Null, we don't know how many nulls there are
132
// because the length of the batch may vary.
133
if scalar.is_null() {
134
null_count.update(AnyValue::Null);
135
}
136
137
skip_batch_predicate_constants.extend([
138
(format_pl_smallstr!("{name}_min"), scalar.clone()),
139
(format_pl_smallstr!("{name}_max"), scalar.clone()),
140
(format_pl_smallstr!("{name}_nc"), null_count),
141
]);
142
}
143
144
Some((name, scalar))
145
})
146
.collect();
147
148
let predicate = Arc::new(PhysicalExprWithConstCols {
149
constants: predicate_constants,
150
child: self.predicate.clone(),
151
});
152
let skip_batch_predicate = self.skip_batch_predicate.as_ref().map(|skp| {
153
Arc::new(PhysicalExprWithConstCols {
154
constants: skip_batch_predicate_constants,
155
child: skp.clone(),
156
}) as _
157
});
158
159
Self {
160
predicate,
161
live_columns: Arc::new(live_columns),
162
skip_batch_predicate,
163
column_predicates: self.column_predicates.clone(), // Q? Maybe this should cull
164
// predicates.
165
hive_predicate: None,
166
hive_predicate_is_full_predicate: false,
167
}
168
}
169
170
/// Create a predicate to skip batches using statistics.
171
pub(crate) fn to_dyn_skip_batch_predicate(
172
&self,
173
schema: SchemaRef,
174
) -> Option<Arc<dyn SkipBatchPredicate>> {
175
let skip_batch_predicate = self.skip_batch_predicate.as_ref()?.clone();
176
Some(Arc::new(SkipBatchPredicateHelper {
177
skip_batch_predicate,
178
schema,
179
}))
180
}
181
182
pub fn to_io(
183
&self,
184
skip_batch_predicate: Option<&Arc<dyn SkipBatchPredicate>>,
185
schema: SchemaRef,
186
) -> ScanIOPredicate {
187
ScanIOPredicate {
188
predicate: phys_expr_to_io_expr(self.predicate.clone()),
189
live_columns: self.live_columns.clone(),
190
skip_batch_predicate: skip_batch_predicate
191
.cloned()
192
.or_else(|| self.to_dyn_skip_batch_predicate(schema)),
193
column_predicates: Arc::new(ColumnPredicates {
194
predicates: self
195
.column_predicates
196
.predicates
197
.iter()
198
.map(|(n, (p, s))| (n.clone(), (phys_expr_to_io_expr(p.clone()), s.clone())))
199
.collect(),
200
is_sumwise_complete: self.column_predicates.is_sumwise_complete,
201
}),
202
hive_predicate: self.hive_predicate.clone().map(phys_expr_to_io_expr),
203
hive_predicate_is_full_predicate: self.hive_predicate_is_full_predicate,
204
}
205
}
206
}
207
208
impl SkipBatchPredicate for SkipBatchPredicateHelper {
209
fn schema(&self) -> &SchemaRef {
210
&self.schema
211
}
212
213
fn evaluate_with_stat_df(&self, df: &DataFrame) -> PolarsResult<Bitmap> {
214
let array = self
215
.skip_batch_predicate
216
.evaluate(df, &Default::default())?;
217
let array = array.bool()?;
218
let array = array.downcast_as_array();
219
220
let array = if let Some(validity) = array.validity() {
221
array.values() & validity
222
} else {
223
array.values().clone()
224
};
225
226
// @NOTE: Certain predicates like `1 == 1` will only output 1 value. We need to broadcast
227
// the result back to the dataframe length.
228
if array.len() == 1 && df.height() != 0 {
229
return Ok(Bitmap::new_with_value(array.get_bit(0), df.height()));
230
}
231
232
assert_eq!(array.len(), df.height());
233
Ok(array)
234
}
235
}
236
237