use alloc::{boxed::Box, format, string::String, vec::Vec};
use core::{future::Future, marker::PhantomData, mem, panic::AssertUnwindSafe};
use std::{
thread::{self, JoinHandle},
thread_local,
};
use crate::executor::FallibleTask;
use bevy_platform::sync::Arc;
use concurrent_queue::ConcurrentQueue;
use futures_lite::FutureExt;
use crate::{
block_on,
thread_executor::{ThreadExecutor, ThreadExecutorTicker},
Task,
};
struct CallOnDrop(Option<Arc<dyn Fn() + Send + Sync + 'static>>);
impl Drop for CallOnDrop {
fn drop(&mut self) {
if let Some(call) = self.0.as_ref() {
call();
}
}
}
#[derive(Default)]
#[must_use]
pub struct TaskPoolBuilder {
num_threads: Option<usize>,
stack_size: Option<usize>,
thread_name: Option<String>,
on_thread_spawn: Option<Arc<dyn Fn() + Send + Sync + 'static>>,
on_thread_destroy: Option<Arc<dyn Fn() + Send + Sync + 'static>>,
}
impl TaskPoolBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn num_threads(mut self, num_threads: usize) -> Self {
self.num_threads = Some(num_threads);
self
}
pub fn stack_size(mut self, stack_size: usize) -> Self {
self.stack_size = Some(stack_size);
self
}
pub fn thread_name(mut self, thread_name: String) -> Self {
self.thread_name = Some(thread_name);
self
}
pub fn on_thread_spawn(mut self, f: impl Fn() + Send + Sync + 'static) -> Self {
let arc = Arc::new(f);
#[cfg(not(target_has_atomic = "ptr"))]
#[expect(
unsafe_code,
reason = "unsized coercion is an unstable feature for non-std types"
)]
let arc = unsafe {
Arc::from_raw(Arc::into_raw(arc) as *const (dyn Fn() + Send + Sync + 'static))
};
self.on_thread_spawn = Some(arc);
self
}
pub fn on_thread_destroy(mut self, f: impl Fn() + Send + Sync + 'static) -> Self {
let arc = Arc::new(f);
#[cfg(not(target_has_atomic = "ptr"))]
#[expect(
unsafe_code,
reason = "unsized coercion is an unstable feature for non-std types"
)]
let arc = unsafe {
Arc::from_raw(Arc::into_raw(arc) as *const (dyn Fn() + Send + Sync + 'static))
};
self.on_thread_destroy = Some(arc);
self
}
pub fn build(self) -> TaskPool {
TaskPool::new_internal(self)
}
}
#[derive(Debug)]
pub struct TaskPool {
executor: Arc<crate::executor::Executor<'static>>,
threads: Vec<JoinHandle<()>>,
shutdown_tx: async_channel::Sender<()>,
}
impl TaskPool {
thread_local! {
static LOCAL_EXECUTOR: crate::executor::LocalExecutor<'static> = const { crate::executor::LocalExecutor::new() };
static THREAD_EXECUTOR: Arc<ThreadExecutor<'static>> = Arc::new(ThreadExecutor::new());
}
pub fn get_thread_executor() -> Arc<ThreadExecutor<'static>> {
Self::THREAD_EXECUTOR.with(Clone::clone)
}
pub fn new() -> Self {
TaskPoolBuilder::new().build()
}
fn new_internal(builder: TaskPoolBuilder) -> Self {
let (shutdown_tx, shutdown_rx) = async_channel::unbounded::<()>();
let executor = Arc::new(crate::executor::Executor::new());
let num_threads = builder
.num_threads
.unwrap_or_else(crate::available_parallelism);
let threads = (0..num_threads)
.map(|i| {
let ex = Arc::clone(&executor);
let shutdown_rx = shutdown_rx.clone();
let thread_name = if let Some(thread_name) = builder.thread_name.as_deref() {
format!("{thread_name} ({i})")
} else {
format!("TaskPool ({i})")
};
let mut thread_builder = thread::Builder::new().name(thread_name);
if let Some(stack_size) = builder.stack_size {
thread_builder = thread_builder.stack_size(stack_size);
}
let on_thread_spawn = builder.on_thread_spawn.clone();
let on_thread_destroy = builder.on_thread_destroy.clone();
thread_builder
.spawn(move || {
TaskPool::LOCAL_EXECUTOR.with(|local_executor| {
if let Some(on_thread_spawn) = on_thread_spawn {
on_thread_spawn();
drop(on_thread_spawn);
}
let _destructor = CallOnDrop(on_thread_destroy);
loop {
let res = std::panic::catch_unwind(|| {
let tick_forever = async move {
loop {
local_executor.tick().await;
}
};
block_on(ex.run(tick_forever.or(shutdown_rx.recv())))
});
if let Ok(value) = res {
value.unwrap_err();
break;
}
}
});
})
.expect("Failed to spawn thread.")
})
.collect();
Self {
executor,
threads,
shutdown_tx,
}
}
pub fn thread_num(&self) -> usize {
self.threads.len()
}
pub fn scope<'env, F, T>(&self, f: F) -> Vec<T>
where
F: for<'scope> FnOnce(&'scope Scope<'scope, 'env, T>),
T: Send + 'static,
{
Self::THREAD_EXECUTOR.with(|scope_executor| {
self.scope_with_executor_inner(true, scope_executor, scope_executor, f)
})
}
pub fn scope_with_executor<'env, F, T>(
&self,
tick_task_pool_executor: bool,
external_executor: Option<&ThreadExecutor>,
f: F,
) -> Vec<T>
where
F: for<'scope> FnOnce(&'scope Scope<'scope, 'env, T>),
T: Send + 'static,
{
Self::THREAD_EXECUTOR.with(|scope_executor| {
if let Some(external_executor) = external_executor {
self.scope_with_executor_inner(
tick_task_pool_executor,
external_executor,
scope_executor,
f,
)
} else {
self.scope_with_executor_inner(
tick_task_pool_executor,
scope_executor,
scope_executor,
f,
)
}
})
}
#[expect(unsafe_code, reason = "Required to transmute lifetimes.")]
fn scope_with_executor_inner<'env, F, T>(
&self,
tick_task_pool_executor: bool,
external_executor: &ThreadExecutor,
scope_executor: &ThreadExecutor,
f: F,
) -> Vec<T>
where
F: for<'scope> FnOnce(&'scope Scope<'scope, 'env, T>),
T: Send + 'static,
{
let executor: &crate::executor::Executor = &self.executor;
let executor: &'env crate::executor::Executor = unsafe { mem::transmute(executor) };
let external_executor: &'env ThreadExecutor<'env> =
unsafe { mem::transmute(external_executor) };
let scope_executor: &'env ThreadExecutor<'env> = unsafe { mem::transmute(scope_executor) };
let spawned: ConcurrentQueue<FallibleTask<Result<T, Box<dyn core::any::Any + Send>>>> =
ConcurrentQueue::unbounded();
let spawned: &'env ConcurrentQueue<
FallibleTask<Result<T, Box<dyn core::any::Any + Send>>>,
> = unsafe { mem::transmute(&spawned) };
let scope = Scope {
executor,
external_executor,
scope_executor,
spawned,
scope: PhantomData,
env: PhantomData,
};
let scope: &'env Scope<'_, 'env, T> = unsafe { mem::transmute(&scope) };
f(scope);
if spawned.is_empty() {
Vec::new()
} else {
block_on(async move {
let get_results = async {
let mut results = Vec::with_capacity(spawned.len());
while let Ok(task) = spawned.pop() {
if let Some(res) = task.await {
match res {
Ok(res) => results.push(res),
Err(payload) => std::panic::resume_unwind(payload),
}
} else {
panic!("Failed to catch panic!");
}
}
results
};
let tick_task_pool_executor = tick_task_pool_executor || self.threads.is_empty();
let scope_ticker = scope_executor.ticker().unwrap();
let external_ticker = if !external_executor.is_same(scope_executor) {
external_executor.ticker()
} else {
None
};
match (external_ticker, tick_task_pool_executor) {
(Some(external_ticker), true) => {
Self::execute_global_external_scope(
executor,
external_ticker,
scope_ticker,
get_results,
)
.await
}
(Some(external_ticker), false) => {
Self::execute_external_scope(external_ticker, scope_ticker, get_results)
.await
}
(None, true) => {
Self::execute_global_scope(executor, scope_ticker, get_results).await
}
(None, false) => Self::execute_scope(scope_ticker, get_results).await,
}
})
}
}
#[inline]
async fn execute_global_external_scope<'scope, 'ticker, T>(
executor: &'scope crate::executor::Executor<'scope>,
external_ticker: ThreadExecutorTicker<'scope, 'ticker>,
scope_ticker: ThreadExecutorTicker<'scope, 'ticker>,
get_results: impl Future<Output = Vec<T>>,
) -> Vec<T> {
let execute_forever = async move {
loop {
let tick_forever = async {
loop {
external_ticker.tick().or(scope_ticker.tick()).await;
}
};
let _result = AssertUnwindSafe(executor.run(tick_forever))
.catch_unwind()
.await
.is_ok();
}
};
get_results.or(execute_forever).await
}
#[inline]
async fn execute_external_scope<'scope, 'ticker, T>(
external_ticker: ThreadExecutorTicker<'scope, 'ticker>,
scope_ticker: ThreadExecutorTicker<'scope, 'ticker>,
get_results: impl Future<Output = Vec<T>>,
) -> Vec<T> {
let execute_forever = async {
loop {
let tick_forever = async {
loop {
external_ticker.tick().or(scope_ticker.tick()).await;
}
};
let _result = AssertUnwindSafe(tick_forever).catch_unwind().await.is_ok();
}
};
get_results.or(execute_forever).await
}
#[inline]
async fn execute_global_scope<'scope, 'ticker, T>(
executor: &'scope crate::executor::Executor<'scope>,
scope_ticker: ThreadExecutorTicker<'scope, 'ticker>,
get_results: impl Future<Output = Vec<T>>,
) -> Vec<T> {
let execute_forever = async {
loop {
let tick_forever = async {
loop {
scope_ticker.tick().await;
}
};
let _result = AssertUnwindSafe(executor.run(tick_forever))
.catch_unwind()
.await
.is_ok();
}
};
get_results.or(execute_forever).await
}
#[inline]
async fn execute_scope<'scope, 'ticker, T>(
scope_ticker: ThreadExecutorTicker<'scope, 'ticker>,
get_results: impl Future<Output = Vec<T>>,
) -> Vec<T> {
let execute_forever = async {
loop {
let tick_forever = async {
loop {
scope_ticker.tick().await;
}
};
let _result = AssertUnwindSafe(tick_forever).catch_unwind().await.is_ok();
}
};
get_results.or(execute_forever).await
}
pub fn spawn<T>(&self, future: impl Future<Output = T> + Send + 'static) -> Task<T>
where
T: Send + 'static,
{
Task::new(self.executor.spawn(future))
}
pub fn spawn_local<T>(&self, future: impl Future<Output = T> + 'static) -> Task<T>
where
T: 'static,
{
Task::new(TaskPool::LOCAL_EXECUTOR.with(|executor| executor.spawn(future)))
}
pub fn with_local_executor<F, R>(&self, f: F) -> R
where
F: FnOnce(&crate::executor::LocalExecutor) -> R,
{
Self::LOCAL_EXECUTOR.with(f)
}
}
impl Default for TaskPool {
fn default() -> Self {
Self::new()
}
}
impl Drop for TaskPool {
fn drop(&mut self) {
self.shutdown_tx.close();
let panicking = thread::panicking();
for join_handle in self.threads.drain(..) {
let res = join_handle.join();
if !panicking {
res.expect("Task thread panicked while executing.");
}
}
}
}
#[derive(Debug)]
pub struct Scope<'scope, 'env: 'scope, T> {
executor: &'scope crate::executor::Executor<'scope>,
external_executor: &'scope ThreadExecutor<'scope>,
scope_executor: &'scope ThreadExecutor<'scope>,
spawned: &'scope ConcurrentQueue<FallibleTask<Result<T, Box<dyn core::any::Any + Send>>>>,
scope: PhantomData<&'scope mut &'scope ()>,
env: PhantomData<&'env mut &'env ()>,
}
impl<'scope, 'env, T: Send + 'scope> Scope<'scope, 'env, T> {
pub fn spawn<Fut: Future<Output = T> + 'scope + Send>(&self, f: Fut) {
let task = self
.executor
.spawn(AssertUnwindSafe(f).catch_unwind())
.fallible();
self.spawned.push(task).unwrap();
}
pub fn spawn_on_scope<Fut: Future<Output = T> + 'scope + Send>(&self, f: Fut) {
let task = self
.scope_executor
.spawn(AssertUnwindSafe(f).catch_unwind())
.fallible();
self.spawned.push(task).unwrap();
}
pub fn spawn_on_external<Fut: Future<Output = T> + 'scope + Send>(&self, f: Fut) {
let task = self
.external_executor
.spawn(AssertUnwindSafe(f).catch_unwind())
.fallible();
self.spawned.push(task).unwrap();
}
}
impl<'scope, 'env, T> Drop for Scope<'scope, 'env, T>
where
T: 'scope,
{
fn drop(&mut self) {
block_on(async {
while let Ok(task) = self.spawned.pop() {
task.cancel().await;
}
});
}
}
#[cfg(test)]
mod tests {
use super::*;
use core::sync::atomic::{AtomicBool, AtomicI32, Ordering};
use std::sync::Barrier;
#[test]
fn test_spawn() {
let pool = TaskPool::new();
let foo = Box::new(42);
let foo = &*foo;
let count = Arc::new(AtomicI32::new(0));
let outputs = pool.scope(|scope| {
for _ in 0..100 {
let count_clone = count.clone();
scope.spawn(async move {
if *foo != 42 {
panic!("not 42!?!?")
} else {
count_clone.fetch_add(1, Ordering::Relaxed);
*foo
}
});
}
});
for output in &outputs {
assert_eq!(*output, 42);
}
assert_eq!(outputs.len(), 100);
assert_eq!(count.load(Ordering::Relaxed), 100);
}
#[test]
fn test_thread_callbacks() {
let counter = Arc::new(AtomicI32::new(0));
let start_counter = counter.clone();
{
let barrier = Arc::new(Barrier::new(11));
let last_barrier = barrier.clone();
let _pool = TaskPoolBuilder::new()
.num_threads(10)
.on_thread_spawn(move || {
start_counter.fetch_add(1, Ordering::Relaxed);
barrier.clone().wait();
})
.build();
last_barrier.wait();
assert_eq!(10, counter.load(Ordering::Relaxed));
}
assert_eq!(10, counter.load(Ordering::Relaxed));
let end_counter = counter.clone();
{
let _pool = TaskPoolBuilder::new()
.num_threads(20)
.on_thread_destroy(move || {
end_counter.fetch_sub(1, Ordering::Relaxed);
})
.build();
assert_eq!(10, counter.load(Ordering::Relaxed));
}
assert_eq!(-10, counter.load(Ordering::Relaxed));
let start_counter = counter.clone();
let end_counter = counter.clone();
{
let barrier = Arc::new(Barrier::new(6));
let last_barrier = barrier.clone();
let _pool = TaskPoolBuilder::new()
.num_threads(5)
.on_thread_spawn(move || {
start_counter.fetch_add(1, Ordering::Relaxed);
barrier.wait();
})
.on_thread_destroy(move || {
end_counter.fetch_sub(1, Ordering::Relaxed);
})
.build();
last_barrier.wait();
assert_eq!(-5, counter.load(Ordering::Relaxed));
}
assert_eq!(-10, counter.load(Ordering::Relaxed));
}
#[test]
fn test_mixed_spawn_on_scope_and_spawn() {
let pool = TaskPool::new();
let foo = Box::new(42);
let foo = &*foo;
let local_count = Arc::new(AtomicI32::new(0));
let non_local_count = Arc::new(AtomicI32::new(0));
let outputs = pool.scope(|scope| {
for i in 0..100 {
if i % 2 == 0 {
let count_clone = non_local_count.clone();
scope.spawn(async move {
if *foo != 42 {
panic!("not 42!?!?")
} else {
count_clone.fetch_add(1, Ordering::Relaxed);
*foo
}
});
} else {
let count_clone = local_count.clone();
scope.spawn_on_scope(async move {
if *foo != 42 {
panic!("not 42!?!?")
} else {
count_clone.fetch_add(1, Ordering::Relaxed);
*foo
}
});
}
}
});
for output in &outputs {
assert_eq!(*output, 42);
}
assert_eq!(outputs.len(), 100);
assert_eq!(local_count.load(Ordering::Relaxed), 50);
assert_eq!(non_local_count.load(Ordering::Relaxed), 50);
}
#[test]
fn test_thread_locality() {
let pool = Arc::new(TaskPool::new());
let count = Arc::new(AtomicI32::new(0));
let barrier = Arc::new(Barrier::new(101));
let thread_check_failed = Arc::new(AtomicBool::new(false));
for _ in 0..100 {
let inner_barrier = barrier.clone();
let count_clone = count.clone();
let inner_pool = pool.clone();
let inner_thread_check_failed = thread_check_failed.clone();
thread::spawn(move || {
inner_pool.scope(|scope| {
let inner_count_clone = count_clone.clone();
scope.spawn(async move {
inner_count_clone.fetch_add(1, Ordering::Release);
});
let spawner = thread::current().id();
let inner_count_clone = count_clone.clone();
scope.spawn_on_scope(async move {
inner_count_clone.fetch_add(1, Ordering::Release);
if thread::current().id() != spawner {
inner_thread_check_failed.store(true, Ordering::Release);
}
});
});
inner_barrier.wait();
});
}
barrier.wait();
assert!(!thread_check_failed.load(Ordering::Acquire));
assert_eq!(count.load(Ordering::Acquire), 200);
}
#[test]
fn test_nested_spawn() {
let pool = TaskPool::new();
let foo = Box::new(42);
let foo = &*foo;
let count = Arc::new(AtomicI32::new(0));
let outputs: Vec<i32> = pool.scope(|scope| {
for _ in 0..10 {
let count_clone = count.clone();
scope.spawn(async move {
for _ in 0..10 {
let count_clone_clone = count_clone.clone();
scope.spawn(async move {
if *foo != 42 {
panic!("not 42!?!?")
} else {
count_clone_clone.fetch_add(1, Ordering::Relaxed);
*foo
}
});
}
*foo
});
}
});
for output in &outputs {
assert_eq!(*output, 42);
}
assert_eq!(outputs.len(), 110);
assert_eq!(count.load(Ordering::Relaxed), 100);
}
#[test]
fn test_nested_locality() {
let pool = Arc::new(TaskPool::new());
let count = Arc::new(AtomicI32::new(0));
let barrier = Arc::new(Barrier::new(101));
let thread_check_failed = Arc::new(AtomicBool::new(false));
for _ in 0..100 {
let inner_barrier = barrier.clone();
let count_clone = count.clone();
let inner_pool = pool.clone();
let inner_thread_check_failed = thread_check_failed.clone();
thread::spawn(move || {
inner_pool.scope(|scope| {
let spawner = thread::current().id();
let inner_count_clone = count_clone.clone();
scope.spawn(async move {
inner_count_clone.fetch_add(1, Ordering::Release);
scope.spawn_on_scope(async move {
inner_count_clone.fetch_add(1, Ordering::Release);
if thread::current().id() != spawner {
inner_thread_check_failed.store(true, Ordering::Release);
}
});
});
});
inner_barrier.wait();
});
}
barrier.wait();
assert!(!thread_check_failed.load(Ordering::Acquire));
assert_eq!(count.load(Ordering::Acquire), 200);
}
#[test]
fn test_nested_scopes() {
let pool = TaskPool::new();
let count = Arc::new(AtomicI32::new(0));
pool.scope(|scope| {
scope.spawn(async {
pool.scope(|scope| {
scope.spawn(async {
count.fetch_add(1, Ordering::Relaxed);
});
});
});
});
assert_eq!(count.load(Ordering::Acquire), 1);
}
}