Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/plans/aexpr/predicates/column_expr.rs
8416 views
1
//! This module creates predicates splits predicates into partial per-column predicates.
2
3
use polars_core::datatypes::DataType;
4
use polars_core::prelude::AnyValue;
5
use polars_core::scalar::Scalar;
6
use polars_core::schema::Schema;
7
use polars_io::predicates::SpecializedColumnPredicate;
8
use polars_ops::series::ClosedInterval;
9
use polars_utils::aliases::PlHashMap;
10
use polars_utils::arena::{Arena, Node};
11
use polars_utils::pl_str::PlSmallStr;
12
13
use super::get_binary_expr_col_and_lv;
14
use crate::dsl::Operator;
15
use crate::plans::aexpr::evaluate::{constant_evaluate, into_column};
16
use crate::plans::{
17
AExpr, IRBooleanFunction, IRFunctionExpr, MintermIter, aexpr_to_leaf_names_iter,
18
};
19
20
pub struct ColumnPredicates {
21
pub predicates: PlHashMap<PlSmallStr, (Node, Option<SpecializedColumnPredicate>)>,
22
23
/// Are all column predicates AND-ed together the original predicate.
24
pub is_sumwise_complete: bool,
25
}
26
27
pub fn aexpr_to_column_predicates(
28
root: Node,
29
expr_arena: &mut Arena<AExpr>,
30
schema: &Schema,
31
) -> ColumnPredicates {
32
let mut predicates =
33
PlHashMap::<PlSmallStr, (Node, Option<SpecializedColumnPredicate>)>::default();
34
let mut is_sumwise_complete = true;
35
36
let minterms = MintermIter::new(root, expr_arena).collect::<Vec<_>>();
37
38
let mut leaf_names = Vec::with_capacity(2);
39
for minterm in minterms {
40
leaf_names.clear();
41
leaf_names.extend(aexpr_to_leaf_names_iter(minterm, expr_arena).cloned());
42
43
if leaf_names.len() != 1 {
44
is_sumwise_complete = false;
45
continue;
46
}
47
48
let column = leaf_names.pop().unwrap();
49
let Some(dtype) = schema.get(&column) else {
50
is_sumwise_complete = false;
51
continue;
52
};
53
54
// We really don't want to deal with these types.
55
use DataType as D;
56
match dtype {
57
#[cfg(feature = "dtype-categorical")]
58
D::Enum(_, _) | D::Categorical(_, _) => {
59
is_sumwise_complete = false;
60
continue;
61
},
62
#[cfg(feature = "dtype-decimal")]
63
D::Decimal(_, _) => {
64
is_sumwise_complete = false;
65
continue;
66
},
67
#[cfg(feature = "object")]
68
D::Object(_) => {
69
is_sumwise_complete = false;
70
continue;
71
},
72
#[cfg(feature = "dtype-f16")]
73
D::Float16 => {
74
is_sumwise_complete = false;
75
continue;
76
},
77
D::Float32 | D::Float64 => {
78
is_sumwise_complete = false;
79
continue;
80
},
81
_ if dtype.is_nested() => {
82
is_sumwise_complete = false;
83
continue;
84
},
85
_ => {},
86
}
87
88
let dtype = dtype.clone();
89
let entry = predicates.entry(column);
90
91
entry
92
.and_modify(|n| {
93
let left = n.0;
94
n.0 = expr_arena.add(AExpr::BinaryExpr {
95
left,
96
op: Operator::LogicalAnd,
97
right: minterm,
98
});
99
n.1 = None;
100
})
101
.or_insert_with(|| {
102
(
103
minterm,
104
Some(()).and_then(|_| {
105
let aexpr = expr_arena.get(minterm);
106
107
match aexpr {
108
#[cfg(all(feature = "regex", feature = "strings"))]
109
AExpr::Function {
110
input,
111
function: IRFunctionExpr::StringExpr(str_function),
112
options: _,
113
} if matches!(
114
str_function,
115
crate::plans::IRStringFunction::Contains { literal: _, strict: true } |
116
crate::plans::IRStringFunction::EndsWith |
117
crate::plans::IRStringFunction::StartsWith
118
) => {
119
use crate::plans::IRStringFunction;
120
121
assert_eq!(input.len(), 2);
122
into_column(input[0].node(), expr_arena)?;
123
let lv = constant_evaluate(
124
input[1].node(),
125
expr_arena,
126
schema,
127
0,
128
)??;
129
130
if !lv.is_scalar() {
131
return None;
132
}
133
let lv = lv.extract_str()?;
134
135
match str_function {
136
IRStringFunction::Contains { literal, strict: _ } => {
137
let pattern = if *literal {
138
regex::escape(lv)
139
} else {
140
lv.to_string()
141
};
142
let pattern = regex::bytes::Regex::new(&pattern).ok()?;
143
Some(SpecializedColumnPredicate::RegexMatch(pattern))
144
},
145
IRStringFunction::StartsWith => Some(SpecializedColumnPredicate::StartsWith(lv.as_bytes().into())),
146
IRStringFunction::EndsWith => Some(SpecializedColumnPredicate::EndsWith(lv.as_bytes().into())),
147
_ => unreachable!(),
148
}
149
},
150
AExpr::Function {
151
input,
152
function: IRFunctionExpr::Boolean(IRBooleanFunction::IsNull),
153
options: _,
154
} => {
155
assert_eq!(input.len(), 1);
156
if into_column(input[0].node(), expr_arena)
157
.is_some()
158
{
159
Some(SpecializedColumnPredicate::Equal(Scalar::null(
160
dtype,
161
)))
162
} else {
163
None
164
}
165
},
166
#[cfg(feature = "is_between")]
167
AExpr::Function {
168
input,
169
function: IRFunctionExpr::Boolean(IRBooleanFunction::IsBetween { closed }),
170
options: _,
171
} => {
172
into_column(input[0].node(), expr_arena)?;
173
174
let (Some(l), Some(r)) = (
175
constant_evaluate(
176
input[1].node(),
177
expr_arena,
178
schema,
179
0,
180
)?,
181
constant_evaluate(
182
input[2].node(),
183
expr_arena,
184
schema,
185
0,
186
)?,
187
) else {
188
return None;
189
};
190
let l = l.to_any_value()?;
191
let r = r.to_any_value()?;
192
if l.dtype() != dtype || r.dtype() != dtype {
193
return None;
194
}
195
196
let (low_closed, high_closed) = match closed {
197
ClosedInterval::Both => (true, true),
198
ClosedInterval::Left => (true, false),
199
ClosedInterval::Right => (false, true),
200
ClosedInterval::None => (false, false),
201
};
202
is_between(
203
&dtype,
204
Some(Scalar::new(dtype.clone(), l.into_static())),
205
Some(Scalar::new(dtype.clone(), r.into_static())),
206
low_closed,
207
high_closed,
208
)
209
},
210
#[cfg(feature = "is_in")]
211
AExpr::Function {
212
input,
213
function: IRFunctionExpr::Boolean(IRBooleanFunction::IsIn { nulls_equal }),
214
options: _,
215
} => {
216
into_column(input[0].node(), expr_arena)?;
217
218
let values = constant_evaluate(
219
input[1].node(),
220
expr_arena,
221
schema,
222
0,
223
)??;
224
let values = values.to_any_value()?;
225
226
let values = match values {
227
AnyValue::List(v) => v,
228
#[cfg(feature = "dtype-array")]
229
AnyValue::Array(v, _) => v,
230
_ => return None,
231
};
232
233
if values.dtype() != &dtype {
234
return None;
235
}
236
if !nulls_equal && values.has_nulls() {
237
return None;
238
}
239
240
let values = values.iter()
241
.map(|av| {
242
Scalar::new(dtype.clone(), av.into_static())
243
})
244
.collect();
245
246
Some(SpecializedColumnPredicate::EqualOneOf(values))
247
},
248
AExpr::Function {
249
input,
250
function: IRFunctionExpr::Boolean(IRBooleanFunction::Not),
251
options: _,
252
} => {
253
if !dtype.is_bool() {
254
return None;
255
}
256
257
assert_eq!(input.len(), 1);
258
if into_column(input[0].node(), expr_arena)
259
.is_some()
260
{
261
Some(SpecializedColumnPredicate::Equal(false.into()))
262
} else {
263
None
264
}
265
},
266
AExpr::BinaryExpr { left, op, right } => {
267
let ((_, _), (lv, lv_node)) =
268
get_binary_expr_col_and_lv(*left, *right, expr_arena, schema)?;
269
let lv = lv?;
270
let av = lv.to_any_value()?;
271
if av.dtype() != dtype {
272
return None;
273
}
274
let scalar = Scalar::new(dtype.clone(), av.into_static());
275
use Operator as O;
276
match (op, lv_node == *right) {
277
(O::Eq, _) if scalar.is_null() => None,
278
(O::Eq | O::EqValidity, _) => {
279
Some(SpecializedColumnPredicate::Equal(scalar))
280
},
281
(O::Lt, true) | (O::Gt, false) => {
282
is_between(&dtype, None, Some(scalar), false, false)
283
},
284
(O::Lt, false) | (O::Gt, true) => {
285
is_between(&dtype, Some(scalar), None, false, false)
286
},
287
(O::LtEq, true) | (O::GtEq, false) => {
288
is_between(&dtype, None, Some(scalar), false, true)
289
},
290
(O::LtEq, false) | (O::GtEq, true) => {
291
is_between(&dtype, Some(scalar), None, true, false)
292
},
293
_ => None,
294
}
295
},
296
_ => None,
297
}
298
}),
299
)
300
});
301
}
302
303
ColumnPredicates {
304
predicates,
305
is_sumwise_complete,
306
}
307
}
308
309
fn is_between(
310
dtype: &DataType,
311
low: Option<Scalar>,
312
high: Option<Scalar>,
313
mut low_closed: bool,
314
mut high_closed: bool,
315
) -> Option<SpecializedColumnPredicate> {
316
let dtype = dtype.to_physical();
317
318
if !dtype.is_integer() {
319
return None;
320
}
321
assert!(low.is_some() || high.is_some());
322
323
low_closed |= low.is_none();
324
high_closed |= high.is_none();
325
326
let mut low = low.map_or_else(|| dtype.min().unwrap(), |sc| sc.to_physical());
327
let mut high = high.map_or_else(|| dtype.max().unwrap(), |sc| sc.to_physical());
328
329
macro_rules! ints {
330
($($t:ident),+) => {
331
match (low.any_value_mut(), high.any_value_mut()) {
332
$(
333
(AV::$t(l), AV::$t(h)) => {
334
if !low_closed {
335
*l = l.checked_add(1)?;
336
}
337
if !high_closed {
338
*h = h.checked_sub(1)?;
339
}
340
if *l > *h {
341
// Really this ought to indicate that nothing should be
342
// loaded since the condition is impossible, but unclear
343
// how to do that at this abstraction layer. Could add
344
// SpecializedColumnPredicate::Impossible or something,
345
// maybe.
346
return None;
347
}
348
},
349
)+
350
_ => return None,
351
}
352
};
353
}
354
355
use AnyValue as AV;
356
ints!(
357
Int8, Int16, Int32, Int64, Int128, UInt8, UInt16, UInt32, UInt64
358
);
359
360
Some(SpecializedColumnPredicate::Between(low, high))
361
}
362
363
#[cfg(test)]
364
mod tests {
365
use polars_error::PolarsResult;
366
367
use super::*;
368
use crate::dsl::Expr;
369
use crate::dsl::functions::col;
370
use crate::plans::{ExprToIRContext, to_expr_ir, typed_lit};
371
372
/// Given a single-column `Expr`, call `aexpr_to_column_predicates()` and
373
/// return the corresponding column's `Option<SpecializedColumnPredicate>`.
374
fn column_predicate_for_expr(
375
col_dtype: DataType,
376
col_name: &str,
377
expr: Expr,
378
) -> PolarsResult<Option<SpecializedColumnPredicate>> {
379
let mut arena = Arena::new();
380
let schema = Schema::from_iter_check_duplicates([(col_name.into(), col_dtype)])?;
381
let mut ctx = ExprToIRContext::new(&mut arena, &schema);
382
let expr_ir = to_expr_ir(expr, &mut ctx)?;
383
let column_predicates = aexpr_to_column_predicates(expr_ir.node(), &mut arena, &schema);
384
assert_eq!(column_predicates.predicates.len(), 1);
385
let Some((col_name2, (_, predicate))) =
386
column_predicates.predicates.clone().into_iter().next()
387
else {
388
panic!(
389
"Unexpected column predicates: {:?}",
390
column_predicates.predicates
391
);
392
};
393
assert_eq!(col_name, col_name2);
394
Ok(predicate)
395
}
396
397
#[test]
398
fn column_predicate_for_inequality_operators() -> PolarsResult<()> {
399
let col_name = "testcol";
400
// Array of (expr, expected minimum, expected maximum):
401
let test_values: [(Expr, i8, i8); _] = [
402
(col(col_name).lt(typed_lit(10i8)), -128, 9),
403
(col(col_name).lt(typed_lit(-11i8)), -128, -12),
404
(col(col_name).gt(typed_lit(17i8)), 18, 127),
405
(col(col_name).gt(typed_lit(-10i8)), -9, 127),
406
(col(col_name).lt_eq(typed_lit(10i8)), -128, 10),
407
(col(col_name).lt_eq(typed_lit(-11i8)), -128, -11),
408
(col(col_name).gt_eq(typed_lit(17i8)), 17, 127),
409
(col(col_name).gt_eq(typed_lit(-10i8)), -10, 127),
410
];
411
for (expr, expected_min, expected_max) in test_values {
412
let predicate = column_predicate_for_expr(DataType::Int8, col_name, expr.clone())?;
413
if let Some(SpecializedColumnPredicate::Between(actual_min, actual_max)) = predicate {
414
assert_eq!(
415
(expected_min.into(), expected_max.into()),
416
(actual_min, actual_max)
417
);
418
} else {
419
panic!("{predicate:?} is unexpected for {expr:?}");
420
}
421
}
422
Ok(())
423
}
424
425
#[test]
426
fn column_predicate_is_between() -> PolarsResult<()> {
427
let col_name = "testcol";
428
// ClosedInterval, expected min, expected max:
429
let test_values: [(_, i8, i8); _] = [
430
(ClosedInterval::Both, 1, 10),
431
(ClosedInterval::Left, 1, 9),
432
(ClosedInterval::Right, 2, 10),
433
(ClosedInterval::None, 2, 9),
434
];
435
for (interval, expected_min, expected_max) in test_values {
436
let expr = col(col_name).is_between(typed_lit(1i8), typed_lit(10i8), interval);
437
let predicate = column_predicate_for_expr(DataType::Int8, col_name, expr.clone())?;
438
if let Some(SpecializedColumnPredicate::Between(actual_min, actual_max)) = predicate {
439
assert_eq!(
440
(expected_min.into(), expected_max.into()),
441
(actual_min, actual_max)
442
);
443
} else {
444
panic!("{predicate:?} is unexpected for {expr:?}");
445
}
446
}
447
Ok(())
448
}
449
}
450
451