Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
automatic1111
GitHub Repository: automatic1111/stable-diffusion-webui
Path: blob/master/modules/models/sd3/sd3_cond.py
3073 views
1
import os
2
import safetensors
3
import torch
4
import typing
5
6
from transformers import CLIPTokenizer, T5TokenizerFast
7
8
from modules import shared, devices, modelloader, sd_hijack_clip, prompt_parser
9
from modules.models.sd3.other_impls import SDClipModel, SDXLClipG, T5XXLModel, SD3Tokenizer
10
11
12
class SafetensorsMapping(typing.Mapping):
13
def __init__(self, file):
14
self.file = file
15
16
def __len__(self):
17
return len(self.file.keys())
18
19
def __iter__(self):
20
for key in self.file.keys():
21
yield key
22
23
def __getitem__(self, key):
24
return self.file.get_tensor(key)
25
26
27
CLIPL_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/clip_l.safetensors"
28
CLIPL_CONFIG = {
29
"hidden_act": "quick_gelu",
30
"hidden_size": 768,
31
"intermediate_size": 3072,
32
"num_attention_heads": 12,
33
"num_hidden_layers": 12,
34
}
35
36
CLIPG_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/clip_g.safetensors"
37
CLIPG_CONFIG = {
38
"hidden_act": "gelu",
39
"hidden_size": 1280,
40
"intermediate_size": 5120,
41
"num_attention_heads": 20,
42
"num_hidden_layers": 32,
43
"textual_inversion_key": "clip_g",
44
}
45
46
T5_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/t5xxl_fp16.safetensors"
47
T5_CONFIG = {
48
"d_ff": 10240,
49
"d_model": 4096,
50
"num_heads": 64,
51
"num_layers": 24,
52
"vocab_size": 32128,
53
}
54
55
56
class Sd3ClipLG(sd_hijack_clip.TextConditionalModel):
57
def __init__(self, clip_l, clip_g):
58
super().__init__()
59
60
self.clip_l = clip_l
61
self.clip_g = clip_g
62
63
self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
64
65
empty = self.tokenizer('')["input_ids"]
66
self.id_start = empty[0]
67
self.id_end = empty[1]
68
self.id_pad = empty[1]
69
70
self.return_pooled = True
71
72
def tokenize(self, texts):
73
return self.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
74
75
def encode_with_transformers(self, tokens):
76
tokens_g = tokens.clone()
77
78
for batch_pos in range(tokens_g.shape[0]):
79
index = tokens_g[batch_pos].cpu().tolist().index(self.id_end)
80
tokens_g[batch_pos, index+1:tokens_g.shape[1]] = 0
81
82
l_out, l_pooled = self.clip_l(tokens)
83
g_out, g_pooled = self.clip_g(tokens_g)
84
85
lg_out = torch.cat([l_out, g_out], dim=-1)
86
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
87
88
vector_out = torch.cat((l_pooled, g_pooled), dim=-1)
89
90
lg_out.pooled = vector_out
91
return lg_out
92
93
def encode_embedding_init_text(self, init_text, nvpt):
94
return torch.zeros((nvpt, 768+1280), device=devices.device) # XXX
95
96
97
class Sd3T5(torch.nn.Module):
98
def __init__(self, t5xxl):
99
super().__init__()
100
101
self.t5xxl = t5xxl
102
self.tokenizer = T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl")
103
104
empty = self.tokenizer('', padding='max_length', max_length=2)["input_ids"]
105
self.id_end = empty[0]
106
self.id_pad = empty[1]
107
108
def tokenize(self, texts):
109
return self.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
110
111
def tokenize_line(self, line, *, target_token_count=None):
112
if shared.opts.emphasis != "None":
113
parsed = prompt_parser.parse_prompt_attention(line)
114
else:
115
parsed = [[line, 1.0]]
116
117
tokenized = self.tokenize([text for text, _ in parsed])
118
119
tokens = []
120
multipliers = []
121
122
for text_tokens, (text, weight) in zip(tokenized, parsed):
123
if text == 'BREAK' and weight == -1:
124
continue
125
126
tokens += text_tokens
127
multipliers += [weight] * len(text_tokens)
128
129
tokens += [self.id_end]
130
multipliers += [1.0]
131
132
if target_token_count is not None:
133
if len(tokens) < target_token_count:
134
tokens += [self.id_pad] * (target_token_count - len(tokens))
135
multipliers += [1.0] * (target_token_count - len(tokens))
136
else:
137
tokens = tokens[0:target_token_count]
138
multipliers = multipliers[0:target_token_count]
139
140
return tokens, multipliers
141
142
def forward(self, texts, *, token_count):
143
if not self.t5xxl or not shared.opts.sd3_enable_t5:
144
return torch.zeros((len(texts), token_count, 4096), device=devices.device, dtype=devices.dtype)
145
146
tokens_batch = []
147
148
for text in texts:
149
tokens, multipliers = self.tokenize_line(text, target_token_count=token_count)
150
tokens_batch.append(tokens)
151
152
t5_out, t5_pooled = self.t5xxl(tokens_batch)
153
154
return t5_out
155
156
def encode_embedding_init_text(self, init_text, nvpt):
157
return torch.zeros((nvpt, 4096), device=devices.device) # XXX
158
159
160
class SD3Cond(torch.nn.Module):
161
def __init__(self, *args, **kwargs):
162
super().__init__(*args, **kwargs)
163
164
self.tokenizer = SD3Tokenizer()
165
166
with torch.no_grad():
167
self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=devices.dtype)
168
self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG)
169
170
if shared.opts.sd3_enable_t5:
171
self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype)
172
else:
173
self.t5xxl = None
174
175
self.model_lg = Sd3ClipLG(self.clip_l, self.clip_g)
176
self.model_t5 = Sd3T5(self.t5xxl)
177
178
def forward(self, prompts: list[str]):
179
with devices.without_autocast():
180
lg_out, vector_out = self.model_lg(prompts)
181
t5_out = self.model_t5(prompts, token_count=lg_out.shape[1])
182
lgt_out = torch.cat([lg_out, t5_out], dim=-2)
183
184
return {
185
'crossattn': lgt_out,
186
'vector': vector_out,
187
}
188
189
def before_load_weights(self, state_dict):
190
clip_path = os.path.join(shared.models_path, "CLIP")
191
192
if 'text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight' not in state_dict:
193
clip_g_file = modelloader.load_file_from_url(CLIPG_URL, model_dir=clip_path, file_name="clip_g.safetensors")
194
with safetensors.safe_open(clip_g_file, framework="pt") as file:
195
self.clip_g.transformer.load_state_dict(SafetensorsMapping(file))
196
197
if 'text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight' not in state_dict:
198
clip_l_file = modelloader.load_file_from_url(CLIPL_URL, model_dir=clip_path, file_name="clip_l.safetensors")
199
with safetensors.safe_open(clip_l_file, framework="pt") as file:
200
self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
201
202
if self.t5xxl and 'text_encoders.t5xxl.transformer.encoder.embed_tokens.weight' not in state_dict:
203
t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp16.safetensors")
204
with safetensors.safe_open(t5_file, framework="pt") as file:
205
self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
206
207
def encode_embedding_init_text(self, init_text, nvpt):
208
return self.model_lg.encode_embedding_init_text(init_text, nvpt)
209
210
def tokenize(self, texts):
211
return self.model_lg.tokenize(texts)
212
213
def medvram_modules(self):
214
return [self.clip_g, self.clip_l, self.t5xxl]
215
216
def get_token_count(self, text):
217
_, token_count = self.model_lg.process_texts([text])
218
219
return token_count
220
221
def get_target_prompt_token_count(self, token_count):
222
return self.model_lg.get_target_prompt_token_count(token_count)
223
224