Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-expr/src/planner.rs
8420 views
1
use polars_core::prelude::*;
2
use polars_plan::constants::{get_literal_name, get_pl_element_name, get_pl_structfields_name};
3
use polars_plan::prelude::expr_ir::ExprIR;
4
use polars_plan::prelude::*;
5
use recursive::recursive;
6
7
use crate::dispatch::{function_expr_to_groups_udf, function_expr_to_udf};
8
use crate::expressions as phys_expr;
9
use crate::expressions::*;
10
use crate::reduce::GroupedReduction;
11
12
pub fn get_expr_depth_limit() -> PolarsResult<u16> {
13
let depth = if let Ok(d) = std::env::var("POLARS_MAX_EXPR_DEPTH") {
14
let v = d
15
.parse::<u64>()
16
.map_err(|_| polars_err!(ComputeError: "could not parse 'max_expr_depth': {}", d))?;
17
u16::try_from(v).unwrap_or(0)
18
} else {
19
512
20
};
21
Ok(depth)
22
}
23
24
fn ok_checker(_i: usize, _state: &ExpressionConversionState) -> PolarsResult<()> {
25
Ok(())
26
}
27
28
pub fn create_physical_expressions_from_irs(
29
exprs: &[ExprIR],
30
expr_arena: &mut Arena<AExpr>,
31
schema: &SchemaRef,
32
state: &mut ExpressionConversionState,
33
) -> PolarsResult<Vec<Arc<dyn PhysicalExpr>>> {
34
create_physical_expressions_check_state(exprs, expr_arena, schema, state, ok_checker)
35
}
36
37
pub(crate) fn create_physical_expressions_check_state<F>(
38
exprs: &[ExprIR],
39
expr_arena: &mut Arena<AExpr>,
40
schema: &SchemaRef,
41
state: &mut ExpressionConversionState,
42
checker: F,
43
) -> PolarsResult<Vec<Arc<dyn PhysicalExpr>>>
44
where
45
F: Fn(usize, &ExpressionConversionState) -> PolarsResult<()>,
46
{
47
exprs
48
.iter()
49
.enumerate()
50
.map(|(i, e)| {
51
state.reset();
52
let out = create_physical_expr(e, expr_arena, schema, state);
53
checker(i, state)?;
54
out
55
})
56
.collect()
57
}
58
59
pub(crate) fn create_physical_expressions_from_nodes(
60
exprs: &[Node],
61
expr_arena: &mut Arena<AExpr>,
62
schema: &SchemaRef,
63
state: &mut ExpressionConversionState,
64
) -> PolarsResult<Vec<Arc<dyn PhysicalExpr>>> {
65
create_physical_expressions_from_nodes_check_state(exprs, expr_arena, schema, state, ok_checker)
66
}
67
68
pub(crate) fn create_physical_expressions_from_nodes_check_state<F>(
69
exprs: &[Node],
70
expr_arena: &mut Arena<AExpr>,
71
schema: &SchemaRef,
72
state: &mut ExpressionConversionState,
73
checker: F,
74
) -> PolarsResult<Vec<Arc<dyn PhysicalExpr>>>
75
where
76
F: Fn(usize, &ExpressionConversionState) -> PolarsResult<()>,
77
{
78
exprs
79
.iter()
80
.enumerate()
81
.map(|(i, e)| {
82
state.reset();
83
let out = create_physical_expr_inner(*e, expr_arena, schema, state);
84
checker(i, state)?;
85
out
86
})
87
.collect()
88
}
89
90
#[derive(Copy, Clone)]
91
pub struct ExpressionConversionState {
92
// settings per context
93
// they remain activate between
94
// expressions
95
pub allow_threading: bool,
96
pub has_windows: bool,
97
// settings per expression
98
// those are reset every expression
99
local: LocalConversionState,
100
}
101
102
#[derive(Copy, Clone, Default)]
103
struct LocalConversionState {
104
has_window: bool,
105
has_lit: bool,
106
}
107
108
impl ExpressionConversionState {
109
pub fn new(allow_threading: bool) -> Self {
110
Self {
111
allow_threading,
112
has_windows: false,
113
local: LocalConversionState {
114
..Default::default()
115
},
116
}
117
}
118
119
fn reset(&mut self) {
120
self.local = LocalConversionState::default();
121
}
122
123
fn set_window(&mut self) {
124
self.has_windows = true;
125
self.local.has_window = true;
126
}
127
}
128
129
pub fn create_physical_expr(
130
expr_ir: &ExprIR,
131
expr_arena: &mut Arena<AExpr>,
132
schema: &SchemaRef, // Schema of the input.
133
state: &mut ExpressionConversionState,
134
) -> PolarsResult<Arc<dyn PhysicalExpr>> {
135
let phys_expr = create_physical_expr_inner(expr_ir.node(), expr_arena, schema, state)?;
136
137
if let Some(name) = expr_ir.get_alias() {
138
Ok(Arc::new(AliasExpr::new(
139
phys_expr,
140
name.clone(),
141
node_to_expr(expr_ir.node(), expr_arena),
142
)))
143
} else {
144
Ok(phys_expr)
145
}
146
}
147
148
#[recursive]
149
fn create_physical_expr_inner(
150
expression: Node,
151
expr_arena: &mut Arena<AExpr>,
152
schema: &SchemaRef, // Schema of the input.
153
state: &mut ExpressionConversionState,
154
) -> PolarsResult<Arc<dyn PhysicalExpr>> {
155
use AExpr::*;
156
157
let aexpr = expr_arena.get(expression);
158
match aexpr.clone() {
159
Len => Ok(Arc::new(phys_expr::CountExpr::new())),
160
#[cfg(feature = "dynamic_group_by")]
161
Rolling {
162
function,
163
index_column,
164
period,
165
offset,
166
closed_window,
167
} => {
168
let output_field = aexpr.to_field(&ToFieldContext::new(expr_arena, schema))?;
169
let index_column = create_physical_expr_inner(index_column, expr_arena, schema, state)?;
170
171
state.set_window();
172
let phys_function = create_physical_expr_inner(function, expr_arena, schema, state)?;
173
let expr = node_to_expr(expression, expr_arena);
174
175
// set again as the state can be reset
176
state.set_window();
177
Ok(Arc::new(RollingExpr {
178
phys_function,
179
index_column,
180
period,
181
offset,
182
closed_window,
183
expr,
184
output_field,
185
}))
186
},
187
Over {
188
function,
189
partition_by,
190
order_by,
191
mapping,
192
} => {
193
let output_field = aexpr.to_field(&ToFieldContext::new(expr_arena, schema))?;
194
state.set_window();
195
let phys_function = create_physical_expr_inner(function, expr_arena, schema, state)?;
196
197
let mut order_by_is_elementwise = false;
198
let order_by = order_by
199
.map(|(node, options)| {
200
order_by_is_elementwise |= is_elementwise_rec(node, expr_arena);
201
PolarsResult::Ok((
202
create_physical_expr_inner(node, expr_arena, schema, state)?,
203
options,
204
))
205
})
206
.transpose()?;
207
208
let expr = node_to_expr(expression, expr_arena);
209
210
// set again as the state can be reset
211
state.set_window();
212
let all_group_by_are_elementwise = partition_by
213
.iter()
214
.all(|n| is_elementwise_rec(*n, expr_arena));
215
let group_by =
216
create_physical_expressions_from_nodes(&partition_by, expr_arena, schema, state)?;
217
let mut apply_columns = aexpr_to_leaf_names(function, expr_arena);
218
// sort and then dedup removes consecutive duplicates == all duplicates
219
apply_columns.sort();
220
apply_columns.dedup();
221
222
if apply_columns.is_empty() {
223
if has_aexpr(function, expr_arena, |e| matches!(e, AExpr::Literal(_))) {
224
apply_columns.push(get_literal_name())
225
} else if has_aexpr(function, expr_arena, |e| matches!(e, AExpr::Len)) {
226
apply_columns.push(PlSmallStr::from_static("len"))
227
} else if has_aexpr(function, expr_arena, |e| matches!(e, AExpr::Element)) {
228
apply_columns.push(PlSmallStr::from_static("element"))
229
} else {
230
let e = node_to_expr(function, expr_arena);
231
polars_bail!(
232
ComputeError:
233
"cannot apply a window function, did not find a root column; \
234
this is likely due to a syntax error in this expression: {:?}", e
235
);
236
}
237
}
238
239
// Check if the branches have an aggregation
240
// when(a > sum)
241
// then (foo)
242
// otherwise(bar - sum)
243
let mut has_arity = false;
244
let mut agg_col = false;
245
for (_, e) in expr_arena.iter(function) {
246
match e {
247
AExpr::Ternary { .. } | AExpr::BinaryExpr { .. } => {
248
has_arity = true;
249
},
250
AExpr::Agg(_) => {
251
agg_col = true;
252
},
253
AExpr::Function { options, .. } | AExpr::AnonymousFunction { options, .. } => {
254
if options.flags.returns_scalar() {
255
agg_col = true;
256
}
257
},
258
_ => {},
259
}
260
}
261
let has_different_group_sources = has_arity && agg_col;
262
263
Ok(Arc::new(WindowExpr {
264
group_by,
265
order_by,
266
apply_columns,
267
phys_function,
268
mapping,
269
expr,
270
has_different_group_sources,
271
output_field,
272
273
order_by_is_elementwise,
274
all_group_by_are_elementwise,
275
}))
276
},
277
Literal(value) => {
278
state.local.has_lit = true;
279
Ok(Arc::new(LiteralExpr::new(
280
value.clone(),
281
node_to_expr(expression, expr_arena),
282
)))
283
},
284
BinaryExpr { left, op, right } => {
285
let output_field = expr_arena
286
.get(expression)
287
.to_field(&ToFieldContext::new(expr_arena, schema))?;
288
let is_scalar = is_scalar_ae(expression, expr_arena);
289
let lhs = create_physical_expr_inner(left, expr_arena, schema, state)?;
290
let rhs = create_physical_expr_inner(right, expr_arena, schema, state)?;
291
Ok(Arc::new(phys_expr::BinaryExpr::new(
292
lhs,
293
op,
294
rhs,
295
node_to_expr(expression, expr_arena),
296
state.local.has_lit,
297
state.allow_threading,
298
is_scalar,
299
output_field,
300
)))
301
},
302
Column(column) => Ok(Arc::new(ColumnExpr::new(
303
column.clone(),
304
node_to_expr(expression, expr_arena),
305
schema.clone(),
306
))),
307
Element => {
308
let output_field = expr_arena
309
.get(expression)
310
.to_field(&ToFieldContext::new(expr_arena, schema))?;
311
312
Ok(Arc::new(ElementExpr::new(output_field)))
313
},
314
#[cfg(feature = "dtype-struct")]
315
StructField(field) => {
316
let output_field = expr_arena
317
.get(expression)
318
.to_field(&ToFieldContext::new(expr_arena, schema))?;
319
320
Ok(Arc::new(FieldExpr::new(
321
field.clone(),
322
node_to_expr(expression, expr_arena),
323
output_field,
324
)))
325
},
326
Sort { expr, options } => {
327
let phys_expr = create_physical_expr_inner(expr, expr_arena, schema, state)?;
328
Ok(Arc::new(SortExpr::new(
329
phys_expr,
330
options,
331
node_to_expr(expression, expr_arena),
332
)))
333
},
334
Gather {
335
expr,
336
idx,
337
returns_scalar,
338
null_on_oob,
339
} => {
340
let phys_expr = create_physical_expr_inner(expr, expr_arena, schema, state)?;
341
let phys_idx = create_physical_expr_inner(idx, expr_arena, schema, state)?;
342
Ok(Arc::new(GatherExpr {
343
phys_expr,
344
idx: phys_idx,
345
expr: node_to_expr(expression, expr_arena),
346
returns_scalar,
347
null_on_oob,
348
}))
349
},
350
SortBy {
351
expr,
352
by,
353
sort_options,
354
} => {
355
let phys_expr = create_physical_expr_inner(expr, expr_arena, schema, state)?;
356
let phys_by = create_physical_expressions_from_nodes(&by, expr_arena, schema, state)?;
357
Ok(Arc::new(SortByExpr::new(
358
phys_expr,
359
phys_by,
360
node_to_expr(expression, expr_arena),
361
sort_options.clone(),
362
)))
363
},
364
Filter { input, by } => {
365
let phys_input = create_physical_expr_inner(input, expr_arena, schema, state)?;
366
let phys_by = create_physical_expr_inner(by, expr_arena, schema, state)?;
367
Ok(Arc::new(FilterExpr::new(
368
phys_input,
369
phys_by,
370
node_to_expr(expression, expr_arena),
371
)))
372
},
373
Agg(agg) => {
374
let expr = agg.get_input().first();
375
let input = create_physical_expr_inner(expr, expr_arena, schema, state)?;
376
let allow_threading = state.allow_threading;
377
378
let output_field = expr_arena
379
.get(expression)
380
.to_field(&ToFieldContext::new(expr_arena, schema))?;
381
382
// Special case: Quantile supports multiple inputs.
383
// TODO refactor to FunctionExpr.
384
if let IRAggExpr::Quantile {
385
quantile, method, ..
386
} = agg
387
{
388
let quantile = create_physical_expr_inner(quantile, expr_arena, schema, state)?;
389
return Ok(Arc::new(AggQuantileExpr::new(input, quantile, method)));
390
}
391
392
let groupby = GroupByMethod::from(agg.clone());
393
394
let agg_type = AggregationType {
395
groupby,
396
allow_threading,
397
};
398
399
Ok(Arc::new(AggregationExpr::new(
400
input,
401
agg_type,
402
output_field,
403
)))
404
},
405
Function {
406
input,
407
function: function @ (IRFunctionExpr::ArgMin | IRFunctionExpr::ArgMax),
408
options: _,
409
} => {
410
let phys_input =
411
create_physical_expr_inner(input[0].node(), expr_arena, schema, state)?;
412
413
let mut output_field = expr_arena
414
.get(expression)
415
.to_field(&ToFieldContext::new(expr_arena, schema))?;
416
output_field = Field::new(output_field.name().clone(), IDX_DTYPE.clone());
417
418
let groupby = match function {
419
IRFunctionExpr::ArgMin => GroupByMethod::ArgMin,
420
IRFunctionExpr::ArgMax => GroupByMethod::ArgMax,
421
_ => unreachable!(), // guaranteed by pattern
422
};
423
424
let agg_type = AggregationType {
425
groupby,
426
allow_threading: state.allow_threading,
427
};
428
429
Ok(Arc::new(AggregationExpr::new(
430
phys_input,
431
agg_type,
432
output_field,
433
)))
434
},
435
Function {
436
input: inputs,
437
function: function @ (IRFunctionExpr::MinBy | IRFunctionExpr::MaxBy),
438
options: _,
439
} => {
440
assert!(inputs.len() == 2);
441
let new_minmax_by = match function {
442
IRFunctionExpr::MinBy => AggMinMaxByExpr::new_min_by,
443
IRFunctionExpr::MaxBy => AggMinMaxByExpr::new_max_by,
444
_ => unreachable!(), // guaranteed by pattern
445
};
446
let input = create_physical_expr_inner(inputs[0].node(), expr_arena, schema, state)?;
447
let by = create_physical_expr_inner(inputs[1].node(), expr_arena, schema, state)?;
448
return Ok(Arc::new(new_minmax_by(input, by)));
449
},
450
Cast {
451
expr,
452
dtype,
453
options,
454
} => {
455
let phys_expr = create_physical_expr_inner(expr, expr_arena, schema, state)?;
456
Ok(Arc::new(CastExpr {
457
input: phys_expr,
458
dtype: dtype.clone(),
459
expr: node_to_expr(expression, expr_arena),
460
options,
461
}))
462
},
463
Ternary {
464
predicate,
465
truthy,
466
falsy,
467
} => {
468
let is_scalar = is_scalar_ae(expression, expr_arena);
469
let mut lit_count = 0u8;
470
state.reset();
471
let predicate = create_physical_expr_inner(predicate, expr_arena, schema, state)?;
472
lit_count += state.local.has_lit as u8;
473
state.reset();
474
let truthy = create_physical_expr_inner(truthy, expr_arena, schema, state)?;
475
lit_count += state.local.has_lit as u8;
476
state.reset();
477
let falsy = create_physical_expr_inner(falsy, expr_arena, schema, state)?;
478
lit_count += state.local.has_lit as u8;
479
Ok(Arc::new(TernaryExpr::new(
480
predicate,
481
truthy,
482
falsy,
483
node_to_expr(expression, expr_arena),
484
state.allow_threading && lit_count < 2,
485
is_scalar,
486
)))
487
},
488
AExpr::AnonymousAgg {
489
input,
490
fmt_str: _,
491
function,
492
} => {
493
let output_field = expr_arena
494
.get(expression)
495
.to_field(&ToFieldContext::new(expr_arena, schema))?;
496
497
let inputs = create_physical_expressions_from_irs(&input, expr_arena, schema, state)?;
498
let grouped_reduction = function
499
.clone()
500
.materialize()?
501
.as_any()
502
.downcast_ref::<Box<dyn GroupedReduction>>()
503
.unwrap()
504
.new_empty();
505
506
Ok(Arc::new(AnonymousAggregationExpr::new(
507
inputs,
508
grouped_reduction,
509
output_field,
510
)))
511
},
512
AnonymousFunction {
513
input,
514
function,
515
options,
516
fmt_str: _,
517
} => {
518
let is_scalar = is_scalar_ae(expression, expr_arena);
519
let output_field = expr_arena
520
.get(expression)
521
.to_field(&ToFieldContext::new(expr_arena, schema))?;
522
523
let input = create_physical_expressions_from_irs(&input, expr_arena, schema, state)?;
524
525
let function = function.clone().materialize()?;
526
let function = function.into_inner().as_column_udf();
527
528
Ok(Arc::new(ApplyExpr::new(
529
input,
530
SpecialEq::new(function),
531
None,
532
node_to_expr(expression, expr_arena),
533
options,
534
state.allow_threading,
535
schema.clone(),
536
output_field,
537
is_scalar,
538
true,
539
)))
540
},
541
Eval {
542
expr,
543
evaluation,
544
variant,
545
} => {
546
let is_scalar = is_scalar_ae(expression, expr_arena);
547
let evaluation_is_scalar = is_scalar_ae(evaluation, expr_arena);
548
let evaluation_is_elementwise = is_elementwise_rec(evaluation, expr_arena);
549
// @NOTE: This is actually also something the downstream apply code should care about.
550
let mut pd_group = ExprPushdownGroup::Pushable;
551
pd_group.update_with_expr_rec(expr_arena.get(evaluation), expr_arena, None);
552
let evaluation_is_fallible = matches!(pd_group, ExprPushdownGroup::Fallible);
553
554
let output_field = expr_arena
555
.get(expression)
556
.to_field(&ToFieldContext::new(expr_arena, schema))?;
557
let input_field = expr_arena
558
.get(expr)
559
.to_field(&ToFieldContext::new(expr_arena, schema))?;
560
let expr = create_physical_expr_inner(expr, expr_arena, schema, state)?;
561
562
let element_dtype = variant.element_dtype(&input_field.dtype)?;
563
let mut eval_schema = schema.as_ref().clone();
564
eval_schema.insert(get_pl_element_name(), element_dtype.clone());
565
let evaluation =
566
create_physical_expr_inner(evaluation, expr_arena, &Arc::new(eval_schema), state)?;
567
568
Ok(Arc::new(EvalExpr::new(
569
expr,
570
evaluation,
571
variant,
572
node_to_expr(expression, expr_arena),
573
output_field,
574
is_scalar,
575
evaluation_is_scalar,
576
evaluation_is_elementwise,
577
evaluation_is_fallible,
578
)))
579
},
580
#[cfg(feature = "dtype-struct")]
581
StructEval { expr, evaluation } => {
582
let is_scalar = is_scalar_ae(expression, expr_arena);
583
let output_field = expr_arena
584
.get(expression)
585
.to_field(&ToFieldContext::new(expr_arena, schema))?;
586
let input_field = expr_arena
587
.get(expr)
588
.to_field(&ToFieldContext::new(expr_arena, schema))?;
589
590
let input = create_physical_expr_inner(expr, expr_arena, schema, state)?;
591
592
let mut eval_schema = schema.as_ref().clone();
593
eval_schema.insert(get_pl_structfields_name(), input_field.dtype().clone());
594
let eval_schema = Arc::new(eval_schema);
595
596
let evaluation = evaluation
597
.iter()
598
.map(|e| create_physical_expr(e, expr_arena, &eval_schema, state))
599
.collect::<PolarsResult<Vec<_>>>()?;
600
601
Ok(Arc::new(StructEvalExpr::new(
602
input,
603
evaluation,
604
node_to_expr(expression, expr_arena),
605
output_field,
606
is_scalar,
607
state.allow_threading,
608
)))
609
},
610
Function {
611
input,
612
function,
613
options,
614
} => {
615
let is_scalar = is_scalar_ae(expression, expr_arena);
616
617
let output_field = expr_arena
618
.get(expression)
619
.to_field(&ToFieldContext::new(expr_arena, schema))?;
620
621
let input = create_physical_expressions_from_irs(&input, expr_arena, schema, state)?;
622
let is_fallible = expr_arena.get(expression).is_fallible_top_level(expr_arena);
623
624
Ok(Arc::new(ApplyExpr::new(
625
input,
626
function_expr_to_udf(function.clone()),
627
function_expr_to_groups_udf(&function),
628
node_to_expr(expression, expr_arena),
629
options,
630
state.allow_threading,
631
schema.clone(),
632
output_field,
633
is_scalar,
634
is_fallible,
635
)))
636
},
637
638
Slice {
639
input,
640
offset,
641
length,
642
} => {
643
let input = create_physical_expr_inner(input, expr_arena, schema, state)?;
644
let offset = create_physical_expr_inner(offset, expr_arena, schema, state)?;
645
let length = create_physical_expr_inner(length, expr_arena, schema, state)?;
646
Ok(Arc::new(SliceExpr {
647
input,
648
offset,
649
length,
650
expr: node_to_expr(expression, expr_arena),
651
}))
652
},
653
Explode { expr, options } => {
654
let input = create_physical_expr_inner(expr, expr_arena, schema, state)?;
655
let function = SpecialEq::new(Arc::new(
656
move |c: &mut [polars_core::frame::column::Column]| c[0].explode(options),
657
) as Arc<dyn ColumnsUdf>);
658
659
let output_field = expr_arena
660
.get(expression)
661
.to_field(&ToFieldContext::new(expr_arena, schema))?;
662
663
Ok(Arc::new(ApplyExpr::new(
664
vec![input],
665
function,
666
None,
667
node_to_expr(expression, expr_arena),
668
FunctionOptions::groupwise(),
669
state.allow_threading,
670
schema.clone(),
671
output_field,
672
false,
673
false,
674
)))
675
},
676
}
677
}
678
679