Path: blob/master/modules/models/sd3/sd3_impls.py
3072 views
### Impls of the SD3 core diffusion model and VAE12import torch3import math4import einops5from modules.models.sd3.mmdit import MMDiT6from PIL import Image789#################################################################################################10### MMDiT Model Wrapping11#################################################################################################121314class ModelSamplingDiscreteFlow(torch.nn.Module):15"""Helper for sampler scheduling (ie timestep/sigma calculations) for Discrete Flow models"""16def __init__(self, shift=1.0):17super().__init__()18self.shift = shift19timesteps = 100020ts = self.sigma(torch.arange(1, timesteps + 1, 1))21self.register_buffer('sigmas', ts)2223@property24def sigma_min(self):25return self.sigmas[0]2627@property28def sigma_max(self):29return self.sigmas[-1]3031def timestep(self, sigma):32return sigma * 10003334def sigma(self, timestep: torch.Tensor):35timestep = timestep / 1000.036if self.shift == 1.0:37return timestep38return self.shift * timestep / (1 + (self.shift - 1) * timestep)3940def calculate_denoised(self, sigma, model_output, model_input):41sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))42return model_input - model_output * sigma4344def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):45return sigma * noise + (1.0 - sigma) * latent_image464748class BaseModel(torch.nn.Module):49"""Wrapper around the core MM-DiT model"""50def __init__(self, shift=1.0, device=None, dtype=torch.float32, state_dict=None, prefix=""):51super().__init__()52# Important configuration values can be quickly determined by checking shapes in the source file53# Some of these will vary between models (eg 2B vs 8B primarily differ in their depth, but also other details change)54patch_size = state_dict[f"{prefix}x_embedder.proj.weight"].shape[2]55depth = state_dict[f"{prefix}x_embedder.proj.weight"].shape[0] // 6456num_patches = state_dict[f"{prefix}pos_embed"].shape[1]57pos_embed_max_size = round(math.sqrt(num_patches))58adm_in_channels = state_dict[f"{prefix}y_embedder.mlp.0.weight"].shape[1]59context_shape = state_dict[f"{prefix}context_embedder.weight"].shape60context_embedder_config = {61"target": "torch.nn.Linear",62"params": {63"in_features": context_shape[1],64"out_features": context_shape[0]65}66}67self.diffusion_model = MMDiT(input_size=None, pos_embed_scaling_factor=None, pos_embed_offset=None, pos_embed_max_size=pos_embed_max_size, patch_size=patch_size, in_channels=16, depth=depth, num_patches=num_patches, adm_in_channels=adm_in_channels, context_embedder_config=context_embedder_config, device=device, dtype=dtype)68self.model_sampling = ModelSamplingDiscreteFlow(shift=shift)69self.depth = depth7071def apply_model(self, x, sigma, c_crossattn=None, y=None):72dtype = self.get_dtype()73timestep = self.model_sampling.timestep(sigma).float()74model_output = self.diffusion_model(x.to(dtype), timestep, context=c_crossattn.to(dtype), y=y.to(dtype)).float()75return self.model_sampling.calculate_denoised(sigma, model_output, x)7677def forward(self, *args, **kwargs):78return self.apply_model(*args, **kwargs)7980def get_dtype(self):81return self.diffusion_model.dtype828384class CFGDenoiser(torch.nn.Module):85"""Helper for applying CFG Scaling to diffusion outputs"""86def __init__(self, model):87super().__init__()88self.model = model8990def forward(self, x, timestep, cond, uncond, cond_scale):91# Run cond and uncond in a batch together92batched = self.model.apply_model(torch.cat([x, x]), torch.cat([timestep, timestep]), c_crossattn=torch.cat([cond["c_crossattn"], uncond["c_crossattn"]]), y=torch.cat([cond["y"], uncond["y"]]))93# Then split and apply CFG Scaling94pos_out, neg_out = batched.chunk(2)95scaled = neg_out + (pos_out - neg_out) * cond_scale96return scaled979899class SD3LatentFormat:100"""Latents are slightly shifted from center - this class must be called after VAE Decode to correct for the shift"""101def __init__(self):102self.scale_factor = 1.5305103self.shift_factor = 0.0609104105def process_in(self, latent):106return (latent - self.shift_factor) * self.scale_factor107108def process_out(self, latent):109return (latent / self.scale_factor) + self.shift_factor110111def decode_latent_to_preview(self, x0):112"""Quick RGB approximate preview of sd3 latents"""113factors = torch.tensor([114[-0.0645, 0.0177, 0.1052], [ 0.0028, 0.0312, 0.0650],115[ 0.1848, 0.0762, 0.0360], [ 0.0944, 0.0360, 0.0889],116[ 0.0897, 0.0506, -0.0364], [-0.0020, 0.1203, 0.0284],117[ 0.0855, 0.0118, 0.0283], [-0.0539, 0.0658, 0.1047],118[-0.0057, 0.0116, 0.0700], [-0.0412, 0.0281, -0.0039],119[ 0.1106, 0.1171, 0.1220], [-0.0248, 0.0682, -0.0481],120[ 0.0815, 0.0846, 0.1207], [-0.0120, -0.0055, -0.0867],121[-0.0749, -0.0634, -0.0456], [-0.1418, -0.1457, -0.1259]122], device="cpu")123latent_image = x0[0].permute(1, 2, 0).cpu() @ factors124125latents_ubyte = (((latent_image + 1) / 2)126.clamp(0, 1) # change scale from -1..1 to 0..1127.mul(0xFF) # to 0..255128.byte()).cpu()129130return Image.fromarray(latents_ubyte.numpy())131132133#################################################################################################134### K-Diffusion Sampling135#################################################################################################136137138def append_dims(x, target_dims):139"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""140dims_to_append = target_dims - x.ndim141return x[(...,) + (None,) * dims_to_append]142143144def to_d(x, sigma, denoised):145"""Converts a denoiser output to a Karras ODE derivative."""146return (x - denoised) / append_dims(sigma, x.ndim)147148149@torch.no_grad()150@torch.autocast("cuda", dtype=torch.float16)151def sample_euler(model, x, sigmas, extra_args=None):152"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""153extra_args = {} if extra_args is None else extra_args154s_in = x.new_ones([x.shape[0]])155for i in range(len(sigmas) - 1):156sigma_hat = sigmas[i]157denoised = model(x, sigma_hat * s_in, **extra_args)158d = to_d(x, sigma_hat, denoised)159dt = sigmas[i + 1] - sigma_hat160# Euler method161x = x + d * dt162return x163164165#################################################################################################166### VAE167#################################################################################################168169170def Normalize(in_channels, num_groups=32, dtype=torch.float32, device=None):171return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)172173174class ResnetBlock(torch.nn.Module):175def __init__(self, *, in_channels, out_channels=None, dtype=torch.float32, device=None):176super().__init__()177self.in_channels = in_channels178out_channels = in_channels if out_channels is None else out_channels179self.out_channels = out_channels180181self.norm1 = Normalize(in_channels, dtype=dtype, device=device)182self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)183self.norm2 = Normalize(out_channels, dtype=dtype, device=device)184self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)185if self.in_channels != self.out_channels:186self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)187else:188self.nin_shortcut = None189self.swish = torch.nn.SiLU(inplace=True)190191def forward(self, x):192hidden = x193hidden = self.norm1(hidden)194hidden = self.swish(hidden)195hidden = self.conv1(hidden)196hidden = self.norm2(hidden)197hidden = self.swish(hidden)198hidden = self.conv2(hidden)199if self.in_channels != self.out_channels:200x = self.nin_shortcut(x)201return x + hidden202203204class AttnBlock(torch.nn.Module):205def __init__(self, in_channels, dtype=torch.float32, device=None):206super().__init__()207self.norm = Normalize(in_channels, dtype=dtype, device=device)208self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)209self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)210self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)211self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)212213def forward(self, x):214hidden = self.norm(x)215q = self.q(hidden)216k = self.k(hidden)217v = self.v(hidden)218b, c, h, w = q.shape219q, k, v = [einops.rearrange(x, "b c h w -> b 1 (h w) c").contiguous() for x in (q, k, v)]220hidden = torch.nn.functional.scaled_dot_product_attention(q, k, v) # scale is dim ** -0.5 per default221hidden = einops.rearrange(hidden, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)222hidden = self.proj_out(hidden)223return x + hidden224225226class Downsample(torch.nn.Module):227def __init__(self, in_channels, dtype=torch.float32, device=None):228super().__init__()229self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0, dtype=dtype, device=device)230231def forward(self, x):232pad = (0,1,0,1)233x = torch.nn.functional.pad(x, pad, mode="constant", value=0)234x = self.conv(x)235return x236237238class Upsample(torch.nn.Module):239def __init__(self, in_channels, dtype=torch.float32, device=None):240super().__init__()241self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)242243def forward(self, x):244x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")245x = self.conv(x)246return x247248249class VAEEncoder(torch.nn.Module):250def __init__(self, ch=128, ch_mult=(1,2,4,4), num_res_blocks=2, in_channels=3, z_channels=16, dtype=torch.float32, device=None):251super().__init__()252self.num_resolutions = len(ch_mult)253self.num_res_blocks = num_res_blocks254# downsampling255self.conv_in = torch.nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)256in_ch_mult = (1,) + tuple(ch_mult)257self.in_ch_mult = in_ch_mult258self.down = torch.nn.ModuleList()259for i_level in range(self.num_resolutions):260block = torch.nn.ModuleList()261attn = torch.nn.ModuleList()262block_in = ch*in_ch_mult[i_level]263block_out = ch*ch_mult[i_level]264for _ in range(num_res_blocks):265block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device))266block_in = block_out267down = torch.nn.Module()268down.block = block269down.attn = attn270if i_level != self.num_resolutions - 1:271down.downsample = Downsample(block_in, dtype=dtype, device=device)272self.down.append(down)273# middle274self.mid = torch.nn.Module()275self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)276self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device)277self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)278# end279self.norm_out = Normalize(block_in, dtype=dtype, device=device)280self.conv_out = torch.nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)281self.swish = torch.nn.SiLU(inplace=True)282283def forward(self, x):284# downsampling285hs = [self.conv_in(x)]286for i_level in range(self.num_resolutions):287for i_block in range(self.num_res_blocks):288h = self.down[i_level].block[i_block](hs[-1])289hs.append(h)290if i_level != self.num_resolutions-1:291hs.append(self.down[i_level].downsample(hs[-1]))292# middle293h = hs[-1]294h = self.mid.block_1(h)295h = self.mid.attn_1(h)296h = self.mid.block_2(h)297# end298h = self.norm_out(h)299h = self.swish(h)300h = self.conv_out(h)301return h302303304class VAEDecoder(torch.nn.Module):305def __init__(self, ch=128, out_ch=3, ch_mult=(1, 2, 4, 4), num_res_blocks=2, resolution=256, z_channels=16, dtype=torch.float32, device=None):306super().__init__()307self.num_resolutions = len(ch_mult)308self.num_res_blocks = num_res_blocks309block_in = ch * ch_mult[self.num_resolutions - 1]310curr_res = resolution // 2 ** (self.num_resolutions - 1)311# z to block_in312self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)313# middle314self.mid = torch.nn.Module()315self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)316self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device)317self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)318# upsampling319self.up = torch.nn.ModuleList()320for i_level in reversed(range(self.num_resolutions)):321block = torch.nn.ModuleList()322block_out = ch * ch_mult[i_level]323for _ in range(self.num_res_blocks + 1):324block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device))325block_in = block_out326up = torch.nn.Module()327up.block = block328if i_level != 0:329up.upsample = Upsample(block_in, dtype=dtype, device=device)330curr_res = curr_res * 2331self.up.insert(0, up) # prepend to get consistent order332# end333self.norm_out = Normalize(block_in, dtype=dtype, device=device)334self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)335self.swish = torch.nn.SiLU(inplace=True)336337def forward(self, z):338# z to block_in339hidden = self.conv_in(z)340# middle341hidden = self.mid.block_1(hidden)342hidden = self.mid.attn_1(hidden)343hidden = self.mid.block_2(hidden)344# upsampling345for i_level in reversed(range(self.num_resolutions)):346for i_block in range(self.num_res_blocks + 1):347hidden = self.up[i_level].block[i_block](hidden)348if i_level != 0:349hidden = self.up[i_level].upsample(hidden)350# end351hidden = self.norm_out(hidden)352hidden = self.swish(hidden)353hidden = self.conv_out(hidden)354return hidden355356357class SDVAE(torch.nn.Module):358def __init__(self, dtype=torch.float32, device=None):359super().__init__()360self.encoder = VAEEncoder(dtype=dtype, device=device)361self.decoder = VAEDecoder(dtype=dtype, device=device)362363@torch.autocast("cuda", dtype=torch.float16)364def decode(self, latent):365return self.decoder(latent)366367@torch.autocast("cuda", dtype=torch.float16)368def encode(self, image):369hidden = self.encoder(image)370mean, logvar = torch.chunk(hidden, 2, dim=1)371logvar = torch.clamp(logvar, -30.0, 20.0)372std = torch.exp(0.5 * logvar)373return mean + std * torch.randn_like(mean)374375376