Path: blob/master/gfpgan/archs/stylegan2_bilinear_arch.py
884 views
import math1import random2import torch3from basicsr.ops.fused_act import FusedLeakyReLU, fused_leaky_relu4from basicsr.utils.registry import ARCH_REGISTRY5from torch import nn6from torch.nn import functional as F789class NormStyleCode(nn.Module):1011def forward(self, x):12"""Normalize the style codes.1314Args:15x (Tensor): Style codes with shape (b, c).1617Returns:18Tensor: Normalized tensor.19"""20return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)212223class EqualLinear(nn.Module):24"""Equalized Linear as StyleGAN2.2526Args:27in_channels (int): Size of each sample.28out_channels (int): Size of each output sample.29bias (bool): If set to ``False``, the layer will not learn an additive30bias. Default: ``True``.31bias_init_val (float): Bias initialized value. Default: 0.32lr_mul (float): Learning rate multiplier. Default: 1.33activation (None | str): The activation after ``linear`` operation.34Supported: 'fused_lrelu', None. Default: None.35"""3637def __init__(self, in_channels, out_channels, bias=True, bias_init_val=0, lr_mul=1, activation=None):38super(EqualLinear, self).__init__()39self.in_channels = in_channels40self.out_channels = out_channels41self.lr_mul = lr_mul42self.activation = activation43if self.activation not in ['fused_lrelu', None]:44raise ValueError(f'Wrong activation value in EqualLinear: {activation}'45"Supported ones are: ['fused_lrelu', None].")46self.scale = (1 / math.sqrt(in_channels)) * lr_mul4748self.weight = nn.Parameter(torch.randn(out_channels, in_channels).div_(lr_mul))49if bias:50self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))51else:52self.register_parameter('bias', None)5354def forward(self, x):55if self.bias is None:56bias = None57else:58bias = self.bias * self.lr_mul59if self.activation == 'fused_lrelu':60out = F.linear(x, self.weight * self.scale)61out = fused_leaky_relu(out, bias)62else:63out = F.linear(x, self.weight * self.scale, bias=bias)64return out6566def __repr__(self):67return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '68f'out_channels={self.out_channels}, bias={self.bias is not None})')697071class ModulatedConv2d(nn.Module):72"""Modulated Conv2d used in StyleGAN2.7374There is no bias in ModulatedConv2d.7576Args:77in_channels (int): Channel number of the input.78out_channels (int): Channel number of the output.79kernel_size (int): Size of the convolving kernel.80num_style_feat (int): Channel number of style features.81demodulate (bool): Whether to demodulate in the conv layer.82Default: True.83sample_mode (str | None): Indicating 'upsample', 'downsample' or None.84Default: None.85eps (float): A value added to the denominator for numerical stability.86Default: 1e-8.87"""8889def __init__(self,90in_channels,91out_channels,92kernel_size,93num_style_feat,94demodulate=True,95sample_mode=None,96eps=1e-8,97interpolation_mode='bilinear'):98super(ModulatedConv2d, self).__init__()99self.in_channels = in_channels100self.out_channels = out_channels101self.kernel_size = kernel_size102self.demodulate = demodulate103self.sample_mode = sample_mode104self.eps = eps105self.interpolation_mode = interpolation_mode106if self.interpolation_mode == 'nearest':107self.align_corners = None108else:109self.align_corners = False110111self.scale = 1 / math.sqrt(in_channels * kernel_size**2)112# modulation inside each modulated conv113self.modulation = EqualLinear(114num_style_feat, in_channels, bias=True, bias_init_val=1, lr_mul=1, activation=None)115116self.weight = nn.Parameter(torch.randn(1, out_channels, in_channels, kernel_size, kernel_size))117self.padding = kernel_size // 2118119def forward(self, x, style):120"""Forward function.121122Args:123x (Tensor): Tensor with shape (b, c, h, w).124style (Tensor): Tensor with shape (b, num_style_feat).125126Returns:127Tensor: Modulated tensor after convolution.128"""129b, c, h, w = x.shape # c = c_in130# weight modulation131style = self.modulation(style).view(b, 1, c, 1, 1)132# self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)133weight = self.scale * self.weight * style # (b, c_out, c_in, k, k)134135if self.demodulate:136demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)137weight = weight * demod.view(b, self.out_channels, 1, 1, 1)138139weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)140141if self.sample_mode == 'upsample':142x = F.interpolate(x, scale_factor=2, mode=self.interpolation_mode, align_corners=self.align_corners)143elif self.sample_mode == 'downsample':144x = F.interpolate(x, scale_factor=0.5, mode=self.interpolation_mode, align_corners=self.align_corners)145146b, c, h, w = x.shape147x = x.view(1, b * c, h, w)148# weight: (b*c_out, c_in, k, k), groups=b149out = F.conv2d(x, weight, padding=self.padding, groups=b)150out = out.view(b, self.out_channels, *out.shape[2:4])151152return out153154def __repr__(self):155return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '156f'out_channels={self.out_channels}, '157f'kernel_size={self.kernel_size}, '158f'demodulate={self.demodulate}, sample_mode={self.sample_mode})')159160161class StyleConv(nn.Module):162"""Style conv.163164Args:165in_channels (int): Channel number of the input.166out_channels (int): Channel number of the output.167kernel_size (int): Size of the convolving kernel.168num_style_feat (int): Channel number of style features.169demodulate (bool): Whether demodulate in the conv layer. Default: True.170sample_mode (str | None): Indicating 'upsample', 'downsample' or None.171Default: None.172"""173174def __init__(self,175in_channels,176out_channels,177kernel_size,178num_style_feat,179demodulate=True,180sample_mode=None,181interpolation_mode='bilinear'):182super(StyleConv, self).__init__()183self.modulated_conv = ModulatedConv2d(184in_channels,185out_channels,186kernel_size,187num_style_feat,188demodulate=demodulate,189sample_mode=sample_mode,190interpolation_mode=interpolation_mode)191self.weight = nn.Parameter(torch.zeros(1)) # for noise injection192self.activate = FusedLeakyReLU(out_channels)193194def forward(self, x, style, noise=None):195# modulate196out = self.modulated_conv(x, style)197# noise injection198if noise is None:199b, _, h, w = out.shape200noise = out.new_empty(b, 1, h, w).normal_()201out = out + self.weight * noise202# activation (with bias)203out = self.activate(out)204return out205206207class ToRGB(nn.Module):208"""To RGB from features.209210Args:211in_channels (int): Channel number of input.212num_style_feat (int): Channel number of style features.213upsample (bool): Whether to upsample. Default: True.214"""215216def __init__(self, in_channels, num_style_feat, upsample=True, interpolation_mode='bilinear'):217super(ToRGB, self).__init__()218self.upsample = upsample219self.interpolation_mode = interpolation_mode220if self.interpolation_mode == 'nearest':221self.align_corners = None222else:223self.align_corners = False224self.modulated_conv = ModulatedConv2d(225in_channels,2263,227kernel_size=1,228num_style_feat=num_style_feat,229demodulate=False,230sample_mode=None,231interpolation_mode=interpolation_mode)232self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))233234def forward(self, x, style, skip=None):235"""Forward function.236237Args:238x (Tensor): Feature tensor with shape (b, c, h, w).239style (Tensor): Tensor with shape (b, num_style_feat).240skip (Tensor): Base/skip tensor. Default: None.241242Returns:243Tensor: RGB images.244"""245out = self.modulated_conv(x, style)246out = out + self.bias247if skip is not None:248if self.upsample:249skip = F.interpolate(250skip, scale_factor=2, mode=self.interpolation_mode, align_corners=self.align_corners)251out = out + skip252return out253254255class ConstantInput(nn.Module):256"""Constant input.257258Args:259num_channel (int): Channel number of constant input.260size (int): Spatial size of constant input.261"""262263def __init__(self, num_channel, size):264super(ConstantInput, self).__init__()265self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))266267def forward(self, batch):268out = self.weight.repeat(batch, 1, 1, 1)269return out270271272@ARCH_REGISTRY.register()273class StyleGAN2GeneratorBilinear(nn.Module):274"""StyleGAN2 Generator.275276Args:277out_size (int): The spatial size of outputs.278num_style_feat (int): Channel number of style features. Default: 512.279num_mlp (int): Layer number of MLP style layers. Default: 8.280channel_multiplier (int): Channel multiplier for large networks of281StyleGAN2. Default: 2.282lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.283narrow (float): Narrow ratio for channels. Default: 1.0.284"""285286def __init__(self,287out_size,288num_style_feat=512,289num_mlp=8,290channel_multiplier=2,291lr_mlp=0.01,292narrow=1,293interpolation_mode='bilinear'):294super(StyleGAN2GeneratorBilinear, self).__init__()295# Style MLP layers296self.num_style_feat = num_style_feat297style_mlp_layers = [NormStyleCode()]298for i in range(num_mlp):299style_mlp_layers.append(300EqualLinear(301num_style_feat, num_style_feat, bias=True, bias_init_val=0, lr_mul=lr_mlp,302activation='fused_lrelu'))303self.style_mlp = nn.Sequential(*style_mlp_layers)304305channels = {306'4': int(512 * narrow),307'8': int(512 * narrow),308'16': int(512 * narrow),309'32': int(512 * narrow),310'64': int(256 * channel_multiplier * narrow),311'128': int(128 * channel_multiplier * narrow),312'256': int(64 * channel_multiplier * narrow),313'512': int(32 * channel_multiplier * narrow),314'1024': int(16 * channel_multiplier * narrow)315}316self.channels = channels317318self.constant_input = ConstantInput(channels['4'], size=4)319self.style_conv1 = StyleConv(320channels['4'],321channels['4'],322kernel_size=3,323num_style_feat=num_style_feat,324demodulate=True,325sample_mode=None,326interpolation_mode=interpolation_mode)327self.to_rgb1 = ToRGB(channels['4'], num_style_feat, upsample=False, interpolation_mode=interpolation_mode)328329self.log_size = int(math.log(out_size, 2))330self.num_layers = (self.log_size - 2) * 2 + 1331self.num_latent = self.log_size * 2 - 2332333self.style_convs = nn.ModuleList()334self.to_rgbs = nn.ModuleList()335self.noises = nn.Module()336337in_channels = channels['4']338# noise339for layer_idx in range(self.num_layers):340resolution = 2**((layer_idx + 5) // 2)341shape = [1, 1, resolution, resolution]342self.noises.register_buffer(f'noise{layer_idx}', torch.randn(*shape))343# style convs and to_rgbs344for i in range(3, self.log_size + 1):345out_channels = channels[f'{2**i}']346self.style_convs.append(347StyleConv(348in_channels,349out_channels,350kernel_size=3,351num_style_feat=num_style_feat,352demodulate=True,353sample_mode='upsample',354interpolation_mode=interpolation_mode))355self.style_convs.append(356StyleConv(357out_channels,358out_channels,359kernel_size=3,360num_style_feat=num_style_feat,361demodulate=True,362sample_mode=None,363interpolation_mode=interpolation_mode))364self.to_rgbs.append(365ToRGB(out_channels, num_style_feat, upsample=True, interpolation_mode=interpolation_mode))366in_channels = out_channels367368def make_noise(self):369"""Make noise for noise injection."""370device = self.constant_input.weight.device371noises = [torch.randn(1, 1, 4, 4, device=device)]372373for i in range(3, self.log_size + 1):374for _ in range(2):375noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))376377return noises378379def get_latent(self, x):380return self.style_mlp(x)381382def mean_latent(self, num_latent):383latent_in = torch.randn(num_latent, self.num_style_feat, device=self.constant_input.weight.device)384latent = self.style_mlp(latent_in).mean(0, keepdim=True)385return latent386387def forward(self,388styles,389input_is_latent=False,390noise=None,391randomize_noise=True,392truncation=1,393truncation_latent=None,394inject_index=None,395return_latents=False):396"""Forward function for StyleGAN2Generator.397398Args:399styles (list[Tensor]): Sample codes of styles.400input_is_latent (bool): Whether input is latent style.401Default: False.402noise (Tensor | None): Input noise or None. Default: None.403randomize_noise (bool): Randomize noise, used when 'noise' is404False. Default: True.405truncation (float): TODO. Default: 1.406truncation_latent (Tensor | None): TODO. Default: None.407inject_index (int | None): The injection index for mixing noise.408Default: None.409return_latents (bool): Whether to return style latents.410Default: False.411"""412# style codes -> latents with Style MLP layer413if not input_is_latent:414styles = [self.style_mlp(s) for s in styles]415# noises416if noise is None:417if randomize_noise:418noise = [None] * self.num_layers # for each style conv layer419else: # use the stored noise420noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]421# style truncation422if truncation < 1:423style_truncation = []424for style in styles:425style_truncation.append(truncation_latent + truncation * (style - truncation_latent))426styles = style_truncation427# get style latent with injection428if len(styles) == 1:429inject_index = self.num_latent430431if styles[0].ndim < 3:432# repeat latent code for all the layers433latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)434else: # used for encoder with different latent code for each layer435latent = styles[0]436elif len(styles) == 2: # mixing noises437if inject_index is None:438inject_index = random.randint(1, self.num_latent - 1)439latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)440latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)441latent = torch.cat([latent1, latent2], 1)442443# main generation444out = self.constant_input(latent.shape[0])445out = self.style_conv1(out, latent[:, 0], noise=noise[0])446skip = self.to_rgb1(out, latent[:, 1])447448i = 1449for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],450noise[2::2], self.to_rgbs):451out = conv1(out, latent[:, i], noise=noise1)452out = conv2(out, latent[:, i + 1], noise=noise2)453skip = to_rgb(out, latent[:, i + 2], skip)454i += 2455456image = skip457458if return_latents:459return image, latent460else:461return image, None462463464class ScaledLeakyReLU(nn.Module):465"""Scaled LeakyReLU.466467Args:468negative_slope (float): Negative slope. Default: 0.2.469"""470471def __init__(self, negative_slope=0.2):472super(ScaledLeakyReLU, self).__init__()473self.negative_slope = negative_slope474475def forward(self, x):476out = F.leaky_relu(x, negative_slope=self.negative_slope)477return out * math.sqrt(2)478479480class EqualConv2d(nn.Module):481"""Equalized Linear as StyleGAN2.482483Args:484in_channels (int): Channel number of the input.485out_channels (int): Channel number of the output.486kernel_size (int): Size of the convolving kernel.487stride (int): Stride of the convolution. Default: 1488padding (int): Zero-padding added to both sides of the input.489Default: 0.490bias (bool): If ``True``, adds a learnable bias to the output.491Default: ``True``.492bias_init_val (float): Bias initialized value. Default: 0.493"""494495def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, bias_init_val=0):496super(EqualConv2d, self).__init__()497self.in_channels = in_channels498self.out_channels = out_channels499self.kernel_size = kernel_size500self.stride = stride501self.padding = padding502self.scale = 1 / math.sqrt(in_channels * kernel_size**2)503504self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))505if bias:506self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))507else:508self.register_parameter('bias', None)509510def forward(self, x):511out = F.conv2d(512x,513self.weight * self.scale,514bias=self.bias,515stride=self.stride,516padding=self.padding,517)518519return out520521def __repr__(self):522return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '523f'out_channels={self.out_channels}, '524f'kernel_size={self.kernel_size},'525f' stride={self.stride}, padding={self.padding}, '526f'bias={self.bias is not None})')527528529class ConvLayer(nn.Sequential):530"""Conv Layer used in StyleGAN2 Discriminator.531532Args:533in_channels (int): Channel number of the input.534out_channels (int): Channel number of the output.535kernel_size (int): Kernel size.536downsample (bool): Whether downsample by a factor of 2.537Default: False.538bias (bool): Whether with bias. Default: True.539activate (bool): Whether use activateion. Default: True.540"""541542def __init__(self,543in_channels,544out_channels,545kernel_size,546downsample=False,547bias=True,548activate=True,549interpolation_mode='bilinear'):550layers = []551self.interpolation_mode = interpolation_mode552# downsample553if downsample:554if self.interpolation_mode == 'nearest':555self.align_corners = None556else:557self.align_corners = False558559layers.append(560torch.nn.Upsample(scale_factor=0.5, mode=interpolation_mode, align_corners=self.align_corners))561stride = 1562self.padding = kernel_size // 2563# conv564layers.append(565EqualConv2d(566in_channels, out_channels, kernel_size, stride=stride, padding=self.padding, bias=bias567and not activate))568# activation569if activate:570if bias:571layers.append(FusedLeakyReLU(out_channels))572else:573layers.append(ScaledLeakyReLU(0.2))574575super(ConvLayer, self).__init__(*layers)576577578class ResBlock(nn.Module):579"""Residual block used in StyleGAN2 Discriminator.580581Args:582in_channels (int): Channel number of the input.583out_channels (int): Channel number of the output.584"""585586def __init__(self, in_channels, out_channels, interpolation_mode='bilinear'):587super(ResBlock, self).__init__()588589self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True)590self.conv2 = ConvLayer(591in_channels,592out_channels,5933,594downsample=True,595interpolation_mode=interpolation_mode,596bias=True,597activate=True)598self.skip = ConvLayer(599in_channels,600out_channels,6011,602downsample=True,603interpolation_mode=interpolation_mode,604bias=False,605activate=False)606607def forward(self, x):608out = self.conv1(x)609out = self.conv2(out)610skip = self.skip(x)611out = (out + skip) / math.sqrt(2)612return out613614615