Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
automatic1111
GitHub Repository: automatic1111/stable-diffusion-webui
Path: blob/master/modules/models/sd3/sd3_model.py
3081 views
1
import contextlib
2
3
import torch
4
5
import k_diffusion
6
from modules.models.sd3.sd3_impls import BaseModel, SDVAE, SD3LatentFormat
7
from modules.models.sd3.sd3_cond import SD3Cond
8
9
from modules import shared, devices
10
11
12
class SD3Denoiser(k_diffusion.external.DiscreteSchedule):
13
def __init__(self, inner_model, sigmas):
14
super().__init__(sigmas, quantize=shared.opts.enable_quantization)
15
self.inner_model = inner_model
16
17
def forward(self, input, sigma, **kwargs):
18
return self.inner_model.apply_model(input, sigma, **kwargs)
19
20
21
class SD3Inferencer(torch.nn.Module):
22
def __init__(self, state_dict, shift=3, use_ema=False):
23
super().__init__()
24
25
self.shift = shift
26
27
with torch.no_grad():
28
self.model = BaseModel(shift=shift, state_dict=state_dict, prefix="model.diffusion_model.", device="cpu", dtype=devices.dtype)
29
self.first_stage_model = SDVAE(device="cpu", dtype=devices.dtype_vae)
30
self.first_stage_model.dtype = self.model.diffusion_model.dtype
31
32
self.alphas_cumprod = 1 / (self.model.model_sampling.sigmas ** 2 + 1)
33
34
self.text_encoders = SD3Cond()
35
self.cond_stage_key = 'txt'
36
37
self.parameterization = "eps"
38
self.model.conditioning_key = "crossattn"
39
40
self.latent_format = SD3LatentFormat()
41
self.latent_channels = 16
42
43
@property
44
def cond_stage_model(self):
45
return self.text_encoders
46
47
def before_load_weights(self, state_dict):
48
self.cond_stage_model.before_load_weights(state_dict)
49
50
def ema_scope(self):
51
return contextlib.nullcontext()
52
53
def get_learned_conditioning(self, batch: list[str]):
54
return self.cond_stage_model(batch)
55
56
def apply_model(self, x, t, cond):
57
return self.model(x, t, c_crossattn=cond['crossattn'], y=cond['vector'])
58
59
def decode_first_stage(self, latent):
60
latent = self.latent_format.process_out(latent)
61
return self.first_stage_model.decode(latent)
62
63
def encode_first_stage(self, image):
64
latent = self.first_stage_model.encode(image)
65
return self.latent_format.process_in(latent)
66
67
def get_first_stage_encoding(self, x):
68
return x
69
70
def create_denoiser(self):
71
return SD3Denoiser(self, self.model.model_sampling.sigmas)
72
73
def medvram_fields(self):
74
return [
75
(self, 'first_stage_model'),
76
(self, 'text_encoders'),
77
(self, 'model'),
78
]
79
80
def add_noise_to_latent(self, x, noise, amount):
81
return x * (1 - amount) + noise * amount
82
83
def fix_dimensions(self, width, height):
84
return width // 16 * 16, height // 16 * 16
85
86
def diffusers_weight_mapping(self):
87
for i in range(self.model.depth):
88
yield f"transformer.transformer_blocks.{i}.attn.to_q", f"diffusion_model_joint_blocks_{i}_x_block_attn_qkv_q_proj"
89
yield f"transformer.transformer_blocks.{i}.attn.to_k", f"diffusion_model_joint_blocks_{i}_x_block_attn_qkv_k_proj"
90
yield f"transformer.transformer_blocks.{i}.attn.to_v", f"diffusion_model_joint_blocks_{i}_x_block_attn_qkv_v_proj"
91
yield f"transformer.transformer_blocks.{i}.attn.to_out.0", f"diffusion_model_joint_blocks_{i}_x_block_attn_proj"
92
93
yield f"transformer.transformer_blocks.{i}.attn.add_q_proj", f"diffusion_model_joint_blocks_{i}_context_block.attn_qkv_q_proj"
94
yield f"transformer.transformer_blocks.{i}.attn.add_k_proj", f"diffusion_model_joint_blocks_{i}_context_block.attn_qkv_k_proj"
95
yield f"transformer.transformer_blocks.{i}.attn.add_v_proj", f"diffusion_model_joint_blocks_{i}_context_block.attn_qkv_v_proj"
96
yield f"transformer.transformer_blocks.{i}.attn.add_out_proj.0", f"diffusion_model_joint_blocks_{i}_context_block_attn_proj"
97
98