Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-io/src/predicates.rs
8424 views
1
use std::fmt;
2
3
use arrow::array::Array;
4
use arrow::bitmap::{Bitmap, BitmapBuilder};
5
use polars_core::prelude::*;
6
#[cfg(feature = "parquet")]
7
use polars_parquet::read::expr::{ParquetColumnExpr, ParquetScalar, SpecializedParquetColumnExpr};
8
use polars_utils::format_pl_smallstr;
9
#[cfg(feature = "serde")]
10
use serde::{Deserialize, Serialize};
11
12
pub trait PhysicalIoExpr: Send + Sync {
13
/// Take a [`DataFrame`] and produces a boolean [`Series`] that serves
14
/// as a predicate mask
15
fn evaluate_io(&self, df: &DataFrame) -> PolarsResult<Series>;
16
}
17
18
#[derive(Debug, Clone)]
19
pub enum SpecializedColumnPredicate {
20
Equal(Scalar),
21
/// A closed (inclusive) range.
22
Between(Scalar, Scalar),
23
EqualOneOf(Box<[Scalar]>),
24
StartsWith(Box<[u8]>),
25
EndsWith(Box<[u8]>),
26
RegexMatch(regex::bytes::Regex),
27
}
28
29
#[derive(Clone)]
30
pub struct ColumnPredicateExpr {
31
column_name: PlSmallStr,
32
dtype: DataType,
33
#[cfg(feature = "parquet")]
34
specialized: Option<SpecializedParquetColumnExpr>,
35
expr: Arc<dyn PhysicalIoExpr>,
36
}
37
38
impl ColumnPredicateExpr {
39
pub fn new(
40
column_name: PlSmallStr,
41
dtype: DataType,
42
expr: Arc<dyn PhysicalIoExpr>,
43
specialized: Option<SpecializedColumnPredicate>,
44
) -> Self {
45
use SpecializedColumnPredicate as S;
46
#[cfg(feature = "parquet")]
47
use SpecializedParquetColumnExpr as P;
48
#[cfg(feature = "parquet")]
49
let specialized = specialized.and_then(|s| {
50
Some(match s {
51
S::Equal(s) => P::Equal(cast_to_parquet_scalar(s)?),
52
S::Between(low, high) => {
53
P::Between(cast_to_parquet_scalar(low)?, cast_to_parquet_scalar(high)?)
54
},
55
S::EqualOneOf(scalars) => P::EqualOneOf(
56
scalars
57
.into_iter()
58
.map(|s| cast_to_parquet_scalar(s).ok_or(()))
59
.collect::<Result<Box<_>, ()>>()
60
.ok()?,
61
),
62
S::StartsWith(s) => P::StartsWith(s),
63
S::EndsWith(s) => P::EndsWith(s),
64
S::RegexMatch(s) => P::RegexMatch(s),
65
})
66
});
67
68
Self {
69
column_name,
70
dtype,
71
#[cfg(feature = "parquet")]
72
specialized,
73
expr,
74
}
75
}
76
}
77
78
#[cfg(feature = "parquet")]
79
impl ParquetColumnExpr for ColumnPredicateExpr {
80
fn evaluate_mut(&self, values: &dyn Array, bm: &mut BitmapBuilder) {
81
// We should never evaluate nulls with this.
82
assert!(values.validity().is_none_or(|v| v.set_bits() == 0));
83
84
// @TODO: Probably these unwraps should be removed.
85
let series =
86
Series::from_chunk_and_dtype(self.column_name.clone(), values.to_boxed(), &self.dtype)
87
.unwrap();
88
let column = series.into_column();
89
let df = unsafe { DataFrame::new_unchecked(values.len(), vec![column]) };
90
91
// @TODO: Probably these unwraps should be removed.
92
let true_mask = self.expr.evaluate_io(&df).unwrap();
93
let true_mask = true_mask.bool().unwrap();
94
95
bm.reserve(true_mask.len());
96
for chunk in true_mask.downcast_iter() {
97
match chunk.validity() {
98
None => bm.extend_from_bitmap(chunk.values()),
99
Some(v) => bm.extend_from_bitmap(&(chunk.values() & v)),
100
}
101
}
102
}
103
fn evaluate_null(&self) -> bool {
104
let column = Column::full_null(self.column_name.clone(), 1, &self.dtype);
105
let df = unsafe { DataFrame::new_unchecked(1, vec![column]) };
106
107
// @TODO: Probably these unwraps should be removed.
108
let true_mask = self.expr.evaluate_io(&df).unwrap();
109
let true_mask = true_mask.bool().unwrap();
110
111
true_mask.get(0).unwrap_or(false)
112
}
113
114
fn as_specialized(&self) -> Option<&SpecializedParquetColumnExpr> {
115
self.specialized.as_ref()
116
}
117
}
118
119
#[cfg(feature = "parquet")]
120
fn cast_to_parquet_scalar(scalar: Scalar) -> Option<ParquetScalar> {
121
use {AnyValue as A, ParquetScalar as P};
122
123
Some(match scalar.into_value() {
124
A::Null => P::Null,
125
A::Boolean(v) => P::Boolean(v),
126
127
A::UInt8(v) => P::UInt8(v),
128
A::UInt16(v) => P::UInt16(v),
129
A::UInt32(v) => P::UInt32(v),
130
A::UInt64(v) => P::UInt64(v),
131
132
A::Int8(v) => P::Int8(v),
133
A::Int16(v) => P::Int16(v),
134
A::Int32(v) => P::Int32(v),
135
A::Int64(v) => P::Int64(v),
136
137
#[cfg(feature = "dtype-time")]
138
A::Date(v) => P::Int32(v),
139
#[cfg(feature = "dtype-datetime")]
140
A::Datetime(v, _, _) | A::DatetimeOwned(v, _, _) => P::Int64(v),
141
#[cfg(feature = "dtype-duration")]
142
A::Duration(v, _) => P::Int64(v),
143
#[cfg(feature = "dtype-time")]
144
A::Time(v) => P::Int64(v),
145
146
A::Float32(v) => P::Float32(v),
147
A::Float64(v) => P::Float64(v),
148
149
// @TODO: Cast to string
150
#[cfg(feature = "dtype-categorical")]
151
A::Categorical(_, _) | A::CategoricalOwned(_, _) | A::Enum(_, _) | A::EnumOwned(_, _) => {
152
return None;
153
},
154
155
A::String(v) => P::String(v.into()),
156
A::StringOwned(v) => P::String(v.as_str().into()),
157
A::Binary(v) => P::Binary(v.into()),
158
A::BinaryOwned(v) => P::Binary(v.into()),
159
_ => return None,
160
})
161
}
162
163
#[cfg(any(feature = "parquet", feature = "ipc"))]
164
pub fn apply_predicate(
165
df: &mut DataFrame,
166
predicate: Option<&dyn PhysicalIoExpr>,
167
parallel: bool,
168
) -> PolarsResult<()> {
169
if let (Some(predicate), false) = (&predicate, df.columns().is_empty()) {
170
let s = predicate.evaluate_io(df)?;
171
let mask = s.bool().expect("filter predicates was not of type boolean");
172
173
if parallel {
174
*df = df.filter(mask)?;
175
} else {
176
*df = df.filter_seq(mask)?;
177
}
178
}
179
Ok(())
180
}
181
182
/// Statistics of the values in a column.
183
///
184
/// The following statistics are tracked for each row group:
185
/// - Null count
186
/// - Minimum value
187
/// - Maximum value
188
#[derive(Debug, Clone)]
189
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
190
pub struct ColumnStats {
191
field: Field,
192
// Each Series contains the stats for each row group.
193
null_count: Option<Series>,
194
min_value: Option<Series>,
195
max_value: Option<Series>,
196
}
197
198
impl ColumnStats {
199
/// Constructs a new [`ColumnStats`].
200
pub fn new(
201
field: Field,
202
null_count: Option<Series>,
203
min_value: Option<Series>,
204
max_value: Option<Series>,
205
) -> Self {
206
Self {
207
field,
208
null_count,
209
min_value,
210
max_value,
211
}
212
}
213
214
/// Constructs a new [`ColumnStats`] with only the [`Field`] information and no statistics.
215
pub fn from_field(field: Field) -> Self {
216
Self {
217
field,
218
null_count: None,
219
min_value: None,
220
max_value: None,
221
}
222
}
223
224
/// Constructs a new [`ColumnStats`] from a single-value Series.
225
pub fn from_column_literal(s: Series) -> Self {
226
debug_assert_eq!(s.len(), 1);
227
Self {
228
field: s.field().into_owned(),
229
null_count: None,
230
min_value: Some(s.clone()),
231
max_value: Some(s),
232
}
233
}
234
235
pub fn field_name(&self) -> &PlSmallStr {
236
self.field.name()
237
}
238
239
/// Returns the [`DataType`] of the column.
240
pub fn dtype(&self) -> &DataType {
241
self.field.dtype()
242
}
243
244
/// Returns the null count of each row group of the column.
245
pub fn get_null_count_state(&self) -> Option<&Series> {
246
self.null_count.as_ref()
247
}
248
249
/// Returns the minimum value of each row group of the column.
250
pub fn get_min_state(&self) -> Option<&Series> {
251
self.min_value.as_ref()
252
}
253
254
/// Returns the maximum value of each row group of the column.
255
pub fn get_max_state(&self) -> Option<&Series> {
256
self.max_value.as_ref()
257
}
258
259
/// Returns the null count of the column.
260
pub fn null_count(&self) -> Option<usize> {
261
match self.dtype() {
262
#[cfg(feature = "dtype-struct")]
263
DataType::Struct(_) => None,
264
_ => {
265
let s = self.get_null_count_state()?;
266
// if all null, there are no statistics.
267
if s.null_count() != s.len() {
268
s.sum().ok()
269
} else {
270
None
271
}
272
},
273
}
274
}
275
276
/// Returns the minimum and maximum values of the column as a single [`Series`].
277
pub fn to_min_max(&self) -> Option<Series> {
278
let min_val = self.get_min_state()?;
279
let max_val = self.get_max_state()?;
280
let dtype = self.dtype();
281
282
if !use_min_max(dtype) {
283
return None;
284
}
285
286
let mut min_max_values = min_val.clone();
287
min_max_values.append(max_val).unwrap();
288
if min_max_values.null_count() > 0 {
289
None
290
} else {
291
Some(min_max_values)
292
}
293
}
294
295
/// Returns the minimum value of the column as a single-value [`Series`].
296
///
297
/// Returns `None` if no maximum value is available.
298
pub fn to_min(&self) -> Option<&Series> {
299
// @scalar-opt
300
let min_val = self.min_value.as_ref()?;
301
let dtype = min_val.dtype();
302
303
if !use_min_max(dtype) || min_val.len() != 1 {
304
return None;
305
}
306
307
if min_val.null_count() > 0 {
308
None
309
} else {
310
Some(min_val)
311
}
312
}
313
314
/// Returns the maximum value of the column as a single-value [`Series`].
315
///
316
/// Returns `None` if no maximum value is available.
317
pub fn to_max(&self) -> Option<&Series> {
318
// @scalar-opt
319
let max_val = self.max_value.as_ref()?;
320
let dtype = max_val.dtype();
321
322
if !use_min_max(dtype) || max_val.len() != 1 {
323
return None;
324
}
325
326
if max_val.null_count() > 0 {
327
None
328
} else {
329
Some(max_val)
330
}
331
}
332
}
333
334
/// Returns whether the [`DataType`] supports minimum/maximum operations.
335
fn use_min_max(dtype: &DataType) -> bool {
336
dtype.is_primitive_numeric()
337
|| dtype.is_temporal()
338
|| matches!(
339
dtype,
340
DataType::String | DataType::Binary | DataType::Boolean
341
)
342
}
343
344
pub struct ColumnStatistics {
345
pub dtype: DataType,
346
pub min: AnyValue<'static>,
347
pub max: AnyValue<'static>,
348
pub null_count: Option<IdxSize>,
349
}
350
351
pub trait SkipBatchPredicate: Send + Sync {
352
fn schema(&self) -> &SchemaRef;
353
354
fn can_skip_batch(
355
&self,
356
batch_size: IdxSize,
357
live_columns: &PlIndexSet<PlSmallStr>,
358
mut statistics: PlIndexMap<PlSmallStr, ColumnStatistics>,
359
) -> PolarsResult<bool> {
360
let mut columns = Vec::with_capacity(1 + live_columns.len() * 3);
361
362
columns.push(Column::new_scalar(
363
PlSmallStr::from_static("len"),
364
Scalar::new(IDX_DTYPE, batch_size.into()),
365
1,
366
));
367
368
for col in live_columns.iter() {
369
let dtype = self.schema().get(col).unwrap();
370
let (min, max, nc) = match statistics.swap_remove(col) {
371
None => (
372
Scalar::null(dtype.clone()),
373
Scalar::null(dtype.clone()),
374
Scalar::null(IDX_DTYPE),
375
),
376
Some(stat) => (
377
Scalar::new(dtype.clone(), stat.min),
378
Scalar::new(dtype.clone(), stat.max),
379
Scalar::new(
380
IDX_DTYPE,
381
stat.null_count.map_or(AnyValue::Null, |nc| nc.into()),
382
),
383
),
384
};
385
columns.extend([
386
Column::new_scalar(format_pl_smallstr!("{col}_min"), min, 1),
387
Column::new_scalar(format_pl_smallstr!("{col}_max"), max, 1),
388
Column::new_scalar(format_pl_smallstr!("{col}_nc"), nc, 1),
389
]);
390
}
391
392
// SAFETY:
393
// * Each column is length = 1
394
// * We have an IndexSet, so each column name is unique
395
let df = unsafe { DataFrame::new_unchecked(1, columns) };
396
Ok(self.evaluate_with_stat_df(&df)?.get_bit(0))
397
}
398
fn evaluate_with_stat_df(&self, df: &DataFrame) -> PolarsResult<Bitmap>;
399
}
400
401
#[derive(Clone)]
402
pub struct ColumnPredicates {
403
pub predicates:
404
PlHashMap<PlSmallStr, (Arc<dyn PhysicalIoExpr>, Option<SpecializedColumnPredicate>)>,
405
pub is_sumwise_complete: bool,
406
}
407
408
// I want to be explicit here.
409
#[allow(clippy::derivable_impls)]
410
impl Default for ColumnPredicates {
411
fn default() -> Self {
412
Self {
413
predicates: PlHashMap::default(),
414
is_sumwise_complete: false,
415
}
416
}
417
}
418
419
pub struct PhysicalExprWithConstCols<T> {
420
constants: Vec<(PlSmallStr, Scalar)>,
421
child: T,
422
}
423
424
impl SkipBatchPredicate for PhysicalExprWithConstCols<Arc<dyn SkipBatchPredicate>> {
425
fn schema(&self) -> &SchemaRef {
426
self.child.schema()
427
}
428
429
fn evaluate_with_stat_df(&self, df: &DataFrame) -> PolarsResult<Bitmap> {
430
let mut df = df.clone();
431
for (name, scalar) in self.constants.iter() {
432
df.with_column(Column::new_scalar(
433
name.clone(),
434
scalar.clone(),
435
df.height(),
436
))?;
437
}
438
self.child.evaluate_with_stat_df(&df)
439
}
440
}
441
442
impl PhysicalIoExpr for PhysicalExprWithConstCols<Arc<dyn PhysicalIoExpr>> {
443
fn evaluate_io(&self, df: &DataFrame) -> PolarsResult<Series> {
444
let mut df = df.clone();
445
for (name, scalar) in self.constants.iter() {
446
df.with_column(Column::new_scalar(
447
name.clone(),
448
scalar.clone(),
449
df.height(),
450
))?;
451
}
452
453
self.child.evaluate_io(&df)
454
}
455
}
456
457
#[derive(Clone)]
458
pub struct ScanIOPredicate {
459
pub predicate: Arc<dyn PhysicalIoExpr>,
460
461
/// Column names that are used in the predicate.
462
pub live_columns: Arc<PlIndexSet<PlSmallStr>>,
463
464
/// A predicate that gets given statistics and evaluates whether a batch can be skipped.
465
pub skip_batch_predicate: Option<Arc<dyn SkipBatchPredicate>>,
466
467
/// A predicate that gets given statistics and evaluates whether a batch can be skipped.
468
pub column_predicates: Arc<ColumnPredicates>,
469
470
/// Predicate parts only referring to hive columns.
471
pub hive_predicate: Option<Arc<dyn PhysicalIoExpr>>,
472
473
pub hive_predicate_is_full_predicate: bool,
474
}
475
476
impl ScanIOPredicate {
477
pub fn set_external_constant_columns(&mut self, constant_columns: Vec<(PlSmallStr, Scalar)>) {
478
if constant_columns.is_empty() {
479
return;
480
}
481
482
let mut live_columns = self.live_columns.as_ref().clone();
483
for (c, _) in constant_columns.iter() {
484
live_columns.swap_remove(c);
485
}
486
self.live_columns = Arc::new(live_columns);
487
488
if let Some(skip_batch_predicate) = self.skip_batch_predicate.take() {
489
let mut sbp_constant_columns = Vec::with_capacity(constant_columns.len() * 3);
490
for (c, v) in constant_columns.iter() {
491
sbp_constant_columns.push((format_pl_smallstr!("{c}_min"), v.clone()));
492
sbp_constant_columns.push((format_pl_smallstr!("{c}_max"), v.clone()));
493
let nc = if v.is_null() {
494
AnyValue::Null
495
} else {
496
(0 as IdxSize).into()
497
};
498
sbp_constant_columns
499
.push((format_pl_smallstr!("{c}_nc"), Scalar::new(IDX_DTYPE, nc)));
500
}
501
self.skip_batch_predicate = Some(Arc::new(PhysicalExprWithConstCols {
502
constants: sbp_constant_columns,
503
child: skip_batch_predicate,
504
}));
505
}
506
507
let mut column_predicates = self.column_predicates.as_ref().clone();
508
for (c, _) in constant_columns.iter() {
509
column_predicates.predicates.remove(c);
510
}
511
self.column_predicates = Arc::new(column_predicates);
512
513
self.predicate = Arc::new(PhysicalExprWithConstCols {
514
constants: constant_columns,
515
child: self.predicate.clone(),
516
});
517
}
518
}
519
520
impl fmt::Debug for ScanIOPredicate {
521
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
522
f.write_str("scan_io_predicate")
523
}
524
}
525
526