Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TencentARC
GitHub Repository: TencentARC/GFPGAN
Path: blob/master/gfpgan/archs/restoreformer_arch.py
884 views
1
"""Modified from https://github.com/wzhouxiff/RestoreFormer
2
"""
3
import numpy as np
4
import torch
5
import torch.nn as nn
6
import torch.nn.functional as F
7
8
9
class VectorQuantizer(nn.Module):
10
"""
11
see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
12
____________________________________________
13
Discretization bottleneck part of the VQ-VAE.
14
Inputs:
15
- n_e : number of embeddings
16
- e_dim : dimension of embedding
17
- beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
18
_____________________________________________
19
"""
20
21
def __init__(self, n_e, e_dim, beta):
22
super(VectorQuantizer, self).__init__()
23
self.n_e = n_e
24
self.e_dim = e_dim
25
self.beta = beta
26
27
self.embedding = nn.Embedding(self.n_e, self.e_dim)
28
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
29
30
def forward(self, z):
31
"""
32
Inputs the output of the encoder network z and maps it to a discrete
33
one-hot vector that is the index of the closest embedding vector e_j
34
z (continuous) -> z_q (discrete)
35
z.shape = (batch, channel, height, width)
36
quantization pipeline:
37
1. get encoder input (B,C,H,W)
38
2. flatten input to (B*H*W,C)
39
"""
40
# reshape z -> (batch, height, width, channel) and flatten
41
z = z.permute(0, 2, 3, 1).contiguous()
42
z_flattened = z.view(-1, self.e_dim)
43
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
44
45
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
46
torch.sum(self.embedding.weight**2, dim=1) - 2 * \
47
torch.matmul(z_flattened, self.embedding.weight.t())
48
49
# could possible replace this here
50
# #\start...
51
# find closest encodings
52
53
min_value, min_encoding_indices = torch.min(d, dim=1)
54
55
min_encoding_indices = min_encoding_indices.unsqueeze(1)
56
57
min_encodings = torch.zeros(min_encoding_indices.shape[0], self.n_e).to(z)
58
min_encodings.scatter_(1, min_encoding_indices, 1)
59
60
# dtype min encodings: torch.float32
61
# min_encodings shape: torch.Size([2048, 512])
62
# min_encoding_indices.shape: torch.Size([2048, 1])
63
64
# get quantized latent vectors
65
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
66
# .........\end
67
68
# with:
69
# .........\start
70
# min_encoding_indices = torch.argmin(d, dim=1)
71
# z_q = self.embedding(min_encoding_indices)
72
# ......\end......... (TODO)
73
74
# compute loss for embedding
75
loss = torch.mean((z_q.detach() - z)**2) + self.beta * torch.mean((z_q - z.detach())**2)
76
77
# preserve gradients
78
z_q = z + (z_q - z).detach()
79
80
# perplexity
81
82
e_mean = torch.mean(min_encodings, dim=0)
83
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
84
85
# reshape back to match original input shape
86
z_q = z_q.permute(0, 3, 1, 2).contiguous()
87
88
return z_q, loss, (perplexity, min_encodings, min_encoding_indices, d)
89
90
def get_codebook_entry(self, indices, shape):
91
# shape specifying (batch, height, width, channel)
92
# TODO: check for more easy handling with nn.Embedding
93
min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices)
94
min_encodings.scatter_(1, indices[:, None], 1)
95
96
# get quantized latent vectors
97
z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
98
99
if shape is not None:
100
z_q = z_q.view(shape)
101
102
# reshape back to match original input shape
103
z_q = z_q.permute(0, 3, 1, 2).contiguous()
104
105
return z_q
106
107
108
# pytorch_diffusion + derived encoder decoder
109
def nonlinearity(x):
110
# swish
111
return x * torch.sigmoid(x)
112
113
114
def Normalize(in_channels):
115
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
116
117
118
class Upsample(nn.Module):
119
120
def __init__(self, in_channels, with_conv):
121
super().__init__()
122
self.with_conv = with_conv
123
if self.with_conv:
124
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
125
126
def forward(self, x):
127
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode='nearest')
128
if self.with_conv:
129
x = self.conv(x)
130
return x
131
132
133
class Downsample(nn.Module):
134
135
def __init__(self, in_channels, with_conv):
136
super().__init__()
137
self.with_conv = with_conv
138
if self.with_conv:
139
# no asymmetric padding in torch conv, must do it ourselves
140
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
141
142
def forward(self, x):
143
if self.with_conv:
144
pad = (0, 1, 0, 1)
145
x = torch.nn.functional.pad(x, pad, mode='constant', value=0)
146
x = self.conv(x)
147
else:
148
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
149
return x
150
151
152
class ResnetBlock(nn.Module):
153
154
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
155
super().__init__()
156
self.in_channels = in_channels
157
out_channels = in_channels if out_channels is None else out_channels
158
self.out_channels = out_channels
159
self.use_conv_shortcut = conv_shortcut
160
161
self.norm1 = Normalize(in_channels)
162
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
163
if temb_channels > 0:
164
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
165
self.norm2 = Normalize(out_channels)
166
self.dropout = torch.nn.Dropout(dropout)
167
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
168
if self.in_channels != self.out_channels:
169
if self.use_conv_shortcut:
170
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
171
else:
172
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
173
174
def forward(self, x, temb):
175
h = x
176
h = self.norm1(h)
177
h = nonlinearity(h)
178
h = self.conv1(h)
179
180
if temb is not None:
181
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
182
183
h = self.norm2(h)
184
h = nonlinearity(h)
185
h = self.dropout(h)
186
h = self.conv2(h)
187
188
if self.in_channels != self.out_channels:
189
if self.use_conv_shortcut:
190
x = self.conv_shortcut(x)
191
else:
192
x = self.nin_shortcut(x)
193
194
return x + h
195
196
197
class MultiHeadAttnBlock(nn.Module):
198
199
def __init__(self, in_channels, head_size=1):
200
super().__init__()
201
self.in_channels = in_channels
202
self.head_size = head_size
203
self.att_size = in_channels // head_size
204
assert (in_channels % head_size == 0), 'The size of head should be divided by the number of channels.'
205
206
self.norm1 = Normalize(in_channels)
207
self.norm2 = Normalize(in_channels)
208
209
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
210
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
211
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
212
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
213
self.num = 0
214
215
def forward(self, x, y=None):
216
h_ = x
217
h_ = self.norm1(h_)
218
if y is None:
219
y = h_
220
else:
221
y = self.norm2(y)
222
223
q = self.q(y)
224
k = self.k(h_)
225
v = self.v(h_)
226
227
# compute attention
228
b, c, h, w = q.shape
229
q = q.reshape(b, self.head_size, self.att_size, h * w)
230
q = q.permute(0, 3, 1, 2) # b, hw, head, att
231
232
k = k.reshape(b, self.head_size, self.att_size, h * w)
233
k = k.permute(0, 3, 1, 2)
234
235
v = v.reshape(b, self.head_size, self.att_size, h * w)
236
v = v.permute(0, 3, 1, 2)
237
238
q = q.transpose(1, 2)
239
v = v.transpose(1, 2)
240
k = k.transpose(1, 2).transpose(2, 3)
241
242
scale = int(self.att_size)**(-0.5)
243
q.mul_(scale)
244
w_ = torch.matmul(q, k)
245
w_ = F.softmax(w_, dim=3)
246
247
w_ = w_.matmul(v)
248
249
w_ = w_.transpose(1, 2).contiguous() # [b, h*w, head, att]
250
w_ = w_.view(b, h, w, -1)
251
w_ = w_.permute(0, 3, 1, 2)
252
253
w_ = self.proj_out(w_)
254
255
return x + w_
256
257
258
class MultiHeadEncoder(nn.Module):
259
260
def __init__(self,
261
ch,
262
out_ch,
263
ch_mult=(1, 2, 4, 8),
264
num_res_blocks=2,
265
attn_resolutions=(16, ),
266
dropout=0.0,
267
resamp_with_conv=True,
268
in_channels=3,
269
resolution=512,
270
z_channels=256,
271
double_z=True,
272
enable_mid=True,
273
head_size=1,
274
**ignore_kwargs):
275
super().__init__()
276
self.ch = ch
277
self.temb_ch = 0
278
self.num_resolutions = len(ch_mult)
279
self.num_res_blocks = num_res_blocks
280
self.resolution = resolution
281
self.in_channels = in_channels
282
self.enable_mid = enable_mid
283
284
# downsampling
285
self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
286
287
curr_res = resolution
288
in_ch_mult = (1, ) + tuple(ch_mult)
289
self.down = nn.ModuleList()
290
for i_level in range(self.num_resolutions):
291
block = nn.ModuleList()
292
attn = nn.ModuleList()
293
block_in = ch * in_ch_mult[i_level]
294
block_out = ch * ch_mult[i_level]
295
for i_block in range(self.num_res_blocks):
296
block.append(
297
ResnetBlock(
298
in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout))
299
block_in = block_out
300
if curr_res in attn_resolutions:
301
attn.append(MultiHeadAttnBlock(block_in, head_size))
302
down = nn.Module()
303
down.block = block
304
down.attn = attn
305
if i_level != self.num_resolutions - 1:
306
down.downsample = Downsample(block_in, resamp_with_conv)
307
curr_res = curr_res // 2
308
self.down.append(down)
309
310
# middle
311
if self.enable_mid:
312
self.mid = nn.Module()
313
self.mid.block_1 = ResnetBlock(
314
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
315
self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size)
316
self.mid.block_2 = ResnetBlock(
317
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
318
319
# end
320
self.norm_out = Normalize(block_in)
321
self.conv_out = torch.nn.Conv2d(
322
block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1)
323
324
def forward(self, x):
325
hs = {}
326
# timestep embedding
327
temb = None
328
329
# downsampling
330
h = self.conv_in(x)
331
hs['in'] = h
332
for i_level in range(self.num_resolutions):
333
for i_block in range(self.num_res_blocks):
334
h = self.down[i_level].block[i_block](h, temb)
335
if len(self.down[i_level].attn) > 0:
336
h = self.down[i_level].attn[i_block](h)
337
338
if i_level != self.num_resolutions - 1:
339
# hs.append(h)
340
hs['block_' + str(i_level)] = h
341
h = self.down[i_level].downsample(h)
342
343
# middle
344
# h = hs[-1]
345
if self.enable_mid:
346
h = self.mid.block_1(h, temb)
347
hs['block_' + str(i_level) + '_atten'] = h
348
h = self.mid.attn_1(h)
349
h = self.mid.block_2(h, temb)
350
hs['mid_atten'] = h
351
352
# end
353
h = self.norm_out(h)
354
h = nonlinearity(h)
355
h = self.conv_out(h)
356
# hs.append(h)
357
hs['out'] = h
358
359
return hs
360
361
362
class MultiHeadDecoder(nn.Module):
363
364
def __init__(self,
365
ch,
366
out_ch,
367
ch_mult=(1, 2, 4, 8),
368
num_res_blocks=2,
369
attn_resolutions=(16, ),
370
dropout=0.0,
371
resamp_with_conv=True,
372
in_channels=3,
373
resolution=512,
374
z_channels=256,
375
give_pre_end=False,
376
enable_mid=True,
377
head_size=1,
378
**ignorekwargs):
379
super().__init__()
380
self.ch = ch
381
self.temb_ch = 0
382
self.num_resolutions = len(ch_mult)
383
self.num_res_blocks = num_res_blocks
384
self.resolution = resolution
385
self.in_channels = in_channels
386
self.give_pre_end = give_pre_end
387
self.enable_mid = enable_mid
388
389
# compute in_ch_mult, block_in and curr_res at lowest res
390
block_in = ch * ch_mult[self.num_resolutions - 1]
391
curr_res = resolution // 2**(self.num_resolutions - 1)
392
self.z_shape = (1, z_channels, curr_res, curr_res)
393
print('Working with z of shape {} = {} dimensions.'.format(self.z_shape, np.prod(self.z_shape)))
394
395
# z to block_in
396
self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
397
398
# middle
399
if self.enable_mid:
400
self.mid = nn.Module()
401
self.mid.block_1 = ResnetBlock(
402
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
403
self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size)
404
self.mid.block_2 = ResnetBlock(
405
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
406
407
# upsampling
408
self.up = nn.ModuleList()
409
for i_level in reversed(range(self.num_resolutions)):
410
block = nn.ModuleList()
411
attn = nn.ModuleList()
412
block_out = ch * ch_mult[i_level]
413
for i_block in range(self.num_res_blocks + 1):
414
block.append(
415
ResnetBlock(
416
in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout))
417
block_in = block_out
418
if curr_res in attn_resolutions:
419
attn.append(MultiHeadAttnBlock(block_in, head_size))
420
up = nn.Module()
421
up.block = block
422
up.attn = attn
423
if i_level != 0:
424
up.upsample = Upsample(block_in, resamp_with_conv)
425
curr_res = curr_res * 2
426
self.up.insert(0, up) # prepend to get consistent order
427
428
# end
429
self.norm_out = Normalize(block_in)
430
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
431
432
def forward(self, z):
433
# assert z.shape[1:] == self.z_shape[1:]
434
self.last_z_shape = z.shape
435
436
# timestep embedding
437
temb = None
438
439
# z to block_in
440
h = self.conv_in(z)
441
442
# middle
443
if self.enable_mid:
444
h = self.mid.block_1(h, temb)
445
h = self.mid.attn_1(h)
446
h = self.mid.block_2(h, temb)
447
448
# upsampling
449
for i_level in reversed(range(self.num_resolutions)):
450
for i_block in range(self.num_res_blocks + 1):
451
h = self.up[i_level].block[i_block](h, temb)
452
if len(self.up[i_level].attn) > 0:
453
h = self.up[i_level].attn[i_block](h)
454
if i_level != 0:
455
h = self.up[i_level].upsample(h)
456
457
# end
458
if self.give_pre_end:
459
return h
460
461
h = self.norm_out(h)
462
h = nonlinearity(h)
463
h = self.conv_out(h)
464
return h
465
466
467
class MultiHeadDecoderTransformer(nn.Module):
468
469
def __init__(self,
470
ch,
471
out_ch,
472
ch_mult=(1, 2, 4, 8),
473
num_res_blocks=2,
474
attn_resolutions=(16, ),
475
dropout=0.0,
476
resamp_with_conv=True,
477
in_channels=3,
478
resolution=512,
479
z_channels=256,
480
give_pre_end=False,
481
enable_mid=True,
482
head_size=1,
483
**ignorekwargs):
484
super().__init__()
485
self.ch = ch
486
self.temb_ch = 0
487
self.num_resolutions = len(ch_mult)
488
self.num_res_blocks = num_res_blocks
489
self.resolution = resolution
490
self.in_channels = in_channels
491
self.give_pre_end = give_pre_end
492
self.enable_mid = enable_mid
493
494
# compute in_ch_mult, block_in and curr_res at lowest res
495
block_in = ch * ch_mult[self.num_resolutions - 1]
496
curr_res = resolution // 2**(self.num_resolutions - 1)
497
self.z_shape = (1, z_channels, curr_res, curr_res)
498
print('Working with z of shape {} = {} dimensions.'.format(self.z_shape, np.prod(self.z_shape)))
499
500
# z to block_in
501
self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
502
503
# middle
504
if self.enable_mid:
505
self.mid = nn.Module()
506
self.mid.block_1 = ResnetBlock(
507
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
508
self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size)
509
self.mid.block_2 = ResnetBlock(
510
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
511
512
# upsampling
513
self.up = nn.ModuleList()
514
for i_level in reversed(range(self.num_resolutions)):
515
block = nn.ModuleList()
516
attn = nn.ModuleList()
517
block_out = ch * ch_mult[i_level]
518
for i_block in range(self.num_res_blocks + 1):
519
block.append(
520
ResnetBlock(
521
in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout))
522
block_in = block_out
523
if curr_res in attn_resolutions:
524
attn.append(MultiHeadAttnBlock(block_in, head_size))
525
up = nn.Module()
526
up.block = block
527
up.attn = attn
528
if i_level != 0:
529
up.upsample = Upsample(block_in, resamp_with_conv)
530
curr_res = curr_res * 2
531
self.up.insert(0, up) # prepend to get consistent order
532
533
# end
534
self.norm_out = Normalize(block_in)
535
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
536
537
def forward(self, z, hs):
538
# assert z.shape[1:] == self.z_shape[1:]
539
# self.last_z_shape = z.shape
540
541
# timestep embedding
542
temb = None
543
544
# z to block_in
545
h = self.conv_in(z)
546
547
# middle
548
if self.enable_mid:
549
h = self.mid.block_1(h, temb)
550
h = self.mid.attn_1(h, hs['mid_atten'])
551
h = self.mid.block_2(h, temb)
552
553
# upsampling
554
for i_level in reversed(range(self.num_resolutions)):
555
for i_block in range(self.num_res_blocks + 1):
556
h = self.up[i_level].block[i_block](h, temb)
557
if len(self.up[i_level].attn) > 0:
558
h = self.up[i_level].attn[i_block](h, hs['block_' + str(i_level) + '_atten'])
559
# hfeature = h.clone()
560
if i_level != 0:
561
h = self.up[i_level].upsample(h)
562
563
# end
564
if self.give_pre_end:
565
return h
566
567
h = self.norm_out(h)
568
h = nonlinearity(h)
569
h = self.conv_out(h)
570
return h
571
572
573
class RestoreFormer(nn.Module):
574
575
def __init__(self,
576
n_embed=1024,
577
embed_dim=256,
578
ch=64,
579
out_ch=3,
580
ch_mult=(1, 2, 2, 4, 4, 8),
581
num_res_blocks=2,
582
attn_resolutions=(16, ),
583
dropout=0.0,
584
in_channels=3,
585
resolution=512,
586
z_channels=256,
587
double_z=False,
588
enable_mid=True,
589
fix_decoder=False,
590
fix_codebook=True,
591
fix_encoder=False,
592
head_size=8):
593
super(RestoreFormer, self).__init__()
594
595
self.encoder = MultiHeadEncoder(
596
ch=ch,
597
out_ch=out_ch,
598
ch_mult=ch_mult,
599
num_res_blocks=num_res_blocks,
600
attn_resolutions=attn_resolutions,
601
dropout=dropout,
602
in_channels=in_channels,
603
resolution=resolution,
604
z_channels=z_channels,
605
double_z=double_z,
606
enable_mid=enable_mid,
607
head_size=head_size)
608
self.decoder = MultiHeadDecoderTransformer(
609
ch=ch,
610
out_ch=out_ch,
611
ch_mult=ch_mult,
612
num_res_blocks=num_res_blocks,
613
attn_resolutions=attn_resolutions,
614
dropout=dropout,
615
in_channels=in_channels,
616
resolution=resolution,
617
z_channels=z_channels,
618
enable_mid=enable_mid,
619
head_size=head_size)
620
621
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25)
622
623
self.quant_conv = torch.nn.Conv2d(z_channels, embed_dim, 1)
624
self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1)
625
626
if fix_decoder:
627
for _, param in self.decoder.named_parameters():
628
param.requires_grad = False
629
for _, param in self.post_quant_conv.named_parameters():
630
param.requires_grad = False
631
for _, param in self.quantize.named_parameters():
632
param.requires_grad = False
633
elif fix_codebook:
634
for _, param in self.quantize.named_parameters():
635
param.requires_grad = False
636
637
if fix_encoder:
638
for _, param in self.encoder.named_parameters():
639
param.requires_grad = False
640
641
def encode(self, x):
642
643
hs = self.encoder(x)
644
h = self.quant_conv(hs['out'])
645
quant, emb_loss, info = self.quantize(h)
646
return quant, emb_loss, info, hs
647
648
def decode(self, quant, hs):
649
quant = self.post_quant_conv(quant)
650
dec = self.decoder(quant, hs)
651
652
return dec
653
654
def forward(self, input, **kwargs):
655
quant, diff, info, hs = self.encode(input)
656
dec = self.decode(quant, hs)
657
658
return dec, None
659
660