Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-io/src/predicates.rs
6939 views
1
use std::fmt;
2
3
use arrow::array::Array;
4
use arrow::bitmap::{Bitmap, BitmapBuilder};
5
use polars_core::prelude::*;
6
#[cfg(feature = "parquet")]
7
use polars_parquet::read::expr::{ParquetColumnExpr, ParquetScalar, SpecializedParquetColumnExpr};
8
use polars_utils::format_pl_smallstr;
9
#[cfg(feature = "serde")]
10
use serde::{Deserialize, Serialize};
11
12
pub trait PhysicalIoExpr: Send + Sync {
13
/// Take a [`DataFrame`] and produces a boolean [`Series`] that serves
14
/// as a predicate mask
15
fn evaluate_io(&self, df: &DataFrame) -> PolarsResult<Series>;
16
}
17
18
#[derive(Debug, Clone)]
19
pub enum SpecializedColumnPredicate {
20
Equal(Scalar),
21
22
Between(Scalar, Scalar),
23
24
EqualOneOf(Box<[Scalar]>),
25
26
StartsWith(Box<[u8]>),
27
EndsWith(Box<[u8]>),
28
StartEndsWith(Box<[u8]>, Box<[u8]>),
29
}
30
31
#[derive(Clone)]
32
pub struct ColumnPredicateExpr {
33
column_name: PlSmallStr,
34
dtype: DataType,
35
#[cfg(feature = "parquet")]
36
specialized: Option<SpecializedParquetColumnExpr>,
37
expr: Arc<dyn PhysicalIoExpr>,
38
}
39
40
impl ColumnPredicateExpr {
41
pub fn new(
42
column_name: PlSmallStr,
43
dtype: DataType,
44
expr: Arc<dyn PhysicalIoExpr>,
45
specialized: Option<SpecializedColumnPredicate>,
46
) -> Self {
47
use SpecializedColumnPredicate as S;
48
#[cfg(feature = "parquet")]
49
use SpecializedParquetColumnExpr as P;
50
#[cfg(feature = "parquet")]
51
let specialized = specialized.and_then(|s| {
52
Some(match s {
53
S::Equal(s) => P::Equal(cast_to_parquet_scalar(s)?),
54
S::Between(low, high) => {
55
P::Between(cast_to_parquet_scalar(low)?, cast_to_parquet_scalar(high)?)
56
},
57
S::EqualOneOf(scalars) => P::EqualOneOf(
58
scalars
59
.into_iter()
60
.map(|s| cast_to_parquet_scalar(s).ok_or(()))
61
.collect::<Result<Box<_>, ()>>()
62
.ok()?,
63
),
64
S::StartsWith(s) => P::StartsWith(s),
65
S::EndsWith(s) => P::EndsWith(s),
66
S::StartEndsWith(start, end) => P::StartEndsWith(start, end),
67
})
68
});
69
70
Self {
71
column_name,
72
dtype,
73
#[cfg(feature = "parquet")]
74
specialized,
75
expr,
76
}
77
}
78
}
79
80
#[cfg(feature = "parquet")]
81
impl ParquetColumnExpr for ColumnPredicateExpr {
82
fn evaluate_mut(&self, values: &dyn Array, bm: &mut BitmapBuilder) {
83
// We should never evaluate nulls with this.
84
assert!(values.validity().is_none_or(|v| v.set_bits() == 0));
85
86
// @TODO: Probably these unwraps should be removed.
87
let series =
88
Series::from_chunk_and_dtype(self.column_name.clone(), values.to_boxed(), &self.dtype)
89
.unwrap();
90
let column = series.into_column();
91
let df = unsafe { DataFrame::new_no_checks(values.len(), vec![column]) };
92
93
// @TODO: Probably these unwraps should be removed.
94
let true_mask = self.expr.evaluate_io(&df).unwrap();
95
let true_mask = true_mask.bool().unwrap();
96
97
bm.reserve(true_mask.len());
98
for chunk in true_mask.downcast_iter() {
99
match chunk.validity() {
100
None => bm.extend_from_bitmap(chunk.values()),
101
Some(v) => bm.extend_from_bitmap(&(chunk.values() & v)),
102
}
103
}
104
}
105
fn evaluate_null(&self) -> bool {
106
let column = Column::full_null(self.column_name.clone(), 1, &self.dtype);
107
let df = unsafe { DataFrame::new_no_checks(1, vec![column]) };
108
109
// @TODO: Probably these unwraps should be removed.
110
let true_mask = self.expr.evaluate_io(&df).unwrap();
111
let true_mask = true_mask.bool().unwrap();
112
113
true_mask.get(0).unwrap_or(false)
114
}
115
116
fn as_specialized(&self) -> Option<&SpecializedParquetColumnExpr> {
117
self.specialized.as_ref()
118
}
119
}
120
121
#[cfg(feature = "parquet")]
122
fn cast_to_parquet_scalar(scalar: Scalar) -> Option<ParquetScalar> {
123
use {AnyValue as A, ParquetScalar as P};
124
125
Some(match scalar.into_value() {
126
A::Null => P::Null,
127
A::Boolean(v) => P::Boolean(v),
128
129
A::UInt8(v) => P::UInt8(v),
130
A::UInt16(v) => P::UInt16(v),
131
A::UInt32(v) => P::UInt32(v),
132
A::UInt64(v) => P::UInt64(v),
133
134
A::Int8(v) => P::Int8(v),
135
A::Int16(v) => P::Int16(v),
136
A::Int32(v) => P::Int32(v),
137
A::Int64(v) => P::Int64(v),
138
139
#[cfg(feature = "dtype-time")]
140
A::Date(v) => P::Int32(v),
141
#[cfg(feature = "dtype-datetime")]
142
A::Datetime(v, _, _) | A::DatetimeOwned(v, _, _) => P::Int64(v),
143
#[cfg(feature = "dtype-duration")]
144
A::Duration(v, _) => P::Int64(v),
145
#[cfg(feature = "dtype-time")]
146
A::Time(v) => P::Int64(v),
147
148
A::Float32(v) => P::Float32(v),
149
A::Float64(v) => P::Float64(v),
150
151
// @TODO: Cast to string
152
#[cfg(feature = "dtype-categorical")]
153
A::Categorical(_, _) | A::CategoricalOwned(_, _) | A::Enum(_, _) | A::EnumOwned(_, _) => {
154
return None;
155
},
156
157
A::String(v) => P::String(v.into()),
158
A::StringOwned(v) => P::String(v.as_str().into()),
159
A::Binary(v) => P::Binary(v.into()),
160
A::BinaryOwned(v) => P::Binary(v.into()),
161
_ => return None,
162
})
163
}
164
165
#[cfg(any(feature = "parquet", feature = "ipc"))]
166
pub fn apply_predicate(
167
df: &mut DataFrame,
168
predicate: Option<&dyn PhysicalIoExpr>,
169
parallel: bool,
170
) -> PolarsResult<()> {
171
if let (Some(predicate), false) = (&predicate, df.get_columns().is_empty()) {
172
let s = predicate.evaluate_io(df)?;
173
let mask = s.bool().expect("filter predicates was not of type boolean");
174
175
if parallel {
176
*df = df.filter(mask)?;
177
} else {
178
*df = df._filter_seq(mask)?;
179
}
180
}
181
Ok(())
182
}
183
184
/// Statistics of the values in a column.
185
///
186
/// The following statistics are tracked for each row group:
187
/// - Null count
188
/// - Minimum value
189
/// - Maximum value
190
#[derive(Debug, Clone)]
191
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
192
pub struct ColumnStats {
193
field: Field,
194
// Each Series contains the stats for each row group.
195
null_count: Option<Series>,
196
min_value: Option<Series>,
197
max_value: Option<Series>,
198
}
199
200
impl ColumnStats {
201
/// Constructs a new [`ColumnStats`].
202
pub fn new(
203
field: Field,
204
null_count: Option<Series>,
205
min_value: Option<Series>,
206
max_value: Option<Series>,
207
) -> Self {
208
Self {
209
field,
210
null_count,
211
min_value,
212
max_value,
213
}
214
}
215
216
/// Constructs a new [`ColumnStats`] with only the [`Field`] information and no statistics.
217
pub fn from_field(field: Field) -> Self {
218
Self {
219
field,
220
null_count: None,
221
min_value: None,
222
max_value: None,
223
}
224
}
225
226
/// Constructs a new [`ColumnStats`] from a single-value Series.
227
pub fn from_column_literal(s: Series) -> Self {
228
debug_assert_eq!(s.len(), 1);
229
Self {
230
field: s.field().into_owned(),
231
null_count: None,
232
min_value: Some(s.clone()),
233
max_value: Some(s),
234
}
235
}
236
237
pub fn field_name(&self) -> &PlSmallStr {
238
self.field.name()
239
}
240
241
/// Returns the [`DataType`] of the column.
242
pub fn dtype(&self) -> &DataType {
243
self.field.dtype()
244
}
245
246
/// Returns the null count of each row group of the column.
247
pub fn get_null_count_state(&self) -> Option<&Series> {
248
self.null_count.as_ref()
249
}
250
251
/// Returns the minimum value of each row group of the column.
252
pub fn get_min_state(&self) -> Option<&Series> {
253
self.min_value.as_ref()
254
}
255
256
/// Returns the maximum value of each row group of the column.
257
pub fn get_max_state(&self) -> Option<&Series> {
258
self.max_value.as_ref()
259
}
260
261
/// Returns the null count of the column.
262
pub fn null_count(&self) -> Option<usize> {
263
match self.dtype() {
264
#[cfg(feature = "dtype-struct")]
265
DataType::Struct(_) => None,
266
_ => {
267
let s = self.get_null_count_state()?;
268
// if all null, there are no statistics.
269
if s.null_count() != s.len() {
270
s.sum().ok()
271
} else {
272
None
273
}
274
},
275
}
276
}
277
278
/// Returns the minimum and maximum values of the column as a single [`Series`].
279
pub fn to_min_max(&self) -> Option<Series> {
280
let min_val = self.get_min_state()?;
281
let max_val = self.get_max_state()?;
282
let dtype = self.dtype();
283
284
if !use_min_max(dtype) {
285
return None;
286
}
287
288
let mut min_max_values = min_val.clone();
289
min_max_values.append(max_val).unwrap();
290
if min_max_values.null_count() > 0 {
291
None
292
} else {
293
Some(min_max_values)
294
}
295
}
296
297
/// Returns the minimum value of the column as a single-value [`Series`].
298
///
299
/// Returns `None` if no maximum value is available.
300
pub fn to_min(&self) -> Option<&Series> {
301
// @scalar-opt
302
let min_val = self.min_value.as_ref()?;
303
let dtype = min_val.dtype();
304
305
if !use_min_max(dtype) || min_val.len() != 1 {
306
return None;
307
}
308
309
if min_val.null_count() > 0 {
310
None
311
} else {
312
Some(min_val)
313
}
314
}
315
316
/// Returns the maximum value of the column as a single-value [`Series`].
317
///
318
/// Returns `None` if no maximum value is available.
319
pub fn to_max(&self) -> Option<&Series> {
320
// @scalar-opt
321
let max_val = self.max_value.as_ref()?;
322
let dtype = max_val.dtype();
323
324
if !use_min_max(dtype) || max_val.len() != 1 {
325
return None;
326
}
327
328
if max_val.null_count() > 0 {
329
None
330
} else {
331
Some(max_val)
332
}
333
}
334
}
335
336
/// Returns whether the [`DataType`] supports minimum/maximum operations.
337
fn use_min_max(dtype: &DataType) -> bool {
338
dtype.is_primitive_numeric()
339
|| dtype.is_temporal()
340
|| matches!(
341
dtype,
342
DataType::String | DataType::Binary | DataType::Boolean
343
)
344
}
345
346
pub struct ColumnStatistics {
347
pub dtype: DataType,
348
pub min: AnyValue<'static>,
349
pub max: AnyValue<'static>,
350
pub null_count: Option<IdxSize>,
351
}
352
353
pub trait SkipBatchPredicate: Send + Sync {
354
fn schema(&self) -> &SchemaRef;
355
356
fn can_skip_batch(
357
&self,
358
batch_size: IdxSize,
359
live_columns: &PlIndexSet<PlSmallStr>,
360
mut statistics: PlIndexMap<PlSmallStr, ColumnStatistics>,
361
) -> PolarsResult<bool> {
362
let mut columns = Vec::with_capacity(1 + live_columns.len() * 3);
363
364
columns.push(Column::new_scalar(
365
PlSmallStr::from_static("len"),
366
Scalar::new(IDX_DTYPE, batch_size.into()),
367
1,
368
));
369
370
for col in live_columns.iter() {
371
let dtype = self.schema().get(col).unwrap();
372
let (min, max, nc) = match statistics.swap_remove(col) {
373
None => (
374
Scalar::null(dtype.clone()),
375
Scalar::null(dtype.clone()),
376
Scalar::null(IDX_DTYPE),
377
),
378
Some(stat) => (
379
Scalar::new(dtype.clone(), stat.min),
380
Scalar::new(dtype.clone(), stat.max),
381
Scalar::new(
382
IDX_DTYPE,
383
stat.null_count.map_or(AnyValue::Null, |nc| nc.into()),
384
),
385
),
386
};
387
columns.extend([
388
Column::new_scalar(format_pl_smallstr!("{col}_min"), min, 1),
389
Column::new_scalar(format_pl_smallstr!("{col}_max"), max, 1),
390
Column::new_scalar(format_pl_smallstr!("{col}_nc"), nc, 1),
391
]);
392
}
393
394
// SAFETY:
395
// * Each column is length = 1
396
// * We have an IndexSet, so each column name is unique
397
let df = unsafe { DataFrame::new_no_checks(1, columns) };
398
Ok(self.evaluate_with_stat_df(&df)?.get_bit(0))
399
}
400
fn evaluate_with_stat_df(&self, df: &DataFrame) -> PolarsResult<Bitmap>;
401
}
402
403
#[derive(Clone)]
404
pub struct ColumnPredicates {
405
pub predicates:
406
PlHashMap<PlSmallStr, (Arc<dyn PhysicalIoExpr>, Option<SpecializedColumnPredicate>)>,
407
pub is_sumwise_complete: bool,
408
}
409
410
// I want to be explicit here.
411
#[allow(clippy::derivable_impls)]
412
impl Default for ColumnPredicates {
413
fn default() -> Self {
414
Self {
415
predicates: PlHashMap::default(),
416
is_sumwise_complete: false,
417
}
418
}
419
}
420
421
pub struct PhysicalExprWithConstCols<T> {
422
constants: Vec<(PlSmallStr, Scalar)>,
423
child: T,
424
}
425
426
impl SkipBatchPredicate for PhysicalExprWithConstCols<Arc<dyn SkipBatchPredicate>> {
427
fn schema(&self) -> &SchemaRef {
428
self.child.schema()
429
}
430
431
fn evaluate_with_stat_df(&self, df: &DataFrame) -> PolarsResult<Bitmap> {
432
let mut df = df.clone();
433
for (name, scalar) in self.constants.iter() {
434
df.with_column(Column::new_scalar(
435
name.clone(),
436
scalar.clone(),
437
df.height(),
438
))?;
439
}
440
self.child.evaluate_with_stat_df(&df)
441
}
442
}
443
444
impl PhysicalIoExpr for PhysicalExprWithConstCols<Arc<dyn PhysicalIoExpr>> {
445
fn evaluate_io(&self, df: &DataFrame) -> PolarsResult<Series> {
446
let mut df = df.clone();
447
for (name, scalar) in self.constants.iter() {
448
df.with_column(Column::new_scalar(
449
name.clone(),
450
scalar.clone(),
451
df.height(),
452
))?;
453
}
454
455
self.child.evaluate_io(&df)
456
}
457
}
458
459
#[derive(Clone)]
460
pub struct ScanIOPredicate {
461
pub predicate: Arc<dyn PhysicalIoExpr>,
462
463
/// Column names that are used in the predicate.
464
pub live_columns: Arc<PlIndexSet<PlSmallStr>>,
465
466
/// A predicate that gets given statistics and evaluates whether a batch can be skipped.
467
pub skip_batch_predicate: Option<Arc<dyn SkipBatchPredicate>>,
468
469
/// A predicate that gets given statistics and evaluates whether a batch can be skipped.
470
pub column_predicates: Arc<ColumnPredicates>,
471
472
/// Predicate parts only referring to hive columns.
473
pub hive_predicate: Option<Arc<dyn PhysicalIoExpr>>,
474
475
pub hive_predicate_is_full_predicate: bool,
476
}
477
478
impl ScanIOPredicate {
479
pub fn set_external_constant_columns(&mut self, constant_columns: Vec<(PlSmallStr, Scalar)>) {
480
if constant_columns.is_empty() {
481
return;
482
}
483
484
let mut live_columns = self.live_columns.as_ref().clone();
485
for (c, _) in constant_columns.iter() {
486
live_columns.swap_remove(c);
487
}
488
self.live_columns = Arc::new(live_columns);
489
490
if let Some(skip_batch_predicate) = self.skip_batch_predicate.take() {
491
let mut sbp_constant_columns = Vec::with_capacity(constant_columns.len() * 3);
492
for (c, v) in constant_columns.iter() {
493
sbp_constant_columns.push((format_pl_smallstr!("{c}_min"), v.clone()));
494
sbp_constant_columns.push((format_pl_smallstr!("{c}_max"), v.clone()));
495
let nc = if v.is_null() {
496
AnyValue::Null
497
} else {
498
(0 as IdxSize).into()
499
};
500
sbp_constant_columns
501
.push((format_pl_smallstr!("{c}_nc"), Scalar::new(IDX_DTYPE, nc)));
502
}
503
self.skip_batch_predicate = Some(Arc::new(PhysicalExprWithConstCols {
504
constants: sbp_constant_columns,
505
child: skip_batch_predicate,
506
}));
507
}
508
509
let mut column_predicates = self.column_predicates.as_ref().clone();
510
for (c, _) in constant_columns.iter() {
511
column_predicates.predicates.remove(c);
512
}
513
self.column_predicates = Arc::new(column_predicates);
514
515
self.predicate = Arc::new(PhysicalExprWithConstCols {
516
constants: constant_columns,
517
child: self.predicate.clone(),
518
});
519
}
520
}
521
522
impl fmt::Debug for ScanIOPredicate {
523
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
524
f.write_str("scan_io_predicate")
525
}
526
}
527
528