Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
automatic1111
GitHub Repository: automatic1111/stable-diffusion-webui
Path: blob/master/modules/models/sd3/other_impls.py
3072 views
1
### This file contains impls for underlying related models (CLIP, T5, etc)
2
3
import torch
4
import math
5
from torch import nn
6
from transformers import CLIPTokenizer, T5TokenizerFast
7
8
from modules import sd_hijack
9
10
11
#################################################################################################
12
### Core/Utility
13
#################################################################################################
14
15
16
class AutocastLinear(nn.Linear):
17
"""Same as usual linear layer, but casts its weights to whatever the parameter type is.
18
19
This is different from torch.autocast in a way that float16 layer processing float32 input
20
will return float16 with autocast on, and float32 with this. T5 seems to be fucked
21
if you do it in full float16 (returning almost all zeros in the final output).
22
"""
23
24
def forward(self, x):
25
return torch.nn.functional.linear(x, self.weight.to(x.dtype), self.bias.to(x.dtype) if self.bias is not None else None)
26
27
28
def attention(q, k, v, heads, mask=None):
29
"""Convenience wrapper around a basic attention operation"""
30
b, _, dim_head = q.shape
31
dim_head //= heads
32
q, k, v = [t.view(b, -1, heads, dim_head).transpose(1, 2) for t in (q, k, v)]
33
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
34
return out.transpose(1, 2).reshape(b, -1, heads * dim_head)
35
36
37
class Mlp(nn.Module):
38
""" MLP as used in Vision Transformer, MLP-Mixer and related networks"""
39
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, dtype=None, device=None):
40
super().__init__()
41
out_features = out_features or in_features
42
hidden_features = hidden_features or in_features
43
44
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias, dtype=dtype, device=device)
45
self.act = act_layer
46
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, dtype=dtype, device=device)
47
48
def forward(self, x):
49
x = self.fc1(x)
50
x = self.act(x)
51
x = self.fc2(x)
52
return x
53
54
55
#################################################################################################
56
### CLIP
57
#################################################################################################
58
59
60
class CLIPAttention(torch.nn.Module):
61
def __init__(self, embed_dim, heads, dtype, device):
62
super().__init__()
63
self.heads = heads
64
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
65
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
66
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
67
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
68
69
def forward(self, x, mask=None):
70
q = self.q_proj(x)
71
k = self.k_proj(x)
72
v = self.v_proj(x)
73
out = attention(q, k, v, self.heads, mask)
74
return self.out_proj(out)
75
76
77
ACTIVATIONS = {
78
"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a),
79
"gelu": torch.nn.functional.gelu,
80
}
81
82
class CLIPLayer(torch.nn.Module):
83
def __init__(self, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device):
84
super().__init__()
85
self.layer_norm1 = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
86
self.self_attn = CLIPAttention(embed_dim, heads, dtype, device)
87
self.layer_norm2 = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
88
#self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device)
89
self.mlp = Mlp(embed_dim, intermediate_size, embed_dim, act_layer=ACTIVATIONS[intermediate_activation], dtype=dtype, device=device)
90
91
def forward(self, x, mask=None):
92
x += self.self_attn(self.layer_norm1(x), mask)
93
x += self.mlp(self.layer_norm2(x))
94
return x
95
96
97
class CLIPEncoder(torch.nn.Module):
98
def __init__(self, num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device):
99
super().__init__()
100
self.layers = torch.nn.ModuleList([CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device) for i in range(num_layers)])
101
102
def forward(self, x, mask=None, intermediate_output=None):
103
if intermediate_output is not None:
104
if intermediate_output < 0:
105
intermediate_output = len(self.layers) + intermediate_output
106
intermediate = None
107
for i, layer in enumerate(self.layers):
108
x = layer(x, mask)
109
if i == intermediate_output:
110
intermediate = x.clone()
111
return x, intermediate
112
113
114
class CLIPEmbeddings(torch.nn.Module):
115
def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None, textual_inversion_key="clip_l"):
116
super().__init__()
117
self.token_embedding = sd_hijack.TextualInversionEmbeddings(vocab_size, embed_dim, dtype=dtype, device=device, textual_inversion_key=textual_inversion_key)
118
self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
119
120
def forward(self, input_tokens):
121
return self.token_embedding(input_tokens) + self.position_embedding.weight
122
123
124
class CLIPTextModel_(torch.nn.Module):
125
def __init__(self, config_dict, dtype, device):
126
num_layers = config_dict["num_hidden_layers"]
127
embed_dim = config_dict["hidden_size"]
128
heads = config_dict["num_attention_heads"]
129
intermediate_size = config_dict["intermediate_size"]
130
intermediate_activation = config_dict["hidden_act"]
131
super().__init__()
132
self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device, textual_inversion_key=config_dict.get('textual_inversion_key', 'clip_l'))
133
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device)
134
self.final_layer_norm = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
135
136
def forward(self, input_tokens, intermediate_output=None, final_layer_norm_intermediate=True):
137
x = self.embeddings(input_tokens)
138
causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
139
x, i = self.encoder(x, mask=causal_mask, intermediate_output=intermediate_output)
140
x = self.final_layer_norm(x)
141
if i is not None and final_layer_norm_intermediate:
142
i = self.final_layer_norm(i)
143
pooled_output = x[torch.arange(x.shape[0], device=x.device), input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1),]
144
return x, i, pooled_output
145
146
147
class CLIPTextModel(torch.nn.Module):
148
def __init__(self, config_dict, dtype, device):
149
super().__init__()
150
self.num_layers = config_dict["num_hidden_layers"]
151
self.text_model = CLIPTextModel_(config_dict, dtype, device)
152
embed_dim = config_dict["hidden_size"]
153
self.text_projection = nn.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
154
self.text_projection.weight.copy_(torch.eye(embed_dim))
155
self.dtype = dtype
156
157
def get_input_embeddings(self):
158
return self.text_model.embeddings.token_embedding
159
160
def set_input_embeddings(self, embeddings):
161
self.text_model.embeddings.token_embedding = embeddings
162
163
def forward(self, *args, **kwargs):
164
x = self.text_model(*args, **kwargs)
165
out = self.text_projection(x[2])
166
return (x[0], x[1], out, x[2])
167
168
169
class SDTokenizer:
170
def __init__(self, max_length=77, pad_with_end=True, tokenizer=None, has_start_token=True, pad_to_max_length=True, min_length=None):
171
self.tokenizer = tokenizer
172
self.max_length = max_length
173
self.min_length = min_length
174
empty = self.tokenizer('')["input_ids"]
175
if has_start_token:
176
self.tokens_start = 1
177
self.start_token = empty[0]
178
self.end_token = empty[1]
179
else:
180
self.tokens_start = 0
181
self.start_token = None
182
self.end_token = empty[0]
183
self.pad_with_end = pad_with_end
184
self.pad_to_max_length = pad_to_max_length
185
vocab = self.tokenizer.get_vocab()
186
self.inv_vocab = {v: k for k, v in vocab.items()}
187
self.max_word_length = 8
188
189
190
def tokenize_with_weights(self, text:str):
191
"""Tokenize the text, with weight values - presume 1.0 for all and ignore other features here. The details aren't relevant for a reference impl, and weights themselves has weak effect on SD3."""
192
if self.pad_with_end:
193
pad_token = self.end_token
194
else:
195
pad_token = 0
196
batch = []
197
if self.start_token is not None:
198
batch.append((self.start_token, 1.0))
199
to_tokenize = text.replace("\n", " ").split(' ')
200
to_tokenize = [x for x in to_tokenize if x != ""]
201
for word in to_tokenize:
202
batch.extend([(t, 1) for t in self.tokenizer(word)["input_ids"][self.tokens_start:-1]])
203
batch.append((self.end_token, 1.0))
204
if self.pad_to_max_length:
205
batch.extend([(pad_token, 1.0)] * (self.max_length - len(batch)))
206
if self.min_length is not None and len(batch) < self.min_length:
207
batch.extend([(pad_token, 1.0)] * (self.min_length - len(batch)))
208
return [batch]
209
210
211
class SDXLClipGTokenizer(SDTokenizer):
212
def __init__(self, tokenizer):
213
super().__init__(pad_with_end=False, tokenizer=tokenizer)
214
215
216
class SD3Tokenizer:
217
def __init__(self):
218
clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
219
self.clip_l = SDTokenizer(tokenizer=clip_tokenizer)
220
self.clip_g = SDXLClipGTokenizer(clip_tokenizer)
221
self.t5xxl = T5XXLTokenizer()
222
223
def tokenize_with_weights(self, text:str):
224
out = {}
225
out["g"] = self.clip_g.tokenize_with_weights(text)
226
out["l"] = self.clip_l.tokenize_with_weights(text)
227
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text)
228
return out
229
230
231
class ClipTokenWeightEncoder:
232
def encode_token_weights(self, token_weight_pairs):
233
tokens = [a[0] for a in token_weight_pairs[0]]
234
out, pooled = self([tokens])
235
if pooled is not None:
236
first_pooled = pooled[0:1].cpu()
237
else:
238
first_pooled = pooled
239
output = [out[0:1]]
240
return torch.cat(output, dim=-2).cpu(), first_pooled
241
242
243
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
244
"""Uses the CLIP transformer encoder for text (from huggingface)"""
245
LAYERS = ["last", "pooled", "hidden"]
246
def __init__(self, device="cpu", max_length=77, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=CLIPTextModel,
247
special_tokens=None, layer_norm_hidden_state=True, return_projected_pooled=True):
248
super().__init__()
249
assert layer in self.LAYERS
250
self.transformer = model_class(textmodel_json_config, dtype, device)
251
self.num_layers = self.transformer.num_layers
252
self.max_length = max_length
253
self.transformer = self.transformer.eval()
254
for param in self.parameters():
255
param.requires_grad = False
256
self.layer = layer
257
self.layer_idx = None
258
self.special_tokens = special_tokens if special_tokens is not None else {"start": 49406, "end": 49407, "pad": 49407}
259
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
260
self.layer_norm_hidden_state = layer_norm_hidden_state
261
self.return_projected_pooled = return_projected_pooled
262
if layer == "hidden":
263
assert layer_idx is not None
264
assert abs(layer_idx) < self.num_layers
265
self.set_clip_options({"layer": layer_idx})
266
self.options_default = (self.layer, self.layer_idx, self.return_projected_pooled)
267
268
def set_clip_options(self, options):
269
layer_idx = options.get("layer", self.layer_idx)
270
self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
271
if layer_idx is None or abs(layer_idx) > self.num_layers:
272
self.layer = "last"
273
else:
274
self.layer = "hidden"
275
self.layer_idx = layer_idx
276
277
def forward(self, tokens):
278
backup_embeds = self.transformer.get_input_embeddings()
279
tokens = torch.asarray(tokens, dtype=torch.int64, device=backup_embeds.weight.device)
280
outputs = self.transformer(tokens, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state)
281
self.transformer.set_input_embeddings(backup_embeds)
282
if self.layer == "last":
283
z = outputs[0]
284
else:
285
z = outputs[1]
286
pooled_output = None
287
if len(outputs) >= 3:
288
if not self.return_projected_pooled and len(outputs) >= 4 and outputs[3] is not None:
289
pooled_output = outputs[3].float()
290
elif outputs[2] is not None:
291
pooled_output = outputs[2].float()
292
return z.float(), pooled_output
293
294
295
class SDXLClipG(SDClipModel):
296
"""Wraps the CLIP-G model into the SD-CLIP-Model interface"""
297
def __init__(self, config, device="cpu", layer="penultimate", layer_idx=None, dtype=None):
298
if layer == "penultimate":
299
layer="hidden"
300
layer_idx=-2
301
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False)
302
303
304
class T5XXLModel(SDClipModel):
305
"""Wraps the T5-XXL model into the SD-CLIP-Model interface for convenience"""
306
def __init__(self, config, device="cpu", layer="last", layer_idx=None, dtype=None):
307
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=T5)
308
309
310
#################################################################################################
311
### T5 implementation, for the T5-XXL text encoder portion, largely pulled from upstream impl
312
#################################################################################################
313
314
class T5XXLTokenizer(SDTokenizer):
315
"""Wraps the T5 Tokenizer from HF into the SDTokenizer interface"""
316
def __init__(self):
317
super().__init__(pad_with_end=False, tokenizer=T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl"), has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77)
318
319
320
class T5LayerNorm(torch.nn.Module):
321
def __init__(self, hidden_size, eps=1e-6, dtype=None, device=None):
322
super().__init__()
323
self.weight = torch.nn.Parameter(torch.ones(hidden_size, dtype=dtype, device=device))
324
self.variance_epsilon = eps
325
326
def forward(self, x):
327
variance = x.pow(2).mean(-1, keepdim=True)
328
x = x * torch.rsqrt(variance + self.variance_epsilon)
329
return self.weight.to(device=x.device, dtype=x.dtype) * x
330
331
332
class T5DenseGatedActDense(torch.nn.Module):
333
def __init__(self, model_dim, ff_dim, dtype, device):
334
super().__init__()
335
self.wi_0 = AutocastLinear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
336
self.wi_1 = AutocastLinear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
337
self.wo = AutocastLinear(ff_dim, model_dim, bias=False, dtype=dtype, device=device)
338
339
def forward(self, x):
340
hidden_gelu = torch.nn.functional.gelu(self.wi_0(x), approximate="tanh")
341
hidden_linear = self.wi_1(x)
342
x = hidden_gelu * hidden_linear
343
x = self.wo(x)
344
return x
345
346
347
class T5LayerFF(torch.nn.Module):
348
def __init__(self, model_dim, ff_dim, dtype, device):
349
super().__init__()
350
self.DenseReluDense = T5DenseGatedActDense(model_dim, ff_dim, dtype, device)
351
self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)
352
353
def forward(self, x):
354
forwarded_states = self.layer_norm(x)
355
forwarded_states = self.DenseReluDense(forwarded_states)
356
x += forwarded_states
357
return x
358
359
360
class T5Attention(torch.nn.Module):
361
def __init__(self, model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device):
362
super().__init__()
363
# Mesh TensorFlow initialization to avoid scaling before softmax
364
self.q = AutocastLinear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
365
self.k = AutocastLinear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
366
self.v = AutocastLinear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
367
self.o = AutocastLinear(inner_dim, model_dim, bias=False, dtype=dtype, device=device)
368
self.num_heads = num_heads
369
self.relative_attention_bias = None
370
if relative_attention_bias:
371
self.relative_attention_num_buckets = 32
372
self.relative_attention_max_distance = 128
373
self.relative_attention_bias = torch.nn.Embedding(self.relative_attention_num_buckets, self.num_heads, device=device)
374
375
@staticmethod
376
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
377
"""
378
Adapted from Mesh Tensorflow:
379
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
380
381
Translate relative position to a bucket number for relative attention. The relative position is defined as
382
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
383
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
384
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
385
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
386
This should allow for more graceful generalization to longer sequences than the model has been trained on
387
388
Args:
389
relative_position: an int32 Tensor
390
bidirectional: a boolean - whether the attention is bidirectional
391
num_buckets: an integer
392
max_distance: an integer
393
394
Returns:
395
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
396
"""
397
relative_buckets = 0
398
if bidirectional:
399
num_buckets //= 2
400
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
401
relative_position = torch.abs(relative_position)
402
else:
403
relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
404
# now relative_position is in the range [0, inf)
405
# half of the buckets are for exact increments in positions
406
max_exact = num_buckets // 2
407
is_small = relative_position < max_exact
408
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
409
relative_position_if_large = max_exact + (
410
torch.log(relative_position.float() / max_exact)
411
/ math.log(max_distance / max_exact)
412
* (num_buckets - max_exact)
413
).to(torch.long)
414
relative_position_if_large = torch.min(relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1))
415
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
416
return relative_buckets
417
418
def compute_bias(self, query_length, key_length, device):
419
"""Compute binned relative position bias"""
420
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
421
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
422
relative_position = memory_position - context_position # shape (query_length, key_length)
423
relative_position_bucket = self._relative_position_bucket(
424
relative_position, # shape (query_length, key_length)
425
bidirectional=True,
426
num_buckets=self.relative_attention_num_buckets,
427
max_distance=self.relative_attention_max_distance,
428
)
429
values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
430
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
431
return values
432
433
def forward(self, x, past_bias=None):
434
q = self.q(x)
435
k = self.k(x)
436
v = self.v(x)
437
438
if self.relative_attention_bias is not None:
439
past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device)
440
if past_bias is not None:
441
mask = past_bias
442
else:
443
mask = None
444
445
out = attention(q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask.to(x.dtype) if mask is not None else None)
446
447
return self.o(out), past_bias
448
449
450
class T5LayerSelfAttention(torch.nn.Module):
451
def __init__(self, model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device):
452
super().__init__()
453
self.SelfAttention = T5Attention(model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device)
454
self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)
455
456
def forward(self, x, past_bias=None):
457
output, past_bias = self.SelfAttention(self.layer_norm(x), past_bias=past_bias)
458
x += output
459
return x, past_bias
460
461
462
class T5Block(torch.nn.Module):
463
def __init__(self, model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device):
464
super().__init__()
465
self.layer = torch.nn.ModuleList()
466
self.layer.append(T5LayerSelfAttention(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device))
467
self.layer.append(T5LayerFF(model_dim, ff_dim, dtype, device))
468
469
def forward(self, x, past_bias=None):
470
x, past_bias = self.layer[0](x, past_bias)
471
x = self.layer[-1](x)
472
return x, past_bias
473
474
475
class T5Stack(torch.nn.Module):
476
def __init__(self, num_layers, model_dim, inner_dim, ff_dim, num_heads, vocab_size, dtype, device):
477
super().__init__()
478
self.embed_tokens = torch.nn.Embedding(vocab_size, model_dim, device=device)
479
self.block = torch.nn.ModuleList([T5Block(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias=(i == 0), dtype=dtype, device=device) for i in range(num_layers)])
480
self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)
481
482
def forward(self, input_ids, intermediate_output=None, final_layer_norm_intermediate=True):
483
intermediate = None
484
x = self.embed_tokens(input_ids).to(torch.float32) # needs float32 or else T5 returns all zeroes
485
past_bias = None
486
for i, layer in enumerate(self.block):
487
x, past_bias = layer(x, past_bias)
488
if i == intermediate_output:
489
intermediate = x.clone()
490
x = self.final_layer_norm(x)
491
if intermediate is not None and final_layer_norm_intermediate:
492
intermediate = self.final_layer_norm(intermediate)
493
return x, intermediate
494
495
496
class T5(torch.nn.Module):
497
def __init__(self, config_dict, dtype, device):
498
super().__init__()
499
self.num_layers = config_dict["num_layers"]
500
self.encoder = T5Stack(self.num_layers, config_dict["d_model"], config_dict["d_model"], config_dict["d_ff"], config_dict["num_heads"], config_dict["vocab_size"], dtype, device)
501
self.dtype = dtype
502
503
def get_input_embeddings(self):
504
return self.encoder.embed_tokens
505
506
def set_input_embeddings(self, embeddings):
507
self.encoder.embed_tokens = embeddings
508
509
def forward(self, *args, **kwargs):
510
return self.encoder(*args, **kwargs)
511
512