Path: blob/master/src/models/big_resnet_deep_legacy.py
809 views
# PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN1# The MIT License (MIT)2# See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details34# models/deep_big_resnet.py56import torch7import torch.nn as nn8import torch.nn.functional as F910import utils.ops as ops11import utils.misc as misc121314class GenBlock(nn.Module):15def __init__(self, in_channels, out_channels, g_cond_mtd, affine_input_dim, upsample,16MODULES, channel_ratio=4):17super(GenBlock, self).__init__()18self.in_channels = in_channels19self.out_channels = out_channels20self.g_cond_mtd = g_cond_mtd21self.upsample = upsample22self.hidden_channels = self.in_channels // channel_ratio2324self.bn1 = MODULES.g_bn(affine_input_dim, self.in_channels, MODULES)25self.bn2 = MODULES.g_bn(affine_input_dim, self.hidden_channels, MODULES)26self.bn3 = MODULES.g_bn(affine_input_dim, self.hidden_channels, MODULES)27self.bn4 = MODULES.g_bn(affine_input_dim, self.hidden_channels, MODULES)2829self.activation = MODULES.g_act_fn3031self.conv2d1 = MODULES.g_conv2d(in_channels=self.in_channels, out_channels=self.hidden_channels, kernel_size=1, stride=1, padding=0)32self.conv2d2 = MODULES.g_conv2d(in_channels=self.hidden_channels,33out_channels=self.hidden_channels,34kernel_size=3,35stride=1,36padding=1)37self.conv2d3 = MODULES.g_conv2d(in_channels=self.hidden_channels,38out_channels=self.hidden_channels,39kernel_size=3,40stride=1,41padding=1)42self.conv2d4 = MODULES.g_conv2d(in_channels=self.hidden_channels,43out_channels=self.out_channels,44kernel_size=1,45stride=1,46padding=0)4748def forward(self, x, affine):49if self.in_channels != self.out_channels:50x0 = x[:, :self.out_channels]51else:52x0 = x5354x = self.bn1(x, affine)55x = self.conv2d1(self.activation(x))5657x = self.bn2(x, affine)58x = self.activation(x)59if self.upsample:60x = F.interpolate(x, scale_factor=2, mode="nearest") # upsample61x = self.conv2d2(x)6263x = self.bn3(x, affine)64x = self.conv2d3(self.activation(x))6566x = self.bn4(x, affine)67x = self.conv2d4(self.activation(x))6869if self.upsample:70x0 = F.interpolate(x0, scale_factor=2, mode="nearest") # upsample71out = x + x072return out737475class Generator(nn.Module):76def __init__(self, z_dim, g_shared_dim, img_size, g_conv_dim, apply_attn, attn_g_loc, g_cond_mtd, num_classes, g_init, g_depth,77mixed_precision, MODULES, MODEL):78super(Generator, self).__init__()79g_in_dims_collection = {80"32": [g_conv_dim * 4, g_conv_dim * 4, g_conv_dim * 4],81"64": [g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2],82"128": [g_conv_dim * 16, g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2],83"256": [g_conv_dim * 16, g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2],84"512": [g_conv_dim * 16, g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim]85}8687g_out_dims_collection = {88"32": [g_conv_dim * 4, g_conv_dim * 4, g_conv_dim * 4],89"64": [g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim],90"128": [g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim],91"256": [g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim],92"512": [g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim, g_conv_dim]93}9495bottom_collection = {"32": 4, "64": 4, "128": 4, "256": 4, "512": 4}9697self.z_dim = z_dim98self.g_shared_dim = g_shared_dim99self.g_cond_mtd = g_cond_mtd100self.num_classes = num_classes101self.mixed_precision = mixed_precision102self.MODEL = MODEL103self.in_dims = g_in_dims_collection[str(img_size)]104self.out_dims = g_out_dims_collection[str(img_size)]105self.bottom = bottom_collection[str(img_size)]106self.num_blocks = len(self.in_dims)107self.affine_input_dim = self.z_dim108109info_dim = 0110if self.MODEL.info_type in ["discrete", "both"]:111info_dim += self.MODEL.info_num_discrete_c*self.MODEL.info_dim_discrete_c112if self.MODEL.info_type in ["continuous", "both"]:113info_dim += self.MODEL.info_num_conti_c114115if self.MODEL.info_type != "N/A":116if self.MODEL.g_info_injection == "concat":117self.info_mix_linear = MODULES.g_linear(in_features=self.z_dim + info_dim, out_features=self.z_dim, bias=True)118elif self.MODEL.g_info_injection == "cBN":119self.affine_input_dim += self.g_shared_dim120self.info_proj_linear = MODULES.g_linear(in_features=info_dim, out_features=self.g_shared_dim, bias=True)121122if self.g_cond_mtd != "W/O":123self.affine_input_dim += self.g_shared_dim124self.shared = ops.embedding(num_embeddings=self.num_classes, embedding_dim=self.g_shared_dim)125126self.linear0 = MODULES.g_linear(in_features=self.affine_input_dim, out_features=self.in_dims[0]*self.bottom*self.bottom, bias=True)127128129self.blocks = []130for index in range(self.num_blocks):131self.blocks += [[132GenBlock(in_channels=self.in_dims[index],133out_channels=self.in_dims[index] if g_index == 0 else self.out_dims[index],134g_cond_mtd=g_cond_mtd,135affine_input_dim=self.affine_input_dim,136upsample=True if g_index == (g_depth - 1) else False,137MODULES=MODULES)138] for g_index in range(g_depth)]139140if index + 1 in attn_g_loc and apply_attn:141self.blocks += [[ops.SelfAttention(self.out_dims[index], is_generator=True, MODULES=MODULES)]]142143self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])144145self.bn4 = ops.batchnorm_2d(in_features=self.out_dims[-1])146self.activation = MODULES.g_act_fn147self.conv2d5 = MODULES.g_conv2d(in_channels=self.out_dims[-1], out_channels=3, kernel_size=3, stride=1, padding=1)148self.tanh = nn.Tanh()149150ops.init_weights(self.modules, g_init)151152def forward(self, z, label, shared_label=None, eval=False):153affine_list = []154with torch.cuda.amp.autocast() if self.mixed_precision and not eval else misc.dummy_context_mgr() as mp:155if self.MODEL.info_type != "N/A":156if self.MODEL.g_info_injection == "concat":157z = self.info_mix_linear(z)158elif self.MODEL.g_info_injection == "cBN":159z, z_info = z[:, :self.z_dim], z[:, self.z_dim:]160affine_list.append(self.info_proj_linear(z_info))161162if self.g_cond_mtd != "W/O":163if shared_label is None:164shared_label = self.shared(label)165affine_list.append(shared_label)166if len(affine_list) > 0:167z = torch.cat(affine_list + [z], 1)168169affine = z170act = self.linear0(z)171act = act.view(-1, self.in_dims[0], self.bottom, self.bottom)172for index, blocklist in enumerate(self.blocks):173for block in blocklist:174if isinstance(block, ops.SelfAttention):175act = block(act)176else:177act = block(act, affine)178179act = self.bn4(act)180act = self.activation(act)181act = self.conv2d5(act)182out = self.tanh(act)183return out184185186class DiscBlock(nn.Module):187def __init__(self, in_channels, out_channels, MODULES, downsample=True, channel_ratio=4):188super(DiscBlock, self).__init__()189self.downsample = downsample190hidden_channels = out_channels // channel_ratio191192self.activation = MODULES.d_act_fn193self.conv2d1 = MODULES.d_conv2d(in_channels=in_channels, out_channels=hidden_channels, kernel_size=1, stride=1, padding=0)194self.conv2d2 = MODULES.d_conv2d(in_channels=hidden_channels, out_channels=hidden_channels, kernel_size=3, stride=1, padding=1)195self.conv2d3 = MODULES.d_conv2d(in_channels=hidden_channels, out_channels=hidden_channels, kernel_size=3, stride=1, padding=1)196self.conv2d4 = MODULES.d_conv2d(in_channels=hidden_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)197198self.learnable_sc = True if (in_channels != out_channels) else False199if self.learnable_sc:200self.conv2d0 = MODULES.d_conv2d(in_channels=in_channels,201out_channels=out_channels - in_channels,202kernel_size=1,203stride=1,204padding=0)205206if self.downsample:207self.average_pooling = nn.AvgPool2d(2)208209def forward(self, x):210x0 = x211212x = self.conv2d1(self.activation(x))213x = self.conv2d2(self.activation(x))214x = self.conv2d3(self.activation(x))215x = self.activation(x)216217if self.downsample:218x = self.average_pooling(x)219220x = self.conv2d4(x)221222if self.downsample:223x0 = self.average_pooling(x0)224if self.learnable_sc:225x0 = torch.cat([x0, self.conv2d0(x0)], 1)226227out = x + x0228return out229230231class Discriminator(nn.Module):232def __init__(self, img_size, d_conv_dim, apply_d_sn, apply_attn, attn_d_loc, d_cond_mtd, aux_cls_type, d_embed_dim, normalize_d_embed,233num_classes, d_init, d_depth, mixed_precision, MODULES, MODEL):234super(Discriminator, self).__init__()235d_in_dims_collection = {236"32": [d_conv_dim * 4, d_conv_dim * 4, d_conv_dim * 4],237"64": [d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8],238"128": [d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 16],239"256": [d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 8, d_conv_dim * 16],240"512": [d_conv_dim, d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 8, d_conv_dim * 16]241}242243d_out_dims_collection = {244"32": [d_conv_dim * 4, d_conv_dim * 4, d_conv_dim * 4],245"64": [d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 16],246"128": [d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 16, d_conv_dim * 16],247"256": [d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 8, d_conv_dim * 16, d_conv_dim * 16],248"512":249[d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 8, d_conv_dim * 16, d_conv_dim * 16]250}251252d_down = {253"32": [True, True, False, False],254"64": [True, True, True, True, False],255"128": [True, True, True, True, True, False],256"256": [True, True, True, True, True, True, False],257"512": [True, True, True, True, True, True, True, False]258}259260self.d_cond_mtd = d_cond_mtd261self.aux_cls_type = aux_cls_type262self.normalize_d_embed = normalize_d_embed263self.num_classes = num_classes264self.mixed_precision = mixed_precision265self.in_dims = d_in_dims_collection[str(img_size)]266self.out_dims = d_out_dims_collection[str(img_size)]267self.MODEL = MODEL268down = d_down[str(img_size)]269270self.input_conv = MODULES.d_conv2d(in_channels=3, out_channels=self.in_dims[0], kernel_size=3, stride=1, padding=1)271272self.blocks = []273for index in range(len(self.in_dims)):274self.blocks += [[275DiscBlock(in_channels=self.in_dims[index] if d_index == 0 else self.out_dims[index],276out_channels=self.out_dims[index],277MODULES=MODULES,278downsample=True if down[index] and d_index == 0 else False)279] for d_index in range(d_depth)]280281if (index+1) in attn_d_loc and apply_attn:282self.blocks += [[ops.SelfAttention(self.out_dims[index], is_generator=False, MODULES=MODULES)]]283284self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])285286self.activation = MODULES.d_act_fn287288# linear layer for adversarial training289if self.d_cond_mtd == "MH":290self.linear1 = MODULES.d_linear(in_features=self.out_dims[-1], out_features=1 + num_classes, bias=True)291elif self.d_cond_mtd == "MD":292self.linear1 = MODULES.d_linear(in_features=self.out_dims[-1], out_features=num_classes, bias=True)293else:294self.linear1 = MODULES.d_linear(in_features=self.out_dims[-1], out_features=1, bias=True)295296# double num_classes for Auxiliary Discriminative Classifier297if self.aux_cls_type == "ADC":298num_classes = num_classes * 2299300# linear and embedding layers for discriminator conditioning301if self.d_cond_mtd == "AC":302self.linear2 = MODULES.d_linear(in_features=self.out_dims[-1], out_features=num_classes, bias=False)303elif self.d_cond_mtd == "PD":304self.embedding = MODULES.d_embedding(num_classes, self.out_dims[-1])305elif self.d_cond_mtd in ["2C", "D2DCE"]:306self.linear2 = MODULES.d_linear(in_features=self.out_dims[-1], out_features=d_embed_dim, bias=True)307self.embedding = MODULES.d_embedding(num_classes, d_embed_dim)308else:309pass310311# linear and embedding layers for evolved classifier-based GAN312if self.aux_cls_type == "TAC":313if self.d_cond_mtd == "AC":314self.linear_mi = MODULES.d_linear(in_features=self.out_dims[-1], out_features=num_classes, bias=False)315elif self.d_cond_mtd in ["2C", "D2DCE"]:316self.linear_mi = MODULES.d_linear(in_features=self.out_dims[-1], out_features=d_embed_dim, bias=True)317self.embedding_mi = MODULES.d_embedding(num_classes, d_embed_dim)318else:319raise NotImplementedError320321# Q head network for infoGAN322if self.MODEL.info_type in ["discrete", "both"]:323out_features = self.MODEL.info_num_discrete_c*self.MODEL.info_dim_discrete_c324self.info_discrete_linear = MODULES.d_linear(in_features=self.out_dims[-1], out_features=out_features, bias=False)325if self.MODEL.info_type in ["continuous", "both"]:326out_features = self.MODEL.info_num_conti_c327self.info_conti_mu_linear = MODULES.d_linear(in_features=self.out_dims[-1], out_features=out_features, bias=False)328self.info_conti_var_linear = MODULES.d_linear(in_features=self.out_dims[-1], out_features=out_features, bias=False)329330if d_init:331ops.init_weights(self.modules, d_init)332333def forward(self, x, label, eval=False, adc_fake=False):334with torch.cuda.amp.autocast() if self.mixed_precision and not eval else misc.dummy_context_mgr() as mp:335embed, proxy, cls_output = None, None, None336mi_embed, mi_proxy, mi_cls_output = None, None, None337info_discrete_c_logits, info_conti_mu, info_conti_var = None, None, None338h = self.input_conv(x)339for index, blocklist in enumerate(self.blocks):340for block in blocklist:341h = block(h)342bottom_h, bottom_w = h.shape[2], h.shape[3]343h = self.activation(h)344h = torch.sum(h, dim=[2, 3])345346# adversarial training347adv_output = torch.squeeze(self.linear1(h))348349# make class labels odd (for fake) or even (for real) for ADC350if self.aux_cls_type == "ADC":351if adc_fake:352label = label*2 + 1353else:354label = label*2355356# forward pass through InfoGAN Q head357if self.MODEL.info_type in ["discrete", "both"]:358info_discrete_c_logits = self.info_discrete_linear(h/(bottom_h*bottom_w))359if self.MODEL.info_type in ["continuous", "both"]:360info_conti_mu = self.info_conti_mu_linear(h/(bottom_h*bottom_w))361info_conti_var = torch.exp(self.info_conti_var_linear(h/(bottom_h*bottom_w)))362363# class conditioning364if self.d_cond_mtd == "AC":365if self.normalize_d_embed:366for W in self.linear2.parameters():367W = F.normalize(W, dim=1)368h = F.normalize(h, dim=1)369cls_output = self.linear2(h)370elif self.d_cond_mtd == "PD":371adv_output = adv_output + torch.sum(torch.mul(self.embedding(label), h), 1)372elif self.d_cond_mtd in ["2C", "D2DCE"]:373embed = self.linear2(h)374proxy = self.embedding(label)375if self.normalize_d_embed:376embed = F.normalize(embed, dim=1)377proxy = F.normalize(proxy, dim=1)378elif self.d_cond_mtd == "MD":379idx = torch.LongTensor(range(label.size(0))).to(label.device)380adv_output = adv_output[idx, label]381elif self.d_cond_mtd in ["W/O", "MH"]:382pass383else:384raise NotImplementedError385386# extra conditioning for TACGAN and ADCGAN387if self.aux_cls_type == "TAC":388if self.d_cond_mtd == "AC":389if self.normalize_d_embed:390for W in self.linear_mi.parameters():391W = F.normalize(W, dim=1)392mi_cls_output = self.linear_mi(h)393elif self.d_cond_mtd in ["2C", "D2DCE"]:394mi_embed = self.linear_mi(h)395mi_proxy = self.embedding_mi(label)396if self.normalize_d_embed:397mi_embed = F.normalize(mi_embed, dim=1)398mi_proxy = F.normalize(mi_proxy, dim=1)399return {400"h": h,401"adv_output": adv_output,402"embed": embed,403"proxy": proxy,404"cls_output": cls_output,405"label": label,406"mi_embed": mi_embed,407"mi_proxy": mi_proxy,408"mi_cls_output": mi_cls_output,409"info_discrete_c_logits": info_discrete_c_logits,410"info_conti_mu": info_conti_mu,411"info_conti_var": info_conti_var412}413414415