Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-stream/src/nodes/group_by.rs
8433 views
1
use std::sync::Arc;
2
3
use polars_core::POOL;
4
use polars_core::prelude::{IntoColumn, PlHashSet, PlRandomState};
5
use polars_core::schema::Schema;
6
use polars_core::utils::accumulate_dataframes_vertical_unchecked;
7
use polars_expr::groups::Grouper;
8
use polars_expr::hash_keys::HashKeys;
9
use polars_expr::hot_groups::{HotGrouper, new_hash_hot_grouper};
10
use polars_expr::reduce::GroupedReduction;
11
use polars_utils::cardinality_sketch::CardinalitySketch;
12
use polars_utils::hashing::HashPartitioner;
13
use polars_utils::itertools::Itertools;
14
use polars_utils::pl_str::PlSmallStr;
15
use polars_utils::sparse_init_vec::SparseInitVec;
16
use polars_utils::{IdxSize, UnitVec};
17
use rayon::prelude::*;
18
use tokio::sync::mpsc::{Receiver, channel};
19
20
use super::compute_node_prelude::*;
21
use crate::async_executor;
22
use crate::expression::StreamExpr;
23
use crate::morsel::get_ideal_morsel_size;
24
use crate::nodes::in_memory_source::InMemorySourceNode;
25
26
#[cfg(debug_assertions)]
27
const DEFAULT_HOT_TABLE_SIZE: usize = 4;
28
#[cfg(not(debug_assertions))]
29
const DEFAULT_HOT_TABLE_SIZE: usize = 4096;
30
31
struct PreAgg {
32
keys: HashKeys,
33
reduction_idxs: UnitVec<usize>,
34
reductions: Vec<Box<dyn GroupedReduction>>,
35
}
36
37
struct LocalGroupBySinkState {
38
hot_grouper_per_input: Vec<Box<dyn HotGrouper>>,
39
hot_grouped_reductions: Vec<Box<dyn GroupedReduction>>,
40
41
// A cardinality sketch per partition for the keys seen by this builder.
42
sketch_per_p: Vec<CardinalitySketch>,
43
44
// morsel_idxs_values_per_p[p][start..stop] contains the offsets into cold_morsels[i]
45
// for partition p, where start, stop are:
46
// let start = morsel_idxs_offsets[i * num_partitions + p];
47
// let stop = morsel_idxs_offsets[(i + 1) * num_partitions + p];
48
cold_morsels: Vec<(usize, u64, HashKeys, DataFrame)>,
49
morsel_idxs_values_per_p: Vec<Vec<IdxSize>>,
50
morsel_idxs_offsets_per_p: Vec<usize>,
51
52
// Similar to the above, but for (evicted) pre-aggregates.
53
// The UnitVec contains the indices of the grouped reductions.
54
pre_aggs: Vec<PreAgg>,
55
pre_agg_idxs_values_per_p: Vec<Vec<IdxSize>>,
56
pre_agg_idxs_offsets_per_p: Vec<usize>,
57
}
58
59
impl LocalGroupBySinkState {
60
fn new(
61
key_schema: Arc<Schema>,
62
reductions: Vec<Box<dyn GroupedReduction>>,
63
hot_table_size: usize,
64
num_partitions: usize,
65
num_inputs: usize,
66
) -> Self {
67
let hot_grouper_per_input = (0..num_inputs)
68
.map(|_| new_hash_hot_grouper(key_schema.clone(), hot_table_size))
69
.collect();
70
Self {
71
hot_grouper_per_input,
72
hot_grouped_reductions: reductions,
73
74
sketch_per_p: vec![CardinalitySketch::new(); num_partitions],
75
76
cold_morsels: Vec::new(),
77
morsel_idxs_values_per_p: vec![Vec::new(); num_partitions],
78
morsel_idxs_offsets_per_p: vec![0; num_partitions],
79
80
pre_aggs: Vec::new(),
81
pre_agg_idxs_values_per_p: vec![Vec::new(); num_partitions],
82
pre_agg_idxs_offsets_per_p: vec![0; num_partitions],
83
}
84
}
85
86
fn flush_evictions(
87
&mut self,
88
input_idx: usize,
89
reduction_idxs: &[usize],
90
partitioner: &HashPartitioner,
91
) {
92
let hash_keys = self.hot_grouper_per_input[input_idx].take_evicted_keys();
93
let reductions = reduction_idxs
94
.iter()
95
.map(|r| self.hot_grouped_reductions[*r].take_evictions())
96
.collect_vec();
97
self.add_pre_agg(hash_keys, reduction_idxs, reductions, partitioner);
98
}
99
100
fn add_pre_agg(
101
&mut self,
102
hash_keys: HashKeys,
103
reduction_idxs: &[usize],
104
reductions: Vec<Box<dyn GroupedReduction>>,
105
partitioner: &HashPartitioner,
106
) {
107
hash_keys.gen_idxs_per_partition(
108
partitioner,
109
&mut self.pre_agg_idxs_values_per_p,
110
&mut self.sketch_per_p,
111
true,
112
);
113
self.pre_agg_idxs_offsets_per_p
114
.extend(self.pre_agg_idxs_values_per_p.iter().map(|vp| vp.len()));
115
let pre_agg = PreAgg {
116
keys: hash_keys,
117
reduction_idxs: UnitVec::from_slice(reduction_idxs),
118
reductions,
119
};
120
self.pre_aggs.push(pre_agg);
121
}
122
}
123
124
struct GroupBySinkState {
125
key_selectors_per_input: Vec<Vec<StreamExpr>>,
126
reductions_per_input: Vec<Vec<usize>>,
127
grouper: Box<dyn Grouper>,
128
uniq_grouped_reduction_cols_per_input: Vec<Vec<PlSmallStr>>,
129
grouped_reduction_cols: Vec<Vec<PlSmallStr>>,
130
grouped_reductions: Vec<Box<dyn GroupedReduction>>,
131
locals: Vec<LocalGroupBySinkState>,
132
random_state: PlRandomState,
133
partitioner: HashPartitioner,
134
has_order_sensitive_agg: bool,
135
}
136
137
impl GroupBySinkState {
138
fn spawn<'env, 's>(
139
&'env mut self,
140
scope: &'s TaskScope<'s, 'env>,
141
receivers: Vec<Receiver<(usize, Morsel)>>,
142
state: &'s StreamingExecutionState,
143
join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,
144
) {
145
for (mut recv, local) in receivers.into_iter().zip(&mut self.locals) {
146
let key_selectors_per_input = &self.key_selectors_per_input;
147
let reductions_per_input = &self.reductions_per_input;
148
let uniq_grouped_reduction_cols_per_input = &self.uniq_grouped_reduction_cols_per_input;
149
let grouped_reduction_cols = &self.grouped_reduction_cols;
150
let random_state = &self.random_state;
151
let partitioner = self.partitioner.clone();
152
let has_order_sensitive_agg = self.has_order_sensitive_agg;
153
join_handles.push(scope.spawn_task(TaskPriority::High, async move {
154
let mut hot_idxs = Vec::new();
155
let mut hot_group_idxs = Vec::new();
156
let mut cold_idxs = Vec::new();
157
let mut in_cols = Vec::new();
158
while let Some((input_idx, morsel)) = recv.recv().await {
159
// Compute hot group indices from key.
160
let seq = morsel.seq().to_u64();
161
let mut df = morsel.into_df();
162
let mut key_columns = Vec::new();
163
for selector in &key_selectors_per_input[input_idx] {
164
let s = selector.evaluate(&df, &state.in_memory_exec_state).await?;
165
key_columns.push(s.into_column());
166
}
167
let keys = unsafe {
168
DataFrame::new_unchecked_with_broadcast(df.height(), key_columns)?
169
};
170
let hash_keys = HashKeys::from_df(&keys, random_state.clone(), true, false);
171
172
let hot_grouper = &mut local.hot_grouper_per_input[input_idx];
173
hot_idxs.clear();
174
hot_group_idxs.clear();
175
cold_idxs.clear();
176
hot_grouper.insert_keys(
177
&hash_keys,
178
&mut hot_idxs,
179
&mut hot_group_idxs,
180
&mut cold_idxs,
181
has_order_sensitive_agg,
182
);
183
184
// Drop columns not used for reductions (key-only columns).
185
let uniq_grouped_reduction_cols =
186
&uniq_grouped_reduction_cols_per_input[input_idx];
187
if uniq_grouped_reduction_cols.len() < df.width() {
188
df = unsafe { df.select_unchecked(uniq_grouped_reduction_cols.as_slice()) }
189
.unwrap();
190
}
191
df.rechunk_mut(); // For gathers.
192
193
// Update hot reductions.
194
for red_idx in &reductions_per_input[input_idx] {
195
let cols = &grouped_reduction_cols[*red_idx];
196
let reduction = &mut local.hot_grouped_reductions[*red_idx];
197
for col in cols {
198
in_cols.push(df.column(col).unwrap());
199
}
200
unsafe {
201
// SAFETY: we resize the reduction to the number of groups beforehand.
202
reduction.resize(hot_grouper.num_groups());
203
reduction.update_groups_while_evicting(
204
&in_cols,
205
&hot_idxs,
206
&hot_group_idxs,
207
seq,
208
)?;
209
}
210
in_cols.clear();
211
in_cols = in_cols.into_iter().map(|_| unreachable!()).collect(); // Clear lifetimes.
212
}
213
214
// Store cold keys.
215
// TODO: don't always gather, if majority cold simply store all and remember offsets into it.
216
if !cold_idxs.is_empty() {
217
unsafe {
218
let cold_keys = hash_keys.gather_unchecked(&cold_idxs);
219
let cold_df = df.take_slice_unchecked_impl(&cold_idxs, false);
220
221
cold_keys.gen_idxs_per_partition(
222
&partitioner,
223
&mut local.morsel_idxs_values_per_p,
224
&mut local.sketch_per_p,
225
true,
226
);
227
local
228
.morsel_idxs_offsets_per_p
229
.extend(local.morsel_idxs_values_per_p.iter().map(|vp| vp.len()));
230
local
231
.cold_morsels
232
.push((input_idx, seq, cold_keys, cold_df));
233
}
234
}
235
236
// If we have too many evicted rows, flush them.
237
if hot_grouper.num_evictions() >= get_ideal_morsel_size() {
238
local.flush_evictions(
239
input_idx,
240
&reductions_per_input[input_idx],
241
&partitioner,
242
);
243
}
244
}
245
Ok(())
246
}));
247
}
248
}
249
250
fn combine_locals(&mut self) -> PolarsResult<Vec<GroupByPartition>> {
251
// Finalize pre-aggregations.
252
POOL.install(|| {
253
self.locals
254
.as_mut_slice()
255
.into_par_iter()
256
.with_max_len(1)
257
.for_each(|l| {
258
for (input_idx, r_idxs) in self.reductions_per_input.iter().enumerate() {
259
let hot_grouper = &mut l.hot_grouper_per_input[input_idx];
260
if hot_grouper.num_evictions() > 0 {
261
l.flush_evictions(input_idx, r_idxs, &self.partitioner);
262
}
263
}
264
265
let mut opt_hot_reductions =
266
l.hot_grouped_reductions.drain(..).map(Some).collect_vec();
267
for (input_idx, r_idxs) in self.reductions_per_input.iter().enumerate() {
268
let hot_grouper = &mut l.hot_grouper_per_input[input_idx];
269
let hot_keys = hot_grouper.keys();
270
let hot_reductions = r_idxs
271
.iter()
272
.map(|r| opt_hot_reductions[*r].take().unwrap())
273
.collect_vec();
274
l.add_pre_agg(hot_keys, r_idxs, hot_reductions, &self.partitioner);
275
}
276
});
277
});
278
279
// To reduce maximum memory usage we want to drop the morsels
280
// as soon as they're processed, so we move into Arcs. The drops might
281
// also be expensive, so instead of directly dropping we put that on
282
// a work queue.
283
let morsels_per_local = self
284
.locals
285
.iter_mut()
286
.map(|l| Arc::new(core::mem::take(&mut l.cold_morsels)))
287
.collect_vec();
288
let pre_aggs_per_local = self
289
.locals
290
.iter_mut()
291
.map(|l| Arc::new(core::mem::take(&mut l.pre_aggs)))
292
.collect_vec();
293
enum ToDrop<A, B> {
294
A(A),
295
B(B),
296
}
297
let (drop_q_send, drop_q_recv) = async_channel::bounded(self.locals.len());
298
let num_partitions = self.locals[0].sketch_per_p.len();
299
let output_per_partition: SparseInitVec<GroupByPartition> =
300
SparseInitVec::with_capacity(num_partitions);
301
let locals = &self.locals;
302
let grouper_template = &self.grouper;
303
let reductions_per_input = &self.reductions_per_input;
304
let grouped_reductions_template = &self.grouped_reductions;
305
let grouped_reduction_cols = &self.grouped_reduction_cols;
306
307
async_executor::task_scope(|s| {
308
// Wrap in outer Arc to move to each thread, performing the
309
// expensive clone on that thread.
310
let arc_morsels_per_local = Arc::new(morsels_per_local);
311
let arc_pre_aggs_per_local = Arc::new(pre_aggs_per_local);
312
let mut join_handles = Vec::new();
313
for p in 0..num_partitions {
314
let arc_morsels_per_local = Arc::clone(&arc_morsels_per_local);
315
let arc_pre_aggs_per_local = Arc::clone(&arc_pre_aggs_per_local);
316
let drop_q_send = drop_q_send.clone();
317
let drop_q_recv = drop_q_recv.clone();
318
let output_per_partition = &output_per_partition;
319
join_handles.push(s.spawn_task(TaskPriority::High, async move {
320
// Extract from outer arc and drop outer arc.
321
let morsels_per_local = Arc::unwrap_or_clone(arc_morsels_per_local);
322
let pre_aggs_per_local = Arc::unwrap_or_clone(arc_pre_aggs_per_local);
323
324
// Compute cardinality estimate and total amount of
325
// payload for this partition.
326
let mut sketch = CardinalitySketch::new();
327
for l in locals {
328
sketch.combine(&l.sketch_per_p[p]);
329
}
330
331
// Allocate grouper and reductions.
332
let est_num_groups = sketch.estimate() * 5 / 4;
333
let mut p_grouper = grouper_template.new_empty();
334
let mut p_reductions = grouped_reductions_template
335
.iter()
336
.map(|gr| gr.new_empty())
337
.collect_vec();
338
p_grouper.reserve(est_num_groups);
339
for r in &mut p_reductions {
340
r.reserve(est_num_groups);
341
}
342
343
// Insert morsels.
344
let mut skip_drop_attempt = false;
345
let mut group_idxs = Vec::new();
346
let mut in_cols = Vec::new();
347
for (l, l_morsels) in locals.iter().zip(morsels_per_local) {
348
// Try to help with dropping.
349
if !skip_drop_attempt {
350
drop(drop_q_recv.try_recv());
351
}
352
353
for (i, morsel) in l_morsels.iter().enumerate() {
354
let (input_idx, seq_id, keys, morsel_df) = morsel;
355
unsafe {
356
let p_morsel_idxs_start =
357
l.morsel_idxs_offsets_per_p[i * num_partitions + p];
358
let p_morsel_idxs_stop =
359
l.morsel_idxs_offsets_per_p[(i + 1) * num_partitions + p];
360
let p_morsel_idxs = &l.morsel_idxs_values_per_p[p]
361
[p_morsel_idxs_start..p_morsel_idxs_stop];
362
363
group_idxs.clear();
364
p_grouper.insert_keys_subset(
365
keys,
366
p_morsel_idxs,
367
Some(&mut group_idxs),
368
);
369
370
for red_idx in &reductions_per_input[*input_idx] {
371
let cols = &grouped_reduction_cols[*red_idx];
372
let reduction = &mut p_reductions[*red_idx];
373
for col in cols {
374
in_cols.push(morsel_df.column(col).unwrap());
375
}
376
reduction.resize(p_grouper.num_groups());
377
reduction.update_groups_subset(
378
&in_cols,
379
p_morsel_idxs,
380
&group_idxs,
381
*seq_id,
382
)?;
383
in_cols.clear();
384
}
385
}
386
}
387
in_cols = in_cols.into_iter().map(|_| unreachable!()).collect(); // Clear lifetimes.
388
389
if let Some(l) = Arc::into_inner(l_morsels) {
390
// If we're the last thread to process this set of morsels we're probably
391
// falling behind the rest, since the drop can be quite expensive we skip
392
// a drop attempt hoping someone else will pick up the slack.
393
drop(drop_q_send.try_send(ToDrop::A(l)));
394
skip_drop_attempt = true;
395
} else {
396
skip_drop_attempt = false;
397
}
398
}
399
400
// Insert pre-aggregates.
401
for (l, l_pre_aggs) in locals.iter().zip(pre_aggs_per_local) {
402
// Try to help with dropping.
403
if !skip_drop_attempt {
404
drop(drop_q_recv.try_recv());
405
}
406
407
for (i, key_pre_aggs) in l_pre_aggs.iter().enumerate() {
408
let PreAgg {
409
keys,
410
reduction_idxs: r_idxs,
411
reductions: pre_aggs,
412
} = key_pre_aggs;
413
unsafe {
414
let p_pre_agg_idxs_start =
415
l.pre_agg_idxs_offsets_per_p[i * num_partitions + p];
416
let p_pre_agg_idxs_stop =
417
l.pre_agg_idxs_offsets_per_p[(i + 1) * num_partitions + p];
418
let p_pre_agg_idxs = &l.pre_agg_idxs_values_per_p[p]
419
[p_pre_agg_idxs_start..p_pre_agg_idxs_stop];
420
421
group_idxs.clear();
422
p_grouper.insert_keys_subset(
423
keys,
424
p_pre_agg_idxs,
425
Some(&mut group_idxs),
426
);
427
for (pre_agg, r_idx) in pre_aggs.iter().zip(r_idxs.iter()) {
428
let r = &mut p_reductions[*r_idx];
429
r.resize(p_grouper.num_groups());
430
r.combine_subset(&**pre_agg, p_pre_agg_idxs, &group_idxs)?;
431
}
432
}
433
}
434
435
if let Some(l) = Arc::into_inner(l_pre_aggs) {
436
// If we're the last thread to process this set of morsels we're probably
437
// falling behind the rest, since the drop can be quite expensive we skip
438
// a drop attempt hoping someone else will pick up the slack.
439
drop(drop_q_send.try_send(ToDrop::B(l)));
440
skip_drop_attempt = true;
441
} else {
442
skip_drop_attempt = false;
443
}
444
}
445
446
// We're done, help others out by doing drops.
447
drop(drop_q_send); // So we don't deadlock trying to receive from ourselves.
448
while let Ok(to_drop) = drop_q_recv.recv().await {
449
drop(to_drop);
450
}
451
452
output_per_partition
453
.try_set(
454
p,
455
GroupByPartition {
456
grouper: p_grouper,
457
grouped_reductions: p_reductions,
458
},
459
)
460
.ok()
461
.unwrap();
462
463
PolarsResult::Ok(())
464
}));
465
}
466
467
// Drop outer arc after spawning each thread so the inner arcs
468
// can get dropped as soon as they're processed. We also have to
469
// drop the drop queue sender so we don't deadlock waiting for it
470
// to end.
471
drop(arc_morsels_per_local);
472
drop(arc_pre_aggs_per_local);
473
drop(drop_q_send);
474
475
polars_io::pl_async::get_runtime().block_on(async move {
476
for handle in join_handles {
477
handle.await?;
478
}
479
PolarsResult::Ok(())
480
})?;
481
PolarsResult::Ok(())
482
})?;
483
484
// Drop remaining local state in parallel.
485
POOL.install(|| {
486
core::mem::take(&mut self.locals)
487
.into_par_iter()
488
.with_max_len(1)
489
.for_each(drop);
490
});
491
492
Ok(output_per_partition.try_assume_init().ok().unwrap())
493
}
494
}
495
496
struct GroupByPartition {
497
grouper: Box<dyn Grouper>,
498
grouped_reductions: Vec<Box<dyn GroupedReduction>>,
499
}
500
501
impl GroupByPartition {
502
fn into_df(self, key_schema: &Schema, output_schema: &Schema) -> PolarsResult<DataFrame> {
503
let mut out = self.grouper.get_keys_in_group_order(key_schema);
504
let out_names = output_schema.iter_names().skip(out.width());
505
for (mut r, name) in self.grouped_reductions.into_iter().zip(out_names) {
506
unsafe {
507
out.push_column_unchecked(r.finalize()?.with_name(name.clone()).into_column());
508
}
509
}
510
Ok(out)
511
}
512
}
513
514
enum GroupByState {
515
Sink(GroupBySinkState),
516
Source(InMemorySourceNode),
517
Done,
518
}
519
520
pub struct GroupByNode {
521
state: GroupByState,
522
key_schema: Arc<Schema>,
523
num_inputs: usize,
524
num_pipelines: usize,
525
output_schema: Arc<Schema>,
526
}
527
528
impl GroupByNode {
529
#[allow(clippy::too_many_arguments)]
530
pub fn new(
531
key_schema: Arc<Schema>,
532
// Input stream i selects keys with key_selectors_per_input[i].
533
key_selectors_per_input: Vec<Vec<StreamExpr>>,
534
// Input stream i feeds grouped_reductions[k] for each k in reductions_per_input[i].
535
reductions_per_input: Vec<Vec<usize>>,
536
grouper: Box<dyn Grouper>,
537
// grouped_reductions[k] is passed input cols grouped_reduction_cols[k].
538
grouped_reduction_cols: Vec<Vec<PlSmallStr>>,
539
grouped_reductions: Vec<Box<dyn GroupedReduction>>,
540
output_schema: Arc<Schema>,
541
random_state: PlRandomState,
542
num_pipelines: usize,
543
has_order_sensitive_agg: bool,
544
) -> Self {
545
let hot_table_size = std::env::var("POLARS_HOT_TABLE_SIZE")
546
.map(|sz| sz.parse::<usize>().unwrap())
547
.unwrap_or(DEFAULT_HOT_TABLE_SIZE);
548
let num_inputs = key_selectors_per_input.len();
549
let num_partitions = num_pipelines;
550
let uniq_grouped_reduction_cols_per_input = reductions_per_input
551
.iter()
552
.map(|rs| {
553
rs.iter()
554
.flat_map(|k| grouped_reduction_cols[*k].iter())
555
.cloned()
556
.collect::<PlHashSet<_>>()
557
.into_iter()
558
.collect_vec()
559
})
560
.collect_vec();
561
let locals = (0..num_pipelines)
562
.map(|_| {
563
let reductions = grouped_reductions.iter().map(|gr| gr.new_empty()).collect();
564
LocalGroupBySinkState::new(
565
key_schema.clone(),
566
reductions,
567
hot_table_size,
568
num_partitions,
569
num_inputs,
570
)
571
})
572
.collect();
573
let partitioner = HashPartitioner::new(num_partitions, 0);
574
Self {
575
state: GroupByState::Sink(GroupBySinkState {
576
key_selectors_per_input,
577
reductions_per_input,
578
grouped_reductions,
579
grouper,
580
random_state,
581
uniq_grouped_reduction_cols_per_input,
582
grouped_reduction_cols,
583
locals,
584
partitioner,
585
has_order_sensitive_agg,
586
}),
587
key_schema,
588
num_inputs,
589
num_pipelines,
590
output_schema,
591
}
592
}
593
}
594
595
impl ComputeNode for GroupByNode {
596
fn name(&self) -> &str {
597
"group-by"
598
}
599
600
fn update_state(
601
&mut self,
602
recv: &mut [PortState],
603
send: &mut [PortState],
604
state: &StreamingExecutionState,
605
) -> PolarsResult<()> {
606
assert!(recv.len() == self.num_inputs && send.len() == 1);
607
608
// State transitions.
609
match &mut self.state {
610
// If the output doesn't want any more data, transition to being done.
611
_ if send[0] == PortState::Done => {
612
self.state = GroupByState::Done;
613
},
614
// All inputs is done, transition to being a source.
615
GroupByState::Sink(_) if recv.iter().all(|r| matches!(r, PortState::Done)) => {
616
let GroupByState::Sink(mut sink) =
617
core::mem::replace(&mut self.state, GroupByState::Done)
618
else {
619
unreachable!()
620
};
621
let partitions = sink.combine_locals()?;
622
let dfs = POOL.install(|| {
623
partitions
624
.into_par_iter()
625
.map(|p| p.into_df(&self.key_schema, &self.output_schema))
626
.collect::<Result<Vec<_>, _>>()
627
})?;
628
629
let df = accumulate_dataframes_vertical_unchecked(dfs);
630
let source = InMemorySourceNode::new(Arc::new(df), MorselSeq::new(0));
631
self.state = GroupByState::Source(source);
632
},
633
// Defer to source node implementation.
634
GroupByState::Source(src) => {
635
src.update_state(&mut [], send, state)?;
636
if send[0] == PortState::Done {
637
self.state = GroupByState::Done;
638
}
639
},
640
// Nothing to change.
641
GroupByState::Done | GroupByState::Sink(_) => {},
642
}
643
644
// Communicate our state.
645
match &self.state {
646
GroupByState::Sink { .. } => {
647
recv.fill(PortState::Ready);
648
send[0] = PortState::Blocked;
649
},
650
GroupByState::Source(..) => {
651
recv.fill(PortState::Done);
652
send[0] = PortState::Ready;
653
},
654
GroupByState::Done => {
655
recv.fill(PortState::Done);
656
send[0] = PortState::Done;
657
},
658
}
659
Ok(())
660
}
661
662
fn spawn<'env, 's>(
663
&'env mut self,
664
scope: &'s TaskScope<'s, 'env>,
665
recv_ports: &mut [Option<RecvPort<'_>>],
666
send_ports: &mut [Option<SendPort<'_>>],
667
state: &'s StreamingExecutionState,
668
join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,
669
) {
670
assert!(send_ports.len() == 1 && recv_ports.len() == self.num_inputs);
671
match &mut self.state {
672
GroupByState::Sink(sink) => {
673
assert!(send_ports[0].is_none());
674
assert!(recv_ports.iter().any(|r| r.is_some()));
675
676
// If we have multiple input streams merge them into one (still identifying which
677
// input stream it came from).
678
let (senders, receivers): (Vec<_>, Vec<_>) =
679
(0..self.num_pipelines).map(|_| channel(1)).unzip();
680
for (i, recv_port) in recv_ports.iter_mut().enumerate() {
681
if let Some(recv_port) = recv_port.take() {
682
for (mut r, s) in recv_port
683
.parallel()
684
.into_iter()
685
.zip(senders.iter().cloned())
686
{
687
join_handles.push(scope.spawn_task(TaskPriority::High, async move {
688
while let Ok(morsel) = r.recv().await {
689
if s.send((i, morsel)).await.is_err() {
690
break;
691
}
692
}
693
694
Ok(())
695
}));
696
}
697
}
698
}
699
sink.spawn(scope, receivers, state, join_handles)
700
},
701
GroupByState::Source(source) => {
702
assert!(recv_ports[0].is_none());
703
source.spawn(scope, &mut [], send_ports, state, join_handles);
704
},
705
GroupByState::Done => unreachable!(),
706
}
707
}
708
}
709
710