Path: blob/master/modules/models/sd3/sd3_cond.py
3073 views
import os1import safetensors2import torch3import typing45from transformers import CLIPTokenizer, T5TokenizerFast67from modules import shared, devices, modelloader, sd_hijack_clip, prompt_parser8from modules.models.sd3.other_impls import SDClipModel, SDXLClipG, T5XXLModel, SD3Tokenizer91011class SafetensorsMapping(typing.Mapping):12def __init__(self, file):13self.file = file1415def __len__(self):16return len(self.file.keys())1718def __iter__(self):19for key in self.file.keys():20yield key2122def __getitem__(self, key):23return self.file.get_tensor(key)242526CLIPL_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/clip_l.safetensors"27CLIPL_CONFIG = {28"hidden_act": "quick_gelu",29"hidden_size": 768,30"intermediate_size": 3072,31"num_attention_heads": 12,32"num_hidden_layers": 12,33}3435CLIPG_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/clip_g.safetensors"36CLIPG_CONFIG = {37"hidden_act": "gelu",38"hidden_size": 1280,39"intermediate_size": 5120,40"num_attention_heads": 20,41"num_hidden_layers": 32,42"textual_inversion_key": "clip_g",43}4445T5_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/t5xxl_fp16.safetensors"46T5_CONFIG = {47"d_ff": 10240,48"d_model": 4096,49"num_heads": 64,50"num_layers": 24,51"vocab_size": 32128,52}535455class Sd3ClipLG(sd_hijack_clip.TextConditionalModel):56def __init__(self, clip_l, clip_g):57super().__init__()5859self.clip_l = clip_l60self.clip_g = clip_g6162self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")6364empty = self.tokenizer('')["input_ids"]65self.id_start = empty[0]66self.id_end = empty[1]67self.id_pad = empty[1]6869self.return_pooled = True7071def tokenize(self, texts):72return self.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]7374def encode_with_transformers(self, tokens):75tokens_g = tokens.clone()7677for batch_pos in range(tokens_g.shape[0]):78index = tokens_g[batch_pos].cpu().tolist().index(self.id_end)79tokens_g[batch_pos, index+1:tokens_g.shape[1]] = 08081l_out, l_pooled = self.clip_l(tokens)82g_out, g_pooled = self.clip_g(tokens_g)8384lg_out = torch.cat([l_out, g_out], dim=-1)85lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))8687vector_out = torch.cat((l_pooled, g_pooled), dim=-1)8889lg_out.pooled = vector_out90return lg_out9192def encode_embedding_init_text(self, init_text, nvpt):93return torch.zeros((nvpt, 768+1280), device=devices.device) # XXX949596class Sd3T5(torch.nn.Module):97def __init__(self, t5xxl):98super().__init__()99100self.t5xxl = t5xxl101self.tokenizer = T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl")102103empty = self.tokenizer('', padding='max_length', max_length=2)["input_ids"]104self.id_end = empty[0]105self.id_pad = empty[1]106107def tokenize(self, texts):108return self.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]109110def tokenize_line(self, line, *, target_token_count=None):111if shared.opts.emphasis != "None":112parsed = prompt_parser.parse_prompt_attention(line)113else:114parsed = [[line, 1.0]]115116tokenized = self.tokenize([text for text, _ in parsed])117118tokens = []119multipliers = []120121for text_tokens, (text, weight) in zip(tokenized, parsed):122if text == 'BREAK' and weight == -1:123continue124125tokens += text_tokens126multipliers += [weight] * len(text_tokens)127128tokens += [self.id_end]129multipliers += [1.0]130131if target_token_count is not None:132if len(tokens) < target_token_count:133tokens += [self.id_pad] * (target_token_count - len(tokens))134multipliers += [1.0] * (target_token_count - len(tokens))135else:136tokens = tokens[0:target_token_count]137multipliers = multipliers[0:target_token_count]138139return tokens, multipliers140141def forward(self, texts, *, token_count):142if not self.t5xxl or not shared.opts.sd3_enable_t5:143return torch.zeros((len(texts), token_count, 4096), device=devices.device, dtype=devices.dtype)144145tokens_batch = []146147for text in texts:148tokens, multipliers = self.tokenize_line(text, target_token_count=token_count)149tokens_batch.append(tokens)150151t5_out, t5_pooled = self.t5xxl(tokens_batch)152153return t5_out154155def encode_embedding_init_text(self, init_text, nvpt):156return torch.zeros((nvpt, 4096), device=devices.device) # XXX157158159class SD3Cond(torch.nn.Module):160def __init__(self, *args, **kwargs):161super().__init__(*args, **kwargs)162163self.tokenizer = SD3Tokenizer()164165with torch.no_grad():166self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=devices.dtype)167self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG)168169if shared.opts.sd3_enable_t5:170self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype)171else:172self.t5xxl = None173174self.model_lg = Sd3ClipLG(self.clip_l, self.clip_g)175self.model_t5 = Sd3T5(self.t5xxl)176177def forward(self, prompts: list[str]):178with devices.without_autocast():179lg_out, vector_out = self.model_lg(prompts)180t5_out = self.model_t5(prompts, token_count=lg_out.shape[1])181lgt_out = torch.cat([lg_out, t5_out], dim=-2)182183return {184'crossattn': lgt_out,185'vector': vector_out,186}187188def before_load_weights(self, state_dict):189clip_path = os.path.join(shared.models_path, "CLIP")190191if 'text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight' not in state_dict:192clip_g_file = modelloader.load_file_from_url(CLIPG_URL, model_dir=clip_path, file_name="clip_g.safetensors")193with safetensors.safe_open(clip_g_file, framework="pt") as file:194self.clip_g.transformer.load_state_dict(SafetensorsMapping(file))195196if 'text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight' not in state_dict:197clip_l_file = modelloader.load_file_from_url(CLIPL_URL, model_dir=clip_path, file_name="clip_l.safetensors")198with safetensors.safe_open(clip_l_file, framework="pt") as file:199self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False)200201if self.t5xxl and 'text_encoders.t5xxl.transformer.encoder.embed_tokens.weight' not in state_dict:202t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp16.safetensors")203with safetensors.safe_open(t5_file, framework="pt") as file:204self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False)205206def encode_embedding_init_text(self, init_text, nvpt):207return self.model_lg.encode_embedding_init_text(init_text, nvpt)208209def tokenize(self, texts):210return self.model_lg.tokenize(texts)211212def medvram_modules(self):213return [self.clip_g, self.clip_l, self.t5xxl]214215def get_token_count(self, text):216_, token_count = self.model_lg.process_texts([text])217218return token_count219220def get_target_prompt_token_count(self, token_count):221return self.model_lg.get_target_prompt_token_count(token_count)222223224