Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-stream/src/async_executor/task.rs
6939 views
1
#![allow(unsafe_op_in_unsafe_fn)]
2
use std::any::Any;
3
use std::future::Future;
4
use std::panic::{AssertUnwindSafe, catch_unwind, resume_unwind};
5
use std::pin::Pin;
6
use std::sync::atomic::{AtomicU8, Ordering};
7
use std::sync::{Arc, Weak};
8
use std::task::{Context, Poll, Wake, Waker};
9
10
use atomic_waker::AtomicWaker;
11
use parking_lot::Mutex;
12
use polars_error::signals::try_raise_keyboard_interrupt;
13
14
/// The state of the task. Can't be part of the TaskData enum as it needs to be
15
/// atomically updateable, even when we hold the lock on the data.
16
#[derive(Default)]
17
struct TaskState {
18
state: AtomicU8,
19
}
20
21
impl TaskState {
22
/// Default state, not running, not scheduled.
23
const IDLE: u8 = 0;
24
25
/// Task is scheduled, that is (task.schedule)(task) was called.
26
const SCHEDULED: u8 = 1;
27
28
/// Task is currently running.
29
const RUNNING: u8 = 2;
30
31
/// Task notified while running.
32
const NOTIFIED_WHILE_RUNNING: u8 = 3;
33
34
/// Wake this task. Returns true if task.schedule should be called.
35
fn wake(&self) -> bool {
36
self.state
37
.fetch_update(Ordering::Release, Ordering::Relaxed, |state| match state {
38
Self::SCHEDULED | Self::NOTIFIED_WHILE_RUNNING => None,
39
Self::RUNNING => Some(Self::NOTIFIED_WHILE_RUNNING),
40
Self::IDLE => Some(Self::SCHEDULED),
41
_ => unreachable!("invalid TaskState"),
42
})
43
.map(|state| state == Self::IDLE)
44
.unwrap_or(false)
45
}
46
47
/// Start running this task.
48
fn start_running(&self) {
49
assert_eq!(self.state.load(Ordering::Acquire), Self::SCHEDULED);
50
self.state.store(Self::RUNNING, Ordering::Relaxed);
51
}
52
53
/// Done running this task. Returns true if task.schedule should be called.
54
fn reschedule_after_running(&self) -> bool {
55
self.state
56
.fetch_update(Ordering::Release, Ordering::Relaxed, |state| match state {
57
Self::RUNNING => Some(Self::IDLE),
58
Self::NOTIFIED_WHILE_RUNNING => Some(Self::SCHEDULED),
59
_ => panic!("TaskState::reschedule_after_running() called on invalid state"),
60
})
61
.map(|old_state| old_state == Self::NOTIFIED_WHILE_RUNNING)
62
.unwrap_or(false)
63
}
64
}
65
66
enum TaskData<F: Future> {
67
Empty,
68
Polling(F, Waker),
69
Ready(F::Output),
70
Panic(Box<dyn Any + Send + 'static>),
71
Cancelled,
72
Joined,
73
}
74
75
struct Task<F: Future, S, M> {
76
state: TaskState,
77
data: Mutex<TaskData<F>>,
78
join_waker: AtomicWaker,
79
schedule: S,
80
metadata: M,
81
}
82
83
impl<'a, F, S, M> Task<F, S, M>
84
where
85
F: Future + Send + 'a,
86
F::Output: Send + 'static,
87
S: Fn(Runnable<M>) + Send + Sync + Copy + 'static,
88
M: Send + Sync + 'static,
89
{
90
/// # Safety
91
/// It is the responsibility of the caller that before lifetime 'a ends the
92
/// task is either polled to completion or cancelled.
93
unsafe fn spawn(future: F, schedule: S, metadata: M) -> Arc<Self> {
94
let task = Arc::new(Self {
95
state: TaskState::default(),
96
data: Mutex::new(TaskData::Empty),
97
join_waker: AtomicWaker::new(),
98
schedule,
99
metadata,
100
});
101
102
let waker = unsafe { Waker::from_raw(std_shim::raw_waker(task.clone())) };
103
*task.data.try_lock().unwrap() = TaskData::Polling(future, waker);
104
task
105
}
106
107
fn into_runnable(self: Arc<Self>) -> Runnable<M> {
108
let arc: Arc<dyn DynTask<M> + 'a> = self;
109
let arc: Arc<dyn DynTask<M>> = unsafe { std::mem::transmute(arc) };
110
Runnable(arc)
111
}
112
113
fn into_join_handle(self: Arc<Self>) -> JoinHandle<F::Output> {
114
let arc: Arc<dyn Joinable<F::Output> + 'a> = self;
115
let arc: Arc<dyn Joinable<F::Output>> = unsafe { std::mem::transmute(arc) };
116
JoinHandle(Some(arc))
117
}
118
119
fn into_cancel_handle(self: Arc<Self>) -> CancelHandle {
120
let arc: Arc<dyn Cancellable + 'a> = self;
121
let arc: Arc<dyn Cancellable> = unsafe { std::mem::transmute(arc) };
122
CancelHandle(Arc::downgrade(&arc))
123
}
124
}
125
126
impl<F, S, M> Wake for Task<F, S, M>
127
where
128
F: Future + Send,
129
F::Output: Send + 'static,
130
S: Fn(Runnable<M>) + Send + Sync + Copy + 'static,
131
M: Send + Sync + 'static,
132
{
133
fn wake(self: Arc<Self>) {
134
if self.state.wake() {
135
let schedule = self.schedule;
136
(schedule)(self.into_runnable());
137
}
138
}
139
140
fn wake_by_ref(self: &Arc<Self>) {
141
self.clone().wake()
142
}
143
}
144
145
pub trait DynTask<M>: Send + Sync {
146
fn metadata(&self) -> &M;
147
fn run(self: Arc<Self>) -> bool;
148
fn schedule(self: Arc<Self>);
149
}
150
151
impl<F, S, M> DynTask<M> for Task<F, S, M>
152
where
153
F: Future + Send,
154
F::Output: Send + 'static,
155
S: Fn(Runnable<M>) + Send + Sync + Copy + 'static,
156
M: Send + Sync + 'static,
157
{
158
fn metadata(&self) -> &M {
159
&self.metadata
160
}
161
162
fn run(self: Arc<Self>) -> bool {
163
let mut data = self.data.lock();
164
165
let poll_result = match &mut *data {
166
TaskData::Polling(future, waker) => {
167
self.state.start_running();
168
// SAFETY: we always store a Task in an Arc and never move it.
169
let fut = unsafe { Pin::new_unchecked(future) };
170
let mut ctx = Context::from_waker(waker);
171
catch_unwind(AssertUnwindSafe(|| {
172
try_raise_keyboard_interrupt();
173
fut.poll(&mut ctx)
174
}))
175
},
176
TaskData::Cancelled => return true,
177
_ => unreachable!("invalid TaskData when polling"),
178
};
179
180
*data = match poll_result {
181
Err(error) => TaskData::Panic(error),
182
Ok(Poll::Ready(output)) => TaskData::Ready(output),
183
Ok(Poll::Pending) => {
184
drop(data);
185
if self.state.reschedule_after_running() {
186
let schedule = self.schedule;
187
(schedule)(self.into_runnable());
188
}
189
return false;
190
},
191
};
192
193
drop(data);
194
self.join_waker.wake();
195
true
196
}
197
198
fn schedule(self: Arc<Self>) {
199
if self.state.wake() {
200
(self.schedule)(self.clone().into_runnable());
201
}
202
}
203
}
204
205
trait Joinable<T>: Send + Sync {
206
fn cancel_handle(self: Arc<Self>) -> CancelHandle;
207
fn poll_join(&self, ctx: &mut Context<'_>) -> Poll<T>;
208
}
209
210
impl<F, S, M> Joinable<F::Output> for Task<F, S, M>
211
where
212
F: Future + Send,
213
F::Output: Send + 'static,
214
S: Fn(Runnable<M>) + Send + Sync + Copy + 'static,
215
M: Send + Sync + 'static,
216
{
217
fn cancel_handle(self: Arc<Self>) -> CancelHandle {
218
self.into_cancel_handle()
219
}
220
221
fn poll_join(&self, cx: &mut Context<'_>) -> Poll<F::Output> {
222
self.join_waker.register(cx.waker());
223
if let Some(mut data) = self.data.try_lock() {
224
if matches!(*data, TaskData::Empty | TaskData::Polling(..)) {
225
return Poll::Pending;
226
}
227
228
match core::mem::replace(&mut *data, TaskData::Joined) {
229
TaskData::Ready(output) => Poll::Ready(output),
230
TaskData::Panic(error) => resume_unwind(error),
231
TaskData::Cancelled => panic!("joined on cancelled task"),
232
_ => unreachable!("invalid TaskData when joining"),
233
}
234
} else {
235
Poll::Pending
236
}
237
}
238
}
239
240
trait Cancellable: Send + Sync {
241
fn cancel(&self);
242
}
243
244
impl<F, S, M> Cancellable for Task<F, S, M>
245
where
246
F: Future + Send,
247
F::Output: Send + 'static,
248
S: Send + Sync + 'static,
249
M: Send + Sync + 'static,
250
{
251
fn cancel(&self) {
252
let mut data = self.data.lock();
253
match *data {
254
// Already done.
255
TaskData::Panic(_) | TaskData::Joined => {},
256
257
// Still in-progress, cancel.
258
_ => {
259
*data = TaskData::Cancelled;
260
if let Some(join_waker) = self.join_waker.take() {
261
join_waker.wake();
262
}
263
},
264
}
265
}
266
}
267
268
pub struct Runnable<M>(Arc<dyn DynTask<M>>);
269
270
impl<M> Runnable<M> {
271
/// Gives the metadata for this task.
272
pub fn metadata(&self) -> &M {
273
self.0.metadata()
274
}
275
276
/// Runs a task, and returns true if the task is done.
277
pub fn run(self) -> bool {
278
self.0.run()
279
}
280
281
/// Schedules this task.
282
pub fn schedule(self) {
283
self.0.schedule()
284
}
285
}
286
287
pub struct JoinHandle<T>(Option<Arc<dyn Joinable<T>>>);
288
pub struct CancelHandle(Weak<dyn Cancellable>);
289
pub struct AbortOnDropHandle<T> {
290
join_handle: JoinHandle<T>,
291
cancel_handle: CancelHandle,
292
}
293
294
impl<T> JoinHandle<T> {
295
pub fn cancel_handle(&self) -> CancelHandle {
296
let arc = self
297
.0
298
.as_ref()
299
.expect("called cancel_handle on joined JoinHandle");
300
Arc::clone(arc).cancel_handle()
301
}
302
}
303
304
impl<T> Future for JoinHandle<T> {
305
type Output = T;
306
307
fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
308
let joinable = self.0.take().expect("JoinHandle polled after completion");
309
310
if let Poll::Ready(output) = joinable.poll_join(ctx) {
311
return Poll::Ready(output);
312
}
313
314
self.0 = Some(joinable);
315
Poll::Pending
316
}
317
}
318
319
impl CancelHandle {
320
pub fn cancel(&self) {
321
if let Some(t) = self.0.upgrade() {
322
t.cancel();
323
}
324
}
325
}
326
327
impl<T> AbortOnDropHandle<T> {
328
pub fn new(join_handle: JoinHandle<T>) -> Self {
329
let cancel_handle = join_handle.cancel_handle();
330
Self {
331
join_handle,
332
cancel_handle,
333
}
334
}
335
}
336
337
impl<T> Future for AbortOnDropHandle<T> {
338
type Output = T;
339
340
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
341
Pin::new(&mut self.join_handle).poll(cx)
342
}
343
}
344
345
impl<T> Drop for AbortOnDropHandle<T> {
346
fn drop(&mut self) {
347
self.cancel_handle.cancel();
348
}
349
}
350
351
pub fn spawn<F, S, M>(future: F, schedule: S, metadata: M) -> (Runnable<M>, JoinHandle<F::Output>)
352
where
353
F: Future + Send + 'static,
354
F::Output: Send + 'static,
355
S: Fn(Runnable<M>) + Send + Sync + Copy + 'static,
356
M: Send + Sync + 'static,
357
{
358
let task = unsafe { Task::spawn(future, schedule, metadata) };
359
(task.clone().into_runnable(), task.into_join_handle())
360
}
361
362
/// Takes a future and turns it into a runnable task with associated metadata.
363
///
364
/// When the task is pending its waker will be set to call schedule
365
/// with the runnable.
366
pub unsafe fn spawn_with_lifetime<'a, F, S, M>(
367
future: F,
368
schedule: S,
369
metadata: M,
370
) -> (Runnable<M>, JoinHandle<F::Output>)
371
where
372
F: Future + Send + 'a,
373
F::Output: Send + 'static,
374
S: Fn(Runnable<M>) + Send + Sync + Copy + 'static,
375
M: Send + Sync + 'static,
376
{
377
let task = Task::spawn(future, schedule, metadata);
378
(task.clone().into_runnable(), task.into_join_handle())
379
}
380
381
// Copied from the standard library, except without the 'static bound.
382
mod std_shim {
383
use std::mem::ManuallyDrop;
384
use std::sync::Arc;
385
use std::task::{RawWaker, RawWakerVTable, Wake};
386
387
#[inline(always)]
388
pub unsafe fn raw_waker<'a, W: Wake + Send + Sync + 'a>(waker: Arc<W>) -> RawWaker {
389
// Increment the reference count of the arc to clone it.
390
//
391
// The #[inline(always)] is to ensure that raw_waker and clone_waker are
392
// always generated in the same code generation unit as one another, and
393
// therefore that the structurally identical const-promoted RawWakerVTable
394
// within both functions is deduplicated at LLVM IR code generation time.
395
// This allows optimizing Waker::will_wake to a single pointer comparison of
396
// the vtable pointers, rather than comparing all four function pointers
397
// within the vtables.
398
#[inline(always)]
399
unsafe fn clone_waker<W: Wake + Send + Sync>(waker: *const ()) -> RawWaker {
400
unsafe { Arc::increment_strong_count(waker as *const W) };
401
RawWaker::new(
402
waker,
403
&RawWakerVTable::new(
404
clone_waker::<W>,
405
wake::<W>,
406
wake_by_ref::<W>,
407
drop_waker::<W>,
408
),
409
)
410
}
411
412
// Wake by value, moving the Arc into the Wake::wake function
413
unsafe fn wake<W: Wake + Send + Sync>(waker: *const ()) {
414
let waker = unsafe { Arc::from_raw(waker as *const W) };
415
<W as Wake>::wake(waker);
416
}
417
418
// Wake by reference, wrap the waker in ManuallyDrop to avoid dropping it
419
unsafe fn wake_by_ref<W: Wake + Send + Sync>(waker: *const ()) {
420
let waker = unsafe { ManuallyDrop::new(Arc::from_raw(waker as *const W)) };
421
<W as Wake>::wake_by_ref(&waker);
422
}
423
424
// Decrement the reference count of the Arc on drop
425
unsafe fn drop_waker<W: Wake + Send + Sync>(waker: *const ()) {
426
unsafe { Arc::decrement_strong_count(waker as *const W) };
427
}
428
429
RawWaker::new(
430
Arc::into_raw(waker) as *const (),
431
&RawWakerVTable::new(
432
clone_waker::<W>,
433
wake::<W>,
434
wake_by_ref::<W>,
435
drop_waker::<W>,
436
),
437
)
438
}
439
}
440
441