Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-stream/src/async_primitives/linearizer.rs
6939 views
1
use std::collections::BinaryHeap;
2
3
use tokio::sync::mpsc::{Receiver, Sender, channel};
4
5
/// Stores the state for which inserter we need to poll.
6
enum PollState {
7
NoPoll,
8
Poll(usize),
9
PollAll,
10
}
11
12
struct LinearedItem<T> {
13
value: T,
14
sender_id: usize,
15
}
16
17
impl<T: Ord> PartialEq for LinearedItem<T> {
18
fn eq(&self, other: &Self) -> bool {
19
self.value.eq(&other.value)
20
}
21
}
22
impl<T: Ord> Eq for LinearedItem<T> {}
23
impl<T: Ord> PartialOrd for LinearedItem<T> {
24
#[allow(clippy::non_canonical_partial_ord_impl)]
25
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
26
Some(self.value.cmp(&other.value))
27
}
28
}
29
impl<T: Ord> Ord for LinearedItem<T> {
30
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
31
self.value.cmp(&other.value)
32
}
33
}
34
35
/// Utility to convert the input of `N` senders of ordered data into `1` stream of ordered data.
36
pub struct Linearizer<T> {
37
receivers: Vec<Receiver<T>>,
38
poll_state: PollState,
39
40
heap: BinaryHeap<LinearedItem<T>>,
41
}
42
43
impl<T: Ord> Linearizer<T> {
44
pub fn new(num_inserters: usize, buffer_size: usize) -> (Self, Vec<Inserter<T>>) {
45
let mut receivers = Vec::with_capacity(num_inserters);
46
let mut inserters = Vec::with_capacity(num_inserters);
47
48
for _ in 0..num_inserters {
49
// We could perhaps use a bespoke spsc bounded channel here in the
50
// future, instead of tokio's mpsc channel.
51
let (sender, receiver) = channel(buffer_size);
52
receivers.push(receiver);
53
inserters.push(Inserter { sender });
54
}
55
let slf = Self {
56
receivers,
57
poll_state: PollState::PollAll,
58
heap: BinaryHeap::with_capacity(num_inserters),
59
};
60
(slf, inserters)
61
}
62
63
pub fn new_with_maintain_order(
64
num_inserters: usize,
65
buffer_size: usize,
66
maintain_order: bool,
67
) -> (Self, Vec<Inserter<T>>) {
68
if maintain_order {
69
return Self::new(num_inserters, buffer_size);
70
}
71
72
let (sender, receiver) = channel(buffer_size * num_inserters);
73
let receivers = vec![receiver];
74
let inserters = (0..num_inserters)
75
.map(|_| Inserter {
76
sender: sender.clone(),
77
})
78
.collect();
79
80
let slf = Self {
81
receivers,
82
poll_state: PollState::PollAll,
83
heap: BinaryHeap::with_capacity(1),
84
};
85
(slf, inserters)
86
}
87
88
/// Fetch the next ordered item produced by senders.
89
///
90
/// This may wait for at each sender to have sent at least one value before the [`Linearizer`]
91
/// starts producing.
92
///
93
/// If all senders have closed their channels and there are no more buffered values, this
94
/// returns `None`.
95
pub async fn get(&mut self) -> Option<T> {
96
// The idea is that we have exactly one value per inserter in the
97
// binary heap, and when we take one out we must refill it. This way we
98
// always ensure we have the value with the highest global order.
99
let poll_range = match self.poll_state {
100
PollState::NoPoll => 0..0,
101
PollState::Poll(i) => i..i + 1,
102
PollState::PollAll => 0..self.receivers.len(),
103
};
104
105
for sender_id in poll_range {
106
// If no value was received from that particular inserter, that
107
// stream is done and thus we no longer need to consider it for the
108
// global order.
109
if let Some(value) = self.receivers[sender_id].recv().await {
110
self.heap.push(LinearedItem { value, sender_id });
111
}
112
}
113
114
if let Some(first_in_merged_streams) = self.heap.pop() {
115
let LinearedItem { value, sender_id } = first_in_merged_streams;
116
self.poll_state = PollState::Poll(sender_id);
117
Some(value)
118
} else {
119
self.poll_state = PollState::NoPoll;
120
None
121
}
122
}
123
}
124
125
pub struct Inserter<T> {
126
sender: Sender<T>,
127
}
128
129
impl<T: Ord> Inserter<T> {
130
pub async fn insert(&mut self, value: T) -> Result<(), T> {
131
self.sender.send(value).await.map_err(|e| e.0)
132
}
133
}
134
135