Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-stream/src/nodes/merge_sorted.rs
8479 views
1
use std::collections::VecDeque;
2
3
use polars_core::prelude::ChunkCompareIneq;
4
use polars_ops::frame::_merge_sorted_dfs;
5
6
use crate::DEFAULT_DISTRIBUTOR_BUFFER_SIZE;
7
use crate::async_primitives::distributor_channel::distributor_channel;
8
use crate::morsel::{SourceToken, get_ideal_morsel_size};
9
use crate::nodes::compute_node_prelude::*;
10
11
/// Performs `merge_sorted` with the last column being regarded as the key column. This key column
12
/// is also popped in the send pipe.
13
pub struct MergeSortedNode {
14
seq: MorselSeq,
15
16
starting_nulls: bool,
17
18
// Not yet merged buffers.
19
left_unmerged: VecDeque<DataFrame>,
20
right_unmerged: VecDeque<DataFrame>,
21
}
22
23
impl MergeSortedNode {
24
pub fn new() -> Self {
25
Self {
26
seq: MorselSeq::default(),
27
28
starting_nulls: false,
29
30
left_unmerged: VecDeque::new(),
31
right_unmerged: VecDeque::new(),
32
}
33
}
34
}
35
36
/// Find a part amongst both unmerged buffers which is mergeable.
37
///
38
/// This returns `None` if there is nothing mergeable at this point.
39
fn find_mergeable(
40
left_unmerged: &mut VecDeque<DataFrame>,
41
right_unmerged: &mut VecDeque<DataFrame>,
42
43
is_first: bool,
44
starting_nulls: &mut bool,
45
) -> PolarsResult<Option<(DataFrame, DataFrame)>> {
46
fn first_non_empty(vd: &mut VecDeque<DataFrame>) -> Option<DataFrame> {
47
let mut df = vd.pop_front()?;
48
while df.height() == 0 {
49
df = vd.pop_front()?;
50
}
51
Some(df)
52
}
53
54
loop {
55
let (mut left, mut right) = match (
56
first_non_empty(left_unmerged),
57
first_non_empty(right_unmerged),
58
) {
59
(Some(l), Some(r)) => (l, r),
60
(Some(l), None) => {
61
left_unmerged.push_front(l);
62
return Ok(None);
63
},
64
(None, Some(r)) => {
65
right_unmerged.push_front(r);
66
return Ok(None);
67
},
68
(None, None) => return Ok(None),
69
};
70
71
let left_key = left.columns().last().unwrap();
72
let right_key = right.columns().last().unwrap();
73
74
let left_null_count = left_key.null_count();
75
let right_null_count = right_key.null_count();
76
77
let has_nulls = left_null_count > 0 || right_null_count > 0;
78
79
// If we are on the first morsel we need to decide whether we have
80
// nulls first or not.
81
if is_first
82
&& has_nulls
83
&& (left_key.head(Some(1)).has_nulls() || right_key.head(Some(1)).has_nulls())
84
{
85
*starting_nulls = true;
86
}
87
88
// For both left and right, find row index of the minimum of the maxima
89
// of the left and right key columns. We can safely merge until this
90
// point.
91
let mut left_cutoff = left.height();
92
let mut right_cutoff = right.height();
93
94
let left_key_last = left_key.tail(Some(1));
95
let right_key_last = right_key.tail(Some(1));
96
97
// We already made sure we had data to work with.
98
assert!(!left_key_last.is_empty());
99
assert!(!right_key_last.is_empty());
100
101
if has_nulls {
102
if *starting_nulls {
103
// If there are starting nulls do those first, then repeat
104
// without the nulls.
105
left_cutoff = left_null_count;
106
right_cutoff = right_null_count;
107
} else {
108
// If there are ending nulls then first do things without the
109
// nulls and then repeat with only the nulls the nulls.
110
let left_is_all_nulls = left_null_count == left.height();
111
let right_is_all_nulls = right_null_count == right.height();
112
113
match (left_is_all_nulls, right_is_all_nulls) {
114
(false, false) => {
115
let left_nulls;
116
let right_nulls;
117
(left, left_nulls) =
118
left.split_at((left.height() - left_null_count) as i64);
119
(right, right_nulls) =
120
right.split_at((right.height() - right_null_count) as i64);
121
122
left_unmerged.push_front(left_nulls);
123
left_unmerged.push_front(left);
124
right_unmerged.push_front(right_nulls);
125
right_unmerged.push_front(right);
126
continue;
127
},
128
(true, false) => left_cutoff = 0,
129
(false, true) => right_cutoff = 0,
130
(true, true) => {},
131
}
132
}
133
} else if left_key_last.lt(&right_key_last)?.all() {
134
// @TODO: This is essentially search sorted, but that does not
135
// support categoricals at moment.
136
let gt_mask = right_key.gt(&left_key_last)?;
137
right_cutoff = gt_mask.downcast_as_array().values().leading_zeros();
138
} else if left_key_last.gt(&right_key_last)?.all() {
139
// @TODO: This is essentially search sorted, but that does not
140
// support categoricals at moment.
141
let gt_mask = left_key.gt(&right_key_last)?;
142
left_cutoff = gt_mask.downcast_as_array().values().leading_zeros();
143
}
144
145
let left_mergeable: DataFrame;
146
let right_mergeable: DataFrame;
147
(left_mergeable, left) = left.split_at(left_cutoff as i64);
148
(right_mergeable, right) = right.split_at(right_cutoff as i64);
149
150
if left.height() > 0 {
151
left_unmerged.push_front(left);
152
}
153
if right.height() > 0 {
154
right_unmerged.push_front(right);
155
}
156
157
return Ok(Some((left_mergeable, right_mergeable)));
158
}
159
}
160
161
fn remove_key_column(df: &mut DataFrame) {
162
// SAFETY:
163
// - We only pop so height stays same.
164
// - We only pop so no new name collisions.
165
// - We clear schema afterwards.
166
unsafe { df.columns_mut().pop().unwrap() };
167
}
168
169
impl ComputeNode for MergeSortedNode {
170
fn name(&self) -> &str {
171
"merge-sorted"
172
}
173
174
fn update_state(
175
&mut self,
176
recv: &mut [PortState],
177
send: &mut [PortState],
178
_state: &StreamingExecutionState,
179
) -> PolarsResult<()> {
180
assert_eq!(send.len(), 1);
181
assert_eq!(recv.len(), 2);
182
183
// Abstraction: we merge buffer state with port state so we can map
184
// to one three possible 'effective' states:
185
// no data now (_blocked); data available (); or no data anymore (_done)
186
let left_done = recv[0] == PortState::Done && self.left_unmerged.is_empty();
187
let right_done = recv[1] == PortState::Done && self.right_unmerged.is_empty();
188
189
// We're done as soon as one side is done.
190
if send[0] == PortState::Done || (left_done && right_done) {
191
recv[0] = PortState::Done;
192
recv[1] = PortState::Done;
193
send[0] = PortState::Done;
194
return Ok(());
195
}
196
197
// Each port is ready to proceed unless one of the other ports is effectively
198
// blocked. For example:
199
// - [Blocked with empty buffer, Ready] [Ready] returns [Ready, Blocked] [Blocked]
200
// - [Blocked with non-empty buffer, Ready] [Ready] returns [Ready, Ready, Ready]
201
let send_blocked = send[0] == PortState::Blocked;
202
let left_blocked = recv[0] == PortState::Blocked && self.left_unmerged.is_empty();
203
let right_blocked = recv[1] == PortState::Blocked && self.right_unmerged.is_empty();
204
send[0] = if left_blocked || right_blocked {
205
PortState::Blocked
206
} else {
207
PortState::Ready
208
};
209
recv[0] = if send_blocked || right_blocked {
210
PortState::Blocked
211
} else {
212
PortState::Ready
213
};
214
recv[1] = if send_blocked || left_blocked {
215
PortState::Blocked
216
} else {
217
PortState::Ready
218
};
219
220
Ok(())
221
}
222
223
fn spawn<'env, 's>(
224
&'env mut self,
225
scope: &'s TaskScope<'s, 'env>,
226
recv_ports: &mut [Option<RecvPort<'_>>],
227
send_ports: &mut [Option<SendPort<'_>>],
228
_state: &'s StreamingExecutionState,
229
join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,
230
) {
231
assert_eq!(recv_ports.len(), 2);
232
assert_eq!(send_ports.len(), 1);
233
234
let send = send_ports[0].take().unwrap().parallel();
235
236
let seq = &mut self.seq;
237
let starting_nulls = &mut self.starting_nulls;
238
let left_unmerged = &mut self.left_unmerged;
239
let right_unmerged = &mut self.right_unmerged;
240
241
match (recv_ports[0].take(), recv_ports[1].take()) {
242
// If we do not need to merge or flush anymore, just start passing the port in
243
// parallel.
244
(Some(port), None) | (None, Some(port))
245
if left_unmerged.is_empty() && right_unmerged.is_empty() =>
246
{
247
let recv = port.parallel();
248
let inner_handles = recv
249
.into_iter()
250
.zip(send)
251
.map(|(mut recv, mut send)| {
252
let morsel_offset = *seq;
253
scope.spawn_task(TaskPriority::High, async move {
254
let mut max_seq = morsel_offset;
255
while let Ok(mut morsel) = recv.recv().await {
256
// Ensure the morsel sequence id stream is monotone non-decreasing.
257
let seq = morsel.seq().offset_by(morsel_offset);
258
max_seq = max_seq.max(seq);
259
260
remove_key_column(morsel.df_mut());
261
262
morsel.set_seq(seq);
263
if send.send(morsel).await.is_err() {
264
break;
265
}
266
}
267
max_seq
268
})
269
})
270
.collect::<Vec<_>>();
271
272
join_handles.push(scope.spawn_task(TaskPriority::High, async move {
273
// Update our global maximum.
274
for handle in inner_handles {
275
*seq = (*seq).max(handle.await);
276
}
277
Ok(())
278
}));
279
},
280
281
// This is the base case. Either:
282
// - Both streams are still open and we still need to merge.
283
// - One or both streams are closed stream is closed and we still have some buffered
284
// data.
285
(left, right) => {
286
async fn buffer_unmerged(
287
port: &mut PortReceiver,
288
unmerged: &mut VecDeque<DataFrame>,
289
) {
290
// If a stop was requested, we need to buffer the remaining
291
// morsels and trigger a phase transition.
292
293
while let Ok(morsel) = port.recv().await {
294
// Request the port stop producing morsels.
295
morsel.source_token().stop();
296
// Buffer all the morsels that were already produced.
297
unmerged.push_back(morsel.into_df());
298
}
299
}
300
301
let (mut distributor, dist_recv) =
302
distributor_channel(send.len(), *DEFAULT_DISTRIBUTOR_BUFFER_SIZE);
303
304
let mut left = left.map(|p| p.serial());
305
let mut right = right.map(|p| p.serial());
306
307
join_handles.push(scope.spawn_task(TaskPriority::Low, async move {
308
let source_token = SourceToken::new();
309
310
// While we can still load data for the empty side.
311
while (left.is_some() || right.is_some())
312
&& !(left.is_none() && left_unmerged.is_empty())
313
&& !(right.is_none() && right_unmerged.is_empty())
314
{
315
// If we have morsels from both input ports, find until where we can merge
316
// them and send that on to be merged.
317
while let Some((left_mergeable, right_mergeable)) = find_mergeable(
318
left_unmerged,
319
right_unmerged,
320
seq.to_u64() == 0,
321
starting_nulls,
322
)? {
323
let left_mergeable =
324
Morsel::new(left_mergeable, *seq, source_token.clone());
325
*seq = seq.successor();
326
327
if distributor
328
.send((left_mergeable, right_mergeable))
329
.await
330
.is_err()
331
{
332
return Ok(());
333
};
334
}
335
336
if source_token.stop_requested() {
337
// Request that a port stops producing morsels and buffers all the
338
// remaining morsels.
339
if let Some(p) = &mut left {
340
buffer_unmerged(p, left_unmerged).await;
341
}
342
if let Some(p) = &mut right {
343
buffer_unmerged(p, right_unmerged).await;
344
}
345
break;
346
}
347
348
assert!(left_unmerged.is_empty() || right_unmerged.is_empty());
349
let (empty_port, empty_unmerged) = match (
350
left_unmerged.is_empty(),
351
right_unmerged.is_empty(),
352
left.as_mut(),
353
right.as_mut(),
354
) {
355
(true, _, Some(left), _) => (left, &mut *left_unmerged),
356
(_, true, _, Some(right)) => (right, &mut *right_unmerged),
357
358
// If the port that is empty is closed, we don't need to merge anymore.
359
_ => break,
360
};
361
362
// Try to get a new morsel from the empty side.
363
let Ok(m) = empty_port.recv().await else {
364
if let Some(p) = &mut left {
365
buffer_unmerged(p, left_unmerged).await;
366
}
367
if let Some(p) = &mut right {
368
buffer_unmerged(p, right_unmerged).await;
369
}
370
break;
371
};
372
empty_unmerged.push_back(m.into_df());
373
}
374
375
// Clear out buffers until we cannot anymore. This helps allows us to go to the
376
// parallel case faster.
377
while let Some((left_mergeable, right_mergeable)) = find_mergeable(
378
left_unmerged,
379
right_unmerged,
380
seq.to_u64() == 0,
381
starting_nulls,
382
)? {
383
let left_mergeable =
384
Morsel::new(left_mergeable, *seq, source_token.clone());
385
*seq = seq.successor();
386
387
if distributor
388
.send((left_mergeable, right_mergeable))
389
.await
390
.is_err()
391
{
392
return Ok(());
393
};
394
}
395
396
// If one of the ports is done and does not have buffered data anymore, we
397
// flush the data on the other side. After this point, this node just pipes
398
// data through.
399
let pass = if left.is_none() && left_unmerged.is_empty() {
400
Some((right.as_mut(), &mut *right_unmerged))
401
} else if right.is_none() && right_unmerged.is_empty() {
402
Some((left.as_mut(), &mut *left_unmerged))
403
} else {
404
None
405
};
406
if let Some((pass_port, pass_unmerged)) = pass {
407
for df in std::mem::take(pass_unmerged) {
408
let m = Morsel::new(df, *seq, source_token.clone());
409
*seq = seq.successor();
410
if distributor.send((m, DataFrame::empty())).await.is_err() {
411
return Ok(());
412
}
413
}
414
415
// Start passing on the port that is still open.
416
if let Some(pass_port) = pass_port {
417
let Ok(mut m) = pass_port.recv().await else {
418
return Ok(());
419
};
420
if source_token.stop_requested() {
421
m.source_token().stop();
422
}
423
m.set_seq(*seq);
424
*seq = seq.successor();
425
if distributor.send((m, DataFrame::empty())).await.is_err() {
426
return Ok(());
427
}
428
429
while let Ok(mut m) = pass_port.recv().await {
430
m.set_seq(*seq);
431
*seq = seq.successor();
432
if distributor.send((m, DataFrame::empty())).await.is_err() {
433
return Ok(());
434
}
435
}
436
}
437
}
438
439
Ok(())
440
}));
441
442
// Task that actually merges the two dataframes. Since this merge might be very
443
// expensive, this is split over several tasks.
444
join_handles.extend(dist_recv.into_iter().zip(send).map(|(mut recv, mut send)| {
445
let ideal_morsel_size = get_ideal_morsel_size();
446
scope.spawn_task(TaskPriority::High, async move {
447
while let Ok((mut left, mut right)) = recv.recv().await {
448
// When we are flushing the buffer, we will just send one morsel from
449
// the input. We don't want to mess with the source token or wait group
450
// and just pass it on.
451
if right.shape_has_zero() {
452
remove_key_column(left.df_mut());
453
454
if send.send(left).await.is_err() {
455
return Ok(());
456
}
457
continue;
458
}
459
460
let (mut left, seq, source_token, wg) = left.into_inner();
461
assert!(wg.is_none());
462
463
let left_s = left
464
.columns()
465
.last()
466
.unwrap()
467
.as_materialized_series()
468
.clone();
469
let right_s = right
470
.columns()
471
.last()
472
.unwrap()
473
.as_materialized_series()
474
.clone();
475
476
remove_key_column(&mut left);
477
remove_key_column(&mut right);
478
479
let merged =
480
_merge_sorted_dfs(&left, &right, &left_s, &right_s, false)?;
481
482
if ideal_morsel_size > 1 && merged.height() > ideal_morsel_size {
483
// The merged dataframe will have at most doubled in size from the
484
// input so we can divide by half.
485
let (m1, m2) = merged.split_at((merged.height() / 2) as i64);
486
487
// MorselSeq have to be monotonely non-decreasing so we can
488
// pass the same sequence token twice.
489
let morsel = Morsel::new(m1, seq, source_token.clone());
490
if send.send(morsel).await.is_err() {
491
break;
492
}
493
let morsel = Morsel::new(m2, seq, source_token.clone());
494
if send.send(morsel).await.is_err() {
495
break;
496
}
497
} else {
498
let morsel = Morsel::new(merged, seq, source_token.clone());
499
if send.send(morsel).await.is_err() {
500
break;
501
}
502
}
503
}
504
505
Ok(())
506
})
507
}));
508
},
509
}
510
}
511
}
512
513