Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
bevyengine
GitHub Repository: bevyengine/bevy
Path: blob/main/crates/bevy_anti_alias/src/dlss/mod.rs
6596 views
1
//! NVIDIA Deep Learning Super Sampling (DLSS).
2
//!
3
//! DLSS uses machine learning models to upscale and anti-alias images.
4
//!
5
//! Requires a NVIDIA RTX GPU, and the Windows/Linux Vulkan rendering backend. Does not work on other platforms.
6
//!
7
//! See https://github.com/bevyengine/dlss_wgpu for licensing requirements and setup instructions.
8
//!
9
//! # Usage
10
//! 1. Enable Bevy's `dlss` feature
11
//! 2. During app setup, insert the `DlssProjectId` resource before `DefaultPlugins`
12
//! 3. Check for the presence of `Option<Res<DlssSuperResolutionSupported>>` at runtime to see if DLSS is supported on the current machine
13
//! 4. Add the `Dlss` component to your camera entity, optionally setting a specific `DlssPerfQualityMode` (defaults to `Auto`)
14
//! 5. Optionally add sharpening via `ContrastAdaptiveSharpening`
15
//! 6. Custom rendering code, including third party crates, should account for the optional `MainPassResolutionOverride` to work with DLSS (see the `custom_render_phase` example)
16
17
mod extract;
18
mod node;
19
mod prepare;
20
21
pub use dlss_wgpu::DlssPerfQualityMode;
22
23
use bevy_app::{App, Plugin};
24
use bevy_core_pipeline::{
25
core_3d::graph::{Core3d, Node3d},
26
prepass::{DepthPrepass, MotionVectorPrepass},
27
};
28
use bevy_ecs::prelude::*;
29
use bevy_math::{UVec2, Vec2};
30
use bevy_reflect::{reflect_remote, Reflect};
31
use bevy_render::{
32
camera::{MipBias, TemporalJitter},
33
render_graph::{RenderGraphExt, ViewNodeRunner},
34
renderer::{
35
raw_vulkan_init::{AdditionalVulkanFeatures, RawVulkanInitSettings},
36
RenderDevice, RenderQueue,
37
},
38
texture::CachedTexture,
39
view::{prepare_view_targets, Hdr},
40
ExtractSchedule, Render, RenderApp, RenderSystems,
41
};
42
use dlss_wgpu::{
43
ray_reconstruction::{
44
DlssRayReconstruction, DlssRayReconstructionDepthMode, DlssRayReconstructionRoughnessMode,
45
},
46
super_resolution::DlssSuperResolution,
47
FeatureSupport,
48
};
49
use std::{
50
marker::PhantomData,
51
ops::Deref,
52
sync::{Arc, Mutex},
53
};
54
use tracing::info;
55
use uuid::Uuid;
56
57
/// Initializes DLSS support in the renderer. This must be registered before [`RenderPlugin`](bevy_render::RenderPlugin) because
58
/// it configures render init code.
59
#[derive(Default)]
60
pub struct DlssInitPlugin;
61
62
impl Plugin for DlssInitPlugin {
63
#[allow(unsafe_code)]
64
fn build(&self, app: &mut App) {
65
let dlss_project_id = app.world().get_resource::<DlssProjectId>()
66
.expect("The `dlss` feature is enabled, but DlssProjectId was not added to the App before DlssInitPlugin.").0;
67
let mut raw_vulkan_settings = app
68
.world_mut()
69
.get_resource_or_init::<RawVulkanInitSettings>();
70
71
// SAFETY: this does not remove any instance features and only enables features that are supported
72
unsafe {
73
raw_vulkan_settings.add_create_instance_callback(
74
move |mut args, additional_vulkan_features| {
75
let mut feature_support = FeatureSupport::default();
76
match dlss_wgpu::register_instance_extensions(
77
dlss_project_id,
78
&mut args,
79
&mut feature_support,
80
) {
81
Ok(_) => {
82
if feature_support.super_resolution_supported {
83
additional_vulkan_features.insert::<DlssSuperResolutionSupported>();
84
}
85
if feature_support.ray_reconstruction_supported {
86
additional_vulkan_features
87
.insert::<DlssRayReconstructionSupported>();
88
}
89
}
90
Err(_) => {}
91
}
92
},
93
);
94
}
95
96
// SAFETY: this does not remove any device features and only enables features that are supported
97
unsafe {
98
raw_vulkan_settings.add_create_device_callback(
99
move |mut args, adapter, additional_vulkan_features| {
100
let mut feature_support = FeatureSupport::default();
101
match dlss_wgpu::register_device_extensions(
102
dlss_project_id,
103
&mut args,
104
adapter,
105
&mut feature_support,
106
) {
107
Ok(_) => {
108
if feature_support.super_resolution_supported {
109
additional_vulkan_features.insert::<DlssSuperResolutionSupported>();
110
} else {
111
additional_vulkan_features.remove::<DlssSuperResolutionSupported>();
112
}
113
if feature_support.ray_reconstruction_supported {
114
additional_vulkan_features
115
.insert::<DlssRayReconstructionSupported>();
116
} else {
117
additional_vulkan_features
118
.remove::<DlssRayReconstructionSupported>();
119
}
120
}
121
Err(_) => {}
122
}
123
},
124
)
125
};
126
}
127
}
128
129
/// Enables DLSS support. This requires [`DlssInitPlugin`] to function, which must be manually registered in the correct order
130
/// prior to registering this plugin.
131
#[derive(Default)]
132
pub struct DlssPlugin;
133
134
impl Plugin for DlssPlugin {
135
fn build(&self, app: &mut App) {
136
app.register_type::<Dlss<DlssSuperResolutionFeature>>()
137
.register_type::<Dlss<DlssRayReconstructionFeature>>();
138
}
139
140
fn finish(&self, app: &mut App) {
141
let (super_resolution_supported, ray_reconstruction_supported) = {
142
let features = app
143
.sub_app_mut(RenderApp)
144
.world()
145
.resource::<AdditionalVulkanFeatures>();
146
(
147
features.has::<DlssSuperResolutionSupported>(),
148
features.has::<DlssRayReconstructionSupported>(),
149
)
150
};
151
if !super_resolution_supported {
152
return;
153
}
154
155
let wgpu_device = {
156
let render_world = app.sub_app(RenderApp).world();
157
let render_device = render_world.resource::<RenderDevice>().wgpu_device();
158
render_device.clone()
159
};
160
let project_id = app.world().get_resource::<DlssProjectId>()
161
.expect("The `dlss` feature is enabled, but DlssProjectId was not added to the App before DlssPlugin.");
162
let dlss_sdk = dlss_wgpu::DlssSdk::new(project_id.0, wgpu_device);
163
if dlss_sdk.is_err() {
164
info!("DLSS is not supported on this system");
165
return;
166
}
167
168
app.insert_resource(DlssSuperResolutionSupported);
169
if ray_reconstruction_supported {
170
app.insert_resource(DlssRayReconstructionSupported);
171
}
172
173
app.sub_app_mut(RenderApp)
174
.insert_resource(DlssSdk(dlss_sdk.unwrap()))
175
.add_systems(
176
ExtractSchedule,
177
(
178
extract::extract_dlss::<DlssSuperResolutionFeature>,
179
extract::extract_dlss::<DlssRayReconstructionFeature>,
180
),
181
)
182
.add_systems(
183
Render,
184
(
185
prepare::prepare_dlss::<DlssSuperResolutionFeature>,
186
prepare::prepare_dlss::<DlssRayReconstructionFeature>,
187
)
188
.in_set(RenderSystems::ManageViews)
189
.before(prepare_view_targets),
190
)
191
.add_render_graph_node::<ViewNodeRunner<node::DlssNode<DlssSuperResolutionFeature>>>(
192
Core3d,
193
Node3d::DlssSuperResolution,
194
)
195
.add_render_graph_node::<ViewNodeRunner<node::DlssNode<DlssRayReconstructionFeature>>>(
196
Core3d,
197
Node3d::DlssRayReconstruction,
198
)
199
.add_render_graph_edges(
200
Core3d,
201
(
202
Node3d::EndMainPass,
203
Node3d::MotionBlur, // Running before DLSS reduces edge artifacts and noise
204
Node3d::DlssSuperResolution,
205
Node3d::DlssRayReconstruction,
206
Node3d::Bloom,
207
Node3d::Tonemapping,
208
),
209
);
210
}
211
}
212
213
/// Camera component to enable DLSS.
214
#[derive(Component, Reflect, Clone)]
215
#[reflect(Component)]
216
#[require(TemporalJitter, MipBias, DepthPrepass, MotionVectorPrepass, Hdr)]
217
pub struct Dlss<F: DlssFeature = DlssSuperResolutionFeature> {
218
/// How much upscaling should be applied.
219
#[reflect(remote = DlssPerfQualityModeRemoteReflect)]
220
pub perf_quality_mode: DlssPerfQualityMode,
221
/// Set to true to delete the saved temporal history (past frames).
222
///
223
/// Useful for preventing ghosting when the history is no longer
224
/// representative of the current frame, such as in sudden camera cuts.
225
///
226
/// After setting this to true, it will automatically be toggled
227
/// back to false at the end of the frame.
228
pub reset: bool,
229
#[reflect(ignore)]
230
pub _phantom_data: PhantomData<F>,
231
}
232
233
impl Default for Dlss<DlssSuperResolutionFeature> {
234
fn default() -> Self {
235
Self {
236
perf_quality_mode: Default::default(),
237
reset: Default::default(),
238
_phantom_data: Default::default(),
239
}
240
}
241
}
242
243
pub trait DlssFeature: Reflect + Clone + Default {
244
type Context: Send;
245
246
fn upscaled_resolution(context: &Self::Context) -> UVec2;
247
248
fn render_resolution(context: &Self::Context) -> UVec2;
249
250
fn suggested_jitter(
251
context: &Self::Context,
252
frame_number: u32,
253
render_resolution: UVec2,
254
) -> Vec2;
255
256
fn suggested_mip_bias(context: &Self::Context, render_resolution: UVec2) -> f32;
257
258
fn new_context(
259
upscaled_resolution: UVec2,
260
perf_quality_mode: DlssPerfQualityMode,
261
feature_flags: dlss_wgpu::DlssFeatureFlags,
262
sdk: Arc<Mutex<dlss_wgpu::DlssSdk>>,
263
device: &RenderDevice,
264
queue: &RenderQueue,
265
) -> Result<Self::Context, dlss_wgpu::DlssError>;
266
}
267
268
/// DLSS Super Resolution.
269
///
270
/// Only available when the [`DlssSuperResolutionSupported`] resource exists.
271
#[derive(Reflect, Clone, Default)]
272
pub struct DlssSuperResolutionFeature;
273
274
impl DlssFeature for DlssSuperResolutionFeature {
275
type Context = DlssSuperResolution;
276
277
fn upscaled_resolution(context: &Self::Context) -> UVec2 {
278
context.upscaled_resolution()
279
}
280
281
fn render_resolution(context: &Self::Context) -> UVec2 {
282
context.render_resolution()
283
}
284
285
fn suggested_jitter(
286
context: &Self::Context,
287
frame_number: u32,
288
render_resolution: UVec2,
289
) -> Vec2 {
290
context.suggested_jitter(frame_number, render_resolution)
291
}
292
293
fn suggested_mip_bias(context: &Self::Context, render_resolution: UVec2) -> f32 {
294
context.suggested_mip_bias(render_resolution)
295
}
296
297
fn new_context(
298
upscaled_resolution: UVec2,
299
perf_quality_mode: DlssPerfQualityMode,
300
feature_flags: dlss_wgpu::DlssFeatureFlags,
301
sdk: Arc<Mutex<dlss_wgpu::DlssSdk>>,
302
device: &RenderDevice,
303
queue: &RenderQueue,
304
) -> Result<Self::Context, dlss_wgpu::DlssError> {
305
DlssSuperResolution::new(
306
upscaled_resolution,
307
perf_quality_mode,
308
feature_flags,
309
sdk,
310
device.wgpu_device(),
311
queue.deref(),
312
)
313
}
314
}
315
316
/// DLSS Ray Reconstruction.
317
///
318
/// Only available when the [`DlssRayReconstructionSupported`] resource exists.
319
#[derive(Reflect, Clone, Default)]
320
pub struct DlssRayReconstructionFeature;
321
322
impl DlssFeature for DlssRayReconstructionFeature {
323
type Context = DlssRayReconstruction;
324
325
fn upscaled_resolution(context: &Self::Context) -> UVec2 {
326
context.upscaled_resolution()
327
}
328
329
fn render_resolution(context: &Self::Context) -> UVec2 {
330
context.render_resolution()
331
}
332
333
fn suggested_jitter(
334
context: &Self::Context,
335
frame_number: u32,
336
render_resolution: UVec2,
337
) -> Vec2 {
338
context.suggested_jitter(frame_number, render_resolution)
339
}
340
341
fn suggested_mip_bias(context: &Self::Context, render_resolution: UVec2) -> f32 {
342
context.suggested_mip_bias(render_resolution)
343
}
344
345
fn new_context(
346
upscaled_resolution: UVec2,
347
perf_quality_mode: DlssPerfQualityMode,
348
feature_flags: dlss_wgpu::DlssFeatureFlags,
349
sdk: Arc<Mutex<dlss_wgpu::DlssSdk>>,
350
device: &RenderDevice,
351
queue: &RenderQueue,
352
) -> Result<Self::Context, dlss_wgpu::DlssError> {
353
DlssRayReconstruction::new(
354
upscaled_resolution,
355
perf_quality_mode,
356
feature_flags,
357
DlssRayReconstructionRoughnessMode::Packed,
358
DlssRayReconstructionDepthMode::Hardware,
359
sdk,
360
device.wgpu_device(),
361
queue.deref(),
362
)
363
}
364
}
365
366
/// Additional textures needed as inputs for [`DlssRayReconstructionFeature`].
367
#[derive(Component)]
368
pub struct ViewDlssRayReconstructionTextures {
369
pub diffuse_albedo: CachedTexture,
370
pub specular_albedo: CachedTexture,
371
pub normal_roughness: CachedTexture,
372
pub specular_motion_vectors: CachedTexture,
373
}
374
375
#[reflect_remote(DlssPerfQualityMode)]
376
#[derive(Default)]
377
enum DlssPerfQualityModeRemoteReflect {
378
#[default]
379
Auto,
380
Dlaa,
381
Quality,
382
Balanced,
383
Performance,
384
UltraPerformance,
385
}
386
387
#[derive(Resource)]
388
struct DlssSdk(Arc<Mutex<dlss_wgpu::DlssSdk>>);
389
390
/// Application-specific ID for DLSS.
391
///
392
/// See the DLSS programming guide for more info.
393
#[derive(Resource, Clone)]
394
pub struct DlssProjectId(pub Uuid);
395
396
/// When DLSS Super Resolution is supported by the current system, this resource will exist in the main world.
397
/// Otherwise this resource will be absent.
398
#[derive(Resource, Clone, Copy)]
399
pub struct DlssSuperResolutionSupported;
400
401
/// When DLSS Ray Reconstruction is supported by the current system, this resource will exist in the main world.
402
/// Otherwise this resource will be absent.
403
#[derive(Resource, Clone, Copy)]
404
pub struct DlssRayReconstructionSupported;
405
406