Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-stream/src/nodes/group_by.rs
6939 views
1
use std::sync::Arc;
2
3
use polars_core::POOL;
4
use polars_core::prelude::{IntoColumn, PlHashSet, PlRandomState};
5
use polars_core::schema::Schema;
6
use polars_core::utils::accumulate_dataframes_vertical_unchecked;
7
use polars_expr::groups::Grouper;
8
use polars_expr::hash_keys::HashKeys;
9
use polars_expr::hot_groups::{HotGrouper, new_hash_hot_grouper};
10
use polars_expr::reduce::GroupedReduction;
11
use polars_utils::IdxSize;
12
use polars_utils::cardinality_sketch::CardinalitySketch;
13
use polars_utils::hashing::HashPartitioner;
14
use polars_utils::itertools::Itertools;
15
use polars_utils::pl_str::PlSmallStr;
16
use polars_utils::sparse_init_vec::SparseInitVec;
17
use rayon::prelude::*;
18
19
use super::compute_node_prelude::*;
20
use crate::async_executor;
21
use crate::async_primitives::connector::Receiver;
22
use crate::expression::StreamExpr;
23
use crate::morsel::get_ideal_morsel_size;
24
use crate::nodes::in_memory_source::InMemorySourceNode;
25
26
#[cfg(debug_assertions)]
27
const DEFAULT_HOT_TABLE_SIZE: usize = 4;
28
#[cfg(not(debug_assertions))]
29
const DEFAULT_HOT_TABLE_SIZE: usize = 4096;
30
31
struct LocalGroupBySinkState {
32
hot_grouper: Box<dyn HotGrouper>,
33
hot_grouped_reductions: Vec<Box<dyn GroupedReduction>>,
34
35
// A cardinality sketch per partition for the keys seen by this builder.
36
sketch_per_p: Vec<CardinalitySketch>,
37
38
// morsel_idxs_values_per_p[p][start..stop] contains the offsets into cold_morsels[i]
39
// for partition p, where start, stop are:
40
// let start = morsel_idxs_offsets[i * num_partitions + p];
41
// let stop = morsel_idxs_offsets[(i + 1) * num_partitions + p];
42
cold_morsels: Vec<(u64, HashKeys, DataFrame)>,
43
morsel_idxs_values_per_p: Vec<Vec<IdxSize>>,
44
morsel_idxs_offsets_per_p: Vec<usize>,
45
46
// Similar to the above, but for (evicted) pre-aggregates.
47
pre_aggs: Vec<(HashKeys, Vec<Box<dyn GroupedReduction>>)>,
48
pre_agg_idxs_values_per_p: Vec<Vec<IdxSize>>,
49
pre_agg_idxs_offsets_per_p: Vec<usize>,
50
}
51
52
impl LocalGroupBySinkState {
53
fn new(
54
key_schema: Arc<Schema>,
55
reductions: Vec<Box<dyn GroupedReduction>>,
56
hot_table_size: usize,
57
num_partitions: usize,
58
) -> Self {
59
let hot_grouper = new_hash_hot_grouper(key_schema, hot_table_size);
60
Self {
61
hot_grouper,
62
hot_grouped_reductions: reductions,
63
64
sketch_per_p: vec![CardinalitySketch::new(); num_partitions],
65
66
cold_morsels: Vec::new(),
67
morsel_idxs_values_per_p: vec![Vec::new(); num_partitions],
68
morsel_idxs_offsets_per_p: vec![0; num_partitions],
69
70
pre_aggs: Vec::new(),
71
pre_agg_idxs_values_per_p: vec![Vec::new(); num_partitions],
72
pre_agg_idxs_offsets_per_p: vec![0; num_partitions],
73
}
74
}
75
76
fn flush_evictions(&mut self, partitioner: &HashPartitioner) {
77
let hash_keys = self.hot_grouper.take_evicted_keys();
78
let reductions = self
79
.hot_grouped_reductions
80
.iter_mut()
81
.map(|hgr| hgr.take_evictions())
82
.collect_vec();
83
self.add_pre_agg(hash_keys, reductions, partitioner);
84
}
85
86
fn add_pre_agg(
87
&mut self,
88
hash_keys: HashKeys,
89
reductions: Vec<Box<dyn GroupedReduction>>,
90
partitioner: &HashPartitioner,
91
) {
92
hash_keys.gen_idxs_per_partition(
93
partitioner,
94
&mut self.pre_agg_idxs_values_per_p,
95
&mut self.sketch_per_p,
96
true,
97
);
98
self.pre_agg_idxs_offsets_per_p
99
.extend(self.pre_agg_idxs_values_per_p.iter().map(|vp| vp.len()));
100
self.pre_aggs.push((hash_keys, reductions));
101
}
102
}
103
104
struct GroupBySinkState {
105
key_selectors: Vec<StreamExpr>,
106
grouper: Box<dyn Grouper>,
107
uniq_grouped_reduction_cols: Vec<PlSmallStr>,
108
grouped_reduction_cols: Vec<PlSmallStr>,
109
grouped_reductions: Vec<Box<dyn GroupedReduction>>,
110
locals: Vec<LocalGroupBySinkState>,
111
random_state: PlRandomState,
112
partitioner: HashPartitioner,
113
has_order_sensitive_agg: bool,
114
}
115
116
impl GroupBySinkState {
117
fn spawn<'env, 's>(
118
&'env mut self,
119
scope: &'s TaskScope<'s, 'env>,
120
receivers: Vec<Receiver<Morsel>>,
121
state: &'s StreamingExecutionState,
122
join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,
123
) {
124
for (mut recv, local) in receivers.into_iter().zip(&mut self.locals) {
125
let key_selectors = &self.key_selectors;
126
let uniq_grouped_reduction_cols = &self.uniq_grouped_reduction_cols;
127
let grouped_reduction_cols = &self.grouped_reduction_cols;
128
let random_state = &self.random_state;
129
let partitioner = self.partitioner.clone();
130
let has_order_sensitive_agg = self.has_order_sensitive_agg;
131
join_handles.push(scope.spawn_task(TaskPriority::High, async move {
132
let mut hot_idxs = Vec::new();
133
let mut hot_group_idxs = Vec::new();
134
let mut cold_idxs = Vec::new();
135
while let Ok(morsel) = recv.recv().await {
136
// Compute hot group indices from key.
137
let seq = morsel.seq().to_u64();
138
let mut df = morsel.into_df();
139
let mut key_columns = Vec::new();
140
for selector in key_selectors {
141
let s = selector.evaluate(&df, &state.in_memory_exec_state).await?;
142
key_columns.push(s.into_column());
143
}
144
let keys = DataFrame::new_with_broadcast_len(key_columns, df.height())?;
145
let hash_keys = HashKeys::from_df(&keys, *random_state, true, false);
146
147
hot_idxs.clear();
148
hot_group_idxs.clear();
149
cold_idxs.clear();
150
local.hot_grouper.insert_keys(
151
&hash_keys,
152
&mut hot_idxs,
153
&mut hot_group_idxs,
154
&mut cold_idxs,
155
has_order_sensitive_agg,
156
);
157
158
// Drop columns not used for reductions (key-only columns).
159
if uniq_grouped_reduction_cols.len() < grouped_reduction_cols.len() {
160
df = df._select_impl(uniq_grouped_reduction_cols).unwrap();
161
}
162
df.rechunk_mut(); // For gathers.
163
164
// Update hot reductions.
165
for (col, reduction) in grouped_reduction_cols
166
.iter()
167
.zip(&mut local.hot_grouped_reductions)
168
{
169
unsafe {
170
// SAFETY: we resize the reduction to the number of groups beforehand.
171
reduction.resize(local.hot_grouper.num_groups());
172
reduction.update_groups_while_evicting(
173
df.column(col).unwrap(),
174
&hot_idxs,
175
&hot_group_idxs,
176
seq,
177
)?;
178
}
179
}
180
181
// Store cold keys.
182
// TODO: don't always gather, if majority cold simply store all and remember offsets into it.
183
if !cold_idxs.is_empty() {
184
unsafe {
185
let cold_keys = hash_keys.gather_unchecked(&cold_idxs);
186
let cold_df = df.take_slice_unchecked_impl(&cold_idxs, false);
187
188
cold_keys.gen_idxs_per_partition(
189
&partitioner,
190
&mut local.morsel_idxs_values_per_p,
191
&mut local.sketch_per_p,
192
true,
193
);
194
local
195
.morsel_idxs_offsets_per_p
196
.extend(local.morsel_idxs_values_per_p.iter().map(|vp| vp.len()));
197
local.cold_morsels.push((seq, cold_keys, cold_df));
198
}
199
}
200
201
// If we have too many evicted rows, flush them.
202
if local.hot_grouper.num_evictions() >= get_ideal_morsel_size() {
203
local.flush_evictions(&partitioner);
204
}
205
}
206
Ok(())
207
}));
208
}
209
}
210
211
fn combine_locals(&mut self) -> PolarsResult<Vec<GroupByPartition>> {
212
// Finalize pre-aggregations.
213
POOL.install(|| {
214
self.locals
215
.as_mut_slice()
216
.into_par_iter()
217
.with_max_len(1)
218
.for_each(|l| {
219
if l.hot_grouper.num_evictions() > 0 {
220
l.flush_evictions(&self.partitioner);
221
}
222
let hot_keys = l.hot_grouper.keys();
223
let hot_reductions = core::mem::take(&mut l.hot_grouped_reductions);
224
l.add_pre_agg(hot_keys, hot_reductions, &self.partitioner);
225
});
226
});
227
228
// To reduce maximum memory usage we want to drop the morsels
229
// as soon as they're processed, so we move into Arcs. The drops might
230
// also be expensive, so instead of directly dropping we put that on
231
// a work queue.
232
let morsels_per_local = self
233
.locals
234
.iter_mut()
235
.map(|l| Arc::new(core::mem::take(&mut l.cold_morsels)))
236
.collect_vec();
237
let pre_aggs_per_local = self
238
.locals
239
.iter_mut()
240
.map(|l| Arc::new(core::mem::take(&mut l.pre_aggs)))
241
.collect_vec();
242
enum ToDrop<A, B> {
243
A(A),
244
B(B),
245
}
246
let (drop_q_send, drop_q_recv) = async_channel::bounded(self.locals.len());
247
let num_partitions = self.locals[0].sketch_per_p.len();
248
let output_per_partition: SparseInitVec<GroupByPartition> =
249
SparseInitVec::with_capacity(num_partitions);
250
let locals = &self.locals;
251
let grouper_template = &self.grouper;
252
let grouped_reductions_template = &self.grouped_reductions;
253
let grouped_reduction_cols = &self.grouped_reduction_cols;
254
255
async_executor::task_scope(|s| {
256
// Wrap in outer Arc to move to each thread, performing the
257
// expensive clone on that thread.
258
let arc_morsels_per_local = Arc::new(morsels_per_local);
259
let arc_pre_aggs_per_local = Arc::new(pre_aggs_per_local);
260
let mut join_handles = Vec::new();
261
for p in 0..num_partitions {
262
let arc_morsels_per_local = Arc::clone(&arc_morsels_per_local);
263
let arc_pre_aggs_per_local = Arc::clone(&arc_pre_aggs_per_local);
264
let drop_q_send = drop_q_send.clone();
265
let drop_q_recv = drop_q_recv.clone();
266
let output_per_partition = &output_per_partition;
267
join_handles.push(s.spawn_task(TaskPriority::High, async move {
268
// Extract from outer arc and drop outer arc.
269
let morsels_per_local = Arc::unwrap_or_clone(arc_morsels_per_local);
270
let pre_aggs_per_local = Arc::unwrap_or_clone(arc_pre_aggs_per_local);
271
272
// Compute cardinality estimate and total amount of
273
// payload for this partition.
274
let mut sketch = CardinalitySketch::new();
275
for l in locals {
276
sketch.combine(&l.sketch_per_p[p]);
277
}
278
279
// Allocate grouper and reductions.
280
let est_num_groups = sketch.estimate() * 5 / 4;
281
let mut p_grouper = grouper_template.new_empty();
282
let mut p_reductions = grouped_reductions_template
283
.iter()
284
.map(|gr| gr.new_empty())
285
.collect_vec();
286
p_grouper.reserve(est_num_groups);
287
for r in &mut p_reductions {
288
r.reserve(est_num_groups);
289
}
290
291
// Insert morsels.
292
let mut skip_drop_attempt = false;
293
let mut group_idxs = Vec::new();
294
for (l, l_morsels) in locals.iter().zip(morsels_per_local) {
295
// Try to help with dropping.
296
if !skip_drop_attempt {
297
drop(drop_q_recv.try_recv());
298
}
299
300
for (i, morsel) in l_morsels.iter().enumerate() {
301
let (seq_id, keys, cols) = morsel;
302
unsafe {
303
let p_morsel_idxs_start =
304
l.morsel_idxs_offsets_per_p[i * num_partitions + p];
305
let p_morsel_idxs_stop =
306
l.morsel_idxs_offsets_per_p[(i + 1) * num_partitions + p];
307
let p_morsel_idxs = &l.morsel_idxs_values_per_p[p]
308
[p_morsel_idxs_start..p_morsel_idxs_stop];
309
310
group_idxs.clear();
311
p_grouper.insert_keys_subset(
312
keys,
313
p_morsel_idxs,
314
Some(&mut group_idxs),
315
);
316
for (c, r) in grouped_reduction_cols.iter().zip(&mut p_reductions) {
317
let values = cols.column(c.as_str()).unwrap();
318
r.resize(p_grouper.num_groups());
319
r.update_groups_subset(
320
values,
321
p_morsel_idxs,
322
&group_idxs,
323
*seq_id,
324
)?;
325
}
326
}
327
}
328
329
if let Some(l) = Arc::into_inner(l_morsels) {
330
// If we're the last thread to process this set of morsels we're probably
331
// falling behind the rest, since the drop can be quite expensive we skip
332
// a drop attempt hoping someone else will pick up the slack.
333
drop(drop_q_send.try_send(ToDrop::A(l)));
334
skip_drop_attempt = true;
335
} else {
336
skip_drop_attempt = false;
337
}
338
}
339
340
// Insert pre-aggregates.
341
for (l, l_pre_aggs) in locals.iter().zip(pre_aggs_per_local) {
342
// Try to help with dropping.
343
if !skip_drop_attempt {
344
drop(drop_q_recv.try_recv());
345
}
346
347
for (i, key_pre_aggs) in l_pre_aggs.iter().enumerate() {
348
let (keys, pre_aggs) = key_pre_aggs;
349
unsafe {
350
let p_pre_agg_idxs_start =
351
l.pre_agg_idxs_offsets_per_p[i * num_partitions + p];
352
let p_pre_agg_idxs_stop =
353
l.pre_agg_idxs_offsets_per_p[(i + 1) * num_partitions + p];
354
let p_pre_agg_idxs = &l.pre_agg_idxs_values_per_p[p]
355
[p_pre_agg_idxs_start..p_pre_agg_idxs_stop];
356
357
group_idxs.clear();
358
p_grouper.insert_keys_subset(
359
keys,
360
p_pre_agg_idxs,
361
Some(&mut group_idxs),
362
);
363
for (pre_agg, r) in pre_aggs.iter().zip(&mut p_reductions) {
364
r.resize(p_grouper.num_groups());
365
r.combine_subset(&**pre_agg, p_pre_agg_idxs, &group_idxs)?;
366
}
367
}
368
}
369
370
if let Some(l) = Arc::into_inner(l_pre_aggs) {
371
// If we're the last thread to process this set of morsels we're probably
372
// falling behind the rest, since the drop can be quite expensive we skip
373
// a drop attempt hoping someone else will pick up the slack.
374
drop(drop_q_send.try_send(ToDrop::B(l)));
375
skip_drop_attempt = true;
376
} else {
377
skip_drop_attempt = false;
378
}
379
}
380
381
// We're done, help others out by doing drops.
382
drop(drop_q_send); // So we don't deadlock trying to receive from ourselves.
383
while let Ok(to_drop) = drop_q_recv.recv().await {
384
drop(to_drop);
385
}
386
387
output_per_partition
388
.try_set(
389
p,
390
GroupByPartition {
391
grouper: p_grouper,
392
grouped_reductions: p_reductions,
393
},
394
)
395
.ok()
396
.unwrap();
397
398
PolarsResult::Ok(())
399
}));
400
}
401
402
// Drop outer arc after spawning each thread so the inner arcs
403
// can get dropped as soon as they're processed. We also have to
404
// drop the drop queue sender so we don't deadlock waiting for it
405
// to end.
406
drop(arc_morsels_per_local);
407
drop(arc_pre_aggs_per_local);
408
drop(drop_q_send);
409
410
polars_io::pl_async::get_runtime().block_on(async move {
411
for handle in join_handles {
412
handle.await?;
413
}
414
PolarsResult::Ok(())
415
})?;
416
PolarsResult::Ok(())
417
})?;
418
419
// Drop remaining local state in parallel.
420
POOL.install(|| {
421
core::mem::take(&mut self.locals)
422
.into_par_iter()
423
.with_max_len(1)
424
.for_each(drop);
425
});
426
427
Ok(output_per_partition.try_assume_init().ok().unwrap())
428
}
429
}
430
431
struct GroupByPartition {
432
grouper: Box<dyn Grouper>,
433
grouped_reductions: Vec<Box<dyn GroupedReduction>>,
434
}
435
436
impl GroupByPartition {
437
fn into_df(self, key_schema: &Schema, output_schema: &Schema) -> PolarsResult<DataFrame> {
438
let mut out = self.grouper.get_keys_in_group_order(key_schema);
439
let out_names = output_schema.iter_names().skip(out.width());
440
for (mut r, name) in self.grouped_reductions.into_iter().zip(out_names) {
441
unsafe {
442
out.with_column_unchecked(r.finalize()?.with_name(name.clone()).into_column());
443
}
444
}
445
Ok(out)
446
}
447
}
448
449
enum GroupByState {
450
Sink(GroupBySinkState),
451
Source(InMemorySourceNode),
452
Done,
453
}
454
455
pub struct GroupByNode {
456
state: GroupByState,
457
key_schema: Arc<Schema>,
458
output_schema: Arc<Schema>,
459
}
460
461
impl GroupByNode {
462
#[allow(clippy::too_many_arguments)]
463
pub fn new(
464
key_schema: Arc<Schema>,
465
key_selectors: Vec<StreamExpr>,
466
grouper: Box<dyn Grouper>,
467
grouped_reduction_cols: Vec<PlSmallStr>,
468
grouped_reductions: Vec<Box<dyn GroupedReduction>>,
469
output_schema: Arc<Schema>,
470
random_state: PlRandomState,
471
num_pipelines: usize,
472
has_order_sensitive_agg: bool,
473
) -> Self {
474
let hot_table_size = std::env::var("POLARS_HOT_TABLE_SIZE")
475
.map(|sz| sz.parse::<usize>().unwrap())
476
.unwrap_or(DEFAULT_HOT_TABLE_SIZE);
477
let num_partitions = num_pipelines;
478
let uniq_grouped_reduction_cols = grouped_reduction_cols
479
.iter()
480
.cloned()
481
.collect::<PlHashSet<_>>()
482
.into_iter()
483
.collect_vec();
484
let locals = (0..num_pipelines)
485
.map(|_| {
486
let reductions = grouped_reductions.iter().map(|gr| gr.new_empty()).collect();
487
LocalGroupBySinkState::new(
488
key_schema.clone(),
489
reductions,
490
hot_table_size,
491
num_partitions,
492
)
493
})
494
.collect();
495
let partitioner = HashPartitioner::new(num_partitions, 0);
496
Self {
497
state: GroupByState::Sink(GroupBySinkState {
498
key_selectors,
499
grouped_reductions,
500
grouper,
501
random_state,
502
uniq_grouped_reduction_cols,
503
grouped_reduction_cols,
504
locals,
505
partitioner,
506
has_order_sensitive_agg,
507
}),
508
key_schema,
509
output_schema,
510
}
511
}
512
}
513
514
impl ComputeNode for GroupByNode {
515
fn name(&self) -> &str {
516
"group-by"
517
}
518
519
fn update_state(
520
&mut self,
521
recv: &mut [PortState],
522
send: &mut [PortState],
523
state: &StreamingExecutionState,
524
) -> PolarsResult<()> {
525
assert!(recv.len() == 1 && send.len() == 1);
526
527
// State transitions.
528
match &mut self.state {
529
// If the output doesn't want any more data, transition to being done.
530
_ if send[0] == PortState::Done => {
531
self.state = GroupByState::Done;
532
},
533
// Input is done, transition to being a source.
534
GroupByState::Sink(_) if matches!(recv[0], PortState::Done) => {
535
let GroupByState::Sink(mut sink) =
536
core::mem::replace(&mut self.state, GroupByState::Done)
537
else {
538
unreachable!()
539
};
540
let partitions = sink.combine_locals()?;
541
let dfs = POOL.install(|| {
542
partitions
543
.into_par_iter()
544
.map(|p| p.into_df(&self.key_schema, &self.output_schema))
545
.collect::<Result<Vec<_>, _>>()
546
})?;
547
548
let df = accumulate_dataframes_vertical_unchecked(dfs);
549
let source = InMemorySourceNode::new(Arc::new(df), MorselSeq::new(0));
550
self.state = GroupByState::Source(source);
551
},
552
// Defer to source node implementation.
553
GroupByState::Source(src) => {
554
src.update_state(&mut [], send, state)?;
555
if send[0] == PortState::Done {
556
self.state = GroupByState::Done;
557
}
558
},
559
// Nothing to change.
560
GroupByState::Done | GroupByState::Sink(_) => {},
561
}
562
563
// Communicate our state.
564
match &self.state {
565
GroupByState::Sink { .. } => {
566
send[0] = PortState::Blocked;
567
recv[0] = PortState::Ready;
568
},
569
GroupByState::Source(..) => {
570
recv[0] = PortState::Done;
571
send[0] = PortState::Ready;
572
},
573
GroupByState::Done => {
574
recv[0] = PortState::Done;
575
send[0] = PortState::Done;
576
},
577
}
578
Ok(())
579
}
580
581
fn spawn<'env, 's>(
582
&'env mut self,
583
scope: &'s TaskScope<'s, 'env>,
584
recv_ports: &mut [Option<RecvPort<'_>>],
585
send_ports: &mut [Option<SendPort<'_>>],
586
state: &'s StreamingExecutionState,
587
join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,
588
) {
589
assert!(send_ports.len() == 1 && recv_ports.len() == 1);
590
match &mut self.state {
591
GroupByState::Sink(sink) => {
592
assert!(send_ports[0].is_none());
593
sink.spawn(
594
scope,
595
recv_ports[0].take().unwrap().parallel(),
596
state,
597
join_handles,
598
)
599
},
600
GroupByState::Source(source) => {
601
assert!(recv_ports[0].is_none());
602
source.spawn(scope, &mut [], send_ports, state, join_handles);
603
},
604
GroupByState::Done => unreachable!(),
605
}
606
}
607
}
608
609