Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-stream/src/async_executor/task.rs
8379 views
1
#![allow(unsafe_op_in_unsafe_fn)]
2
use std::any::Any;
3
use std::future::Future;
4
use std::panic::{AssertUnwindSafe, catch_unwind, resume_unwind};
5
use std::pin::Pin;
6
use std::sync::Arc;
7
use std::sync::atomic::{AtomicU8, Ordering};
8
use std::task::{Context, Poll, Wake, Waker};
9
10
use atomic_waker::AtomicWaker;
11
use parking_lot::Mutex;
12
use polars_error::signals::try_raise_keyboard_interrupt;
13
14
/// The state of the task. Can't be part of the TaskData enum as it needs to be
15
/// atomically updateable, even when we hold the lock on the data.
16
#[derive(Default)]
17
struct TaskState {
18
state: AtomicU8,
19
}
20
21
impl TaskState {
22
/// Default state, not running, not scheduled.
23
const IDLE: u8 = 0;
24
25
/// Task is scheduled, that is (task.schedule)(task) was called.
26
const SCHEDULED: u8 = 1;
27
28
/// Task is currently running.
29
const RUNNING: u8 = 2;
30
31
/// Task notified while running.
32
const NOTIFIED_WHILE_RUNNING: u8 = 3;
33
34
/// Wake this task. Returns true if task.schedule should be called.
35
fn wake(&self) -> bool {
36
self.state
37
.fetch_update(Ordering::Release, Ordering::Relaxed, |state| match state {
38
Self::SCHEDULED | Self::NOTIFIED_WHILE_RUNNING => None,
39
Self::RUNNING => Some(Self::NOTIFIED_WHILE_RUNNING),
40
Self::IDLE => Some(Self::SCHEDULED),
41
_ => unreachable!("invalid TaskState"),
42
})
43
.map(|state| state == Self::IDLE)
44
.unwrap_or(false)
45
}
46
47
/// Start running this task.
48
fn start_running(&self) {
49
assert_eq!(self.state.load(Ordering::Acquire), Self::SCHEDULED);
50
self.state.store(Self::RUNNING, Ordering::Relaxed);
51
}
52
53
/// Done running this task. Returns true if task.schedule should be called.
54
fn reschedule_after_running(&self) -> bool {
55
self.state
56
.fetch_update(Ordering::Release, Ordering::Relaxed, |state| match state {
57
Self::RUNNING => Some(Self::IDLE),
58
Self::NOTIFIED_WHILE_RUNNING => Some(Self::SCHEDULED),
59
_ => panic!("TaskState::reschedule_after_running() called on invalid state"),
60
})
61
.map(|old_state| old_state == Self::NOTIFIED_WHILE_RUNNING)
62
.unwrap_or(false)
63
}
64
}
65
66
enum TaskData<F: Future> {
67
Empty,
68
Polling(F, Waker),
69
Ready(F::Output),
70
Panic(Box<dyn Any + Send + 'static>),
71
Cancelled,
72
Joined,
73
}
74
75
struct Task<F: Future, S, M> {
76
state: TaskState,
77
data: Mutex<TaskData<F>>,
78
join_waker: AtomicWaker,
79
schedule: S,
80
metadata: M,
81
}
82
83
impl<'a, F, S, M> Task<F, S, M>
84
where
85
F: Future + Send + 'a,
86
F::Output: Send + 'static,
87
S: Fn(Arc<dyn Runnable<M>>) + Send + Sync + Copy + 'static,
88
M: Send + Sync + 'static,
89
{
90
/// # Safety
91
/// It is the responsibility of the caller that before lifetime 'a ends the
92
/// task is either polled to completion or cancelled.
93
unsafe fn spawn(future: F, schedule: S, metadata: M) -> Arc<Self> {
94
let task = Arc::new(Self {
95
state: TaskState::default(),
96
data: Mutex::new(TaskData::Empty),
97
join_waker: AtomicWaker::new(),
98
schedule,
99
metadata,
100
});
101
102
let waker = unsafe { Waker::from_raw(std_shim::raw_waker(task.clone())) };
103
*task.data.try_lock().unwrap() = TaskData::Polling(future, waker);
104
task
105
}
106
107
fn into_dyn(self: Arc<Self>) -> Arc<dyn DynTask<F::Output, M>> {
108
let arc: Arc<dyn DynTask<F::Output, M> + 'a> = self;
109
let arc: Arc<dyn DynTask<F::Output, M>> = unsafe { std::mem::transmute(arc) };
110
arc
111
}
112
}
113
114
impl<F, S, M> Wake for Task<F, S, M>
115
where
116
F: Future + Send,
117
F::Output: Send + 'static,
118
S: Fn(Arc<dyn Runnable<M>>) + Send + Sync + Copy + 'static,
119
M: Send + Sync + 'static,
120
{
121
fn wake(self: Arc<Self>) {
122
if self.state.wake() {
123
let schedule = self.schedule;
124
(schedule)(self.into_dyn());
125
}
126
}
127
128
fn wake_by_ref(self: &Arc<Self>) {
129
self.clone().wake()
130
}
131
}
132
133
/// Partially type-erased task: no future.
134
pub trait DynTask<T, M>: Send + Sync + Runnable<M> + Joinable<T> + Cancellable {}
135
136
impl<F, S, M> DynTask<F::Output, M> for Task<F, S, M>
137
where
138
F: Future + Send,
139
F::Output: Send + 'static,
140
S: Fn(Arc<dyn Runnable<M>>) + Send + Sync + Copy + 'static,
141
M: Send + Sync + 'static,
142
{
143
}
144
145
/// Partially type-erased task: no future or return type.
146
pub trait Runnable<M>: Send + Sync {
147
/// Gives the metadata for this task.
148
fn metadata(&self) -> &M;
149
150
/// Runs a task, and returns true if the task is done.
151
fn run(self: Arc<Self>) -> bool;
152
153
/// Schedules this task.
154
fn schedule(self: Arc<Self>);
155
}
156
157
impl<F, S, M> Runnable<M> for Task<F, S, M>
158
where
159
F: Future + Send,
160
F::Output: Send + 'static,
161
S: Fn(Arc<dyn Runnable<M>>) + Send + Sync + Copy + 'static,
162
M: Send + Sync + 'static,
163
{
164
fn metadata(&self) -> &M {
165
&self.metadata
166
}
167
168
fn run(self: Arc<Self>) -> bool {
169
let mut data = self.data.lock();
170
171
let poll_result = match &mut *data {
172
TaskData::Polling(future, waker) => {
173
self.state.start_running();
174
// SAFETY: we always store a Task in an Arc and never move it.
175
let fut = unsafe { Pin::new_unchecked(future) };
176
let mut ctx = Context::from_waker(waker);
177
catch_unwind(AssertUnwindSafe(|| {
178
try_raise_keyboard_interrupt();
179
fut.poll(&mut ctx)
180
}))
181
},
182
TaskData::Cancelled => return true,
183
_ => unreachable!("invalid TaskData when polling"),
184
};
185
186
*data = match poll_result {
187
Err(error) => TaskData::Panic(error),
188
Ok(Poll::Ready(output)) => TaskData::Ready(output),
189
Ok(Poll::Pending) => {
190
drop(data);
191
if self.state.reschedule_after_running() {
192
let schedule = self.schedule;
193
(schedule)(self.into_dyn());
194
}
195
return false;
196
},
197
};
198
199
drop(data);
200
self.join_waker.wake();
201
true
202
}
203
204
fn schedule(self: Arc<Self>) {
205
if self.state.wake() {
206
(self.schedule)(self.clone().into_dyn());
207
}
208
}
209
}
210
211
/// Partially type-erased task: no future or metadata.
212
pub trait Joinable<T>: Send + Sync + Cancellable {
213
fn poll_join(&self, ctx: &mut Context<'_>) -> Poll<T>;
214
}
215
216
impl<F, S, M> Joinable<F::Output> for Task<F, S, M>
217
where
218
F: Future + Send,
219
F::Output: Send + 'static,
220
S: Fn(Arc<dyn Runnable<M>>) + Send + Sync + Copy + 'static,
221
M: Send + Sync + 'static,
222
{
223
fn poll_join(&self, cx: &mut Context<'_>) -> Poll<F::Output> {
224
self.join_waker.register(cx.waker());
225
if let Some(mut data) = self.data.try_lock() {
226
if matches!(*data, TaskData::Empty | TaskData::Polling(..)) {
227
return Poll::Pending;
228
}
229
230
match core::mem::replace(&mut *data, TaskData::Joined) {
231
TaskData::Ready(output) => Poll::Ready(output),
232
TaskData::Panic(error) => resume_unwind(error),
233
TaskData::Cancelled => panic!("joined on cancelled task"),
234
_ => unreachable!("invalid TaskData when joining"),
235
}
236
} else {
237
Poll::Pending
238
}
239
}
240
}
241
242
/// Fully type-erased task.
243
pub trait Cancellable: Send + Sync {
244
fn cancel(&self);
245
}
246
247
impl<F, S, M> Cancellable for Task<F, S, M>
248
where
249
F: Future + Send,
250
F::Output: Send + 'static,
251
S: Send + Sync + 'static,
252
M: Send + Sync + 'static,
253
{
254
fn cancel(&self) {
255
let mut data = self.data.lock();
256
match *data {
257
// Already done.
258
TaskData::Panic(_) | TaskData::Joined => {},
259
260
// Still in-progress, cancel.
261
_ => {
262
*data = TaskData::Cancelled;
263
if let Some(join_waker) = self.join_waker.take() {
264
join_waker.wake();
265
}
266
},
267
}
268
}
269
}
270
271
pub fn spawn<F, S, M>(future: F, schedule: S, metadata: M) -> Arc<dyn DynTask<F::Output, M>>
272
where
273
F: Future + Send + 'static,
274
F::Output: Send + 'static,
275
S: Fn(Arc<dyn Runnable<M>>) + Send + Sync + Copy + 'static,
276
M: Send + Sync + 'static,
277
{
278
unsafe { Task::spawn(future, schedule, metadata) }.into_dyn()
279
}
280
281
/// Takes a future and turns it into a runnable task with associated metadata.
282
///
283
/// When the task is pending its waker will be set to call schedule
284
/// with the runnable.
285
pub unsafe fn spawn_with_lifetime<'a, F, S, M>(
286
future: F,
287
schedule: S,
288
metadata: M,
289
) -> Arc<dyn DynTask<F::Output, M>>
290
where
291
F: Future + Send + 'a,
292
F::Output: Send + 'static,
293
S: Fn(Arc<dyn Runnable<M>>) + Send + Sync + Copy + 'static,
294
M: Send + Sync + 'static,
295
{
296
Task::spawn(future, schedule, metadata).into_dyn()
297
}
298
299
// Copied from the standard library, except without the 'static bound.
300
mod std_shim {
301
use std::mem::ManuallyDrop;
302
use std::sync::Arc;
303
use std::task::{RawWaker, RawWakerVTable, Wake};
304
305
#[inline(always)]
306
pub unsafe fn raw_waker<'a, W: Wake + Send + Sync + 'a>(waker: Arc<W>) -> RawWaker {
307
// Increment the reference count of the arc to clone it.
308
//
309
// The #[inline(always)] is to ensure that raw_waker and clone_waker are
310
// always generated in the same code generation unit as one another, and
311
// therefore that the structurally identical const-promoted RawWakerVTable
312
// within both functions is deduplicated at LLVM IR code generation time.
313
// This allows optimizing Waker::will_wake to a single pointer comparison of
314
// the vtable pointers, rather than comparing all four function pointers
315
// within the vtables.
316
#[inline(always)]
317
unsafe fn clone_waker<W: Wake + Send + Sync>(waker: *const ()) -> RawWaker {
318
unsafe { Arc::increment_strong_count(waker as *const W) };
319
RawWaker::new(
320
waker,
321
&RawWakerVTable::new(
322
clone_waker::<W>,
323
wake::<W>,
324
wake_by_ref::<W>,
325
drop_waker::<W>,
326
),
327
)
328
}
329
330
// Wake by value, moving the Arc into the Wake::wake function
331
unsafe fn wake<W: Wake + Send + Sync>(waker: *const ()) {
332
let waker = unsafe { Arc::from_raw(waker as *const W) };
333
<W as Wake>::wake(waker);
334
}
335
336
// Wake by reference, wrap the waker in ManuallyDrop to avoid dropping it
337
unsafe fn wake_by_ref<W: Wake + Send + Sync>(waker: *const ()) {
338
let waker = unsafe { ManuallyDrop::new(Arc::from_raw(waker as *const W)) };
339
<W as Wake>::wake_by_ref(&waker);
340
}
341
342
// Decrement the reference count of the Arc on drop
343
unsafe fn drop_waker<W: Wake + Send + Sync>(waker: *const ()) {
344
unsafe { Arc::decrement_strong_count(waker as *const W) };
345
}
346
347
RawWaker::new(
348
Arc::into_raw(waker) as *const (),
349
&RawWakerVTable::new(
350
clone_waker::<W>,
351
wake::<W>,
352
wake_by_ref::<W>,
353
drop_waker::<W>,
354
),
355
)
356
}
357
}
358
359