Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/plans/conversion/dsl_to_ir/mod.rs
7889 views
1
use arrow::datatypes::ArrowSchemaRef;
2
use either::Either;
3
use expr_expansion::rewrite_projections;
4
use hive::hive_partitions_from_paths;
5
use polars_core::chunked_array::cast::CastOptions;
6
use polars_core::config::verbose;
7
use polars_utils::format_pl_smallstr;
8
use polars_utils::itertools::Itertools;
9
use polars_utils::plpath::PlPath;
10
use polars_utils::unique_id::UniqueId;
11
12
use super::convert_utils::SplitPredicates;
13
use super::stack_opt::ConversionOptimizer;
14
use super::*;
15
use crate::constants::get_pl_element_name;
16
use crate::dsl::PartitionedSinkOptions;
17
use crate::dsl::sink2::FileProviderType;
18
19
mod concat;
20
mod datatype_fn_to_ir;
21
mod expr_expansion;
22
mod expr_to_ir;
23
mod functions;
24
mod join;
25
mod scans;
26
mod utils;
27
pub use expr_expansion::{expand_expression, is_regex_projection, prepare_projection};
28
pub use expr_to_ir::{ExprToIRContext, to_expr_ir};
29
use expr_to_ir::{to_expr_ir_materialized_lit, to_expr_irs};
30
use utils::DslConversionContext;
31
32
macro_rules! failed_here {
33
($($t:tt)*) => {
34
format!("'{}'", stringify!($($t)*)).into()
35
}
36
}
37
pub(super) use failed_here;
38
39
pub fn to_alp(
40
lp: DslPlan,
41
expr_arena: &mut Arena<AExpr>,
42
lp_arena: &mut Arena<IR>,
43
// Only `SIMPLIFY_EXPR`, `TYPE_COERCION`, `TYPE_CHECK` are respected.
44
opt_flags: &mut OptFlags,
45
) -> PolarsResult<Node> {
46
let conversion_optimizer = ConversionOptimizer::new(
47
opt_flags.contains(OptFlags::SIMPLIFY_EXPR),
48
opt_flags.contains(OptFlags::TYPE_COERCION),
49
opt_flags.contains(OptFlags::TYPE_CHECK),
50
);
51
52
let mut ctxt = DslConversionContext {
53
expr_arena,
54
lp_arena,
55
conversion_optimizer,
56
opt_flags,
57
nodes_scratch: &mut unitvec![],
58
cache_file_info: Default::default(),
59
pushdown_maintain_errors: optimizer::pushdown_maintain_errors(),
60
verbose: verbose(),
61
seen_caches: Default::default(),
62
};
63
64
match to_alp_impl(lp, &mut ctxt) {
65
Ok(out) => Ok(out),
66
Err(err) => {
67
if opt_flags.contains(OptFlags::EAGER) {
68
// If we dispatched to the lazy engine from the eager API, we don't want to resolve
69
// where in the query plan it went wrong. It is clear from the backtrace anyway.
70
return Err(err.remove_context());
71
};
72
let Some(ir_until_then) = lp_arena.last_node() else {
73
return Err(err);
74
};
75
let node_name = if let PolarsError::Context { msg, .. } = &err {
76
msg
77
} else {
78
"THIS_NODE"
79
};
80
let plan = IRPlan::new(
81
ir_until_then,
82
std::mem::take(lp_arena),
83
std::mem::take(expr_arena),
84
);
85
let location = format!("{}", plan.display());
86
Err(err.wrap_msg(|msg| {
87
format!("{msg}\n\nResolved plan until failure:\n\n\t---> FAILED HERE RESOLVING {node_name} <---\n{location}")
88
}))
89
},
90
}
91
}
92
93
fn run_conversion(lp: IR, ctxt: &mut DslConversionContext, name: &str) -> PolarsResult<Node> {
94
let lp_node = ctxt.lp_arena.add(lp);
95
ctxt.conversion_optimizer
96
.optimize_exprs(ctxt.expr_arena, ctxt.lp_arena, lp_node, false)
97
.map_err(|e| e.context(format!("'{name}' failed").into()))?;
98
99
Ok(lp_node)
100
}
101
102
/// converts LogicalPlan to IR
103
/// it adds expressions & lps to the respective arenas as it traverses the plan
104
/// finally it returns the top node of the logical plan
105
#[recursive]
106
pub fn to_alp_impl(lp: DslPlan, ctxt: &mut DslConversionContext) -> PolarsResult<Node> {
107
let owned = Arc::unwrap_or_clone;
108
109
let v = match lp {
110
DslPlan::Scan {
111
sources,
112
unified_scan_args,
113
scan_type,
114
cached_ir,
115
} => scans::dsl_to_ir(sources, unified_scan_args, scan_type, cached_ir, ctxt)?,
116
#[cfg(feature = "python")]
117
DslPlan::PythonScan { options } => {
118
use crate::dsl::python_dsl::PythonOptionsDsl;
119
120
let schema = options.get_schema()?;
121
122
let PythonOptionsDsl {
123
scan_fn,
124
schema_fn: _,
125
python_source,
126
validate_schema,
127
is_pure,
128
} = options;
129
130
IR::PythonScan {
131
options: PythonOptions {
132
scan_fn,
133
schema,
134
python_source,
135
validate_schema,
136
output_schema: Default::default(),
137
with_columns: Default::default(),
138
n_rows: Default::default(),
139
predicate: Default::default(),
140
is_pure,
141
},
142
}
143
},
144
DslPlan::Union { inputs, args } => {
145
let mut inputs = inputs
146
.into_iter()
147
.map(|lp| to_alp_impl(lp, ctxt))
148
.collect::<PolarsResult<Vec<_>>>()
149
.map_err(|e| e.context(failed_here!(vertical concat)))?;
150
151
if args.diagonal {
152
inputs = concat::convert_diagonal_concat(inputs, ctxt.lp_arena, ctxt.expr_arena)?;
153
}
154
155
if args.to_supertypes {
156
concat::convert_st_union(
157
&mut inputs,
158
ctxt.lp_arena,
159
ctxt.expr_arena,
160
ctxt.opt_flags,
161
)
162
.map_err(|e| e.context(failed_here!(vertical concat)))?;
163
}
164
165
let first = *inputs.first().ok_or_else(
166
|| polars_err!(InvalidOperation: "expected at least one input in 'union'/'concat'"),
167
)?;
168
let schema = ctxt.lp_arena.get(first).schema(ctxt.lp_arena);
169
for n in &inputs[1..] {
170
let schema_i = ctxt.lp_arena.get(*n).schema(ctxt.lp_arena);
171
// The first argument
172
schema_i.matches_schema(schema.as_ref()).map_err(|_| polars_err!(InvalidOperation: "'union'/'concat' inputs should all have the same schema,\
173
got\n{:?} and \n{:?}", schema, schema_i)
174
)?;
175
}
176
177
let options = args.into();
178
IR::Union { inputs, options }
179
},
180
DslPlan::HConcat { inputs, options } => {
181
let inputs = inputs
182
.into_iter()
183
.map(|lp| to_alp_impl(lp, ctxt))
184
.collect::<PolarsResult<Vec<_>>>()
185
.map_err(|e| e.context(failed_here!(horizontal concat)))?;
186
187
let schema = concat::h_concat_schema(&inputs, ctxt.lp_arena)?;
188
189
IR::HConcat {
190
inputs,
191
schema,
192
options,
193
}
194
},
195
DslPlan::Filter { input, predicate } => {
196
let mut input =
197
to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_here!(filter)))?;
198
let input_schema = ctxt.lp_arena.get(input).schema(ctxt.lp_arena);
199
200
let mut out = Vec::with_capacity(1);
201
expr_expansion::expand_expression(
202
&predicate,
203
&PlHashSet::default(),
204
input_schema.as_ref().as_ref(),
205
&mut out,
206
ctxt.opt_flags,
207
)?;
208
209
let predicate = match out.len() {
210
1 => {
211
// all good
212
out.pop().unwrap()
213
},
214
0 => {
215
let msg = "The predicate expanded to zero expressions. \
216
This may for example be caused by a regex not matching column names or \
217
a column dtype match not hitting any dtypes in the DataFrame";
218
polars_bail!(ComputeError: msg);
219
},
220
_ => {
221
let mut expanded = String::new();
222
for e in out.iter().take(5) {
223
expanded.push_str(&format!("\t{e:?},\n"))
224
}
225
// pop latest comma
226
expanded.pop();
227
if out.len() > 5 {
228
expanded.push_str("\t...\n")
229
}
230
231
let msg = if cfg!(feature = "python") {
232
format!(
233
"The predicate passed to 'LazyFrame.filter' expanded to multiple expressions: \n\n{expanded}\n\
234
This is ambiguous. Try to combine the predicates with the 'all' or `any' expression."
235
)
236
} else {
237
format!(
238
"The predicate passed to 'LazyFrame.filter' expanded to multiple expressions: \n\n{expanded}\n\
239
This is ambiguous. Try to combine the predicates with the 'all_horizontal' or `any_horizontal' expression."
240
)
241
};
242
polars_bail!(ComputeError: msg)
243
},
244
};
245
let predicate_ae = to_expr_ir(
246
predicate,
247
&mut ExprToIRContext::new_with_opt_eager(
248
ctxt.expr_arena,
249
&input_schema,
250
ctxt.opt_flags,
251
),
252
)?;
253
254
if ctxt.opt_flags.predicate_pushdown() {
255
ctxt.nodes_scratch.clear();
256
257
if let Some(SplitPredicates { pushable, fallible }) = SplitPredicates::new(
258
predicate_ae.node(),
259
ctxt.expr_arena,
260
Some(ctxt.nodes_scratch),
261
ctxt.pushdown_maintain_errors,
262
) {
263
let mut update_input = |predicate: Node| -> PolarsResult<()> {
264
let predicate = ExprIR::from_node(predicate, ctxt.expr_arena);
265
ctxt.conversion_optimizer
266
.push_scratch(predicate.node(), ctxt.expr_arena);
267
let lp = IR::Filter { input, predicate };
268
input = run_conversion(lp, ctxt, "filter")?;
269
270
Ok(())
271
};
272
273
// Pushables first, then fallible.
274
275
for predicate in pushable {
276
update_input(predicate)?;
277
}
278
279
if let Some(node) = fallible {
280
update_input(node)?;
281
}
282
283
return Ok(input);
284
};
285
};
286
287
ctxt.conversion_optimizer
288
.push_scratch(predicate_ae.node(), ctxt.expr_arena);
289
let lp = IR::Filter {
290
input,
291
predicate: predicate_ae,
292
};
293
return run_conversion(lp, ctxt, "filter");
294
},
295
DslPlan::Slice { input, offset, len } => {
296
let input =
297
to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_here!(slice)))?;
298
IR::Slice { input, offset, len }
299
},
300
DslPlan::DataFrameScan { df, schema } => IR::DataFrameScan {
301
df,
302
schema,
303
output_schema: None,
304
},
305
DslPlan::Select {
306
expr,
307
input,
308
options,
309
} => {
310
let input =
311
to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_here!(select)))?;
312
let input_schema = ctxt.lp_arena.get(input).schema(ctxt.lp_arena);
313
let (exprs, schema) = prepare_projection(expr, &input_schema, ctxt.opt_flags)
314
.map_err(|e| e.context(failed_here!(select)))?;
315
316
if exprs.is_empty() {
317
ctxt.lp_arena.replace(input, utils::empty_df());
318
return Ok(input);
319
}
320
321
let eirs = to_expr_irs(
322
exprs,
323
&mut ExprToIRContext::new_with_opt_eager(
324
ctxt.expr_arena,
325
&input_schema,
326
ctxt.opt_flags,
327
),
328
)?;
329
ctxt.conversion_optimizer
330
.fill_scratch(&eirs, ctxt.expr_arena);
331
332
let schema = Arc::new(schema);
333
let lp = IR::Select {
334
expr: eirs,
335
input,
336
schema,
337
options,
338
};
339
340
return run_conversion(lp, ctxt, "select").map_err(|e| e.context(failed_here!(select)));
341
},
342
DslPlan::Sort {
343
input,
344
by_column,
345
slice,
346
mut sort_options,
347
} => {
348
let input =
349
to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_here!(select)))?;
350
let input_schema = ctxt.lp_arena.get(input).schema(ctxt.lp_arena);
351
352
// note: if given an Expr::Columns, count the individual cols
353
let n_by_exprs = if by_column.len() == 1 {
354
match &by_column[0] {
355
Expr::Selector(s) => s.into_columns(&input_schema, &Default::default())?.len(),
356
_ => 1,
357
}
358
} else {
359
by_column.len()
360
};
361
let n_desc = sort_options.descending.len();
362
polars_ensure!(
363
n_desc == n_by_exprs || n_desc == 1,
364
ComputeError: "the length of `descending` ({}) does not match the length of `by` ({})", n_desc, by_column.len()
365
);
366
let n_nulls_last = sort_options.nulls_last.len();
367
polars_ensure!(
368
n_nulls_last == n_by_exprs || n_nulls_last == 1,
369
ComputeError: "the length of `nulls_last` ({}) does not match the length of `by` ({})", n_nulls_last, by_column.len()
370
);
371
372
let mut expanded_cols = Vec::new();
373
let mut nulls_last = Vec::new();
374
let mut descending = Vec::new();
375
376
// note: nulls_last/descending need to be matched to expanded multi-output expressions.
377
// when one of nulls_last/descending has not been updated from the default (single
378
// value true/false), 'cycle' ensures that "by_column" iter is not truncated.
379
for (c, (&n, &d)) in by_column.into_iter().zip(
380
sort_options
381
.nulls_last
382
.iter()
383
.cycle()
384
.zip(sort_options.descending.iter().cycle()),
385
) {
386
let exprs = utils::expand_expressions(
387
input,
388
vec![c],
389
ctxt.lp_arena,
390
ctxt.expr_arena,
391
ctxt.opt_flags,
392
)
393
.map_err(|e| e.context(failed_here!(sort)))?;
394
395
nulls_last.extend(std::iter::repeat_n(n, exprs.len()));
396
descending.extend(std::iter::repeat_n(d, exprs.len()));
397
expanded_cols.extend(exprs);
398
}
399
sort_options.nulls_last = nulls_last;
400
sort_options.descending = descending;
401
402
ctxt.conversion_optimizer
403
.fill_scratch(&expanded_cols, ctxt.expr_arena);
404
let mut by_column = expanded_cols;
405
406
// Remove null columns in multi-columns sort
407
if by_column.len() > 1 {
408
let input_schema = ctxt.lp_arena.get(input).schema(ctxt.lp_arena);
409
410
let mut null_columns = vec![];
411
412
for (i, c) in by_column.iter().enumerate() {
413
if let DataType::Null = c.dtype(&input_schema, ctxt.expr_arena)? {
414
null_columns.push(i);
415
}
416
}
417
// All null columns, only take one.
418
if null_columns.len() == by_column.len() {
419
by_column.truncate(1);
420
sort_options.nulls_last.truncate(1);
421
sort_options.descending.truncate(1);
422
}
423
// Remove the null columns
424
else if !null_columns.is_empty() {
425
for i in null_columns.into_iter().rev() {
426
by_column.remove(i);
427
sort_options.nulls_last.remove(i);
428
sort_options.descending.remove(i);
429
}
430
}
431
}
432
if by_column.is_empty() {
433
return Ok(input);
434
};
435
436
let lp = IR::Sort {
437
input,
438
by_column,
439
slice,
440
sort_options,
441
};
442
443
return run_conversion(lp, ctxt, "sort").map_err(|e| e.context(failed_here!(sort)));
444
},
445
DslPlan::Cache { input, id } => {
446
let input = match ctxt.seen_caches.get(&id) {
447
Some(input) => *input,
448
None => {
449
let input = to_alp_impl(owned(input), ctxt)
450
.map_err(|e| e.context(failed_here!(cache)))?;
451
let seen_before = ctxt.seen_caches.insert(id, input);
452
assert!(
453
seen_before.is_none(),
454
"Cache could not have been created in the mean time. That would make the DAG cyclic."
455
);
456
input
457
},
458
};
459
460
IR::Cache { input, id }
461
},
462
DslPlan::GroupBy {
463
input,
464
keys,
465
predicates,
466
mut aggs,
467
apply,
468
maintain_order,
469
options,
470
} => {
471
// If the group by contains any predicates, we update the plan by turning the
472
// predicates into aggregations and filtering on them. Then, we recursively call
473
// this function.
474
if !predicates.is_empty() {
475
let predicate_names = (0..predicates.len())
476
.map(|i| format_pl_smallstr!("__POLARS_HAVING_{i}"))
477
.collect::<Arc<[_]>>();
478
let predicates = predicates
479
.into_iter()
480
.zip(predicate_names.iter())
481
.map(|(p, name)| p.alias(name.clone()))
482
.collect_vec();
483
aggs.extend(predicates);
484
485
let lp = DslPlan::GroupBy {
486
input,
487
keys,
488
predicates: vec![],
489
aggs,
490
apply,
491
maintain_order,
492
options,
493
};
494
let lp = DslBuilder::from(lp)
495
.filter(
496
all_horizontal(
497
predicate_names.iter().map(|n| col(n.clone())).collect_vec(),
498
)
499
.unwrap(),
500
)
501
.drop(Selector::ByName {
502
names: predicate_names,
503
strict: true,
504
})
505
.build();
506
return to_alp_impl(lp, ctxt);
507
}
508
509
// NOTE: As we went into this branch, we know that no predicates are provided.
510
let input =
511
to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_here!(group_by)))?;
512
513
// Rolling + group-by sorts the whole table, so remove unneeded columns
514
if ctxt.opt_flags.eager() && options.is_rolling() && !keys.is_empty() {
515
ctxt.opt_flags.insert(OptFlags::PROJECTION_PUSHDOWN)
516
}
517
518
let (keys, aggs, schema) = resolve_group_by(
519
input,
520
keys,
521
aggs,
522
&options,
523
ctxt.lp_arena,
524
ctxt.expr_arena,
525
ctxt.opt_flags,
526
)
527
.map_err(|e| e.context(failed_here!(group_by)))?;
528
529
let (apply, schema) = if let Some((apply, schema)) = apply {
530
(Some(apply), schema)
531
} else {
532
(None, schema)
533
};
534
535
ctxt.conversion_optimizer
536
.fill_scratch(&keys, ctxt.expr_arena);
537
ctxt.conversion_optimizer
538
.fill_scratch(&aggs, ctxt.expr_arena);
539
540
let lp = IR::GroupBy {
541
input,
542
keys,
543
aggs,
544
schema,
545
apply,
546
maintain_order,
547
options,
548
};
549
return run_conversion(lp, ctxt, "group_by")
550
.map_err(|e| e.context(failed_here!(group_by)));
551
},
552
DslPlan::Join {
553
input_left,
554
input_right,
555
left_on,
556
right_on,
557
predicates,
558
options,
559
} => {
560
return join::resolve_join(
561
Either::Left(input_left),
562
Either::Left(input_right),
563
left_on,
564
right_on,
565
predicates,
566
JoinOptionsIR::from(Arc::unwrap_or_clone(options)),
567
ctxt,
568
)
569
.map_err(|e| e.context(failed_here!(join)))
570
.map(|t| t.0);
571
},
572
DslPlan::HStack {
573
input,
574
exprs,
575
options,
576
} => {
577
let input = to_alp_impl(owned(input), ctxt)
578
.map_err(|e| e.context(failed_here!(with_columns)))?;
579
let (exprs, schema) =
580
resolve_with_columns(exprs, input, ctxt.lp_arena, ctxt.expr_arena, ctxt.opt_flags)
581
.map_err(|e| e.context(failed_here!(with_columns)))?;
582
583
ctxt.conversion_optimizer
584
.fill_scratch(&exprs, ctxt.expr_arena);
585
let lp = IR::HStack {
586
input,
587
exprs,
588
schema,
589
options,
590
};
591
return run_conversion(lp, ctxt, "with_columns");
592
},
593
DslPlan::MatchToSchema {
594
input,
595
match_schema,
596
per_column,
597
extra_columns,
598
} => {
599
let input =
600
to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_here!(unique)))?;
601
let input_schema = ctxt.lp_arena.get(input).schema(ctxt.lp_arena);
602
603
assert_eq!(per_column.len(), match_schema.len());
604
605
if input_schema.as_ref() == &match_schema {
606
return Ok(input);
607
}
608
609
let mut exprs = Vec::with_capacity(match_schema.len());
610
let mut found_missing_columns = Vec::new();
611
let mut used_input_columns = 0;
612
613
for ((column, dtype), per_column) in match_schema.iter().zip(per_column.iter()) {
614
match input_schema.get(column) {
615
None => match &per_column.missing_columns {
616
MissingColumnsPolicyOrExpr::Raise => found_missing_columns.push(column),
617
MissingColumnsPolicyOrExpr::Insert => exprs.push(Expr::Alias(
618
Arc::new(Expr::Literal(LiteralValue::Scalar(Scalar::null(
619
dtype.clone(),
620
)))),
621
column.clone(),
622
)),
623
MissingColumnsPolicyOrExpr::InsertWith(expr) => {
624
exprs.push(Expr::Alias(Arc::new(expr.clone()), column.clone()))
625
},
626
},
627
Some(input_dtype) if dtype == input_dtype => {
628
used_input_columns += 1;
629
exprs.push(Expr::Column(column.clone()))
630
},
631
Some(input_dtype) => {
632
let from_dtype = input_dtype;
633
let to_dtype = dtype;
634
635
let policy = CastColumnsPolicy {
636
integer_upcast: per_column.integer_cast == UpcastOrForbid::Upcast,
637
float_upcast: per_column.float_cast == UpcastOrForbid::Upcast,
638
missing_struct_fields: per_column.missing_struct_fields,
639
extra_struct_fields: per_column.extra_struct_fields,
640
641
..Default::default()
642
};
643
644
let should_cast =
645
policy.should_cast_column(column, to_dtype, from_dtype)?;
646
647
let mut expr = Expr::Column(PlSmallStr::from_str(column));
648
if should_cast {
649
expr = expr.cast_with_options(to_dtype.clone(), CastOptions::NonStrict);
650
}
651
652
used_input_columns += 1;
653
exprs.push(expr);
654
},
655
}
656
}
657
658
// Report the error for missing columns
659
if let Some(lst) = found_missing_columns.first() {
660
use std::fmt::Write;
661
let mut formatted = String::new();
662
write!(&mut formatted, "\"{}\"", found_missing_columns[0]).unwrap();
663
for c in &found_missing_columns[1..] {
664
write!(&mut formatted, ", \"{c}\"").unwrap();
665
}
666
667
write!(&mut formatted, "\"{lst}\"").unwrap();
668
polars_bail!(SchemaMismatch: "missing columns in `match_to_schema`: {formatted}");
669
}
670
671
// Report the error for extra columns
672
if used_input_columns != input_schema.len()
673
&& extra_columns == ExtraColumnsPolicy::Raise
674
{
675
let found_extra_columns = input_schema
676
.iter_names()
677
.filter(|n| !match_schema.contains(n))
678
.collect::<Vec<_>>();
679
680
use std::fmt::Write;
681
let mut formatted = String::new();
682
write!(&mut formatted, "\"{}\"", found_extra_columns[0]).unwrap();
683
for c in &found_extra_columns[1..] {
684
write!(&mut formatted, ", \"{c}\"").unwrap();
685
}
686
687
polars_bail!(SchemaMismatch: "extra columns in `match_to_schema`: {formatted}");
688
}
689
690
let exprs = to_expr_irs(
691
exprs,
692
&mut ExprToIRContext::new_with_opt_eager(
693
ctxt.expr_arena,
694
&input_schema,
695
ctxt.opt_flags,
696
),
697
)?;
698
699
ctxt.conversion_optimizer
700
.fill_scratch(&exprs, ctxt.expr_arena);
701
let lp = IR::Select {
702
input,
703
expr: exprs,
704
schema: match_schema.clone(),
705
options: ProjectionOptions {
706
run_parallel: true,
707
duplicate_check: false,
708
should_broadcast: true,
709
},
710
};
711
return run_conversion(lp, ctxt, "match_to_schema");
712
},
713
DslPlan::PipeWithSchema { input, callback } => {
714
// Derive the schema from the input
715
let mut inputs = Vec::with_capacity(input.len());
716
let mut input_schemas = Vec::with_capacity(input.len());
717
718
for plan in input.as_ref() {
719
let ir = to_alp_impl(plan.clone(), ctxt)?;
720
let schema = ctxt.lp_arena.get(ir).schema(ctxt.lp_arena).into_owned();
721
722
let dsl = DslPlan::IR {
723
dsl: Arc::new(plan.clone()),
724
version: ctxt.lp_arena.version(),
725
node: Some(ir),
726
};
727
inputs.push(dsl);
728
input_schemas.push(schema);
729
}
730
731
// Adjust the input and start conversion again
732
let input_adjusted = callback.call((inputs, input_schemas))?;
733
return to_alp_impl(input_adjusted, ctxt);
734
},
735
#[cfg(feature = "pivot")]
736
DslPlan::Pivot {
737
input,
738
on,
739
on_columns,
740
index,
741
values,
742
agg,
743
maintain_order,
744
separator,
745
} => {
746
let input =
747
to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_here!(unique)))?;
748
let input_schema = ctxt.lp_arena.get(input).schema(ctxt.lp_arena);
749
750
let on = on.into_columns(input_schema.as_ref(), &Default::default())?;
751
let index = index.into_columns(input_schema.as_ref(), &Default::default())?;
752
let values = values.into_columns(input_schema.as_ref(), &Default::default())?;
753
754
polars_ensure!(!on.is_empty(), InvalidOperation: "`pivot` called without `on` columns.");
755
polars_ensure!(on.len() == on_columns.width(), InvalidOperation: "`pivot` expected `on` and `on_columns` to have the same amount of columns.");
756
if on.len() > 1 {
757
polars_ensure!(
758
on_columns.get_columns().iter().zip(on.iter()).all(|(c, o)| o == c.name()),
759
InvalidOperation: "`pivot` has mismatching column names between `on` and `on_columns`."
760
);
761
}
762
polars_ensure!(!values.is_empty(), InvalidOperation: "`pivot` called without `values` columns.");
763
764
let on_titles = if on_columns.width() == 1 {
765
on_columns.get_columns()[0].cast(&DataType::String)?
766
} else {
767
on_columns
768
.as_ref()
769
.clone()
770
.into_struct(PlSmallStr::EMPTY)
771
.cast(&DataType::String)?
772
.into_column()
773
};
774
let on_titles = on_titles.str()?;
775
776
let mut expr_schema = input_schema.as_ref().as_ref().clone();
777
let mut out = Vec::with_capacity(1);
778
let mut aggs = Vec::<ExprIR>::with_capacity(values.len() * on_columns.height());
779
for value in values.iter() {
780
out.clear();
781
let value_dtype = input_schema.try_get(value)?;
782
expr_schema.insert(get_pl_element_name(), value_dtype.clone());
783
expand_expression(
784
&agg,
785
&Default::default(),
786
&expr_schema,
787
&mut out,
788
ctxt.opt_flags,
789
)?;
790
polars_ensure!(
791
out.len() == 1,
792
InvalidOperation: "Pivot expression are not allowed to expand to more than 1 expression"
793
);
794
let agg = out.pop().unwrap();
795
let agg_ae = to_expr_ir(
796
agg,
797
&mut ExprToIRContext::new_with_opt_eager(
798
ctxt.expr_arena,
799
&expr_schema,
800
ctxt.opt_flags,
801
),
802
)?
803
.node();
804
805
polars_ensure!(
806
aexpr_to_leaf_names_iter(agg_ae, ctxt.expr_arena).count() == 0,
807
InvalidOperation: "explicit column references are not allowed in the `aggregate_function` of `pivot`"
808
);
809
810
for i in 0..on_columns.height() {
811
let mut name = String::new();
812
if values.len() > 1 {
813
name.push_str(value.as_str());
814
name.push_str(separator.as_str());
815
}
816
817
name.push_str(on_titles.get(i).unwrap_or("null"));
818
819
fn on_predicate(
820
on: &PlSmallStr,
821
on_column: &Column,
822
i: usize,
823
expr_arena: &mut Arena<AExpr>,
824
) -> AExprBuilder {
825
let e = AExprBuilder::col(on.clone(), expr_arena);
826
e.eq(
827
AExprBuilder::lit_scalar(
828
Scalar::new(
829
on_column.dtype().clone(),
830
on_column.get(i).unwrap().into_static(),
831
),
832
expr_arena,
833
),
834
expr_arena,
835
)
836
}
837
838
let predicate = if on.len() == 1 {
839
on_predicate(&on[0], &on_columns.get_columns()[0], i, ctxt.expr_arena)
840
} else {
841
AExprBuilder::function(
842
on.iter()
843
.enumerate()
844
.map(|(j, on_col)| {
845
on_predicate(
846
on_col,
847
&on_columns.get_columns()[j],
848
i,
849
ctxt.expr_arena,
850
)
851
.expr_ir(on_col.clone())
852
})
853
.collect::<Vec<_>>(),
854
IRFunctionExpr::Boolean(IRBooleanFunction::AllHorizontal),
855
ctxt.expr_arena,
856
)
857
};
858
859
let replacement_element = AExprBuilder::col(value.clone(), ctxt.expr_arena)
860
.filter(predicate, ctxt.expr_arena)
861
.node();
862
863
#[recursive::recursive]
864
fn deep_clone_element_replace(
865
ae: Node,
866
arena: &mut Arena<AExpr>,
867
replacement: Node,
868
) -> Node {
869
let slf = arena.get(ae).clone();
870
if matches!(slf, AExpr::Element) {
871
return deep_clone_ae(replacement, arena);
872
} else if matches!(slf, AExpr::Len) {
873
// For backwards-compatibility, we support providing `pl.len()` to mean
874
// the length of the group here.
875
let element = deep_clone_ae(replacement, arena);
876
return AExprBuilder::new_from_node(element).len(arena).node();
877
}
878
879
let mut children = vec![];
880
slf.children_rev(&mut children);
881
for child in &mut children {
882
*child = deep_clone_element_replace(*child, arena, replacement);
883
}
884
children.reverse();
885
886
arena.add(slf.replace_children(&children))
887
}
888
aggs.push(ExprIR::new(
889
deep_clone_element_replace(agg_ae, ctxt.expr_arena, replacement_element),
890
OutputName::Alias(name.into()),
891
));
892
}
893
}
894
895
let keys = index
896
.into_iter()
897
.map(|i| AExprBuilder::col(i.clone(), ctxt.expr_arena).expr_ir(i))
898
.collect();
899
IRBuilder::new(input, ctxt.expr_arena, ctxt.lp_arena)
900
.group_by(keys, aggs, None, maintain_order, Default::default())
901
.build()
902
},
903
DslPlan::Distinct { input, options } => {
904
let input =
905
to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_here!(unique)))?;
906
let input_schema = ctxt.lp_arena.get(input).schema(ctxt.lp_arena).into_owned();
907
908
// "subset" param supports cols and/or arbitrary expressions
909
let (input, subset, temp_cols) = if let Some(exprs) = options.subset {
910
let exprs = rewrite_projections(
911
exprs,
912
&PlHashSet::default(),
913
&input_schema,
914
ctxt.opt_flags,
915
)?;
916
917
// identify cols and exprs in "subset" param
918
let mut subset_colnames = vec![];
919
let mut subset_exprs = vec![];
920
for expr in &exprs {
921
match expr {
922
Expr::Column(name) => {
923
polars_ensure!(
924
input_schema.contains(name),
925
ColumnNotFound: "{name:?} not found"
926
);
927
subset_colnames.push(name.clone());
928
},
929
_ => subset_exprs.push(expr.clone()),
930
}
931
}
932
933
if subset_exprs.is_empty() {
934
// "subset" is a collection of basic cols (or empty)
935
(input, Some(subset_colnames.into_iter().collect()), vec![])
936
} else {
937
// "subset" contains exprs; add them as temporary cols
938
let (aliased_exprs, temp_names): (Vec<_>, Vec<_>) = subset_exprs
939
.into_iter()
940
.enumerate()
941
.map(|(idx, expr)| {
942
let temp_name = format_pl_smallstr!("__POLARS_UNIQUE_SUBSET_{}", idx);
943
(expr.alias(temp_name.clone()), temp_name)
944
})
945
.unzip();
946
947
subset_colnames.extend_from_slice(&temp_names);
948
949
// integrate the temporary cols with the existing "input" node
950
let (temp_expr_irs, schema) = resolve_with_columns(
951
aliased_exprs,
952
input,
953
ctxt.lp_arena,
954
ctxt.expr_arena,
955
ctxt.opt_flags,
956
)?;
957
ctxt.conversion_optimizer
958
.fill_scratch(&temp_expr_irs, ctxt.expr_arena);
959
960
let input_with_exprs = ctxt.lp_arena.add(IR::HStack {
961
input,
962
exprs: temp_expr_irs,
963
schema,
964
options: ProjectionOptions {
965
run_parallel: false,
966
duplicate_check: false,
967
should_broadcast: true,
968
},
969
});
970
(
971
input_with_exprs,
972
Some(subset_colnames.into_iter().collect()),
973
temp_names,
974
)
975
}
976
} else {
977
(input, None, vec![])
978
};
979
980
// `distinct` definition (will contain temporary cols if we have "subset" exprs)
981
let distinct_node = ctxt.lp_arena.add(IR::Distinct {
982
input,
983
options: DistinctOptionsIR {
984
subset,
985
maintain_order: options.maintain_order,
986
keep_strategy: options.keep_strategy,
987
slice: None,
988
},
989
});
990
991
// if no temporary cols (eg: we had no "subset" exprs), we're done...
992
if temp_cols.is_empty() {
993
return Ok(distinct_node);
994
}
995
996
// ...otherwise, drop them by projecting the original schema
997
return Ok(ctxt.lp_arena.add(IR::SimpleProjection {
998
input: distinct_node,
999
columns: input_schema,
1000
}));
1001
},
1002
DslPlan::MapFunction { input, function } => {
1003
let input = to_alp_impl(owned(input), ctxt)
1004
.map_err(|e| e.context(failed_here!(format!("{}", function).to_lowercase())))?;
1005
let input_schema = ctxt.lp_arena.get(input).schema(ctxt.lp_arena);
1006
1007
match function {
1008
DslFunction::Explode {
1009
columns,
1010
options,
1011
allow_empty,
1012
} => {
1013
let columns = columns.into_columns(&input_schema, &Default::default())?;
1014
polars_ensure!(!columns.is_empty() || allow_empty, InvalidOperation: "no columns provided in explode");
1015
if columns.is_empty() {
1016
return Ok(input);
1017
}
1018
let function = FunctionIR::Explode {
1019
columns: columns.into_iter().collect(),
1020
options,
1021
schema: Default::default(),
1022
};
1023
let ir = IR::MapFunction { input, function };
1024
return Ok(ctxt.lp_arena.add(ir));
1025
},
1026
DslFunction::FillNan(fill_value) => {
1027
let exprs = input_schema
1028
.iter()
1029
.filter_map(|(name, dtype)| match dtype {
1030
DataType::Float16 | DataType::Float32 | DataType::Float64 => Some(
1031
col(name.clone())
1032
.fill_nan(fill_value.clone())
1033
.alias(name.clone()),
1034
),
1035
_ => None,
1036
})
1037
.collect::<Vec<_>>();
1038
1039
let (exprs, schema) = resolve_with_columns(
1040
exprs,
1041
input,
1042
ctxt.lp_arena,
1043
ctxt.expr_arena,
1044
ctxt.opt_flags,
1045
)
1046
.map_err(|e| e.context(failed_here!(fill_nan)))?;
1047
1048
ctxt.conversion_optimizer
1049
.fill_scratch(&exprs, ctxt.expr_arena);
1050
1051
let lp = IR::HStack {
1052
input,
1053
exprs,
1054
schema,
1055
options: ProjectionOptions {
1056
duplicate_check: false,
1057
..Default::default()
1058
},
1059
};
1060
return run_conversion(lp, ctxt, "fill_nan");
1061
},
1062
DslFunction::Stats(sf) => {
1063
let exprs = match sf {
1064
StatsFunction::Var { ddof } => stats_helper(
1065
|dt| dt.is_primitive_numeric() || dt.is_bool() || dt.is_decimal(),
1066
|name| col(name.clone()).var(ddof),
1067
&input_schema,
1068
),
1069
StatsFunction::Std { ddof } => stats_helper(
1070
|dt| dt.is_primitive_numeric() || dt.is_bool() || dt.is_decimal(),
1071
|name| col(name.clone()).std(ddof),
1072
&input_schema,
1073
),
1074
StatsFunction::Quantile { quantile, method } => stats_helper(
1075
|dt| dt.is_primitive_numeric() || dt.is_decimal() || dt.is_temporal(),
1076
|name| col(name.clone()).quantile(quantile.clone(), method),
1077
&input_schema,
1078
),
1079
StatsFunction::Mean => stats_helper(
1080
|dt| {
1081
dt.is_primitive_numeric()
1082
|| dt.is_temporal()
1083
|| dt.is_bool()
1084
|| dt.is_decimal()
1085
},
1086
|name| col(name.clone()).mean(),
1087
&input_schema,
1088
),
1089
StatsFunction::Sum => stats_helper(
1090
|dt| {
1091
dt.is_primitive_numeric()
1092
|| dt.is_decimal()
1093
|| matches!(dt, DataType::Boolean | DataType::Duration(_))
1094
},
1095
|name| col(name.clone()).sum(),
1096
&input_schema,
1097
),
1098
StatsFunction::Min => stats_helper(
1099
|dt| dt.is_ord(),
1100
|name| col(name.clone()).min(),
1101
&input_schema,
1102
),
1103
StatsFunction::Max => stats_helper(
1104
|dt| dt.is_ord(),
1105
|name| col(name.clone()).max(),
1106
&input_schema,
1107
),
1108
StatsFunction::Median => stats_helper(
1109
|dt| {
1110
dt.is_primitive_numeric()
1111
|| dt.is_temporal()
1112
|| dt == &DataType::Boolean
1113
},
1114
|name| col(name.clone()).median(),
1115
&input_schema,
1116
),
1117
};
1118
let schema = Arc::new(expressions_to_schema(
1119
&exprs,
1120
&input_schema,
1121
|duplicate_name: &str| duplicate_name.to_string(),
1122
)?);
1123
let eirs = to_expr_irs(
1124
exprs,
1125
&mut ExprToIRContext::new_with_opt_eager(
1126
ctxt.expr_arena,
1127
&input_schema,
1128
ctxt.opt_flags,
1129
),
1130
)?;
1131
1132
ctxt.conversion_optimizer
1133
.fill_scratch(&eirs, ctxt.expr_arena);
1134
1135
let lp = IR::Select {
1136
input,
1137
expr: eirs,
1138
schema,
1139
options: ProjectionOptions {
1140
duplicate_check: false,
1141
..Default::default()
1142
},
1143
};
1144
return run_conversion(lp, ctxt, "stats");
1145
},
1146
DslFunction::Rename {
1147
existing,
1148
new,
1149
strict,
1150
} => {
1151
assert_eq!(existing.len(), new.len());
1152
if existing.is_empty() {
1153
return Ok(input);
1154
}
1155
1156
let existing_lut =
1157
PlIndexSet::from_iter(existing.iter().map(PlSmallStr::as_str));
1158
1159
let mut schema = Schema::with_capacity(input_schema.len());
1160
let mut num_replaced = 0;
1161
1162
// Turn the rename into a select.
1163
let expr = input_schema
1164
.iter()
1165
.map(|(n, dtype)| {
1166
Ok(match existing_lut.get_index_of(n.as_str()) {
1167
None => {
1168
schema.try_insert(n.clone(), dtype.clone())?;
1169
Expr::Column(n.clone())
1170
},
1171
Some(i) => {
1172
num_replaced += 1;
1173
schema.try_insert(new[i].clone(), dtype.clone())?;
1174
Expr::Column(n.clone()).alias(new[i].clone())
1175
},
1176
})
1177
})
1178
.collect::<PolarsResult<Vec<_>>>()?;
1179
1180
if strict && num_replaced != existing.len() {
1181
let col = existing.iter().find(|c| !input_schema.contains(c)).unwrap();
1182
polars_bail!(col_not_found = col);
1183
}
1184
1185
// Nothing changed, make into a no-op.
1186
if num_replaced == 0 {
1187
return Ok(input);
1188
}
1189
1190
let expr = to_expr_irs(
1191
expr,
1192
&mut ExprToIRContext::new_with_opt_eager(
1193
ctxt.expr_arena,
1194
&input_schema,
1195
ctxt.opt_flags,
1196
),
1197
)?;
1198
ctxt.conversion_optimizer
1199
.fill_scratch(&expr, ctxt.expr_arena);
1200
1201
IR::Select {
1202
input,
1203
expr,
1204
schema: Arc::new(schema),
1205
options: ProjectionOptions {
1206
run_parallel: false,
1207
duplicate_check: false,
1208
should_broadcast: false,
1209
},
1210
}
1211
},
1212
_ => {
1213
let function = function.into_function_ir(&input_schema)?;
1214
IR::MapFunction { input, function }
1215
},
1216
}
1217
},
1218
DslPlan::ExtContext { input, contexts } => {
1219
let input = to_alp_impl(owned(input), ctxt)
1220
.map_err(|e| e.context(failed_here!(with_context)))?;
1221
let contexts = contexts
1222
.into_iter()
1223
.map(|lp| to_alp_impl(lp, ctxt))
1224
.collect::<PolarsResult<Vec<_>>>()
1225
.map_err(|e| e.context(failed_here!(with_context)))?;
1226
1227
let mut schema = (**ctxt.lp_arena.get(input).schema(ctxt.lp_arena)).clone();
1228
for input in &contexts {
1229
let other_schema = ctxt.lp_arena.get(*input).schema(ctxt.lp_arena);
1230
for fld in other_schema.iter_fields() {
1231
if schema.get(fld.name()).is_none() {
1232
schema.with_column(fld.name, fld.dtype);
1233
}
1234
}
1235
}
1236
1237
IR::ExtContext {
1238
input,
1239
contexts,
1240
schema: Arc::new(schema),
1241
}
1242
},
1243
DslPlan::Sink { input, payload } => {
1244
let input =
1245
to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_here!(sink)))?;
1246
let input_schema = ctxt.lp_arena.get(input).schema(ctxt.lp_arena);
1247
let payload = match payload {
1248
SinkType::Memory => SinkTypeIR::Memory,
1249
SinkType::Callback(f) => SinkTypeIR::Callback(f),
1250
SinkType::File(options) => SinkTypeIR::File(options),
1251
SinkType::Partitioned(PartitionedSinkOptions {
1252
base_path,
1253
file_path_provider,
1254
partition_strategy,
1255
finish_callback,
1256
file_format,
1257
unified_sink_args,
1258
max_rows_per_file,
1259
approximate_bytes_per_file,
1260
}) => {
1261
let expr_to_ir_cx = &mut ExprToIRContext::new_with_opt_eager(
1262
ctxt.expr_arena,
1263
&input_schema,
1264
ctxt.opt_flags,
1265
);
1266
1267
let partition_strategy = match partition_strategy {
1268
PartitionStrategy::Keyed {
1269
keys,
1270
include_keys,
1271
keys_pre_grouped,
1272
per_partition_sort_by,
1273
} => {
1274
let keys = to_expr_irs(keys, expr_to_ir_cx)?;
1275
let per_partition_sort_by: Vec<SortColumnIR> = per_partition_sort_by
1276
.into_iter()
1277
.map(|s| {
1278
let SortColumn {
1279
expr,
1280
descending,
1281
nulls_last,
1282
} = s;
1283
Ok(SortColumnIR {
1284
expr: to_expr_ir(expr, expr_to_ir_cx)?,
1285
descending,
1286
nulls_last,
1287
})
1288
})
1289
.collect::<PolarsResult<_>>()?;
1290
1291
PartitionStrategyIR::Keyed {
1292
keys,
1293
include_keys,
1294
keys_pre_grouped,
1295
per_partition_sort_by,
1296
}
1297
},
1298
PartitionStrategy::FileSize => PartitionStrategyIR::FileSize,
1299
};
1300
1301
let options = PartitionedSinkOptionsIR {
1302
base_path,
1303
file_path_provider: file_path_provider.unwrap_or_else(|| {
1304
FileProviderType::Hive {
1305
extension: PlSmallStr::from_static(file_format.extension()),
1306
}
1307
}),
1308
partition_strategy,
1309
finish_callback,
1310
file_format,
1311
unified_sink_args,
1312
max_rows_per_file,
1313
approximate_bytes_per_file,
1314
};
1315
1316
ctxt.conversion_optimizer
1317
.fill_scratch(options.expr_irs_iter(), ctxt.expr_arena);
1318
1319
SinkTypeIR::Partitioned(options)
1320
},
1321
};
1322
1323
let lp = IR::Sink { input, payload };
1324
return run_conversion(lp, ctxt, "sink");
1325
},
1326
DslPlan::SinkMultiple { inputs } => {
1327
let inputs = inputs
1328
.into_iter()
1329
.map(|lp| to_alp_impl(lp, ctxt))
1330
.collect::<PolarsResult<Vec<_>>>()
1331
.map_err(|e| e.context(failed_here!(vertical concat)))?;
1332
IR::SinkMultiple { inputs }
1333
},
1334
#[cfg(feature = "merge_sorted")]
1335
DslPlan::MergeSorted {
1336
input_left,
1337
input_right,
1338
key,
1339
} => {
1340
let input_left = to_alp_impl(owned(input_left), ctxt)
1341
.map_err(|e| e.context(failed_here!(merge_sorted)))?;
1342
let input_right = to_alp_impl(owned(input_right), ctxt)
1343
.map_err(|e| e.context(failed_here!(merge_sorted)))?;
1344
1345
let left_schema = ctxt.lp_arena.get(input_left).schema(ctxt.lp_arena);
1346
let right_schema = ctxt.lp_arena.get(input_right).schema(ctxt.lp_arena);
1347
1348
left_schema
1349
.ensure_is_exact_match(&right_schema)
1350
.map_err(|err| err.context("merge_sorted".into()))?;
1351
1352
left_schema
1353
.try_get(key.as_str())
1354
.map_err(|err| err.context("merge_sorted".into()))?;
1355
1356
IR::MergeSorted {
1357
input_left,
1358
input_right,
1359
key,
1360
}
1361
},
1362
DslPlan::IR { node, dsl, version } => {
1363
return match node {
1364
Some(node)
1365
if version == ctxt.lp_arena.version()
1366
&& ctxt.conversion_optimizer.used_arenas.insert(version) =>
1367
{
1368
Ok(node)
1369
},
1370
_ => to_alp_impl(owned(dsl), ctxt),
1371
};
1372
},
1373
};
1374
Ok(ctxt.lp_arena.add(v))
1375
}
1376
1377
fn resolve_with_columns(
1378
exprs: Vec<Expr>,
1379
input: Node,
1380
lp_arena: &Arena<IR>,
1381
expr_arena: &mut Arena<AExpr>,
1382
opt_flags: &mut OptFlags,
1383
) -> PolarsResult<(Vec<ExprIR>, SchemaRef)> {
1384
let input_schema = lp_arena.get(input).schema(lp_arena);
1385
let mut output_schema = (**input_schema).clone();
1386
let exprs = rewrite_projections(exprs, &PlHashSet::new(), &input_schema, opt_flags)?;
1387
let mut output_names = PlHashSet::with_capacity(exprs.len());
1388
1389
let eirs = to_expr_irs(
1390
exprs,
1391
&mut ExprToIRContext::new_with_opt_eager(expr_arena, &input_schema, opt_flags),
1392
)?;
1393
for eir in eirs.iter() {
1394
let field = eir.field(&input_schema, expr_arena)?;
1395
1396
if !output_names.insert(field.name().clone()) {
1397
let msg = format!(
1398
"the name '{}' passed to `LazyFrame.with_columns` is duplicate\n\n\
1399
It's possible that multiple expressions are returning the same default column name. \
1400
If this is the case, try renaming the columns with `.alias(\"new_name\")` to avoid \
1401
duplicate column names.",
1402
field.name()
1403
);
1404
polars_bail!(ComputeError: msg)
1405
}
1406
output_schema.with_column(field.name, field.dtype.materialize_unknown(true)?);
1407
}
1408
1409
Ok((eirs, Arc::new(output_schema)))
1410
}
1411
1412
fn resolve_group_by(
1413
input: Node,
1414
keys: Vec<Expr>,
1415
aggs: Vec<Expr>,
1416
_options: &GroupbyOptions,
1417
lp_arena: &Arena<IR>,
1418
expr_arena: &mut Arena<AExpr>,
1419
opt_flags: &mut OptFlags,
1420
) -> PolarsResult<(Vec<ExprIR>, Vec<ExprIR>, SchemaRef)> {
1421
let input_schema = lp_arena.get(input).schema(lp_arena);
1422
let input_schema = input_schema.as_ref();
1423
let mut keys = rewrite_projections(keys, &PlHashSet::default(), input_schema, opt_flags)?;
1424
1425
// Initialize schema from keys
1426
let mut output_schema = expressions_to_schema(&keys, input_schema, |duplicate_name: &str| {
1427
format!("group_by keys contained duplicate output name '{duplicate_name}'")
1428
})?;
1429
let mut key_names: PlHashSet<PlSmallStr> = output_schema.iter_names().cloned().collect();
1430
1431
#[allow(unused_mut)]
1432
let mut pop_keys = false;
1433
// Add dynamic groupby index column(s)
1434
// Also add index columns to keys for expression expansion.
1435
#[cfg(feature = "dynamic_group_by")]
1436
{
1437
if let Some(options) = _options.rolling.as_ref() {
1438
let name = options.index_column.clone();
1439
let dtype = input_schema.try_get(name.as_str())?;
1440
keys.push(col(name.clone()));
1441
key_names.insert(name.clone());
1442
pop_keys = true;
1443
output_schema.with_column(name.clone(), dtype.clone());
1444
} else if let Some(options) = _options.dynamic.as_ref() {
1445
let name = options.index_column.clone();
1446
keys.push(col(name.clone()));
1447
key_names.insert(name.clone());
1448
pop_keys = true;
1449
let dtype = input_schema.try_get(name.as_str())?;
1450
if options.include_boundaries {
1451
output_schema.with_column("_lower_boundary".into(), dtype.clone());
1452
output_schema.with_column("_upper_boundary".into(), dtype.clone());
1453
}
1454
output_schema.with_column(name.clone(), dtype.clone());
1455
}
1456
}
1457
let keys_index_len = output_schema.len();
1458
if pop_keys {
1459
let _ = keys.pop();
1460
}
1461
let keys = to_expr_irs(
1462
keys,
1463
&mut ExprToIRContext::new_with_opt_eager(expr_arena, input_schema, opt_flags),
1464
)?;
1465
1466
// Add aggregation column(s)
1467
let aggs = rewrite_projections(aggs, &key_names, input_schema, opt_flags)?;
1468
let aggs = to_expr_irs(
1469
aggs,
1470
&mut ExprToIRContext::new_with_opt_eager(expr_arena, input_schema, opt_flags),
1471
)?;
1472
utils::validate_expressions(&keys, expr_arena, input_schema, "group by")?;
1473
utils::validate_expressions(&aggs, expr_arena, input_schema, "group by")?;
1474
1475
let mut aggs_schema = expr_irs_to_schema(&aggs, input_schema, expr_arena)?;
1476
1477
// Make sure aggregation columns do not contain duplicates
1478
if aggs_schema.len() < aggs.len() {
1479
let mut names = PlHashSet::with_capacity(aggs.len());
1480
for agg in aggs.iter() {
1481
let name = agg.output_name();
1482
polars_ensure!(names.insert(name.clone()), duplicate = name)
1483
}
1484
}
1485
1486
// Coerce aggregation column(s) into List unless not needed (auto-implode)
1487
debug_assert!(aggs_schema.len() == aggs.len());
1488
for ((_name, dtype), expr) in aggs_schema.iter_mut().zip(&aggs) {
1489
if !expr.is_scalar(expr_arena) {
1490
*dtype = dtype.clone().implode();
1491
}
1492
}
1493
1494
// Final output_schema
1495
output_schema.merge(aggs_schema);
1496
1497
// Make sure aggregation columns do not contain keys or index columns
1498
if output_schema.len() < (keys_index_len + aggs.len()) {
1499
let mut names = PlHashSet::with_capacity(output_schema.len());
1500
for agg in aggs.iter().chain(keys.iter()) {
1501
let name = agg.output_name();
1502
polars_ensure!(names.insert(name.clone()), duplicate = name)
1503
}
1504
}
1505
1506
Ok((keys, aggs, Arc::new(output_schema)))
1507
}
1508
1509
fn stats_helper<F, E>(condition: F, expr: E, schema: &Schema) -> Vec<Expr>
1510
where
1511
F: Fn(&DataType) -> bool,
1512
E: Fn(&PlSmallStr) -> Expr,
1513
{
1514
schema
1515
.iter()
1516
.map(|(name, dt)| {
1517
if condition(dt) {
1518
expr(name)
1519
} else {
1520
lit(NULL).cast(dt.clone()).alias(name.clone())
1521
}
1522
})
1523
.collect()
1524
}
1525
1526
pub(crate) fn maybe_init_projection_excluding_hive(
1527
reader_schema: &Either<ArrowSchemaRef, SchemaRef>,
1528
hive_parts: Option<&SchemaRef>,
1529
) -> Option<Arc<[PlSmallStr]>> {
1530
// Update `with_columns` with a projection so that hive columns aren't loaded from the
1531
// file
1532
let hive_schema = hive_parts?;
1533
1534
match &reader_schema {
1535
Either::Left(reader_schema) => hive_schema
1536
.iter_names()
1537
.any(|x| reader_schema.contains(x))
1538
.then(|| {
1539
reader_schema
1540
.iter_names_cloned()
1541
.filter(|x| !hive_schema.contains(x))
1542
.collect::<Arc<[_]>>()
1543
}),
1544
Either::Right(reader_schema) => hive_schema
1545
.iter_names()
1546
.any(|x| reader_schema.contains(x))
1547
.then(|| {
1548
reader_schema
1549
.iter_names_cloned()
1550
.filter(|x| !hive_schema.contains(x))
1551
.collect::<Arc<[_]>>()
1552
}),
1553
}
1554
}
1555
1556