Path: blob/master/src/models/deep_conv.py
809 views
# PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN1# The MIT License (MIT)2# See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details34# models/deep_conv.py56import torch7import torch.nn as nn8import torch.nn.functional as F910import utils.ops as ops11import utils.misc as misc121314class GenBlock(nn.Module):15def __init__(self, in_channels, out_channels, g_cond_mtd, g_info_injection, affine_input_dim, MODULES):16super(GenBlock, self).__init__()17self.g_cond_mtd = g_cond_mtd18self.g_info_injection = g_info_injection1920self.deconv0 = MODULES.g_deconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=4, stride=2, padding=1)2122if self.g_cond_mtd == "W/O" and self.g_info_injection in ["N/A", "concat"]:23self.bn0 = MODULES.g_bn(in_features=out_channels)24elif self.g_cond_mtd == "cBN" or self.g_info_injection == "cBN":25self.bn0 = MODULES.g_bn(affine_input_dim, out_channels, MODULES)26else:27raise NotImplementedError2829self.activation = MODULES.g_act_fn3031def forward(self, x, affine):32x = self.deconv0(x)33if self.g_cond_mtd == "W/O" and self.g_info_injection in ["N/A", "concat"]:34x = self.bn0(x)35elif self.g_cond_mtd == "cBN" or self.g_info_injection == "cBN":36x = self.bn0(x, affine)37out = self.activation(x)38return out394041class Generator(nn.Module):42def __init__(self, z_dim, g_shared_dim, img_size, g_conv_dim, apply_attn, attn_g_loc, g_cond_mtd, num_classes, g_init, g_depth,43mixed_precision, MODULES, MODEL):44super(Generator, self).__init__()45self.in_dims = [512, 256, 128]46self.out_dims = [256, 128, 64]4748self.z_dim = z_dim49self.num_classes = num_classes50self.g_cond_mtd = g_cond_mtd51self.mixed_precision = mixed_precision52self.MODEL = MODEL53self.affine_input_dim = 05455info_dim = 056if self.MODEL.info_type in ["discrete", "both"]:57info_dim += self.MODEL.info_num_discrete_c*self.MODEL.info_dim_discrete_c58if self.MODEL.info_type in ["continuous", "both"]:59info_dim += self.MODEL.info_num_conti_c6061self.g_info_injection = self.MODEL.g_info_injection62if self.MODEL.info_type != "N/A":63if self.g_info_injection == "concat":64self.info_mix_linear = MODULES.g_linear(in_features=self.z_dim + info_dim, out_features=self.z_dim, bias=True)65elif self.g_info_injection == "cBN":66self.affine_input_dim += self.z_dim67self.info_proj_linear = MODULES.g_linear(in_features=info_dim, out_features=self.z_dim, bias=True)6869if self.g_cond_mtd != "W/O" and self.g_cond_mtd == "cBN":70self.affine_input_dim += self.num_classes7172self.linear0 = MODULES.g_linear(in_features=self.z_dim, out_features=self.in_dims[0]*4*4, bias=True)7374self.blocks = []75for index in range(len(self.in_dims)):76self.blocks += [[77GenBlock(in_channels=self.in_dims[index],78out_channels=self.out_dims[index],79g_cond_mtd=self.g_cond_mtd,80g_info_injection=self.g_info_injection,81affine_input_dim=self.affine_input_dim,82MODULES=MODULES)83]]8485if index + 1 in attn_g_loc and apply_attn:86self.blocks += [[ops.SelfAttention(self.out_dims[index], is_generator=True, MODULES=MODULES)]]8788self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])8990self.conv4 = MODULES.g_conv2d(in_channels=self.out_dims[-1], out_channels=3, kernel_size=3, stride=1, padding=1)91self.tanh = nn.Tanh()9293ops.init_weights(self.modules, g_init)9495def forward(self, z, label, shared_label=None, eval=False):96affine_list = []97if self.g_cond_mtd != "W/O":98label = F.one_hot(label, num_classes=self.num_classes).to(torch.float32)99with torch.cuda.amp.autocast() if self.mixed_precision and not eval else misc.dummy_context_mgr() as mp:100if self.MODEL.info_type != "N/A":101if self.g_info_injection == "concat":102z = self.info_mix_linear(z)103elif self.g_info_injection == "cBN":104z, z_info = z[:, :self.z_dim], z[:, self.z_dim:]105affine_list.append(self.info_proj_linear(z_info))106107if self.g_cond_mtd != "W/O":108affine_list.append(label)109if len(affine_list) > 0:110affines = torch.cat(affine_list, 1)111else:112affines = None113114act = self.linear0(z)115act = act.view(-1, self.in_dims[0], 4, 4)116for index, blocklist in enumerate(self.blocks):117for block in blocklist:118if isinstance(block, ops.SelfAttention):119act = block(act)120else:121act = block(act, affines)122123act = self.conv4(act)124out = self.tanh(act)125return out126127128class DiscBlock(nn.Module):129def __init__(self, in_channels, out_channels, apply_d_sn, MODULES):130super(DiscBlock, self).__init__()131self.apply_d_sn = apply_d_sn132133self.conv0 = MODULES.d_conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)134self.conv1 = MODULES.d_conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=4, stride=2, padding=1)135136if not apply_d_sn:137self.bn0 = MODULES.d_bn(in_features=out_channels)138self.bn1 = MODULES.d_bn(in_features=out_channels)139140self.activation = MODULES.d_act_fn141142def forward(self, x):143x = self.conv0(x)144if not self.apply_d_sn:145x = self.bn0(x)146x = self.activation(x)147148x = self.conv1(x)149if not self.apply_d_sn:150x = self.bn1(x)151out = self.activation(x)152return out153154155class Discriminator(nn.Module):156def __init__(self, img_size, d_conv_dim, apply_d_sn, apply_attn, attn_d_loc, d_cond_mtd, aux_cls_type, d_embed_dim, normalize_d_embed,157num_classes, d_init, d_depth, mixed_precision, MODULES, MODEL):158super(Discriminator, self).__init__()159self.in_dims = [3] + [64, 128]160self.out_dims = [64, 128, 256]161162self.apply_d_sn = apply_d_sn163self.d_cond_mtd = d_cond_mtd164self.aux_cls_type = aux_cls_type165self.normalize_d_embed = normalize_d_embed166self.num_classes = num_classes167self.mixed_precision = mixed_precision168self.MODEL= MODEL169170self.blocks = []171for index in range(len(self.in_dims)):172self.blocks += [[173DiscBlock(in_channels=self.in_dims[index], out_channels=self.out_dims[index], apply_d_sn=self.apply_d_sn, MODULES=MODULES)174]]175176if index + 1 in attn_d_loc and apply_attn:177self.blocks += [[ops.SelfAttention(self.out_dims[index], is_generator=False, MODULES=MODULES)]]178179self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])180181self.activation = MODULES.d_act_fn182self.conv1 = MODULES.d_conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1)183184if not self.apply_d_sn:185self.bn1 = MODULES.d_bn(in_features=512)186187# linear layer for adversarial training188if self.d_cond_mtd == "MH":189self.linear1 = MODULES.d_linear(in_features=512, out_features=1 + num_classes, bias=True)190elif self.d_cond_mtd == "MD":191self.linear1 = MODULES.d_linear(in_features=512, out_features=num_classes, bias=True)192else:193self.linear1 = MODULES.d_linear(in_features=512, out_features=1, bias=True)194195# double num_classes for Auxiliary Discriminative Classifier196if self.aux_cls_type == "ADC":197num_classes = num_classes * 2198199# linear and embedding layers for discriminator conditioning200if self.d_cond_mtd == "AC":201self.linear2 = MODULES.d_linear(in_features=512, out_features=num_classes, bias=False)202elif self.d_cond_mtd == "PD":203self.embedding = MODULES.d_embedding(num_classes, 512)204elif self.d_cond_mtd in ["2C", "D2DCE"]:205self.linear2 = MODULES.d_linear(in_features=512, out_features=d_embed_dim, bias=True)206self.embedding = MODULES.d_embedding(num_classes, d_embed_dim)207else:208pass209210# linear and embedding layers for evolved classifier-based GAN211if self.aux_cls_type == "TAC":212if self.d_cond_mtd == "AC":213self.linear_mi = MODULES.d_linear(in_features=512, out_features=num_classes, bias=False)214elif self.d_cond_mtd in ["2C", "D2DCE"]:215self.linear_mi = MODULES.d_linear(in_features=512, out_features=d_embed_dim, bias=True)216self.embedding_mi = MODULES.d_embedding(num_classes, d_embed_dim)217else:218raise NotImplementedError219220# Q head network for infoGAN221if self.MODEL.info_type in ["discrete", "both"]:222out_features = self.MODEL.info_num_discrete_c*self.MODEL.info_dim_discrete_c223self.info_discrete_linear = MODULES.d_linear(in_features=512, out_features=out_features, bias=False)224if self.MODEL.info_type in ["continuous", "both"]:225out_features = self.MODEL.info_num_conti_c226self.info_conti_mu_linear = MODULES.d_linear(in_features=512, out_features=out_features, bias=False)227self.info_conti_var_linear = MODULES.d_linear(in_features=512, out_features=out_features, bias=False)228229if d_init:230ops.init_weights(self.modules, d_init)231232def forward(self, x, label, eval=False, adc_fake=False):233with torch.cuda.amp.autocast() if self.mixed_precision and not eval else misc.dummy_context_mgr() as mp:234embed, proxy, cls_output = None, None, None235mi_embed, mi_proxy, mi_cls_output = None, None, None236info_discrete_c_logits, info_conti_mu, info_conti_var = None, None, None237h = x238for index, blocklist in enumerate(self.blocks):239for block in blocklist:240h = block(h)241h = self.conv1(h)242if not self.apply_d_sn:243h = self.bn1(h)244bottom_h, bottom_w = h.shape[2], h.shape[3]245h = self.activation(h)246h = torch.sum(h, dim=[2, 3])247248# adversarial training249adv_output = torch.squeeze(self.linear1(h))250251# make class labels odd (for fake) or even (for real) for ADC252if self.aux_cls_type == "ADC":253if adc_fake:254label = label*2 + 1255else:256label = label*2257258# forward pass through InfoGAN Q head259if self.MODEL.info_type in ["discrete", "both"]:260info_discrete_c_logits = self.info_discrete_linear(h/(bottom_h*bottom_w))261if self.MODEL.info_type in ["continuous", "both"]:262info_conti_mu = self.info_conti_mu_linear(h/(bottom_h*bottom_w))263info_conti_var = torch.exp(self.info_conti_var_linear(h/(bottom_h*bottom_w)))264265# class conditioning266if self.d_cond_mtd == "AC":267if self.normalize_d_embed:268for W in self.linear2.parameters():269W = F.normalize(W, dim=1)270h = F.normalize(h, dim=1)271cls_output = self.linear2(h)272elif self.d_cond_mtd == "PD":273adv_output = adv_output + torch.sum(torch.mul(self.embedding(label), h), 1)274elif self.d_cond_mtd in ["2C", "D2DCE"]:275embed = self.linear2(h)276proxy = self.embedding(label)277if self.normalize_d_embed:278embed = F.normalize(embed, dim=1)279proxy = F.normalize(proxy, dim=1)280elif self.d_cond_mtd == "MD":281idx = torch.LongTensor(range(label.size(0))).to(label.device)282adv_output = adv_output[idx, label]283elif self.d_cond_mtd in ["W/O", "MH"]:284pass285else:286raise NotImplementedError287288# extra conditioning for TACGAN and ADCGAN289if self.aux_cls_type == "TAC":290if self.d_cond_mtd == "AC":291if self.normalize_d_embed:292for W in self.linear_mi.parameters():293W = F.normalize(W, dim=1)294mi_cls_output = self.linear_mi(h)295elif self.d_cond_mtd in ["2C", "D2DCE"]:296mi_embed = self.linear_mi(h)297mi_proxy = self.embedding_mi(label)298if self.normalize_d_embed:299mi_embed = F.normalize(mi_embed, dim=1)300mi_proxy = F.normalize(mi_proxy, dim=1)301return {302"h": h,303"adv_output": adv_output,304"embed": embed,305"proxy": proxy,306"cls_output": cls_output,307"label": label,308"mi_embed": mi_embed,309"mi_proxy": mi_proxy,310"mi_cls_output": mi_cls_output,311"info_discrete_c_logits": info_discrete_c_logits,312"info_conti_mu": info_conti_mu,313"info_conti_var": info_conti_var314}315316317