Path: blob/main/crates/polars-stream/src/async_executor/task.rs
8379 views
#![allow(unsafe_op_in_unsafe_fn)]1use std::any::Any;2use std::future::Future;3use std::panic::{AssertUnwindSafe, catch_unwind, resume_unwind};4use std::pin::Pin;5use std::sync::Arc;6use std::sync::atomic::{AtomicU8, Ordering};7use std::task::{Context, Poll, Wake, Waker};89use atomic_waker::AtomicWaker;10use parking_lot::Mutex;11use polars_error::signals::try_raise_keyboard_interrupt;1213/// The state of the task. Can't be part of the TaskData enum as it needs to be14/// atomically updateable, even when we hold the lock on the data.15#[derive(Default)]16struct TaskState {17state: AtomicU8,18}1920impl TaskState {21/// Default state, not running, not scheduled.22const IDLE: u8 = 0;2324/// Task is scheduled, that is (task.schedule)(task) was called.25const SCHEDULED: u8 = 1;2627/// Task is currently running.28const RUNNING: u8 = 2;2930/// Task notified while running.31const NOTIFIED_WHILE_RUNNING: u8 = 3;3233/// Wake this task. Returns true if task.schedule should be called.34fn wake(&self) -> bool {35self.state36.fetch_update(Ordering::Release, Ordering::Relaxed, |state| match state {37Self::SCHEDULED | Self::NOTIFIED_WHILE_RUNNING => None,38Self::RUNNING => Some(Self::NOTIFIED_WHILE_RUNNING),39Self::IDLE => Some(Self::SCHEDULED),40_ => unreachable!("invalid TaskState"),41})42.map(|state| state == Self::IDLE)43.unwrap_or(false)44}4546/// Start running this task.47fn start_running(&self) {48assert_eq!(self.state.load(Ordering::Acquire), Self::SCHEDULED);49self.state.store(Self::RUNNING, Ordering::Relaxed);50}5152/// Done running this task. Returns true if task.schedule should be called.53fn reschedule_after_running(&self) -> bool {54self.state55.fetch_update(Ordering::Release, Ordering::Relaxed, |state| match state {56Self::RUNNING => Some(Self::IDLE),57Self::NOTIFIED_WHILE_RUNNING => Some(Self::SCHEDULED),58_ => panic!("TaskState::reschedule_after_running() called on invalid state"),59})60.map(|old_state| old_state == Self::NOTIFIED_WHILE_RUNNING)61.unwrap_or(false)62}63}6465enum TaskData<F: Future> {66Empty,67Polling(F, Waker),68Ready(F::Output),69Panic(Box<dyn Any + Send + 'static>),70Cancelled,71Joined,72}7374struct Task<F: Future, S, M> {75state: TaskState,76data: Mutex<TaskData<F>>,77join_waker: AtomicWaker,78schedule: S,79metadata: M,80}8182impl<'a, F, S, M> Task<F, S, M>83where84F: Future + Send + 'a,85F::Output: Send + 'static,86S: Fn(Arc<dyn Runnable<M>>) + Send + Sync + Copy + 'static,87M: Send + Sync + 'static,88{89/// # Safety90/// It is the responsibility of the caller that before lifetime 'a ends the91/// task is either polled to completion or cancelled.92unsafe fn spawn(future: F, schedule: S, metadata: M) -> Arc<Self> {93let task = Arc::new(Self {94state: TaskState::default(),95data: Mutex::new(TaskData::Empty),96join_waker: AtomicWaker::new(),97schedule,98metadata,99});100101let waker = unsafe { Waker::from_raw(std_shim::raw_waker(task.clone())) };102*task.data.try_lock().unwrap() = TaskData::Polling(future, waker);103task104}105106fn into_dyn(self: Arc<Self>) -> Arc<dyn DynTask<F::Output, M>> {107let arc: Arc<dyn DynTask<F::Output, M> + 'a> = self;108let arc: Arc<dyn DynTask<F::Output, M>> = unsafe { std::mem::transmute(arc) };109arc110}111}112113impl<F, S, M> Wake for Task<F, S, M>114where115F: Future + Send,116F::Output: Send + 'static,117S: Fn(Arc<dyn Runnable<M>>) + Send + Sync + Copy + 'static,118M: Send + Sync + 'static,119{120fn wake(self: Arc<Self>) {121if self.state.wake() {122let schedule = self.schedule;123(schedule)(self.into_dyn());124}125}126127fn wake_by_ref(self: &Arc<Self>) {128self.clone().wake()129}130}131132/// Partially type-erased task: no future.133pub trait DynTask<T, M>: Send + Sync + Runnable<M> + Joinable<T> + Cancellable {}134135impl<F, S, M> DynTask<F::Output, M> for Task<F, S, M>136where137F: Future + Send,138F::Output: Send + 'static,139S: Fn(Arc<dyn Runnable<M>>) + Send + Sync + Copy + 'static,140M: Send + Sync + 'static,141{142}143144/// Partially type-erased task: no future or return type.145pub trait Runnable<M>: Send + Sync {146/// Gives the metadata for this task.147fn metadata(&self) -> &M;148149/// Runs a task, and returns true if the task is done.150fn run(self: Arc<Self>) -> bool;151152/// Schedules this task.153fn schedule(self: Arc<Self>);154}155156impl<F, S, M> Runnable<M> for Task<F, S, M>157where158F: Future + Send,159F::Output: Send + 'static,160S: Fn(Arc<dyn Runnable<M>>) + Send + Sync + Copy + 'static,161M: Send + Sync + 'static,162{163fn metadata(&self) -> &M {164&self.metadata165}166167fn run(self: Arc<Self>) -> bool {168let mut data = self.data.lock();169170let poll_result = match &mut *data {171TaskData::Polling(future, waker) => {172self.state.start_running();173// SAFETY: we always store a Task in an Arc and never move it.174let fut = unsafe { Pin::new_unchecked(future) };175let mut ctx = Context::from_waker(waker);176catch_unwind(AssertUnwindSafe(|| {177try_raise_keyboard_interrupt();178fut.poll(&mut ctx)179}))180},181TaskData::Cancelled => return true,182_ => unreachable!("invalid TaskData when polling"),183};184185*data = match poll_result {186Err(error) => TaskData::Panic(error),187Ok(Poll::Ready(output)) => TaskData::Ready(output),188Ok(Poll::Pending) => {189drop(data);190if self.state.reschedule_after_running() {191let schedule = self.schedule;192(schedule)(self.into_dyn());193}194return false;195},196};197198drop(data);199self.join_waker.wake();200true201}202203fn schedule(self: Arc<Self>) {204if self.state.wake() {205(self.schedule)(self.clone().into_dyn());206}207}208}209210/// Partially type-erased task: no future or metadata.211pub trait Joinable<T>: Send + Sync + Cancellable {212fn poll_join(&self, ctx: &mut Context<'_>) -> Poll<T>;213}214215impl<F, S, M> Joinable<F::Output> for Task<F, S, M>216where217F: Future + Send,218F::Output: Send + 'static,219S: Fn(Arc<dyn Runnable<M>>) + Send + Sync + Copy + 'static,220M: Send + Sync + 'static,221{222fn poll_join(&self, cx: &mut Context<'_>) -> Poll<F::Output> {223self.join_waker.register(cx.waker());224if let Some(mut data) = self.data.try_lock() {225if matches!(*data, TaskData::Empty | TaskData::Polling(..)) {226return Poll::Pending;227}228229match core::mem::replace(&mut *data, TaskData::Joined) {230TaskData::Ready(output) => Poll::Ready(output),231TaskData::Panic(error) => resume_unwind(error),232TaskData::Cancelled => panic!("joined on cancelled task"),233_ => unreachable!("invalid TaskData when joining"),234}235} else {236Poll::Pending237}238}239}240241/// Fully type-erased task.242pub trait Cancellable: Send + Sync {243fn cancel(&self);244}245246impl<F, S, M> Cancellable for Task<F, S, M>247where248F: Future + Send,249F::Output: Send + 'static,250S: Send + Sync + 'static,251M: Send + Sync + 'static,252{253fn cancel(&self) {254let mut data = self.data.lock();255match *data {256// Already done.257TaskData::Panic(_) | TaskData::Joined => {},258259// Still in-progress, cancel.260_ => {261*data = TaskData::Cancelled;262if let Some(join_waker) = self.join_waker.take() {263join_waker.wake();264}265},266}267}268}269270pub fn spawn<F, S, M>(future: F, schedule: S, metadata: M) -> Arc<dyn DynTask<F::Output, M>>271where272F: Future + Send + 'static,273F::Output: Send + 'static,274S: Fn(Arc<dyn Runnable<M>>) + Send + Sync + Copy + 'static,275M: Send + Sync + 'static,276{277unsafe { Task::spawn(future, schedule, metadata) }.into_dyn()278}279280/// Takes a future and turns it into a runnable task with associated metadata.281///282/// When the task is pending its waker will be set to call schedule283/// with the runnable.284pub unsafe fn spawn_with_lifetime<'a, F, S, M>(285future: F,286schedule: S,287metadata: M,288) -> Arc<dyn DynTask<F::Output, M>>289where290F: Future + Send + 'a,291F::Output: Send + 'static,292S: Fn(Arc<dyn Runnable<M>>) + Send + Sync + Copy + 'static,293M: Send + Sync + 'static,294{295Task::spawn(future, schedule, metadata).into_dyn()296}297298// Copied from the standard library, except without the 'static bound.299mod std_shim {300use std::mem::ManuallyDrop;301use std::sync::Arc;302use std::task::{RawWaker, RawWakerVTable, Wake};303304#[inline(always)]305pub unsafe fn raw_waker<'a, W: Wake + Send + Sync + 'a>(waker: Arc<W>) -> RawWaker {306// Increment the reference count of the arc to clone it.307//308// The #[inline(always)] is to ensure that raw_waker and clone_waker are309// always generated in the same code generation unit as one another, and310// therefore that the structurally identical const-promoted RawWakerVTable311// within both functions is deduplicated at LLVM IR code generation time.312// This allows optimizing Waker::will_wake to a single pointer comparison of313// the vtable pointers, rather than comparing all four function pointers314// within the vtables.315#[inline(always)]316unsafe fn clone_waker<W: Wake + Send + Sync>(waker: *const ()) -> RawWaker {317unsafe { Arc::increment_strong_count(waker as *const W) };318RawWaker::new(319waker,320&RawWakerVTable::new(321clone_waker::<W>,322wake::<W>,323wake_by_ref::<W>,324drop_waker::<W>,325),326)327}328329// Wake by value, moving the Arc into the Wake::wake function330unsafe fn wake<W: Wake + Send + Sync>(waker: *const ()) {331let waker = unsafe { Arc::from_raw(waker as *const W) };332<W as Wake>::wake(waker);333}334335// Wake by reference, wrap the waker in ManuallyDrop to avoid dropping it336unsafe fn wake_by_ref<W: Wake + Send + Sync>(waker: *const ()) {337let waker = unsafe { ManuallyDrop::new(Arc::from_raw(waker as *const W)) };338<W as Wake>::wake_by_ref(&waker);339}340341// Decrement the reference count of the Arc on drop342unsafe fn drop_waker<W: Wake + Send + Sync>(waker: *const ()) {343unsafe { Arc::decrement_strong_count(waker as *const W) };344}345346RawWaker::new(347Arc::into_raw(waker) as *const (),348&RawWakerVTable::new(349clone_waker::<W>,350wake::<W>,351wake_by_ref::<W>,352drop_waker::<W>,353),354)355}356}357358359