Path: blob/main/crates/polars-stream/src/async_primitives/task_parker.rs
6939 views
use std::future::Future;1use std::pin::Pin;2use std::sync::atomic::{AtomicU8, Ordering};3use std::task::{Context, Poll, Waker};45use parking_lot::Mutex;67#[derive(Default)]8pub struct TaskParker {9state: AtomicU8,10waker: Mutex<Option<Waker>>,11}1213impl TaskParker {14const RUNNING: u8 = 0;15const PREPARING_TO_PARK: u8 = 1;16const PARKED: u8 = 2;1718/// Returns a future that when awaited parks this task.19///20/// Any notifications between calls to park and the await will cancel21/// the park attempt.22pub fn park(&self) -> TaskParkFuture<'_> {23self.state.store(Self::PREPARING_TO_PARK, Ordering::SeqCst);24TaskParkFuture { parker: self }25}2627/// Unparks the parked task, if it was parked.28pub fn unpark(&self) {29let state = self.state.load(Ordering::SeqCst);30if state != Self::RUNNING {31let old_state = self.state.swap(Self::RUNNING, Ordering::SeqCst);32if old_state == Self::PARKED {33if let Some(w) = self.waker.lock().take() {34w.wake();35}36}37}38}39}4041pub struct TaskParkFuture<'a> {42parker: &'a TaskParker,43}4445impl Future for TaskParkFuture<'_> {46type Output = ();4748fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {49let mut state = self.parker.state.load(Ordering::SeqCst);50loop {51match state {52TaskParker::RUNNING => return Poll::Ready(()),5354TaskParker::PARKED => {55// Refresh our waker.56match &mut *self.parker.waker.lock() {57Some(w) => w.clone_from(cx.waker()),58None => return Poll::Ready(()), // Apparently someone woke us up.59}60},61TaskParker::PREPARING_TO_PARK => {62// Install waker first before publishing that we're parked63// to prevent missed notifications.64*self.parker.waker.lock() = Some(cx.waker().clone());65match self.parker.state.compare_exchange_weak(66TaskParker::PREPARING_TO_PARK,67TaskParker::PARKED,68Ordering::SeqCst,69Ordering::SeqCst,70) {71Ok(_) => return Poll::Pending,72Err(s) => state = s,73}74},75_ => unreachable!(),76}77}78}79}808182