Path: blob/main/crates/polars-stream/src/async_primitives/wait_group.rs
6939 views
use std::future::Future;1use std::pin::Pin;2use std::sync::Arc;3use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};4use std::task::{Context, Poll, Waker};56use parking_lot::Mutex;78#[derive(Default, Debug)]9struct WaitGroupInner {10waker: Mutex<Option<Waker>>,11token_count: AtomicUsize,12is_waiting: AtomicBool,13}1415#[derive(Default)]16pub struct WaitGroup {17inner: Arc<WaitGroupInner>,18}1920impl WaitGroup {21/// Creates a token.22pub fn token(&self) -> WaitToken {23self.inner.token_count.fetch_add(1, Ordering::Relaxed);24WaitToken {25inner: Arc::clone(&self.inner),26}27}2829/// Waits until all created tokens are dropped.30///31/// # Panics32/// Panics if there is more than one simultaneous waiter.33pub async fn wait(&self) {34let was_waiting = self.inner.is_waiting.swap(true, Ordering::Relaxed);35assert!(!was_waiting);36WaitGroupFuture { inner: &self.inner }.await37}38}3940struct WaitGroupFuture<'a> {41inner: &'a Arc<WaitGroupInner>,42}4344impl Future for WaitGroupFuture<'_> {45type Output = ();4647fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {48if self.inner.token_count.load(Ordering::Acquire) == 0 {49return Poll::Ready(());50}5152// Check count again while holding lock to prevent missed notifications.53let mut waker_lock = self.inner.waker.lock();54if self.inner.token_count.load(Ordering::Acquire) == 0 {55return Poll::Ready(());56}5758let waker = cx.waker().clone();59*waker_lock = Some(waker);60Poll::Pending61}62}6364impl Drop for WaitGroupFuture<'_> {65fn drop(&mut self) {66self.inner.is_waiting.store(false, Ordering::Relaxed);67}68}6970#[derive(Debug)]71pub struct WaitToken {72inner: Arc<WaitGroupInner>,73}7475impl Clone for WaitToken {76fn clone(&self) -> Self {77self.inner.token_count.fetch_add(1, Ordering::Relaxed);78Self {79inner: self.inner.clone(),80}81}82}8384impl Drop for WaitToken {85fn drop(&mut self) {86// Token count was 1, we must notify.87if self.inner.token_count.fetch_sub(1, Ordering::Release) == 1 {88if let Some(w) = self.inner.waker.lock().take() {89w.wake();90}91}92}93}949596