Path: blob/main/examples/shader_advanced/compute_mesh.rs
9343 views
//! This example shows how to initialize an empty mesh with a Handle1//! and a render-world only usage. That buffer is then filled by a2//! compute shader on the GPU without transferring data back3//! to the CPU.4//!5//! The `mesh_allocator` is used to get references to the relevant slabs6//! that contain the mesh data we're interested in.7//!8//! This example does not remove the `GenerateMesh` component after9//! generating the mesh.1011use std::ops::Not;1213use bevy::{14asset::RenderAssetUsages,15color::palettes::tailwind::{RED_400, SKY_400},16core_pipeline::schedule::camera_driver,17mesh::Indices,18platform::collections::HashSet,19prelude::*,20render::{21extract_component::{ExtractComponent, ExtractComponentPlugin},22mesh::allocator::MeshAllocator,23render_resource::{24binding_types::{storage_buffer, uniform_buffer},25*,26},27renderer::{RenderContext, RenderGraph, RenderQueue},28Render, RenderApp, RenderStartup,29},30};3132/// This example uses a shader source file from the assets subdirectory33const SHADER_ASSET_PATH: &str = "shaders/compute_mesh.wgsl";3435fn main() {36App::new()37.add_plugins((38DefaultPlugins,39ComputeShaderMeshGeneratorPlugin,40ExtractComponentPlugin::<GenerateMesh>::default(),41))42.insert_resource(ClearColor(Color::BLACK))43.add_systems(Startup, setup)44.run();45}4647// We need a plugin to organize all the systems and render node required for this example48struct ComputeShaderMeshGeneratorPlugin;49impl Plugin for ComputeShaderMeshGeneratorPlugin {50fn build(&self, app: &mut App) {51let Some(render_app) = app.get_sub_app_mut(RenderApp) else {52return;53};5455render_app56.init_resource::<ChunksToProcess>()57.add_systems(RenderStartup, init_compute_pipeline)58.add_systems(Render, prepare_chunks)59.add_systems(RenderGraph, compute_mesh.before(camera_driver));60}61fn finish(&self, app: &mut App) {62let Some(render_app) = app.get_sub_app_mut(RenderApp) else {63return;64};65render_app66.world_mut()67.resource_mut::<MeshAllocator>()68// This allows using the mesh allocator slabs as69// storage buffers directly in the compute shader.70// Which means that we can write from our compute71// shader directly to the allocated mesh slabs.72.extra_buffer_usages = BufferUsages::STORAGE;73}74}7576/// Holds a handle to the empty mesh that should be filled77/// by the compute shader.78#[derive(Component, ExtractComponent, Clone)]79struct GenerateMesh(Handle<Mesh>);8081fn setup(82mut commands: Commands,83mut meshes: ResMut<Assets<Mesh>>,84mut materials: ResMut<Assets<StandardMaterial>>,85) {86// a truly empty mesh will error if used in Mesh3d87// so we set up the data to be what we want the compute shader to output88// We're using 36 indices and 24 vertices which is directly taken from89// the Bevy Cuboid mesh implementation.90//91// We allocate 50 spots for each attribute here because92// it is *very important* that the amount of data allocated here is93// *bigger* than (or exactly equal to) the amount of data we intend to94// write from the compute shader. This amount of data defines how big95// the buffer we get from the mesh_allocator will be, which in turn96// defines how big the buffer is when we're in the compute shader.97//98// If it turns out you don't need all of the space when the compute shader99// is writing data, you can write NaN to the rest of the data.100let empty_mesh = {101let mut mesh = Mesh::new(102PrimitiveTopology::TriangleList,103RenderAssetUsages::RENDER_WORLD,104)105.with_inserted_attribute(Mesh::ATTRIBUTE_POSITION, vec![[0.; 3]; 50])106.with_inserted_attribute(Mesh::ATTRIBUTE_NORMAL, vec![[0.; 3]; 50])107.with_inserted_attribute(Mesh::ATTRIBUTE_UV_0, vec![[0.; 2]; 50])108.with_inserted_indices(Indices::U32(vec![0; 50]));109110mesh.asset_usage = RenderAssetUsages::RENDER_WORLD;111mesh112};113114let handle = meshes.add(empty_mesh);115116// we spawn two "users" of the mesh handle,117// but only insert `GenerateMesh` on one of them118// to show that the mesh handle works as usual119commands.spawn((120GenerateMesh(handle.clone()),121Mesh3d(handle.clone()),122MeshMaterial3d(materials.add(StandardMaterial {123base_color: RED_400.into(),124..default()125})),126Transform::from_xyz(-2.5, 1.5, 0.),127));128129commands.spawn((130Mesh3d(handle),131MeshMaterial3d(materials.add(StandardMaterial {132base_color: SKY_400.into(),133..default()134})),135Transform::from_xyz(2.5, 1.5, 0.),136));137138// some additional scene elements.139// This mesh specifically is here so that we don't assume140// mesh_allocator offsets that would only work if we had141// one mesh in the scene.142commands.spawn((143Mesh3d(meshes.add(Circle::new(4.0))),144MeshMaterial3d(materials.add(Color::WHITE)),145Transform::from_rotation(Quat::from_rotation_x(-std::f32::consts::FRAC_PI_2)),146));147commands.spawn((148PointLight {149shadow_maps_enabled: true,150..default()151},152Transform::from_xyz(4.0, 8.0, 4.0),153));154// camera155commands.spawn((156Camera3d::default(),157Transform::from_xyz(-2.5, 4.5, 9.0).looking_at(Vec3::ZERO, Vec3::Y),158));159}160161/// This is called `ChunksToProcess` because this example originated162/// from a use case of generating chunks of landscape or voxels163/// It only exists in the render world.164#[derive(Resource, Default)]165struct ChunksToProcess(Vec<AssetId<Mesh>>);166167/// `processed` is a `HashSet` contains the `AssetId`s that have been168/// processed. We use that to remove `AssetId`s that have already169/// been processed, which means each unique `GenerateMesh` will result170/// in one compute shader mesh generation process instead of generating171/// the mesh every frame.172fn prepare_chunks(173meshes_to_generate: Query<&GenerateMesh>,174mut chunks: ResMut<ChunksToProcess>,175pipeline_cache: Res<PipelineCache>,176pipeline: Res<ComputePipeline>,177mut processed: Local<HashSet<AssetId<Mesh>>>,178) {179// If the pipeline isn't ready, then meshes180// won't be processed. So we want to wait until181// the pipeline is ready before considering any mesh processed.182if pipeline_cache183.get_compute_pipeline(pipeline.pipeline)184.is_some()185{186// get the AssetId for each Handle<Mesh>187// which we'll use later to get the relevant buffers188// from the mesh_allocator189let chunk_data: Vec<AssetId<Mesh>> = meshes_to_generate190.iter()191.filter_map(|gmesh| {192let id = gmesh.0.id();193processed.contains(&id).not().then_some(id)194})195.collect();196197// Cache any meshes we're going to process this frame198for id in &chunk_data {199processed.insert(*id);200}201202chunks.0 = chunk_data;203}204}205206#[derive(Resource)]207struct ComputePipeline {208layout: BindGroupLayoutDescriptor,209pipeline: CachedComputePipelineId,210}211212// init only happens once213fn init_compute_pipeline(214mut commands: Commands,215asset_server: Res<AssetServer>,216pipeline_cache: Res<PipelineCache>,217) {218let layout = BindGroupLayoutDescriptor::new(219"",220&BindGroupLayoutEntries::sequential(221ShaderStages::COMPUTE,222(223// offsets224uniform_buffer::<DataRanges>(false),225// vertices226storage_buffer::<Vec<u32>>(false),227// indices228storage_buffer::<Vec<u32>>(false),229),230),231);232let shader = asset_server.load(SHADER_ASSET_PATH);233let pipeline = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor {234label: Some("Mesh generation compute shader".into()),235layout: vec![layout.clone()],236shader: shader.clone(),237..default()238});239commands.insert_resource(ComputePipeline { layout, pipeline });240}241242// A uniform that holds the vertex and index offsets243// for the vertex/index mesh_allocator buffer slabs244#[derive(ShaderType)]245struct DataRanges {246vertex_start: u32,247vertex_end: u32,248index_start: u32,249index_end: u32,250}251252fn compute_mesh(253mut render_context: RenderContext,254chunks: Res<ChunksToProcess>,255mesh_allocator: Res<MeshAllocator>,256pipeline_cache: Res<PipelineCache>,257pipeline: Res<ComputePipeline>,258render_queue: Res<RenderQueue>,259) {260let Some(init_pipeline) = pipeline_cache.get_compute_pipeline(pipeline.pipeline) else {261return;262};263264for mesh_id in &chunks.0 {265info!(?mesh_id, "processing mesh");266267// the mesh_allocator holds slabs of meshes, so the buffers we get here268// can contain more data than just the mesh we're asking for.269// That's why there is a range field.270// You should *not* touch data in these buffers that is outside of the range.271let vertex_buffer_slice = mesh_allocator.mesh_vertex_slice(mesh_id).unwrap();272let index_buffer_slice = mesh_allocator.mesh_index_slice(mesh_id).unwrap();273274let first = DataRanges {275// there are 8 vertex data values (pos, normal, uv) per vertex276// and the vertex_buffer_slice.range.start is in "vertex elements"277// which includes all of that data, so each index is worth 8 indices278// to our shader code.279vertex_start: vertex_buffer_slice.range.start * 8,280vertex_end: vertex_buffer_slice.range.end * 8,281// but each vertex index is a single value, so the index of the282// vertex indices is exactly what the value is283index_start: index_buffer_slice.range.start,284index_end: index_buffer_slice.range.end,285};286287let mut uniforms = UniformBuffer::from(first);288uniforms.write_buffer(render_context.render_device(), &render_queue);289290// pass in the full mesh_allocator slabs as well as the first index291// offsets for the vertex and index buffers292let bind_group = render_context.render_device().create_bind_group(293None,294&pipeline_cache.get_bind_group_layout(&pipeline.layout),295&BindGroupEntries::sequential((296&uniforms,297vertex_buffer_slice.buffer.as_entire_buffer_binding(),298index_buffer_slice.buffer.as_entire_buffer_binding(),299)),300);301302let mut pass =303render_context304.command_encoder()305.begin_compute_pass(&ComputePassDescriptor {306label: Some("Mesh generation compute pass"),307..default()308});309pass.push_debug_group("compute_mesh");310311pass.set_bind_group(0, &bind_group, &[]);312pass.set_pipeline(init_pipeline);313// we only dispatch 1,1,1 workgroup here, but a real compute shader314// would take advantage of more and larger size workgroups315pass.dispatch_workgroups(1, 1, 1);316317pass.pop_debug_group();318}319}320321322