Path: blob/master/gfpgan/archs/gfpgan_bilinear_arch.py
884 views
import math1import random2import torch3from basicsr.utils.registry import ARCH_REGISTRY4from torch import nn56from .gfpganv1_arch import ResUpBlock7from .stylegan2_bilinear_arch import (ConvLayer, EqualConv2d, EqualLinear, ResBlock, ScaledLeakyReLU,8StyleGAN2GeneratorBilinear)91011class StyleGAN2GeneratorBilinearSFT(StyleGAN2GeneratorBilinear):12"""StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).1314It is the bilinear version. It does not use the complicated UpFirDnSmooth function that is not friendly for15deployment. It can be easily converted to the clean version: StyleGAN2GeneratorCSFT.1617Args:18out_size (int): The spatial size of outputs.19num_style_feat (int): Channel number of style features. Default: 512.20num_mlp (int): Layer number of MLP style layers. Default: 8.21channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.22lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.23narrow (float): The narrow ratio for channels. Default: 1.24sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.25"""2627def __init__(self,28out_size,29num_style_feat=512,30num_mlp=8,31channel_multiplier=2,32lr_mlp=0.01,33narrow=1,34sft_half=False):35super(StyleGAN2GeneratorBilinearSFT, self).__init__(36out_size,37num_style_feat=num_style_feat,38num_mlp=num_mlp,39channel_multiplier=channel_multiplier,40lr_mlp=lr_mlp,41narrow=narrow)42self.sft_half = sft_half4344def forward(self,45styles,46conditions,47input_is_latent=False,48noise=None,49randomize_noise=True,50truncation=1,51truncation_latent=None,52inject_index=None,53return_latents=False):54"""Forward function for StyleGAN2GeneratorBilinearSFT.5556Args:57styles (list[Tensor]): Sample codes of styles.58conditions (list[Tensor]): SFT conditions to generators.59input_is_latent (bool): Whether input is latent style. Default: False.60noise (Tensor | None): Input noise or None. Default: None.61randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.62truncation (float): The truncation ratio. Default: 1.63truncation_latent (Tensor | None): The truncation latent tensor. Default: None.64inject_index (int | None): The injection index for mixing noise. Default: None.65return_latents (bool): Whether to return style latents. Default: False.66"""67# style codes -> latents with Style MLP layer68if not input_is_latent:69styles = [self.style_mlp(s) for s in styles]70# noises71if noise is None:72if randomize_noise:73noise = [None] * self.num_layers # for each style conv layer74else: # use the stored noise75noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]76# style truncation77if truncation < 1:78style_truncation = []79for style in styles:80style_truncation.append(truncation_latent + truncation * (style - truncation_latent))81styles = style_truncation82# get style latents with injection83if len(styles) == 1:84inject_index = self.num_latent8586if styles[0].ndim < 3:87# repeat latent code for all the layers88latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)89else: # used for encoder with different latent code for each layer90latent = styles[0]91elif len(styles) == 2: # mixing noises92if inject_index is None:93inject_index = random.randint(1, self.num_latent - 1)94latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)95latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)96latent = torch.cat([latent1, latent2], 1)9798# main generation99out = self.constant_input(latent.shape[0])100out = self.style_conv1(out, latent[:, 0], noise=noise[0])101skip = self.to_rgb1(out, latent[:, 1])102103i = 1104for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],105noise[2::2], self.to_rgbs):106out = conv1(out, latent[:, i], noise=noise1)107108# the conditions may have fewer levels109if i < len(conditions):110# SFT part to combine the conditions111if self.sft_half: # only apply SFT to half of the channels112out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)113out_sft = out_sft * conditions[i - 1] + conditions[i]114out = torch.cat([out_same, out_sft], dim=1)115else: # apply SFT to all the channels116out = out * conditions[i - 1] + conditions[i]117118out = conv2(out, latent[:, i + 1], noise=noise2)119skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space120i += 2121122image = skip123124if return_latents:125return image, latent126else:127return image, None128129130@ARCH_REGISTRY.register()131class GFPGANBilinear(nn.Module):132"""The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.133134It is the bilinear version and it does not use the complicated UpFirDnSmooth function that is not friendly for135deployment. It can be easily converted to the clean version: GFPGANv1Clean.136137138Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.139140Args:141out_size (int): The spatial size of outputs.142num_style_feat (int): Channel number of style features. Default: 512.143channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.144decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.145fix_decoder (bool): Whether to fix the decoder. Default: True.146147num_mlp (int): Layer number of MLP style layers. Default: 8.148lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.149input_is_latent (bool): Whether input is latent style. Default: False.150different_w (bool): Whether to use different latent w for different layers. Default: False.151narrow (float): The narrow ratio for channels. Default: 1.152sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.153"""154155def __init__(156self,157out_size,158num_style_feat=512,159channel_multiplier=1,160decoder_load_path=None,161fix_decoder=True,162# for stylegan decoder163num_mlp=8,164lr_mlp=0.01,165input_is_latent=False,166different_w=False,167narrow=1,168sft_half=False):169170super(GFPGANBilinear, self).__init__()171self.input_is_latent = input_is_latent172self.different_w = different_w173self.num_style_feat = num_style_feat174175unet_narrow = narrow * 0.5 # by default, use a half of input channels176channels = {177'4': int(512 * unet_narrow),178'8': int(512 * unet_narrow),179'16': int(512 * unet_narrow),180'32': int(512 * unet_narrow),181'64': int(256 * channel_multiplier * unet_narrow),182'128': int(128 * channel_multiplier * unet_narrow),183'256': int(64 * channel_multiplier * unet_narrow),184'512': int(32 * channel_multiplier * unet_narrow),185'1024': int(16 * channel_multiplier * unet_narrow)186}187188self.log_size = int(math.log(out_size, 2))189first_out_size = 2**(int(math.log(out_size, 2)))190191self.conv_body_first = ConvLayer(3, channels[f'{first_out_size}'], 1, bias=True, activate=True)192193# downsample194in_channels = channels[f'{first_out_size}']195self.conv_body_down = nn.ModuleList()196for i in range(self.log_size, 2, -1):197out_channels = channels[f'{2**(i - 1)}']198self.conv_body_down.append(ResBlock(in_channels, out_channels))199in_channels = out_channels200201self.final_conv = ConvLayer(in_channels, channels['4'], 3, bias=True, activate=True)202203# upsample204in_channels = channels['4']205self.conv_body_up = nn.ModuleList()206for i in range(3, self.log_size + 1):207out_channels = channels[f'{2**i}']208self.conv_body_up.append(ResUpBlock(in_channels, out_channels))209in_channels = out_channels210211# to RGB212self.toRGB = nn.ModuleList()213for i in range(3, self.log_size + 1):214self.toRGB.append(EqualConv2d(channels[f'{2**i}'], 3, 1, stride=1, padding=0, bias=True, bias_init_val=0))215216if different_w:217linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat218else:219linear_out_channel = num_style_feat220221self.final_linear = EqualLinear(222channels['4'] * 4 * 4, linear_out_channel, bias=True, bias_init_val=0, lr_mul=1, activation=None)223224# the decoder: stylegan2 generator with SFT modulations225self.stylegan_decoder = StyleGAN2GeneratorBilinearSFT(226out_size=out_size,227num_style_feat=num_style_feat,228num_mlp=num_mlp,229channel_multiplier=channel_multiplier,230lr_mlp=lr_mlp,231narrow=narrow,232sft_half=sft_half)233234# load pre-trained stylegan2 model if necessary235if decoder_load_path:236self.stylegan_decoder.load_state_dict(237torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])238# fix decoder without updating params239if fix_decoder:240for _, param in self.stylegan_decoder.named_parameters():241param.requires_grad = False242243# for SFT modulations (scale and shift)244self.condition_scale = nn.ModuleList()245self.condition_shift = nn.ModuleList()246for i in range(3, self.log_size + 1):247out_channels = channels[f'{2**i}']248if sft_half:249sft_out_channels = out_channels250else:251sft_out_channels = out_channels * 2252self.condition_scale.append(253nn.Sequential(254EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0),255ScaledLeakyReLU(0.2),256EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=1)))257self.condition_shift.append(258nn.Sequential(259EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0),260ScaledLeakyReLU(0.2),261EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0)))262263def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True):264"""Forward function for GFPGANBilinear.265266Args:267x (Tensor): Input images.268return_latents (bool): Whether to return style latents. Default: False.269return_rgb (bool): Whether return intermediate rgb images. Default: True.270randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.271"""272conditions = []273unet_skips = []274out_rgbs = []275276# encoder277feat = self.conv_body_first(x)278for i in range(self.log_size - 2):279feat = self.conv_body_down[i](feat)280unet_skips.insert(0, feat)281282feat = self.final_conv(feat)283284# style code285style_code = self.final_linear(feat.view(feat.size(0), -1))286if self.different_w:287style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)288289# decode290for i in range(self.log_size - 2):291# add unet skip292feat = feat + unet_skips[i]293# ResUpLayer294feat = self.conv_body_up[i](feat)295# generate scale and shift for SFT layers296scale = self.condition_scale[i](feat)297conditions.append(scale.clone())298shift = self.condition_shift[i](feat)299conditions.append(shift.clone())300# generate rgb images301if return_rgb:302out_rgbs.append(self.toRGB[i](feat))303304# decoder305image, _ = self.stylegan_decoder([style_code],306conditions,307return_latents=return_latents,308input_is_latent=self.input_is_latent,309randomize_noise=randomize_noise)310311return image, out_rgbs312313314