Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-stream/src/async_primitives/connector.rs
8424 views
1
#![allow(unsafe_op_in_unsafe_fn)]
2
use std::cell::UnsafeCell;
3
use std::mem::MaybeUninit;
4
use std::pin::Pin;
5
use std::sync::Arc;
6
use std::sync::atomic::{AtomicU8, Ordering};
7
use std::task::{Context, Poll, Waker};
8
9
use atomic_waker::AtomicWaker;
10
use pin_project_lite::pin_project;
11
12
pub type Sender<T> = SenderExt<T, ()>;
13
pub type Receiver<T> = ReceiverExt<T, ()>;
14
15
/// Single-producer, single-consumer capacity-one channel.
16
pub fn connector<T>() -> (Sender<T>, Receiver<T>) {
17
let connector = Arc::new(Connector::new(()));
18
(
19
Sender {
20
connector: connector.clone(),
21
},
22
Receiver { connector },
23
)
24
}
25
26
/// Single-producer, single-consumer capacity-one channel, with a shared common
27
/// value.
28
pub fn connector_with<T, S>(shared: S) -> (SenderExt<T, S>, ReceiverExt<T, S>) {
29
let connector = Arc::new(Connector::new(shared));
30
(
31
SenderExt {
32
connector: connector.clone(),
33
},
34
ReceiverExt { connector },
35
)
36
}
37
38
/*
39
For UnsafeCell safety, a sender may only set the FULL_BIT (giving exclusive
40
access to value to the receiver), and a receiver may only unset the FULL_BIT
41
(giving exclusive access back to the sender). Setting/clearing the FULL_BIT
42
must be done with a Release ordering, and before reading/writing the value
43
the FULL_BIT must be checked with an Acquire ordering.
44
45
The exception is when the closed bit is set, at that point the unclosed
46
end has full exclusive access.
47
*/
48
49
const FULL_BIT: u8 = 0b1;
50
const CLOSED_BIT: u8 = 0b10;
51
const WAITING_BIT: u8 = 0b100;
52
53
#[repr(align(128))]
54
struct Connector<T, S> {
55
send_waker: AtomicWaker,
56
recv_waker: AtomicWaker,
57
value: UnsafeCell<MaybeUninit<T>>,
58
state: AtomicU8,
59
shared: S,
60
}
61
62
impl<T, S> Connector<T, S> {
63
fn new(shared: S) -> Self {
64
Self {
65
send_waker: AtomicWaker::new(),
66
recv_waker: AtomicWaker::new(),
67
value: UnsafeCell::new(MaybeUninit::uninit()),
68
state: AtomicU8::new(0),
69
shared,
70
}
71
}
72
}
73
74
pub enum SendError<T> {
75
Full(T),
76
Closed(T),
77
}
78
79
pub enum RecvError {
80
Empty,
81
Closed,
82
}
83
84
// SAFETY: all the send methods may only be called from a single sender at a
85
// time, and similarly for all the recv methods from a single receiver.
86
impl<T, S> Connector<T, S> {
87
unsafe fn poll_send(&self, value: &mut Option<T>, waker: &Waker) -> Poll<Result<(), T>> {
88
if let Some(v) = value.take() {
89
let mut state = self.state.load(Ordering::Acquire);
90
if state & FULL_BIT == FULL_BIT {
91
self.send_waker.register(waker);
92
let (Ok(s) | Err(s)) = self.state.compare_exchange(
93
state,
94
state | WAITING_BIT,
95
Ordering::Relaxed,
96
Ordering::Acquire, // Receiver updated, re-acquire.
97
);
98
state = s;
99
}
100
101
match self.try_send_impl(v, state) {
102
Ok(()) => {},
103
Err(SendError::Closed(v)) => return Poll::Ready(Err(v)),
104
Err(SendError::Full(v)) => {
105
*value = Some(v);
106
return Poll::Pending;
107
},
108
}
109
}
110
111
Poll::Ready(Ok(()))
112
}
113
114
unsafe fn try_send_impl(&self, value: T, state: u8) -> Result<(), SendError<T>> {
115
if state & CLOSED_BIT == CLOSED_BIT {
116
return Err(SendError::Closed(value));
117
}
118
if state & FULL_BIT == FULL_BIT {
119
return Err(SendError::Full(value));
120
}
121
122
unsafe {
123
self.value.get().write(MaybeUninit::new(value));
124
let state = self.state.swap(FULL_BIT, Ordering::Release);
125
if state & WAITING_BIT == WAITING_BIT {
126
self.recv_waker.wake();
127
}
128
if state & CLOSED_BIT == CLOSED_BIT {
129
// SAFETY: no synchronization needed, we are the only one left.
130
// Restore the closed bit we just overwrote.
131
self.state.store(CLOSED_BIT, Ordering::Relaxed);
132
return Err(SendError::Closed(self.value.get().read().assume_init()));
133
}
134
}
135
136
Ok(())
137
}
138
139
unsafe fn poll_recv(&self, waker: &Waker) -> Poll<Result<T, ()>> {
140
let mut state = self.state.load(Ordering::Acquire);
141
if state & FULL_BIT == 0 {
142
self.recv_waker.register(waker);
143
let (Ok(s) | Err(s)) = self.state.compare_exchange(
144
state,
145
state | WAITING_BIT,
146
Ordering::Relaxed,
147
Ordering::Acquire, // Sender updated, re-acquire.
148
);
149
state = s;
150
}
151
152
match self.try_recv_impl(state) {
153
Ok(v) => Poll::Ready(Ok(v)),
154
Err(RecvError::Empty) => Poll::Pending,
155
Err(RecvError::Closed) => Poll::Ready(Err(())),
156
}
157
}
158
159
unsafe fn try_recv_impl(&self, state: u8) -> Result<T, RecvError> {
160
if state & FULL_BIT == FULL_BIT {
161
unsafe {
162
let ret = self.value.get().read().assume_init();
163
let state = self.state.swap(0, Ordering::Release);
164
if state & WAITING_BIT == WAITING_BIT {
165
self.send_waker.wake();
166
}
167
if state & CLOSED_BIT == CLOSED_BIT {
168
// Restore the closed bit we just overwrote.
169
self.state.store(CLOSED_BIT, Ordering::Relaxed);
170
}
171
return Ok(ret);
172
}
173
}
174
175
// Check closed bit last so we do receive any last element sent before
176
// closing sender.
177
if state & CLOSED_BIT == CLOSED_BIT {
178
return Err(RecvError::Closed);
179
}
180
181
Err(RecvError::Empty)
182
}
183
184
unsafe fn try_send(&self, value: T) -> Result<(), SendError<T>> {
185
self.try_send_impl(value, self.state.load(Ordering::Acquire))
186
}
187
188
unsafe fn try_recv(&self) -> Result<T, RecvError> {
189
self.try_recv_impl(self.state.load(Ordering::Acquire))
190
}
191
192
/// # Safety
193
/// You may not access this connector anymore as a sender after this call.
194
unsafe fn close_send(&self) {
195
self.state.fetch_or(CLOSED_BIT, Ordering::Relaxed);
196
self.recv_waker.wake();
197
}
198
199
/// # Safety
200
/// You may not access this connector anymore as a receiver after this call.
201
unsafe fn close_recv(&self) {
202
let state = self.state.fetch_or(CLOSED_BIT, Ordering::Acquire);
203
drop(self.try_recv_impl(state));
204
self.send_waker.wake();
205
}
206
}
207
208
pub struct SenderExt<T, S> {
209
connector: Arc<Connector<T, S>>,
210
}
211
212
unsafe impl<T: Send, S: Sync> Send for SenderExt<T, S> {}
213
214
impl<T, S> Drop for SenderExt<T, S> {
215
fn drop(&mut self) {
216
unsafe { self.connector.close_send() }
217
}
218
}
219
220
pub struct ReceiverExt<T, S> {
221
connector: Arc<Connector<T, S>>,
222
}
223
224
unsafe impl<T: Send, S: Sync> Send for ReceiverExt<T, S> {}
225
226
impl<T, S> Drop for ReceiverExt<T, S> {
227
fn drop(&mut self) {
228
unsafe { self.connector.close_recv() }
229
}
230
}
231
232
pin_project! {
233
pub struct SendFuture<'a, T, S> {
234
connector: &'a Connector<T, S>,
235
value: Option<T>,
236
}
237
}
238
239
unsafe impl<T: Send, S: Sync> Send for SendFuture<'_, T, S> {}
240
241
impl<T: Send, S: Sync> SenderExt<T, S> {
242
/// Returns a future that when awaited will send the value to the [`ReceiverExt`].
243
/// Returns Err(value) if the connector is closed.
244
#[must_use]
245
pub fn send(&mut self, value: T) -> SendFuture<'_, T, S> {
246
SendFuture {
247
connector: &self.connector,
248
value: Some(value),
249
}
250
}
251
252
#[allow(unused)]
253
pub fn try_send(&mut self, value: T) -> Result<(), SendError<T>> {
254
unsafe { self.connector.try_send(value) }
255
}
256
257
pub fn shared(&self) -> &S {
258
&self.connector.shared
259
}
260
}
261
262
impl<T, S> std::future::Future for SendFuture<'_, T, S> {
263
type Output = Result<(), T>;
264
265
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
266
assert!(
267
self.value.is_some(),
268
"re-poll after Poll::Ready in connector SendFuture"
269
);
270
unsafe { self.connector.poll_send(self.project().value, cx.waker()) }
271
}
272
}
273
274
pin_project! {
275
pub struct RecvFuture<'a, T, S> {
276
connector: &'a Connector<T, S>,
277
done: bool,
278
}
279
}
280
281
unsafe impl<T: Send, S: Sync> Send for RecvFuture<'_, T, S> {}
282
283
impl<T: Send, S: Sync> ReceiverExt<T, S> {
284
/// Returns a future that when awaited will return `Ok(value)` once the
285
/// value is received, or returns `Err(())` if the [`SenderExt`] was dropped
286
/// before sending a value.
287
#[must_use]
288
pub fn recv(&mut self) -> RecvFuture<'_, T, S> {
289
RecvFuture {
290
connector: &self.connector,
291
done: false,
292
}
293
}
294
295
#[allow(unused)]
296
pub fn try_recv(&mut self) -> Result<T, RecvError> {
297
unsafe { self.connector.try_recv() }
298
}
299
300
pub fn shared(&self) -> &S {
301
&self.connector.shared
302
}
303
}
304
305
impl<T, S> std::future::Future for RecvFuture<'_, T, S> {
306
type Output = Result<T, ()>;
307
308
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
309
assert!(
310
!self.done,
311
"re-poll after Poll::Ready in connector SendFuture"
312
);
313
unsafe { self.connector.poll_recv(cx.waker()) }
314
}
315
}
316
317