Path: blob/master/modules/models/sd3/other_impls.py
3072 views
### This file contains impls for underlying related models (CLIP, T5, etc)12import torch3import math4from torch import nn5from transformers import CLIPTokenizer, T5TokenizerFast67from modules import sd_hijack8910#################################################################################################11### Core/Utility12#################################################################################################131415class AutocastLinear(nn.Linear):16"""Same as usual linear layer, but casts its weights to whatever the parameter type is.1718This is different from torch.autocast in a way that float16 layer processing float32 input19will return float16 with autocast on, and float32 with this. T5 seems to be fucked20if you do it in full float16 (returning almost all zeros in the final output).21"""2223def forward(self, x):24return torch.nn.functional.linear(x, self.weight.to(x.dtype), self.bias.to(x.dtype) if self.bias is not None else None)252627def attention(q, k, v, heads, mask=None):28"""Convenience wrapper around a basic attention operation"""29b, _, dim_head = q.shape30dim_head //= heads31q, k, v = [t.view(b, -1, heads, dim_head).transpose(1, 2) for t in (q, k, v)]32out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)33return out.transpose(1, 2).reshape(b, -1, heads * dim_head)343536class Mlp(nn.Module):37""" MLP as used in Vision Transformer, MLP-Mixer and related networks"""38def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, dtype=None, device=None):39super().__init__()40out_features = out_features or in_features41hidden_features = hidden_features or in_features4243self.fc1 = nn.Linear(in_features, hidden_features, bias=bias, dtype=dtype, device=device)44self.act = act_layer45self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, dtype=dtype, device=device)4647def forward(self, x):48x = self.fc1(x)49x = self.act(x)50x = self.fc2(x)51return x525354#################################################################################################55### CLIP56#################################################################################################575859class CLIPAttention(torch.nn.Module):60def __init__(self, embed_dim, heads, dtype, device):61super().__init__()62self.heads = heads63self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)64self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)65self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)66self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)6768def forward(self, x, mask=None):69q = self.q_proj(x)70k = self.k_proj(x)71v = self.v_proj(x)72out = attention(q, k, v, self.heads, mask)73return self.out_proj(out)747576ACTIVATIONS = {77"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a),78"gelu": torch.nn.functional.gelu,79}8081class CLIPLayer(torch.nn.Module):82def __init__(self, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device):83super().__init__()84self.layer_norm1 = nn.LayerNorm(embed_dim, dtype=dtype, device=device)85self.self_attn = CLIPAttention(embed_dim, heads, dtype, device)86self.layer_norm2 = nn.LayerNorm(embed_dim, dtype=dtype, device=device)87#self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device)88self.mlp = Mlp(embed_dim, intermediate_size, embed_dim, act_layer=ACTIVATIONS[intermediate_activation], dtype=dtype, device=device)8990def forward(self, x, mask=None):91x += self.self_attn(self.layer_norm1(x), mask)92x += self.mlp(self.layer_norm2(x))93return x949596class CLIPEncoder(torch.nn.Module):97def __init__(self, num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device):98super().__init__()99self.layers = torch.nn.ModuleList([CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device) for i in range(num_layers)])100101def forward(self, x, mask=None, intermediate_output=None):102if intermediate_output is not None:103if intermediate_output < 0:104intermediate_output = len(self.layers) + intermediate_output105intermediate = None106for i, layer in enumerate(self.layers):107x = layer(x, mask)108if i == intermediate_output:109intermediate = x.clone()110return x, intermediate111112113class CLIPEmbeddings(torch.nn.Module):114def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None, textual_inversion_key="clip_l"):115super().__init__()116self.token_embedding = sd_hijack.TextualInversionEmbeddings(vocab_size, embed_dim, dtype=dtype, device=device, textual_inversion_key=textual_inversion_key)117self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device)118119def forward(self, input_tokens):120return self.token_embedding(input_tokens) + self.position_embedding.weight121122123class CLIPTextModel_(torch.nn.Module):124def __init__(self, config_dict, dtype, device):125num_layers = config_dict["num_hidden_layers"]126embed_dim = config_dict["hidden_size"]127heads = config_dict["num_attention_heads"]128intermediate_size = config_dict["intermediate_size"]129intermediate_activation = config_dict["hidden_act"]130super().__init__()131self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device, textual_inversion_key=config_dict.get('textual_inversion_key', 'clip_l'))132self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device)133self.final_layer_norm = nn.LayerNorm(embed_dim, dtype=dtype, device=device)134135def forward(self, input_tokens, intermediate_output=None, final_layer_norm_intermediate=True):136x = self.embeddings(input_tokens)137causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)138x, i = self.encoder(x, mask=causal_mask, intermediate_output=intermediate_output)139x = self.final_layer_norm(x)140if i is not None and final_layer_norm_intermediate:141i = self.final_layer_norm(i)142pooled_output = x[torch.arange(x.shape[0], device=x.device), input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1),]143return x, i, pooled_output144145146class CLIPTextModel(torch.nn.Module):147def __init__(self, config_dict, dtype, device):148super().__init__()149self.num_layers = config_dict["num_hidden_layers"]150self.text_model = CLIPTextModel_(config_dict, dtype, device)151embed_dim = config_dict["hidden_size"]152self.text_projection = nn.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)153self.text_projection.weight.copy_(torch.eye(embed_dim))154self.dtype = dtype155156def get_input_embeddings(self):157return self.text_model.embeddings.token_embedding158159def set_input_embeddings(self, embeddings):160self.text_model.embeddings.token_embedding = embeddings161162def forward(self, *args, **kwargs):163x = self.text_model(*args, **kwargs)164out = self.text_projection(x[2])165return (x[0], x[1], out, x[2])166167168class SDTokenizer:169def __init__(self, max_length=77, pad_with_end=True, tokenizer=None, has_start_token=True, pad_to_max_length=True, min_length=None):170self.tokenizer = tokenizer171self.max_length = max_length172self.min_length = min_length173empty = self.tokenizer('')["input_ids"]174if has_start_token:175self.tokens_start = 1176self.start_token = empty[0]177self.end_token = empty[1]178else:179self.tokens_start = 0180self.start_token = None181self.end_token = empty[0]182self.pad_with_end = pad_with_end183self.pad_to_max_length = pad_to_max_length184vocab = self.tokenizer.get_vocab()185self.inv_vocab = {v: k for k, v in vocab.items()}186self.max_word_length = 8187188189def tokenize_with_weights(self, text:str):190"""Tokenize the text, with weight values - presume 1.0 for all and ignore other features here. The details aren't relevant for a reference impl, and weights themselves has weak effect on SD3."""191if self.pad_with_end:192pad_token = self.end_token193else:194pad_token = 0195batch = []196if self.start_token is not None:197batch.append((self.start_token, 1.0))198to_tokenize = text.replace("\n", " ").split(' ')199to_tokenize = [x for x in to_tokenize if x != ""]200for word in to_tokenize:201batch.extend([(t, 1) for t in self.tokenizer(word)["input_ids"][self.tokens_start:-1]])202batch.append((self.end_token, 1.0))203if self.pad_to_max_length:204batch.extend([(pad_token, 1.0)] * (self.max_length - len(batch)))205if self.min_length is not None and len(batch) < self.min_length:206batch.extend([(pad_token, 1.0)] * (self.min_length - len(batch)))207return [batch]208209210class SDXLClipGTokenizer(SDTokenizer):211def __init__(self, tokenizer):212super().__init__(pad_with_end=False, tokenizer=tokenizer)213214215class SD3Tokenizer:216def __init__(self):217clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")218self.clip_l = SDTokenizer(tokenizer=clip_tokenizer)219self.clip_g = SDXLClipGTokenizer(clip_tokenizer)220self.t5xxl = T5XXLTokenizer()221222def tokenize_with_weights(self, text:str):223out = {}224out["g"] = self.clip_g.tokenize_with_weights(text)225out["l"] = self.clip_l.tokenize_with_weights(text)226out["t5xxl"] = self.t5xxl.tokenize_with_weights(text)227return out228229230class ClipTokenWeightEncoder:231def encode_token_weights(self, token_weight_pairs):232tokens = [a[0] for a in token_weight_pairs[0]]233out, pooled = self([tokens])234if pooled is not None:235first_pooled = pooled[0:1].cpu()236else:237first_pooled = pooled238output = [out[0:1]]239return torch.cat(output, dim=-2).cpu(), first_pooled240241242class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):243"""Uses the CLIP transformer encoder for text (from huggingface)"""244LAYERS = ["last", "pooled", "hidden"]245def __init__(self, device="cpu", max_length=77, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=CLIPTextModel,246special_tokens=None, layer_norm_hidden_state=True, return_projected_pooled=True):247super().__init__()248assert layer in self.LAYERS249self.transformer = model_class(textmodel_json_config, dtype, device)250self.num_layers = self.transformer.num_layers251self.max_length = max_length252self.transformer = self.transformer.eval()253for param in self.parameters():254param.requires_grad = False255self.layer = layer256self.layer_idx = None257self.special_tokens = special_tokens if special_tokens is not None else {"start": 49406, "end": 49407, "pad": 49407}258self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))259self.layer_norm_hidden_state = layer_norm_hidden_state260self.return_projected_pooled = return_projected_pooled261if layer == "hidden":262assert layer_idx is not None263assert abs(layer_idx) < self.num_layers264self.set_clip_options({"layer": layer_idx})265self.options_default = (self.layer, self.layer_idx, self.return_projected_pooled)266267def set_clip_options(self, options):268layer_idx = options.get("layer", self.layer_idx)269self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)270if layer_idx is None or abs(layer_idx) > self.num_layers:271self.layer = "last"272else:273self.layer = "hidden"274self.layer_idx = layer_idx275276def forward(self, tokens):277backup_embeds = self.transformer.get_input_embeddings()278tokens = torch.asarray(tokens, dtype=torch.int64, device=backup_embeds.weight.device)279outputs = self.transformer(tokens, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state)280self.transformer.set_input_embeddings(backup_embeds)281if self.layer == "last":282z = outputs[0]283else:284z = outputs[1]285pooled_output = None286if len(outputs) >= 3:287if not self.return_projected_pooled and len(outputs) >= 4 and outputs[3] is not None:288pooled_output = outputs[3].float()289elif outputs[2] is not None:290pooled_output = outputs[2].float()291return z.float(), pooled_output292293294class SDXLClipG(SDClipModel):295"""Wraps the CLIP-G model into the SD-CLIP-Model interface"""296def __init__(self, config, device="cpu", layer="penultimate", layer_idx=None, dtype=None):297if layer == "penultimate":298layer="hidden"299layer_idx=-2300super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False)301302303class T5XXLModel(SDClipModel):304"""Wraps the T5-XXL model into the SD-CLIP-Model interface for convenience"""305def __init__(self, config, device="cpu", layer="last", layer_idx=None, dtype=None):306super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=T5)307308309#################################################################################################310### T5 implementation, for the T5-XXL text encoder portion, largely pulled from upstream impl311#################################################################################################312313class T5XXLTokenizer(SDTokenizer):314"""Wraps the T5 Tokenizer from HF into the SDTokenizer interface"""315def __init__(self):316super().__init__(pad_with_end=False, tokenizer=T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl"), has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77)317318319class T5LayerNorm(torch.nn.Module):320def __init__(self, hidden_size, eps=1e-6, dtype=None, device=None):321super().__init__()322self.weight = torch.nn.Parameter(torch.ones(hidden_size, dtype=dtype, device=device))323self.variance_epsilon = eps324325def forward(self, x):326variance = x.pow(2).mean(-1, keepdim=True)327x = x * torch.rsqrt(variance + self.variance_epsilon)328return self.weight.to(device=x.device, dtype=x.dtype) * x329330331class T5DenseGatedActDense(torch.nn.Module):332def __init__(self, model_dim, ff_dim, dtype, device):333super().__init__()334self.wi_0 = AutocastLinear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)335self.wi_1 = AutocastLinear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)336self.wo = AutocastLinear(ff_dim, model_dim, bias=False, dtype=dtype, device=device)337338def forward(self, x):339hidden_gelu = torch.nn.functional.gelu(self.wi_0(x), approximate="tanh")340hidden_linear = self.wi_1(x)341x = hidden_gelu * hidden_linear342x = self.wo(x)343return x344345346class T5LayerFF(torch.nn.Module):347def __init__(self, model_dim, ff_dim, dtype, device):348super().__init__()349self.DenseReluDense = T5DenseGatedActDense(model_dim, ff_dim, dtype, device)350self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)351352def forward(self, x):353forwarded_states = self.layer_norm(x)354forwarded_states = self.DenseReluDense(forwarded_states)355x += forwarded_states356return x357358359class T5Attention(torch.nn.Module):360def __init__(self, model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device):361super().__init__()362# Mesh TensorFlow initialization to avoid scaling before softmax363self.q = AutocastLinear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)364self.k = AutocastLinear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)365self.v = AutocastLinear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)366self.o = AutocastLinear(inner_dim, model_dim, bias=False, dtype=dtype, device=device)367self.num_heads = num_heads368self.relative_attention_bias = None369if relative_attention_bias:370self.relative_attention_num_buckets = 32371self.relative_attention_max_distance = 128372self.relative_attention_bias = torch.nn.Embedding(self.relative_attention_num_buckets, self.num_heads, device=device)373374@staticmethod375def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):376"""377Adapted from Mesh Tensorflow:378https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593379380Translate relative position to a bucket number for relative attention. The relative position is defined as381memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to382position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for383small absolute relative_position and larger buckets for larger absolute relative_positions. All relative384positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.385This should allow for more graceful generalization to longer sequences than the model has been trained on386387Args:388relative_position: an int32 Tensor389bidirectional: a boolean - whether the attention is bidirectional390num_buckets: an integer391max_distance: an integer392393Returns:394a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)395"""396relative_buckets = 0397if bidirectional:398num_buckets //= 2399relative_buckets += (relative_position > 0).to(torch.long) * num_buckets400relative_position = torch.abs(relative_position)401else:402relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))403# now relative_position is in the range [0, inf)404# half of the buckets are for exact increments in positions405max_exact = num_buckets // 2406is_small = relative_position < max_exact407# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance408relative_position_if_large = max_exact + (409torch.log(relative_position.float() / max_exact)410/ math.log(max_distance / max_exact)411* (num_buckets - max_exact)412).to(torch.long)413relative_position_if_large = torch.min(relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1))414relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)415return relative_buckets416417def compute_bias(self, query_length, key_length, device):418"""Compute binned relative position bias"""419context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]420memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]421relative_position = memory_position - context_position # shape (query_length, key_length)422relative_position_bucket = self._relative_position_bucket(423relative_position, # shape (query_length, key_length)424bidirectional=True,425num_buckets=self.relative_attention_num_buckets,426max_distance=self.relative_attention_max_distance,427)428values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)429values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)430return values431432def forward(self, x, past_bias=None):433q = self.q(x)434k = self.k(x)435v = self.v(x)436437if self.relative_attention_bias is not None:438past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device)439if past_bias is not None:440mask = past_bias441else:442mask = None443444out = attention(q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask.to(x.dtype) if mask is not None else None)445446return self.o(out), past_bias447448449class T5LayerSelfAttention(torch.nn.Module):450def __init__(self, model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device):451super().__init__()452self.SelfAttention = T5Attention(model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device)453self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)454455def forward(self, x, past_bias=None):456output, past_bias = self.SelfAttention(self.layer_norm(x), past_bias=past_bias)457x += output458return x, past_bias459460461class T5Block(torch.nn.Module):462def __init__(self, model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device):463super().__init__()464self.layer = torch.nn.ModuleList()465self.layer.append(T5LayerSelfAttention(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device))466self.layer.append(T5LayerFF(model_dim, ff_dim, dtype, device))467468def forward(self, x, past_bias=None):469x, past_bias = self.layer[0](x, past_bias)470x = self.layer[-1](x)471return x, past_bias472473474class T5Stack(torch.nn.Module):475def __init__(self, num_layers, model_dim, inner_dim, ff_dim, num_heads, vocab_size, dtype, device):476super().__init__()477self.embed_tokens = torch.nn.Embedding(vocab_size, model_dim, device=device)478self.block = torch.nn.ModuleList([T5Block(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias=(i == 0), dtype=dtype, device=device) for i in range(num_layers)])479self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)480481def forward(self, input_ids, intermediate_output=None, final_layer_norm_intermediate=True):482intermediate = None483x = self.embed_tokens(input_ids).to(torch.float32) # needs float32 or else T5 returns all zeroes484past_bias = None485for i, layer in enumerate(self.block):486x, past_bias = layer(x, past_bias)487if i == intermediate_output:488intermediate = x.clone()489x = self.final_layer_norm(x)490if intermediate is not None and final_layer_norm_intermediate:491intermediate = self.final_layer_norm(intermediate)492return x, intermediate493494495class T5(torch.nn.Module):496def __init__(self, config_dict, dtype, device):497super().__init__()498self.num_layers = config_dict["num_layers"]499self.encoder = T5Stack(self.num_layers, config_dict["d_model"], config_dict["d_model"], config_dict["d_ff"], config_dict["num_heads"], config_dict["vocab_size"], dtype, device)500self.dtype = dtype501502def get_input_embeddings(self):503return self.encoder.embed_tokens504505def set_input_embeddings(self, embeddings):506self.encoder.embed_tokens = embeddings507508def forward(self, *args, **kwargs):509return self.encoder(*args, **kwargs)510511512