Path: blob/master/modules/models/sd3/mmdit.py
3081 views
### This file contains impls for MM-DiT, the core model component of SD312import math3from typing import Dict, Optional4import numpy as np5import torch6import torch.nn as nn7from einops import rearrange, repeat8from modules.models.sd3.other_impls import attention, Mlp91011class PatchEmbed(nn.Module):12""" 2D Image to Patch Embedding"""13def __init__(14self,15img_size: Optional[int] = 224,16patch_size: int = 16,17in_chans: int = 3,18embed_dim: int = 768,19flatten: bool = True,20bias: bool = True,21strict_img_size: bool = True,22dynamic_img_pad: bool = False,23dtype=None,24device=None,25):26super().__init__()27self.patch_size = (patch_size, patch_size)28if img_size is not None:29self.img_size = (img_size, img_size)30self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)])31self.num_patches = self.grid_size[0] * self.grid_size[1]32else:33self.img_size = None34self.grid_size = None35self.num_patches = None3637# flatten spatial dim and transpose to channels last, kept for bwd compat38self.flatten = flatten39self.strict_img_size = strict_img_size40self.dynamic_img_pad = dynamic_img_pad4142self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, dtype=dtype, device=device)4344def forward(self, x):45B, C, H, W = x.shape46x = self.proj(x)47if self.flatten:48x = x.flatten(2).transpose(1, 2) # NCHW -> NLC49return x505152def modulate(x, shift, scale):53if shift is None:54shift = torch.zeros_like(scale)55return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)565758#################################################################################59# Sine/Cosine Positional Embedding Functions #60#################################################################################616263def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, scaling_factor=None, offset=None):64"""65grid_size: int of the grid height and width66return:67pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)68"""69grid_h = np.arange(grid_size, dtype=np.float32)70grid_w = np.arange(grid_size, dtype=np.float32)71grid = np.meshgrid(grid_w, grid_h) # here w goes first72grid = np.stack(grid, axis=0)73if scaling_factor is not None:74grid = grid / scaling_factor75if offset is not None:76grid = grid - offset77grid = grid.reshape([2, 1, grid_size, grid_size])78pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)79if cls_token and extra_tokens > 0:80pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)81return pos_embed828384def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):85assert embed_dim % 2 == 086# use half of dimensions to encode grid_h87emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)88emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)89emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)90return emb919293def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):94"""95embed_dim: output dimension for each position96pos: a list of positions to be encoded: size (M,)97out: (M, D)98"""99assert embed_dim % 2 == 0100omega = np.arange(embed_dim // 2, dtype=np.float64)101omega /= embed_dim / 2.0102omega = 1.0 / 10000**omega # (D/2,)103pos = pos.reshape(-1) # (M,)104out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product105emb_sin = np.sin(out) # (M, D/2)106emb_cos = np.cos(out) # (M, D/2)107return np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)108109110#################################################################################111# Embedding Layers for Timesteps and Class Labels #112#################################################################################113114115class TimestepEmbedder(nn.Module):116"""Embeds scalar timesteps into vector representations."""117118def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None):119super().__init__()120self.mlp = nn.Sequential(121nn.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device),122nn.SiLU(),123nn.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),124)125self.frequency_embedding_size = frequency_embedding_size126127@staticmethod128def timestep_embedding(t, dim, max_period=10000):129"""130Create sinusoidal timestep embeddings.131:param t: a 1-D Tensor of N indices, one per batch element.132These may be fractional.133:param dim: the dimension of the output.134:param max_period: controls the minimum frequency of the embeddings.135:return: an (N, D) Tensor of positional embeddings.136"""137half = dim // 2138freqs = torch.exp(139-math.log(max_period)140* torch.arange(start=0, end=half, dtype=torch.float32)141/ half142).to(device=t.device)143args = t[:, None].float() * freqs[None]144embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)145if dim % 2:146embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)147if torch.is_floating_point(t):148embedding = embedding.to(dtype=t.dtype)149return embedding150151def forward(self, t, dtype, **kwargs):152t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)153t_emb = self.mlp(t_freq)154return t_emb155156157class VectorEmbedder(nn.Module):158"""Embeds a flat vector of dimension input_dim"""159160def __init__(self, input_dim: int, hidden_size: int, dtype=None, device=None):161super().__init__()162self.mlp = nn.Sequential(163nn.Linear(input_dim, hidden_size, bias=True, dtype=dtype, device=device),164nn.SiLU(),165nn.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),166)167168def forward(self, x: torch.Tensor) -> torch.Tensor:169return self.mlp(x)170171172#################################################################################173# Core DiT Model #174#################################################################################175176177class QkvLinear(torch.nn.Linear):178pass179180def split_qkv(qkv, head_dim):181qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, -1, head_dim).movedim(2, 0)182return qkv[0], qkv[1], qkv[2]183184def optimized_attention(qkv, num_heads):185return attention(qkv[0], qkv[1], qkv[2], num_heads)186187class SelfAttention(nn.Module):188ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug")189190def __init__(191self,192dim: int,193num_heads: int = 8,194qkv_bias: bool = False,195qk_scale: Optional[float] = None,196attn_mode: str = "xformers",197pre_only: bool = False,198qk_norm: Optional[str] = None,199rmsnorm: bool = False,200dtype=None,201device=None,202):203super().__init__()204self.num_heads = num_heads205self.head_dim = dim // num_heads206207self.qkv = QkvLinear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)208if not pre_only:209self.proj = nn.Linear(dim, dim, dtype=dtype, device=device)210assert attn_mode in self.ATTENTION_MODES211self.attn_mode = attn_mode212self.pre_only = pre_only213214if qk_norm == "rms":215self.ln_q = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)216self.ln_k = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)217elif qk_norm == "ln":218self.ln_q = nn.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)219self.ln_k = nn.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)220elif qk_norm is None:221self.ln_q = nn.Identity()222self.ln_k = nn.Identity()223else:224raise ValueError(qk_norm)225226def pre_attention(self, x: torch.Tensor):227B, L, C = x.shape228qkv = self.qkv(x)229q, k, v = split_qkv(qkv, self.head_dim)230q = self.ln_q(q).reshape(q.shape[0], q.shape[1], -1)231k = self.ln_k(k).reshape(q.shape[0], q.shape[1], -1)232return (q, k, v)233234def post_attention(self, x: torch.Tensor) -> torch.Tensor:235assert not self.pre_only236x = self.proj(x)237return x238239def forward(self, x: torch.Tensor) -> torch.Tensor:240(q, k, v) = self.pre_attention(x)241x = attention(q, k, v, self.num_heads)242x = self.post_attention(x)243return x244245246class RMSNorm(torch.nn.Module):247def __init__(248self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6, device=None, dtype=None249):250"""251Initialize the RMSNorm normalization layer.252Args:253dim (int): The dimension of the input tensor.254eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.255Attributes:256eps (float): A small value added to the denominator for numerical stability.257weight (nn.Parameter): Learnable scaling parameter.258"""259super().__init__()260self.eps = eps261self.learnable_scale = elementwise_affine262if self.learnable_scale:263self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))264else:265self.register_parameter("weight", None)266267def _norm(self, x):268"""269Apply the RMSNorm normalization to the input tensor.270Args:271x (torch.Tensor): The input tensor.272Returns:273torch.Tensor: The normalized tensor.274"""275return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)276277def forward(self, x):278"""279Forward pass through the RMSNorm layer.280Args:281x (torch.Tensor): The input tensor.282Returns:283torch.Tensor: The output tensor after applying RMSNorm.284"""285x = self._norm(x)286if self.learnable_scale:287return x * self.weight.to(device=x.device, dtype=x.dtype)288else:289return x290291292class SwiGLUFeedForward(nn.Module):293def __init__(294self,295dim: int,296hidden_dim: int,297multiple_of: int,298ffn_dim_multiplier: Optional[float] = None,299):300"""301Initialize the FeedForward module.302303Args:304dim (int): Input dimension.305hidden_dim (int): Hidden dimension of the feedforward layer.306multiple_of (int): Value to ensure hidden dimension is a multiple of this value.307ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None.308309Attributes:310w1 (ColumnParallelLinear): Linear transformation for the first layer.311w2 (RowParallelLinear): Linear transformation for the second layer.312w3 (ColumnParallelLinear): Linear transformation for the third layer.313314"""315super().__init__()316hidden_dim = int(2 * hidden_dim / 3)317# custom dim factor multiplier318if ffn_dim_multiplier is not None:319hidden_dim = int(ffn_dim_multiplier * hidden_dim)320hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)321322self.w1 = nn.Linear(dim, hidden_dim, bias=False)323self.w2 = nn.Linear(hidden_dim, dim, bias=False)324self.w3 = nn.Linear(dim, hidden_dim, bias=False)325326def forward(self, x):327return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))328329330class DismantledBlock(nn.Module):331"""A DiT block with gated adaptive layer norm (adaLN) conditioning."""332333ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug")334335def __init__(336self,337hidden_size: int,338num_heads: int,339mlp_ratio: float = 4.0,340attn_mode: str = "xformers",341qkv_bias: bool = False,342pre_only: bool = False,343rmsnorm: bool = False,344scale_mod_only: bool = False,345swiglu: bool = False,346qk_norm: Optional[str] = None,347dtype=None,348device=None,349**block_kwargs,350):351super().__init__()352assert attn_mode in self.ATTENTION_MODES353if not rmsnorm:354self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)355else:356self.norm1 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)357self.attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, attn_mode=attn_mode, pre_only=pre_only, qk_norm=qk_norm, rmsnorm=rmsnorm, dtype=dtype, device=device)358if not pre_only:359if not rmsnorm:360self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)361else:362self.norm2 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)363mlp_hidden_dim = int(hidden_size * mlp_ratio)364if not pre_only:365if not swiglu:366self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=nn.GELU(approximate="tanh"), dtype=dtype, device=device)367else:368self.mlp = SwiGLUFeedForward(dim=hidden_size, hidden_dim=mlp_hidden_dim, multiple_of=256)369self.scale_mod_only = scale_mod_only370if not scale_mod_only:371n_mods = 6 if not pre_only else 2372else:373n_mods = 4 if not pre_only else 1374self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, n_mods * hidden_size, bias=True, dtype=dtype, device=device))375self.pre_only = pre_only376377def pre_attention(self, x: torch.Tensor, c: torch.Tensor):378assert x is not None, "pre_attention called with None input"379if not self.pre_only:380if not self.scale_mod_only:381shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)382else:383shift_msa = None384shift_mlp = None385scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(4, dim=1)386qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa))387return qkv, (x, gate_msa, shift_mlp, scale_mlp, gate_mlp)388else:389if not self.scale_mod_only:390shift_msa, scale_msa = self.adaLN_modulation(c).chunk(2, dim=1)391else:392shift_msa = None393scale_msa = self.adaLN_modulation(c)394qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa))395return qkv, None396397def post_attention(self, attn, x, gate_msa, shift_mlp, scale_mlp, gate_mlp):398assert not self.pre_only399x = x + gate_msa.unsqueeze(1) * self.attn.post_attention(attn)400x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))401return x402403def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:404assert not self.pre_only405(q, k, v), intermediates = self.pre_attention(x, c)406attn = attention(q, k, v, self.attn.num_heads)407return self.post_attention(attn, *intermediates)408409410def block_mixing(context, x, context_block, x_block, c):411assert context is not None, "block_mixing called with None context"412context_qkv, context_intermediates = context_block.pre_attention(context, c)413414x_qkv, x_intermediates = x_block.pre_attention(x, c)415416o = []417for t in range(3):418o.append(torch.cat((context_qkv[t], x_qkv[t]), dim=1))419q, k, v = tuple(o)420421attn = attention(q, k, v, x_block.attn.num_heads)422context_attn, x_attn = (attn[:, : context_qkv[0].shape[1]], attn[:, context_qkv[0].shape[1] :])423424if not context_block.pre_only:425context = context_block.post_attention(context_attn, *context_intermediates)426else:427context = None428x = x_block.post_attention(x_attn, *x_intermediates)429return context, x430431432class JointBlock(nn.Module):433"""just a small wrapper to serve as a fsdp unit"""434435def __init__(self, *args, **kwargs):436super().__init__()437pre_only = kwargs.pop("pre_only")438qk_norm = kwargs.pop("qk_norm", None)439self.context_block = DismantledBlock(*args, pre_only=pre_only, qk_norm=qk_norm, **kwargs)440self.x_block = DismantledBlock(*args, pre_only=False, qk_norm=qk_norm, **kwargs)441442def forward(self, *args, **kwargs):443return block_mixing(*args, context_block=self.context_block, x_block=self.x_block, **kwargs)444445446class FinalLayer(nn.Module):447"""448The final layer of DiT.449"""450451def __init__(self, hidden_size: int, patch_size: int, out_channels: int, total_out_channels: Optional[int] = None, dtype=None, device=None):452super().__init__()453self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)454self.linear = (455nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)456if (total_out_channels is None)457else nn.Linear(hidden_size, total_out_channels, bias=True, dtype=dtype, device=device)458)459self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device))460461def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:462shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)463x = modulate(self.norm_final(x), shift, scale)464x = self.linear(x)465return x466467468class MMDiT(nn.Module):469"""Diffusion model with a Transformer backbone."""470471def __init__(472self,473input_size: int = 32,474patch_size: int = 2,475in_channels: int = 4,476depth: int = 28,477mlp_ratio: float = 4.0,478learn_sigma: bool = False,479adm_in_channels: Optional[int] = None,480context_embedder_config: Optional[Dict] = None,481register_length: int = 0,482attn_mode: str = "torch",483rmsnorm: bool = False,484scale_mod_only: bool = False,485swiglu: bool = False,486out_channels: Optional[int] = None,487pos_embed_scaling_factor: Optional[float] = None,488pos_embed_offset: Optional[float] = None,489pos_embed_max_size: Optional[int] = None,490num_patches = None,491qk_norm: Optional[str] = None,492qkv_bias: bool = True,493dtype = None,494device = None,495):496super().__init__()497self.dtype = dtype498self.learn_sigma = learn_sigma499self.in_channels = in_channels500default_out_channels = in_channels * 2 if learn_sigma else in_channels501self.out_channels = out_channels if out_channels is not None else default_out_channels502self.patch_size = patch_size503self.pos_embed_scaling_factor = pos_embed_scaling_factor504self.pos_embed_offset = pos_embed_offset505self.pos_embed_max_size = pos_embed_max_size506507# apply magic --> this defines a head_size of 64508hidden_size = 64 * depth509num_heads = depth510511self.num_heads = num_heads512513self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True, strict_img_size=self.pos_embed_max_size is None, dtype=dtype, device=device)514self.t_embedder = TimestepEmbedder(hidden_size, dtype=dtype, device=device)515516if adm_in_channels is not None:517assert isinstance(adm_in_channels, int)518self.y_embedder = VectorEmbedder(adm_in_channels, hidden_size, dtype=dtype, device=device)519520self.context_embedder = nn.Identity()521if context_embedder_config is not None:522if context_embedder_config["target"] == "torch.nn.Linear":523self.context_embedder = nn.Linear(**context_embedder_config["params"], dtype=dtype, device=device)524525self.register_length = register_length526if self.register_length > 0:527self.register = nn.Parameter(torch.randn(1, register_length, hidden_size, dtype=dtype, device=device))528529# num_patches = self.x_embedder.num_patches530# Will use fixed sin-cos embedding:531# just use a buffer already532if num_patches is not None:533self.register_buffer(534"pos_embed",535torch.zeros(1, num_patches, hidden_size, dtype=dtype, device=device),536)537else:538self.pos_embed = None539540self.joint_blocks = nn.ModuleList(541[542JointBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, attn_mode=attn_mode, pre_only=i == depth - 1, rmsnorm=rmsnorm, scale_mod_only=scale_mod_only, swiglu=swiglu, qk_norm=qk_norm, dtype=dtype, device=device)543for i in range(depth)544]545)546547self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels, dtype=dtype, device=device)548549def cropped_pos_embed(self, hw):550assert self.pos_embed_max_size is not None551p = self.x_embedder.patch_size[0]552h, w = hw553# patched size554h = h // p555w = w // p556assert h <= self.pos_embed_max_size, (h, self.pos_embed_max_size)557assert w <= self.pos_embed_max_size, (w, self.pos_embed_max_size)558top = (self.pos_embed_max_size - h) // 2559left = (self.pos_embed_max_size - w) // 2560spatial_pos_embed = rearrange(561self.pos_embed,562"1 (h w) c -> 1 h w c",563h=self.pos_embed_max_size,564w=self.pos_embed_max_size,565)566spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :]567spatial_pos_embed = rearrange(spatial_pos_embed, "1 h w c -> 1 (h w) c")568return spatial_pos_embed569570def unpatchify(self, x, hw=None):571"""572x: (N, T, patch_size**2 * C)573imgs: (N, H, W, C)574"""575c = self.out_channels576p = self.x_embedder.patch_size[0]577if hw is None:578h = w = int(x.shape[1] ** 0.5)579else:580h, w = hw581h = h // p582w = w // p583assert h * w == x.shape[1]584585x = x.reshape(shape=(x.shape[0], h, w, p, p, c))586x = torch.einsum("nhwpqc->nchpwq", x)587imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))588return imgs589590def forward_core_with_concat(self, x: torch.Tensor, c_mod: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor:591if self.register_length > 0:592context = torch.cat((repeat(self.register, "1 ... -> b ...", b=x.shape[0]), context if context is not None else torch.Tensor([]).type_as(x)), 1)593594# context is B, L', D595# x is B, L, D596for block in self.joint_blocks:597context, x = block(context, x, c=c_mod)598599x = self.final_layer(x, c_mod) # (N, T, patch_size ** 2 * out_channels)600return x601602def forward(self, x: torch.Tensor, t: torch.Tensor, y: Optional[torch.Tensor] = None, context: Optional[torch.Tensor] = None) -> torch.Tensor:603"""604Forward pass of DiT.605x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)606t: (N,) tensor of diffusion timesteps607y: (N,) tensor of class labels608"""609hw = x.shape[-2:]610x = self.x_embedder(x) + self.cropped_pos_embed(hw)611c = self.t_embedder(t, dtype=x.dtype) # (N, D)612if y is not None:613y = self.y_embedder(y) # (N, D)614c = c + y # (N, D)615616context = self.context_embedder(context)617618x = self.forward_core_with_concat(x, c, context)619620x = self.unpatchify(x, hw=hw) # (N, out_channels, H, W)621return x622623624