Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-stream/src/async_primitives/distributor_channel.rs
6939 views
1
use std::cell::UnsafeCell;
2
use std::mem::MaybeUninit;
3
use std::sync::Arc;
4
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
5
6
use crossbeam_utils::CachePadded;
7
use rand::prelude::*;
8
9
use super::task_parker::TaskParker;
10
11
/// Single-producer multi-consumer FIFO channel.
12
///
13
/// Each [`Receiver`] has an internal buffer of `bufsize`. Thus it is possible
14
/// that when one [`Sender`] is exhausted some other receivers still have data
15
/// available.
16
///
17
/// The FIFO order is only guaranteed per receiver. That is, each receiver is
18
/// guaranteed to see a subset of the data sent by the sender in the order the
19
/// sender sent it in, but not necessarily contiguously.
20
///
21
/// When one or more receivers are closed no attempt is made to avoid filling
22
/// those receivers' buffers. The values in the buffer of a closed receiver are
23
/// lost forever, they're not redistributed among the others, and simply
24
/// dropped when the channel is dropped.
25
pub fn distributor_channel<T>(
26
num_receivers: usize,
27
bufsize: usize,
28
) -> (Sender<T>, Vec<Receiver<T>>) {
29
let capacity = bufsize.next_power_of_two();
30
let receivers = (0..num_receivers)
31
.map(|_| {
32
CachePadded::new(ReceiverSlot {
33
closed: AtomicBool::new(false),
34
read_head: AtomicUsize::new(0),
35
parker: TaskParker::default(),
36
data: (0..capacity)
37
.map(|_| UnsafeCell::new(MaybeUninit::uninit()))
38
.collect(),
39
})
40
})
41
.collect();
42
let inner = Arc::new(DistributorInner {
43
send_closed: AtomicBool::new(false),
44
send_parker: TaskParker::default(),
45
write_heads: (0..num_receivers).map(|_| AtomicUsize::new(0)).collect(),
46
receivers,
47
48
bufsize,
49
mask: capacity - 1,
50
});
51
52
let receivers = (0..num_receivers)
53
.map(|index| Receiver {
54
inner: inner.clone(),
55
index,
56
})
57
.collect();
58
59
let sender = Sender {
60
inner,
61
round_robin_idx: 0,
62
rng: SmallRng::from_rng(&mut rand::rng()),
63
};
64
65
(sender, receivers)
66
}
67
68
pub enum SendError<T> {
69
Full(T),
70
Closed(T),
71
}
72
73
pub enum RecvError {
74
Empty,
75
Closed,
76
}
77
78
struct ReceiverSlot<T> {
79
closed: AtomicBool,
80
read_head: AtomicUsize,
81
parker: TaskParker,
82
data: Box<[UnsafeCell<MaybeUninit<T>>]>,
83
}
84
85
struct DistributorInner<T> {
86
send_closed: AtomicBool,
87
send_parker: TaskParker,
88
write_heads: Vec<AtomicUsize>,
89
receivers: Vec<CachePadded<ReceiverSlot<T>>>,
90
91
bufsize: usize,
92
mask: usize,
93
}
94
95
impl<T> DistributorInner<T> {
96
fn reduce_index(&self, idx: usize) -> usize {
97
idx & self.mask
98
}
99
}
100
101
pub struct Sender<T> {
102
inner: Arc<DistributorInner<T>>,
103
round_robin_idx: usize,
104
rng: SmallRng,
105
}
106
107
pub struct Receiver<T> {
108
inner: Arc<DistributorInner<T>>,
109
index: usize,
110
}
111
112
unsafe impl<T: Send> Send for Sender<T> {}
113
unsafe impl<T: Send> Send for Receiver<T> {}
114
115
impl<T: Send> Sender<T> {
116
/// Note: This intentionally takes `&mut` to ensure it is only accessed in a single-threaded
117
/// manner.
118
pub async fn send(&mut self, mut value: T) -> Result<(), T> {
119
let num_receivers = self.inner.receivers.len();
120
loop {
121
// Fast-path.
122
self.round_robin_idx += 1;
123
if self.round_robin_idx >= num_receivers {
124
self.round_robin_idx -= num_receivers;
125
}
126
127
let mut hungriest_idx = self.round_robin_idx;
128
let mut shortest_len = self.upper_bound_len(self.round_robin_idx);
129
for _ in 0..4 {
130
let idx = ((self.rng.random::<u32>() as u64 * num_receivers as u64) >> 32) as usize;
131
let len = self.upper_bound_len(idx);
132
if len < shortest_len {
133
shortest_len = len;
134
hungriest_idx = idx;
135
}
136
}
137
138
match unsafe { self.try_send(hungriest_idx, value) } {
139
Ok(()) => return Ok(()),
140
Err(SendError::Full(v)) => value = v,
141
Err(SendError::Closed(v)) => value = v,
142
}
143
144
// Do one proper search before parking.
145
let park = self.inner.send_parker.park();
146
147
// Try all receivers, starting at a random index.
148
let mut idx = ((self.rng.random::<u32>() as u64 * num_receivers as u64) >> 32) as usize;
149
let mut all_closed = true;
150
for _ in 0..num_receivers {
151
match unsafe { self.try_send(idx, value) } {
152
Ok(()) => return Ok(()),
153
Err(SendError::Full(v)) => {
154
all_closed = false;
155
value = v;
156
},
157
Err(SendError::Closed(v)) => value = v,
158
}
159
160
idx += 1;
161
if idx >= num_receivers {
162
idx -= num_receivers;
163
}
164
}
165
166
if all_closed {
167
return Err(value);
168
}
169
170
park.await;
171
}
172
}
173
174
// Returns the upper bound on the length of the queue of the given receiver.
175
// It is an upper bound because racy reads can reduce it in the meantime.
176
fn upper_bound_len(&self, recv_idx: usize) -> usize {
177
let read_head = self.inner.receivers[recv_idx]
178
.read_head
179
.load(Ordering::SeqCst);
180
let write_head = self.inner.write_heads[recv_idx].load(Ordering::Relaxed);
181
write_head.wrapping_sub(read_head)
182
}
183
184
/// # Safety
185
/// May only be called from one thread at a time.
186
unsafe fn try_send(&self, recv_idx: usize, value: T) -> Result<(), SendError<T>> {
187
let read_head = self.inner.receivers[recv_idx]
188
.read_head
189
.load(Ordering::SeqCst);
190
let write_head = self.inner.write_heads[recv_idx].load(Ordering::Relaxed);
191
let len = write_head.wrapping_sub(read_head);
192
if len < self.inner.bufsize {
193
let idx = self.inner.reduce_index(write_head);
194
unsafe {
195
self.inner.receivers[recv_idx].data[idx]
196
.get()
197
.write(MaybeUninit::new(value));
198
self.inner.write_heads[recv_idx]
199
.store(write_head.wrapping_add(1), Ordering::SeqCst);
200
}
201
self.inner.receivers[recv_idx].parker.unpark();
202
Ok(())
203
} else if self.inner.receivers[recv_idx].closed.load(Ordering::SeqCst) {
204
Err(SendError::Closed(value))
205
} else {
206
Err(SendError::Full(value))
207
}
208
}
209
}
210
211
impl<T: Send> Receiver<T> {
212
/// Note: This intentionally takes `&mut` to ensure it is only accessed in a single-threaded
213
/// manner.
214
pub async fn recv(&mut self) -> Result<T, ()> {
215
loop {
216
// Fast-path.
217
match unsafe { self.try_recv() } {
218
Ok(v) => return Ok(v),
219
Err(RecvError::Closed) => return Err(()),
220
Err(RecvError::Empty) => {},
221
}
222
223
// Try again, threatening to park if there's still nothing.
224
let park = self.inner.receivers[self.index].parker.park();
225
match unsafe { self.try_recv() } {
226
Ok(v) => return Ok(v),
227
Err(RecvError::Closed) => return Err(()),
228
Err(RecvError::Empty) => {},
229
}
230
park.await;
231
}
232
}
233
234
/// # Safety
235
/// May only be called from one thread at a time.
236
unsafe fn try_recv(&self) -> Result<T, RecvError> {
237
loop {
238
let read_head = self.inner.receivers[self.index]
239
.read_head
240
.load(Ordering::Relaxed);
241
let write_head = self.inner.write_heads[self.index].load(Ordering::SeqCst);
242
if read_head != write_head {
243
let idx = self.inner.reduce_index(read_head);
244
let read;
245
unsafe {
246
let ptr = self.inner.receivers[self.index].data[idx].get();
247
read = ptr.read().assume_init();
248
self.inner.receivers[self.index]
249
.read_head
250
.store(read_head.wrapping_add(1), Ordering::SeqCst);
251
}
252
self.inner.send_parker.unpark();
253
return Ok(read);
254
} else if self.inner.send_closed.load(Ordering::SeqCst) {
255
// Check write head again, sender could've sent something right
256
// before closing. We can do this relaxed because we'll read it
257
// again in the next iteration with SeqCst if it's a new value.
258
if write_head == self.inner.write_heads[self.index].load(Ordering::Relaxed) {
259
return Err(RecvError::Closed);
260
}
261
} else {
262
return Err(RecvError::Empty);
263
}
264
}
265
}
266
}
267
268
impl<T> Drop for Sender<T> {
269
fn drop(&mut self) {
270
self.inner.send_closed.store(true, Ordering::SeqCst);
271
for recv in &self.inner.receivers {
272
recv.parker.unpark();
273
}
274
}
275
}
276
277
impl<T> Drop for Receiver<T> {
278
fn drop(&mut self) {
279
self.inner.receivers[self.index]
280
.closed
281
.store(true, Ordering::SeqCst);
282
self.inner.send_parker.unpark();
283
}
284
}
285
286
impl<T> Drop for DistributorInner<T> {
287
fn drop(&mut self) {
288
for r in 0..self.receivers.len() {
289
// We have exclusive access, so we only need to atomically load once.
290
let write_head = self.write_heads[r].load(Ordering::SeqCst);
291
let mut read_head = self.receivers[r].read_head.load(Ordering::Relaxed);
292
while read_head != write_head {
293
let idx = self.reduce_index(read_head);
294
unsafe {
295
(*self.receivers[r].data[idx].get()).assume_init_drop();
296
}
297
read_head = read_head.wrapping_add(1);
298
}
299
}
300
}
301
}
302
303