Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/plans/aexpr/predicates/column_expr.rs
7889 views
1
//! This module creates predicates splits predicates into partial per-column predicates.
2
3
use polars_core::datatypes::DataType;
4
use polars_core::prelude::AnyValue;
5
use polars_core::scalar::Scalar;
6
use polars_core::schema::Schema;
7
use polars_io::predicates::SpecializedColumnPredicate;
8
use polars_ops::series::ClosedInterval;
9
use polars_utils::aliases::PlHashMap;
10
use polars_utils::arena::{Arena, Node};
11
use polars_utils::pl_str::PlSmallStr;
12
13
use super::get_binary_expr_col_and_lv;
14
use crate::dsl::Operator;
15
use crate::plans::aexpr::evaluate::{constant_evaluate, into_column};
16
use crate::plans::{
17
AExpr, IRBooleanFunction, IRFunctionExpr, MintermIter, aexpr_to_leaf_names_iter,
18
};
19
20
pub struct ColumnPredicates {
21
pub predicates: PlHashMap<PlSmallStr, (Node, Option<SpecializedColumnPredicate>)>,
22
23
/// Are all column predicates AND-ed together the original predicate.
24
pub is_sumwise_complete: bool,
25
}
26
27
pub fn aexpr_to_column_predicates(
28
root: Node,
29
expr_arena: &mut Arena<AExpr>,
30
schema: &Schema,
31
) -> ColumnPredicates {
32
let mut predicates =
33
PlHashMap::<PlSmallStr, (Node, Option<SpecializedColumnPredicate>)>::default();
34
let mut is_sumwise_complete = true;
35
36
let minterms = MintermIter::new(root, expr_arena).collect::<Vec<_>>();
37
38
let mut leaf_names = Vec::with_capacity(2);
39
for minterm in minterms {
40
leaf_names.clear();
41
leaf_names.extend(aexpr_to_leaf_names_iter(minterm, expr_arena).cloned());
42
43
if leaf_names.len() != 1 {
44
is_sumwise_complete = false;
45
continue;
46
}
47
48
let column = leaf_names.pop().unwrap();
49
let Some(dtype) = schema.get(&column) else {
50
is_sumwise_complete = false;
51
continue;
52
};
53
54
// We really don't want to deal with these types.
55
use DataType as D;
56
match dtype {
57
#[cfg(feature = "dtype-categorical")]
58
D::Enum(_, _) | D::Categorical(_, _) => {
59
is_sumwise_complete = false;
60
continue;
61
},
62
#[cfg(feature = "dtype-decimal")]
63
D::Decimal(_, _) => {
64
is_sumwise_complete = false;
65
continue;
66
},
67
#[cfg(feature = "object")]
68
D::Object(_) => {
69
is_sumwise_complete = false;
70
continue;
71
},
72
#[cfg(feature = "dtype-f16")]
73
D::Float16 => {
74
is_sumwise_complete = false;
75
continue;
76
},
77
D::Float32 | D::Float64 => {
78
is_sumwise_complete = false;
79
continue;
80
},
81
_ if dtype.is_nested() => {
82
is_sumwise_complete = false;
83
continue;
84
},
85
_ => {},
86
}
87
88
let dtype = dtype.clone();
89
let entry = predicates.entry(column);
90
91
entry
92
.and_modify(|n| {
93
let left = n.0;
94
n.0 = expr_arena.add(AExpr::BinaryExpr {
95
left,
96
op: Operator::LogicalAnd,
97
right: minterm,
98
});
99
n.1 = None;
100
})
101
.or_insert_with(|| {
102
(
103
minterm,
104
Some(()).and_then(|_| {
105
let aexpr = expr_arena.get(minterm);
106
107
match aexpr {
108
#[cfg(all(feature = "regex", feature = "strings"))]
109
AExpr::Function {
110
input,
111
function: IRFunctionExpr::StringExpr(str_function),
112
options: _,
113
} if matches!(
114
str_function,
115
crate::plans::IRStringFunction::Contains { literal: _, strict: true } |
116
crate::plans::IRStringFunction::EndsWith |
117
crate::plans::IRStringFunction::StartsWith
118
) => {
119
use crate::plans::IRStringFunction;
120
121
assert_eq!(input.len(), 2);
122
into_column(input[0].node(), expr_arena)?;
123
let lv = constant_evaluate(
124
input[1].node(),
125
expr_arena,
126
schema,
127
0,
128
)??;
129
130
if !lv.is_scalar() {
131
return None;
132
}
133
let lv = lv.extract_str()?;
134
135
match str_function {
136
IRStringFunction::Contains { literal, strict: _ } => {
137
let pattern = if *literal {
138
regex::escape(lv)
139
} else {
140
lv.to_string()
141
};
142
let pattern = regex::bytes::Regex::new(&pattern).ok()?;
143
Some(SpecializedColumnPredicate::RegexMatch(pattern))
144
},
145
IRStringFunction::StartsWith => Some(SpecializedColumnPredicate::StartsWith(lv.as_bytes().into())),
146
IRStringFunction::EndsWith => Some(SpecializedColumnPredicate::EndsWith(lv.as_bytes().into())),
147
_ => unreachable!(),
148
}
149
},
150
AExpr::Function {
151
input,
152
function: IRFunctionExpr::Boolean(IRBooleanFunction::IsNull),
153
options: _,
154
} => {
155
assert_eq!(input.len(), 1);
156
if into_column(input[0].node(), expr_arena)
157
.is_some()
158
{
159
Some(SpecializedColumnPredicate::Equal(Scalar::null(
160
dtype,
161
)))
162
} else {
163
None
164
}
165
},
166
#[cfg(feature = "is_between")]
167
AExpr::Function {
168
input,
169
function: IRFunctionExpr::Boolean(IRBooleanFunction::IsBetween { closed }),
170
options: _,
171
} => {
172
let (Some(l), Some(r)) = (
173
constant_evaluate(
174
input[1].node(),
175
expr_arena,
176
schema,
177
0,
178
)?,
179
constant_evaluate(
180
input[2].node(),
181
expr_arena,
182
schema,
183
0,
184
)?,
185
) else {
186
return None;
187
};
188
let l = l.to_any_value()?;
189
let r = r.to_any_value()?;
190
if l.dtype() != dtype || r.dtype() != dtype {
191
return None;
192
}
193
194
let (low_closed, high_closed) = match closed {
195
ClosedInterval::Both => (true, true),
196
ClosedInterval::Left => (true, false),
197
ClosedInterval::Right => (false, true),
198
ClosedInterval::None => (false, false),
199
};
200
is_between(
201
&dtype,
202
Some(Scalar::new(dtype.clone(), l.into_static())),
203
Some(Scalar::new(dtype.clone(), r.into_static())),
204
low_closed,
205
high_closed,
206
)
207
},
208
#[cfg(feature = "is_in")]
209
AExpr::Function {
210
input,
211
function: IRFunctionExpr::Boolean(IRBooleanFunction::IsIn { nulls_equal }),
212
options: _,
213
} => {
214
into_column(input[0].node(), expr_arena)?;
215
216
let values = constant_evaluate(
217
input[1].node(),
218
expr_arena,
219
schema,
220
0,
221
)??;
222
let values = values.to_any_value()?;
223
224
let values = match values {
225
AnyValue::List(v) => v,
226
#[cfg(feature = "dtype-array")]
227
AnyValue::Array(v, _) => v,
228
_ => return None,
229
};
230
231
if values.dtype() != &dtype {
232
return None;
233
}
234
if !nulls_equal && values.has_nulls() {
235
return None;
236
}
237
238
let values = values.iter()
239
.map(|av| {
240
Scalar::new(dtype.clone(), av.into_static())
241
})
242
.collect();
243
244
Some(SpecializedColumnPredicate::EqualOneOf(values))
245
},
246
AExpr::Function {
247
input,
248
function: IRFunctionExpr::Boolean(IRBooleanFunction::Not),
249
options: _,
250
} => {
251
if !dtype.is_bool() {
252
return None;
253
}
254
255
assert_eq!(input.len(), 1);
256
if into_column(input[0].node(), expr_arena)
257
.is_some()
258
{
259
Some(SpecializedColumnPredicate::Equal(false.into()))
260
} else {
261
None
262
}
263
},
264
AExpr::BinaryExpr { left, op, right } => {
265
let ((_, _), (lv, lv_node)) =
266
get_binary_expr_col_and_lv(*left, *right, expr_arena, schema)?;
267
let lv = lv?;
268
let av = lv.to_any_value()?;
269
if av.dtype() != dtype {
270
return None;
271
}
272
let scalar = Scalar::new(dtype.clone(), av.into_static());
273
use Operator as O;
274
match (op, lv_node == *right) {
275
(O::Eq, _) if scalar.is_null() => None,
276
(O::Eq | O::EqValidity, _) => {
277
Some(SpecializedColumnPredicate::Equal(scalar))
278
},
279
(O::Lt, true) | (O::Gt, false) => {
280
is_between(&dtype, None, Some(scalar), false, false)
281
},
282
(O::Lt, false) | (O::Gt, true) => {
283
is_between(&dtype, Some(scalar), None, false, false)
284
},
285
(O::LtEq, true) | (O::GtEq, false) => {
286
is_between(&dtype, None, Some(scalar), false, true)
287
},
288
(O::LtEq, false) | (O::GtEq, true) => {
289
is_between(&dtype, Some(scalar), None, true, false)
290
},
291
_ => None,
292
}
293
},
294
_ => None,
295
}
296
}),
297
)
298
});
299
}
300
301
ColumnPredicates {
302
predicates,
303
is_sumwise_complete,
304
}
305
}
306
307
fn is_between(
308
dtype: &DataType,
309
low: Option<Scalar>,
310
high: Option<Scalar>,
311
mut low_closed: bool,
312
mut high_closed: bool,
313
) -> Option<SpecializedColumnPredicate> {
314
let dtype = dtype.to_physical();
315
316
if !dtype.is_integer() {
317
return None;
318
}
319
assert!(low.is_some() || high.is_some());
320
321
low_closed |= low.is_none();
322
high_closed |= high.is_none();
323
324
let mut low = low.map_or_else(|| dtype.min().unwrap(), |sc| sc.to_physical());
325
let mut high = high.map_or_else(|| dtype.max().unwrap(), |sc| sc.to_physical());
326
327
macro_rules! ints {
328
($($t:ident),+) => {
329
match (low.any_value_mut(), high.any_value_mut()) {
330
$(
331
(AV::$t(l), AV::$t(h)) => {
332
if !low_closed {
333
*l = l.checked_add(1)?;
334
}
335
if !high_closed {
336
*h = h.checked_sub(1)?;
337
}
338
if *l > *h {
339
// Really this ought to indicate that nothing should be
340
// loaded since the condition is impossible, but unclear
341
// how to do that at this abstraction layer. Could add
342
// SpecializedColumnPredicate::Impossible or something,
343
// maybe.
344
return None;
345
}
346
},
347
)+
348
_ => return None,
349
}
350
};
351
}
352
353
use AnyValue as AV;
354
ints!(
355
Int8, Int16, Int32, Int64, Int128, UInt8, UInt16, UInt32, UInt64
356
);
357
358
Some(SpecializedColumnPredicate::Between(low, high))
359
}
360
361
#[cfg(test)]
362
mod tests {
363
use polars_error::PolarsResult;
364
365
use super::*;
366
use crate::dsl::{Expr, col};
367
use crate::plans::{ExprToIRContext, to_expr_ir, typed_lit};
368
369
/// Given a single-column `Expr`, call `aexpr_to_column_predicates()` and
370
/// return the corresponding column's `Option<SpecializedColumnPredicate>`.
371
fn column_predicate_for_expr(
372
col_dtype: DataType,
373
col_name: &str,
374
expr: Expr,
375
) -> PolarsResult<Option<SpecializedColumnPredicate>> {
376
let mut arena = Arena::new();
377
let schema = Schema::from_iter_check_duplicates([(col_name.into(), col_dtype)])?;
378
let mut ctx = ExprToIRContext::new(&mut arena, &schema);
379
let expr_ir = to_expr_ir(expr, &mut ctx)?;
380
let column_predicates = aexpr_to_column_predicates(expr_ir.node(), &mut arena, &schema);
381
assert_eq!(column_predicates.predicates.len(), 1);
382
let Some((col_name2, (_, predicate))) =
383
column_predicates.predicates.clone().into_iter().next()
384
else {
385
panic!(
386
"Unexpected column predicates: {:?}",
387
column_predicates.predicates
388
);
389
};
390
assert_eq!(col_name, col_name2);
391
Ok(predicate)
392
}
393
394
#[test]
395
fn column_predicate_for_inequality_operators() -> PolarsResult<()> {
396
let col_name = "testcol";
397
// Array of (expr, expected minimum, expected maximum):
398
let test_values: [(Expr, i8, i8); _] = [
399
(col(col_name).lt(typed_lit(10i8)), -128, 9),
400
(col(col_name).lt(typed_lit(-11i8)), -128, -12),
401
(col(col_name).gt(typed_lit(17i8)), 18, 127),
402
(col(col_name).gt(typed_lit(-10i8)), -9, 127),
403
(col(col_name).lt_eq(typed_lit(10i8)), -128, 10),
404
(col(col_name).lt_eq(typed_lit(-11i8)), -128, -11),
405
(col(col_name).gt_eq(typed_lit(17i8)), 17, 127),
406
(col(col_name).gt_eq(typed_lit(-10i8)), -10, 127),
407
];
408
for (expr, expected_min, expected_max) in test_values {
409
let predicate = column_predicate_for_expr(DataType::Int8, col_name, expr.clone())?;
410
if let Some(SpecializedColumnPredicate::Between(actual_min, actual_max)) = predicate {
411
assert_eq!(
412
(expected_min.into(), expected_max.into()),
413
(actual_min, actual_max)
414
);
415
} else {
416
panic!("{predicate:?} is unexpected for {expr:?}");
417
}
418
}
419
Ok(())
420
}
421
422
#[test]
423
fn column_predicate_is_between() -> PolarsResult<()> {
424
let col_name = "testcol";
425
// ClosedInterval, expected min, expected max:
426
let test_values: [(_, i8, i8); _] = [
427
(ClosedInterval::Both, 1, 10),
428
(ClosedInterval::Left, 1, 9),
429
(ClosedInterval::Right, 2, 10),
430
(ClosedInterval::None, 2, 9),
431
];
432
for (interval, expected_min, expected_max) in test_values {
433
let expr = col(col_name).is_between(typed_lit(1i8), typed_lit(10i8), interval);
434
let predicate = column_predicate_for_expr(DataType::Int8, col_name, expr.clone())?;
435
if let Some(SpecializedColumnPredicate::Between(actual_min, actual_max)) = predicate {
436
assert_eq!(
437
(expected_min.into(), expected_max.into()),
438
(actual_min, actual_max)
439
);
440
} else {
441
panic!("{predicate:?} is unexpected for {expr:?}");
442
}
443
}
444
Ok(())
445
}
446
}
447
448