Path: blob/main/crates/polars-stream/src/async_executor/task.rs
6939 views
#![allow(unsafe_op_in_unsafe_fn)]1use std::any::Any;2use std::future::Future;3use std::panic::{AssertUnwindSafe, catch_unwind, resume_unwind};4use std::pin::Pin;5use std::sync::atomic::{AtomicU8, Ordering};6use std::sync::{Arc, Weak};7use std::task::{Context, Poll, Wake, Waker};89use atomic_waker::AtomicWaker;10use parking_lot::Mutex;11use polars_error::signals::try_raise_keyboard_interrupt;1213/// The state of the task. Can't be part of the TaskData enum as it needs to be14/// atomically updateable, even when we hold the lock on the data.15#[derive(Default)]16struct TaskState {17state: AtomicU8,18}1920impl TaskState {21/// Default state, not running, not scheduled.22const IDLE: u8 = 0;2324/// Task is scheduled, that is (task.schedule)(task) was called.25const SCHEDULED: u8 = 1;2627/// Task is currently running.28const RUNNING: u8 = 2;2930/// Task notified while running.31const NOTIFIED_WHILE_RUNNING: u8 = 3;3233/// Wake this task. Returns true if task.schedule should be called.34fn wake(&self) -> bool {35self.state36.fetch_update(Ordering::Release, Ordering::Relaxed, |state| match state {37Self::SCHEDULED | Self::NOTIFIED_WHILE_RUNNING => None,38Self::RUNNING => Some(Self::NOTIFIED_WHILE_RUNNING),39Self::IDLE => Some(Self::SCHEDULED),40_ => unreachable!("invalid TaskState"),41})42.map(|state| state == Self::IDLE)43.unwrap_or(false)44}4546/// Start running this task.47fn start_running(&self) {48assert_eq!(self.state.load(Ordering::Acquire), Self::SCHEDULED);49self.state.store(Self::RUNNING, Ordering::Relaxed);50}5152/// Done running this task. Returns true if task.schedule should be called.53fn reschedule_after_running(&self) -> bool {54self.state55.fetch_update(Ordering::Release, Ordering::Relaxed, |state| match state {56Self::RUNNING => Some(Self::IDLE),57Self::NOTIFIED_WHILE_RUNNING => Some(Self::SCHEDULED),58_ => panic!("TaskState::reschedule_after_running() called on invalid state"),59})60.map(|old_state| old_state == Self::NOTIFIED_WHILE_RUNNING)61.unwrap_or(false)62}63}6465enum TaskData<F: Future> {66Empty,67Polling(F, Waker),68Ready(F::Output),69Panic(Box<dyn Any + Send + 'static>),70Cancelled,71Joined,72}7374struct Task<F: Future, S, M> {75state: TaskState,76data: Mutex<TaskData<F>>,77join_waker: AtomicWaker,78schedule: S,79metadata: M,80}8182impl<'a, F, S, M> Task<F, S, M>83where84F: Future + Send + 'a,85F::Output: Send + 'static,86S: Fn(Runnable<M>) + Send + Sync + Copy + 'static,87M: Send + Sync + 'static,88{89/// # Safety90/// It is the responsibility of the caller that before lifetime 'a ends the91/// task is either polled to completion or cancelled.92unsafe fn spawn(future: F, schedule: S, metadata: M) -> Arc<Self> {93let task = Arc::new(Self {94state: TaskState::default(),95data: Mutex::new(TaskData::Empty),96join_waker: AtomicWaker::new(),97schedule,98metadata,99});100101let waker = unsafe { Waker::from_raw(std_shim::raw_waker(task.clone())) };102*task.data.try_lock().unwrap() = TaskData::Polling(future, waker);103task104}105106fn into_runnable(self: Arc<Self>) -> Runnable<M> {107let arc: Arc<dyn DynTask<M> + 'a> = self;108let arc: Arc<dyn DynTask<M>> = unsafe { std::mem::transmute(arc) };109Runnable(arc)110}111112fn into_join_handle(self: Arc<Self>) -> JoinHandle<F::Output> {113let arc: Arc<dyn Joinable<F::Output> + 'a> = self;114let arc: Arc<dyn Joinable<F::Output>> = unsafe { std::mem::transmute(arc) };115JoinHandle(Some(arc))116}117118fn into_cancel_handle(self: Arc<Self>) -> CancelHandle {119let arc: Arc<dyn Cancellable + 'a> = self;120let arc: Arc<dyn Cancellable> = unsafe { std::mem::transmute(arc) };121CancelHandle(Arc::downgrade(&arc))122}123}124125impl<F, S, M> Wake for Task<F, S, M>126where127F: Future + Send,128F::Output: Send + 'static,129S: Fn(Runnable<M>) + Send + Sync + Copy + 'static,130M: Send + Sync + 'static,131{132fn wake(self: Arc<Self>) {133if self.state.wake() {134let schedule = self.schedule;135(schedule)(self.into_runnable());136}137}138139fn wake_by_ref(self: &Arc<Self>) {140self.clone().wake()141}142}143144pub trait DynTask<M>: Send + Sync {145fn metadata(&self) -> &M;146fn run(self: Arc<Self>) -> bool;147fn schedule(self: Arc<Self>);148}149150impl<F, S, M> DynTask<M> for Task<F, S, M>151where152F: Future + Send,153F::Output: Send + 'static,154S: Fn(Runnable<M>) + Send + Sync + Copy + 'static,155M: Send + Sync + 'static,156{157fn metadata(&self) -> &M {158&self.metadata159}160161fn run(self: Arc<Self>) -> bool {162let mut data = self.data.lock();163164let poll_result = match &mut *data {165TaskData::Polling(future, waker) => {166self.state.start_running();167// SAFETY: we always store a Task in an Arc and never move it.168let fut = unsafe { Pin::new_unchecked(future) };169let mut ctx = Context::from_waker(waker);170catch_unwind(AssertUnwindSafe(|| {171try_raise_keyboard_interrupt();172fut.poll(&mut ctx)173}))174},175TaskData::Cancelled => return true,176_ => unreachable!("invalid TaskData when polling"),177};178179*data = match poll_result {180Err(error) => TaskData::Panic(error),181Ok(Poll::Ready(output)) => TaskData::Ready(output),182Ok(Poll::Pending) => {183drop(data);184if self.state.reschedule_after_running() {185let schedule = self.schedule;186(schedule)(self.into_runnable());187}188return false;189},190};191192drop(data);193self.join_waker.wake();194true195}196197fn schedule(self: Arc<Self>) {198if self.state.wake() {199(self.schedule)(self.clone().into_runnable());200}201}202}203204trait Joinable<T>: Send + Sync {205fn cancel_handle(self: Arc<Self>) -> CancelHandle;206fn poll_join(&self, ctx: &mut Context<'_>) -> Poll<T>;207}208209impl<F, S, M> Joinable<F::Output> for Task<F, S, M>210where211F: Future + Send,212F::Output: Send + 'static,213S: Fn(Runnable<M>) + Send + Sync + Copy + 'static,214M: Send + Sync + 'static,215{216fn cancel_handle(self: Arc<Self>) -> CancelHandle {217self.into_cancel_handle()218}219220fn poll_join(&self, cx: &mut Context<'_>) -> Poll<F::Output> {221self.join_waker.register(cx.waker());222if let Some(mut data) = self.data.try_lock() {223if matches!(*data, TaskData::Empty | TaskData::Polling(..)) {224return Poll::Pending;225}226227match core::mem::replace(&mut *data, TaskData::Joined) {228TaskData::Ready(output) => Poll::Ready(output),229TaskData::Panic(error) => resume_unwind(error),230TaskData::Cancelled => panic!("joined on cancelled task"),231_ => unreachable!("invalid TaskData when joining"),232}233} else {234Poll::Pending235}236}237}238239trait Cancellable: Send + Sync {240fn cancel(&self);241}242243impl<F, S, M> Cancellable for Task<F, S, M>244where245F: Future + Send,246F::Output: Send + 'static,247S: Send + Sync + 'static,248M: Send + Sync + 'static,249{250fn cancel(&self) {251let mut data = self.data.lock();252match *data {253// Already done.254TaskData::Panic(_) | TaskData::Joined => {},255256// Still in-progress, cancel.257_ => {258*data = TaskData::Cancelled;259if let Some(join_waker) = self.join_waker.take() {260join_waker.wake();261}262},263}264}265}266267pub struct Runnable<M>(Arc<dyn DynTask<M>>);268269impl<M> Runnable<M> {270/// Gives the metadata for this task.271pub fn metadata(&self) -> &M {272self.0.metadata()273}274275/// Runs a task, and returns true if the task is done.276pub fn run(self) -> bool {277self.0.run()278}279280/// Schedules this task.281pub fn schedule(self) {282self.0.schedule()283}284}285286pub struct JoinHandle<T>(Option<Arc<dyn Joinable<T>>>);287pub struct CancelHandle(Weak<dyn Cancellable>);288pub struct AbortOnDropHandle<T> {289join_handle: JoinHandle<T>,290cancel_handle: CancelHandle,291}292293impl<T> JoinHandle<T> {294pub fn cancel_handle(&self) -> CancelHandle {295let arc = self296.0297.as_ref()298.expect("called cancel_handle on joined JoinHandle");299Arc::clone(arc).cancel_handle()300}301}302303impl<T> Future for JoinHandle<T> {304type Output = T;305306fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {307let joinable = self.0.take().expect("JoinHandle polled after completion");308309if let Poll::Ready(output) = joinable.poll_join(ctx) {310return Poll::Ready(output);311}312313self.0 = Some(joinable);314Poll::Pending315}316}317318impl CancelHandle {319pub fn cancel(&self) {320if let Some(t) = self.0.upgrade() {321t.cancel();322}323}324}325326impl<T> AbortOnDropHandle<T> {327pub fn new(join_handle: JoinHandle<T>) -> Self {328let cancel_handle = join_handle.cancel_handle();329Self {330join_handle,331cancel_handle,332}333}334}335336impl<T> Future for AbortOnDropHandle<T> {337type Output = T;338339fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {340Pin::new(&mut self.join_handle).poll(cx)341}342}343344impl<T> Drop for AbortOnDropHandle<T> {345fn drop(&mut self) {346self.cancel_handle.cancel();347}348}349350pub fn spawn<F, S, M>(future: F, schedule: S, metadata: M) -> (Runnable<M>, JoinHandle<F::Output>)351where352F: Future + Send + 'static,353F::Output: Send + 'static,354S: Fn(Runnable<M>) + Send + Sync + Copy + 'static,355M: Send + Sync + 'static,356{357let task = unsafe { Task::spawn(future, schedule, metadata) };358(task.clone().into_runnable(), task.into_join_handle())359}360361/// Takes a future and turns it into a runnable task with associated metadata.362///363/// When the task is pending its waker will be set to call schedule364/// with the runnable.365pub unsafe fn spawn_with_lifetime<'a, F, S, M>(366future: F,367schedule: S,368metadata: M,369) -> (Runnable<M>, JoinHandle<F::Output>)370where371F: Future + Send + 'a,372F::Output: Send + 'static,373S: Fn(Runnable<M>) + Send + Sync + Copy + 'static,374M: Send + Sync + 'static,375{376let task = Task::spawn(future, schedule, metadata);377(task.clone().into_runnable(), task.into_join_handle())378}379380// Copied from the standard library, except without the 'static bound.381mod std_shim {382use std::mem::ManuallyDrop;383use std::sync::Arc;384use std::task::{RawWaker, RawWakerVTable, Wake};385386#[inline(always)]387pub unsafe fn raw_waker<'a, W: Wake + Send + Sync + 'a>(waker: Arc<W>) -> RawWaker {388// Increment the reference count of the arc to clone it.389//390// The #[inline(always)] is to ensure that raw_waker and clone_waker are391// always generated in the same code generation unit as one another, and392// therefore that the structurally identical const-promoted RawWakerVTable393// within both functions is deduplicated at LLVM IR code generation time.394// This allows optimizing Waker::will_wake to a single pointer comparison of395// the vtable pointers, rather than comparing all four function pointers396// within the vtables.397#[inline(always)]398unsafe fn clone_waker<W: Wake + Send + Sync>(waker: *const ()) -> RawWaker {399unsafe { Arc::increment_strong_count(waker as *const W) };400RawWaker::new(401waker,402&RawWakerVTable::new(403clone_waker::<W>,404wake::<W>,405wake_by_ref::<W>,406drop_waker::<W>,407),408)409}410411// Wake by value, moving the Arc into the Wake::wake function412unsafe fn wake<W: Wake + Send + Sync>(waker: *const ()) {413let waker = unsafe { Arc::from_raw(waker as *const W) };414<W as Wake>::wake(waker);415}416417// Wake by reference, wrap the waker in ManuallyDrop to avoid dropping it418unsafe fn wake_by_ref<W: Wake + Send + Sync>(waker: *const ()) {419let waker = unsafe { ManuallyDrop::new(Arc::from_raw(waker as *const W)) };420<W as Wake>::wake_by_ref(&waker);421}422423// Decrement the reference count of the Arc on drop424unsafe fn drop_waker<W: Wake + Send + Sync>(waker: *const ()) {425unsafe { Arc::decrement_strong_count(waker as *const W) };426}427428RawWaker::new(429Arc::into_raw(waker) as *const (),430&RawWakerVTable::new(431clone_waker::<W>,432wake::<W>,433wake_by_ref::<W>,434drop_waker::<W>,435),436)437}438}439440441