Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-expr/src/planner.rs
6939 views
1
use polars_core::prelude::*;
2
use polars_plan::prelude::expr_ir::ExprIR;
3
use polars_plan::prelude::*;
4
use recursive::recursive;
5
6
use crate::expressions as phys_expr;
7
use crate::expressions::*;
8
9
pub fn get_expr_depth_limit() -> PolarsResult<u16> {
10
let depth = if let Ok(d) = std::env::var("POLARS_MAX_EXPR_DEPTH") {
11
let v = d
12
.parse::<u64>()
13
.map_err(|_| polars_err!(ComputeError: "could not parse 'max_expr_depth': {}", d))?;
14
u16::try_from(v).unwrap_or(0)
15
} else {
16
512
17
};
18
Ok(depth)
19
}
20
21
fn ok_checker(_i: usize, _state: &ExpressionConversionState) -> PolarsResult<()> {
22
Ok(())
23
}
24
25
pub fn create_physical_expressions_from_irs(
26
exprs: &[ExprIR],
27
context: Context,
28
expr_arena: &Arena<AExpr>,
29
schema: &SchemaRef,
30
state: &mut ExpressionConversionState,
31
) -> PolarsResult<Vec<Arc<dyn PhysicalExpr>>> {
32
create_physical_expressions_check_state(exprs, context, expr_arena, schema, state, ok_checker)
33
}
34
35
pub(crate) fn create_physical_expressions_check_state<F>(
36
exprs: &[ExprIR],
37
context: Context,
38
expr_arena: &Arena<AExpr>,
39
schema: &SchemaRef,
40
state: &mut ExpressionConversionState,
41
checker: F,
42
) -> PolarsResult<Vec<Arc<dyn PhysicalExpr>>>
43
where
44
F: Fn(usize, &ExpressionConversionState) -> PolarsResult<()>,
45
{
46
exprs
47
.iter()
48
.enumerate()
49
.map(|(i, e)| {
50
state.reset();
51
let out = create_physical_expr(e, context, expr_arena, schema, state);
52
checker(i, state)?;
53
out
54
})
55
.collect()
56
}
57
58
pub(crate) fn create_physical_expressions_from_nodes(
59
exprs: &[Node],
60
context: Context,
61
expr_arena: &Arena<AExpr>,
62
schema: &SchemaRef,
63
state: &mut ExpressionConversionState,
64
) -> PolarsResult<Vec<Arc<dyn PhysicalExpr>>> {
65
create_physical_expressions_from_nodes_check_state(
66
exprs, context, expr_arena, schema, state, ok_checker,
67
)
68
}
69
70
pub(crate) fn create_physical_expressions_from_nodes_check_state<F>(
71
exprs: &[Node],
72
context: Context,
73
expr_arena: &Arena<AExpr>,
74
schema: &SchemaRef,
75
state: &mut ExpressionConversionState,
76
checker: F,
77
) -> PolarsResult<Vec<Arc<dyn PhysicalExpr>>>
78
where
79
F: Fn(usize, &ExpressionConversionState) -> PolarsResult<()>,
80
{
81
exprs
82
.iter()
83
.enumerate()
84
.map(|(i, e)| {
85
state.reset();
86
let out = create_physical_expr_inner(*e, context, expr_arena, schema, state);
87
checker(i, state)?;
88
out
89
})
90
.collect()
91
}
92
93
#[derive(Copy, Clone)]
94
pub struct ExpressionConversionState {
95
// settings per context
96
// they remain activate between
97
// expressions
98
pub allow_threading: bool,
99
pub has_windows: bool,
100
// settings per expression
101
// those are reset every expression
102
local: LocalConversionState,
103
}
104
105
#[derive(Copy, Clone, Default)]
106
struct LocalConversionState {
107
has_implode: bool,
108
has_window: bool,
109
has_lit: bool,
110
}
111
112
impl ExpressionConversionState {
113
pub fn new(allow_threading: bool) -> Self {
114
Self {
115
allow_threading,
116
has_windows: false,
117
local: LocalConversionState {
118
..Default::default()
119
},
120
}
121
}
122
123
fn reset(&mut self) {
124
self.local = LocalConversionState::default();
125
}
126
127
fn has_implode(&self) -> bool {
128
self.local.has_implode
129
}
130
131
fn set_window(&mut self) {
132
self.has_windows = true;
133
self.local.has_window = true;
134
}
135
}
136
137
pub fn create_physical_expr(
138
expr_ir: &ExprIR,
139
ctxt: Context,
140
expr_arena: &Arena<AExpr>,
141
schema: &SchemaRef,
142
state: &mut ExpressionConversionState,
143
) -> PolarsResult<Arc<dyn PhysicalExpr>> {
144
let phys_expr = create_physical_expr_inner(expr_ir.node(), ctxt, expr_arena, schema, state)?;
145
146
if let Some(name) = expr_ir.get_alias() {
147
Ok(Arc::new(AliasExpr::new(
148
phys_expr,
149
name.clone(),
150
node_to_expr(expr_ir.node(), expr_arena),
151
)))
152
} else {
153
Ok(phys_expr)
154
}
155
}
156
157
#[recursive]
158
fn create_physical_expr_inner(
159
expression: Node,
160
ctxt: Context,
161
expr_arena: &Arena<AExpr>,
162
schema: &SchemaRef,
163
state: &mut ExpressionConversionState,
164
) -> PolarsResult<Arc<dyn PhysicalExpr>> {
165
use AExpr::*;
166
167
match expr_arena.get(expression) {
168
Len => Ok(Arc::new(phys_expr::CountExpr::new())),
169
Window {
170
function,
171
partition_by,
172
order_by,
173
options,
174
} => {
175
let function = *function;
176
state.set_window();
177
let phys_function = create_physical_expr_inner(
178
function,
179
Context::Aggregation,
180
expr_arena,
181
schema,
182
state,
183
)?;
184
185
let order_by = order_by
186
.map(|(node, options)| {
187
PolarsResult::Ok((
188
create_physical_expr_inner(
189
node,
190
Context::Aggregation,
191
expr_arena,
192
schema,
193
state,
194
)?,
195
options,
196
))
197
})
198
.transpose()?;
199
200
let function_expr = node_to_expr(function, expr_arena);
201
let expr = node_to_expr(expression, expr_arena);
202
203
// set again as the state can be reset
204
state.set_window();
205
match options {
206
WindowType::Over(mapping) => {
207
// TODO! Order by
208
let group_by = create_physical_expressions_from_nodes(
209
partition_by,
210
Context::Aggregation,
211
expr_arena,
212
schema,
213
state,
214
)?;
215
let mut apply_columns = aexpr_to_leaf_names(function, expr_arena);
216
// sort and then dedup removes consecutive duplicates == all duplicates
217
apply_columns.sort();
218
apply_columns.dedup();
219
220
if apply_columns.is_empty() {
221
if has_aexpr(function, expr_arena, |e| matches!(e, AExpr::Literal(_))) {
222
apply_columns.push(PlSmallStr::from_static("literal"))
223
} else if has_aexpr(function, expr_arena, |e| matches!(e, AExpr::Len)) {
224
apply_columns.push(PlSmallStr::from_static("len"))
225
} else {
226
let e = node_to_expr(function, expr_arena);
227
polars_bail!(
228
ComputeError:
229
"cannot apply a window function, did not find a root column; \
230
this is likely due to a syntax error in this expression: {:?}", e
231
);
232
}
233
}
234
235
// Check if the branches have an aggregation
236
// when(a > sum)
237
// then (foo)
238
// otherwise(bar - sum)
239
let mut has_arity = false;
240
let mut agg_col = false;
241
for (_, e) in expr_arena.iter(function) {
242
match e {
243
AExpr::Ternary { .. } | AExpr::BinaryExpr { .. } => {
244
has_arity = true;
245
},
246
AExpr::Agg(_) => {
247
agg_col = true;
248
},
249
AExpr::Function { options, .. }
250
| AExpr::AnonymousFunction { options, .. } => {
251
if options.flags.returns_scalar() {
252
agg_col = true;
253
}
254
},
255
_ => {},
256
}
257
}
258
let has_different_group_sources = has_arity && agg_col;
259
260
Ok(Arc::new(WindowExpr {
261
group_by,
262
order_by,
263
apply_columns,
264
function: function_expr,
265
phys_function,
266
mapping: *mapping,
267
expr,
268
has_different_group_sources,
269
}))
270
},
271
#[cfg(feature = "dynamic_group_by")]
272
WindowType::Rolling(options) => Ok(Arc::new(RollingExpr {
273
function: function_expr,
274
phys_function,
275
options: options.clone(),
276
expr,
277
})),
278
}
279
},
280
Literal(value) => {
281
state.local.has_lit = true;
282
Ok(Arc::new(LiteralExpr::new(
283
value.clone(),
284
node_to_expr(expression, expr_arena),
285
)))
286
},
287
BinaryExpr { left, op, right } => {
288
let output_field = expr_arena.get(expression).to_field(schema, expr_arena)?;
289
let is_scalar = is_scalar_ae(expression, expr_arena);
290
let lhs = create_physical_expr_inner(*left, ctxt, expr_arena, schema, state)?;
291
let rhs = create_physical_expr_inner(*right, ctxt, expr_arena, schema, state)?;
292
Ok(Arc::new(phys_expr::BinaryExpr::new(
293
lhs,
294
*op,
295
rhs,
296
node_to_expr(expression, expr_arena),
297
state.local.has_lit,
298
state.allow_threading,
299
is_scalar,
300
output_field,
301
)))
302
},
303
Column(column) => Ok(Arc::new(ColumnExpr::new(
304
column.clone(),
305
node_to_expr(expression, expr_arena),
306
schema.clone(),
307
))),
308
Sort { expr, options } => {
309
let phys_expr = create_physical_expr_inner(*expr, ctxt, expr_arena, schema, state)?;
310
Ok(Arc::new(SortExpr::new(
311
phys_expr,
312
*options,
313
node_to_expr(expression, expr_arena),
314
)))
315
},
316
Gather {
317
expr,
318
idx,
319
returns_scalar,
320
} => {
321
let phys_expr = create_physical_expr_inner(*expr, ctxt, expr_arena, schema, state)?;
322
let phys_idx = create_physical_expr_inner(*idx, ctxt, expr_arena, schema, state)?;
323
Ok(Arc::new(GatherExpr {
324
phys_expr,
325
idx: phys_idx,
326
expr: node_to_expr(expression, expr_arena),
327
returns_scalar: *returns_scalar,
328
}))
329
},
330
SortBy {
331
expr,
332
by,
333
sort_options,
334
} => {
335
let phys_expr = create_physical_expr_inner(*expr, ctxt, expr_arena, schema, state)?;
336
let phys_by =
337
create_physical_expressions_from_nodes(by, ctxt, expr_arena, schema, state)?;
338
Ok(Arc::new(SortByExpr::new(
339
phys_expr,
340
phys_by,
341
node_to_expr(expression, expr_arena),
342
sort_options.clone(),
343
)))
344
},
345
Filter { input, by } => {
346
let phys_input = create_physical_expr_inner(*input, ctxt, expr_arena, schema, state)?;
347
let phys_by = create_physical_expr_inner(*by, ctxt, expr_arena, schema, state)?;
348
Ok(Arc::new(FilterExpr::new(
349
phys_input,
350
phys_by,
351
node_to_expr(expression, expr_arena),
352
)))
353
},
354
Agg(agg) => {
355
let expr = agg.get_input().first();
356
let input = create_physical_expr_inner(expr, ctxt, expr_arena, schema, state)?;
357
polars_ensure!(!(state.has_implode() && matches!(ctxt, Context::Aggregation)), InvalidOperation: "'implode' followed by an aggregation is not allowed");
358
state.local.has_implode |= matches!(agg, IRAggExpr::Implode(_));
359
let allow_threading = state.allow_threading;
360
361
match ctxt {
362
Context::Default if !matches!(agg, IRAggExpr::Quantile { .. }) => {
363
use {GroupByMethod as GBM, IRAggExpr as I};
364
365
let groupby = match agg {
366
I::Min { propagate_nans, .. } if *propagate_nans => GBM::NanMin,
367
I::Min { .. } => GBM::Min,
368
I::Max { propagate_nans, .. } if *propagate_nans => GBM::NanMax,
369
I::Max { .. } => GBM::Max,
370
I::Median(_) => GBM::Median,
371
I::NUnique(_) => GBM::NUnique,
372
I::First(_) => GBM::First,
373
I::Last(_) => GBM::Last,
374
I::Mean(_) => GBM::Mean,
375
I::Implode(_) => GBM::Implode,
376
I::Quantile { .. } => unreachable!(),
377
I::Sum(_) => GBM::Sum,
378
I::Count {
379
input: _,
380
include_nulls,
381
} => GBM::Count {
382
include_nulls: *include_nulls,
383
},
384
I::Std(_, ddof) => GBM::Std(*ddof),
385
I::Var(_, ddof) => GBM::Var(*ddof),
386
I::AggGroups(_) => {
387
polars_bail!(InvalidOperation: "agg groups expression only supported in aggregation context")
388
},
389
};
390
391
let agg_type = AggregationType {
392
groupby,
393
allow_threading,
394
};
395
396
Ok(Arc::new(AggregationExpr::new(input, agg_type, None)))
397
},
398
_ => {
399
if let IRAggExpr::Quantile {
400
quantile,
401
method: interpol,
402
..
403
} = agg
404
{
405
let quantile =
406
create_physical_expr_inner(*quantile, ctxt, expr_arena, schema, state)?;
407
return Ok(Arc::new(AggQuantileExpr::new(input, quantile, *interpol)));
408
}
409
410
let field = expr_arena.get(expression).to_field_with_ctx(
411
schema,
412
Context::Aggregation,
413
expr_arena,
414
)?;
415
416
let groupby = GroupByMethod::from(agg.clone());
417
let agg_type = AggregationType {
418
groupby,
419
allow_threading: false,
420
};
421
Ok(Arc::new(AggregationExpr::new(input, agg_type, Some(field))))
422
},
423
}
424
},
425
Cast {
426
expr,
427
dtype,
428
options,
429
} => {
430
let phys_expr = create_physical_expr_inner(*expr, ctxt, expr_arena, schema, state)?;
431
Ok(Arc::new(CastExpr {
432
input: phys_expr,
433
dtype: dtype.clone(),
434
expr: node_to_expr(expression, expr_arena),
435
options: *options,
436
}))
437
},
438
Ternary {
439
predicate,
440
truthy,
441
falsy,
442
} => {
443
let is_scalar = is_scalar_ae(expression, expr_arena);
444
let mut lit_count = 0u8;
445
state.reset();
446
let predicate =
447
create_physical_expr_inner(*predicate, ctxt, expr_arena, schema, state)?;
448
lit_count += state.local.has_lit as u8;
449
state.reset();
450
let truthy = create_physical_expr_inner(*truthy, ctxt, expr_arena, schema, state)?;
451
lit_count += state.local.has_lit as u8;
452
state.reset();
453
let falsy = create_physical_expr_inner(*falsy, ctxt, expr_arena, schema, state)?;
454
lit_count += state.local.has_lit as u8;
455
Ok(Arc::new(TernaryExpr::new(
456
predicate,
457
truthy,
458
falsy,
459
node_to_expr(expression, expr_arena),
460
state.allow_threading && lit_count < 2,
461
is_scalar,
462
)))
463
},
464
AnonymousFunction {
465
input,
466
function,
467
options,
468
fmt_str: _,
469
} => {
470
let is_scalar = is_scalar_ae(expression, expr_arena);
471
let output_field = expr_arena
472
.get(expression)
473
.to_field_with_ctx(schema, ctxt, expr_arena)?;
474
475
let input =
476
create_physical_expressions_from_irs(input, ctxt, expr_arena, schema, state)?;
477
478
let function = function.clone().materialize()?;
479
let function = function.into_inner().as_column_udf();
480
481
Ok(Arc::new(ApplyExpr::new(
482
input,
483
SpecialEq::new(function),
484
node_to_expr(expression, expr_arena),
485
*options,
486
state.allow_threading,
487
schema.clone(),
488
output_field,
489
is_scalar,
490
)))
491
},
492
Eval {
493
expr,
494
evaluation,
495
variant,
496
} => {
497
let is_scalar = is_scalar_ae(expression, expr_arena);
498
let evaluation_is_scalar = is_scalar_ae(*evaluation, expr_arena);
499
let mut pd_group = ExprPushdownGroup::Pushable;
500
pd_group.update_with_expr_rec(expr_arena.get(*evaluation), expr_arena, None);
501
502
let output_field_with_ctx = expr_arena
503
.get(expression)
504
.to_field_with_ctx(schema, ctxt, expr_arena)?;
505
let non_aggregated_output_field =
506
expr_arena.get(expression).to_field(schema, expr_arena)?;
507
let input_field = expr_arena.get(*expr).to_field(schema, expr_arena)?;
508
let expr =
509
create_physical_expr_inner(*expr, Context::Default, expr_arena, schema, state)?;
510
511
let element_dtype = variant.element_dtype(&input_field.dtype)?;
512
let eval_schema = Schema::from_iter([(PlSmallStr::EMPTY, element_dtype.clone())]);
513
let evaluation = create_physical_expr_inner(
514
*evaluation,
515
Context::Default,
516
expr_arena,
517
&Arc::new(eval_schema),
518
state,
519
)?;
520
521
Ok(Arc::new(EvalExpr::new(
522
expr,
523
evaluation,
524
*variant,
525
node_to_expr(expression, expr_arena),
526
state.allow_threading,
527
output_field_with_ctx,
528
non_aggregated_output_field.dtype,
529
is_scalar,
530
pd_group,
531
evaluation_is_scalar,
532
)))
533
},
534
Function {
535
input,
536
function,
537
options,
538
} => {
539
let is_scalar = is_scalar_ae(expression, expr_arena);
540
let output_field = expr_arena
541
.get(expression)
542
.to_field_with_ctx(schema, ctxt, expr_arena)?;
543
let input =
544
create_physical_expressions_from_irs(input, ctxt, expr_arena, schema, state)?;
545
546
Ok(Arc::new(ApplyExpr::new(
547
input,
548
function.clone().into(),
549
node_to_expr(expression, expr_arena),
550
*options,
551
state.allow_threading,
552
schema.clone(),
553
output_field,
554
is_scalar,
555
)))
556
},
557
Slice {
558
input,
559
offset,
560
length,
561
} => {
562
let input = create_physical_expr_inner(*input, ctxt, expr_arena, schema, state)?;
563
let offset = create_physical_expr_inner(*offset, ctxt, expr_arena, schema, state)?;
564
let length = create_physical_expr_inner(*length, ctxt, expr_arena, schema, state)?;
565
polars_ensure!(!(state.has_implode() && matches!(ctxt, Context::Aggregation)),
566
InvalidOperation: "'implode' followed by a slice during aggregation is not allowed");
567
Ok(Arc::new(SliceExpr {
568
input,
569
offset,
570
length,
571
expr: node_to_expr(expression, expr_arena),
572
}))
573
},
574
Explode { expr, skip_empty } => {
575
let input = create_physical_expr_inner(*expr, ctxt, expr_arena, schema, state)?;
576
let skip_empty = *skip_empty;
577
let function = SpecialEq::new(Arc::new(
578
move |c: &mut [polars_core::frame::column::Column]| c[0].explode(skip_empty),
579
) as Arc<dyn ColumnsUdf>);
580
581
let field = expr_arena
582
.get(expression)
583
.to_field_with_ctx(schema, ctxt, expr_arena)?;
584
585
Ok(Arc::new(ApplyExpr::new(
586
vec![input],
587
function,
588
node_to_expr(expression, expr_arena),
589
FunctionOptions::groupwise(),
590
state.allow_threading,
591
schema.clone(),
592
field,
593
false,
594
)))
595
},
596
}
597
}
598
599