Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-stream/src/physical_plan/lower_group_by.rs
8503 views
1
use std::sync::Arc;
2
3
use parking_lot::Mutex;
4
use polars_core::frame::DataFrame;
5
use polars_core::prelude::{Field, InitHashMaps, PlIndexMap, PlIndexSet, SortMultipleOptions};
6
use polars_core::schema::Schema;
7
use polars_error::{PolarsResult, polars_err};
8
use polars_expr::state::ExecutionState;
9
use polars_mem_engine::create_physical_plan;
10
use polars_plan::plans::expr_ir::{ExprIR, OutputName};
11
use polars_plan::plans::{AExpr, IR, IRAggExpr, IRFunctionExpr, NaiveExprMerger, write_group_by};
12
use polars_plan::prelude::{GroupbyOptions, *};
13
use polars_utils::arena::{Arena, Node};
14
use polars_utils::pl_str::PlSmallStr;
15
use polars_utils::{IdxSize, unique_column_name};
16
use recursive::recursive;
17
use slotmap::SlotMap;
18
19
use super::{ExprCache, PhysNode, PhysNodeKey, PhysNodeKind, PhysStream, StreamingLowerIRContext};
20
use crate::physical_plan::lower_expr::{
21
build_hstack_stream, build_select_stream, compute_output_schema, is_elementwise_rec_cached,
22
is_fake_elementwise_function, is_input_independent,
23
};
24
use crate::physical_plan::lower_ir::{
25
build_filter_stream, build_row_idx_stream, build_slice_stream,
26
};
27
use crate::utils::late_materialized_df::LateMaterializedDataFrame;
28
29
#[allow(clippy::too_many_arguments)]
30
fn build_group_by_fallback(
31
input: PhysStream,
32
keys: &[ExprIR],
33
aggs: &[ExprIR],
34
output_schema: Arc<Schema>,
35
maintain_order: bool,
36
options: Arc<GroupbyOptions>,
37
apply: Option<PlanCallback<DataFrame, DataFrame>>,
38
expr_arena: &mut Arena<AExpr>,
39
phys_sm: &mut SlotMap<PhysNodeKey, PhysNode>,
40
format_str: Option<String>,
41
) -> PolarsResult<PhysStream> {
42
let input_schema = phys_sm[input.node].output_schema.clone();
43
let lmdf = Arc::new(LateMaterializedDataFrame::default());
44
let mut lp_arena = Arena::default();
45
let input_lp_node = lp_arena.add(lmdf.clone().as_ir_node(input_schema));
46
let group_by_lp_node = lp_arena.add(IR::GroupBy {
47
input: input_lp_node,
48
keys: keys.to_vec(),
49
aggs: aggs.to_vec(),
50
schema: output_schema.clone(),
51
maintain_order,
52
options,
53
apply,
54
});
55
let executor = Mutex::new(create_physical_plan(
56
group_by_lp_node,
57
&mut lp_arena,
58
expr_arena,
59
Some(crate::dispatch::build_streaming_query_executor),
60
)?);
61
62
let group_by_node = PhysNode {
63
output_schema,
64
kind: PhysNodeKind::InMemoryMap {
65
input,
66
map: Arc::new(move |df| {
67
lmdf.set_materialized_dataframe(df);
68
let mut state = ExecutionState::new();
69
executor.lock().execute(&mut state)
70
}),
71
format_str,
72
},
73
};
74
75
Ok(PhysStream::first(phys_sm.insert(group_by_node)))
76
}
77
78
// Given an aggregate expression returns a column expression which is to
79
// represent the aggregate result in the post-select.
80
//
81
// For each input to this aggregate uniq_input_names is updated to map the
82
// unique id of the input expressions to an input columns the aggregate
83
// expression expects.
84
//
85
// uniq_agg_exprs is updated with the unique id of the aggregate mapping to
86
// the aggregate expression and vector of unique input ids for that aggregate.
87
#[allow(clippy::too_many_arguments)]
88
fn replace_agg_uniq(
89
expr: Node,
90
expr_merger: &mut NaiveExprMerger,
91
expr_cache: &mut ExprCache,
92
expr_arena: &mut Arena<AExpr>,
93
agg_exprs: &mut Vec<ExprIR>,
94
uniq_input_names: &mut PlIndexMap<u32, PlSmallStr>,
95
uniq_agg_exprs: &mut PlIndexMap<u32, (ExprIR, Vec<u32>)>,
96
uniq_elementwise_exprs: &mut PlIndexMap<u32, ExprIR>,
97
) -> Node {
98
let aexpr = expr_arena.get(expr).clone();
99
let mut inputs = Vec::new();
100
aexpr.inputs_rev(&mut inputs);
101
inputs.reverse();
102
103
let agg_id = expr_merger.get_uniq_id(expr).unwrap();
104
let name = uniq_agg_exprs
105
.entry(agg_id)
106
.or_insert_with(|| {
107
let mut input_ids = Vec::new();
108
let input_cols = inputs
109
.iter()
110
.map(|input| {
111
let (input_id, node) = replace_elementwise_components(
112
*input,
113
expr_merger,
114
expr_cache,
115
expr_arena,
116
uniq_input_names,
117
uniq_elementwise_exprs,
118
);
119
if let Some(id) = input_id {
120
// Already elementwise.
121
input_ids.push(id);
122
node
123
} else {
124
let input_id = expr_merger.add_and_get_uniq_id(node, expr_arena);
125
input_ids.push(input_id);
126
let input_col = uniq_input_names
127
.entry(input_id)
128
.or_insert_with(unique_column_name)
129
.clone();
130
expr_arena.add(AExpr::Column(input_col))
131
}
132
})
133
.collect::<Vec<_>>();
134
let trans_agg_node = expr_arena.add(aexpr.replace_inputs(&input_cols));
135
136
// Add to aggregation expressions and replace with a reference to its output.
137
let agg_expr = ExprIR::new(trans_agg_node, OutputName::Alias(unique_column_name()));
138
agg_exprs.push(agg_expr.clone());
139
(agg_expr, input_ids)
140
})
141
.0
142
.output_name()
143
.clone();
144
expr_arena.add(AExpr::Column(name))
145
}
146
147
/// Replaces all elementwise subexpressions with column references, storing the elementwise
148
/// expressions uniquely in expr_merger/uniq_elementwise_exprs keys.
149
#[recursive]
150
fn replace_elementwise_components(
151
expr: Node,
152
expr_merger: &mut NaiveExprMerger,
153
expr_cache: &mut ExprCache,
154
expr_arena: &mut Arena<AExpr>,
155
uniq_input_names: &mut PlIndexMap<u32, PlSmallStr>,
156
uniq_elementwise_exprs: &mut PlIndexMap<u32, ExprIR>,
157
) -> (Option<u32>, Node) {
158
if is_elementwise_rec_cached(expr, expr_arena, expr_cache)
159
|| (is_input_independent(expr, expr_arena, expr_cache) && is_scalar_ae(expr, expr_arena))
160
{
161
let id = expr_merger.add_and_get_uniq_id(expr, expr_arena);
162
let name = uniq_input_names
163
.entry(id)
164
.or_insert_with(unique_column_name)
165
.clone();
166
let node = uniq_elementwise_exprs
167
.entry(id)
168
.or_insert_with(|| ExprIR::from_column_name(name, expr_arena))
169
.node();
170
(Some(id), node)
171
} else {
172
let aexpr = expr_arena.get(expr).clone();
173
let mut inputs = Vec::new();
174
aexpr.inputs_rev(&mut inputs);
175
inputs.reverse();
176
177
for input in &mut inputs {
178
*input = replace_elementwise_components(
179
*input,
180
expr_merger,
181
expr_cache,
182
expr_arena,
183
uniq_input_names,
184
uniq_elementwise_exprs,
185
)
186
.1;
187
}
188
let rec_node = expr_arena.add(aexpr.replace_inputs(&inputs));
189
(None, rec_node)
190
}
191
}
192
193
/// Tries to lower an expression as a 'elementwise scalar agg expression'.
194
///
195
/// Such an expression is defined as the elementwise combination of scalar
196
/// aggregations.
197
#[recursive]
198
#[allow(clippy::too_many_arguments)]
199
fn try_lower_elementwise_scalar_agg_expr(
200
expr: Node,
201
expr_merger: &mut NaiveExprMerger,
202
expr_cache: &mut ExprCache,
203
expr_arena: &mut Arena<AExpr>,
204
agg_exprs: &mut Vec<ExprIR>,
205
uniq_input_names: &mut PlIndexMap<u32, PlSmallStr>,
206
uniq_agg_exprs: &mut PlIndexMap<u32, (ExprIR, Vec<u32>)>,
207
uniq_elementwise_exprs: &mut PlIndexMap<u32, ExprIR>,
208
) -> Option<Node> {
209
// Helper macros to simplify (recursive) calls.
210
macro_rules! lower_rec {
211
($input:expr) => {
212
try_lower_elementwise_scalar_agg_expr(
213
$input,
214
expr_merger,
215
expr_cache,
216
expr_arena,
217
agg_exprs,
218
uniq_input_names,
219
uniq_agg_exprs,
220
uniq_elementwise_exprs,
221
)
222
};
223
}
224
225
macro_rules! replace_agg_uniq {
226
($input:expr) => {
227
replace_agg_uniq(
228
$input,
229
expr_merger,
230
expr_cache,
231
expr_arena,
232
agg_exprs,
233
uniq_input_names,
234
uniq_agg_exprs,
235
uniq_elementwise_exprs,
236
)
237
};
238
}
239
240
if is_input_independent(expr, expr_arena, expr_cache) {
241
if expr_arena.get(expr).is_scalar(expr_arena) {
242
return Some(expr);
243
} else {
244
let agg = IRAggExpr::Implode(expr);
245
return Some(expr_arena.add(AExpr::Agg(agg)));
246
}
247
}
248
249
match expr_arena.get(expr) {
250
// Should be handled separately in `Eval`.
251
AExpr::Element => unreachable!(),
252
253
AExpr::StructField(_) => {
254
// Reflecting StructEval expr state is not yet supported.
255
None
256
},
257
258
AExpr::Column(_) => {
259
// Implicit implode not yet supported.
260
None
261
},
262
263
AExpr::Literal(lit) => {
264
if lit.is_scalar() {
265
Some(expr)
266
} else {
267
None
268
}
269
},
270
271
#[cfg(feature = "dynamic_group_by")]
272
AExpr::Rolling { .. } => None,
273
274
AExpr::Slice { .. }
275
| AExpr::Over { .. }
276
| AExpr::Sort { .. }
277
| AExpr::SortBy { .. }
278
| AExpr::Gather { .. } => None,
279
280
// Explode and filter are row-separable and should thus in theory work
281
// in a streaming fashion but they change the length of the input which
282
// means the same filter/explode should also be applied to the key
283
// column, which is not (yet) supported.
284
AExpr::Explode { .. } | AExpr::Filter { .. } => None,
285
286
AExpr::BinaryExpr { left, op, right } => {
287
let (left, op, right) = (*left, *op, *right);
288
let left = lower_rec!(left)?;
289
let right = lower_rec!(right)?;
290
Some(expr_arena.add(AExpr::BinaryExpr { left, op, right }))
291
},
292
293
AExpr::Eval {
294
expr,
295
evaluation,
296
variant,
297
} => {
298
let (expr, evaluation, variant) = (*expr, *evaluation, *variant);
299
let expr = lower_rec!(expr)?;
300
Some(expr_arena.add(AExpr::Eval {
301
expr,
302
evaluation,
303
variant,
304
}))
305
},
306
307
AExpr::StructEval { expr, evaluation } => {
308
// @TODO: Reflect the lowering result of `expr` into the respective
309
// StructField lowering calls.
310
let (expr, evaluation) = (*expr, evaluation.clone());
311
let expr = lower_rec!(expr)?;
312
313
let new_evaluation = evaluation
314
.into_iter()
315
.map(|i| {
316
let new_node = lower_rec!(i.node())?;
317
Some(ExprIR::new(
318
new_node,
319
OutputName::Alias(i.output_name().clone()),
320
))
321
})
322
.collect::<Option<Vec<_>>>()?;
323
324
Some(expr_arena.add(AExpr::StructEval {
325
expr,
326
evaluation: new_evaluation,
327
}))
328
},
329
330
AExpr::Ternary {
331
predicate,
332
truthy,
333
falsy,
334
} => {
335
let (predicate, truthy, falsy) = (*predicate, *truthy, *falsy);
336
let predicate = lower_rec!(predicate)?;
337
let truthy = lower_rec!(truthy)?;
338
let falsy = lower_rec!(falsy)?;
339
Some(expr_arena.add(AExpr::Ternary {
340
predicate,
341
truthy,
342
falsy,
343
}))
344
},
345
346
#[cfg(feature = "bitwise")]
347
AExpr::Function {
348
function:
349
IRFunctionExpr::Bitwise(
350
IRBitwiseFunction::And | IRBitwiseFunction::Or | IRBitwiseFunction::Xor,
351
),
352
..
353
} => Some(replace_agg_uniq!(expr)),
354
355
#[cfg(feature = "approx_unique")]
356
AExpr::Function {
357
function: IRFunctionExpr::ApproxNUnique,
358
..
359
} => Some(replace_agg_uniq!(expr)),
360
361
AExpr::Function {
362
function:
363
IRFunctionExpr::Boolean(IRBooleanFunction::Any { .. } | IRBooleanFunction::All { .. })
364
| IRFunctionExpr::MinBy
365
| IRFunctionExpr::MaxBy
366
| IRFunctionExpr::NullCount,
367
..
368
} => Some(replace_agg_uniq!(expr)),
369
370
AExpr::AnonymousAgg { .. } => Some(replace_agg_uniq!(expr)),
371
372
node @ AExpr::Function { input, options, .. }
373
| node @ AExpr::AnonymousFunction { input, options, .. }
374
if options.is_elementwise() && !is_fake_elementwise_function(node) =>
375
{
376
let node = node.clone();
377
let input = input.clone();
378
let new_input = input
379
.into_iter()
380
.map(|i| {
381
// The function may be sensitive to names (e.g. pl.struct), so we restore them.
382
let new_node = lower_rec!(i.node())?;
383
Some(ExprIR::new(
384
new_node,
385
OutputName::Alias(i.output_name().clone()),
386
))
387
})
388
.collect::<Option<Vec<_>>>()?;
389
390
let mut new_node = node;
391
match &mut new_node {
392
AExpr::Function { input, .. } | AExpr::AnonymousFunction { input, .. } => {
393
*input = new_input;
394
},
395
_ => unreachable!(),
396
}
397
Some(expr_arena.add(new_node))
398
},
399
400
AExpr::Function { .. } | AExpr::AnonymousFunction { .. } => None,
401
402
AExpr::Cast {
403
expr,
404
dtype,
405
options,
406
} => {
407
let (expr, dtype, options) = (*expr, dtype.clone(), *options);
408
let expr = lower_rec!(expr)?;
409
Some(expr_arena.add(AExpr::Cast {
410
expr,
411
dtype,
412
options,
413
}))
414
},
415
416
AExpr::Agg(agg) => {
417
match agg {
418
IRAggExpr::Min { .. }
419
| IRAggExpr::Max { .. }
420
| IRAggExpr::First(_)
421
| IRAggExpr::FirstNonNull(_)
422
| IRAggExpr::Last(_)
423
| IRAggExpr::LastNonNull(_)
424
| IRAggExpr::Item { .. }
425
| IRAggExpr::Mean(_)
426
| IRAggExpr::Sum(_)
427
| IRAggExpr::Var(..)
428
| IRAggExpr::Std(..)
429
| IRAggExpr::Count { .. } => Some(replace_agg_uniq!(expr)),
430
IRAggExpr::NUnique(uniq_input) => {
431
let function = IRFunctionExpr::Unique(false);
432
let uniq_input_expr = ExprIR::from_node(*uniq_input, expr_arena);
433
let uniq_node = expr_arena.add(AExpr::Function {
434
input: vec![uniq_input_expr],
435
options: function.function_options(),
436
function,
437
});
438
439
let count = IRAggExpr::Count {
440
input: uniq_node,
441
include_nulls: true,
442
};
443
let count_node = expr_arena.add(AExpr::Agg(count));
444
expr_merger.add_expr(count_node, expr_arena);
445
Some(replace_agg_uniq!(count_node))
446
},
447
IRAggExpr::Median(..)
448
| IRAggExpr::Implode(..)
449
| IRAggExpr::Quantile { .. }
450
| IRAggExpr::AggGroups(..) => None, // TODO: allow all aggregates,
451
}
452
},
453
AExpr::Len => {
454
let agg_id = expr_merger.get_uniq_id(expr).unwrap();
455
let name = uniq_agg_exprs
456
.entry(agg_id)
457
.or_insert_with(|| {
458
let agg_expr = ExprIR::new(expr, OutputName::Alias(unique_column_name()));
459
agg_exprs.push(agg_expr.clone());
460
(agg_expr, Vec::new())
461
})
462
.0
463
.output_name()
464
.clone();
465
Some(expr_arena.add(AExpr::Column(name)))
466
},
467
}
468
}
469
470
#[allow(clippy::too_many_arguments)]
471
fn try_lower_agg_input_expr(
472
input_stream: PhysStream,
473
keys: &[ExprIR],
474
expr: Node,
475
expr_arena: &mut Arena<AExpr>,
476
phys_sm: &mut SlotMap<PhysNodeKey, PhysNode>,
477
expr_cache: &mut ExprCache,
478
ctx: StreamingLowerIRContext,
479
) -> PolarsResult<Option<(PhysStream, Node, /* all_keys_included */ bool)>> {
480
if is_elementwise_rec_cached(expr, expr_arena, expr_cache) {
481
return Ok(Some((input_stream, expr, true)));
482
}
483
484
match expr_arena.get(expr) {
485
AExpr::Function {
486
input: uniq_input,
487
function: IRFunctionExpr::Unique(stable),
488
options: _,
489
} => {
490
assert!(uniq_input.len() == 1);
491
let input_node = uniq_input[0].node();
492
let maintain_order = *stable;
493
494
let Some((stream, node, all_keys_included)) = try_lower_agg_input_expr(
495
input_stream,
496
keys,
497
input_node,
498
expr_arena,
499
phys_sm,
500
expr_cache,
501
ctx,
502
)?
503
else {
504
return Ok(None);
505
};
506
507
let output_name = unique_column_name();
508
let mut gb_keys = keys.to_vec();
509
gb_keys.push(ExprIR::new(node, OutputName::Alias(output_name.clone())));
510
511
let aggs = &[];
512
let options = Arc::new(GroupbyOptions::default());
513
let Some(stream) = try_build_streaming_group_by(
514
stream,
515
&gb_keys,
516
aggs,
517
maintain_order,
518
options,
519
None,
520
expr_arena,
521
phys_sm,
522
expr_cache,
523
ctx,
524
)?
525
else {
526
return Ok(None);
527
};
528
529
let trans_output = expr_arena.add(AExpr::Column(output_name));
530
Ok(Some((stream, trans_output, all_keys_included)))
531
},
532
533
AExpr::Filter {
534
input: filter_input,
535
by: predicate,
536
} => {
537
if !is_elementwise_rec_cached(*filter_input, expr_arena, expr_cache)
538
|| !is_elementwise_rec_cached(*predicate, expr_arena, expr_cache)
539
{
540
return Ok(None);
541
}
542
543
let output_name = unique_column_name();
544
let predicate_name = unique_column_name();
545
let mut select_exprs = keys.to_vec();
546
select_exprs.push(ExprIR::new(
547
*filter_input,
548
OutputName::Alias(output_name.clone()),
549
));
550
select_exprs.push(ExprIR::new(
551
*predicate,
552
OutputName::Alias(predicate_name.clone()),
553
));
554
555
let mut stream = build_select_stream(
556
input_stream,
557
&select_exprs,
558
expr_arena,
559
phys_sm,
560
expr_cache,
561
ctx,
562
)?;
563
stream = build_filter_stream(
564
stream,
565
ExprIR::from_column_name(predicate_name, expr_arena),
566
expr_arena,
567
phys_sm,
568
expr_cache,
569
ctx,
570
)?;
571
572
let trans_output = expr_arena.add(AExpr::Column(output_name));
573
Ok(Some((stream, trans_output, false)))
574
},
575
_ => Ok(None),
576
}
577
}
578
579
#[allow(clippy::too_many_arguments)]
580
fn try_build_streaming_group_by(
581
mut input: PhysStream,
582
keys: &[ExprIR],
583
aggs: &[ExprIR],
584
maintain_order: bool,
585
options: Arc<GroupbyOptions>,
586
apply: Option<PlanCallback<DataFrame, DataFrame>>,
587
expr_arena: &mut Arena<AExpr>,
588
phys_sm: &mut SlotMap<PhysNodeKey, PhysNode>,
589
expr_cache: &mut ExprCache,
590
ctx: StreamingLowerIRContext,
591
) -> PolarsResult<Option<PhysStream>> {
592
if apply.is_some() {
593
return Ok(None); // TODO
594
}
595
596
#[cfg(feature = "dynamic_group_by")]
597
if options.dynamic.is_some() || options.rolling.is_some() {
598
return Ok(None); // TODO
599
}
600
601
if keys.is_empty() {
602
return Err(
603
polars_err!(ComputeError: "at least one key is required in a group_by operation"),
604
);
605
}
606
607
// Not supported yet.
608
let all_independent = keys
609
.iter()
610
.chain(aggs.iter())
611
.all(|expr| is_input_independent(expr.node(), expr_arena, expr_cache));
612
if all_independent {
613
return Ok(None);
614
}
615
616
// Augment with row index if maintaining order.
617
let row_idx_name = unique_column_name();
618
let row_idx_node = expr_arena.add(AExpr::Column(row_idx_name.clone()));
619
let mut agg_storage;
620
let aggs = if maintain_order {
621
input = build_row_idx_stream(input, row_idx_name.clone(), None, phys_sm);
622
let first_agg_node = expr_arena.add(AExpr::Agg(IRAggExpr::First(row_idx_node)));
623
agg_storage = aggs.to_vec();
624
agg_storage.push(ExprIR::from_node(first_agg_node, expr_arena));
625
&agg_storage
626
} else {
627
aggs
628
};
629
630
// Fill all expressions into the merger, letting us extract common subexpressions later.
631
let mut expr_merger = NaiveExprMerger::default();
632
for key in keys {
633
expr_merger.add_expr(key.node(), expr_arena);
634
}
635
for agg in aggs {
636
expr_merger.add_expr(agg.node(), expr_arena);
637
}
638
639
// Extract aggregates, input expressions for those aggregates and replace
640
// with agg node output columns.
641
let mut uniq_input_names = PlIndexMap::new();
642
let mut key_ids = PlIndexSet::new();
643
let mut trans_agg_exprs = Vec::new();
644
let mut trans_keys = Vec::new();
645
let mut trans_output_exprs = Vec::new();
646
for key in keys {
647
let key_id = expr_merger.get_uniq_id(key.node()).unwrap();
648
key_ids.insert(key_id);
649
let key_name = uniq_input_names
650
.entry(key_id)
651
.or_insert_with(|| {
652
let key_name = unique_column_name();
653
trans_keys.push(ExprIR::from_column_name(key_name.clone(), expr_arena));
654
key_name
655
})
656
.clone();
657
658
let output_name = OutputName::Alias(key.output_name().clone());
659
let trans_output_node = expr_arena.add(AExpr::Column(key_name));
660
trans_output_exprs.push(ExprIR::new(trans_output_node, output_name));
661
}
662
663
// Maps aggregation expression ids to output column expression and a vec
664
// of input expression ids.
665
let mut uniq_agg_exprs = PlIndexMap::new();
666
667
// Maps elementwise input expression ids to column expression.
668
let mut uniq_elementwise_exprs = PlIndexMap::new();
669
670
for agg in aggs {
671
let Some(trans_node) = try_lower_elementwise_scalar_agg_expr(
672
agg.node(),
673
&mut expr_merger,
674
expr_cache,
675
expr_arena,
676
&mut trans_agg_exprs,
677
&mut uniq_input_names,
678
&mut uniq_agg_exprs,
679
&mut uniq_elementwise_exprs,
680
) else {
681
return Ok(None);
682
};
683
let output_name = OutputName::Alias(agg.output_name().clone());
684
trans_output_exprs.push(ExprIR::new(trans_node, output_name));
685
}
686
687
// We must lower the keys together with the elementwise inputs to the aggregations.
688
let mut pre_select_input_ids = key_ids.clone();
689
pre_select_input_ids.extend(uniq_elementwise_exprs.keys());
690
691
let mut pre_select_exprs = Vec::new();
692
for uniq_id in pre_select_input_ids {
693
let name = &uniq_input_names[&uniq_id];
694
let node = expr_merger.get_node(uniq_id).unwrap();
695
pre_select_exprs.push(ExprIR::new(node, OutputName::Alias(name.clone())));
696
}
697
698
// If all inputs are input independent add a dummy column so the group sizes are correct. See #23868.
699
let mut direct_input_needed = false;
700
if pre_select_exprs
701
.iter()
702
.all(|e| is_input_independent(e.node(), expr_arena, expr_cache))
703
{
704
direct_input_needed = true;
705
let dummy_col_name = phys_sm[input.node].output_schema.get_at_index(0).unwrap().0;
706
let dummy_col = expr_arena.add(AExpr::Column(dummy_col_name.clone()));
707
pre_select_exprs.push(ExprIR::new(
708
dummy_col,
709
OutputName::ColumnLhs(dummy_col_name.clone()),
710
));
711
}
712
713
// Create pre-select.
714
let pre_select = build_select_stream(
715
input,
716
&pre_select_exprs,
717
expr_arena,
718
phys_sm,
719
expr_cache,
720
ctx,
721
)?;
722
723
// Create input streams.
724
let mut all_keys_included_in_other_inputs = false;
725
let mut aggs_with_elementwise_inputs = Vec::new();
726
let mut other_agg_input_streams = PlIndexMap::new();
727
for (_uniq_agg_id, (agg_expr, input_ids)) in uniq_agg_exprs.iter() {
728
if input_ids
729
.iter()
730
.all(|i| uniq_elementwise_exprs.contains_key(i))
731
{
732
aggs_with_elementwise_inputs.push(agg_expr.clone());
733
direct_input_needed = true;
734
continue;
735
}
736
737
// More than one non-elementwise input to this agg, unsure how to handle this.
738
if input_ids.len() != 1 {
739
return Ok(None);
740
}
741
742
let input_id = input_ids[0];
743
let input_node = expr_merger.get_node(input_id).unwrap();
744
let input_name = uniq_input_names[&input_id].clone();
745
if !other_agg_input_streams.contains_key(&input_id) {
746
let Some((stream, trans_node, keys_included)) = try_lower_agg_input_expr(
747
pre_select,
748
&trans_keys,
749
input_node,
750
expr_arena,
751
phys_sm,
752
expr_cache,
753
ctx,
754
)?
755
else {
756
return Ok(None);
757
};
758
all_keys_included_in_other_inputs |= keys_included;
759
let mut trans_stream_outputs = trans_keys.clone();
760
trans_stream_outputs.push(ExprIR::new(trans_node, OutputName::Alias(input_name)));
761
let stream = build_select_stream(
762
stream,
763
&trans_stream_outputs,
764
expr_arena,
765
phys_sm,
766
expr_cache,
767
ctx,
768
)?;
769
other_agg_input_streams.insert(input_id, (stream, Vec::new()));
770
}
771
772
other_agg_input_streams[&input_id].1.push(agg_expr.clone());
773
}
774
775
// Reconstruct the output schema of this node.
776
let mut group_by_output_schema = Schema::default();
777
let mut inputs = Vec::new();
778
let mut key_per_input = Vec::new();
779
let mut aggs_per_input = Vec::new();
780
if direct_input_needed || !all_keys_included_in_other_inputs {
781
let this_input_schema = &phys_sm[pre_select.node].output_schema;
782
let exprs = [
783
trans_keys.as_slice(),
784
aggs_with_elementwise_inputs.as_slice(),
785
]
786
.concat();
787
let elementwise_out_schema =
788
compute_output_schema(this_input_schema, &exprs, expr_arena).unwrap();
789
group_by_output_schema.merge((*elementwise_out_schema).clone());
790
inputs.push(pre_select);
791
key_per_input.push(trans_keys.clone());
792
aggs_per_input.push(aggs_with_elementwise_inputs);
793
}
794
for (_input_id, (stream, aggs)) in other_agg_input_streams {
795
let this_input_schema = &phys_sm[stream.node].output_schema;
796
let exprs = [trans_keys.as_slice(), aggs.as_slice()].concat();
797
let this_out_schema = compute_output_schema(this_input_schema, &exprs, expr_arena).unwrap();
798
group_by_output_schema.merge((*this_out_schema).clone());
799
inputs.push(stream);
800
key_per_input.push(trans_keys.clone());
801
aggs_per_input.push(aggs);
802
}
803
let group_by_output_schema = Arc::new(group_by_output_schema);
804
805
let agg_node = phys_sm.insert(PhysNode::new(
806
group_by_output_schema.clone(),
807
PhysNodeKind::GroupBy {
808
inputs,
809
key_per_input,
810
aggs_per_input,
811
},
812
));
813
814
// Sort the input based on the first row index if maintaining order.
815
let post_select_input = if maintain_order {
816
let sort_node = phys_sm.insert(PhysNode::new(
817
group_by_output_schema,
818
PhysNodeKind::Sort {
819
input: PhysStream::first(agg_node),
820
by_column: vec![trans_output_exprs.last().unwrap().clone()],
821
slice: None,
822
sort_options: SortMultipleOptions::new(),
823
},
824
));
825
trans_output_exprs.pop(); // Remove row idx from post-select.
826
PhysStream::first(sort_node)
827
} else {
828
PhysStream::first(agg_node)
829
};
830
831
let post_select = build_select_stream(
832
post_select_input,
833
&trans_output_exprs,
834
expr_arena,
835
phys_sm,
836
expr_cache,
837
ctx,
838
)?;
839
840
let out = if let Some((offset, len)) = options.slice {
841
build_slice_stream(post_select, offset, len, phys_sm)
842
} else {
843
post_select
844
};
845
Ok(Some(out))
846
}
847
848
#[expect(clippy::too_many_arguments)]
849
pub fn try_build_sorted_group_by(
850
input: PhysStream,
851
keys: &[ExprIR],
852
aggs: &[ExprIR],
853
output_schema: Arc<Schema>,
854
maintain_order: bool,
855
options: Arc<GroupbyOptions>,
856
apply: Option<PlanCallback<DataFrame, DataFrame>>,
857
expr_arena: &mut Arena<AExpr>,
858
phys_sm: &mut SlotMap<PhysNodeKey, PhysNode>,
859
expr_cache: &mut ExprCache,
860
ctx: StreamingLowerIRContext,
861
are_keys_sorted: bool,
862
) -> PolarsResult<Option<PhysStream>> {
863
let input_schema = phys_sm[input.node].output_schema.as_ref();
864
865
if keys.is_empty()
866
|| apply.is_some()
867
|| options.is_rolling()
868
|| options.is_dynamic()
869
|| (!are_keys_sorted && maintain_order)
870
|| keys.iter().any(|k| {
871
k.dtype(input_schema, expr_arena)
872
.is_ok_and(|dtype| dtype.contains_unknown())
873
})
874
{
875
return Ok(None);
876
}
877
878
let mut input = input;
879
let mut input_column = unique_column_name();
880
let mut projected = false;
881
let mut row_encoded: Option<Vec<Field>> = None;
882
883
if keys.len() > 1 || keys[0].dtype(input_schema, expr_arena)?.is_nested() {
884
let key_fields = keys
885
.iter()
886
.map(|k| k.field(input_schema, expr_arena))
887
.collect::<PolarsResult<Vec<_>>>()?;
888
let expr = AExprBuilder::function(
889
keys.to_vec(),
890
IRFunctionExpr::RowEncode(
891
key_fields.iter().map(|k| k.dtype().clone()).collect(),
892
RowEncodingVariant::Ordered {
893
descending: None,
894
nulls_last: None,
895
broadcast_nulls: None,
896
},
897
),
898
expr_arena,
899
)
900
.expr_ir(input_column.clone());
901
input = build_hstack_stream(input, &[expr], expr_arena, phys_sm, expr_cache, ctx)?;
902
projected = true;
903
row_encoded = Some(key_fields);
904
} else if !matches!(expr_arena.get(keys[0].node()), AExpr::Column(c) if c == keys[0].output_name())
905
{
906
input = build_hstack_stream(
907
input,
908
&[keys[0].with_alias(input_column.clone())],
909
expr_arena,
910
phys_sm,
911
expr_cache,
912
ctx,
913
)?;
914
projected = true;
915
} else {
916
input_column = keys[0].output_name().clone();
917
}
918
919
let key = AExprBuilder::col(input_column.clone(), expr_arena).expr_ir(input_column.clone());
920
921
let schema = phys_sm[input.node].output_schema.clone();
922
if !are_keys_sorted {
923
let row_idx_name = unique_column_name();
924
input = build_row_idx_stream(input, row_idx_name.clone(), None, phys_sm);
925
926
let row_idx_expr =
927
AExprBuilder::col(row_idx_name.clone(), expr_arena).expr_ir(row_idx_name.clone());
928
929
input = PhysStream::first(phys_sm.insert(PhysNode {
930
output_schema: phys_sm[input.node].output_schema.clone(),
931
kind: PhysNodeKind::Sort {
932
input,
933
by_column: vec![key, row_idx_expr],
934
slice: None,
935
sort_options: SortMultipleOptions::default(),
936
},
937
}));
938
}
939
940
let mut gb_output_schema = Schema::with_capacity(aggs.len() + 1);
941
gb_output_schema.insert(
942
input_column.clone(),
943
schema.get(input_column.as_str()).unwrap().clone(),
944
);
945
for agg in aggs {
946
let field = agg.field(schema.as_ref(), expr_arena)?;
947
let dtype = if agg.is_scalar(expr_arena) {
948
field.dtype
949
} else {
950
field.dtype.implode()
951
};
952
gb_output_schema.insert(field.name, dtype);
953
}
954
input = PhysStream::first(
955
phys_sm.insert(PhysNode {
956
output_schema: Arc::new(gb_output_schema.clone()),
957
kind: PhysNodeKind::SortedGroupBy {
958
input,
959
key: input_column.clone(),
960
aggs: aggs.to_vec(),
961
slice: options
962
.slice
963
.filter(|(o, _)| *o >= 0)
964
.map(|(o, l)| (o as IdxSize, l as IdxSize)),
965
},
966
}),
967
);
968
if let Some((offset, length)) = options.slice.as_ref().filter(|(o, _)| *o < 0) {
969
input = build_slice_stream(input, *offset, *length, phys_sm);
970
}
971
972
if projected {
973
if let Some(key_fields) = row_encoded {
974
let expr =
975
AExprBuilder::col(input_column.clone(), expr_arena).expr_ir(input_column.clone());
976
let expr = AExprBuilder::function(
977
vec![expr],
978
IRFunctionExpr::RowDecode(
979
key_fields,
980
RowEncodingVariant::Ordered {
981
descending: None,
982
nulls_last: None,
983
broadcast_nulls: None,
984
},
985
),
986
expr_arena,
987
)
988
.expr_ir(input_column.clone());
989
input = build_hstack_stream(input, &[expr], expr_arena, phys_sm, expr_cache, ctx)?;
990
991
// Unnest the row encoded columns.
992
input = PhysStream::first(phys_sm.insert(PhysNode {
993
output_schema: output_schema.clone(),
994
kind: PhysNodeKind::Map {
995
input,
996
map: Arc::new(move |df: DataFrame| df.unnest([input_column.clone()], None))
997
as _,
998
format_str: ctx.prepare_visualization.then(|| "UNNEST".to_string()),
999
},
1000
}));
1001
1002
let exprs = output_schema
1003
.iter_names()
1004
.map(|name| AExprBuilder::col(name.clone(), expr_arena).expr_ir(name.clone()))
1005
.collect::<Vec<_>>();
1006
input = build_select_stream(input, &exprs, expr_arena, phys_sm, expr_cache, ctx)?;
1007
} else {
1008
let exprs = std::iter::once(input_column)
1009
.map(|name| (name, output_schema.get_at_index(0).unwrap().0.clone()))
1010
.chain(
1011
output_schema
1012
.iter_names_cloned()
1013
.skip(1)
1014
.map(|name| (name.clone(), name.clone())),
1015
)
1016
.map(|(col_name, out_name)| {
1017
AExprBuilder::col(col_name, expr_arena).expr_ir(out_name)
1018
})
1019
.collect::<Vec<_>>();
1020
input = build_select_stream(input, &exprs, expr_arena, phys_sm, expr_cache, ctx)?;
1021
}
1022
}
1023
1024
Ok(Some(input))
1025
}
1026
1027
#[allow(clippy::too_many_arguments)]
1028
pub fn build_group_by_stream(
1029
input: PhysStream,
1030
keys: &[ExprIR],
1031
aggs: &[ExprIR],
1032
output_schema: Arc<Schema>,
1033
maintain_order: bool,
1034
options: Arc<GroupbyOptions>,
1035
apply: Option<PlanCallback<DataFrame, DataFrame>>,
1036
expr_arena: &mut Arena<AExpr>,
1037
phys_sm: &mut SlotMap<PhysNodeKey, PhysNode>,
1038
expr_cache: &mut ExprCache,
1039
ctx: StreamingLowerIRContext,
1040
are_keys_sorted: bool,
1041
) -> PolarsResult<PhysStream> {
1042
#[cfg(feature = "dynamic_group_by")]
1043
if let Some(rolling_options) = options.as_ref().rolling.as_ref()
1044
&& keys.is_empty()
1045
&& apply.is_none()
1046
{
1047
let mut input = PhysStream::first(
1048
phys_sm.insert(PhysNode::new(
1049
output_schema.clone(),
1050
PhysNodeKind::RollingGroupBy {
1051
input,
1052
index_column: rolling_options.index_column.clone(),
1053
period: rolling_options.period,
1054
offset: rolling_options.offset,
1055
closed: rolling_options.closed_window,
1056
slice: options
1057
.slice
1058
.filter(|(o, _)| *o >= 0)
1059
.map(|(o, l)| (o as IdxSize, l as IdxSize)),
1060
aggs: aggs.to_vec(),
1061
},
1062
)),
1063
);
1064
if let Some((offset, length)) = options.slice.as_ref().filter(|(o, _)| *o < 0) {
1065
input = build_slice_stream(input, *offset, *length, phys_sm);
1066
}
1067
return Ok(input);
1068
} else if let Some(dynamic_options) = options.as_ref().dynamic.as_ref()
1069
&& keys.is_empty()
1070
&& apply.is_none()
1071
{
1072
let mut input = PhysStream::first(
1073
phys_sm.insert(PhysNode::new(
1074
output_schema.clone(),
1075
PhysNodeKind::DynamicGroupBy {
1076
input,
1077
options: dynamic_options.clone(),
1078
aggs: aggs.to_vec(),
1079
slice: options
1080
.slice
1081
.filter(|(o, _)| *o >= 0)
1082
.map(|(o, l)| (o as IdxSize, l as IdxSize)),
1083
},
1084
)),
1085
);
1086
if let Some((offset, length)) = options.slice.as_ref().filter(|(o, _)| *o < 0) {
1087
input = build_slice_stream(input, *offset, *length, phys_sm);
1088
}
1089
return Ok(input);
1090
}
1091
1092
if (are_keys_sorted || std::env::var("POLARS_FORCE_SORTED_GROUP_BY").is_ok_and(|v| v == "1"))
1093
&& let Some(stream) = try_build_sorted_group_by(
1094
input,
1095
keys,
1096
aggs,
1097
output_schema.clone(),
1098
maintain_order,
1099
options.clone(),
1100
apply.clone(),
1101
expr_arena,
1102
phys_sm,
1103
expr_cache,
1104
ctx,
1105
are_keys_sorted,
1106
)?
1107
{
1108
Ok(stream)
1109
} else if let Some(stream) = try_build_streaming_group_by(
1110
input,
1111
keys,
1112
aggs,
1113
maintain_order,
1114
options.clone(),
1115
apply.clone(),
1116
expr_arena,
1117
phys_sm,
1118
expr_cache,
1119
ctx,
1120
)? {
1121
Ok(stream)
1122
} else {
1123
let format_str = ctx.prepare_visualization.then(|| {
1124
let mut buffer = String::new();
1125
write_group_by(
1126
&mut buffer,
1127
0,
1128
expr_arena,
1129
keys,
1130
aggs,
1131
apply.as_ref(),
1132
maintain_order,
1133
)
1134
.unwrap();
1135
buffer
1136
});
1137
build_group_by_fallback(
1138
input,
1139
keys,
1140
aggs,
1141
output_schema,
1142
maintain_order,
1143
options,
1144
apply,
1145
expr_arena,
1146
phys_sm,
1147
format_str,
1148
)
1149
}
1150
}
1151
1152