Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TencentARC
GitHub Repository: TencentARC/GFPGAN
Path: blob/master/gfpgan/archs/gfpgan_bilinear_arch.py
884 views
1
import math
2
import random
3
import torch
4
from basicsr.utils.registry import ARCH_REGISTRY
5
from torch import nn
6
7
from .gfpganv1_arch import ResUpBlock
8
from .stylegan2_bilinear_arch import (ConvLayer, EqualConv2d, EqualLinear, ResBlock, ScaledLeakyReLU,
9
StyleGAN2GeneratorBilinear)
10
11
12
class StyleGAN2GeneratorBilinearSFT(StyleGAN2GeneratorBilinear):
13
"""StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
14
15
It is the bilinear version. It does not use the complicated UpFirDnSmooth function that is not friendly for
16
deployment. It can be easily converted to the clean version: StyleGAN2GeneratorCSFT.
17
18
Args:
19
out_size (int): The spatial size of outputs.
20
num_style_feat (int): Channel number of style features. Default: 512.
21
num_mlp (int): Layer number of MLP style layers. Default: 8.
22
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
23
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
24
narrow (float): The narrow ratio for channels. Default: 1.
25
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
26
"""
27
28
def __init__(self,
29
out_size,
30
num_style_feat=512,
31
num_mlp=8,
32
channel_multiplier=2,
33
lr_mlp=0.01,
34
narrow=1,
35
sft_half=False):
36
super(StyleGAN2GeneratorBilinearSFT, self).__init__(
37
out_size,
38
num_style_feat=num_style_feat,
39
num_mlp=num_mlp,
40
channel_multiplier=channel_multiplier,
41
lr_mlp=lr_mlp,
42
narrow=narrow)
43
self.sft_half = sft_half
44
45
def forward(self,
46
styles,
47
conditions,
48
input_is_latent=False,
49
noise=None,
50
randomize_noise=True,
51
truncation=1,
52
truncation_latent=None,
53
inject_index=None,
54
return_latents=False):
55
"""Forward function for StyleGAN2GeneratorBilinearSFT.
56
57
Args:
58
styles (list[Tensor]): Sample codes of styles.
59
conditions (list[Tensor]): SFT conditions to generators.
60
input_is_latent (bool): Whether input is latent style. Default: False.
61
noise (Tensor | None): Input noise or None. Default: None.
62
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
63
truncation (float): The truncation ratio. Default: 1.
64
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
65
inject_index (int | None): The injection index for mixing noise. Default: None.
66
return_latents (bool): Whether to return style latents. Default: False.
67
"""
68
# style codes -> latents with Style MLP layer
69
if not input_is_latent:
70
styles = [self.style_mlp(s) for s in styles]
71
# noises
72
if noise is None:
73
if randomize_noise:
74
noise = [None] * self.num_layers # for each style conv layer
75
else: # use the stored noise
76
noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
77
# style truncation
78
if truncation < 1:
79
style_truncation = []
80
for style in styles:
81
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
82
styles = style_truncation
83
# get style latents with injection
84
if len(styles) == 1:
85
inject_index = self.num_latent
86
87
if styles[0].ndim < 3:
88
# repeat latent code for all the layers
89
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
90
else: # used for encoder with different latent code for each layer
91
latent = styles[0]
92
elif len(styles) == 2: # mixing noises
93
if inject_index is None:
94
inject_index = random.randint(1, self.num_latent - 1)
95
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
96
latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
97
latent = torch.cat([latent1, latent2], 1)
98
99
# main generation
100
out = self.constant_input(latent.shape[0])
101
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
102
skip = self.to_rgb1(out, latent[:, 1])
103
104
i = 1
105
for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
106
noise[2::2], self.to_rgbs):
107
out = conv1(out, latent[:, i], noise=noise1)
108
109
# the conditions may have fewer levels
110
if i < len(conditions):
111
# SFT part to combine the conditions
112
if self.sft_half: # only apply SFT to half of the channels
113
out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
114
out_sft = out_sft * conditions[i - 1] + conditions[i]
115
out = torch.cat([out_same, out_sft], dim=1)
116
else: # apply SFT to all the channels
117
out = out * conditions[i - 1] + conditions[i]
118
119
out = conv2(out, latent[:, i + 1], noise=noise2)
120
skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
121
i += 2
122
123
image = skip
124
125
if return_latents:
126
return image, latent
127
else:
128
return image, None
129
130
131
@ARCH_REGISTRY.register()
132
class GFPGANBilinear(nn.Module):
133
"""The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
134
135
It is the bilinear version and it does not use the complicated UpFirDnSmooth function that is not friendly for
136
deployment. It can be easily converted to the clean version: GFPGANv1Clean.
137
138
139
Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
140
141
Args:
142
out_size (int): The spatial size of outputs.
143
num_style_feat (int): Channel number of style features. Default: 512.
144
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
145
decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
146
fix_decoder (bool): Whether to fix the decoder. Default: True.
147
148
num_mlp (int): Layer number of MLP style layers. Default: 8.
149
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
150
input_is_latent (bool): Whether input is latent style. Default: False.
151
different_w (bool): Whether to use different latent w for different layers. Default: False.
152
narrow (float): The narrow ratio for channels. Default: 1.
153
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
154
"""
155
156
def __init__(
157
self,
158
out_size,
159
num_style_feat=512,
160
channel_multiplier=1,
161
decoder_load_path=None,
162
fix_decoder=True,
163
# for stylegan decoder
164
num_mlp=8,
165
lr_mlp=0.01,
166
input_is_latent=False,
167
different_w=False,
168
narrow=1,
169
sft_half=False):
170
171
super(GFPGANBilinear, self).__init__()
172
self.input_is_latent = input_is_latent
173
self.different_w = different_w
174
self.num_style_feat = num_style_feat
175
176
unet_narrow = narrow * 0.5 # by default, use a half of input channels
177
channels = {
178
'4': int(512 * unet_narrow),
179
'8': int(512 * unet_narrow),
180
'16': int(512 * unet_narrow),
181
'32': int(512 * unet_narrow),
182
'64': int(256 * channel_multiplier * unet_narrow),
183
'128': int(128 * channel_multiplier * unet_narrow),
184
'256': int(64 * channel_multiplier * unet_narrow),
185
'512': int(32 * channel_multiplier * unet_narrow),
186
'1024': int(16 * channel_multiplier * unet_narrow)
187
}
188
189
self.log_size = int(math.log(out_size, 2))
190
first_out_size = 2**(int(math.log(out_size, 2)))
191
192
self.conv_body_first = ConvLayer(3, channels[f'{first_out_size}'], 1, bias=True, activate=True)
193
194
# downsample
195
in_channels = channels[f'{first_out_size}']
196
self.conv_body_down = nn.ModuleList()
197
for i in range(self.log_size, 2, -1):
198
out_channels = channels[f'{2**(i - 1)}']
199
self.conv_body_down.append(ResBlock(in_channels, out_channels))
200
in_channels = out_channels
201
202
self.final_conv = ConvLayer(in_channels, channels['4'], 3, bias=True, activate=True)
203
204
# upsample
205
in_channels = channels['4']
206
self.conv_body_up = nn.ModuleList()
207
for i in range(3, self.log_size + 1):
208
out_channels = channels[f'{2**i}']
209
self.conv_body_up.append(ResUpBlock(in_channels, out_channels))
210
in_channels = out_channels
211
212
# to RGB
213
self.toRGB = nn.ModuleList()
214
for i in range(3, self.log_size + 1):
215
self.toRGB.append(EqualConv2d(channels[f'{2**i}'], 3, 1, stride=1, padding=0, bias=True, bias_init_val=0))
216
217
if different_w:
218
linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
219
else:
220
linear_out_channel = num_style_feat
221
222
self.final_linear = EqualLinear(
223
channels['4'] * 4 * 4, linear_out_channel, bias=True, bias_init_val=0, lr_mul=1, activation=None)
224
225
# the decoder: stylegan2 generator with SFT modulations
226
self.stylegan_decoder = StyleGAN2GeneratorBilinearSFT(
227
out_size=out_size,
228
num_style_feat=num_style_feat,
229
num_mlp=num_mlp,
230
channel_multiplier=channel_multiplier,
231
lr_mlp=lr_mlp,
232
narrow=narrow,
233
sft_half=sft_half)
234
235
# load pre-trained stylegan2 model if necessary
236
if decoder_load_path:
237
self.stylegan_decoder.load_state_dict(
238
torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
239
# fix decoder without updating params
240
if fix_decoder:
241
for _, param in self.stylegan_decoder.named_parameters():
242
param.requires_grad = False
243
244
# for SFT modulations (scale and shift)
245
self.condition_scale = nn.ModuleList()
246
self.condition_shift = nn.ModuleList()
247
for i in range(3, self.log_size + 1):
248
out_channels = channels[f'{2**i}']
249
if sft_half:
250
sft_out_channels = out_channels
251
else:
252
sft_out_channels = out_channels * 2
253
self.condition_scale.append(
254
nn.Sequential(
255
EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0),
256
ScaledLeakyReLU(0.2),
257
EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=1)))
258
self.condition_shift.append(
259
nn.Sequential(
260
EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0),
261
ScaledLeakyReLU(0.2),
262
EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0)))
263
264
def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True):
265
"""Forward function for GFPGANBilinear.
266
267
Args:
268
x (Tensor): Input images.
269
return_latents (bool): Whether to return style latents. Default: False.
270
return_rgb (bool): Whether return intermediate rgb images. Default: True.
271
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
272
"""
273
conditions = []
274
unet_skips = []
275
out_rgbs = []
276
277
# encoder
278
feat = self.conv_body_first(x)
279
for i in range(self.log_size - 2):
280
feat = self.conv_body_down[i](feat)
281
unet_skips.insert(0, feat)
282
283
feat = self.final_conv(feat)
284
285
# style code
286
style_code = self.final_linear(feat.view(feat.size(0), -1))
287
if self.different_w:
288
style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
289
290
# decode
291
for i in range(self.log_size - 2):
292
# add unet skip
293
feat = feat + unet_skips[i]
294
# ResUpLayer
295
feat = self.conv_body_up[i](feat)
296
# generate scale and shift for SFT layers
297
scale = self.condition_scale[i](feat)
298
conditions.append(scale.clone())
299
shift = self.condition_shift[i](feat)
300
conditions.append(shift.clone())
301
# generate rgb images
302
if return_rgb:
303
out_rgbs.append(self.toRGB[i](feat))
304
305
# decoder
306
image, _ = self.stylegan_decoder([style_code],
307
conditions,
308
return_latents=return_latents,
309
input_is_latent=self.input_is_latent,
310
randomize_noise=randomize_noise)
311
312
return image, out_rgbs
313
314