Path: blob/master/src/models/stylegan2.py
809 views
"""1this code is borrowed from https://github.com/NVlabs/stylegan2-ada-pytorch with few modifications23Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.45NVIDIA CORPORATION and its licensors retain all intellectual property6and proprietary rights in and to this software, related documentation7and any modifications thereto. Any use, reproduction, disclosure or8distribution of this software and related documentation without an express9license agreement from NVIDIA CORPORATION is strictly prohibited.10"""1112import torch13import torch.nn.functional as F14import numpy as np1516import utils.style_misc as misc17from utils.style_ops import conv2d_resample18from utils.style_ops import upfirdn2d19from utils.style_ops import bias_act20from utils.style_ops import fma212223def normalize_2nd_moment(x, dim=1, eps=1e-8):24return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt()252627def modulated_conv2d(28x, # Input tensor of shape [batch_size, in_channels, in_height, in_width].29weight, # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width].30styles, # Modulation coefficients of shape [batch_size, in_channels].31noise=None, # Optional noise tensor to add to the output activations.32up=1, # Integer upsampling factor.33down=1, # Integer downsampling factor.34padding=0, # Padding with respect to the upsampled image.35resample_filter=None, # Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter().36demodulate=True, # Apply weight demodulation?37flip_weight=True, # False = convolution, True = correlation (matches torch.nn.functional.conv2d).38fused_modconv=True, # Perform modulation, convolution, and demodulation as a single fused operation?39):40batch_size = x.shape[0]41out_channels, in_channels, kh, kw = weight.shape42misc.assert_shape(weight, [out_channels, in_channels, kh, kw]) # [OIkk]43misc.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW]44misc.assert_shape(styles, [batch_size, in_channels]) # [NI]4546# Pre-normalize inputs to avoid FP16 overflow.47if x.dtype == torch.float16 and demodulate:48weight = weight * (1 / np.sqrt(in_channels * kh * kw) / weight.norm(float("inf"), dim=[1, 2, 3], keepdim=True)) # max_Ikk49styles = styles / styles.norm(float("inf"), dim=1, keepdim=True) # max_I5051# Calculate per-sample weights and demodulation coefficients.52w = None53dcoefs = None54if demodulate or fused_modconv:55w = weight.unsqueeze(0) # [NOIkk]56w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk]57if demodulate:58dcoefs = (w.square().sum(dim=[2, 3, 4]) + 1e-8).rsqrt() # [NO]59if demodulate and fused_modconv:60w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk]6162# Execute by scaling the activations before and after the convolution.63if not fused_modconv:64x = x * styles.to(x.dtype).reshape(batch_size, -1, 1, 1)65x = conv2d_resample.conv2d_resample(x=x,66w=weight.to(x.dtype),67f=resample_filter,68up=up,69down=down,70padding=padding,71flip_weight=flip_weight)72if demodulate and noise is not None:73x = fma.fma(x, dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1), noise.to(x.dtype))74elif demodulate:75x = x * dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1)76elif noise is not None:77x = x.add_(noise.to(x.dtype))78return x7980# Execute as one fused op using grouped convolution.81with misc.suppress_tracer_warnings(): # this value will be treated as a constant82batch_size = int(batch_size)83misc.assert_shape(x, [batch_size, in_channels, None, None])84x = x.reshape(1, -1, *x.shape[2:])85w = w.reshape(-1, in_channels, kh, kw)86x = conv2d_resample.conv2d_resample(x=x,87w=w.to(x.dtype),88f=resample_filter,89up=up,90down=down,91padding=padding,92groups=batch_size,93flip_weight=flip_weight)94x = x.reshape(batch_size, -1, *x.shape[2:])95if noise is not None:96x = x.add_(noise)97return x9899100class FullyConnectedLayer(torch.nn.Module):101def __init__(102self,103in_features, # Number of input features.104out_features, # Number of output features.105bias=True, # Apply additive bias before the activation function?106activation="linear", # Activation function: "relu", "lrelu", etc.107lr_multiplier=1, # Learning rate multiplier.108bias_init=0, # Initial value for the additive bias.109):110super().__init__()111self.activation = activation112self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier)113self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None114self.weight_gain = lr_multiplier / np.sqrt(in_features)115self.bias_gain = lr_multiplier116117def forward(self, x):118w = self.weight.to(x.dtype) * self.weight_gain119b = self.bias120if b is not None:121b = b.to(x.dtype)122if self.bias_gain != 1:123b = b * self.bias_gain124125if self.activation == "linear" and b is not None:126x = torch.addmm(b.unsqueeze(0), x, w.t())127else:128x = x.matmul(w.t())129x = bias_act.bias_act(x, b, act=self.activation)130return x131132133class Conv2dLayer(torch.nn.Module):134def __init__(135self,136in_channels, # Number of input channels.137out_channels, # Number of output channels.138kernel_size, # Width and height of the convolution kernel.139bias=True, # Apply additive bias before the activation function?140activation="linear", # Activation function: "relu", "lrelu", etc.141up=1, # Integer upsampling factor.142down=1, # Integer downsampling factor.143resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations.144conv_clamp=None, # Clamp the output to +-X, None = disable clamping.145channels_last=False, # Expect the input to have memory_format=channels_last?146trainable=True, # Update the weights of this layer during training?147):148super().__init__()149self.activation = activation150self.up = up151self.down = down152self.conv_clamp = conv_clamp153self.register_buffer("resample_filter", upfirdn2d.setup_filter(resample_filter))154self.padding = kernel_size // 2155self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size**2))156self.act_gain = bias_act.activation_funcs[activation].def_gain157158memory_format = torch.channels_last if channels_last else torch.contiguous_format159weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format)160bias = torch.zeros([out_channels]) if bias else None161if trainable:162self.weight = torch.nn.Parameter(weight)163self.bias = torch.nn.Parameter(bias) if bias is not None else None164else:165self.register_buffer("weight", weight)166if bias is not None:167self.register_buffer("bias", bias)168else:169self.bias = None170171def forward(self, x, gain=1):172w = self.weight * self.weight_gain173b = self.bias.to(x.dtype) if self.bias is not None else None174flip_weight = (self.up == 1) # slightly faster175x = conv2d_resample.conv2d_resample(x=x,176w=w.to(x.dtype),177f=self.resample_filter,178up=self.up,179down=self.down,180padding=self.padding,181flip_weight=flip_weight)182183act_gain = self.act_gain * gain184act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None185x = bias_act.bias_act(x, b, act=self.activation, gain=act_gain, clamp=act_clamp)186return x187188189class MappingNetwork(torch.nn.Module):190def __init__(191self,192z_dim, # Input latent (Z) dimensionality, 0 = no latent.193c_dim, # Conditioning label (C) dimensionality, 0 = no label.194w_dim, # Intermediate latent (W) dimensionality.195num_ws, # Number of intermediate latents to output, None = do not broadcast.196num_layers=8, # Number of mapping layers.197embed_features=None, # Label embedding dimensionality, None = same as w_dim.198layer_features=None, # Number of intermediate features in the mapping layers, None = same as w_dim.199activation="lrelu", # Activation function: "relu", "lrelu", etc.200lr_multiplier=0.01, # Learning rate multiplier for the mapping layers.201w_avg_beta=0.998, # Decay for tracking the moving average of W during training, None = do not track.202):203super().__init__()204self.z_dim = z_dim205self.c_dim = c_dim206self.w_dim = w_dim207self.num_ws = num_ws208self.num_layers = num_layers209self.w_avg_beta = w_avg_beta210211if embed_features is None:212embed_features = w_dim213if c_dim == 0:214embed_features = 0215if layer_features is None:216layer_features = w_dim217features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim]218219if c_dim > 0:220self.embed = FullyConnectedLayer(c_dim, embed_features)221for idx in range(num_layers):222in_features = features_list[idx]223out_features = features_list[idx + 1]224layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier)225setattr(self, f"fc{idx}", layer)226227if num_ws is not None and w_avg_beta is not None:228self.register_buffer("w_avg", torch.zeros([w_dim]))229230def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False):231# Embed, normalize, and concat inputs.232x = None233if self.z_dim > 0:234misc.assert_shape(z, [None, self.z_dim])235x = normalize_2nd_moment(z.to(torch.float32))236if self.c_dim > 0:237misc.assert_shape(c, [None, self.c_dim])238y = normalize_2nd_moment(self.embed(c.to(torch.float32)))239x = torch.cat([x, y], dim=1) if x is not None else y240241# Main layers.242for idx in range(self.num_layers):243layer = getattr(self, f"fc{idx}")244x = layer(x)245246# Update moving average of W.247if update_emas and self.w_avg_beta is not None:248self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))249250# Broadcast.251if self.num_ws is not None:252x = x.unsqueeze(1).repeat([1, self.num_ws, 1])253254# Apply truncation.255if truncation_psi != 1:256assert self.w_avg_beta is not None257if self.num_ws is None or truncation_cutoff is None:258x = self.w_avg.lerp(x, truncation_psi)259else:260x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi)261return x262263264class SynthesisLayer(torch.nn.Module):265def __init__(266self,267in_channels, # Number of input channels.268out_channels, # Number of output channels.269w_dim, # Intermediate latent (W) dimensionality.270resolution, # Resolution of this layer.271kernel_size=3, # Convolution kernel size.272up=1, # Integer upsampling factor.273use_noise=True, # Enable noise input?274activation="lrelu", # Activation function: "relu", "lrelu", etc.275resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations.276conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.277channels_last=False, # Use channels_last format for the weights?278):279super().__init__()280self.resolution = resolution281self.up = up282self.use_noise = use_noise283self.activation = activation284self.conv_clamp = conv_clamp285self.register_buffer("resample_filter", upfirdn2d.setup_filter(resample_filter))286self.padding = kernel_size // 2287self.act_gain = bias_act.activation_funcs[activation].def_gain288289self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)290memory_format = torch.channels_last if channels_last else torch.contiguous_format291self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format))292if use_noise:293self.register_buffer("noise_const", torch.randn([resolution, resolution]))294self.noise_strength = torch.nn.Parameter(torch.zeros([]))295self.bias = torch.nn.Parameter(torch.zeros([out_channels]))296297def forward(self, x, w, noise_mode="random", fused_modconv=True, gain=1):298assert noise_mode in ["random", "const", "none"]299in_resolution = self.resolution // self.up300misc.assert_shape(x, [None, self.weight.shape[1], in_resolution, in_resolution])301styles = self.affine(w)302303noise = None304if self.use_noise and noise_mode == "random":305noise = torch.randn([x.shape[0], 1, self.resolution, self.resolution], device=x.device) * self.noise_strength306if self.use_noise and noise_mode == "const":307noise = self.noise_const * self.noise_strength308309flip_weight = (self.up == 1) # slightly faster310x = modulated_conv2d(x=x,311weight=self.weight,312styles=styles,313noise=noise,314up=self.up,315padding=self.padding,316resample_filter=self.resample_filter,317flip_weight=flip_weight,318fused_modconv=fused_modconv)319320act_gain = self.act_gain * gain321act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None322x = bias_act.bias_act(x, self.bias.to(x.dtype), act=self.activation, gain=act_gain, clamp=act_clamp)323return x324325326class ToRGBLayer(torch.nn.Module):327def __init__(self, in_channels, out_channels, w_dim, kernel_size=1, conv_clamp=None, channels_last=False):328super().__init__()329self.conv_clamp = conv_clamp330self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)331memory_format = torch.channels_last if channels_last else torch.contiguous_format332self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format))333self.bias = torch.nn.Parameter(torch.zeros([out_channels]))334self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size**2))335336def forward(self, x, w, fused_modconv=True):337styles = self.affine(w) * self.weight_gain338x = modulated_conv2d(x=x, weight=self.weight, styles=styles, demodulate=False, fused_modconv=fused_modconv)339x = bias_act.bias_act(x, self.bias.to(x.dtype), clamp=self.conv_clamp)340return x341342343class SynthesisBlock(torch.nn.Module):344def __init__(345self,346in_channels, # Number of input channels, 0 = first block.347out_channels, # Number of output channels.348w_dim, # Intermediate latent (W) dimensionality.349resolution, # Resolution of this block.350img_channels, # Number of output color channels.351is_last, # Is this the last block?352architecture="skip", # Architecture: "orig", "skip", "resnet".353resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations.354conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.355use_fp16=False, # Use FP16 for this block?356fp16_channels_last=False, # Use channels-last memory format with FP16?357**layer_kwargs, # Arguments for SynthesisLayer.358):359assert architecture in ["orig", "skip", "resnet"]360super().__init__()361self.in_channels = in_channels362self.w_dim = w_dim363self.resolution = resolution364self.img_channels = img_channels365self.is_last = is_last366self.architecture = architecture367self.use_fp16 = use_fp16368self.channels_last = (use_fp16 and fp16_channels_last)369self.register_buffer("resample_filter", upfirdn2d.setup_filter(resample_filter))370self.num_conv = 0371self.num_torgb = 0372373if in_channels == 0:374self.const = torch.nn.Parameter(torch.randn([out_channels, resolution, resolution]))375376if in_channels != 0:377self.conv0 = SynthesisLayer(in_channels,378out_channels,379w_dim=w_dim,380resolution=resolution,381up=2,382resample_filter=resample_filter,383conv_clamp=conv_clamp,384channels_last=self.channels_last,385**layer_kwargs)386self.num_conv += 1387388self.conv1 = SynthesisLayer(out_channels,389out_channels,390w_dim=w_dim,391resolution=resolution,392conv_clamp=conv_clamp,393channels_last=self.channels_last,394**layer_kwargs)395self.num_conv += 1396397if is_last or architecture == "skip":398self.torgb = ToRGBLayer(out_channels, img_channels, w_dim=w_dim, conv_clamp=conv_clamp, channels_last=self.channels_last)399self.num_torgb += 1400401if in_channels != 0 and architecture == "resnet":402self.skip = Conv2dLayer(in_channels,403out_channels,404kernel_size=1,405bias=False,406up=2,407resample_filter=resample_filter,408channels_last=self.channels_last)409410def forward(self, x, img, ws, force_fp32=False, fused_modconv=None, update_emas=False, **layer_kwargs):411_ = update_emas # unused412misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim])413w_iter = iter(ws.unbind(dim=1))414dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32415memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format416if fused_modconv is None:417with misc.suppress_tracer_warnings(): # this value will be treated as a constant418fused_modconv = (not self.training) and (dtype == torch.float32 or int(x.shape[0]) == 1)419420# Input.421if self.in_channels == 0:422x = self.const.to(dtype=dtype, memory_format=memory_format)423x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1])424else:425misc.assert_shape(x, [None, self.in_channels, self.resolution // 2, self.resolution // 2])426x = x.to(dtype=dtype, memory_format=memory_format)427428# Main layers.429if self.in_channels == 0:430x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)431elif self.architecture == "resnet":432y = self.skip(x, gain=np.sqrt(0.5))433x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)434x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs)435x = y.add_(x)436else:437x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)438x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)439440# ToRGB.441if img is not None:442misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 2])443img = upfirdn2d.upsample2d(img, self.resample_filter)444if self.is_last or self.architecture == "skip":445y = self.torgb(x, next(w_iter), fused_modconv=fused_modconv)446y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format)447img = img.add_(y) if img is not None else y448449assert x.dtype == dtype450assert img is None or img.dtype == torch.float32451return x, img452453454class SynthesisNetwork(torch.nn.Module):455def __init__(456self,457w_dim, # Intermediate latent (W) dimensionality.458img_resolution, # Output image resolution.459img_channels, # Number of color channels.460channel_base=32768, # Overall multiplier for the number of channels.461channel_max=512, # Maximum number of channels in any layer.462num_fp16_res=0, # Use FP16 for the N highest resolutions.463**block_kwargs, # Arguments for SynthesisBlock.464):465assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0466super().__init__()467self.w_dim = w_dim468self.img_resolution = img_resolution469self.img_resolution_log2 = int(np.log2(img_resolution))470self.img_channels = img_channels471self.block_resolutions = [2**i for i in range(2, self.img_resolution_log2 + 1)]472channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions}473fp16_resolution = max(2**(self.img_resolution_log2 + 1 - num_fp16_res), 8)474475self.num_ws = 0476for res in self.block_resolutions:477in_channels = channels_dict[res // 2] if res > 4 else 0478out_channels = channels_dict[res]479use_fp16 = (res >= fp16_resolution)480is_last = (res == self.img_resolution)481block = SynthesisBlock(in_channels,482out_channels,483w_dim=w_dim,484resolution=res,485img_channels=img_channels,486is_last=is_last,487use_fp16=use_fp16,488**block_kwargs)489self.num_ws += block.num_conv490if is_last:491self.num_ws += block.num_torgb492setattr(self, f"b{res}", block)493494def forward(self, ws, **block_kwargs):495block_ws = []496misc.assert_shape(ws, [None, self.num_ws, self.w_dim])497ws = ws.to(torch.float32)498w_idx = 0499for res in self.block_resolutions:500block = getattr(self, f"b{res}")501block_ws.append(ws.narrow(1, w_idx, block.num_conv + block.num_torgb))502w_idx += block.num_conv503504x = img = None505for res, cur_ws in zip(self.block_resolutions, block_ws):506block = getattr(self, f"b{res}")507x, img = block(x, img, cur_ws, **block_kwargs)508return img509510511class Generator(torch.nn.Module):512def __init__(513self,514z_dim, # Input latent (Z) dimensionality.515c_dim, # Conditioning label (C) dimensionality.516w_dim, # Intermediate latent (W) dimensionality.517img_resolution, # Output resolution.518img_channels, # Number of output color channels.519MODEL, # MODEL config required for applying infoGAN520mapping_kwargs={}, # Arguments for MappingNetwork.521synthesis_kwargs={}, # Arguments for SynthesisNetwork.522):523super().__init__()524self.z_dim = z_dim525self.c_dim = c_dim526self.w_dim = w_dim527self.MODEL = MODEL528self.img_resolution = img_resolution529self.img_channels = img_channels530531z_extra_dim = 0532if self.MODEL.info_type in ["discrete", "both"]:533z_extra_dim += self.MODEL.info_num_discrete_c*self.MODEL.info_dim_discrete_c534if self.MODEL.info_type in ["continuous", "both"]:535z_extra_dim += self.MODEL.info_num_conti_c536537if self.MODEL.info_type != "N/A":538self.z_dim += z_extra_dim539540self.synthesis = SynthesisNetwork(w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, **synthesis_kwargs)541self.num_ws = self.synthesis.num_ws542self.mapping = MappingNetwork(z_dim=self.z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs)543544def forward(self, z, c, eval=False, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs):545ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)546img = self.synthesis(ws, update_emas=update_emas, **synthesis_kwargs)547return img548549550class DiscriminatorBlock(torch.nn.Module):551def __init__(552self,553in_channels, # Number of input channels, 0 = first block.554tmp_channels, # Number of intermediate channels.555out_channels, # Number of output channels.556resolution, # Resolution of this block.557img_channels, # Number of input color channels.558first_layer_idx, # Index of the first layer.559architecture="resnet", # Architecture: "orig", "skip", "resnet".560activation="lrelu", # Activation function: "relu", "lrelu", etc.561resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations.562conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.563use_fp16=False, # Use FP16 for this block?564fp16_channels_last=False, # Use channels-last memory format with FP16?565freeze_layers=0, # Freeze-D: Number of layers to freeze.566):567assert in_channels in [0, tmp_channels]568assert architecture in ["orig", "skip", "resnet"]569super().__init__()570self.in_channels = in_channels571self.resolution = resolution572self.img_channels = img_channels573self.first_layer_idx = first_layer_idx574self.architecture = architecture575self.use_fp16 = use_fp16576self.channels_last = (use_fp16 and fp16_channels_last)577self.register_buffer("resample_filter", upfirdn2d.setup_filter(resample_filter))578579self.num_layers = 0580581def trainable_gen():582while True:583layer_idx = self.first_layer_idx + self.num_layers584trainable = (layer_idx >= freeze_layers)585self.num_layers += 1586yield trainable587588trainable_iter = trainable_gen()589590if in_channels == 0 or architecture == "skip":591self.fromrgb = Conv2dLayer(img_channels,592tmp_channels,593kernel_size=1,594activation=activation,595trainable=next(trainable_iter),596conv_clamp=conv_clamp,597channels_last=self.channels_last)598599self.conv0 = Conv2dLayer(tmp_channels,600tmp_channels,601kernel_size=3,602activation=activation,603trainable=next(trainable_iter),604conv_clamp=conv_clamp,605channels_last=self.channels_last)606607self.conv1 = Conv2dLayer(tmp_channels,608out_channels,609kernel_size=3,610activation=activation,611down=2,612trainable=next(trainable_iter),613resample_filter=resample_filter,614conv_clamp=conv_clamp,615channels_last=self.channels_last)616617if architecture == "resnet":618self.skip = Conv2dLayer(tmp_channels,619out_channels,620kernel_size=1,621bias=False,622down=2,623trainable=next(trainable_iter),624resample_filter=resample_filter,625channels_last=self.channels_last)626627def forward(self, x, img, force_fp32=False):628dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32629memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format630631# Input.632if x is not None:633misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution])634x = x.to(dtype=dtype, memory_format=memory_format)635636# FromRGB.637if self.in_channels == 0 or self.architecture == "skip":638misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution])639img = img.to(dtype=dtype, memory_format=memory_format)640y = self.fromrgb(img)641x = x + y if x is not None else y642img = upfirdn2d.downsample2d(img, self.resample_filter) if self.architecture == "skip" else None643644# Main layers.645if self.architecture == "resnet":646y = self.skip(x, gain=np.sqrt(0.5))647x = self.conv0(x)648x = self.conv1(x, gain=np.sqrt(0.5))649x = y.add_(x)650else:651x = self.conv0(x)652x = self.conv1(x)653654assert x.dtype == dtype655return x, img656657658class MinibatchStdLayer(torch.nn.Module):659def __init__(self, group_size, num_channels=1):660super().__init__()661self.group_size = group_size662self.num_channels = num_channels663664def forward(self, x):665N, C, H, W = x.shape666with misc.suppress_tracer_warnings(): # as_tensor results are registered as constants667G = torch.min(torch.as_tensor(self.group_size), torch.as_tensor(N)) if self.group_size is not None else N668F = self.num_channels669c = C // F670671y = x.reshape(G, -1, F, c, H, W) # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c.672y = y - y.mean(dim=0) # [GnFcHW] Subtract mean over group.673y = y.square().mean(dim=0) # [nFcHW] Calc variance over group.674y = (y + 1e-8).sqrt() # [nFcHW] Calc stddev over group.675y = y.mean(dim=[2, 3, 4]) # [nF] Take average over channels and pixels.676y = y.reshape(-1, F, 1, 1) # [nF11] Add missing dimensions.677y = y.repeat(G, 1, H, W) # [NFHW] Replicate over group and pixels.678x = torch.cat([x, y], dim=1) # [NCHW] Append to input as new channels.679return x680681682class DiscriminatorEpilogue(torch.nn.Module):683def __init__(684self,685in_channels, # Number of input channels.686cmap_dim, # Dimensionality of mapped conditioning label, 0 = no label.687resolution, # Resolution of this block.688img_channels, # Number of input color channels.689architecture="resnet", # Architecture: "orig", "skip", "resnet".690mbstd_group_size=4, # Group size for the minibatch standard deviation layer, None = entire minibatch.691mbstd_num_channels=1, # Number of features for the minibatch standard deviation layer, 0 = disable.692activation="lrelu", # Activation function: "relu", "lrelu", etc.693conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.694):695assert architecture in ["orig", "skip", "resnet"]696super().__init__()697self.in_channels = in_channels698self.cmap_dim = cmap_dim699self.resolution = resolution700self.img_channels = img_channels701self.architecture = architecture702703if architecture == "skip":704self.fromrgb = Conv2dLayer(img_channels, in_channels, kernel_size=1, activation=activation)705self.mbstd = MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels) if mbstd_num_channels > 0 else None706self.conv = Conv2dLayer(in_channels + mbstd_num_channels, in_channels, kernel_size=3, activation=activation, conv_clamp=conv_clamp)707self.fc = FullyConnectedLayer(in_channels * (resolution**2), in_channels, activation=activation)708# self.out = FullyConnectedLayer(in_channels, 1 if cmap_dim == 0 else cmap_dim)709710def forward(self, x, img, force_fp32=False):711misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution]) # [NCHW]712_ = force_fp32 # unused713dtype = torch.float32714memory_format = torch.contiguous_format715716# FromRGB.717x = x.to(dtype=dtype, memory_format=memory_format)718if self.architecture == "skip":719misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution])720img = img.to(dtype=dtype, memory_format=memory_format)721x = x + self.fromrgb(img)722723# Main layers.724if self.mbstd is not None:725x = self.mbstd(x)726x = self.conv(x)727x = self.fc(x.flatten(1))728# x = self.out(x)729730return x731732733class Discriminator(torch.nn.Module):734def __init__(735self,736c_dim, # Conditioning label (C) dimensionality.737img_resolution, # Input resolution.738img_channels, # Number of input color channels.739architecture="resnet", # Architecture: "orig", "skip", "resnet".740channel_base=32768, # Overall multiplier for the number of channels.741channel_max=512, # Maximum number of channels in any layer.742num_fp16_res=0, # Use FP16 for the N highest resolutions.743conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.744cmap_dim=None, # Dimensionality of mapped conditioning label, None = default.745d_cond_mtd=None, # conditioning method of the discriminator746aux_cls_type=None, # type of auxiliary classifier747d_embed_dim=None, # dimension of feature maps after convolution operations748num_classes=None, # number of classes749normalize_d_embed=None, # whether to normalize the feature maps or not750block_kwargs={}, # Arguments for DiscriminatorBlock.751mapping_kwargs={}, # Arguments for MappingNetwork.752epilogue_kwargs={}, # Arguments for DiscriminatorEpilogue.753MODEL=None, # needed to check options for infoGAN754):755super().__init__()756self.c_dim = c_dim757self.img_resolution = img_resolution758self.img_channels = img_channels759self.cmap_dim = cmap_dim760self.d_cond_mtd = d_cond_mtd761self.aux_cls_type = aux_cls_type762self.num_classes = num_classes763self.normalize_d_embed = normalize_d_embed764self.img_resolution_log2 = int(np.log2(img_resolution))765self.block_resolutions = [2**i for i in range(self.img_resolution_log2, 2, -1)]766self.MODEL = MODEL767channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]}768fp16_resolution = max(2**(self.img_resolution_log2 + 1 - num_fp16_res), 8)769770if self.cmap_dim is None:771self.cmap_dim = channels_dict[4]772if c_dim == 0:773self.cmap_dim = 0774775common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp)776cur_layer_idx = 0777for res in self.block_resolutions:778in_channels = channels_dict[res] if res < img_resolution else 0779tmp_channels = channels_dict[res]780out_channels = channels_dict[res // 2]781use_fp16 = (res >= fp16_resolution)782block = DiscriminatorBlock(in_channels,783tmp_channels,784out_channels,785resolution=res,786first_layer_idx=cur_layer_idx,787use_fp16=use_fp16,788**block_kwargs,789**common_kwargs)790setattr(self, f"b{res}", block)791cur_layer_idx += block.num_layers792793self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=self.cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs)794795# linear layer for adversarial training796if self.d_cond_mtd == "MH":797self.linear1 = FullyConnectedLayer(channels_dict[4], 1 + self.num_classes, bias=True)798elif self.d_cond_mtd == "MD":799self.linear1 = FullyConnectedLayer(channels_dict[4], self.num_classes, bias=True)800elif self.d_cond_mtd == "SPD":801self.linear1 = FullyConnectedLayer(channels_dict[4], 1 if self.cmap_dim == 0 else self.cmap_dim, bias=True)802else:803self.linear1 = FullyConnectedLayer(channels_dict[4], 1, bias=True)804805# double num_classes for Auxiliary Discriminative Classifier806if self.aux_cls_type == "ADC":807num_classes, c_dim = num_classes * 2, c_dim * 2808809# linear and embedding layers for discriminator conditioning810if self.d_cond_mtd == "AC":811self.linear2 = FullyConnectedLayer(channels_dict[4], num_classes, bias=False)812elif self.d_cond_mtd == "PD":813self.linear2 = FullyConnectedLayer(channels_dict[4], self.cmap_dim, bias=True)814elif self.d_cond_mtd == "SPD":815self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=self.cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs)816elif self.d_cond_mtd in ["2C", "D2DCE"]:817self.linear2 = FullyConnectedLayer(channels_dict[4], d_embed_dim, bias=True)818self.embedding = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=d_embed_dim, num_ws=None, w_avg_beta=None, num_layers=1, **mapping_kwargs)819else:820pass821822# linear and embedding layers for evolved classifier-based GAN823if self.aux_cls_type == "TAC":824if self.d_cond_mtd == "AC":825self.linear_mi = FullyConnectedLayer(channels_dict[4], num_classes, bias=False)826elif self.d_cond_mtd in ["2C", "D2DCE"]:827self.linear_mi = FullyConnectedLayer(channels_dict[4], d_embed_dim, bias=True)828self.embedding_mi = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=d_embed_dim, num_ws=None, w_avg_beta=None, num_layers=1, **mapping_kwargs)829else:830raise NotImplementedError831832# Q head network for infoGAN833if self.MODEL.info_type in ["discrete", "both"]:834out_features = self.MODEL.info_num_discrete_c*self.MODEL.info_dim_discrete_c835self.info_discrete_linear = FullyConnectedLayer(in_features=channels_dict[4], out_features=out_features, bias=False)836if self.MODEL.info_type in ["continuous", "both"]:837out_features = self.MODEL.info_num_conti_c838self.info_conti_mu_linear = FullyConnectedLayer(in_features=channels_dict[4], out_features=out_features, bias=False)839self.info_conti_var_linear = FullyConnectedLayer(in_features=channels_dict[4], out_features=out_features, bias=False)840841def forward(self, img, label, eval=False, adc_fake=False, update_emas=False, **block_kwargs):842_ = update_emas # unused843x, embed, proxy, cls_output = None, None, None, None844mi_embed, mi_proxy, mi_cls_output = None, None, None845info_discrete_c_logits, info_conti_mu, info_conti_var = None, None, None846for res in self.block_resolutions:847block = getattr(self, f"b{res}")848x, img = block(x, img, **block_kwargs)849h = self.b4(x, img)850851# adversarial training852if self.d_cond_mtd != "SPD":853adv_output = torch.squeeze(self.linear1(h))854855# make class labels odd (for fake) or even (for real) for ADC856if self.aux_cls_type == "ADC":857if adc_fake:858label = label*2 + 1859else:860label = label*2861oh_label = F.one_hot(label, self.num_classes * 2 if self.aux_cls_type=="ADC" else self.num_classes)862863# forward pass through InfoGAN Q head864if self.MODEL.info_type in ["discrete", "both"]:865info_discrete_c_logits = self.info_discrete_linear(h)866if self.MODEL.info_type in ["continuous", "both"]:867info_conti_mu = self.info_conti_mu_linear(h)868info_conti_var = torch.exp(self.info_conti_var_linear(h))869870# class conditioning871if self.d_cond_mtd == "AC":872if self.normalize_d_embed:873for W in self.linear2.parameters():874W = F.normalize(W, dim=1)875h = F.normalize(h, dim=1)876cls_output = self.linear2(h)877elif self.d_cond_mtd == "PD":878adv_output = adv_output + torch.sum(torch.mul(self.embedding(None, oh_label), h), 1)879elif self.d_cond_mtd == "SPD":880embed = self.linear1(h)881cmap = self.mapping(None, oh_label)882adv_output = (embed * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))883elif self.d_cond_mtd in ["2C", "D2DCE"]:884embed = self.linear2(h)885proxy = self.embedding(None, oh_label)886if self.normalize_d_embed:887embed = F.normalize(embed, dim=1)888proxy = F.normalize(proxy, dim=1)889elif self.d_cond_mtd == "MD":890idx = torch.LongTensor(range(label.size(0))).to(label.device)891adv_output = adv_output[idx, label]892elif self.d_cond_mtd in ["W/O", "MH"]:893pass894else:895raise NotImplementedError896897# extra conditioning for TACGAN and ADCGAN898if self.aux_cls_type == "TAC":899if self.d_cond_mtd == "AC":900if self.normalize_d_embed:901for W in self.linear_mi.parameters():902W = F.normalize(W, dim=1)903mi_cls_output = self.linear_mi(h)904elif self.d_cond_mtd in ["2C", "D2DCE"]:905mi_embed = self.linear_mi(h)906mi_proxy = self.embedding_mi(None, oh_label)907if self.normalize_d_embed:908mi_embed = F.normalize(mi_embed, dim=1)909mi_proxy = F.normalize(mi_proxy, dim=1)910return {911"h": h,912"adv_output": adv_output,913"embed": embed,914"proxy": proxy,915"cls_output": cls_output,916"label": label,917"mi_embed": mi_embed,918"mi_proxy": mi_proxy,919"mi_cls_output": mi_cls_output,920"info_discrete_c_logits": info_discrete_c_logits,921"info_conti_mu": info_conti_mu,922"info_conti_var": info_conti_var923}924925926