Path: blob/master/modules/models/sd3/sd3_model.py
3081 views
import contextlib12import torch34import k_diffusion5from modules.models.sd3.sd3_impls import BaseModel, SDVAE, SD3LatentFormat6from modules.models.sd3.sd3_cond import SD3Cond78from modules import shared, devices91011class SD3Denoiser(k_diffusion.external.DiscreteSchedule):12def __init__(self, inner_model, sigmas):13super().__init__(sigmas, quantize=shared.opts.enable_quantization)14self.inner_model = inner_model1516def forward(self, input, sigma, **kwargs):17return self.inner_model.apply_model(input, sigma, **kwargs)181920class SD3Inferencer(torch.nn.Module):21def __init__(self, state_dict, shift=3, use_ema=False):22super().__init__()2324self.shift = shift2526with torch.no_grad():27self.model = BaseModel(shift=shift, state_dict=state_dict, prefix="model.diffusion_model.", device="cpu", dtype=devices.dtype)28self.first_stage_model = SDVAE(device="cpu", dtype=devices.dtype_vae)29self.first_stage_model.dtype = self.model.diffusion_model.dtype3031self.alphas_cumprod = 1 / (self.model.model_sampling.sigmas ** 2 + 1)3233self.text_encoders = SD3Cond()34self.cond_stage_key = 'txt'3536self.parameterization = "eps"37self.model.conditioning_key = "crossattn"3839self.latent_format = SD3LatentFormat()40self.latent_channels = 164142@property43def cond_stage_model(self):44return self.text_encoders4546def before_load_weights(self, state_dict):47self.cond_stage_model.before_load_weights(state_dict)4849def ema_scope(self):50return contextlib.nullcontext()5152def get_learned_conditioning(self, batch: list[str]):53return self.cond_stage_model(batch)5455def apply_model(self, x, t, cond):56return self.model(x, t, c_crossattn=cond['crossattn'], y=cond['vector'])5758def decode_first_stage(self, latent):59latent = self.latent_format.process_out(latent)60return self.first_stage_model.decode(latent)6162def encode_first_stage(self, image):63latent = self.first_stage_model.encode(image)64return self.latent_format.process_in(latent)6566def get_first_stage_encoding(self, x):67return x6869def create_denoiser(self):70return SD3Denoiser(self, self.model.model_sampling.sigmas)7172def medvram_fields(self):73return [74(self, 'first_stage_model'),75(self, 'text_encoders'),76(self, 'model'),77]7879def add_noise_to_latent(self, x, noise, amount):80return x * (1 - amount) + noise * amount8182def fix_dimensions(self, width, height):83return width // 16 * 16, height // 16 * 168485def diffusers_weight_mapping(self):86for i in range(self.model.depth):87yield f"transformer.transformer_blocks.{i}.attn.to_q", f"diffusion_model_joint_blocks_{i}_x_block_attn_qkv_q_proj"88yield f"transformer.transformer_blocks.{i}.attn.to_k", f"diffusion_model_joint_blocks_{i}_x_block_attn_qkv_k_proj"89yield f"transformer.transformer_blocks.{i}.attn.to_v", f"diffusion_model_joint_blocks_{i}_x_block_attn_qkv_v_proj"90yield f"transformer.transformer_blocks.{i}.attn.to_out.0", f"diffusion_model_joint_blocks_{i}_x_block_attn_proj"9192yield f"transformer.transformer_blocks.{i}.attn.add_q_proj", f"diffusion_model_joint_blocks_{i}_context_block.attn_qkv_q_proj"93yield f"transformer.transformer_blocks.{i}.attn.add_k_proj", f"diffusion_model_joint_blocks_{i}_context_block.attn_qkv_k_proj"94yield f"transformer.transformer_blocks.{i}.attn.add_v_proj", f"diffusion_model_joint_blocks_{i}_context_block.attn_qkv_v_proj"95yield f"transformer.transformer_blocks.{i}.attn.add_out_proj.0", f"diffusion_model_joint_blocks_{i}_context_block_attn_proj"969798