Path: blob/master/src/models/stylegan3.py
809 views
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.1#2# NVIDIA CORPORATION and its licensors retain all intellectual property3# and proprietary rights in and to this software, related documentation4# and any modifications thereto. Any use, reproduction, disclosure or5# distribution of this software and related documentation without an express6# license agreement from NVIDIA CORPORATION is strictly prohibited.78"""Generator architecture from the paper9"Alias-Free Generative Adversarial Networks"."""1011import numpy as np12import scipy.signal13import scipy.optimize14import torch15import utils.style_misc as misc1617from utils.style_ops import conv2d_gradfix18from utils.style_ops import filtered_lrelu19from utils.style_ops import bias_act2021#----------------------------------------------------------------------------2223def modulated_conv2d(24x, # Input tensor: [batch_size, in_channels, in_height, in_width]25w, # Weight tensor: [out_channels, in_channels, kernel_height, kernel_width]26s, # Style tensor: [batch_size, in_channels]27demodulate = True, # Apply weight demodulation?28padding = 0, # Padding: int or [padH, padW]29input_gain = None, # Optional scale factors for the input channels: [], [in_channels], or [batch_size, in_channels]30):31with misc.suppress_tracer_warnings(): # this value will be treated as a constant32batch_size = int(x.shape[0])33out_channels, in_channels, kh, kw = w.shape34misc.assert_shape(w, [out_channels, in_channels, kh, kw]) # [OIkk]35misc.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW]36misc.assert_shape(s, [batch_size, in_channels]) # [NI]3738# Pre-normalize inputs.39if demodulate:40w = w * w.square().mean([1,2,3], keepdim=True).rsqrt()41s = s * s.square().mean().rsqrt()4243# Modulate weights.44w = w.unsqueeze(0) # [NOIkk]45w = w * s.unsqueeze(1).unsqueeze(3).unsqueeze(4) # [NOIkk]4647# Demodulate weights.48if demodulate:49dcoefs = (w.square().sum(dim=[2,3,4]) + 1e-8).rsqrt() # [NO]50w = w * dcoefs.unsqueeze(2).unsqueeze(3).unsqueeze(4) # [NOIkk]5152# Apply input scaling.53if input_gain is not None:54input_gain = input_gain.expand(batch_size, in_channels) # [NI]55w = w * input_gain.unsqueeze(1).unsqueeze(3).unsqueeze(4) # [NOIkk]5657# Execute as one fused op using grouped convolution.58x = x.reshape(1, -1, *x.shape[2:])59w = w.reshape(-1, in_channels, kh, kw)60x = conv2d_gradfix.conv2d(input=x, weight=w.to(x.dtype), padding=padding, groups=batch_size)61x = x.reshape(batch_size, -1, *x.shape[2:])62return x6364#----------------------------------------------------------------------------6566class FullyConnectedLayer(torch.nn.Module):67def __init__(self,68in_features, # Number of input features.69out_features, # Number of output features.70activation = 'linear', # Activation function: 'relu', 'lrelu', etc.71bias = True, # Apply additive bias before the activation function?72lr_multiplier = 1, # Learning rate multiplier.73weight_init = 1, # Initial standard deviation of the weight tensor.74bias_init = 0, # Initial value of the additive bias.75):76super().__init__()77self.in_features = in_features78self.out_features = out_features79self.activation = activation80self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) * (weight_init / lr_multiplier))81bias_init = np.broadcast_to(np.asarray(bias_init, dtype=np.float32), [out_features])82self.bias = torch.nn.Parameter(torch.from_numpy(bias_init / lr_multiplier)) if bias else None83self.weight_gain = lr_multiplier / np.sqrt(in_features)84self.bias_gain = lr_multiplier8586def forward(self, x):87w = self.weight.to(x.dtype) * self.weight_gain88b = self.bias89if b is not None:90b = b.to(x.dtype)91if self.bias_gain != 1:92b = b * self.bias_gain93if self.activation == 'linear' and b is not None:94x = torch.addmm(b.unsqueeze(0), x, w.t())95else:96x = x.matmul(w.t())97x = bias_act.bias_act(x, b, act=self.activation)98return x99100def extra_repr(self):101return f'in_features={self.in_features:d}, out_features={self.out_features:d}, activation={self.activation:s}'102103#----------------------------------------------------------------------------104105class MappingNetwork(torch.nn.Module):106def __init__(self,107z_dim, # Input latent (Z) dimensionality.108c_dim, # Conditioning label (C) dimensionality, 0 = no labels.109w_dim, # Intermediate latent (W) dimensionality.110num_ws, # Number of intermediate latents to output.111num_layers = 2, # Number of mapping layers.112lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers.113w_avg_beta = 0.998, # Decay for tracking the moving average of W during training.114):115super().__init__()116self.z_dim = z_dim117self.c_dim = c_dim118self.w_dim = w_dim119self.num_ws = num_ws120self.num_layers = num_layers121self.w_avg_beta = w_avg_beta122123# Construct layers.124self.embed = FullyConnectedLayer(self.c_dim, self.w_dim) if self.c_dim > 0 else None125features = [self.z_dim + (self.w_dim if self.c_dim > 0 else 0)] + [self.w_dim] * self.num_layers126for idx, in_features, out_features in zip(range(num_layers), features[:-1], features[1:]):127layer = FullyConnectedLayer(in_features, out_features, activation='lrelu', lr_multiplier=lr_multiplier)128setattr(self, f'fc{idx}', layer)129self.register_buffer('w_avg', torch.zeros([w_dim]))130131def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False):132misc.assert_shape(z, [None, self.z_dim])133if truncation_cutoff is None:134truncation_cutoff = self.num_ws135136# Embed, normalize, and concatenate inputs.137x = z.to(torch.float32)138x = x * (x.square().mean(1, keepdim=True) + 1e-8).rsqrt()139if self.c_dim > 0:140misc.assert_shape(c, [None, self.c_dim])141y = self.embed(c.to(torch.float32))142y = y * (y.square().mean(1, keepdim=True) + 1e-8).rsqrt()143x = torch.cat([x, y], dim=1) if x is not None else y144145# Execute layers.146for idx in range(self.num_layers):147x = getattr(self, f'fc{idx}')(x)148149# Update moving average of W.150if update_emas:151self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))152153# Broadcast and apply truncation.154x = x.unsqueeze(1).repeat([1, self.num_ws, 1])155if truncation_psi != 1:156x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi)157return x158159def extra_repr(self):160return f'z_dim={self.z_dim:d}, c_dim={self.c_dim:d}, w_dim={self.w_dim:d}, num_ws={self.num_ws:d}'161162#----------------------------------------------------------------------------163164class SynthesisInput(torch.nn.Module):165def __init__(self,166w_dim, # Intermediate latent (W) dimensionality.167channels, # Number of output channels.168size, # Output spatial size: int or [width, height].169sampling_rate, # Output sampling rate.170bandwidth, # Output bandwidth.171):172super().__init__()173self.w_dim = w_dim174self.channels = channels175self.size = np.broadcast_to(np.asarray(size), [2])176self.sampling_rate = sampling_rate177self.bandwidth = bandwidth178179# Draw random frequencies from uniform 2D disc.180freqs = torch.randn([self.channels, 2])181radii = freqs.square().sum(dim=1, keepdim=True).sqrt()182freqs /= radii * radii.square().exp().pow(0.25)183freqs *= bandwidth184phases = torch.rand([self.channels]) - 0.5185186# Setup parameters and buffers.187self.weight = torch.nn.Parameter(torch.randn([self.channels, self.channels]))188self.affine = FullyConnectedLayer(w_dim, 4, weight_init=0, bias_init=[1,0,0,0])189self.register_buffer('transform', torch.eye(3, 3)) # User-specified inverse transform wrt. resulting image.190self.register_buffer('freqs', freqs)191self.register_buffer('phases', phases)192193def forward(self, w):194# Introduce batch dimension.195transforms = self.transform.unsqueeze(0) # [batch, row, col]196freqs = self.freqs.unsqueeze(0) # [batch, channel, xy]197phases = self.phases.unsqueeze(0) # [batch, channel]198199# Apply learned transformation.200t = self.affine(w) # t = (r_c, r_s, t_x, t_y)201t = t / t[:, :2].norm(dim=1, keepdim=True) # t' = (r'_c, r'_s, t'_x, t'_y)202m_r = torch.eye(3, device=w.device).unsqueeze(0).repeat([w.shape[0], 1, 1]) # Inverse rotation wrt. resulting image.203m_r[:, 0, 0] = t[:, 0] # r'_c204m_r[:, 0, 1] = -t[:, 1] # r'_s205m_r[:, 1, 0] = t[:, 1] # r'_s206m_r[:, 1, 1] = t[:, 0] # r'_c207m_t = torch.eye(3, device=w.device).unsqueeze(0).repeat([w.shape[0], 1, 1]) # Inverse translation wrt. resulting image.208m_t[:, 0, 2] = -t[:, 2] # t'_x209m_t[:, 1, 2] = -t[:, 3] # t'_y210transforms = m_r @ m_t @ transforms # First rotate resulting image, then translate, and finally apply user-specified transform.211212# Transform frequencies.213phases = phases + (freqs @ transforms[:, :2, 2:]).squeeze(2)214freqs = freqs @ transforms[:, :2, :2]215216# Dampen out-of-band frequencies that may occur due to the user-specified transform.217amplitudes = (1 - (freqs.norm(dim=2) - self.bandwidth) / (self.sampling_rate / 2 - self.bandwidth)).clamp(0, 1)218219# Construct sampling grid.220theta = torch.eye(2, 3, device=w.device)221theta[0, 0] = 0.5 * self.size[0] / self.sampling_rate222theta[1, 1] = 0.5 * self.size[1] / self.sampling_rate223grids = torch.nn.functional.affine_grid(theta.unsqueeze(0), [1, 1, self.size[1], self.size[0]], align_corners=False)224225# Compute Fourier features.226x = (grids.unsqueeze(3) @ freqs.permute(0, 2, 1).unsqueeze(1).unsqueeze(2)).squeeze(3) # [batch, height, width, channel]227x = x + phases.unsqueeze(1).unsqueeze(2)228x = torch.sin(x * (np.pi * 2))229x = x * amplitudes.unsqueeze(1).unsqueeze(2)230231# Apply trainable mapping.232weight = self.weight / np.sqrt(self.channels)233x = x @ weight.t()234235# Ensure correct shape.236x = x.permute(0, 3, 1, 2) # [batch, channel, height, width]237misc.assert_shape(x, [w.shape[0], self.channels, int(self.size[1]), int(self.size[0])])238return x239240def extra_repr(self):241return '\n'.join([242f'w_dim={self.w_dim:d}, channels={self.channels:d}, size={list(self.size)},',243f'sampling_rate={self.sampling_rate:g}, bandwidth={self.bandwidth:g}'])244245#----------------------------------------------------------------------------246247class SynthesisLayer(torch.nn.Module):248def __init__(self,249w_dim, # Intermediate latent (W) dimensionality.250is_torgb, # Is this the final ToRGB layer?251is_critically_sampled, # Does this layer use critical sampling?252use_fp16, # Does this layer use FP16?253254# Input & output specifications.255in_channels, # Number of input channels.256out_channels, # Number of output channels.257in_size, # Input spatial size: int or [width, height].258out_size, # Output spatial size: int or [width, height].259in_sampling_rate, # Input sampling rate (s).260out_sampling_rate, # Output sampling rate (s).261in_cutoff, # Input cutoff frequency (f_c).262out_cutoff, # Output cutoff frequency (f_c).263in_half_width, # Input transition band half-width (f_h).264out_half_width, # Output Transition band half-width (f_h).265266# Hyperparameters.267conv_kernel = 3, # Convolution kernel size. Ignored for final the ToRGB layer.268filter_size = 6, # Low-pass filter size relative to the lower resolution when up/downsampling.269lrelu_upsampling = 2, # Relative sampling rate for leaky ReLU. Ignored for final the ToRGB layer.270use_radial_filters = False, # Use radially symmetric downsampling filter? Ignored for critically sampled layers.271conv_clamp = 256, # Clamp the output to [-X, +X], None = disable clamping.272magnitude_ema_beta = 0.999, # Decay rate for the moving average of input magnitudes.273):274super().__init__()275self.w_dim = w_dim276self.is_torgb = is_torgb277self.is_critically_sampled = is_critically_sampled278self.use_fp16 = use_fp16279self.in_channels = in_channels280self.out_channels = out_channels281self.in_size = np.broadcast_to(np.asarray(in_size), [2])282self.out_size = np.broadcast_to(np.asarray(out_size), [2])283self.in_sampling_rate = in_sampling_rate284self.out_sampling_rate = out_sampling_rate285self.tmp_sampling_rate = max(in_sampling_rate, out_sampling_rate) * (1 if is_torgb else lrelu_upsampling)286self.in_cutoff = in_cutoff287self.out_cutoff = out_cutoff288self.in_half_width = in_half_width289self.out_half_width = out_half_width290self.conv_kernel = 1 if is_torgb else conv_kernel291self.conv_clamp = conv_clamp292self.magnitude_ema_beta = magnitude_ema_beta293294# Setup parameters and buffers.295self.affine = FullyConnectedLayer(self.w_dim, self.in_channels, bias_init=1)296self.weight = torch.nn.Parameter(torch.randn([self.out_channels, self.in_channels, self.conv_kernel, self.conv_kernel]))297self.bias = torch.nn.Parameter(torch.zeros([self.out_channels]))298self.register_buffer('magnitude_ema', torch.ones([]))299300# Design upsampling filter.301self.up_factor = int(np.rint(self.tmp_sampling_rate / self.in_sampling_rate))302assert self.in_sampling_rate * self.up_factor == self.tmp_sampling_rate303self.up_taps = filter_size * self.up_factor if self.up_factor > 1 and not self.is_torgb else 1304self.register_buffer('up_filter', self.design_lowpass_filter(305numtaps=self.up_taps, cutoff=self.in_cutoff, width=self.in_half_width*2, fs=self.tmp_sampling_rate))306307# Design downsampling filter.308self.down_factor = int(np.rint(self.tmp_sampling_rate / self.out_sampling_rate))309assert self.out_sampling_rate * self.down_factor == self.tmp_sampling_rate310self.down_taps = filter_size * self.down_factor if self.down_factor > 1 and not self.is_torgb else 1311self.down_radial = use_radial_filters and not self.is_critically_sampled312self.register_buffer('down_filter', self.design_lowpass_filter(313numtaps=self.down_taps, cutoff=self.out_cutoff, width=self.out_half_width*2, fs=self.tmp_sampling_rate, radial=self.down_radial))314315# Compute padding.316pad_total = (self.out_size - 1) * self.down_factor + 1 # Desired output size before downsampling.317pad_total -= (self.in_size + self.conv_kernel - 1) * self.up_factor # Input size after upsampling.318pad_total += self.up_taps + self.down_taps - 2 # Size reduction caused by the filters.319pad_lo = (pad_total + self.up_factor) // 2 # Shift sample locations according to the symmetric interpretation (Appendix C.3).320pad_hi = pad_total - pad_lo321self.padding = [int(pad_lo[0]), int(pad_hi[0]), int(pad_lo[1]), int(pad_hi[1])]322323def forward(self, x, w, noise_mode='random', force_fp32=False, update_emas=False):324assert noise_mode in ['random', 'const', 'none'] # unused325misc.assert_shape(x, [None, self.in_channels, int(self.in_size[1]), int(self.in_size[0])])326misc.assert_shape(w, [x.shape[0], self.w_dim])327328# Track input magnitude.329if update_emas:330with torch.autograd.profiler.record_function('update_magnitude_ema'):331magnitude_cur = x.detach().to(torch.float32).square().mean()332self.magnitude_ema.copy_(magnitude_cur.lerp(self.magnitude_ema, self.magnitude_ema_beta))333input_gain = self.magnitude_ema.rsqrt()334335# Execute affine layer.336styles = self.affine(w)337if self.is_torgb:338weight_gain = 1 / np.sqrt(self.in_channels * (self.conv_kernel ** 2))339styles = styles * weight_gain340341# Execute modulated conv2d.342dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32343x = modulated_conv2d(x=x.to(dtype), w=self.weight, s=styles,344padding=self.conv_kernel-1, demodulate=(not self.is_torgb), input_gain=input_gain)345346# Execute bias, filtered leaky ReLU, and clamping.347gain = 1 if self.is_torgb else np.sqrt(2)348slope = 1 if self.is_torgb else 0.2349x = filtered_lrelu.filtered_lrelu(x=x, fu=self.up_filter, fd=self.down_filter, b=self.bias.to(x.dtype),350up=self.up_factor, down=self.down_factor, padding=self.padding, gain=gain, slope=slope, clamp=self.conv_clamp)351352# Ensure correct shape and dtype.353misc.assert_shape(x, [None, self.out_channels, int(self.out_size[1]), int(self.out_size[0])])354assert x.dtype == dtype355return x356357@staticmethod358def design_lowpass_filter(numtaps, cutoff, width, fs, radial=False):359assert numtaps >= 1360361# Identity filter.362if numtaps == 1:363return None364365# Separable Kaiser low-pass filter.366if not radial:367f = scipy.signal.firwin(numtaps=numtaps, cutoff=cutoff, width=width, fs=fs)368return torch.as_tensor(f, dtype=torch.float32)369370# Radially symmetric jinc-based filter.371x = (np.arange(numtaps) - (numtaps - 1) / 2) / fs372r = np.hypot(*np.meshgrid(x, x))373f = scipy.special.j1(2 * cutoff * (np.pi * r)) / (np.pi * r)374beta = scipy.signal.kaiser_beta(scipy.signal.kaiser_atten(numtaps, width / (fs / 2)))375w = np.kaiser(numtaps, beta)376f *= np.outer(w, w)377f /= np.sum(f)378return torch.as_tensor(f, dtype=torch.float32)379380def extra_repr(self):381return '\n'.join([382f'w_dim={self.w_dim:d}, is_torgb={self.is_torgb},',383f'is_critically_sampled={self.is_critically_sampled}, use_fp16={self.use_fp16},',384f'in_sampling_rate={self.in_sampling_rate:g}, out_sampling_rate={self.out_sampling_rate:g},',385f'in_cutoff={self.in_cutoff:g}, out_cutoff={self.out_cutoff:g},',386f'in_half_width={self.in_half_width:g}, out_half_width={self.out_half_width:g},',387f'in_size={list(self.in_size)}, out_size={list(self.out_size)},',388f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}'])389390#----------------------------------------------------------------------------391392class SynthesisNetwork(torch.nn.Module):393def __init__(self,394w_dim, # Intermediate latent (W) dimensionality.395img_resolution, # Output image resolution.396img_channels, # Number of color channels.397channel_base = 32768, # Overall multiplier for the number of channels.398channel_max = 512, # Maximum number of channels in any layer.399num_layers = 14, # Total number of layers, excluding Fourier features and ToRGB.400num_critical = 2, # Number of critically sampled layers at the end.401first_cutoff = 2, # Cutoff frequency of the first layer (f_{c,0}).402first_stopband = 2**2.1, # Minimum stopband of the first layer (f_{t,0}).403last_stopband_rel = 2**0.3, # Minimum stopband of the last layer, expressed relative to the cutoff.404margin_size = 10, # Number of additional pixels outside the image.405output_scale = 0.25, # Scale factor for the output image.406num_fp16_res = 4, # Use FP16 for the N highest resolutions.407**layer_kwargs, # Arguments for SynthesisLayer.408):409super().__init__()410self.w_dim = w_dim411self.num_ws = num_layers + 2412self.img_resolution = img_resolution413self.img_channels = img_channels414self.num_layers = num_layers415self.num_critical = num_critical416self.margin_size = margin_size417self.output_scale = output_scale418self.num_fp16_res = num_fp16_res419420# Geometric progression of layer cutoffs and min. stopbands.421last_cutoff = self.img_resolution / 2 # f_{c,N}422last_stopband = last_cutoff * last_stopband_rel # f_{t,N}423exponents = np.minimum(np.arange(self.num_layers + 1) / (self.num_layers - self.num_critical), 1)424cutoffs = first_cutoff * (last_cutoff / first_cutoff) ** exponents # f_c[i]425stopbands = first_stopband * (last_stopband / first_stopband) ** exponents # f_t[i]426427# Compute remaining layer parameters.428sampling_rates = np.exp2(np.ceil(np.log2(np.minimum(stopbands * 2, self.img_resolution)))) # s[i]429half_widths = np.maximum(stopbands, sampling_rates / 2) - cutoffs # f_h[i]430sizes = sampling_rates + self.margin_size * 2431sizes[-2:] = self.img_resolution432channels = np.rint(np.minimum((channel_base / 2) / cutoffs, channel_max))433channels[-1] = self.img_channels434435# Construct layers.436self.input = SynthesisInput(437w_dim=self.w_dim, channels=int(channels[0]), size=int(sizes[0]),438sampling_rate=sampling_rates[0], bandwidth=cutoffs[0])439self.layer_names = []440for idx in range(self.num_layers + 1):441prev = max(idx - 1, 0)442is_torgb = (idx == self.num_layers)443is_critically_sampled = (idx >= self.num_layers - self.num_critical)444use_fp16 = (sampling_rates[idx] * (2 ** self.num_fp16_res) > self.img_resolution)445layer = SynthesisLayer(446w_dim=self.w_dim, is_torgb=is_torgb, is_critically_sampled=is_critically_sampled, use_fp16=use_fp16,447in_channels=int(channels[prev]), out_channels= int(channels[idx]),448in_size=int(sizes[prev]), out_size=int(sizes[idx]),449in_sampling_rate=int(sampling_rates[prev]), out_sampling_rate=int(sampling_rates[idx]),450in_cutoff=cutoffs[prev], out_cutoff=cutoffs[idx],451in_half_width=half_widths[prev], out_half_width=half_widths[idx],452**layer_kwargs)453name = f'L{idx}_{layer.out_size[0]}_{layer.out_channels}'454setattr(self, name, layer)455self.layer_names.append(name)456457def forward(self, ws, **layer_kwargs):458misc.assert_shape(ws, [None, self.num_ws, self.w_dim])459ws = ws.to(torch.float32).unbind(dim=1)460461# Execute layers.462x = self.input(ws[0])463for name, w in zip(self.layer_names, ws[1:]):464x = getattr(self, name)(x, w, **layer_kwargs)465if self.output_scale != 1:466x = x * self.output_scale467468# Ensure correct shape and dtype.469misc.assert_shape(x, [None, self.img_channels, self.img_resolution, self.img_resolution])470x = x.to(torch.float32)471return x472473def extra_repr(self):474return '\n'.join([475f'w_dim={self.w_dim:d}, num_ws={self.num_ws:d},',476f'img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d},',477f'num_layers={self.num_layers:d}, num_critical={self.num_critical:d},',478f'margin_size={self.margin_size:d}, num_fp16_res={self.num_fp16_res:d}'])479480#----------------------------------------------------------------------------481482class Generator(torch.nn.Module):483def __init__(self,484z_dim, # Input latent (Z) dimensionality.485c_dim, # Conditioning label (C) dimensionality.486w_dim, # Intermediate latent (W) dimensionality.487img_resolution, # Output resolution.488img_channels, # Number of output color channels.489MODEL, # MODEL config required for infoGAN490mapping_kwargs = {}, # Arguments for MappingNetwork.491synthesis_kwargs = {}, # Arguments for SynthesisNetwork.492):493super().__init__()494self.z_dim = z_dim495self.c_dim = c_dim496self.w_dim = w_dim497self.MODEL = MODEL498self.img_resolution = img_resolution499self.img_channels = img_channels500501z_extra_dim = 0502if self.MODEL.info_type in ["discrete", "both"]:503z_extra_dim += self.MODEL.info_num_discrete_c*self.MODEL.info_dim_discrete_c504if self.MODEL.info_type in ["continuous", "both"]:505z_extra_dim += self.MODEL.info_num_conti_c506507if self.MODEL.info_type != "N/A":508self.z_dim += z_extra_dim509510self.synthesis = SynthesisNetwork(w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, **synthesis_kwargs)511self.num_ws = self.synthesis.num_ws512self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs)513514def forward(self, z, c, eval=False, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs):515ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)516img = self.synthesis(ws, update_emas=update_emas, **synthesis_kwargs)517return img518519#----------------------------------------------------------------------------520521522