Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-stream/src/async_executor/mod.rs
8406 views
1
#![allow(clippy::disallowed_types)]
2
3
mod park_group;
4
mod task;
5
6
use std::cell::{Cell, UnsafeCell};
7
use std::future::Future;
8
use std::marker::PhantomData;
9
use std::panic::{AssertUnwindSafe, Location};
10
use std::pin::Pin;
11
use std::sync::atomic::{AtomicBool, Ordering};
12
use std::sync::{Arc, OnceLock, Weak};
13
use std::task::{Context, Poll};
14
use std::time::Instant;
15
16
use crossbeam_deque::{Injector, Steal, Stealer, Worker as WorkQueue};
17
use crossbeam_utils::CachePadded;
18
use park_group::ParkGroup;
19
use parking_lot::Mutex;
20
use polars_core::ALLOW_RAYON_THREADS;
21
use polars_utils::relaxed_cell::RelaxedCell;
22
use rand::rngs::SmallRng;
23
use rand::{Rng, SeedableRng};
24
use slotmap::SlotMap;
25
use task::{Cancellable, DynTask, Runnable};
26
27
static NUM_EXECUTOR_THREADS: RelaxedCell<usize> = RelaxedCell::new_usize(0);
28
pub fn set_num_threads(t: usize) {
29
NUM_EXECUTOR_THREADS.store(t);
30
}
31
32
static TRACK_METRICS: RelaxedCell<bool> = RelaxedCell::new_bool(false);
33
34
pub fn track_task_metrics(should_track: bool) {
35
TRACK_METRICS.store(should_track);
36
}
37
38
static GLOBAL_SCHEDULER: OnceLock<Executor> = OnceLock::new();
39
40
thread_local!(
41
/// Used to store which executor thread this is.
42
static TLS_THREAD_ID: Cell<usize> = const { Cell::new(usize::MAX) };
43
);
44
45
slotmap::new_key_type! {
46
struct TaskKey;
47
}
48
49
/// High priority tasks are scheduled preferentially over low priority tasks.
50
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
51
pub enum TaskPriority {
52
Low,
53
High,
54
}
55
56
/// Metadata associated with a task to help schedule it and clean it up.
57
struct ScopedTaskMetadata {
58
task_key: TaskKey,
59
completed_tasks: Weak<Mutex<Vec<TaskKey>>>,
60
}
61
62
#[derive(Default)]
63
#[repr(align(128))]
64
pub struct TaskMetrics {
65
pub total_polls: RelaxedCell<u64>,
66
pub total_stolen_polls: RelaxedCell<u64>,
67
pub total_poll_time_ns: RelaxedCell<u64>,
68
pub max_poll_time_ns: RelaxedCell<u64>,
69
pub done: RelaxedCell<bool>,
70
}
71
72
struct TaskMetadata {
73
spawn_location: &'static Location<'static>,
74
priority: TaskPriority,
75
freshly_spawned: AtomicBool,
76
scoped: Option<ScopedTaskMetadata>,
77
metrics: Option<Arc<TaskMetrics>>,
78
}
79
80
impl Drop for TaskMetadata {
81
fn drop(&mut self) {
82
if let Some(metrics) = self.metrics.as_ref() {
83
metrics.done.store(true);
84
}
85
86
if let Some(scoped) = &self.scoped {
87
if let Some(completed_tasks) = scoped.completed_tasks.upgrade() {
88
completed_tasks.lock().push(scoped.task_key);
89
}
90
}
91
}
92
}
93
94
pub struct JoinHandle<T>(Arc<dyn DynTask<T, TaskMetadata>>);
95
pub struct CancelHandle(Weak<dyn Cancellable>);
96
97
impl<T> JoinHandle<T> {
98
pub fn metrics(&self) -> Option<&Arc<TaskMetrics>> {
99
self.0.metadata().metrics.as_ref()
100
}
101
102
#[allow(unused)]
103
pub fn spawn_location(&self) -> &'static Location<'static> {
104
self.0.metadata().spawn_location
105
}
106
107
pub fn cancel_handle(&self) -> CancelHandle {
108
let coerce: Weak<dyn DynTask<T, TaskMetadata>> = Arc::downgrade(&self.0);
109
CancelHandle(coerce)
110
}
111
}
112
113
impl<T> Future for JoinHandle<T> {
114
type Output = T;
115
116
fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
117
self.0.poll_join(ctx)
118
}
119
}
120
121
impl CancelHandle {
122
pub fn cancel(&self) {
123
if let Some(t) = self.0.upgrade() {
124
t.cancel();
125
}
126
}
127
}
128
129
pub struct AbortOnDropHandle<T> {
130
join_handle: JoinHandle<T>,
131
cancel_handle: CancelHandle,
132
}
133
134
impl<T> AbortOnDropHandle<T> {
135
pub fn new(join_handle: JoinHandle<T>) -> Self {
136
let cancel_handle = join_handle.cancel_handle();
137
Self {
138
join_handle,
139
cancel_handle,
140
}
141
}
142
}
143
144
impl<T> Future for AbortOnDropHandle<T> {
145
type Output = T;
146
147
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
148
Pin::new(&mut self.join_handle).poll(cx)
149
}
150
}
151
152
impl<T> Drop for AbortOnDropHandle<T> {
153
fn drop(&mut self) {
154
self.cancel_handle.cancel();
155
}
156
}
157
158
/// A task ready to run.
159
type ReadyTask = Arc<dyn Runnable<TaskMetadata>>;
160
161
/// A per-thread task list.
162
struct ThreadLocalTaskList {
163
// May be used from any thread.
164
high_prio_tasks_stealer: Stealer<ReadyTask>,
165
166
// SAFETY: these may only be used on the thread this task list belongs to.
167
high_prio_tasks: WorkQueue<ReadyTask>,
168
local_slot: UnsafeCell<Option<ReadyTask>>,
169
}
170
171
unsafe impl Sync for ThreadLocalTaskList {}
172
173
struct Executor {
174
park_group: ParkGroup,
175
thread_task_lists: Vec<CachePadded<ThreadLocalTaskList>>,
176
global_high_prio_task_queue: Injector<ReadyTask>,
177
global_low_prio_task_queue: Injector<ReadyTask>,
178
}
179
180
impl Executor {
181
fn schedule_task(&self, task: ReadyTask) {
182
let thread = TLS_THREAD_ID.get();
183
let meta = task.metadata();
184
let opt_ttl = self.thread_task_lists.get(thread);
185
186
let mut use_global_queue = opt_ttl.is_none();
187
if meta.freshly_spawned.load(Ordering::Relaxed) {
188
use_global_queue = true;
189
meta.freshly_spawned.store(false, Ordering::Relaxed);
190
}
191
192
if use_global_queue {
193
// Scheduled from an unknown thread, add to global queue.
194
if meta.priority == TaskPriority::High {
195
self.global_high_prio_task_queue.push(task);
196
} else {
197
self.global_low_prio_task_queue.push(task);
198
}
199
self.park_group.unpark_one();
200
} else {
201
let ttl = opt_ttl.unwrap();
202
// SAFETY: this slot may only be accessed from the local thread, which we are.
203
let slot = unsafe { &mut *ttl.local_slot.get() };
204
205
if meta.priority == TaskPriority::High {
206
// Insert new task into thread local slot, taking out the old task.
207
let Some(task) = slot.replace(task) else {
208
// We pushed a task into our local slot which was empty. Since
209
// we are already awake, no need to notify anyone.
210
return;
211
};
212
213
ttl.high_prio_tasks.push(task);
214
self.park_group.unpark_one();
215
} else {
216
// Optimization: while this is a low priority task we have no
217
// high priority tasks on this thread so we'll execute this one.
218
if ttl.high_prio_tasks.is_empty() && slot.is_none() {
219
*slot = Some(task);
220
} else {
221
self.global_low_prio_task_queue.push(task);
222
self.park_group.unpark_one();
223
}
224
}
225
}
226
}
227
228
fn try_steal_task<R: Rng>(&self, thread: usize, rng: &mut R) -> Option<ReadyTask> {
229
// Try to get a global task.
230
loop {
231
match self.global_high_prio_task_queue.steal() {
232
Steal::Empty => break,
233
Steal::Success(task) => return Some(task),
234
Steal::Retry => std::hint::spin_loop(),
235
}
236
}
237
238
loop {
239
match self.global_low_prio_task_queue.steal() {
240
Steal::Empty => break,
241
Steal::Success(task) => return Some(task),
242
Steal::Retry => std::hint::spin_loop(),
243
}
244
}
245
246
// Try to steal tasks.
247
let ttl = &self.thread_task_lists[thread];
248
for _ in 0..4 {
249
let mut retry = true;
250
while retry {
251
retry = false;
252
253
for idx in random_permutation(self.thread_task_lists.len() as u32, rng) {
254
let foreign_ttl = &self.thread_task_lists[idx as usize];
255
match foreign_ttl
256
.high_prio_tasks_stealer
257
.steal_batch_and_pop(&ttl.high_prio_tasks)
258
{
259
Steal::Empty => {},
260
Steal::Success(task) => return Some(task),
261
Steal::Retry => retry = true,
262
}
263
}
264
265
std::hint::spin_loop()
266
}
267
}
268
269
None
270
}
271
272
fn runner(&self, thread: usize) {
273
TLS_THREAD_ID.set(thread);
274
ALLOW_RAYON_THREADS.set(false);
275
276
let mut rng = SmallRng::from_rng(&mut rand::rng());
277
let mut worker = self.park_group.new_worker();
278
279
loop {
280
let ttl = &self.thread_task_lists[thread];
281
let mut local = true;
282
let task = (|| {
283
// Try to get a task from LIFO slot.
284
if let Some(task) = unsafe { (*ttl.local_slot.get()).take() } {
285
return Some(task);
286
}
287
288
// Try to get a local high-priority task.
289
if let Some(task) = ttl.high_prio_tasks.pop() {
290
return Some(task);
291
}
292
293
// Try to steal a task.
294
local = false;
295
if let Some(task) = self.try_steal_task(thread, &mut rng) {
296
return Some(task);
297
}
298
299
// Prepare to park, then try one more steal attempt.
300
let park = worker.prepare_park();
301
if let Some(task) = self.try_steal_task(thread, &mut rng) {
302
return Some(task);
303
}
304
305
park.park();
306
None
307
})();
308
309
if let Some(task) = task {
310
worker.recruit_next();
311
if let Some(metrics) = task.metadata().metrics.clone() {
312
let start = Instant::now();
313
task.run();
314
let elapsed_ns = start.elapsed().as_nanos() as u64;
315
metrics.total_polls.fetch_add(1);
316
if !local {
317
metrics.total_stolen_polls.fetch_add(1);
318
}
319
metrics.total_poll_time_ns.fetch_add(elapsed_ns);
320
metrics.max_poll_time_ns.fetch_max(elapsed_ns);
321
} else {
322
task.run();
323
}
324
}
325
}
326
}
327
328
fn global() -> &'static Executor {
329
GLOBAL_SCHEDULER.get_or_init(|| {
330
let mut n_threads = NUM_EXECUTOR_THREADS.load();
331
if n_threads == 0 {
332
n_threads = std::thread::available_parallelism()
333
.map(|n| n.get())
334
.unwrap_or(4);
335
}
336
337
let thread_task_lists = (0..n_threads)
338
.map(|t| {
339
std::thread::Builder::new()
340
.name(format!("async-executor-{t}"))
341
.spawn(move || Self::global().runner(t))
342
.unwrap();
343
344
let high_prio_tasks = WorkQueue::new_lifo();
345
CachePadded::new(ThreadLocalTaskList {
346
high_prio_tasks_stealer: high_prio_tasks.stealer(),
347
high_prio_tasks,
348
local_slot: UnsafeCell::new(None),
349
})
350
})
351
.collect();
352
Self {
353
park_group: ParkGroup::new(),
354
thread_task_lists,
355
global_high_prio_task_queue: Injector::new(),
356
global_low_prio_task_queue: Injector::new(),
357
}
358
})
359
}
360
}
361
362
pub struct TaskScope<'scope, 'env: 'scope> {
363
// Keep track of in-progress tasks so we can forcibly cancel them
364
// when the scope ends, to ensure the lifetimes are respected.
365
// Tasks add their own key to completed_tasks when done so we can
366
// reclaim the memory used by the cancel_handles.
367
cancel_handles: Mutex<SlotMap<TaskKey, CancelHandle>>,
368
completed_tasks: Arc<Mutex<Vec<TaskKey>>>,
369
370
// Copied from std::thread::scope. Necessary to prevent unsoundness.
371
scope: PhantomData<&'scope mut &'scope ()>,
372
env: PhantomData<&'env mut &'env ()>,
373
}
374
375
impl<'scope> TaskScope<'scope, '_> {
376
// Not Drop because that extends lifetimes.
377
fn destroy(&self) {
378
// Make sure all tasks are cancelled.
379
for (_, t) in self.cancel_handles.lock().drain() {
380
t.cancel();
381
}
382
}
383
384
fn clear_completed_tasks(&self) {
385
let mut cancel_handles = self.cancel_handles.lock();
386
for t in self.completed_tasks.lock().drain(..) {
387
cancel_handles.remove(t);
388
}
389
}
390
391
#[track_caller]
392
pub fn spawn_task<F: Future + Send + 'scope>(
393
&self,
394
priority: TaskPriority,
395
fut: F,
396
) -> JoinHandle<F::Output>
397
where
398
<F as Future>::Output: Send + 'static,
399
{
400
let spawn_location = Location::caller();
401
self.clear_completed_tasks();
402
403
let mut runnable = None;
404
let mut join_handle = None;
405
self.cancel_handles.lock().insert_with_key(|task_key| {
406
let metrics = TRACK_METRICS.load().then(Arc::default);
407
let dyn_task = unsafe {
408
// SAFETY: we make sure to cancel this task before 'scope ends.
409
let executor = Executor::global();
410
let on_wake = move |task| executor.schedule_task(task);
411
task::spawn_with_lifetime(
412
fut,
413
on_wake,
414
TaskMetadata {
415
spawn_location,
416
priority,
417
freshly_spawned: AtomicBool::new(true),
418
scoped: Some(ScopedTaskMetadata {
419
task_key,
420
completed_tasks: Arc::downgrade(&self.completed_tasks),
421
}),
422
metrics,
423
},
424
)
425
};
426
runnable = Some(Arc::clone(&dyn_task));
427
let jh = JoinHandle(dyn_task);
428
let cancel_handle = jh.cancel_handle();
429
join_handle = Some(jh);
430
cancel_handle
431
});
432
runnable.unwrap().schedule();
433
join_handle.unwrap()
434
}
435
}
436
437
pub fn task_scope<'env, F, T>(f: F) -> T
438
where
439
F: for<'scope> FnOnce(&'scope TaskScope<'scope, 'env>) -> T,
440
{
441
// By having this local variable inaccessible to anyone we guarantee
442
// that either abort is called killing the entire process, or that this
443
// executor is properly destroyed.
444
let scope = TaskScope {
445
cancel_handles: Mutex::default(),
446
completed_tasks: Arc::new(Mutex::default()),
447
scope: PhantomData,
448
env: PhantomData,
449
};
450
451
let result = std::panic::catch_unwind(AssertUnwindSafe(|| f(&scope)));
452
453
// Make sure all tasks are properly destroyed.
454
scope.destroy();
455
456
match result {
457
Err(e) => std::panic::resume_unwind(e),
458
Ok(result) => result,
459
}
460
}
461
462
#[track_caller]
463
pub fn spawn<F: Future + Send + 'static>(priority: TaskPriority, fut: F) -> JoinHandle<F::Output>
464
where
465
<F as Future>::Output: Send + 'static,
466
{
467
let spawn_location = Location::caller();
468
let executor = Executor::global();
469
let on_wake = move |task| executor.schedule_task(task);
470
let metrics = TRACK_METRICS.load().then(Arc::default);
471
let dyn_task = task::spawn(
472
fut,
473
on_wake,
474
TaskMetadata {
475
spawn_location,
476
priority,
477
freshly_spawned: AtomicBool::new(true),
478
scoped: None,
479
metrics,
480
},
481
);
482
Arc::clone(&dyn_task).schedule();
483
JoinHandle(dyn_task)
484
}
485
486
fn random_permutation<R: Rng>(len: u32, rng: &mut R) -> impl Iterator<Item = u32> {
487
let modulus = len.next_power_of_two();
488
let halfwidth = modulus.trailing_zeros() / 2;
489
let mask = modulus - 1;
490
let displace_zero = rng.random::<u32>();
491
let odd1 = rng.random::<u32>() | 1;
492
let odd2 = rng.random::<u32>() | 1;
493
let uniform_first = ((rng.random::<u32>() as u64 * len as u64) >> 32) as u32;
494
495
(0..modulus)
496
.map(move |mut i| {
497
// Invertible permutation on [0, modulus).
498
i = i.wrapping_add(displace_zero);
499
i = i.wrapping_mul(odd1);
500
i ^= (i & mask) >> halfwidth;
501
i = i.wrapping_mul(odd2);
502
i & mask
503
})
504
.filter(move |i| *i < len)
505
.map(move |mut i| {
506
i += uniform_first;
507
if i >= len {
508
i -= len;
509
}
510
i
511
})
512
}
513
514