Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-stream/src/async_primitives/task_parker.rs
6939 views
1
use std::future::Future;
2
use std::pin::Pin;
3
use std::sync::atomic::{AtomicU8, Ordering};
4
use std::task::{Context, Poll, Waker};
5
6
use parking_lot::Mutex;
7
8
#[derive(Default)]
9
pub struct TaskParker {
10
state: AtomicU8,
11
waker: Mutex<Option<Waker>>,
12
}
13
14
impl TaskParker {
15
const RUNNING: u8 = 0;
16
const PREPARING_TO_PARK: u8 = 1;
17
const PARKED: u8 = 2;
18
19
/// Returns a future that when awaited parks this task.
20
///
21
/// Any notifications between calls to park and the await will cancel
22
/// the park attempt.
23
pub fn park(&self) -> TaskParkFuture<'_> {
24
self.state.store(Self::PREPARING_TO_PARK, Ordering::SeqCst);
25
TaskParkFuture { parker: self }
26
}
27
28
/// Unparks the parked task, if it was parked.
29
pub fn unpark(&self) {
30
let state = self.state.load(Ordering::SeqCst);
31
if state != Self::RUNNING {
32
let old_state = self.state.swap(Self::RUNNING, Ordering::SeqCst);
33
if old_state == Self::PARKED {
34
if let Some(w) = self.waker.lock().take() {
35
w.wake();
36
}
37
}
38
}
39
}
40
}
41
42
pub struct TaskParkFuture<'a> {
43
parker: &'a TaskParker,
44
}
45
46
impl Future for TaskParkFuture<'_> {
47
type Output = ();
48
49
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
50
let mut state = self.parker.state.load(Ordering::SeqCst);
51
loop {
52
match state {
53
TaskParker::RUNNING => return Poll::Ready(()),
54
55
TaskParker::PARKED => {
56
// Refresh our waker.
57
match &mut *self.parker.waker.lock() {
58
Some(w) => w.clone_from(cx.waker()),
59
None => return Poll::Ready(()), // Apparently someone woke us up.
60
}
61
},
62
TaskParker::PREPARING_TO_PARK => {
63
// Install waker first before publishing that we're parked
64
// to prevent missed notifications.
65
*self.parker.waker.lock() = Some(cx.waker().clone());
66
match self.parker.state.compare_exchange_weak(
67
TaskParker::PREPARING_TO_PARK,
68
TaskParker::PARKED,
69
Ordering::SeqCst,
70
Ordering::SeqCst,
71
) {
72
Ok(_) => return Poll::Pending,
73
Err(s) => state = s,
74
}
75
},
76
_ => unreachable!(),
77
}
78
}
79
}
80
}
81
82