Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-stream/src/async_primitives/connector.rs
6939 views
1
#![allow(unsafe_op_in_unsafe_fn)]
2
use std::cell::UnsafeCell;
3
use std::mem::MaybeUninit;
4
use std::pin::Pin;
5
use std::sync::Arc;
6
use std::sync::atomic::{AtomicU8, Ordering};
7
use std::task::{Context, Poll, Waker};
8
9
use atomic_waker::AtomicWaker;
10
use pin_project_lite::pin_project;
11
12
/// Single-producer, single-consumer capacity-one channel.
13
pub fn connector<T>() -> (Sender<T>, Receiver<T>) {
14
let connector = Arc::new(Connector::default());
15
(
16
Sender {
17
connector: connector.clone(),
18
},
19
Receiver { connector },
20
)
21
}
22
23
/*
24
For UnsafeCell safety, a sender may only set the FULL_BIT (giving exclusive
25
access to value to the receiver), and a receiver may only unset the FULL_BIT
26
(giving exclusive access back to the sender). Setting/clearing the FULL_BIT
27
must be done with a Release ordering, and before reading/writing the value
28
the FULL_BIT must be checked with an Acquire ordering.
29
30
The exception is when the closed bit is set, at that point the unclosed
31
end has full exclusive access.
32
*/
33
34
const FULL_BIT: u8 = 0b1;
35
const CLOSED_BIT: u8 = 0b10;
36
const WAITING_BIT: u8 = 0b100;
37
38
#[repr(align(64))]
39
struct Connector<T> {
40
send_waker: AtomicWaker,
41
recv_waker: AtomicWaker,
42
value: UnsafeCell<MaybeUninit<T>>,
43
state: AtomicU8,
44
}
45
46
impl<T> Default for Connector<T> {
47
fn default() -> Self {
48
Self {
49
send_waker: AtomicWaker::new(),
50
recv_waker: AtomicWaker::new(),
51
value: UnsafeCell::new(MaybeUninit::uninit()),
52
state: AtomicU8::new(0),
53
}
54
}
55
}
56
57
pub enum SendError<T> {
58
Full(T),
59
Closed(T),
60
}
61
62
pub enum RecvError {
63
Empty,
64
Closed,
65
}
66
67
// SAFETY: all the send methods may only be called from a single sender at a
68
// time, and similarly for all the recv methods from a single receiver.
69
impl<T> Connector<T> {
70
unsafe fn poll_send(&self, value: &mut Option<T>, waker: &Waker) -> Poll<Result<(), T>> {
71
if let Some(v) = value.take() {
72
let mut state = self.state.load(Ordering::Acquire);
73
if state & FULL_BIT == FULL_BIT {
74
self.send_waker.register(waker);
75
let (Ok(s) | Err(s)) = self.state.compare_exchange(
76
state,
77
state | WAITING_BIT,
78
Ordering::Relaxed,
79
Ordering::Acquire, // Receiver updated, re-acquire.
80
);
81
state = s;
82
}
83
84
match self.try_send_impl(v, state) {
85
Ok(()) => {},
86
Err(SendError::Closed(v)) => return Poll::Ready(Err(v)),
87
Err(SendError::Full(v)) => {
88
*value = Some(v);
89
return Poll::Pending;
90
},
91
}
92
}
93
94
Poll::Ready(Ok(()))
95
}
96
97
unsafe fn try_send_impl(&self, value: T, state: u8) -> Result<(), SendError<T>> {
98
if state & CLOSED_BIT == CLOSED_BIT {
99
return Err(SendError::Closed(value));
100
}
101
if state & FULL_BIT == FULL_BIT {
102
return Err(SendError::Full(value));
103
}
104
105
unsafe {
106
self.value.get().write(MaybeUninit::new(value));
107
let state = self.state.swap(FULL_BIT, Ordering::Release);
108
if state & WAITING_BIT == WAITING_BIT {
109
self.recv_waker.wake();
110
}
111
if state & CLOSED_BIT == CLOSED_BIT {
112
// SAFETY: no synchronization needed, we are the only one left.
113
// Restore the closed bit we just overwrote.
114
self.state.store(CLOSED_BIT, Ordering::Relaxed);
115
return Err(SendError::Closed(self.value.get().read().assume_init()));
116
}
117
}
118
119
Ok(())
120
}
121
122
unsafe fn poll_recv(&self, waker: &Waker) -> Poll<Result<T, ()>> {
123
let mut state = self.state.load(Ordering::Acquire);
124
if state & FULL_BIT == 0 {
125
self.recv_waker.register(waker);
126
let (Ok(s) | Err(s)) = self.state.compare_exchange(
127
state,
128
state | WAITING_BIT,
129
Ordering::Relaxed,
130
Ordering::Acquire, // Sender updated, re-acquire.
131
);
132
state = s;
133
}
134
135
match self.try_recv_impl(state) {
136
Ok(v) => Poll::Ready(Ok(v)),
137
Err(RecvError::Empty) => Poll::Pending,
138
Err(RecvError::Closed) => Poll::Ready(Err(())),
139
}
140
}
141
142
unsafe fn try_recv_impl(&self, state: u8) -> Result<T, RecvError> {
143
if state & FULL_BIT == FULL_BIT {
144
unsafe {
145
let ret = self.value.get().read().assume_init();
146
let state = self.state.swap(0, Ordering::Release);
147
if state & WAITING_BIT == WAITING_BIT {
148
self.send_waker.wake();
149
}
150
if state & CLOSED_BIT == CLOSED_BIT {
151
// Restore the closed bit we just overwrote.
152
self.state.store(CLOSED_BIT, Ordering::Relaxed);
153
}
154
return Ok(ret);
155
}
156
}
157
158
// Check closed bit last so we do receive any last element sent before
159
// closing sender.
160
if state & CLOSED_BIT == CLOSED_BIT {
161
return Err(RecvError::Closed);
162
}
163
164
Err(RecvError::Empty)
165
}
166
167
unsafe fn try_send(&self, value: T) -> Result<(), SendError<T>> {
168
self.try_send_impl(value, self.state.load(Ordering::Acquire))
169
}
170
171
unsafe fn try_recv(&self) -> Result<T, RecvError> {
172
self.try_recv_impl(self.state.load(Ordering::Acquire))
173
}
174
175
/// # Safety
176
/// You may not access this connector anymore as a sender after this call.
177
unsafe fn close_send(&self) {
178
self.state.fetch_or(CLOSED_BIT, Ordering::Relaxed);
179
self.recv_waker.wake();
180
}
181
182
/// # Safety
183
/// You may not access this connector anymore as a receiver after this call.
184
unsafe fn close_recv(&self) {
185
let state = self.state.fetch_or(CLOSED_BIT, Ordering::Acquire);
186
drop(self.try_recv_impl(state));
187
self.send_waker.wake();
188
}
189
}
190
191
pub struct Sender<T> {
192
connector: Arc<Connector<T>>,
193
}
194
195
unsafe impl<T: Send> Send for Sender<T> {}
196
197
impl<T> Drop for Sender<T> {
198
fn drop(&mut self) {
199
unsafe { self.connector.close_send() }
200
}
201
}
202
203
pub struct Receiver<T> {
204
connector: Arc<Connector<T>>,
205
}
206
207
unsafe impl<T: Send> Send for Receiver<T> {}
208
209
impl<T> Drop for Receiver<T> {
210
fn drop(&mut self) {
211
unsafe { self.connector.close_recv() }
212
}
213
}
214
215
pin_project! {
216
pub struct SendFuture<'a, T> {
217
connector: &'a Connector<T>,
218
value: Option<T>,
219
}
220
}
221
222
unsafe impl<T: Send> Send for SendFuture<'_, T> {}
223
224
impl<T: Send> Sender<T> {
225
/// Returns a future that when awaited will send the value to the [`Receiver`].
226
/// Returns Err(value) if the connector is closed.
227
#[must_use]
228
pub fn send(&mut self, value: T) -> SendFuture<'_, T> {
229
SendFuture {
230
connector: &self.connector,
231
value: Some(value),
232
}
233
}
234
235
#[allow(unused)]
236
pub fn try_send(&mut self, value: T) -> Result<(), SendError<T>> {
237
unsafe { self.connector.try_send(value) }
238
}
239
}
240
241
impl<T> std::future::Future for SendFuture<'_, T> {
242
type Output = Result<(), T>;
243
244
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
245
assert!(
246
self.value.is_some(),
247
"re-poll after Poll::Ready in connector SendFuture"
248
);
249
unsafe { self.connector.poll_send(self.project().value, cx.waker()) }
250
}
251
}
252
253
pin_project! {
254
pub struct RecvFuture<'a, T> {
255
connector: &'a Connector<T>,
256
done: bool,
257
}
258
}
259
260
unsafe impl<T: Send> Send for RecvFuture<'_, T> {}
261
262
impl<T: Send> Receiver<T> {
263
/// Returns a future that when awaited will return `Ok(value)` once the
264
/// value is received, or returns `Err(())` if the [`Sender`] was dropped
265
/// before sending a value.
266
#[must_use]
267
pub fn recv(&mut self) -> RecvFuture<'_, T> {
268
RecvFuture {
269
connector: &self.connector,
270
done: false,
271
}
272
}
273
274
#[allow(unused)]
275
pub fn try_recv(&mut self) -> Result<T, RecvError> {
276
unsafe { self.connector.try_recv() }
277
}
278
}
279
280
impl<T> std::future::Future for RecvFuture<'_, T> {
281
type Output = Result<T, ()>;
282
283
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
284
assert!(
285
!self.done,
286
"re-poll after Poll::Ready in connector SendFuture"
287
);
288
unsafe { self.connector.poll_recv(cx.waker()) }
289
}
290
}
291
292