Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/plans/optimizer/simplify_expr/simplify_functions.rs
7889 views
1
use super::*;
2
3
pub(super) fn optimize_functions(
4
input: Vec<ExprIR>,
5
function: IRFunctionExpr,
6
options: FunctionOptions,
7
expr_arena: &mut Arena<AExpr>,
8
) -> PolarsResult<Option<AExpr>> {
9
let out = match function {
10
// is_null().any() -> null_count() > 0
11
// is_not_null().any() -> null_count() < len()
12
// CORRECTNESS: we can ignore 'ignore_nulls' since is_null/is_not_null never produces NULLS
13
IRFunctionExpr::Boolean(IRBooleanFunction::Any { ignore_nulls: _ }) => {
14
let input_node = expr_arena.get(input[0].node());
15
match input_node {
16
AExpr::Function {
17
input,
18
function: IRFunctionExpr::Boolean(IRBooleanFunction::IsNull),
19
options: _,
20
} => Some(AExpr::BinaryExpr {
21
left: expr_arena.add(new_null_count(input)),
22
op: Operator::Gt,
23
right: expr_arena.add(AExpr::Literal(LiteralValue::new_idxsize(0))),
24
}),
25
AExpr::Function {
26
input,
27
function: IRFunctionExpr::Boolean(IRBooleanFunction::IsNotNull),
28
options: _,
29
} => {
30
// we should perform optimization only if the original expression is a column
31
// so in case of disabled CSE, we will not suffer from performance regression
32
if input.len() == 1 {
33
let is_not_null_input_node = input[0].node();
34
match expr_arena.get(is_not_null_input_node) {
35
AExpr::Column(_) => Some(AExpr::BinaryExpr {
36
op: Operator::Lt,
37
left: expr_arena.add(new_null_count(input)),
38
right: expr_arena.add(AExpr::Agg(IRAggExpr::Count {
39
input: is_not_null_input_node,
40
include_nulls: true,
41
})),
42
}),
43
_ => None,
44
}
45
} else {
46
None
47
}
48
},
49
_ => None,
50
}
51
},
52
// is_null().all() -> null_count() == len()
53
// is_not_null().all() -> null_count() == 0
54
IRFunctionExpr::Boolean(IRBooleanFunction::All { ignore_nulls: _ }) => {
55
let input_node = expr_arena.get(input[0].node());
56
match input_node {
57
AExpr::Function {
58
input,
59
function: IRFunctionExpr::Boolean(IRBooleanFunction::IsNull),
60
options: _,
61
} => {
62
// we should perform optimization only if the original expression is a column
63
// so in case of disabled CSE, we will not suffer from performance regression
64
if input.len() == 1 {
65
let is_null_input_node = input[0].node();
66
match expr_arena.get(is_null_input_node) {
67
AExpr::Column(_) => Some(AExpr::BinaryExpr {
68
op: Operator::Eq,
69
right: expr_arena.add(new_null_count(input)),
70
left: expr_arena.add(AExpr::Agg(IRAggExpr::Count {
71
input: is_null_input_node,
72
include_nulls: true,
73
})),
74
}),
75
_ => None,
76
}
77
} else {
78
None
79
}
80
},
81
AExpr::Function {
82
input,
83
function: IRFunctionExpr::Boolean(IRBooleanFunction::IsNotNull),
84
options: _,
85
} => Some(AExpr::BinaryExpr {
86
left: expr_arena.add(new_null_count(input)),
87
op: Operator::Eq,
88
right: expr_arena.add(AExpr::Literal(LiteralValue::new_idxsize(0))),
89
}),
90
_ => None,
91
}
92
},
93
// sort().reverse() -> sort(reverse)
94
// sort_by().reverse() -> sort_by(reverse)
95
IRFunctionExpr::Reverse => {
96
let input = expr_arena.get(input[0].node());
97
match input {
98
AExpr::Sort { expr, options } => {
99
let mut options = *options;
100
options.descending = !options.descending;
101
Some(AExpr::Sort {
102
expr: *expr,
103
options,
104
})
105
},
106
AExpr::SortBy {
107
expr,
108
by,
109
sort_options,
110
} => {
111
let mut sort_options = sort_options.clone();
112
let reversed_descending = sort_options.descending.iter().map(|x| !*x).collect();
113
sort_options.descending = reversed_descending;
114
Some(AExpr::SortBy {
115
expr: *expr,
116
by: by.clone(),
117
sort_options,
118
})
119
},
120
// TODO: add support for cum_sum and other operation that allow reversing.
121
_ => None,
122
}
123
},
124
// flatten nested concat_str calls
125
#[cfg(all(feature = "strings", feature = "concat_str"))]
126
ref function @ IRFunctionExpr::StringExpr(IRStringFunction::ConcatHorizontal {
127
delimiter: ref sep,
128
ignore_nulls,
129
}) if sep.is_empty() => {
130
if input
131
.iter()
132
.any(|e| is_string_concat(expr_arena.get(e.node()), ignore_nulls))
133
{
134
let mut new_inputs = Vec::with_capacity(input.len() * 2);
135
136
for e in input {
137
match get_string_concat_input(e.node(), expr_arena, ignore_nulls) {
138
Some(inp) => new_inputs.extend_from_slice(inp),
139
None => new_inputs.push(e.clone()),
140
}
141
}
142
Some(AExpr::Function {
143
input: new_inputs,
144
function: function.clone(),
145
options,
146
})
147
} else {
148
None
149
}
150
},
151
IRFunctionExpr::Boolean(IRBooleanFunction::Not) => {
152
let y = expr_arena.get(input[0].node());
153
154
match y {
155
// not(a and b) => not(a) or not(b)
156
AExpr::BinaryExpr {
157
left,
158
op: Operator::And | Operator::LogicalAnd,
159
right,
160
} => {
161
let left = *left;
162
let right = *right;
163
Some(AExpr::BinaryExpr {
164
left: expr_arena.add(AExpr::Function {
165
input: vec![ExprIR::from_node(left, expr_arena)],
166
function: IRFunctionExpr::Boolean(IRBooleanFunction::Not),
167
options,
168
}),
169
op: Operator::Or,
170
right: expr_arena.add(AExpr::Function {
171
input: vec![ExprIR::from_node(right, expr_arena)],
172
function: IRFunctionExpr::Boolean(IRBooleanFunction::Not),
173
options,
174
}),
175
})
176
},
177
// not(a or b) => not(a) and not(b)
178
AExpr::BinaryExpr {
179
left,
180
op: Operator::Or | Operator::LogicalOr,
181
right,
182
} => {
183
let left = *left;
184
let right = *right;
185
Some(AExpr::BinaryExpr {
186
left: expr_arena.add(AExpr::Function {
187
input: vec![ExprIR::from_node(left, expr_arena)],
188
function: IRFunctionExpr::Boolean(IRBooleanFunction::Not),
189
options,
190
}),
191
op: Operator::And,
192
right: expr_arena.add(AExpr::Function {
193
input: vec![ExprIR::from_node(right, expr_arena)],
194
function: IRFunctionExpr::Boolean(IRBooleanFunction::Not),
195
options,
196
}),
197
})
198
},
199
// not(not x) => x
200
AExpr::Function {
201
input,
202
function: IRFunctionExpr::Boolean(IRBooleanFunction::Not),
203
..
204
} => Some(expr_arena.get(input[0].node()).clone()),
205
// not(lit x) => !x
206
AExpr::Literal(lv) if lv.bool().is_some() => {
207
Some(AExpr::Literal(Scalar::from(!lv.bool().unwrap()).into()))
208
},
209
// not(x.is_null) => x.is_not_null
210
AExpr::Function {
211
input,
212
function: IRFunctionExpr::Boolean(IRBooleanFunction::IsNull),
213
options,
214
} => Some(AExpr::Function {
215
input: input.clone(),
216
function: IRFunctionExpr::Boolean(IRBooleanFunction::IsNotNull),
217
options: *options,
218
}),
219
// not(x.is_not_null) => x.is_null
220
AExpr::Function {
221
input,
222
function: IRFunctionExpr::Boolean(IRBooleanFunction::IsNotNull),
223
options,
224
} => Some(AExpr::Function {
225
input: input.clone(),
226
function: IRFunctionExpr::Boolean(IRBooleanFunction::IsNull),
227
options: *options,
228
}),
229
// not(a == b) => a != b
230
AExpr::BinaryExpr {
231
left,
232
op: Operator::Eq,
233
right,
234
} => Some(AExpr::BinaryExpr {
235
left: *left,
236
op: Operator::NotEq,
237
right: *right,
238
}),
239
// not(a != b) => a == b
240
AExpr::BinaryExpr {
241
left,
242
op: Operator::NotEq,
243
right,
244
} => Some(AExpr::BinaryExpr {
245
left: *left,
246
op: Operator::Eq,
247
right: *right,
248
}),
249
// not(a < b) => a >= b
250
AExpr::BinaryExpr {
251
left,
252
op: Operator::Lt,
253
right,
254
} => Some(AExpr::BinaryExpr {
255
left: *left,
256
op: Operator::GtEq,
257
right: *right,
258
}),
259
// not(a <= b) => a > b
260
AExpr::BinaryExpr {
261
left,
262
op: Operator::LtEq,
263
right,
264
} => Some(AExpr::BinaryExpr {
265
left: *left,
266
op: Operator::Gt,
267
right: *right,
268
}),
269
// not(a > b) => a <= b
270
AExpr::BinaryExpr {
271
left,
272
op: Operator::Gt,
273
right,
274
} => Some(AExpr::BinaryExpr {
275
left: *left,
276
op: Operator::LtEq,
277
right: *right,
278
}),
279
// not(a >= b) => a < b
280
AExpr::BinaryExpr {
281
left,
282
op: Operator::GtEq,
283
right,
284
} => Some(AExpr::BinaryExpr {
285
left: *left,
286
op: Operator::Lt,
287
right: *right,
288
}),
289
#[cfg(feature = "is_between")]
290
// not(col('x').is_between(a,b)) => col('x') < a || col('x') > b
291
AExpr::Function {
292
input,
293
function: IRFunctionExpr::Boolean(IRBooleanFunction::IsBetween { closed }),
294
..
295
} => {
296
if !matches!(expr_arena.get(input[0].node()), AExpr::Column(_)) {
297
None
298
} else {
299
let left_cmp_op = match closed {
300
ClosedInterval::Both | ClosedInterval::Left => Operator::Lt,
301
ClosedInterval::None | ClosedInterval::Right => Operator::LtEq,
302
};
303
let right_cmp_op = match closed {
304
ClosedInterval::Both | ClosedInterval::Right => Operator::Gt,
305
ClosedInterval::None | ClosedInterval::Left => Operator::GtEq,
306
};
307
let left_left = input[0].node();
308
let right_left = input[1].node();
309
310
let left_right = left_left;
311
let right_right = input[2].node();
312
313
// input[0] is between input[1] and input[2]
314
Some(AExpr::BinaryExpr {
315
// input[0] (<,<=) input[1]
316
left: expr_arena.add(AExpr::BinaryExpr {
317
left: left_left,
318
op: left_cmp_op,
319
right: right_left,
320
}),
321
// OR
322
op: Operator::Or,
323
// input[0] (>,>=) input[2]
324
right: expr_arena.add(AExpr::BinaryExpr {
325
left: left_right,
326
op: right_cmp_op,
327
right: right_right,
328
}),
329
})
330
}
331
},
332
_ => None,
333
}
334
},
335
IRFunctionExpr::GatherEvery { n: 1, offset: 0 } => {
336
Some(expr_arena.get(input[0].node()).clone())
337
},
338
IRFunctionExpr::GatherEvery { n: 1, offset } => {
339
let offset_i64: i64 = offset.try_into().unwrap_or(i64::MAX);
340
let offset_node =
341
expr_arena.add(AExpr::Literal(LiteralValue::Scalar(offset_i64.into())));
342
let length_node = expr_arena.add(AExpr::Literal(LiteralValue::Scalar(
343
(usize::MAX as u64).into(),
344
)));
345
Some(AExpr::Slice {
346
input: input[0].node(),
347
offset: offset_node,
348
length: length_node,
349
})
350
},
351
_ => None,
352
};
353
Ok(out)
354
}
355
356
#[cfg(all(feature = "strings", feature = "concat_str"))]
357
fn is_string_concat(ae: &AExpr, ignore_nulls: bool) -> bool {
358
matches!(ae, AExpr::Function {
359
function:IRFunctionExpr::StringExpr(
360
IRStringFunction::ConcatHorizontal{delimiter: sep, ignore_nulls: func_inore_nulls},
361
),
362
..
363
} if sep.is_empty() && *func_inore_nulls == ignore_nulls)
364
}
365
366
#[cfg(all(feature = "strings", feature = "concat_str"))]
367
fn get_string_concat_input(
368
node: Node,
369
expr_arena: &Arena<AExpr>,
370
ignore_nulls: bool,
371
) -> Option<&[ExprIR]> {
372
match expr_arena.get(node) {
373
AExpr::Function {
374
input,
375
function:
376
IRFunctionExpr::StringExpr(IRStringFunction::ConcatHorizontal {
377
delimiter: sep,
378
ignore_nulls: func_ignore_nulls,
379
}),
380
..
381
} if sep.is_empty() && *func_ignore_nulls == ignore_nulls => Some(input),
382
_ => None,
383
}
384
}
385
386