Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-stream/src/async_executor/mod.rs
6939 views
1
#![allow(clippy::disallowed_types)]
2
3
mod park_group;
4
mod task;
5
6
use std::cell::{Cell, UnsafeCell};
7
use std::collections::HashMap;
8
use std::future::Future;
9
use std::marker::PhantomData;
10
use std::panic::{AssertUnwindSafe, Location};
11
use std::sync::atomic::{AtomicBool, Ordering};
12
use std::sync::{Arc, LazyLock, OnceLock, Weak};
13
use std::time::Duration;
14
15
use crossbeam_deque::{Injector, Steal, Stealer, Worker as WorkQueue};
16
use crossbeam_utils::CachePadded;
17
use park_group::ParkGroup;
18
use parking_lot::Mutex;
19
use polars_utils::relaxed_cell::RelaxedCell;
20
use rand::rngs::SmallRng;
21
use rand::{Rng, SeedableRng};
22
use slotmap::SlotMap;
23
pub use task::{AbortOnDropHandle, JoinHandle};
24
use task::{CancelHandle, Runnable};
25
26
static NUM_EXECUTOR_THREADS: RelaxedCell<usize> = RelaxedCell::new_usize(0);
27
pub fn set_num_threads(t: usize) {
28
NUM_EXECUTOR_THREADS.store(t);
29
}
30
31
static GLOBAL_SCHEDULER: OnceLock<Executor> = OnceLock::new();
32
33
thread_local!(
34
/// Used to store which executor thread this is.
35
static TLS_THREAD_ID: Cell<usize> = const { Cell::new(usize::MAX) };
36
);
37
38
static NS_SPENT_BLOCKED: LazyLock<Mutex<HashMap<&'static Location<'static>, u64>>> =
39
LazyLock::new(Mutex::default);
40
41
static TRACK_WAIT_STATISTICS: RelaxedCell<bool> = RelaxedCell::new_bool(false);
42
43
pub fn track_task_wait_statistics(should_track: bool) {
44
TRACK_WAIT_STATISTICS.store(should_track);
45
}
46
47
pub fn get_task_wait_statistics() -> Vec<(&'static Location<'static>, Duration)> {
48
NS_SPENT_BLOCKED
49
.lock()
50
.iter()
51
.map(|(l, ns)| (*l, Duration::from_nanos(*ns)))
52
.collect()
53
}
54
55
pub fn clear_task_wait_statistics() {
56
NS_SPENT_BLOCKED.lock().clear()
57
}
58
59
slotmap::new_key_type! {
60
struct TaskKey;
61
}
62
63
/// High priority tasks are scheduled preferentially over low priority tasks.
64
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
65
pub enum TaskPriority {
66
Low,
67
High,
68
}
69
70
/// Metadata associated with a task to help schedule it and clean it up.
71
struct ScopedTaskMetadata {
72
task_key: TaskKey,
73
completed_tasks: Weak<Mutex<Vec<TaskKey>>>,
74
}
75
76
struct TaskMetadata {
77
spawn_location: &'static Location<'static>,
78
ns_spent_blocked: RelaxedCell<u64>,
79
priority: TaskPriority,
80
freshly_spawned: AtomicBool,
81
scoped: Option<ScopedTaskMetadata>,
82
}
83
84
impl Drop for TaskMetadata {
85
fn drop(&mut self) {
86
*NS_SPENT_BLOCKED
87
.lock()
88
.entry(self.spawn_location)
89
.or_default() += self.ns_spent_blocked.load();
90
if let Some(scoped) = &self.scoped {
91
if let Some(completed_tasks) = scoped.completed_tasks.upgrade() {
92
completed_tasks.lock().push(scoped.task_key);
93
}
94
}
95
}
96
}
97
98
/// A task ready to run.
99
type ReadyTask = Runnable<TaskMetadata>;
100
101
/// A per-thread task list.
102
struct ThreadLocalTaskList {
103
// May be used from any thread.
104
high_prio_tasks_stealer: Stealer<ReadyTask>,
105
106
// SAFETY: these may only be used on the thread this task list belongs to.
107
high_prio_tasks: WorkQueue<ReadyTask>,
108
local_slot: UnsafeCell<Option<ReadyTask>>,
109
}
110
111
unsafe impl Sync for ThreadLocalTaskList {}
112
113
struct Executor {
114
park_group: ParkGroup,
115
thread_task_lists: Vec<CachePadded<ThreadLocalTaskList>>,
116
global_high_prio_task_queue: Injector<ReadyTask>,
117
global_low_prio_task_queue: Injector<ReadyTask>,
118
}
119
120
impl Executor {
121
fn schedule_task(&self, task: ReadyTask) {
122
let thread = TLS_THREAD_ID.get();
123
let meta = task.metadata();
124
let opt_ttl = self.thread_task_lists.get(thread);
125
126
let mut use_global_queue = opt_ttl.is_none();
127
if meta.freshly_spawned.load(Ordering::Relaxed) {
128
use_global_queue = true;
129
meta.freshly_spawned.store(false, Ordering::Relaxed);
130
}
131
132
if use_global_queue {
133
// Scheduled from an unknown thread, add to global queue.
134
if meta.priority == TaskPriority::High {
135
self.global_high_prio_task_queue.push(task);
136
} else {
137
self.global_low_prio_task_queue.push(task);
138
}
139
self.park_group.unpark_one();
140
} else {
141
let ttl = opt_ttl.unwrap();
142
// SAFETY: this slot may only be accessed from the local thread, which we are.
143
let slot = unsafe { &mut *ttl.local_slot.get() };
144
145
if meta.priority == TaskPriority::High {
146
// Insert new task into thread local slot, taking out the old task.
147
let Some(task) = slot.replace(task) else {
148
// We pushed a task into our local slot which was empty. Since
149
// we are already awake, no need to notify anyone.
150
return;
151
};
152
153
ttl.high_prio_tasks.push(task);
154
self.park_group.unpark_one();
155
} else {
156
// Optimization: while this is a low priority task we have no
157
// high priority tasks on this thread so we'll execute this one.
158
if ttl.high_prio_tasks.is_empty() && slot.is_none() {
159
*slot = Some(task);
160
} else {
161
self.global_low_prio_task_queue.push(task);
162
self.park_group.unpark_one();
163
}
164
}
165
}
166
}
167
168
fn try_steal_task<R: Rng>(&self, thread: usize, rng: &mut R) -> Option<ReadyTask> {
169
// Try to get a global task.
170
loop {
171
match self.global_high_prio_task_queue.steal() {
172
Steal::Empty => break,
173
Steal::Success(task) => return Some(task),
174
Steal::Retry => std::hint::spin_loop(),
175
}
176
}
177
178
loop {
179
match self.global_low_prio_task_queue.steal() {
180
Steal::Empty => break,
181
Steal::Success(task) => return Some(task),
182
Steal::Retry => std::hint::spin_loop(),
183
}
184
}
185
186
// Try to steal tasks.
187
let ttl = &self.thread_task_lists[thread];
188
for _ in 0..4 {
189
let mut retry = true;
190
while retry {
191
retry = false;
192
193
for idx in random_permutation(self.thread_task_lists.len() as u32, rng) {
194
let foreign_ttl = &self.thread_task_lists[idx as usize];
195
match foreign_ttl
196
.high_prio_tasks_stealer
197
.steal_batch_and_pop(&ttl.high_prio_tasks)
198
{
199
Steal::Empty => {},
200
Steal::Success(task) => return Some(task),
201
Steal::Retry => retry = true,
202
}
203
}
204
205
std::hint::spin_loop()
206
}
207
}
208
209
None
210
}
211
212
fn runner(&self, thread: usize) {
213
TLS_THREAD_ID.set(thread);
214
215
let mut rng = SmallRng::from_rng(&mut rand::rng());
216
let mut worker = self.park_group.new_worker();
217
let mut last_block_start = None;
218
219
loop {
220
let ttl = &self.thread_task_lists[thread];
221
let task = (|| {
222
// Try to get a task from LIFO slot.
223
if let Some(task) = unsafe { (*ttl.local_slot.get()).take() } {
224
return Some(task);
225
}
226
227
// Try to get a local high-priority task.
228
if let Some(task) = ttl.high_prio_tasks.pop() {
229
return Some(task);
230
}
231
232
// Try to steal a task.
233
if let Some(task) = self.try_steal_task(thread, &mut rng) {
234
return Some(task);
235
}
236
237
// Prepare to park, then try one more steal attempt.
238
let park = worker.prepare_park();
239
if let Some(task) = self.try_steal_task(thread, &mut rng) {
240
return Some(task);
241
}
242
243
if last_block_start.is_none() && TRACK_WAIT_STATISTICS.load() {
244
last_block_start = Some(std::time::Instant::now());
245
}
246
park.park();
247
None
248
})();
249
250
if let Some(task) = task {
251
if let Some(t) = last_block_start.take() {
252
if TRACK_WAIT_STATISTICS.load() {
253
let ns: u64 = t.elapsed().as_nanos().try_into().unwrap();
254
task.metadata().ns_spent_blocked.fetch_add(ns);
255
}
256
}
257
worker.recruit_next();
258
task.run();
259
}
260
}
261
}
262
263
fn global() -> &'static Executor {
264
GLOBAL_SCHEDULER.get_or_init(|| {
265
let mut n_threads = NUM_EXECUTOR_THREADS.load();
266
if n_threads == 0 {
267
n_threads = std::thread::available_parallelism()
268
.map(|n| n.get())
269
.unwrap_or(4);
270
}
271
272
let thread_task_lists = (0..n_threads)
273
.map(|t| {
274
std::thread::Builder::new()
275
.name(format!("async-executor-{t}"))
276
.spawn(move || Self::global().runner(t))
277
.unwrap();
278
279
let high_prio_tasks = WorkQueue::new_lifo();
280
CachePadded::new(ThreadLocalTaskList {
281
high_prio_tasks_stealer: high_prio_tasks.stealer(),
282
high_prio_tasks,
283
local_slot: UnsafeCell::new(None),
284
})
285
})
286
.collect();
287
Self {
288
park_group: ParkGroup::new(),
289
thread_task_lists,
290
global_high_prio_task_queue: Injector::new(),
291
global_low_prio_task_queue: Injector::new(),
292
}
293
})
294
}
295
}
296
297
pub struct TaskScope<'scope, 'env: 'scope> {
298
// Keep track of in-progress tasks so we can forcibly cancel them
299
// when the scope ends, to ensure the lifetimes are respected.
300
// Tasks add their own key to completed_tasks when done so we can
301
// reclaim the memory used by the cancel_handles.
302
cancel_handles: Mutex<SlotMap<TaskKey, CancelHandle>>,
303
completed_tasks: Arc<Mutex<Vec<TaskKey>>>,
304
305
// Copied from std::thread::scope. Necessary to prevent unsoundness.
306
scope: PhantomData<&'scope mut &'scope ()>,
307
env: PhantomData<&'env mut &'env ()>,
308
}
309
310
impl<'scope> TaskScope<'scope, '_> {
311
// Not Drop because that extends lifetimes.
312
fn destroy(&self) {
313
// Make sure all tasks are cancelled.
314
for (_, t) in self.cancel_handles.lock().drain() {
315
t.cancel();
316
}
317
}
318
319
fn clear_completed_tasks(&self) {
320
let mut cancel_handles = self.cancel_handles.lock();
321
for t in self.completed_tasks.lock().drain(..) {
322
cancel_handles.remove(t);
323
}
324
}
325
326
#[track_caller]
327
pub fn spawn_task<F: Future + Send + 'scope>(
328
&self,
329
priority: TaskPriority,
330
fut: F,
331
) -> JoinHandle<F::Output>
332
where
333
<F as Future>::Output: Send + 'static,
334
{
335
let spawn_location = Location::caller();
336
self.clear_completed_tasks();
337
338
let mut runnable = None;
339
let mut join_handle = None;
340
self.cancel_handles.lock().insert_with_key(|task_key| {
341
let (run, jh) = unsafe {
342
// SAFETY: we make sure to cancel this task before 'scope ends.
343
let executor = Executor::global();
344
let on_wake = move |task| executor.schedule_task(task);
345
task::spawn_with_lifetime(
346
fut,
347
on_wake,
348
TaskMetadata {
349
spawn_location,
350
ns_spent_blocked: RelaxedCell::new_u64(0),
351
priority,
352
freshly_spawned: AtomicBool::new(true),
353
scoped: Some(ScopedTaskMetadata {
354
task_key,
355
completed_tasks: Arc::downgrade(&self.completed_tasks),
356
}),
357
},
358
)
359
};
360
let cancel_handle = jh.cancel_handle();
361
runnable = Some(run);
362
join_handle = Some(jh);
363
cancel_handle
364
});
365
runnable.unwrap().schedule();
366
join_handle.unwrap()
367
}
368
}
369
370
pub fn task_scope<'env, F, T>(f: F) -> T
371
where
372
F: for<'scope> FnOnce(&'scope TaskScope<'scope, 'env>) -> T,
373
{
374
// By having this local variable inaccessible to anyone we guarantee
375
// that either abort is called killing the entire process, or that this
376
// executor is properly destroyed.
377
let scope = TaskScope {
378
cancel_handles: Mutex::default(),
379
completed_tasks: Arc::new(Mutex::default()),
380
scope: PhantomData,
381
env: PhantomData,
382
};
383
384
let result = std::panic::catch_unwind(AssertUnwindSafe(|| f(&scope)));
385
386
// Make sure all tasks are properly destroyed.
387
scope.destroy();
388
389
match result {
390
Err(e) => std::panic::resume_unwind(e),
391
Ok(result) => result,
392
}
393
}
394
395
#[track_caller]
396
pub fn spawn<F: Future + Send + 'static>(priority: TaskPriority, fut: F) -> JoinHandle<F::Output>
397
where
398
<F as Future>::Output: Send + 'static,
399
{
400
let spawn_location = Location::caller();
401
let executor = Executor::global();
402
let on_wake = move |task| executor.schedule_task(task);
403
let (runnable, join_handle) = task::spawn(
404
fut,
405
on_wake,
406
TaskMetadata {
407
spawn_location,
408
ns_spent_blocked: RelaxedCell::new_u64(0),
409
priority,
410
freshly_spawned: AtomicBool::new(true),
411
scoped: None,
412
},
413
);
414
runnable.schedule();
415
join_handle
416
}
417
418
fn random_permutation<R: Rng>(len: u32, rng: &mut R) -> impl Iterator<Item = u32> {
419
let modulus = len.next_power_of_two();
420
let halfwidth = modulus.trailing_zeros() / 2;
421
let mask = modulus - 1;
422
let displace_zero = rng.random::<u32>();
423
let odd1 = rng.random::<u32>() | 1;
424
let odd2 = rng.random::<u32>() | 1;
425
let uniform_first = ((rng.random::<u32>() as u64 * len as u64) >> 32) as u32;
426
427
(0..modulus)
428
.map(move |mut i| {
429
// Invertible permutation on [0, modulus).
430
i = i.wrapping_add(displace_zero);
431
i = i.wrapping_mul(odd1);
432
i ^= (i & mask) >> halfwidth;
433
i = i.wrapping_mul(odd2);
434
i & mask
435
})
436
.filter(move |i| *i < len)
437
.map(move |mut i| {
438
i += uniform_first;
439
if i >= len {
440
i -= len;
441
}
442
i
443
})
444
}
445
446