Path: blob/main/crates/bevy_app/src/task_pool_plugin.rs
6595 views
use crate::{App, Plugin};12use alloc::string::ToString;3use bevy_platform::sync::Arc;4use bevy_tasks::{AsyncComputeTaskPool, ComputeTaskPool, IoTaskPool, TaskPoolBuilder};5use core::fmt::Debug;6use log::trace;78cfg_if::cfg_if! {9if #[cfg(not(all(target_arch = "wasm32", feature = "web")))] {10use {crate::Last, bevy_tasks::tick_global_task_pools_on_main_thread};11use bevy_ecs::system::NonSendMarker;1213/// A system used to check and advanced our task pools.14///15/// Calls [`tick_global_task_pools_on_main_thread`],16/// and uses [`NonSendMarker`] to ensure that this system runs on the main thread17fn tick_global_task_pools(_main_thread_marker: NonSendMarker) {18tick_global_task_pools_on_main_thread();19}20}21}2223/// Setup of default task pools: [`AsyncComputeTaskPool`], [`ComputeTaskPool`], [`IoTaskPool`].24#[derive(Default)]25pub struct TaskPoolPlugin {26/// Options for the [`TaskPool`](bevy_tasks::TaskPool) created at application start.27pub task_pool_options: TaskPoolOptions,28}2930impl Plugin for TaskPoolPlugin {31fn build(&self, _app: &mut App) {32// Setup the default bevy task pools33self.task_pool_options.create_default_pools();3435#[cfg(not(all(target_arch = "wasm32", feature = "web")))]36_app.add_systems(Last, tick_global_task_pools);37}38}3940/// Defines a simple way to determine how many threads to use given the number of remaining cores41/// and number of total cores42#[derive(Clone)]43pub struct TaskPoolThreadAssignmentPolicy {44/// Force using at least this many threads45pub min_threads: usize,46/// Under no circumstance use more than this many threads for this pool47pub max_threads: usize,48/// Target using this percentage of total cores, clamped by `min_threads` and `max_threads`. It is49/// permitted to use 1.0 to try to use all remaining threads50pub percent: f32,51/// Callback that is invoked once for every created thread as it starts.52/// This configuration will be ignored under wasm platform.53pub on_thread_spawn: Option<Arc<dyn Fn() + Send + Sync + 'static>>,54/// Callback that is invoked once for every created thread as it terminates55/// This configuration will be ignored under wasm platform.56pub on_thread_destroy: Option<Arc<dyn Fn() + Send + Sync + 'static>>,57}5859impl Debug for TaskPoolThreadAssignmentPolicy {60fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {61f.debug_struct("TaskPoolThreadAssignmentPolicy")62.field("min_threads", &self.min_threads)63.field("max_threads", &self.max_threads)64.field("percent", &self.percent)65.finish()66}67}6869impl TaskPoolThreadAssignmentPolicy {70/// Determine the number of threads to use for this task pool71fn get_number_of_threads(&self, remaining_threads: usize, total_threads: usize) -> usize {72assert!(self.percent >= 0.0);73let proportion = total_threads as f32 * self.percent;74let mut desired = proportion as usize;7576// Equivalent to round() for positive floats without libm requirement for77// no_std compatibility78if proportion - desired as f32 >= 0.5 {79desired += 1;80}8182// Limit ourselves to the number of cores available83desired = desired.min(remaining_threads);8485// Clamp by min_threads, max_threads. (This may result in us using more threads than are86// available, this is intended. An example case where this might happen is a device with87// <= 2 threads.88desired.clamp(self.min_threads, self.max_threads)89}90}9192/// Helper for configuring and creating the default task pools. For end-users who want full control,93/// set up [`TaskPoolPlugin`]94#[derive(Clone, Debug)]95pub struct TaskPoolOptions {96/// If the number of physical cores is less than `min_total_threads`, force using97/// `min_total_threads`98pub min_total_threads: usize,99/// If the number of physical cores is greater than `max_total_threads`, force using100/// `max_total_threads`101pub max_total_threads: usize,102103/// Used to determine number of IO threads to allocate104pub io: TaskPoolThreadAssignmentPolicy,105/// Used to determine number of async compute threads to allocate106pub async_compute: TaskPoolThreadAssignmentPolicy,107/// Used to determine number of compute threads to allocate108pub compute: TaskPoolThreadAssignmentPolicy,109}110111impl Default for TaskPoolOptions {112fn default() -> Self {113TaskPoolOptions {114// By default, use however many cores are available on the system115min_total_threads: 1,116max_total_threads: usize::MAX,117118// Use 25% of cores for IO, at least 1, no more than 4119io: TaskPoolThreadAssignmentPolicy {120min_threads: 1,121max_threads: 4,122percent: 0.25,123on_thread_spawn: None,124on_thread_destroy: None,125},126127// Use 25% of cores for async compute, at least 1, no more than 4128async_compute: TaskPoolThreadAssignmentPolicy {129min_threads: 1,130max_threads: 4,131percent: 0.25,132on_thread_spawn: None,133on_thread_destroy: None,134},135136// Use all remaining cores for compute (at least 1)137compute: TaskPoolThreadAssignmentPolicy {138min_threads: 1,139max_threads: usize::MAX,140percent: 1.0, // This 1.0 here means "whatever is left over"141on_thread_spawn: None,142on_thread_destroy: None,143},144}145}146}147148impl TaskPoolOptions {149/// Create a configuration that forces using the given number of threads.150pub fn with_num_threads(thread_count: usize) -> Self {151TaskPoolOptions {152min_total_threads: thread_count,153max_total_threads: thread_count,154..Default::default()155}156}157158/// Inserts the default thread pools into the given resource map based on the configured values159pub fn create_default_pools(&self) {160let total_threads = bevy_tasks::available_parallelism()161.clamp(self.min_total_threads, self.max_total_threads);162trace!("Assigning {total_threads} cores to default task pools");163164let mut remaining_threads = total_threads;165166{167// Determine the number of IO threads we will use168let io_threads = self169.io170.get_number_of_threads(remaining_threads, total_threads);171172trace!("IO Threads: {io_threads}");173remaining_threads = remaining_threads.saturating_sub(io_threads);174175IoTaskPool::get_or_init(|| {176let builder = TaskPoolBuilder::default()177.num_threads(io_threads)178.thread_name("IO Task Pool".to_string());179180#[cfg(not(all(target_arch = "wasm32", feature = "web")))]181let builder = {182let mut builder = builder;183if let Some(f) = self.io.on_thread_spawn.clone() {184builder = builder.on_thread_spawn(move || f());185}186if let Some(f) = self.io.on_thread_destroy.clone() {187builder = builder.on_thread_destroy(move || f());188}189builder190};191192builder.build()193});194}195196{197// Determine the number of async compute threads we will use198let async_compute_threads = self199.async_compute200.get_number_of_threads(remaining_threads, total_threads);201202trace!("Async Compute Threads: {async_compute_threads}");203remaining_threads = remaining_threads.saturating_sub(async_compute_threads);204205AsyncComputeTaskPool::get_or_init(|| {206let builder = TaskPoolBuilder::default()207.num_threads(async_compute_threads)208.thread_name("Async Compute Task Pool".to_string());209210#[cfg(not(all(target_arch = "wasm32", feature = "web")))]211let builder = {212let mut builder = builder;213if let Some(f) = self.async_compute.on_thread_spawn.clone() {214builder = builder.on_thread_spawn(move || f());215}216if let Some(f) = self.async_compute.on_thread_destroy.clone() {217builder = builder.on_thread_destroy(move || f());218}219builder220};221222builder.build()223});224}225226{227// Determine the number of compute threads we will use228// This is intentionally last so that an end user can specify 1.0 as the percent229let compute_threads = self230.compute231.get_number_of_threads(remaining_threads, total_threads);232233trace!("Compute Threads: {compute_threads}");234235ComputeTaskPool::get_or_init(|| {236let builder = TaskPoolBuilder::default()237.num_threads(compute_threads)238.thread_name("Compute Task Pool".to_string());239240#[cfg(not(all(target_arch = "wasm32", feature = "web")))]241let builder = {242let mut builder = builder;243if let Some(f) = self.compute.on_thread_spawn.clone() {244builder = builder.on_thread_spawn(move || f());245}246if let Some(f) = self.compute.on_thread_destroy.clone() {247builder = builder.on_thread_destroy(move || f());248}249builder250};251252builder.build()253});254}255}256}257258#[cfg(test)]259mod tests {260use super::*;261use bevy_tasks::prelude::{AsyncComputeTaskPool, ComputeTaskPool, IoTaskPool};262263#[test]264fn runs_spawn_local_tasks() {265let mut app = App::new();266app.add_plugins(TaskPoolPlugin::default());267268let (async_tx, async_rx) = crossbeam_channel::unbounded();269AsyncComputeTaskPool::get()270.spawn_local(async move {271async_tx.send(()).unwrap();272})273.detach();274275let (compute_tx, compute_rx) = crossbeam_channel::unbounded();276ComputeTaskPool::get()277.spawn_local(async move {278compute_tx.send(()).unwrap();279})280.detach();281282let (io_tx, io_rx) = crossbeam_channel::unbounded();283IoTaskPool::get()284.spawn_local(async move {285io_tx.send(()).unwrap();286})287.detach();288289app.run();290291async_rx.try_recv().unwrap();292compute_rx.try_recv().unwrap();293io_rx.try_recv().unwrap();294}295}296297298