Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
bevyengine
GitHub Repository: bevyengine/bevy
Path: blob/main/crates/bevy_shader/src/shader_cache.rs
6604 views
1
use crate::shader::*;
2
use alloc::sync::Arc;
3
use bevy_asset::AssetId;
4
use bevy_platform::collections::{hash_map::EntryRef, HashMap, HashSet};
5
use core::hash::Hash;
6
use naga::valid::Capabilities;
7
use thiserror::Error;
8
use tracing::{debug, error};
9
use wgpu_types::{DownlevelFlags, Features};
10
11
/// Source of a shader module.
12
///
13
/// The source will be parsed and validated.
14
///
15
/// Any necessary shader translation (e.g. from WGSL to SPIR-V or vice versa)
16
/// will be done internally by wgpu.
17
///
18
/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification,
19
/// only WGSL source code strings are accepted.
20
///
21
/// This is roughly equivalent to `wgpu::ShaderSource`
22
#[cfg_attr(
23
not(feature = "decoupled_naga"),
24
expect(
25
clippy::large_enum_variant,
26
reason = "naga modules are the most common use, and are large"
27
)
28
)]
29
#[derive(Clone, Debug)]
30
pub enum ShaderCacheSource<'a> {
31
/// SPIR-V module represented as a slice of words.
32
SpirV(&'a [u8]),
33
/// WGSL module as a string slice.
34
Wgsl(String),
35
/// Naga module.
36
#[cfg(not(feature = "decoupled_naga"))]
37
Naga(naga::Module),
38
}
39
40
pub type CachedPipelineId = usize;
41
42
struct ShaderData<ShaderModule> {
43
pipelines: HashSet<CachedPipelineId>,
44
processed_shaders: HashMap<Box<[ShaderDefVal]>, Arc<ShaderModule>>,
45
resolved_imports: HashMap<ShaderImport, AssetId<Shader>>,
46
dependents: HashSet<AssetId<Shader>>,
47
}
48
49
impl<T> Default for ShaderData<T> {
50
fn default() -> Self {
51
Self {
52
pipelines: Default::default(),
53
processed_shaders: Default::default(),
54
resolved_imports: Default::default(),
55
dependents: Default::default(),
56
}
57
}
58
}
59
60
pub struct ShaderCache<ShaderModule, RenderDevice> {
61
data: HashMap<AssetId<Shader>, ShaderData<ShaderModule>>,
62
load_module: fn(
63
&RenderDevice,
64
ShaderCacheSource,
65
&ValidateShader,
66
) -> Result<ShaderModule, PipelineCacheError>,
67
#[cfg(feature = "shader_format_wesl")]
68
asset_paths: HashMap<wesl::syntax::ModulePath, AssetId<Shader>>,
69
shaders: HashMap<AssetId<Shader>, Shader>,
70
import_path_shaders: HashMap<ShaderImport, AssetId<Shader>>,
71
waiting_on_import: HashMap<ShaderImport, Vec<AssetId<Shader>>>,
72
pub composer: naga_oil::compose::Composer,
73
}
74
75
#[derive(serde::Serialize, serde::Deserialize, Clone, PartialEq, Eq, Debug, Hash)]
76
pub enum ShaderDefVal {
77
Bool(String, bool),
78
Int(String, i32),
79
UInt(String, u32),
80
}
81
82
impl From<&str> for ShaderDefVal {
83
fn from(key: &str) -> Self {
84
ShaderDefVal::Bool(key.to_string(), true)
85
}
86
}
87
88
impl From<String> for ShaderDefVal {
89
fn from(key: String) -> Self {
90
ShaderDefVal::Bool(key, true)
91
}
92
}
93
94
impl ShaderDefVal {
95
pub fn value_as_string(&self) -> String {
96
match self {
97
ShaderDefVal::Bool(_, def) => def.to_string(),
98
ShaderDefVal::Int(_, def) => def.to_string(),
99
ShaderDefVal::UInt(_, def) => def.to_string(),
100
}
101
}
102
}
103
104
impl<ShaderModule, RenderDevice> ShaderCache<ShaderModule, RenderDevice> {
105
pub fn new(
106
features: Features,
107
downlevel: DownlevelFlags,
108
load_module: fn(
109
&RenderDevice,
110
ShaderCacheSource,
111
&ValidateShader,
112
) -> Result<ShaderModule, PipelineCacheError>,
113
) -> Self {
114
let capabilities = get_capabilities(features, downlevel);
115
#[cfg(debug_assertions)]
116
let composer = naga_oil::compose::Composer::default();
117
#[cfg(not(debug_assertions))]
118
let composer = naga_oil::compose::Composer::non_validating();
119
120
let composer = composer.with_capabilities(capabilities);
121
122
Self {
123
composer,
124
load_module,
125
data: Default::default(),
126
#[cfg(feature = "shader_format_wesl")]
127
asset_paths: Default::default(),
128
shaders: Default::default(),
129
import_path_shaders: Default::default(),
130
waiting_on_import: Default::default(),
131
}
132
}
133
134
#[expect(
135
clippy::result_large_err,
136
reason = "See https://github.com/bevyengine/bevy/issues/19220"
137
)]
138
fn add_import_to_composer(
139
composer: &mut naga_oil::compose::Composer,
140
import_path_shaders: &HashMap<ShaderImport, AssetId<Shader>>,
141
shaders: &HashMap<AssetId<Shader>, Shader>,
142
import: &ShaderImport,
143
) -> Result<(), PipelineCacheError> {
144
// Early out if we've already imported this module
145
if composer.contains_module(&import.module_name()) {
146
return Ok(());
147
}
148
149
// Check if the import is available (this handles the recursive import case)
150
let shader = import_path_shaders
151
.get(import)
152
.and_then(|handle| shaders.get(handle))
153
.ok_or(PipelineCacheError::ShaderImportNotYetAvailable)?;
154
155
// Recurse down to ensure all import dependencies are met
156
for import in &shader.imports {
157
Self::add_import_to_composer(composer, import_path_shaders, shaders, import)?;
158
}
159
160
composer.add_composable_module(shader.into())?;
161
// if we fail to add a module the composer will tell us what is missing
162
163
Ok(())
164
}
165
166
#[expect(
167
clippy::result_large_err,
168
reason = "See https://github.com/bevyengine/bevy/issues/19220"
169
)]
170
pub fn get(
171
&mut self,
172
render_device: &RenderDevice,
173
pipeline: CachedPipelineId,
174
id: AssetId<Shader>,
175
shader_defs: &[ShaderDefVal],
176
) -> Result<Arc<ShaderModule>, PipelineCacheError> {
177
let shader = self
178
.shaders
179
.get(&id)
180
.ok_or(PipelineCacheError::ShaderNotLoaded(id))?;
181
182
let data = self.data.entry(id).or_default();
183
let n_asset_imports = shader
184
.imports()
185
.filter(|import| matches!(import, ShaderImport::AssetPath(_)))
186
.count();
187
let n_resolved_asset_imports = data
188
.resolved_imports
189
.keys()
190
.filter(|import| matches!(import, ShaderImport::AssetPath(_)))
191
.count();
192
if n_asset_imports != n_resolved_asset_imports {
193
return Err(PipelineCacheError::ShaderImportNotYetAvailable);
194
}
195
196
data.pipelines.insert(pipeline);
197
198
// PERF: this shader_defs clone isn't great. use raw_entry_mut when it stabilizes
199
let module = match data.processed_shaders.entry_ref(shader_defs) {
200
EntryRef::Occupied(entry) => entry.into_mut(),
201
EntryRef::Vacant(entry) => {
202
debug!(
203
"processing shader {}, with shader defs {:?}",
204
id, shader_defs
205
);
206
let shader_source = match &shader.source {
207
Source::SpirV(data) => ShaderCacheSource::SpirV(data.as_ref()),
208
#[cfg(feature = "shader_format_wesl")]
209
Source::Wesl(_) => {
210
if let ShaderImport::AssetPath(path) = shader.import_path() {
211
let shader_resolver =
212
ShaderResolver::new(&self.asset_paths, &self.shaders);
213
let module_path = wesl::syntax::ModulePath::from_path(path);
214
let mut compiler_options = wesl::CompileOptions {
215
imports: true,
216
condcomp: true,
217
lower: true,
218
..Default::default()
219
};
220
221
for shader_def in shader_defs {
222
match shader_def {
223
ShaderDefVal::Bool(key, value) => {
224
compiler_options.features.insert(key.clone(), *value);
225
}
226
_ => debug!(
227
"ShaderDefVal::Int and ShaderDefVal::UInt are not supported in wesl",
228
),
229
}
230
}
231
232
let compiled = wesl::compile(
233
&module_path,
234
&shader_resolver,
235
&wesl::EscapeMangler,
236
&compiler_options,
237
)
238
.unwrap();
239
240
ShaderCacheSource::Wgsl(compiled.to_string())
241
} else {
242
panic!("Wesl shaders must be imported from a file");
243
}
244
}
245
_ => {
246
for import in shader.imports() {
247
Self::add_import_to_composer(
248
&mut self.composer,
249
&self.import_path_shaders,
250
&self.shaders,
251
import,
252
)?;
253
}
254
255
let shader_defs = shader_defs
256
.iter()
257
.chain(shader.shader_defs.iter())
258
.map(|def| match def.clone() {
259
ShaderDefVal::Bool(k, v) => {
260
(k, naga_oil::compose::ShaderDefValue::Bool(v))
261
}
262
ShaderDefVal::Int(k, v) => {
263
(k, naga_oil::compose::ShaderDefValue::Int(v))
264
}
265
ShaderDefVal::UInt(k, v) => {
266
(k, naga_oil::compose::ShaderDefValue::UInt(v))
267
}
268
})
269
.collect::<std::collections::HashMap<_, _>>();
270
271
let naga = self.composer.make_naga_module(
272
naga_oil::compose::NagaModuleDescriptor {
273
shader_defs,
274
..shader.into()
275
},
276
)?;
277
278
#[cfg(not(feature = "decoupled_naga"))]
279
{
280
ShaderCacheSource::Naga(naga)
281
}
282
283
#[cfg(feature = "decoupled_naga")]
284
{
285
let mut validator = naga::valid::Validator::new(
286
naga::valid::ValidationFlags::all(),
287
self.composer.capabilities,
288
);
289
let module_info = validator.validate(&naga).unwrap();
290
let wgsl = naga::back::wgsl::write_string(
291
&naga,
292
&module_info,
293
naga::back::wgsl::WriterFlags::empty(),
294
)
295
.unwrap();
296
ShaderCacheSource::Wgsl(wgsl)
297
}
298
}
299
};
300
301
let shader_module =
302
(self.load_module)(render_device, shader_source, &shader.validate_shader)?;
303
304
entry.insert(Arc::new(shader_module))
305
}
306
};
307
308
Ok(module.clone())
309
}
310
311
fn clear(&mut self, id: AssetId<Shader>) -> Vec<CachedPipelineId> {
312
let mut shaders_to_clear = vec![id];
313
let mut pipelines_to_queue = Vec::new();
314
while let Some(handle) = shaders_to_clear.pop() {
315
if let Some(data) = self.data.get_mut(&handle) {
316
data.processed_shaders.clear();
317
pipelines_to_queue.extend(data.pipelines.iter().copied());
318
shaders_to_clear.extend(data.dependents.iter().copied());
319
320
if let Some(Shader { import_path, .. }) = self.shaders.get(&handle) {
321
self.composer
322
.remove_composable_module(&import_path.module_name());
323
}
324
}
325
}
326
327
pipelines_to_queue
328
}
329
330
pub fn set_shader(&mut self, id: AssetId<Shader>, shader: Shader) -> Vec<CachedPipelineId> {
331
let pipelines_to_queue = self.clear(id);
332
let path = shader.import_path();
333
self.import_path_shaders.insert(path.clone(), id);
334
if let Some(waiting_shaders) = self.waiting_on_import.get_mut(path) {
335
for waiting_shader in waiting_shaders.drain(..) {
336
// resolve waiting shader import
337
let data = self.data.entry(waiting_shader).or_default();
338
data.resolved_imports.insert(path.clone(), id);
339
// add waiting shader as dependent of this shader
340
let data = self.data.entry(id).or_default();
341
data.dependents.insert(waiting_shader);
342
}
343
}
344
345
for import in shader.imports() {
346
if let Some(import_id) = self.import_path_shaders.get(import).copied() {
347
// resolve import because it is currently available
348
let data = self.data.entry(id).or_default();
349
data.resolved_imports.insert(import.clone(), import_id);
350
// add this shader as a dependent of the import
351
let data = self.data.entry(import_id).or_default();
352
data.dependents.insert(id);
353
} else {
354
let waiting = self.waiting_on_import.entry(import.clone()).or_default();
355
waiting.push(id);
356
}
357
}
358
359
#[cfg(feature = "shader_format_wesl")]
360
if let Source::Wesl(_) = shader.source
361
&& let ShaderImport::AssetPath(path) = shader.import_path()
362
{
363
self.asset_paths
364
.insert(wesl::syntax::ModulePath::from_path(path), id);
365
}
366
self.shaders.insert(id, shader);
367
pipelines_to_queue
368
}
369
370
pub fn remove(&mut self, id: AssetId<Shader>) -> Vec<CachedPipelineId> {
371
let pipelines_to_queue = self.clear(id);
372
if let Some(shader) = self.shaders.remove(&id) {
373
self.import_path_shaders.remove(shader.import_path());
374
}
375
376
pipelines_to_queue
377
}
378
}
379
380
#[cfg(feature = "shader_format_wesl")]
381
pub struct ShaderResolver<'a> {
382
asset_paths: &'a HashMap<wesl::syntax::ModulePath, AssetId<Shader>>,
383
shaders: &'a HashMap<AssetId<Shader>, Shader>,
384
}
385
386
#[cfg(feature = "shader_format_wesl")]
387
impl<'a> ShaderResolver<'a> {
388
pub fn new(
389
asset_paths: &'a HashMap<wesl::syntax::ModulePath, AssetId<Shader>>,
390
shaders: &'a HashMap<AssetId<Shader>, Shader>,
391
) -> Self {
392
Self {
393
asset_paths,
394
shaders,
395
}
396
}
397
}
398
399
#[cfg(feature = "shader_format_wesl")]
400
impl<'a> wesl::Resolver for ShaderResolver<'a> {
401
fn resolve_source(
402
&self,
403
module_path: &wesl::syntax::ModulePath,
404
) -> Result<alloc::borrow::Cow<'_, str>, wesl::ResolveError> {
405
let asset_id = self.asset_paths.get(module_path).ok_or_else(|| {
406
wesl::ResolveError::ModuleNotFound(module_path.clone(), "Invalid asset id".to_string())
407
})?;
408
409
let shader = self.shaders.get(asset_id).unwrap();
410
Ok(alloc::borrow::Cow::Borrowed(shader.source.as_str()))
411
}
412
}
413
414
/// Type of error returned by a `PipelineCache` when the creation of a GPU pipeline object failed.
415
#[cfg_attr(
416
not(target_arch = "wasm32"),
417
expect(
418
clippy::large_enum_variant,
419
reason = "See https://github.com/bevyengine/bevy/issues/19220"
420
)
421
)]
422
#[derive(Error, Debug)]
423
pub enum PipelineCacheError {
424
#[error(
425
"Pipeline could not be compiled because the following shader could not be loaded: {0:?}"
426
)]
427
ShaderNotLoaded(AssetId<Shader>),
428
#[error(transparent)]
429
ProcessShaderError(#[from] naga_oil::compose::ComposerError),
430
#[error("Shader import not yet available.")]
431
ShaderImportNotYetAvailable,
432
#[error("Could not create shader module: {0}")]
433
CreateShaderModule(String),
434
}
435
436
// TODO: This needs to be kept up to date with the capabilities in the `create_validator` function in wgpu-core
437
// https://github.com/gfx-rs/wgpu/blob/trunk/wgpu-core/src/device/mod.rs#L449
438
// We can't use the `wgpu-core` function to detect the device's capabilities because `wgpu-core` isn't included in WebGPU builds.
439
/// Get the device's capabilities for use in `naga_oil`.
440
fn get_capabilities(features: Features, downlevel: DownlevelFlags) -> Capabilities {
441
let mut capabilities = Capabilities::empty();
442
capabilities.set(
443
Capabilities::PUSH_CONSTANT,
444
features.contains(Features::PUSH_CONSTANTS),
445
);
446
capabilities.set(
447
Capabilities::FLOAT64,
448
features.contains(Features::SHADER_F64),
449
);
450
capabilities.set(
451
Capabilities::PRIMITIVE_INDEX,
452
features.contains(Features::SHADER_PRIMITIVE_INDEX),
453
);
454
capabilities.set(
455
Capabilities::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING,
456
features.contains(Features::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING),
457
);
458
capabilities.set(
459
Capabilities::STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING,
460
features.contains(Features::STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING),
461
);
462
capabilities.set(
463
Capabilities::UNIFORM_BUFFER_ARRAY_NON_UNIFORM_INDEXING,
464
features.contains(Features::UNIFORM_BUFFER_BINDING_ARRAYS),
465
);
466
// TODO: This needs a proper wgpu feature
467
capabilities.set(
468
Capabilities::SAMPLER_NON_UNIFORM_INDEXING,
469
features.contains(Features::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING),
470
);
471
capabilities.set(
472
Capabilities::STORAGE_TEXTURE_16BIT_NORM_FORMATS,
473
features.contains(Features::TEXTURE_FORMAT_16BIT_NORM),
474
);
475
capabilities.set(
476
Capabilities::MULTIVIEW,
477
features.contains(Features::MULTIVIEW),
478
);
479
capabilities.set(
480
Capabilities::EARLY_DEPTH_TEST,
481
features.contains(Features::SHADER_EARLY_DEPTH_TEST),
482
);
483
capabilities.set(
484
Capabilities::SHADER_INT64,
485
features.contains(Features::SHADER_INT64),
486
);
487
capabilities.set(
488
Capabilities::SHADER_INT64_ATOMIC_MIN_MAX,
489
features.intersects(
490
Features::SHADER_INT64_ATOMIC_MIN_MAX | Features::SHADER_INT64_ATOMIC_ALL_OPS,
491
),
492
);
493
capabilities.set(
494
Capabilities::SHADER_INT64_ATOMIC_ALL_OPS,
495
features.contains(Features::SHADER_INT64_ATOMIC_ALL_OPS),
496
);
497
capabilities.set(
498
Capabilities::MULTISAMPLED_SHADING,
499
downlevel.contains(DownlevelFlags::MULTISAMPLED_SHADING),
500
);
501
capabilities.set(
502
Capabilities::RAY_QUERY,
503
features.contains(Features::EXPERIMENTAL_RAY_QUERY),
504
);
505
capabilities.set(
506
Capabilities::DUAL_SOURCE_BLENDING,
507
features.contains(Features::DUAL_SOURCE_BLENDING),
508
);
509
capabilities.set(
510
Capabilities::CLIP_DISTANCE,
511
features.contains(Features::CLIP_DISTANCES),
512
);
513
capabilities.set(
514
Capabilities::CUBE_ARRAY_TEXTURES,
515
downlevel.contains(DownlevelFlags::CUBE_ARRAY_TEXTURES),
516
);
517
capabilities.set(
518
Capabilities::SUBGROUP,
519
features.intersects(Features::SUBGROUP | Features::SUBGROUP_VERTEX),
520
);
521
capabilities.set(
522
Capabilities::SUBGROUP_BARRIER,
523
features.intersects(Features::SUBGROUP_BARRIER),
524
);
525
capabilities.set(
526
Capabilities::SUBGROUP_VERTEX_STAGE,
527
features.contains(Features::SUBGROUP_VERTEX),
528
);
529
capabilities.set(
530
Capabilities::SHADER_FLOAT32_ATOMIC,
531
features.contains(Features::SHADER_FLOAT32_ATOMIC),
532
);
533
capabilities.set(
534
Capabilities::TEXTURE_ATOMIC,
535
features.contains(Features::TEXTURE_ATOMIC),
536
);
537
capabilities.set(
538
Capabilities::TEXTURE_INT64_ATOMIC,
539
features.contains(Features::TEXTURE_INT64_ATOMIC),
540
);
541
capabilities.set(
542
Capabilities::SHADER_FLOAT16,
543
features.contains(Features::SHADER_F16),
544
);
545
capabilities.set(
546
Capabilities::RAY_HIT_VERTEX_POSITION,
547
features.intersects(Features::EXPERIMENTAL_RAY_HIT_VERTEX_RETURN),
548
);
549
550
capabilities
551
}
552
553