Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
bevyengine
GitHub Repository: bevyengine/bevy
Path: blob/main/examples/shader_advanced/custom_shader_instancing.rs
9334 views
1
//! A shader that renders a mesh multiple times in one draw call.
2
//!
3
//! Bevy will automatically batch and instance your meshes assuming you use the same
4
//! `Handle<Material>` and `Handle<Mesh>` for all of your instances.
5
//!
6
//! This example is intended for advanced users and shows how to make a custom instancing
7
//! implementation using bevy's low level rendering api.
8
//! It's generally recommended to try the built-in instancing before going with this approach.
9
10
use bevy::pbr::{SetMeshViewBindingArrayBindGroup, ViewKeyCache};
11
use bevy::{
12
camera::visibility::NoFrustumCulling,
13
core_pipeline::core_3d::Transparent3d,
14
ecs::{
15
query::QueryItem,
16
system::{lifetimeless::*, SystemParamItem},
17
},
18
mesh::{MeshVertexBufferLayoutRef, VertexBufferLayout},
19
pbr::{
20
MeshPipeline, MeshPipelineKey, RenderMeshInstances, SetMeshBindGroup, SetMeshViewBindGroup,
21
},
22
prelude::*,
23
render::{
24
extract_component::{ExtractComponent, ExtractComponentPlugin},
25
mesh::{allocator::MeshAllocator, RenderMesh, RenderMeshBufferInfo},
26
render_asset::RenderAssets,
27
render_phase::{
28
AddRenderCommand, DrawFunctions, PhaseItem, PhaseItemExtraIndex, RenderCommand,
29
RenderCommandResult, SetItemPipeline, TrackedRenderPass, ViewSortedRenderPhases,
30
},
31
render_resource::*,
32
renderer::RenderDevice,
33
sync_component::SyncComponent,
34
sync_world::MainEntity,
35
view::{ExtractedView, NoIndirectDrawing},
36
Render, RenderApp, RenderStartup, RenderSystems,
37
},
38
};
39
use bytemuck::{Pod, Zeroable};
40
41
/// This example uses a shader source file from the assets subdirectory
42
const SHADER_ASSET_PATH: &str = "shaders/instancing.wgsl";
43
44
fn main() {
45
App::new()
46
.add_plugins((DefaultPlugins, CustomMaterialPlugin))
47
.add_systems(Startup, setup)
48
.run();
49
}
50
51
fn setup(mut commands: Commands, mut meshes: ResMut<Assets<Mesh>>) {
52
commands.spawn((
53
Mesh3d(meshes.add(Cuboid::new(0.5, 0.5, 0.5))),
54
InstanceMaterialData(
55
(1..=10)
56
.flat_map(|x| (1..=10).map(move |y| (x as f32 / 10.0, y as f32 / 10.0)))
57
.map(|(x, y)| InstanceData {
58
position: Vec3::new(x * 10.0 - 5.0, y * 10.0 - 5.0, 0.0),
59
scale: 1.0,
60
color: LinearRgba::from(Color::hsla(x * 360., y, 0.5, 1.0)).to_f32_array(),
61
})
62
.collect(),
63
),
64
// NOTE: Frustum culling is done based on the Aabb of the Mesh and the GlobalTransform.
65
// As the cube is at the origin, if its Aabb moves outside the view frustum, all the
66
// instanced cubes will be culled.
67
// The InstanceMaterialData contains the 'GlobalTransform' information for this custom
68
// instancing, and that is not taken into account with the built-in frustum culling.
69
// We must disable the built-in frustum culling by adding the `NoFrustumCulling` marker
70
// component to avoid incorrect culling.
71
NoFrustumCulling,
72
));
73
74
// camera
75
commands.spawn((
76
Camera3d::default(),
77
Transform::from_xyz(0.0, 0.0, 15.0).looking_at(Vec3::ZERO, Vec3::Y),
78
// We need this component because we use `draw_indexed` and `draw`
79
// instead of `draw_indirect_indexed` and `draw_indirect` in
80
// `DrawMeshInstanced::render`.
81
NoIndirectDrawing,
82
));
83
}
84
85
#[derive(Component, Deref)]
86
struct InstanceMaterialData(Vec<InstanceData>);
87
88
impl SyncComponent for InstanceMaterialData {
89
type Out = Self;
90
}
91
92
impl ExtractComponent for InstanceMaterialData {
93
type QueryData = &'static InstanceMaterialData;
94
type QueryFilter = ();
95
96
fn extract_component(item: QueryItem<'_, '_, Self::QueryData>) -> Option<Self> {
97
Some(InstanceMaterialData(item.0.clone()))
98
}
99
}
100
101
struct CustomMaterialPlugin;
102
103
impl Plugin for CustomMaterialPlugin {
104
fn build(&self, app: &mut App) {
105
app.add_plugins(ExtractComponentPlugin::<InstanceMaterialData>::default());
106
app.sub_app_mut(RenderApp)
107
.add_render_command::<Transparent3d, DrawCustom>()
108
.init_resource::<SpecializedMeshPipelines<CustomPipeline>>()
109
.add_systems(RenderStartup, init_custom_pipeline)
110
.add_systems(
111
Render,
112
(
113
queue_custom.in_set(RenderSystems::QueueMeshes),
114
prepare_instance_buffers.in_set(RenderSystems::PrepareResources),
115
),
116
);
117
}
118
}
119
120
#[derive(Clone, Copy, Pod, Zeroable)]
121
#[repr(C)]
122
struct InstanceData {
123
position: Vec3,
124
scale: f32,
125
color: [f32; 4],
126
}
127
128
fn queue_custom(
129
transparent_3d_draw_functions: Res<DrawFunctions<Transparent3d>>,
130
custom_pipeline: Res<CustomPipeline>,
131
mut pipelines: ResMut<SpecializedMeshPipelines<CustomPipeline>>,
132
pipeline_cache: Res<PipelineCache>,
133
meshes: Res<RenderAssets<RenderMesh>>,
134
render_mesh_instances: Res<RenderMeshInstances>,
135
material_meshes: Query<(Entity, &MainEntity), With<InstanceMaterialData>>,
136
mut transparent_render_phases: ResMut<ViewSortedRenderPhases<Transparent3d>>,
137
views: Query<&ExtractedView>,
138
view_key_cache: Res<ViewKeyCache>,
139
) {
140
let draw_custom = transparent_3d_draw_functions.read().id::<DrawCustom>();
141
142
for view in &views {
143
let Some(transparent_phase) = transparent_render_phases.get_mut(&view.retained_view_entity)
144
else {
145
continue;
146
};
147
148
let Some(&view_key) = view_key_cache.get(&view.retained_view_entity) else {
149
continue;
150
};
151
152
let rangefinder = view.rangefinder3d();
153
for (entity, main_entity) in &material_meshes {
154
let Some(mesh_instance) = render_mesh_instances.render_mesh_queue_data(*main_entity)
155
else {
156
continue;
157
};
158
let Some(mesh) = meshes.get(mesh_instance.mesh_asset_id) else {
159
continue;
160
};
161
let key =
162
view_key | MeshPipelineKey::from_primitive_topology(mesh.primitive_topology());
163
let pipeline = pipelines
164
.specialize(&pipeline_cache, &custom_pipeline, key, &mesh.layout)
165
.unwrap();
166
transparent_phase.add(Transparent3d {
167
entity: (entity, *main_entity),
168
pipeline,
169
draw_function: draw_custom,
170
distance: rangefinder.distance(&mesh_instance.center),
171
batch_range: 0..1,
172
extra_index: PhaseItemExtraIndex::None,
173
indexed: true,
174
});
175
}
176
}
177
}
178
179
#[derive(Component)]
180
struct InstanceBuffer {
181
buffer: Buffer,
182
length: usize,
183
}
184
185
fn prepare_instance_buffers(
186
mut commands: Commands,
187
query: Query<(Entity, &InstanceMaterialData)>,
188
render_device: Res<RenderDevice>,
189
) {
190
for (entity, instance_data) in &query {
191
let buffer = render_device.create_buffer_with_data(&BufferInitDescriptor {
192
label: Some("instance data buffer"),
193
contents: bytemuck::cast_slice(instance_data.as_slice()),
194
usage: BufferUsages::VERTEX | BufferUsages::COPY_DST,
195
});
196
commands.entity(entity).insert(InstanceBuffer {
197
buffer,
198
length: instance_data.len(),
199
});
200
}
201
}
202
203
#[derive(Resource)]
204
struct CustomPipeline {
205
shader: Handle<Shader>,
206
mesh_pipeline: MeshPipeline,
207
}
208
209
fn init_custom_pipeline(
210
mut commands: Commands,
211
asset_server: Res<AssetServer>,
212
mesh_pipeline: Res<MeshPipeline>,
213
) {
214
commands.insert_resource(CustomPipeline {
215
shader: asset_server.load(SHADER_ASSET_PATH),
216
mesh_pipeline: mesh_pipeline.clone(),
217
});
218
}
219
220
impl SpecializedMeshPipeline for CustomPipeline {
221
type Key = MeshPipelineKey;
222
223
fn specialize(
224
&self,
225
key: Self::Key,
226
layout: &MeshVertexBufferLayoutRef,
227
) -> Result<RenderPipelineDescriptor, SpecializedMeshPipelineError> {
228
let mut descriptor = self.mesh_pipeline.specialize(key, layout)?;
229
230
descriptor.vertex.shader = self.shader.clone();
231
descriptor.vertex.buffers.push(VertexBufferLayout {
232
array_stride: size_of::<InstanceData>() as u64,
233
step_mode: VertexStepMode::Instance,
234
attributes: vec![
235
VertexAttribute {
236
format: VertexFormat::Float32x4,
237
offset: 0,
238
shader_location: 3, // shader locations 0-2 are taken up by Position, Normal and UV attributes
239
},
240
VertexAttribute {
241
format: VertexFormat::Float32x4,
242
offset: VertexFormat::Float32x4.size(),
243
shader_location: 4,
244
},
245
],
246
});
247
descriptor.fragment.as_mut().unwrap().shader = self.shader.clone();
248
Ok(descriptor)
249
}
250
}
251
252
type DrawCustom = (
253
SetItemPipeline,
254
SetMeshViewBindGroup<0>,
255
SetMeshViewBindingArrayBindGroup<1>,
256
SetMeshBindGroup<2>,
257
DrawMeshInstanced,
258
);
259
260
struct DrawMeshInstanced;
261
262
impl<P: PhaseItem> RenderCommand<P> for DrawMeshInstanced {
263
type Param = (
264
SRes<RenderAssets<RenderMesh>>,
265
SRes<RenderMeshInstances>,
266
SRes<MeshAllocator>,
267
);
268
type ViewQuery = ();
269
type ItemQuery = Read<InstanceBuffer>;
270
271
#[inline]
272
fn render<'w>(
273
item: &P,
274
_view: (),
275
instance_buffer: Option<&'w InstanceBuffer>,
276
(meshes, render_mesh_instances, mesh_allocator): SystemParamItem<'w, '_, Self::Param>,
277
pass: &mut TrackedRenderPass<'w>,
278
) -> RenderCommandResult {
279
// A borrow check workaround.
280
let mesh_allocator = mesh_allocator.into_inner();
281
282
let Some(mesh_instance) = render_mesh_instances.render_mesh_queue_data(item.main_entity())
283
else {
284
return RenderCommandResult::Skip;
285
};
286
let Some(gpu_mesh) = meshes.into_inner().get(mesh_instance.mesh_asset_id) else {
287
return RenderCommandResult::Skip;
288
};
289
let Some(instance_buffer) = instance_buffer else {
290
return RenderCommandResult::Skip;
291
};
292
let Some(vertex_buffer_slice) =
293
mesh_allocator.mesh_vertex_slice(&mesh_instance.mesh_asset_id)
294
else {
295
return RenderCommandResult::Skip;
296
};
297
298
pass.set_vertex_buffer(0, vertex_buffer_slice.buffer.slice(..));
299
pass.set_vertex_buffer(1, instance_buffer.buffer.slice(..));
300
301
match &gpu_mesh.buffer_info {
302
RenderMeshBufferInfo::Indexed {
303
index_format,
304
count,
305
} => {
306
let Some(index_buffer_slice) =
307
mesh_allocator.mesh_index_slice(&mesh_instance.mesh_asset_id)
308
else {
309
return RenderCommandResult::Skip;
310
};
311
312
pass.set_index_buffer(index_buffer_slice.buffer.slice(..), *index_format);
313
pass.draw_indexed(
314
index_buffer_slice.range.start..(index_buffer_slice.range.start + count),
315
vertex_buffer_slice.range.start as i32,
316
0..instance_buffer.length as u32,
317
);
318
}
319
RenderMeshBufferInfo::NonIndexed => {
320
pass.draw(vertex_buffer_slice.range, 0..instance_buffer.length as u32);
321
}
322
}
323
RenderCommandResult::Success
324
}
325
}
326
327