Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
bevyengine
GitHub Repository: bevyengine/bevy
Path: blob/main/crates/bevy_tasks/src/task_pool.rs
6604 views
1
use alloc::{boxed::Box, format, string::String, vec::Vec};
2
use core::{future::Future, marker::PhantomData, mem, panic::AssertUnwindSafe};
3
use std::{
4
thread::{self, JoinHandle},
5
thread_local,
6
};
7
8
use crate::executor::FallibleTask;
9
use bevy_platform::sync::Arc;
10
use concurrent_queue::ConcurrentQueue;
11
use futures_lite::FutureExt;
12
13
use crate::{
14
block_on,
15
thread_executor::{ThreadExecutor, ThreadExecutorTicker},
16
Task,
17
};
18
19
struct CallOnDrop(Option<Arc<dyn Fn() + Send + Sync + 'static>>);
20
21
impl Drop for CallOnDrop {
22
fn drop(&mut self) {
23
if let Some(call) = self.0.as_ref() {
24
call();
25
}
26
}
27
}
28
29
/// Used to create a [`TaskPool`]
30
#[derive(Default)]
31
#[must_use]
32
pub struct TaskPoolBuilder {
33
/// If set, we'll set up the thread pool to use at most `num_threads` threads.
34
/// Otherwise use the logical core count of the system
35
num_threads: Option<usize>,
36
/// If set, we'll use the given stack size rather than the system default
37
stack_size: Option<usize>,
38
/// Allows customizing the name of the threads - helpful for debugging. If set, threads will
39
/// be named `<thread_name> (<thread_index>)`, i.e. `"MyThreadPool (2)"`.
40
thread_name: Option<String>,
41
42
on_thread_spawn: Option<Arc<dyn Fn() + Send + Sync + 'static>>,
43
on_thread_destroy: Option<Arc<dyn Fn() + Send + Sync + 'static>>,
44
}
45
46
impl TaskPoolBuilder {
47
/// Creates a new [`TaskPoolBuilder`] instance
48
pub fn new() -> Self {
49
Self::default()
50
}
51
52
/// Override the number of threads created for the pool. If unset, we default to the number
53
/// of logical cores of the system
54
pub fn num_threads(mut self, num_threads: usize) -> Self {
55
self.num_threads = Some(num_threads);
56
self
57
}
58
59
/// Override the stack size of the threads created for the pool
60
pub fn stack_size(mut self, stack_size: usize) -> Self {
61
self.stack_size = Some(stack_size);
62
self
63
}
64
65
/// Override the name of the threads created for the pool. If set, threads will
66
/// be named `<thread_name> (<thread_index>)`, i.e. `MyThreadPool (2)`
67
pub fn thread_name(mut self, thread_name: String) -> Self {
68
self.thread_name = Some(thread_name);
69
self
70
}
71
72
/// Sets a callback that is invoked once for every created thread as it starts.
73
///
74
/// This is called on the thread itself and has access to all thread-local storage.
75
/// This will block running async tasks on the thread until the callback completes.
76
pub fn on_thread_spawn(mut self, f: impl Fn() + Send + Sync + 'static) -> Self {
77
let arc = Arc::new(f);
78
79
#[cfg(not(target_has_atomic = "ptr"))]
80
#[expect(
81
unsafe_code,
82
reason = "unsized coercion is an unstable feature for non-std types"
83
)]
84
// SAFETY:
85
// - Coercion from `impl Fn` to `dyn Fn` is valid
86
// - `Arc::from_raw` receives a valid pointer from a previous call to `Arc::into_raw`
87
let arc = unsafe {
88
Arc::from_raw(Arc::into_raw(arc) as *const (dyn Fn() + Send + Sync + 'static))
89
};
90
91
self.on_thread_spawn = Some(arc);
92
self
93
}
94
95
/// Sets a callback that is invoked once for every created thread as it terminates.
96
///
97
/// This is called on the thread itself and has access to all thread-local storage.
98
/// This will block thread termination until the callback completes.
99
pub fn on_thread_destroy(mut self, f: impl Fn() + Send + Sync + 'static) -> Self {
100
let arc = Arc::new(f);
101
102
#[cfg(not(target_has_atomic = "ptr"))]
103
#[expect(
104
unsafe_code,
105
reason = "unsized coercion is an unstable feature for non-std types"
106
)]
107
// SAFETY:
108
// - Coercion from `impl Fn` to `dyn Fn` is valid
109
// - `Arc::from_raw` receives a valid pointer from a previous call to `Arc::into_raw`
110
let arc = unsafe {
111
Arc::from_raw(Arc::into_raw(arc) as *const (dyn Fn() + Send + Sync + 'static))
112
};
113
114
self.on_thread_destroy = Some(arc);
115
self
116
}
117
118
/// Creates a new [`TaskPool`] based on the current options.
119
pub fn build(self) -> TaskPool {
120
TaskPool::new_internal(self)
121
}
122
}
123
124
/// A thread pool for executing tasks.
125
///
126
/// While futures usually need to be polled to be executed, Bevy tasks are being
127
/// automatically driven by the pool on threads owned by the pool. The [`Task`]
128
/// future only needs to be polled in order to receive the result. (For that
129
/// purpose, it is often stored in a component or resource, see the
130
/// `async_compute` example.)
131
///
132
/// If the result is not required, one may also use [`Task::detach`] and the pool
133
/// will still execute a task, even if it is dropped.
134
#[derive(Debug)]
135
pub struct TaskPool {
136
/// The executor for the pool.
137
executor: Arc<crate::executor::Executor<'static>>,
138
139
// The inner state of the pool.
140
threads: Vec<JoinHandle<()>>,
141
shutdown_tx: async_channel::Sender<()>,
142
}
143
144
impl TaskPool {
145
thread_local! {
146
static LOCAL_EXECUTOR: crate::executor::LocalExecutor<'static> = const { crate::executor::LocalExecutor::new() };
147
static THREAD_EXECUTOR: Arc<ThreadExecutor<'static>> = Arc::new(ThreadExecutor::new());
148
}
149
150
/// Each thread should only create one `ThreadExecutor`, otherwise, there are good chances they will deadlock
151
pub fn get_thread_executor() -> Arc<ThreadExecutor<'static>> {
152
Self::THREAD_EXECUTOR.with(Clone::clone)
153
}
154
155
/// Create a `TaskPool` with the default configuration.
156
pub fn new() -> Self {
157
TaskPoolBuilder::new().build()
158
}
159
160
fn new_internal(builder: TaskPoolBuilder) -> Self {
161
let (shutdown_tx, shutdown_rx) = async_channel::unbounded::<()>();
162
163
let executor = Arc::new(crate::executor::Executor::new());
164
165
let num_threads = builder
166
.num_threads
167
.unwrap_or_else(crate::available_parallelism);
168
169
let threads = (0..num_threads)
170
.map(|i| {
171
let ex = Arc::clone(&executor);
172
let shutdown_rx = shutdown_rx.clone();
173
174
let thread_name = if let Some(thread_name) = builder.thread_name.as_deref() {
175
format!("{thread_name} ({i})")
176
} else {
177
format!("TaskPool ({i})")
178
};
179
let mut thread_builder = thread::Builder::new().name(thread_name);
180
181
if let Some(stack_size) = builder.stack_size {
182
thread_builder = thread_builder.stack_size(stack_size);
183
}
184
185
let on_thread_spawn = builder.on_thread_spawn.clone();
186
let on_thread_destroy = builder.on_thread_destroy.clone();
187
188
thread_builder
189
.spawn(move || {
190
TaskPool::LOCAL_EXECUTOR.with(|local_executor| {
191
if let Some(on_thread_spawn) = on_thread_spawn {
192
on_thread_spawn();
193
drop(on_thread_spawn);
194
}
195
let _destructor = CallOnDrop(on_thread_destroy);
196
loop {
197
let res = std::panic::catch_unwind(|| {
198
let tick_forever = async move {
199
loop {
200
local_executor.tick().await;
201
}
202
};
203
block_on(ex.run(tick_forever.or(shutdown_rx.recv())))
204
});
205
if let Ok(value) = res {
206
// Use unwrap_err because we expect a Closed error
207
value.unwrap_err();
208
break;
209
}
210
}
211
});
212
})
213
.expect("Failed to spawn thread.")
214
})
215
.collect();
216
217
Self {
218
executor,
219
threads,
220
shutdown_tx,
221
}
222
}
223
224
/// Return the number of threads owned by the task pool
225
pub fn thread_num(&self) -> usize {
226
self.threads.len()
227
}
228
229
/// Allows spawning non-`'static` futures on the thread pool. The function takes a callback,
230
/// passing a scope object into it. The scope object provided to the callback can be used
231
/// to spawn tasks. This function will await the completion of all tasks before returning.
232
///
233
/// This is similar to [`thread::scope`] and `rayon::scope`.
234
///
235
/// # Example
236
///
237
/// ```
238
/// use bevy_tasks::TaskPool;
239
///
240
/// let pool = TaskPool::new();
241
/// let mut x = 0;
242
/// let results = pool.scope(|s| {
243
/// s.spawn(async {
244
/// // you can borrow the spawner inside a task and spawn tasks from within the task
245
/// s.spawn(async {
246
/// // borrow x and mutate it.
247
/// x = 2;
248
/// // return a value from the task
249
/// 1
250
/// });
251
/// // return some other value from the first task
252
/// 0
253
/// });
254
/// });
255
///
256
/// // The ordering of results is non-deterministic if you spawn from within tasks as above.
257
/// // If you're doing this, you'll have to write your code to not depend on the ordering.
258
/// assert!(results.contains(&0));
259
/// assert!(results.contains(&1));
260
///
261
/// // The ordering is deterministic if you only spawn directly from the closure function.
262
/// let results = pool.scope(|s| {
263
/// s.spawn(async { 0 });
264
/// s.spawn(async { 1 });
265
/// });
266
/// assert_eq!(&results[..], &[0, 1]);
267
///
268
/// // You can access x after scope runs, since it was only temporarily borrowed in the scope.
269
/// assert_eq!(x, 2);
270
/// ```
271
///
272
/// # Lifetimes
273
///
274
/// The [`Scope`] object takes two lifetimes: `'scope` and `'env`.
275
///
276
/// The `'scope` lifetime represents the lifetime of the scope. That is the time during
277
/// which the provided closure and tasks that are spawned into the scope are run.
278
///
279
/// The `'env` lifetime represents the lifetime of whatever is borrowed by the scope.
280
/// Thus this lifetime must outlive `'scope`.
281
///
282
/// ```compile_fail
283
/// use bevy_tasks::TaskPool;
284
/// fn scope_escapes_closure() {
285
/// let pool = TaskPool::new();
286
/// let foo = Box::new(42);
287
/// pool.scope(|scope| {
288
/// std::thread::spawn(move || {
289
/// // UB. This could spawn on the scope after `.scope` returns and the internal Scope is dropped.
290
/// scope.spawn(async move {
291
/// assert_eq!(*foo, 42);
292
/// });
293
/// });
294
/// });
295
/// }
296
/// ```
297
///
298
/// ```compile_fail
299
/// use bevy_tasks::TaskPool;
300
/// fn cannot_borrow_from_closure() {
301
/// let pool = TaskPool::new();
302
/// pool.scope(|scope| {
303
/// let x = 1;
304
/// let y = &x;
305
/// scope.spawn(async move {
306
/// assert_eq!(*y, 1);
307
/// });
308
/// });
309
/// }
310
pub fn scope<'env, F, T>(&self, f: F) -> Vec<T>
311
where
312
F: for<'scope> FnOnce(&'scope Scope<'scope, 'env, T>),
313
T: Send + 'static,
314
{
315
Self::THREAD_EXECUTOR.with(|scope_executor| {
316
self.scope_with_executor_inner(true, scope_executor, scope_executor, f)
317
})
318
}
319
320
/// This allows passing an external executor to spawn tasks on. When you pass an external executor
321
/// [`Scope::spawn_on_scope`] spawns is then run on the thread that [`ThreadExecutor`] is being ticked on.
322
/// If [`None`] is passed the scope will use a [`ThreadExecutor`] that is ticked on the current thread.
323
///
324
/// When `tick_task_pool_executor` is set to `true`, the multithreaded task stealing executor is ticked on the scope
325
/// thread. Disabling this can be useful when finishing the scope is latency sensitive. Pulling tasks from
326
/// global executor can run tasks unrelated to the scope and delay when the scope returns.
327
///
328
/// See [`Self::scope`] for more details in general about how scopes work.
329
pub fn scope_with_executor<'env, F, T>(
330
&self,
331
tick_task_pool_executor: bool,
332
external_executor: Option<&ThreadExecutor>,
333
f: F,
334
) -> Vec<T>
335
where
336
F: for<'scope> FnOnce(&'scope Scope<'scope, 'env, T>),
337
T: Send + 'static,
338
{
339
Self::THREAD_EXECUTOR.with(|scope_executor| {
340
// If an `external_executor` is passed, use that. Otherwise, get the executor stored
341
// in the `THREAD_EXECUTOR` thread local.
342
if let Some(external_executor) = external_executor {
343
self.scope_with_executor_inner(
344
tick_task_pool_executor,
345
external_executor,
346
scope_executor,
347
f,
348
)
349
} else {
350
self.scope_with_executor_inner(
351
tick_task_pool_executor,
352
scope_executor,
353
scope_executor,
354
f,
355
)
356
}
357
})
358
}
359
360
#[expect(unsafe_code, reason = "Required to transmute lifetimes.")]
361
fn scope_with_executor_inner<'env, F, T>(
362
&self,
363
tick_task_pool_executor: bool,
364
external_executor: &ThreadExecutor,
365
scope_executor: &ThreadExecutor,
366
f: F,
367
) -> Vec<T>
368
where
369
F: for<'scope> FnOnce(&'scope Scope<'scope, 'env, T>),
370
T: Send + 'static,
371
{
372
// SAFETY: This safety comment applies to all references transmuted to 'env.
373
// Any futures spawned with these references need to return before this function completes.
374
// This is guaranteed because we drive all the futures spawned onto the Scope
375
// to completion in this function. However, rust has no way of knowing this so we
376
// transmute the lifetimes to 'env here to appease the compiler as it is unable to validate safety.
377
// Any usages of the references passed into `Scope` must be accessed through
378
// the transmuted reference for the rest of this function.
379
let executor: &crate::executor::Executor = &self.executor;
380
// SAFETY: As above, all futures must complete in this function so we can change the lifetime
381
let executor: &'env crate::executor::Executor = unsafe { mem::transmute(executor) };
382
// SAFETY: As above, all futures must complete in this function so we can change the lifetime
383
let external_executor: &'env ThreadExecutor<'env> =
384
unsafe { mem::transmute(external_executor) };
385
// SAFETY: As above, all futures must complete in this function so we can change the lifetime
386
let scope_executor: &'env ThreadExecutor<'env> = unsafe { mem::transmute(scope_executor) };
387
let spawned: ConcurrentQueue<FallibleTask<Result<T, Box<dyn core::any::Any + Send>>>> =
388
ConcurrentQueue::unbounded();
389
// shadow the variable so that the owned value cannot be used for the rest of the function
390
// SAFETY: As above, all futures must complete in this function so we can change the lifetime
391
let spawned: &'env ConcurrentQueue<
392
FallibleTask<Result<T, Box<dyn core::any::Any + Send>>>,
393
> = unsafe { mem::transmute(&spawned) };
394
395
let scope = Scope {
396
executor,
397
external_executor,
398
scope_executor,
399
spawned,
400
scope: PhantomData,
401
env: PhantomData,
402
};
403
404
// shadow the variable so that the owned value cannot be used for the rest of the function
405
// SAFETY: As above, all futures must complete in this function so we can change the lifetime
406
let scope: &'env Scope<'_, 'env, T> = unsafe { mem::transmute(&scope) };
407
408
f(scope);
409
410
if spawned.is_empty() {
411
Vec::new()
412
} else {
413
block_on(async move {
414
let get_results = async {
415
let mut results = Vec::with_capacity(spawned.len());
416
while let Ok(task) = spawned.pop() {
417
if let Some(res) = task.await {
418
match res {
419
Ok(res) => results.push(res),
420
Err(payload) => std::panic::resume_unwind(payload),
421
}
422
} else {
423
panic!("Failed to catch panic!");
424
}
425
}
426
results
427
};
428
429
let tick_task_pool_executor = tick_task_pool_executor || self.threads.is_empty();
430
431
// we get this from a thread local so we should always be on the scope executors thread.
432
// note: it is possible `scope_executor` and `external_executor` is the same executor,
433
// in that case, we should only tick one of them, otherwise, it may cause deadlock.
434
let scope_ticker = scope_executor.ticker().unwrap();
435
let external_ticker = if !external_executor.is_same(scope_executor) {
436
external_executor.ticker()
437
} else {
438
None
439
};
440
441
match (external_ticker, tick_task_pool_executor) {
442
(Some(external_ticker), true) => {
443
Self::execute_global_external_scope(
444
executor,
445
external_ticker,
446
scope_ticker,
447
get_results,
448
)
449
.await
450
}
451
(Some(external_ticker), false) => {
452
Self::execute_external_scope(external_ticker, scope_ticker, get_results)
453
.await
454
}
455
// either external_executor is none or it is same as scope_executor
456
(None, true) => {
457
Self::execute_global_scope(executor, scope_ticker, get_results).await
458
}
459
(None, false) => Self::execute_scope(scope_ticker, get_results).await,
460
}
461
})
462
}
463
}
464
465
#[inline]
466
async fn execute_global_external_scope<'scope, 'ticker, T>(
467
executor: &'scope crate::executor::Executor<'scope>,
468
external_ticker: ThreadExecutorTicker<'scope, 'ticker>,
469
scope_ticker: ThreadExecutorTicker<'scope, 'ticker>,
470
get_results: impl Future<Output = Vec<T>>,
471
) -> Vec<T> {
472
// we restart the executors if a task errors. if a scoped
473
// task errors it will panic the scope on the call to get_results
474
let execute_forever = async move {
475
loop {
476
let tick_forever = async {
477
loop {
478
external_ticker.tick().or(scope_ticker.tick()).await;
479
}
480
};
481
// we don't care if it errors. If a scoped task errors it will propagate
482
// to get_results
483
let _result = AssertUnwindSafe(executor.run(tick_forever))
484
.catch_unwind()
485
.await
486
.is_ok();
487
}
488
};
489
get_results.or(execute_forever).await
490
}
491
492
#[inline]
493
async fn execute_external_scope<'scope, 'ticker, T>(
494
external_ticker: ThreadExecutorTicker<'scope, 'ticker>,
495
scope_ticker: ThreadExecutorTicker<'scope, 'ticker>,
496
get_results: impl Future<Output = Vec<T>>,
497
) -> Vec<T> {
498
let execute_forever = async {
499
loop {
500
let tick_forever = async {
501
loop {
502
external_ticker.tick().or(scope_ticker.tick()).await;
503
}
504
};
505
let _result = AssertUnwindSafe(tick_forever).catch_unwind().await.is_ok();
506
}
507
};
508
get_results.or(execute_forever).await
509
}
510
511
#[inline]
512
async fn execute_global_scope<'scope, 'ticker, T>(
513
executor: &'scope crate::executor::Executor<'scope>,
514
scope_ticker: ThreadExecutorTicker<'scope, 'ticker>,
515
get_results: impl Future<Output = Vec<T>>,
516
) -> Vec<T> {
517
let execute_forever = async {
518
loop {
519
let tick_forever = async {
520
loop {
521
scope_ticker.tick().await;
522
}
523
};
524
let _result = AssertUnwindSafe(executor.run(tick_forever))
525
.catch_unwind()
526
.await
527
.is_ok();
528
}
529
};
530
get_results.or(execute_forever).await
531
}
532
533
#[inline]
534
async fn execute_scope<'scope, 'ticker, T>(
535
scope_ticker: ThreadExecutorTicker<'scope, 'ticker>,
536
get_results: impl Future<Output = Vec<T>>,
537
) -> Vec<T> {
538
let execute_forever = async {
539
loop {
540
let tick_forever = async {
541
loop {
542
scope_ticker.tick().await;
543
}
544
};
545
let _result = AssertUnwindSafe(tick_forever).catch_unwind().await.is_ok();
546
}
547
};
548
get_results.or(execute_forever).await
549
}
550
551
/// Spawns a static future onto the thread pool. The returned [`Task`] is a
552
/// future that can be polled for the result. It can also be canceled and
553
/// "detached", allowing the task to continue running even if dropped. In
554
/// any case, the pool will execute the task even without polling by the
555
/// end-user.
556
///
557
/// If the provided future is non-`Send`, [`TaskPool::spawn_local`] should
558
/// be used instead.
559
pub fn spawn<T>(&self, future: impl Future<Output = T> + Send + 'static) -> Task<T>
560
where
561
T: Send + 'static,
562
{
563
Task::new(self.executor.spawn(future))
564
}
565
566
/// Spawns a static future on the thread-local async executor for the
567
/// current thread. The task will run entirely on the thread the task was
568
/// spawned on.
569
///
570
/// The returned [`Task`] is a future that can be polled for the
571
/// result. It can also be canceled and "detached", allowing the task to
572
/// continue running even if dropped. In any case, the pool will execute the
573
/// task even without polling by the end-user.
574
///
575
/// Users should generally prefer to use [`TaskPool::spawn`] instead,
576
/// unless the provided future is not `Send`.
577
pub fn spawn_local<T>(&self, future: impl Future<Output = T> + 'static) -> Task<T>
578
where
579
T: 'static,
580
{
581
Task::new(TaskPool::LOCAL_EXECUTOR.with(|executor| executor.spawn(future)))
582
}
583
584
/// Runs a function with the local executor. Typically used to tick
585
/// the local executor on the main thread as it needs to share time with
586
/// other things.
587
///
588
/// ```
589
/// use bevy_tasks::TaskPool;
590
///
591
/// TaskPool::new().with_local_executor(|local_executor| {
592
/// local_executor.try_tick();
593
/// });
594
/// ```
595
pub fn with_local_executor<F, R>(&self, f: F) -> R
596
where
597
F: FnOnce(&crate::executor::LocalExecutor) -> R,
598
{
599
Self::LOCAL_EXECUTOR.with(f)
600
}
601
}
602
603
impl Default for TaskPool {
604
fn default() -> Self {
605
Self::new()
606
}
607
}
608
609
impl Drop for TaskPool {
610
fn drop(&mut self) {
611
self.shutdown_tx.close();
612
613
let panicking = thread::panicking();
614
for join_handle in self.threads.drain(..) {
615
let res = join_handle.join();
616
if !panicking {
617
res.expect("Task thread panicked while executing.");
618
}
619
}
620
}
621
}
622
623
/// A [`TaskPool`] scope for running one or more non-`'static` futures.
624
///
625
/// For more information, see [`TaskPool::scope`].
626
#[derive(Debug)]
627
pub struct Scope<'scope, 'env: 'scope, T> {
628
executor: &'scope crate::executor::Executor<'scope>,
629
external_executor: &'scope ThreadExecutor<'scope>,
630
scope_executor: &'scope ThreadExecutor<'scope>,
631
spawned: &'scope ConcurrentQueue<FallibleTask<Result<T, Box<dyn core::any::Any + Send>>>>,
632
// make `Scope` invariant over 'scope and 'env
633
scope: PhantomData<&'scope mut &'scope ()>,
634
env: PhantomData<&'env mut &'env ()>,
635
}
636
637
impl<'scope, 'env, T: Send + 'scope> Scope<'scope, 'env, T> {
638
/// Spawns a scoped future onto the thread pool. The scope *must* outlive
639
/// the provided future. The results of the future will be returned as a part of
640
/// [`TaskPool::scope`]'s return value.
641
///
642
/// For futures that should run on the thread `scope` is called on [`Scope::spawn_on_scope`] should be used
643
/// instead.
644
///
645
/// For more information, see [`TaskPool::scope`].
646
pub fn spawn<Fut: Future<Output = T> + 'scope + Send>(&self, f: Fut) {
647
let task = self
648
.executor
649
.spawn(AssertUnwindSafe(f).catch_unwind())
650
.fallible();
651
// ConcurrentQueue only errors when closed or full, but we never
652
// close and use an unbounded queue, so it is safe to unwrap
653
self.spawned.push(task).unwrap();
654
}
655
656
/// Spawns a scoped future onto the thread the scope is run on. The scope *must* outlive
657
/// the provided future. The results of the future will be returned as a part of
658
/// [`TaskPool::scope`]'s return value. Users should generally prefer to use
659
/// [`Scope::spawn`] instead, unless the provided future needs to run on the scope's thread.
660
///
661
/// For more information, see [`TaskPool::scope`].
662
pub fn spawn_on_scope<Fut: Future<Output = T> + 'scope + Send>(&self, f: Fut) {
663
let task = self
664
.scope_executor
665
.spawn(AssertUnwindSafe(f).catch_unwind())
666
.fallible();
667
// ConcurrentQueue only errors when closed or full, but we never
668
// close and use an unbounded queue, so it is safe to unwrap
669
self.spawned.push(task).unwrap();
670
}
671
672
/// Spawns a scoped future onto the thread of the external thread executor.
673
/// This is typically the main thread. The scope *must* outlive
674
/// the provided future. The results of the future will be returned as a part of
675
/// [`TaskPool::scope`]'s return value. Users should generally prefer to use
676
/// [`Scope::spawn`] instead, unless the provided future needs to run on the external thread.
677
///
678
/// For more information, see [`TaskPool::scope`].
679
pub fn spawn_on_external<Fut: Future<Output = T> + 'scope + Send>(&self, f: Fut) {
680
let task = self
681
.external_executor
682
.spawn(AssertUnwindSafe(f).catch_unwind())
683
.fallible();
684
// ConcurrentQueue only errors when closed or full, but we never
685
// close and use an unbounded queue, so it is safe to unwrap
686
self.spawned.push(task).unwrap();
687
}
688
}
689
690
impl<'scope, 'env, T> Drop for Scope<'scope, 'env, T>
691
where
692
T: 'scope,
693
{
694
fn drop(&mut self) {
695
block_on(async {
696
while let Ok(task) = self.spawned.pop() {
697
task.cancel().await;
698
}
699
});
700
}
701
}
702
703
#[cfg(test)]
704
mod tests {
705
use super::*;
706
use core::sync::atomic::{AtomicBool, AtomicI32, Ordering};
707
use std::sync::Barrier;
708
709
#[test]
710
fn test_spawn() {
711
let pool = TaskPool::new();
712
713
let foo = Box::new(42);
714
let foo = &*foo;
715
716
let count = Arc::new(AtomicI32::new(0));
717
718
let outputs = pool.scope(|scope| {
719
for _ in 0..100 {
720
let count_clone = count.clone();
721
scope.spawn(async move {
722
if *foo != 42 {
723
panic!("not 42!?!?")
724
} else {
725
count_clone.fetch_add(1, Ordering::Relaxed);
726
*foo
727
}
728
});
729
}
730
});
731
732
for output in &outputs {
733
assert_eq!(*output, 42);
734
}
735
736
assert_eq!(outputs.len(), 100);
737
assert_eq!(count.load(Ordering::Relaxed), 100);
738
}
739
740
#[test]
741
fn test_thread_callbacks() {
742
let counter = Arc::new(AtomicI32::new(0));
743
let start_counter = counter.clone();
744
{
745
let barrier = Arc::new(Barrier::new(11));
746
let last_barrier = barrier.clone();
747
// Build and immediately drop to terminate
748
let _pool = TaskPoolBuilder::new()
749
.num_threads(10)
750
.on_thread_spawn(move || {
751
start_counter.fetch_add(1, Ordering::Relaxed);
752
barrier.clone().wait();
753
})
754
.build();
755
last_barrier.wait();
756
assert_eq!(10, counter.load(Ordering::Relaxed));
757
}
758
assert_eq!(10, counter.load(Ordering::Relaxed));
759
let end_counter = counter.clone();
760
{
761
let _pool = TaskPoolBuilder::new()
762
.num_threads(20)
763
.on_thread_destroy(move || {
764
end_counter.fetch_sub(1, Ordering::Relaxed);
765
})
766
.build();
767
assert_eq!(10, counter.load(Ordering::Relaxed));
768
}
769
assert_eq!(-10, counter.load(Ordering::Relaxed));
770
let start_counter = counter.clone();
771
let end_counter = counter.clone();
772
{
773
let barrier = Arc::new(Barrier::new(6));
774
let last_barrier = barrier.clone();
775
let _pool = TaskPoolBuilder::new()
776
.num_threads(5)
777
.on_thread_spawn(move || {
778
start_counter.fetch_add(1, Ordering::Relaxed);
779
barrier.wait();
780
})
781
.on_thread_destroy(move || {
782
end_counter.fetch_sub(1, Ordering::Relaxed);
783
})
784
.build();
785
last_barrier.wait();
786
assert_eq!(-5, counter.load(Ordering::Relaxed));
787
}
788
assert_eq!(-10, counter.load(Ordering::Relaxed));
789
}
790
791
#[test]
792
fn test_mixed_spawn_on_scope_and_spawn() {
793
let pool = TaskPool::new();
794
795
let foo = Box::new(42);
796
let foo = &*foo;
797
798
let local_count = Arc::new(AtomicI32::new(0));
799
let non_local_count = Arc::new(AtomicI32::new(0));
800
801
let outputs = pool.scope(|scope| {
802
for i in 0..100 {
803
if i % 2 == 0 {
804
let count_clone = non_local_count.clone();
805
scope.spawn(async move {
806
if *foo != 42 {
807
panic!("not 42!?!?")
808
} else {
809
count_clone.fetch_add(1, Ordering::Relaxed);
810
*foo
811
}
812
});
813
} else {
814
let count_clone = local_count.clone();
815
scope.spawn_on_scope(async move {
816
if *foo != 42 {
817
panic!("not 42!?!?")
818
} else {
819
count_clone.fetch_add(1, Ordering::Relaxed);
820
*foo
821
}
822
});
823
}
824
}
825
});
826
827
for output in &outputs {
828
assert_eq!(*output, 42);
829
}
830
831
assert_eq!(outputs.len(), 100);
832
assert_eq!(local_count.load(Ordering::Relaxed), 50);
833
assert_eq!(non_local_count.load(Ordering::Relaxed), 50);
834
}
835
836
#[test]
837
fn test_thread_locality() {
838
let pool = Arc::new(TaskPool::new());
839
let count = Arc::new(AtomicI32::new(0));
840
let barrier = Arc::new(Barrier::new(101));
841
let thread_check_failed = Arc::new(AtomicBool::new(false));
842
843
for _ in 0..100 {
844
let inner_barrier = barrier.clone();
845
let count_clone = count.clone();
846
let inner_pool = pool.clone();
847
let inner_thread_check_failed = thread_check_failed.clone();
848
thread::spawn(move || {
849
inner_pool.scope(|scope| {
850
let inner_count_clone = count_clone.clone();
851
scope.spawn(async move {
852
inner_count_clone.fetch_add(1, Ordering::Release);
853
});
854
let spawner = thread::current().id();
855
let inner_count_clone = count_clone.clone();
856
scope.spawn_on_scope(async move {
857
inner_count_clone.fetch_add(1, Ordering::Release);
858
if thread::current().id() != spawner {
859
// NOTE: This check is using an atomic rather than simply panicking the
860
// thread to avoid deadlocking the barrier on failure
861
inner_thread_check_failed.store(true, Ordering::Release);
862
}
863
});
864
});
865
inner_barrier.wait();
866
});
867
}
868
barrier.wait();
869
assert!(!thread_check_failed.load(Ordering::Acquire));
870
assert_eq!(count.load(Ordering::Acquire), 200);
871
}
872
873
#[test]
874
fn test_nested_spawn() {
875
let pool = TaskPool::new();
876
877
let foo = Box::new(42);
878
let foo = &*foo;
879
880
let count = Arc::new(AtomicI32::new(0));
881
882
let outputs: Vec<i32> = pool.scope(|scope| {
883
for _ in 0..10 {
884
let count_clone = count.clone();
885
scope.spawn(async move {
886
for _ in 0..10 {
887
let count_clone_clone = count_clone.clone();
888
scope.spawn(async move {
889
if *foo != 42 {
890
panic!("not 42!?!?")
891
} else {
892
count_clone_clone.fetch_add(1, Ordering::Relaxed);
893
*foo
894
}
895
});
896
}
897
*foo
898
});
899
}
900
});
901
902
for output in &outputs {
903
assert_eq!(*output, 42);
904
}
905
906
// the inner loop runs 100 times and the outer one runs 10. 100 + 10
907
assert_eq!(outputs.len(), 110);
908
assert_eq!(count.load(Ordering::Relaxed), 100);
909
}
910
911
#[test]
912
fn test_nested_locality() {
913
let pool = Arc::new(TaskPool::new());
914
let count = Arc::new(AtomicI32::new(0));
915
let barrier = Arc::new(Barrier::new(101));
916
let thread_check_failed = Arc::new(AtomicBool::new(false));
917
918
for _ in 0..100 {
919
let inner_barrier = barrier.clone();
920
let count_clone = count.clone();
921
let inner_pool = pool.clone();
922
let inner_thread_check_failed = thread_check_failed.clone();
923
thread::spawn(move || {
924
inner_pool.scope(|scope| {
925
let spawner = thread::current().id();
926
let inner_count_clone = count_clone.clone();
927
scope.spawn(async move {
928
inner_count_clone.fetch_add(1, Ordering::Release);
929
930
// spawning on the scope from another thread runs the futures on the scope's thread
931
scope.spawn_on_scope(async move {
932
inner_count_clone.fetch_add(1, Ordering::Release);
933
if thread::current().id() != spawner {
934
// NOTE: This check is using an atomic rather than simply panicking the
935
// thread to avoid deadlocking the barrier on failure
936
inner_thread_check_failed.store(true, Ordering::Release);
937
}
938
});
939
});
940
});
941
inner_barrier.wait();
942
});
943
}
944
barrier.wait();
945
assert!(!thread_check_failed.load(Ordering::Acquire));
946
assert_eq!(count.load(Ordering::Acquire), 200);
947
}
948
949
// This test will often freeze on other executors.
950
#[test]
951
fn test_nested_scopes() {
952
let pool = TaskPool::new();
953
let count = Arc::new(AtomicI32::new(0));
954
955
pool.scope(|scope| {
956
scope.spawn(async {
957
pool.scope(|scope| {
958
scope.spawn(async {
959
count.fetch_add(1, Ordering::Relaxed);
960
});
961
});
962
});
963
});
964
965
assert_eq!(count.load(Ordering::Acquire), 1);
966
}
967
}
968
969