Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TencentARC
GitHub Repository: TencentARC/GFPGAN
Path: blob/master/gfpgan/archs/stylegan2_clean_arch.py
884 views
1
import math
2
import random
3
import torch
4
from basicsr.archs.arch_util import default_init_weights
5
from basicsr.utils.registry import ARCH_REGISTRY
6
from torch import nn
7
from torch.nn import functional as F
8
9
10
class NormStyleCode(nn.Module):
11
12
def forward(self, x):
13
"""Normalize the style codes.
14
15
Args:
16
x (Tensor): Style codes with shape (b, c).
17
18
Returns:
19
Tensor: Normalized tensor.
20
"""
21
return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
22
23
24
class ModulatedConv2d(nn.Module):
25
"""Modulated Conv2d used in StyleGAN2.
26
27
There is no bias in ModulatedConv2d.
28
29
Args:
30
in_channels (int): Channel number of the input.
31
out_channels (int): Channel number of the output.
32
kernel_size (int): Size of the convolving kernel.
33
num_style_feat (int): Channel number of style features.
34
demodulate (bool): Whether to demodulate in the conv layer. Default: True.
35
sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
36
eps (float): A value added to the denominator for numerical stability. Default: 1e-8.
37
"""
38
39
def __init__(self,
40
in_channels,
41
out_channels,
42
kernel_size,
43
num_style_feat,
44
demodulate=True,
45
sample_mode=None,
46
eps=1e-8):
47
super(ModulatedConv2d, self).__init__()
48
self.in_channels = in_channels
49
self.out_channels = out_channels
50
self.kernel_size = kernel_size
51
self.demodulate = demodulate
52
self.sample_mode = sample_mode
53
self.eps = eps
54
55
# modulation inside each modulated conv
56
self.modulation = nn.Linear(num_style_feat, in_channels, bias=True)
57
# initialization
58
default_init_weights(self.modulation, scale=1, bias_fill=1, a=0, mode='fan_in', nonlinearity='linear')
59
60
self.weight = nn.Parameter(
61
torch.randn(1, out_channels, in_channels, kernel_size, kernel_size) /
62
math.sqrt(in_channels * kernel_size**2))
63
self.padding = kernel_size // 2
64
65
def forward(self, x, style):
66
"""Forward function.
67
68
Args:
69
x (Tensor): Tensor with shape (b, c, h, w).
70
style (Tensor): Tensor with shape (b, num_style_feat).
71
72
Returns:
73
Tensor: Modulated tensor after convolution.
74
"""
75
b, c, h, w = x.shape # c = c_in
76
# weight modulation
77
style = self.modulation(style).view(b, 1, c, 1, 1)
78
# self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
79
weight = self.weight * style # (b, c_out, c_in, k, k)
80
81
if self.demodulate:
82
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
83
weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
84
85
weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)
86
87
# upsample or downsample if necessary
88
if self.sample_mode == 'upsample':
89
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
90
elif self.sample_mode == 'downsample':
91
x = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False)
92
93
b, c, h, w = x.shape
94
x = x.view(1, b * c, h, w)
95
# weight: (b*c_out, c_in, k, k), groups=b
96
out = F.conv2d(x, weight, padding=self.padding, groups=b)
97
out = out.view(b, self.out_channels, *out.shape[2:4])
98
99
return out
100
101
def __repr__(self):
102
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, out_channels={self.out_channels}, '
103
f'kernel_size={self.kernel_size}, demodulate={self.demodulate}, sample_mode={self.sample_mode})')
104
105
106
class StyleConv(nn.Module):
107
"""Style conv used in StyleGAN2.
108
109
Args:
110
in_channels (int): Channel number of the input.
111
out_channels (int): Channel number of the output.
112
kernel_size (int): Size of the convolving kernel.
113
num_style_feat (int): Channel number of style features.
114
demodulate (bool): Whether demodulate in the conv layer. Default: True.
115
sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
116
"""
117
118
def __init__(self, in_channels, out_channels, kernel_size, num_style_feat, demodulate=True, sample_mode=None):
119
super(StyleConv, self).__init__()
120
self.modulated_conv = ModulatedConv2d(
121
in_channels, out_channels, kernel_size, num_style_feat, demodulate=demodulate, sample_mode=sample_mode)
122
self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
123
self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
124
self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
125
126
def forward(self, x, style, noise=None):
127
# modulate
128
out = self.modulated_conv(x, style) * 2**0.5 # for conversion
129
# noise injection
130
if noise is None:
131
b, _, h, w = out.shape
132
noise = out.new_empty(b, 1, h, w).normal_()
133
out = out + self.weight * noise
134
# add bias
135
out = out + self.bias
136
# activation
137
out = self.activate(out)
138
return out
139
140
141
class ToRGB(nn.Module):
142
"""To RGB (image space) from features.
143
144
Args:
145
in_channels (int): Channel number of input.
146
num_style_feat (int): Channel number of style features.
147
upsample (bool): Whether to upsample. Default: True.
148
"""
149
150
def __init__(self, in_channels, num_style_feat, upsample=True):
151
super(ToRGB, self).__init__()
152
self.upsample = upsample
153
self.modulated_conv = ModulatedConv2d(
154
in_channels, 3, kernel_size=1, num_style_feat=num_style_feat, demodulate=False, sample_mode=None)
155
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
156
157
def forward(self, x, style, skip=None):
158
"""Forward function.
159
160
Args:
161
x (Tensor): Feature tensor with shape (b, c, h, w).
162
style (Tensor): Tensor with shape (b, num_style_feat).
163
skip (Tensor): Base/skip tensor. Default: None.
164
165
Returns:
166
Tensor: RGB images.
167
"""
168
out = self.modulated_conv(x, style)
169
out = out + self.bias
170
if skip is not None:
171
if self.upsample:
172
skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False)
173
out = out + skip
174
return out
175
176
177
class ConstantInput(nn.Module):
178
"""Constant input.
179
180
Args:
181
num_channel (int): Channel number of constant input.
182
size (int): Spatial size of constant input.
183
"""
184
185
def __init__(self, num_channel, size):
186
super(ConstantInput, self).__init__()
187
self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))
188
189
def forward(self, batch):
190
out = self.weight.repeat(batch, 1, 1, 1)
191
return out
192
193
194
@ARCH_REGISTRY.register()
195
class StyleGAN2GeneratorClean(nn.Module):
196
"""Clean version of StyleGAN2 Generator.
197
198
Args:
199
out_size (int): The spatial size of outputs.
200
num_style_feat (int): Channel number of style features. Default: 512.
201
num_mlp (int): Layer number of MLP style layers. Default: 8.
202
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
203
narrow (float): Narrow ratio for channels. Default: 1.0.
204
"""
205
206
def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1):
207
super(StyleGAN2GeneratorClean, self).__init__()
208
# Style MLP layers
209
self.num_style_feat = num_style_feat
210
style_mlp_layers = [NormStyleCode()]
211
for i in range(num_mlp):
212
style_mlp_layers.extend(
213
[nn.Linear(num_style_feat, num_style_feat, bias=True),
214
nn.LeakyReLU(negative_slope=0.2, inplace=True)])
215
self.style_mlp = nn.Sequential(*style_mlp_layers)
216
# initialization
217
default_init_weights(self.style_mlp, scale=1, bias_fill=0, a=0.2, mode='fan_in', nonlinearity='leaky_relu')
218
219
# channel list
220
channels = {
221
'4': int(512 * narrow),
222
'8': int(512 * narrow),
223
'16': int(512 * narrow),
224
'32': int(512 * narrow),
225
'64': int(256 * channel_multiplier * narrow),
226
'128': int(128 * channel_multiplier * narrow),
227
'256': int(64 * channel_multiplier * narrow),
228
'512': int(32 * channel_multiplier * narrow),
229
'1024': int(16 * channel_multiplier * narrow)
230
}
231
self.channels = channels
232
233
self.constant_input = ConstantInput(channels['4'], size=4)
234
self.style_conv1 = StyleConv(
235
channels['4'],
236
channels['4'],
237
kernel_size=3,
238
num_style_feat=num_style_feat,
239
demodulate=True,
240
sample_mode=None)
241
self.to_rgb1 = ToRGB(channels['4'], num_style_feat, upsample=False)
242
243
self.log_size = int(math.log(out_size, 2))
244
self.num_layers = (self.log_size - 2) * 2 + 1
245
self.num_latent = self.log_size * 2 - 2
246
247
self.style_convs = nn.ModuleList()
248
self.to_rgbs = nn.ModuleList()
249
self.noises = nn.Module()
250
251
in_channels = channels['4']
252
# noise
253
for layer_idx in range(self.num_layers):
254
resolution = 2**((layer_idx + 5) // 2)
255
shape = [1, 1, resolution, resolution]
256
self.noises.register_buffer(f'noise{layer_idx}', torch.randn(*shape))
257
# style convs and to_rgbs
258
for i in range(3, self.log_size + 1):
259
out_channels = channels[f'{2**i}']
260
self.style_convs.append(
261
StyleConv(
262
in_channels,
263
out_channels,
264
kernel_size=3,
265
num_style_feat=num_style_feat,
266
demodulate=True,
267
sample_mode='upsample'))
268
self.style_convs.append(
269
StyleConv(
270
out_channels,
271
out_channels,
272
kernel_size=3,
273
num_style_feat=num_style_feat,
274
demodulate=True,
275
sample_mode=None))
276
self.to_rgbs.append(ToRGB(out_channels, num_style_feat, upsample=True))
277
in_channels = out_channels
278
279
def make_noise(self):
280
"""Make noise for noise injection."""
281
device = self.constant_input.weight.device
282
noises = [torch.randn(1, 1, 4, 4, device=device)]
283
284
for i in range(3, self.log_size + 1):
285
for _ in range(2):
286
noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
287
288
return noises
289
290
def get_latent(self, x):
291
return self.style_mlp(x)
292
293
def mean_latent(self, num_latent):
294
latent_in = torch.randn(num_latent, self.num_style_feat, device=self.constant_input.weight.device)
295
latent = self.style_mlp(latent_in).mean(0, keepdim=True)
296
return latent
297
298
def forward(self,
299
styles,
300
input_is_latent=False,
301
noise=None,
302
randomize_noise=True,
303
truncation=1,
304
truncation_latent=None,
305
inject_index=None,
306
return_latents=False):
307
"""Forward function for StyleGAN2GeneratorClean.
308
309
Args:
310
styles (list[Tensor]): Sample codes of styles.
311
input_is_latent (bool): Whether input is latent style. Default: False.
312
noise (Tensor | None): Input noise or None. Default: None.
313
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
314
truncation (float): The truncation ratio. Default: 1.
315
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
316
inject_index (int | None): The injection index for mixing noise. Default: None.
317
return_latents (bool): Whether to return style latents. Default: False.
318
"""
319
# style codes -> latents with Style MLP layer
320
if not input_is_latent:
321
styles = [self.style_mlp(s) for s in styles]
322
# noises
323
if noise is None:
324
if randomize_noise:
325
noise = [None] * self.num_layers # for each style conv layer
326
else: # use the stored noise
327
noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
328
# style truncation
329
if truncation < 1:
330
style_truncation = []
331
for style in styles:
332
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
333
styles = style_truncation
334
# get style latents with injection
335
if len(styles) == 1:
336
inject_index = self.num_latent
337
338
if styles[0].ndim < 3:
339
# repeat latent code for all the layers
340
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
341
else: # used for encoder with different latent code for each layer
342
latent = styles[0]
343
elif len(styles) == 2: # mixing noises
344
if inject_index is None:
345
inject_index = random.randint(1, self.num_latent - 1)
346
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
347
latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
348
latent = torch.cat([latent1, latent2], 1)
349
350
# main generation
351
out = self.constant_input(latent.shape[0])
352
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
353
skip = self.to_rgb1(out, latent[:, 1])
354
355
i = 1
356
for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
357
noise[2::2], self.to_rgbs):
358
out = conv1(out, latent[:, i], noise=noise1)
359
out = conv2(out, latent[:, i + 1], noise=noise2)
360
skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
361
i += 2
362
363
image = skip
364
365
if return_latents:
366
return image, latent
367
else:
368
return image, None
369
370