Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
automatic1111
GitHub Repository: automatic1111/stable-diffusion-webui
Path: blob/master/modules/models/sd3/sd3_impls.py
3072 views
1
### Impls of the SD3 core diffusion model and VAE
2
3
import torch
4
import math
5
import einops
6
from modules.models.sd3.mmdit import MMDiT
7
from PIL import Image
8
9
10
#################################################################################################
11
### MMDiT Model Wrapping
12
#################################################################################################
13
14
15
class ModelSamplingDiscreteFlow(torch.nn.Module):
16
"""Helper for sampler scheduling (ie timestep/sigma calculations) for Discrete Flow models"""
17
def __init__(self, shift=1.0):
18
super().__init__()
19
self.shift = shift
20
timesteps = 1000
21
ts = self.sigma(torch.arange(1, timesteps + 1, 1))
22
self.register_buffer('sigmas', ts)
23
24
@property
25
def sigma_min(self):
26
return self.sigmas[0]
27
28
@property
29
def sigma_max(self):
30
return self.sigmas[-1]
31
32
def timestep(self, sigma):
33
return sigma * 1000
34
35
def sigma(self, timestep: torch.Tensor):
36
timestep = timestep / 1000.0
37
if self.shift == 1.0:
38
return timestep
39
return self.shift * timestep / (1 + (self.shift - 1) * timestep)
40
41
def calculate_denoised(self, sigma, model_output, model_input):
42
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
43
return model_input - model_output * sigma
44
45
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
46
return sigma * noise + (1.0 - sigma) * latent_image
47
48
49
class BaseModel(torch.nn.Module):
50
"""Wrapper around the core MM-DiT model"""
51
def __init__(self, shift=1.0, device=None, dtype=torch.float32, state_dict=None, prefix=""):
52
super().__init__()
53
# Important configuration values can be quickly determined by checking shapes in the source file
54
# Some of these will vary between models (eg 2B vs 8B primarily differ in their depth, but also other details change)
55
patch_size = state_dict[f"{prefix}x_embedder.proj.weight"].shape[2]
56
depth = state_dict[f"{prefix}x_embedder.proj.weight"].shape[0] // 64
57
num_patches = state_dict[f"{prefix}pos_embed"].shape[1]
58
pos_embed_max_size = round(math.sqrt(num_patches))
59
adm_in_channels = state_dict[f"{prefix}y_embedder.mlp.0.weight"].shape[1]
60
context_shape = state_dict[f"{prefix}context_embedder.weight"].shape
61
context_embedder_config = {
62
"target": "torch.nn.Linear",
63
"params": {
64
"in_features": context_shape[1],
65
"out_features": context_shape[0]
66
}
67
}
68
self.diffusion_model = MMDiT(input_size=None, pos_embed_scaling_factor=None, pos_embed_offset=None, pos_embed_max_size=pos_embed_max_size, patch_size=patch_size, in_channels=16, depth=depth, num_patches=num_patches, adm_in_channels=adm_in_channels, context_embedder_config=context_embedder_config, device=device, dtype=dtype)
69
self.model_sampling = ModelSamplingDiscreteFlow(shift=shift)
70
self.depth = depth
71
72
def apply_model(self, x, sigma, c_crossattn=None, y=None):
73
dtype = self.get_dtype()
74
timestep = self.model_sampling.timestep(sigma).float()
75
model_output = self.diffusion_model(x.to(dtype), timestep, context=c_crossattn.to(dtype), y=y.to(dtype)).float()
76
return self.model_sampling.calculate_denoised(sigma, model_output, x)
77
78
def forward(self, *args, **kwargs):
79
return self.apply_model(*args, **kwargs)
80
81
def get_dtype(self):
82
return self.diffusion_model.dtype
83
84
85
class CFGDenoiser(torch.nn.Module):
86
"""Helper for applying CFG Scaling to diffusion outputs"""
87
def __init__(self, model):
88
super().__init__()
89
self.model = model
90
91
def forward(self, x, timestep, cond, uncond, cond_scale):
92
# Run cond and uncond in a batch together
93
batched = self.model.apply_model(torch.cat([x, x]), torch.cat([timestep, timestep]), c_crossattn=torch.cat([cond["c_crossattn"], uncond["c_crossattn"]]), y=torch.cat([cond["y"], uncond["y"]]))
94
# Then split and apply CFG Scaling
95
pos_out, neg_out = batched.chunk(2)
96
scaled = neg_out + (pos_out - neg_out) * cond_scale
97
return scaled
98
99
100
class SD3LatentFormat:
101
"""Latents are slightly shifted from center - this class must be called after VAE Decode to correct for the shift"""
102
def __init__(self):
103
self.scale_factor = 1.5305
104
self.shift_factor = 0.0609
105
106
def process_in(self, latent):
107
return (latent - self.shift_factor) * self.scale_factor
108
109
def process_out(self, latent):
110
return (latent / self.scale_factor) + self.shift_factor
111
112
def decode_latent_to_preview(self, x0):
113
"""Quick RGB approximate preview of sd3 latents"""
114
factors = torch.tensor([
115
[-0.0645, 0.0177, 0.1052], [ 0.0028, 0.0312, 0.0650],
116
[ 0.1848, 0.0762, 0.0360], [ 0.0944, 0.0360, 0.0889],
117
[ 0.0897, 0.0506, -0.0364], [-0.0020, 0.1203, 0.0284],
118
[ 0.0855, 0.0118, 0.0283], [-0.0539, 0.0658, 0.1047],
119
[-0.0057, 0.0116, 0.0700], [-0.0412, 0.0281, -0.0039],
120
[ 0.1106, 0.1171, 0.1220], [-0.0248, 0.0682, -0.0481],
121
[ 0.0815, 0.0846, 0.1207], [-0.0120, -0.0055, -0.0867],
122
[-0.0749, -0.0634, -0.0456], [-0.1418, -0.1457, -0.1259]
123
], device="cpu")
124
latent_image = x0[0].permute(1, 2, 0).cpu() @ factors
125
126
latents_ubyte = (((latent_image + 1) / 2)
127
.clamp(0, 1) # change scale from -1..1 to 0..1
128
.mul(0xFF) # to 0..255
129
.byte()).cpu()
130
131
return Image.fromarray(latents_ubyte.numpy())
132
133
134
#################################################################################################
135
### K-Diffusion Sampling
136
#################################################################################################
137
138
139
def append_dims(x, target_dims):
140
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
141
dims_to_append = target_dims - x.ndim
142
return x[(...,) + (None,) * dims_to_append]
143
144
145
def to_d(x, sigma, denoised):
146
"""Converts a denoiser output to a Karras ODE derivative."""
147
return (x - denoised) / append_dims(sigma, x.ndim)
148
149
150
@torch.no_grad()
151
@torch.autocast("cuda", dtype=torch.float16)
152
def sample_euler(model, x, sigmas, extra_args=None):
153
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
154
extra_args = {} if extra_args is None else extra_args
155
s_in = x.new_ones([x.shape[0]])
156
for i in range(len(sigmas) - 1):
157
sigma_hat = sigmas[i]
158
denoised = model(x, sigma_hat * s_in, **extra_args)
159
d = to_d(x, sigma_hat, denoised)
160
dt = sigmas[i + 1] - sigma_hat
161
# Euler method
162
x = x + d * dt
163
return x
164
165
166
#################################################################################################
167
### VAE
168
#################################################################################################
169
170
171
def Normalize(in_channels, num_groups=32, dtype=torch.float32, device=None):
172
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
173
174
175
class ResnetBlock(torch.nn.Module):
176
def __init__(self, *, in_channels, out_channels=None, dtype=torch.float32, device=None):
177
super().__init__()
178
self.in_channels = in_channels
179
out_channels = in_channels if out_channels is None else out_channels
180
self.out_channels = out_channels
181
182
self.norm1 = Normalize(in_channels, dtype=dtype, device=device)
183
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
184
self.norm2 = Normalize(out_channels, dtype=dtype, device=device)
185
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
186
if self.in_channels != self.out_channels:
187
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)
188
else:
189
self.nin_shortcut = None
190
self.swish = torch.nn.SiLU(inplace=True)
191
192
def forward(self, x):
193
hidden = x
194
hidden = self.norm1(hidden)
195
hidden = self.swish(hidden)
196
hidden = self.conv1(hidden)
197
hidden = self.norm2(hidden)
198
hidden = self.swish(hidden)
199
hidden = self.conv2(hidden)
200
if self.in_channels != self.out_channels:
201
x = self.nin_shortcut(x)
202
return x + hidden
203
204
205
class AttnBlock(torch.nn.Module):
206
def __init__(self, in_channels, dtype=torch.float32, device=None):
207
super().__init__()
208
self.norm = Normalize(in_channels, dtype=dtype, device=device)
209
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)
210
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)
211
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)
212
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)
213
214
def forward(self, x):
215
hidden = self.norm(x)
216
q = self.q(hidden)
217
k = self.k(hidden)
218
v = self.v(hidden)
219
b, c, h, w = q.shape
220
q, k, v = [einops.rearrange(x, "b c h w -> b 1 (h w) c").contiguous() for x in (q, k, v)]
221
hidden = torch.nn.functional.scaled_dot_product_attention(q, k, v) # scale is dim ** -0.5 per default
222
hidden = einops.rearrange(hidden, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
223
hidden = self.proj_out(hidden)
224
return x + hidden
225
226
227
class Downsample(torch.nn.Module):
228
def __init__(self, in_channels, dtype=torch.float32, device=None):
229
super().__init__()
230
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0, dtype=dtype, device=device)
231
232
def forward(self, x):
233
pad = (0,1,0,1)
234
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
235
x = self.conv(x)
236
return x
237
238
239
class Upsample(torch.nn.Module):
240
def __init__(self, in_channels, dtype=torch.float32, device=None):
241
super().__init__()
242
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
243
244
def forward(self, x):
245
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
246
x = self.conv(x)
247
return x
248
249
250
class VAEEncoder(torch.nn.Module):
251
def __init__(self, ch=128, ch_mult=(1,2,4,4), num_res_blocks=2, in_channels=3, z_channels=16, dtype=torch.float32, device=None):
252
super().__init__()
253
self.num_resolutions = len(ch_mult)
254
self.num_res_blocks = num_res_blocks
255
# downsampling
256
self.conv_in = torch.nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
257
in_ch_mult = (1,) + tuple(ch_mult)
258
self.in_ch_mult = in_ch_mult
259
self.down = torch.nn.ModuleList()
260
for i_level in range(self.num_resolutions):
261
block = torch.nn.ModuleList()
262
attn = torch.nn.ModuleList()
263
block_in = ch*in_ch_mult[i_level]
264
block_out = ch*ch_mult[i_level]
265
for _ in range(num_res_blocks):
266
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device))
267
block_in = block_out
268
down = torch.nn.Module()
269
down.block = block
270
down.attn = attn
271
if i_level != self.num_resolutions - 1:
272
down.downsample = Downsample(block_in, dtype=dtype, device=device)
273
self.down.append(down)
274
# middle
275
self.mid = torch.nn.Module()
276
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
277
self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device)
278
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
279
# end
280
self.norm_out = Normalize(block_in, dtype=dtype, device=device)
281
self.conv_out = torch.nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
282
self.swish = torch.nn.SiLU(inplace=True)
283
284
def forward(self, x):
285
# downsampling
286
hs = [self.conv_in(x)]
287
for i_level in range(self.num_resolutions):
288
for i_block in range(self.num_res_blocks):
289
h = self.down[i_level].block[i_block](hs[-1])
290
hs.append(h)
291
if i_level != self.num_resolutions-1:
292
hs.append(self.down[i_level].downsample(hs[-1]))
293
# middle
294
h = hs[-1]
295
h = self.mid.block_1(h)
296
h = self.mid.attn_1(h)
297
h = self.mid.block_2(h)
298
# end
299
h = self.norm_out(h)
300
h = self.swish(h)
301
h = self.conv_out(h)
302
return h
303
304
305
class VAEDecoder(torch.nn.Module):
306
def __init__(self, ch=128, out_ch=3, ch_mult=(1, 2, 4, 4), num_res_blocks=2, resolution=256, z_channels=16, dtype=torch.float32, device=None):
307
super().__init__()
308
self.num_resolutions = len(ch_mult)
309
self.num_res_blocks = num_res_blocks
310
block_in = ch * ch_mult[self.num_resolutions - 1]
311
curr_res = resolution // 2 ** (self.num_resolutions - 1)
312
# z to block_in
313
self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
314
# middle
315
self.mid = torch.nn.Module()
316
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
317
self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device)
318
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
319
# upsampling
320
self.up = torch.nn.ModuleList()
321
for i_level in reversed(range(self.num_resolutions)):
322
block = torch.nn.ModuleList()
323
block_out = ch * ch_mult[i_level]
324
for _ in range(self.num_res_blocks + 1):
325
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device))
326
block_in = block_out
327
up = torch.nn.Module()
328
up.block = block
329
if i_level != 0:
330
up.upsample = Upsample(block_in, dtype=dtype, device=device)
331
curr_res = curr_res * 2
332
self.up.insert(0, up) # prepend to get consistent order
333
# end
334
self.norm_out = Normalize(block_in, dtype=dtype, device=device)
335
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
336
self.swish = torch.nn.SiLU(inplace=True)
337
338
def forward(self, z):
339
# z to block_in
340
hidden = self.conv_in(z)
341
# middle
342
hidden = self.mid.block_1(hidden)
343
hidden = self.mid.attn_1(hidden)
344
hidden = self.mid.block_2(hidden)
345
# upsampling
346
for i_level in reversed(range(self.num_resolutions)):
347
for i_block in range(self.num_res_blocks + 1):
348
hidden = self.up[i_level].block[i_block](hidden)
349
if i_level != 0:
350
hidden = self.up[i_level].upsample(hidden)
351
# end
352
hidden = self.norm_out(hidden)
353
hidden = self.swish(hidden)
354
hidden = self.conv_out(hidden)
355
return hidden
356
357
358
class SDVAE(torch.nn.Module):
359
def __init__(self, dtype=torch.float32, device=None):
360
super().__init__()
361
self.encoder = VAEEncoder(dtype=dtype, device=device)
362
self.decoder = VAEDecoder(dtype=dtype, device=device)
363
364
@torch.autocast("cuda", dtype=torch.float16)
365
def decode(self, latent):
366
return self.decoder(latent)
367
368
@torch.autocast("cuda", dtype=torch.float16)
369
def encode(self, image):
370
hidden = self.encoder(image)
371
mean, logvar = torch.chunk(hidden, 2, dim=1)
372
logvar = torch.clamp(logvar, -30.0, 20.0)
373
std = torch.exp(0.5 * logvar)
374
return mean + std * torch.randn_like(mean)
375
376