Path: blob/master/gfpgan/archs/stylegan2_clean_arch.py
884 views
import math1import random2import torch3from basicsr.archs.arch_util import default_init_weights4from basicsr.utils.registry import ARCH_REGISTRY5from torch import nn6from torch.nn import functional as F789class NormStyleCode(nn.Module):1011def forward(self, x):12"""Normalize the style codes.1314Args:15x (Tensor): Style codes with shape (b, c).1617Returns:18Tensor: Normalized tensor.19"""20return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)212223class ModulatedConv2d(nn.Module):24"""Modulated Conv2d used in StyleGAN2.2526There is no bias in ModulatedConv2d.2728Args:29in_channels (int): Channel number of the input.30out_channels (int): Channel number of the output.31kernel_size (int): Size of the convolving kernel.32num_style_feat (int): Channel number of style features.33demodulate (bool): Whether to demodulate in the conv layer. Default: True.34sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.35eps (float): A value added to the denominator for numerical stability. Default: 1e-8.36"""3738def __init__(self,39in_channels,40out_channels,41kernel_size,42num_style_feat,43demodulate=True,44sample_mode=None,45eps=1e-8):46super(ModulatedConv2d, self).__init__()47self.in_channels = in_channels48self.out_channels = out_channels49self.kernel_size = kernel_size50self.demodulate = demodulate51self.sample_mode = sample_mode52self.eps = eps5354# modulation inside each modulated conv55self.modulation = nn.Linear(num_style_feat, in_channels, bias=True)56# initialization57default_init_weights(self.modulation, scale=1, bias_fill=1, a=0, mode='fan_in', nonlinearity='linear')5859self.weight = nn.Parameter(60torch.randn(1, out_channels, in_channels, kernel_size, kernel_size) /61math.sqrt(in_channels * kernel_size**2))62self.padding = kernel_size // 26364def forward(self, x, style):65"""Forward function.6667Args:68x (Tensor): Tensor with shape (b, c, h, w).69style (Tensor): Tensor with shape (b, num_style_feat).7071Returns:72Tensor: Modulated tensor after convolution.73"""74b, c, h, w = x.shape # c = c_in75# weight modulation76style = self.modulation(style).view(b, 1, c, 1, 1)77# self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)78weight = self.weight * style # (b, c_out, c_in, k, k)7980if self.demodulate:81demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)82weight = weight * demod.view(b, self.out_channels, 1, 1, 1)8384weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)8586# upsample or downsample if necessary87if self.sample_mode == 'upsample':88x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)89elif self.sample_mode == 'downsample':90x = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False)9192b, c, h, w = x.shape93x = x.view(1, b * c, h, w)94# weight: (b*c_out, c_in, k, k), groups=b95out = F.conv2d(x, weight, padding=self.padding, groups=b)96out = out.view(b, self.out_channels, *out.shape[2:4])9798return out99100def __repr__(self):101return (f'{self.__class__.__name__}(in_channels={self.in_channels}, out_channels={self.out_channels}, '102f'kernel_size={self.kernel_size}, demodulate={self.demodulate}, sample_mode={self.sample_mode})')103104105class StyleConv(nn.Module):106"""Style conv used in StyleGAN2.107108Args:109in_channels (int): Channel number of the input.110out_channels (int): Channel number of the output.111kernel_size (int): Size of the convolving kernel.112num_style_feat (int): Channel number of style features.113demodulate (bool): Whether demodulate in the conv layer. Default: True.114sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.115"""116117def __init__(self, in_channels, out_channels, kernel_size, num_style_feat, demodulate=True, sample_mode=None):118super(StyleConv, self).__init__()119self.modulated_conv = ModulatedConv2d(120in_channels, out_channels, kernel_size, num_style_feat, demodulate=demodulate, sample_mode=sample_mode)121self.weight = nn.Parameter(torch.zeros(1)) # for noise injection122self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))123self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)124125def forward(self, x, style, noise=None):126# modulate127out = self.modulated_conv(x, style) * 2**0.5 # for conversion128# noise injection129if noise is None:130b, _, h, w = out.shape131noise = out.new_empty(b, 1, h, w).normal_()132out = out + self.weight * noise133# add bias134out = out + self.bias135# activation136out = self.activate(out)137return out138139140class ToRGB(nn.Module):141"""To RGB (image space) from features.142143Args:144in_channels (int): Channel number of input.145num_style_feat (int): Channel number of style features.146upsample (bool): Whether to upsample. Default: True.147"""148149def __init__(self, in_channels, num_style_feat, upsample=True):150super(ToRGB, self).__init__()151self.upsample = upsample152self.modulated_conv = ModulatedConv2d(153in_channels, 3, kernel_size=1, num_style_feat=num_style_feat, demodulate=False, sample_mode=None)154self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))155156def forward(self, x, style, skip=None):157"""Forward function.158159Args:160x (Tensor): Feature tensor with shape (b, c, h, w).161style (Tensor): Tensor with shape (b, num_style_feat).162skip (Tensor): Base/skip tensor. Default: None.163164Returns:165Tensor: RGB images.166"""167out = self.modulated_conv(x, style)168out = out + self.bias169if skip is not None:170if self.upsample:171skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False)172out = out + skip173return out174175176class ConstantInput(nn.Module):177"""Constant input.178179Args:180num_channel (int): Channel number of constant input.181size (int): Spatial size of constant input.182"""183184def __init__(self, num_channel, size):185super(ConstantInput, self).__init__()186self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))187188def forward(self, batch):189out = self.weight.repeat(batch, 1, 1, 1)190return out191192193@ARCH_REGISTRY.register()194class StyleGAN2GeneratorClean(nn.Module):195"""Clean version of StyleGAN2 Generator.196197Args:198out_size (int): The spatial size of outputs.199num_style_feat (int): Channel number of style features. Default: 512.200num_mlp (int): Layer number of MLP style layers. Default: 8.201channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.202narrow (float): Narrow ratio for channels. Default: 1.0.203"""204205def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1):206super(StyleGAN2GeneratorClean, self).__init__()207# Style MLP layers208self.num_style_feat = num_style_feat209style_mlp_layers = [NormStyleCode()]210for i in range(num_mlp):211style_mlp_layers.extend(212[nn.Linear(num_style_feat, num_style_feat, bias=True),213nn.LeakyReLU(negative_slope=0.2, inplace=True)])214self.style_mlp = nn.Sequential(*style_mlp_layers)215# initialization216default_init_weights(self.style_mlp, scale=1, bias_fill=0, a=0.2, mode='fan_in', nonlinearity='leaky_relu')217218# channel list219channels = {220'4': int(512 * narrow),221'8': int(512 * narrow),222'16': int(512 * narrow),223'32': int(512 * narrow),224'64': int(256 * channel_multiplier * narrow),225'128': int(128 * channel_multiplier * narrow),226'256': int(64 * channel_multiplier * narrow),227'512': int(32 * channel_multiplier * narrow),228'1024': int(16 * channel_multiplier * narrow)229}230self.channels = channels231232self.constant_input = ConstantInput(channels['4'], size=4)233self.style_conv1 = StyleConv(234channels['4'],235channels['4'],236kernel_size=3,237num_style_feat=num_style_feat,238demodulate=True,239sample_mode=None)240self.to_rgb1 = ToRGB(channels['4'], num_style_feat, upsample=False)241242self.log_size = int(math.log(out_size, 2))243self.num_layers = (self.log_size - 2) * 2 + 1244self.num_latent = self.log_size * 2 - 2245246self.style_convs = nn.ModuleList()247self.to_rgbs = nn.ModuleList()248self.noises = nn.Module()249250in_channels = channels['4']251# noise252for layer_idx in range(self.num_layers):253resolution = 2**((layer_idx + 5) // 2)254shape = [1, 1, resolution, resolution]255self.noises.register_buffer(f'noise{layer_idx}', torch.randn(*shape))256# style convs and to_rgbs257for i in range(3, self.log_size + 1):258out_channels = channels[f'{2**i}']259self.style_convs.append(260StyleConv(261in_channels,262out_channels,263kernel_size=3,264num_style_feat=num_style_feat,265demodulate=True,266sample_mode='upsample'))267self.style_convs.append(268StyleConv(269out_channels,270out_channels,271kernel_size=3,272num_style_feat=num_style_feat,273demodulate=True,274sample_mode=None))275self.to_rgbs.append(ToRGB(out_channels, num_style_feat, upsample=True))276in_channels = out_channels277278def make_noise(self):279"""Make noise for noise injection."""280device = self.constant_input.weight.device281noises = [torch.randn(1, 1, 4, 4, device=device)]282283for i in range(3, self.log_size + 1):284for _ in range(2):285noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))286287return noises288289def get_latent(self, x):290return self.style_mlp(x)291292def mean_latent(self, num_latent):293latent_in = torch.randn(num_latent, self.num_style_feat, device=self.constant_input.weight.device)294latent = self.style_mlp(latent_in).mean(0, keepdim=True)295return latent296297def forward(self,298styles,299input_is_latent=False,300noise=None,301randomize_noise=True,302truncation=1,303truncation_latent=None,304inject_index=None,305return_latents=False):306"""Forward function for StyleGAN2GeneratorClean.307308Args:309styles (list[Tensor]): Sample codes of styles.310input_is_latent (bool): Whether input is latent style. Default: False.311noise (Tensor | None): Input noise or None. Default: None.312randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.313truncation (float): The truncation ratio. Default: 1.314truncation_latent (Tensor | None): The truncation latent tensor. Default: None.315inject_index (int | None): The injection index for mixing noise. Default: None.316return_latents (bool): Whether to return style latents. Default: False.317"""318# style codes -> latents with Style MLP layer319if not input_is_latent:320styles = [self.style_mlp(s) for s in styles]321# noises322if noise is None:323if randomize_noise:324noise = [None] * self.num_layers # for each style conv layer325else: # use the stored noise326noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]327# style truncation328if truncation < 1:329style_truncation = []330for style in styles:331style_truncation.append(truncation_latent + truncation * (style - truncation_latent))332styles = style_truncation333# get style latents with injection334if len(styles) == 1:335inject_index = self.num_latent336337if styles[0].ndim < 3:338# repeat latent code for all the layers339latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)340else: # used for encoder with different latent code for each layer341latent = styles[0]342elif len(styles) == 2: # mixing noises343if inject_index is None:344inject_index = random.randint(1, self.num_latent - 1)345latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)346latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)347latent = torch.cat([latent1, latent2], 1)348349# main generation350out = self.constant_input(latent.shape[0])351out = self.style_conv1(out, latent[:, 0], noise=noise[0])352skip = self.to_rgb1(out, latent[:, 1])353354i = 1355for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],356noise[2::2], self.to_rgbs):357out = conv1(out, latent[:, i], noise=noise1)358out = conv2(out, latent[:, i + 1], noise=noise2)359skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space360i += 2361362image = skip363364if return_latents:365return image, latent366else:367return image, None368369370