Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-stream/src/async_primitives/wait_group.rs
6939 views
1
use std::future::Future;
2
use std::pin::Pin;
3
use std::sync::Arc;
4
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
5
use std::task::{Context, Poll, Waker};
6
7
use parking_lot::Mutex;
8
9
#[derive(Default, Debug)]
10
struct WaitGroupInner {
11
waker: Mutex<Option<Waker>>,
12
token_count: AtomicUsize,
13
is_waiting: AtomicBool,
14
}
15
16
#[derive(Default)]
17
pub struct WaitGroup {
18
inner: Arc<WaitGroupInner>,
19
}
20
21
impl WaitGroup {
22
/// Creates a token.
23
pub fn token(&self) -> WaitToken {
24
self.inner.token_count.fetch_add(1, Ordering::Relaxed);
25
WaitToken {
26
inner: Arc::clone(&self.inner),
27
}
28
}
29
30
/// Waits until all created tokens are dropped.
31
///
32
/// # Panics
33
/// Panics if there is more than one simultaneous waiter.
34
pub async fn wait(&self) {
35
let was_waiting = self.inner.is_waiting.swap(true, Ordering::Relaxed);
36
assert!(!was_waiting);
37
WaitGroupFuture { inner: &self.inner }.await
38
}
39
}
40
41
struct WaitGroupFuture<'a> {
42
inner: &'a Arc<WaitGroupInner>,
43
}
44
45
impl Future for WaitGroupFuture<'_> {
46
type Output = ();
47
48
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
49
if self.inner.token_count.load(Ordering::Acquire) == 0 {
50
return Poll::Ready(());
51
}
52
53
// Check count again while holding lock to prevent missed notifications.
54
let mut waker_lock = self.inner.waker.lock();
55
if self.inner.token_count.load(Ordering::Acquire) == 0 {
56
return Poll::Ready(());
57
}
58
59
let waker = cx.waker().clone();
60
*waker_lock = Some(waker);
61
Poll::Pending
62
}
63
}
64
65
impl Drop for WaitGroupFuture<'_> {
66
fn drop(&mut self) {
67
self.inner.is_waiting.store(false, Ordering::Relaxed);
68
}
69
}
70
71
#[derive(Debug)]
72
pub struct WaitToken {
73
inner: Arc<WaitGroupInner>,
74
}
75
76
impl Clone for WaitToken {
77
fn clone(&self) -> Self {
78
self.inner.token_count.fetch_add(1, Ordering::Relaxed);
79
Self {
80
inner: self.inner.clone(),
81
}
82
}
83
}
84
85
impl Drop for WaitToken {
86
fn drop(&mut self) {
87
// Token count was 1, we must notify.
88
if self.inner.token_count.fetch_sub(1, Ordering::Release) == 1 {
89
if let Some(w) = self.inner.waker.lock().take() {
90
w.wake();
91
}
92
}
93
}
94
}
95
96