Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-stream/src/physical_plan/lower_group_by.rs
6939 views
1
use std::sync::Arc;
2
3
use parking_lot::Mutex;
4
use polars_core::frame::DataFrame;
5
use polars_core::prelude::{InitHashMaps, PlIndexMap, SortMultipleOptions};
6
use polars_core::schema::Schema;
7
use polars_error::{PolarsResult, polars_err};
8
use polars_expr::state::ExecutionState;
9
use polars_mem_engine::create_physical_plan;
10
use polars_plan::plans::expr_ir::{ExprIR, OutputName};
11
use polars_plan::plans::{AExpr, IR, IRAggExpr, IRFunctionExpr, NaiveExprMerger, write_group_by};
12
use polars_plan::prelude::{GroupbyOptions, *};
13
use polars_utils::arena::{Arena, Node};
14
use polars_utils::pl_str::PlSmallStr;
15
use polars_utils::unique_column_name;
16
use recursive::recursive;
17
use slotmap::SlotMap;
18
19
use super::{ExprCache, PhysNode, PhysNodeKey, PhysNodeKind, PhysStream, StreamingLowerIRContext};
20
use crate::physical_plan::lower_expr::{
21
build_select_stream, compute_output_schema, is_elementwise_rec_cached,
22
is_fake_elementwise_function, is_input_independent,
23
};
24
use crate::physical_plan::lower_ir::{build_row_idx_stream, build_slice_stream};
25
use crate::utils::late_materialized_df::LateMaterializedDataFrame;
26
27
#[allow(clippy::too_many_arguments)]
28
fn build_group_by_fallback(
29
input: PhysStream,
30
keys: &[ExprIR],
31
aggs: &[ExprIR],
32
output_schema: Arc<Schema>,
33
maintain_order: bool,
34
options: Arc<GroupbyOptions>,
35
apply: Option<PlanCallback<DataFrame, DataFrame>>,
36
expr_arena: &mut Arena<AExpr>,
37
phys_sm: &mut SlotMap<PhysNodeKey, PhysNode>,
38
format_str: Option<String>,
39
) -> PolarsResult<PhysStream> {
40
let input_schema = phys_sm[input.node].output_schema.clone();
41
let lmdf = Arc::new(LateMaterializedDataFrame::default());
42
let mut lp_arena = Arena::default();
43
let input_lp_node = lp_arena.add(lmdf.clone().as_ir_node(input_schema));
44
let group_by_lp_node = lp_arena.add(IR::GroupBy {
45
input: input_lp_node,
46
keys: keys.to_vec(),
47
aggs: aggs.to_vec(),
48
schema: output_schema.clone(),
49
maintain_order,
50
options,
51
apply,
52
});
53
let executor = Mutex::new(create_physical_plan(
54
group_by_lp_node,
55
&mut lp_arena,
56
expr_arena,
57
None,
58
)?);
59
60
let group_by_node = PhysNode {
61
output_schema,
62
kind: PhysNodeKind::InMemoryMap {
63
input,
64
map: Arc::new(move |df| {
65
lmdf.set_materialized_dataframe(df);
66
let mut state = ExecutionState::new();
67
executor.lock().execute(&mut state)
68
}),
69
format_str,
70
},
71
};
72
73
Ok(PhysStream::first(phys_sm.insert(group_by_node)))
74
}
75
76
/// Tries to lower an expression as a 'elementwise scalar agg expression'.
77
///
78
/// Such an expression is defined as the elementwise combination of scalar
79
/// aggregations of elementwise combinations of the input columns / scalar literals.
80
#[recursive]
81
#[allow(clippy::too_many_arguments)]
82
fn try_lower_elementwise_scalar_agg_expr(
83
expr: Node,
84
outer_name: Option<PlSmallStr>,
85
expr_merger: &NaiveExprMerger,
86
expr_cache: &mut ExprCache,
87
expr_arena: &mut Arena<AExpr>,
88
agg_exprs: &mut Vec<ExprIR>,
89
uniq_input_exprs: &mut PlIndexMap<u32, PlSmallStr>,
90
uniq_agg_exprs: &mut PlIndexMap<u32, PlSmallStr>,
91
) -> Option<Node> {
92
// Helper macro to simplify recursive calls.
93
macro_rules! lower_rec {
94
($input:expr) => {
95
try_lower_elementwise_scalar_agg_expr(
96
$input,
97
None,
98
expr_merger,
99
expr_cache,
100
expr_arena,
101
agg_exprs,
102
uniq_input_exprs,
103
uniq_agg_exprs,
104
)
105
};
106
}
107
108
match expr_arena.get(expr) {
109
AExpr::Column(_) => {
110
// Implicit implode not yet supported.
111
None
112
},
113
114
AExpr::Literal(lit) => {
115
if lit.is_scalar() {
116
Some(expr)
117
} else {
118
None
119
}
120
},
121
122
AExpr::Slice { .. }
123
| AExpr::Window { .. }
124
| AExpr::Sort { .. }
125
| AExpr::SortBy { .. }
126
| AExpr::Gather { .. } => None,
127
128
// Explode and filter are row-separable and should thus in theory work
129
// in a streaming fashion but they change the length of the input which
130
// means the same filter/explode should also be applied to the key
131
// column, which is not (yet) supported.
132
AExpr::Explode { .. } | AExpr::Filter { .. } => None,
133
134
AExpr::BinaryExpr { left, op, right } => {
135
let (left, op, right) = (*left, *op, *right);
136
let left = lower_rec!(left)?;
137
let right = lower_rec!(right)?;
138
Some(expr_arena.add(AExpr::BinaryExpr { left, op, right }))
139
},
140
141
AExpr::Eval {
142
expr,
143
evaluation,
144
variant,
145
} => {
146
let (expr, evaluation, variant) = (*expr, *evaluation, *variant);
147
let expr = lower_rec!(expr)?;
148
Some(expr_arena.add(AExpr::Eval {
149
expr,
150
evaluation,
151
variant,
152
}))
153
},
154
155
AExpr::Ternary {
156
predicate,
157
truthy,
158
falsy,
159
} => {
160
let (predicate, truthy, falsy) = (*predicate, *truthy, *falsy);
161
let predicate = lower_rec!(predicate)?;
162
let truthy = lower_rec!(truthy)?;
163
let falsy = lower_rec!(falsy)?;
164
Some(expr_arena.add(AExpr::Ternary {
165
predicate,
166
truthy,
167
falsy,
168
}))
169
},
170
171
#[cfg(feature = "bitwise")]
172
AExpr::Function {
173
input: inner_exprs,
174
function:
175
IRFunctionExpr::Bitwise(
176
inner_fn @ (IRBitwiseFunction::And
177
| IRBitwiseFunction::Or
178
| IRBitwiseFunction::Xor),
179
),
180
options,
181
} => {
182
assert!(inner_exprs.len() == 1);
183
184
let input = inner_exprs[0].clone().node();
185
let inner_fn = *inner_fn;
186
let options = *options;
187
188
if is_input_independent(input, expr_arena, expr_cache) {
189
// TODO: we could simply return expr here, but we first need an is_scalar function, because if
190
// it is not a scalar we need to return expr.implode().
191
return None;
192
}
193
194
if !is_elementwise_rec_cached(input, expr_arena, expr_cache) {
195
return None;
196
}
197
198
let agg_id = expr_merger.get_uniq_id(expr).unwrap();
199
let name = uniq_agg_exprs
200
.entry(agg_id)
201
.or_insert_with(|| {
202
let input_id = expr_merger.get_uniq_id(input).unwrap();
203
let input_col = uniq_input_exprs
204
.entry(input_id)
205
.or_insert_with(unique_column_name)
206
.clone();
207
let input_col_node = expr_arena.add(AExpr::Column(input_col));
208
let trans_agg_node = expr_arena.add(AExpr::Function {
209
input: vec![ExprIR::from_node(input_col_node, expr_arena)],
210
function: IRFunctionExpr::Bitwise(inner_fn),
211
options,
212
});
213
214
// Add to aggregation expressions and replace with a reference to its output.
215
let agg_expr = if let Some(name) = outer_name {
216
ExprIR::new(trans_agg_node, OutputName::Alias(name))
217
} else {
218
ExprIR::new(trans_agg_node, OutputName::Alias(unique_column_name()))
219
};
220
agg_exprs.push(agg_expr.clone());
221
agg_expr.output_name().clone()
222
})
223
.clone();
224
let result_node = expr_arena.add(AExpr::Column(name));
225
Some(result_node)
226
},
227
228
AExpr::Function {
229
input: inner_exprs,
230
function:
231
IRFunctionExpr::Boolean(
232
inner_fn @ (IRBooleanFunction::Any { .. } | IRBooleanFunction::All { .. }),
233
),
234
options,
235
} => {
236
assert!(inner_exprs.len() == 1);
237
238
let input = inner_exprs[0].clone().node();
239
let inner_fn = inner_fn.clone();
240
let options = *options;
241
242
if is_input_independent(input, expr_arena, expr_cache) {
243
// TODO: we could simply return expr here, but we first need an is_scalar function, because if
244
// it is not a scalar we need to return expr.implode().
245
return None;
246
}
247
248
if !is_elementwise_rec_cached(input, expr_arena, expr_cache) {
249
return None;
250
}
251
252
let agg_id = expr_merger.get_uniq_id(expr).unwrap();
253
let name = uniq_agg_exprs
254
.entry(agg_id)
255
.or_insert_with(|| {
256
let input_id = expr_merger.get_uniq_id(input).unwrap();
257
let input_col = uniq_input_exprs
258
.entry(input_id)
259
.or_insert_with(unique_column_name)
260
.clone();
261
let input_col_node = expr_arena.add(AExpr::Column(input_col));
262
let trans_agg_node = expr_arena.add(AExpr::Function {
263
input: vec![ExprIR::from_node(input_col_node, expr_arena)],
264
function: IRFunctionExpr::Boolean(inner_fn),
265
options,
266
});
267
268
// Add to aggregation expressions and replace with a reference to its output.
269
let agg_expr = if let Some(name) = outer_name {
270
ExprIR::new(trans_agg_node, OutputName::Alias(name))
271
} else {
272
ExprIR::new(trans_agg_node, OutputName::Alias(unique_column_name()))
273
};
274
agg_exprs.push(agg_expr.clone());
275
agg_expr.output_name().clone()
276
})
277
.clone();
278
let result_node = expr_arena.add(AExpr::Column(name));
279
Some(result_node)
280
},
281
282
node @ AExpr::Function { input, options, .. }
283
| node @ AExpr::AnonymousFunction { input, options, .. }
284
if options.is_elementwise() && !is_fake_elementwise_function(node) =>
285
{
286
let node = node.clone();
287
let input = input.clone();
288
let new_input = input
289
.into_iter()
290
.map(|i| {
291
// The function may be sensitive to names (e.g. pl.struct), so we restore them.
292
let new_node = lower_rec!(i.node())?;
293
Some(ExprIR::new(
294
new_node,
295
OutputName::Alias(i.output_name().clone()),
296
))
297
})
298
.collect::<Option<Vec<_>>>()?;
299
300
let mut new_node = node;
301
match &mut new_node {
302
AExpr::Function { input, .. } | AExpr::AnonymousFunction { input, .. } => {
303
*input = new_input;
304
},
305
_ => unreachable!(),
306
}
307
Some(expr_arena.add(new_node))
308
},
309
310
AExpr::Function { .. } | AExpr::AnonymousFunction { .. } => None,
311
312
AExpr::Cast {
313
expr,
314
dtype,
315
options,
316
} => {
317
let (expr, dtype, options) = (*expr, dtype.clone(), *options);
318
let expr = lower_rec!(expr)?;
319
Some(expr_arena.add(AExpr::Cast {
320
expr,
321
dtype,
322
options,
323
}))
324
},
325
326
AExpr::Agg(agg) => {
327
match agg {
328
IRAggExpr::Min { input, .. }
329
| IRAggExpr::Max { input, .. }
330
| IRAggExpr::First(input)
331
| IRAggExpr::Last(input)
332
| IRAggExpr::Mean(input)
333
| IRAggExpr::Sum(input)
334
| IRAggExpr::Var(input, ..)
335
| IRAggExpr::Std(input, ..)
336
| IRAggExpr::Count { input, .. } => {
337
let agg = agg.clone();
338
let input = *input;
339
if is_input_independent(input, expr_arena, expr_cache) {
340
// TODO: we could simply return expr here, but we first need an is_scalar function, because if
341
// it is not a scalar we need to return expr.implode().
342
return None;
343
}
344
345
if !is_elementwise_rec_cached(input, expr_arena, expr_cache) {
346
return None;
347
}
348
349
let agg_id = expr_merger.get_uniq_id(expr).unwrap();
350
let name = uniq_agg_exprs
351
.entry(agg_id)
352
.or_insert_with(|| {
353
let mut trans_agg = agg;
354
let input_id = expr_merger.get_uniq_id(input).unwrap();
355
let input_col = uniq_input_exprs
356
.entry(input_id)
357
.or_insert_with(unique_column_name)
358
.clone();
359
let input_col_node = expr_arena.add(AExpr::Column(input_col));
360
trans_agg.set_input(input_col_node);
361
let trans_agg_node = expr_arena.add(AExpr::Agg(trans_agg));
362
363
// Add to aggregation expressions and replace with a reference to its output.
364
let agg_expr = if let Some(name) = outer_name {
365
ExprIR::new(trans_agg_node, OutputName::Alias(name))
366
} else {
367
ExprIR::new(trans_agg_node, OutputName::Alias(unique_column_name()))
368
};
369
agg_exprs.push(agg_expr.clone());
370
agg_expr.output_name().clone()
371
})
372
.clone();
373
374
let result_node = expr_arena.add(AExpr::Column(name));
375
Some(result_node)
376
},
377
IRAggExpr::Median(..)
378
| IRAggExpr::NUnique(..)
379
| IRAggExpr::Implode(..)
380
| IRAggExpr::Quantile { .. }
381
| IRAggExpr::AggGroups(..) => None, // TODO: allow all aggregates,
382
}
383
},
384
AExpr::Len => {
385
let agg_expr = if let Some(name) = outer_name {
386
ExprIR::new(expr, OutputName::Alias(name))
387
} else {
388
ExprIR::new(expr, OutputName::Alias(unique_column_name()))
389
};
390
let result_node = expr_arena.add(AExpr::Column(agg_expr.output_name().clone()));
391
agg_exprs.push(agg_expr);
392
Some(result_node)
393
},
394
}
395
}
396
397
#[allow(clippy::too_many_arguments)]
398
fn try_build_streaming_group_by(
399
mut input: PhysStream,
400
keys: &[ExprIR],
401
aggs: &[ExprIR],
402
maintain_order: bool,
403
options: Arc<GroupbyOptions>,
404
apply: Option<PlanCallback<DataFrame, DataFrame>>,
405
expr_arena: &mut Arena<AExpr>,
406
phys_sm: &mut SlotMap<PhysNodeKey, PhysNode>,
407
expr_cache: &mut ExprCache,
408
ctx: StreamingLowerIRContext,
409
) -> Option<PolarsResult<PhysStream>> {
410
if apply.is_some() {
411
return None; // TODO
412
}
413
414
#[cfg(feature = "dynamic_group_by")]
415
if options.dynamic.is_some() || options.rolling.is_some() {
416
return None; // TODO
417
}
418
419
if keys.is_empty() {
420
return Some(Err(
421
polars_err!(ComputeError: "at least one key is required in a group_by operation"),
422
));
423
}
424
425
// Augment with row index if maintaining order.
426
let row_idx_name = unique_column_name();
427
let row_idx_node = expr_arena.add(AExpr::Column(row_idx_name.clone()));
428
let mut agg_storage;
429
let aggs = if maintain_order {
430
input = build_row_idx_stream(input, row_idx_name.clone(), None, phys_sm);
431
let first_agg_node = expr_arena.add(AExpr::Agg(IRAggExpr::First(row_idx_node)));
432
agg_storage = aggs.to_vec();
433
agg_storage.push(ExprIR::from_node(first_agg_node, expr_arena));
434
&agg_storage
435
} else {
436
aggs
437
};
438
439
let all_independent = keys
440
.iter()
441
.chain(aggs.iter())
442
.all(|expr| is_input_independent(expr.node(), expr_arena, expr_cache));
443
if all_independent {
444
return None;
445
}
446
447
// Fill all expressions into the merger, letting us extract common subexpressions later.
448
let mut expr_merger = NaiveExprMerger::default();
449
for key in keys {
450
expr_merger.add_expr(key.node(), expr_arena);
451
}
452
for agg in aggs {
453
expr_merger.add_expr(agg.node(), expr_arena);
454
}
455
456
// Extract aggregates, input expressions for those aggregates and replace
457
// with agg node output columns.
458
let mut uniq_input_exprs = PlIndexMap::new();
459
let mut trans_agg_exprs = Vec::new();
460
let mut trans_keys = Vec::new();
461
let mut trans_output_exprs = Vec::new();
462
for key in keys {
463
let key_id = expr_merger.get_uniq_id(key.node()).unwrap();
464
let uniq_col = uniq_input_exprs
465
.entry(key_id)
466
.or_insert_with(unique_column_name)
467
.clone();
468
469
// Keys might refer to the same column multiple times, we have to give a unique name to it.
470
let uniq_name = unique_column_name();
471
let trans_key_node = expr_arena.add(AExpr::Column(uniq_col));
472
trans_keys.push(ExprIR::new(
473
trans_key_node,
474
OutputName::Alias(uniq_name.clone()),
475
));
476
let output_name = OutputName::Alias(key.output_name().clone());
477
let trans_output_node = expr_arena.add(AExpr::Column(uniq_name));
478
trans_output_exprs.push(ExprIR::new(trans_output_node, output_name));
479
}
480
481
let mut uniq_agg_exprs = PlIndexMap::new();
482
for agg in aggs {
483
let trans_node = try_lower_elementwise_scalar_agg_expr(
484
agg.node(),
485
Some(agg.output_name().clone()),
486
&expr_merger,
487
expr_cache,
488
expr_arena,
489
&mut trans_agg_exprs,
490
&mut uniq_input_exprs,
491
&mut uniq_agg_exprs,
492
)?;
493
let output_name = OutputName::Alias(agg.output_name().clone());
494
trans_output_exprs.push(ExprIR::new(trans_node, output_name));
495
}
496
497
// We must lower the keys together with the input to the aggregations.
498
let mut input_exprs = Vec::new();
499
for (uniq_id, name) in uniq_input_exprs.iter() {
500
let node = expr_merger.get_node(*uniq_id).unwrap();
501
input_exprs.push(ExprIR::new(node, OutputName::Alias(name.clone())));
502
}
503
504
// If all inputs are input independent add a dummy column so the group sizes are correct. See #23868.
505
if input_exprs
506
.iter()
507
.all(|e| is_input_independent(e.node(), expr_arena, expr_cache))
508
{
509
let dummy_col_name = phys_sm[input.node].output_schema.get_at_index(0).unwrap().0;
510
let dummy_col = expr_arena.add(AExpr::Column(dummy_col_name.clone()));
511
input_exprs.push(ExprIR::new(
512
dummy_col,
513
OutputName::ColumnLhs(dummy_col_name.clone()),
514
));
515
}
516
517
let pre_select =
518
build_select_stream(input, &input_exprs, expr_arena, phys_sm, expr_cache, ctx).ok()?;
519
520
let input_schema = &phys_sm[pre_select.node].output_schema;
521
let group_by_output_schema = compute_output_schema(
522
input_schema,
523
&[trans_keys.as_slice(), trans_agg_exprs.as_slice()].concat(),
524
expr_arena,
525
)
526
.unwrap();
527
let agg_node = phys_sm.insert(PhysNode::new(
528
group_by_output_schema.clone(),
529
PhysNodeKind::GroupBy {
530
input: pre_select,
531
key: trans_keys,
532
aggs: trans_agg_exprs,
533
},
534
));
535
536
// Sort the input based on the first row index if maintaining order.
537
let post_select_input = if maintain_order {
538
let sort_node = phys_sm.insert(PhysNode::new(
539
group_by_output_schema,
540
PhysNodeKind::Sort {
541
input: PhysStream::first(agg_node),
542
by_column: vec![ExprIR::from_node(row_idx_node, expr_arena)],
543
slice: None,
544
sort_options: SortMultipleOptions::new(),
545
},
546
));
547
trans_output_exprs.pop(); // Remove row idx from post-select.
548
PhysStream::first(sort_node)
549
} else {
550
PhysStream::first(agg_node)
551
};
552
553
let post_select = build_select_stream(
554
post_select_input,
555
&trans_output_exprs,
556
expr_arena,
557
phys_sm,
558
expr_cache,
559
ctx,
560
);
561
562
let out = if let Some((offset, len)) = options.slice {
563
post_select.map(|s| build_slice_stream(s, offset, len, phys_sm))
564
} else {
565
post_select
566
};
567
Some(out)
568
}
569
570
#[allow(clippy::too_many_arguments)]
571
pub fn build_group_by_stream(
572
input: PhysStream,
573
keys: &[ExprIR],
574
aggs: &[ExprIR],
575
output_schema: Arc<Schema>,
576
maintain_order: bool,
577
options: Arc<GroupbyOptions>,
578
apply: Option<PlanCallback<DataFrame, DataFrame>>,
579
expr_arena: &mut Arena<AExpr>,
580
phys_sm: &mut SlotMap<PhysNodeKey, PhysNode>,
581
expr_cache: &mut ExprCache,
582
ctx: StreamingLowerIRContext,
583
) -> PolarsResult<PhysStream> {
584
let streaming = try_build_streaming_group_by(
585
input,
586
keys,
587
aggs,
588
maintain_order,
589
options.clone(),
590
apply.clone(),
591
expr_arena,
592
phys_sm,
593
expr_cache,
594
ctx,
595
);
596
if let Some(stream) = streaming {
597
stream
598
} else {
599
let format_str = ctx.prepare_visualization.then(|| {
600
let mut buffer = String::new();
601
write_group_by(
602
&mut buffer,
603
0,
604
expr_arena,
605
keys,
606
aggs,
607
apply.as_ref(),
608
maintain_order,
609
)
610
.unwrap();
611
buffer
612
});
613
build_group_by_fallback(
614
input,
615
keys,
616
aggs,
617
output_schema,
618
maintain_order,
619
options,
620
apply,
621
expr_arena,
622
phys_sm,
623
format_str,
624
)
625
}
626
}
627
628