Path: blob/master/src/models/big_resnet_deep_studiogan.py
809 views
# PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN1# The MIT License (MIT)2# See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details34# models/deep_big_resnet.py56import torch7import torch.nn as nn8import torch.nn.functional as F910import utils.ops as ops11import utils.misc as misc121314class GenBlock(nn.Module):15def __init__(self, in_channels, out_channels, g_cond_mtd, affine_input_dim, upsample,16MODULES, channel_ratio=4):17super(GenBlock, self).__init__()18self.in_channels = in_channels19self.out_channels = out_channels20self.g_cond_mtd = g_cond_mtd21self.upsample = upsample22self.hidden_channels = self.in_channels // channel_ratio2324self.bn1 = MODULES.g_bn(affine_input_dim, self.in_channels, MODULES)25self.bn2 = MODULES.g_bn(affine_input_dim, self.hidden_channels, MODULES)26self.bn3 = MODULES.g_bn(affine_input_dim, self.hidden_channels, MODULES)27self.bn4 = MODULES.g_bn(affine_input_dim, self.hidden_channels, MODULES)2829self.activation = MODULES.g_act_fn30self.conv2d0 = MODULES.g_conv2d(in_channels=self.in_channels,31out_channels=self.out_channels,32kernel_size=1,33stride=1,34padding=0)35self.conv2d1 = MODULES.g_conv2d(in_channels=self.in_channels,36out_channels=self.hidden_channels,37kernel_size=1,38stride=1,39padding=0)40self.conv2d2 = MODULES.g_conv2d(in_channels=self.hidden_channels,41out_channels=self.hidden_channels,42kernel_size=3,43stride=1,44padding=1)45self.conv2d3 = MODULES.g_conv2d(in_channels=self.hidden_channels,46out_channels=self.hidden_channels,47kernel_size=3,48stride=1,49padding=1)50self.conv2d4 = MODULES.g_conv2d(in_channels=self.hidden_channels,51out_channels=self.out_channels,52kernel_size=1,53stride=1,54padding=0)5556def forward(self, x, affine):57x0 = x58x = self.bn1(x, affine)59x = self.conv2d1(self.activation(x))6061x = self.bn2(x, affine)62x = self.activation(x)63if self.upsample:64x = F.interpolate(x, scale_factor=2, mode="nearest") # upsample65x = self.conv2d2(x)6667x = self.bn3(x, affine)68x = self.conv2d3(self.activation(x))6970x = self.bn4(x, affine)71x = self.conv2d4(self.activation(x))7273if self.upsample:74x0 = F.interpolate(x0, scale_factor=2, mode="nearest") # upsample75x0 = self.conv2d0(x0)76out = x + x077return out787980class Generator(nn.Module):81def __init__(self, z_dim, g_shared_dim, img_size, g_conv_dim, apply_attn, attn_g_loc, g_cond_mtd, num_classes, g_init, g_depth,82mixed_precision, MODULES, MODEL):83super(Generator, self).__init__()84g_in_dims_collection = {85"32": [g_conv_dim * 4, g_conv_dim * 4, g_conv_dim * 4],86"64": [g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2],87"128": [g_conv_dim * 16, g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2],88"256": [g_conv_dim * 16, g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2],89"512": [g_conv_dim * 16, g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim]90}9192g_out_dims_collection = {93"32": [g_conv_dim * 4, g_conv_dim * 4, g_conv_dim * 4],94"64": [g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim],95"128": [g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim],96"256": [g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim],97"512": [g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim, g_conv_dim]98}99100bottom_collection = {"32": 4, "64": 4, "128": 4, "256": 4, "512": 4}101102self.z_dim = z_dim103self.g_shared_dim = g_shared_dim104self.g_cond_mtd = g_cond_mtd105self.num_classes = num_classes106self.mixed_precision = mixed_precision107self.MODEL = MODEL108self.in_dims = g_in_dims_collection[str(img_size)]109self.out_dims = g_out_dims_collection[str(img_size)]110self.bottom = bottom_collection[str(img_size)]111self.num_blocks = len(self.in_dims)112self.affine_input_dim = self.z_dim113114info_dim = 0115if self.MODEL.info_type in ["discrete", "both"]:116info_dim += self.MODEL.info_num_discrete_c*self.MODEL.info_dim_discrete_c117if self.MODEL.info_type in ["continuous", "both"]:118info_dim += self.MODEL.info_num_conti_c119120if self.MODEL.info_type != "N/A":121if self.MODEL.g_info_injection == "concat":122self.info_mix_linear = MODULES.g_linear(in_features=self.z_dim + info_dim, out_features=self.z_dim, bias=True)123elif self.MODEL.g_info_injection == "cBN":124self.affine_input_dim += self.g_shared_dim125self.info_proj_linear = MODULES.g_linear(in_features=info_dim, out_features=self.g_shared_dim, bias=True)126127if self.g_cond_mtd != "W/O":128self.affine_input_dim += self.g_shared_dim129self.shared = ops.embedding(num_embeddings=self.num_classes, embedding_dim=self.g_shared_dim)130131self.linear0 = MODULES.g_linear(in_features=self.affine_input_dim, out_features=self.in_dims[0]*self.bottom*self.bottom, bias=True)132133134self.blocks = []135for index in range(self.num_blocks):136self.blocks += [[137GenBlock(in_channels=self.in_dims[index],138out_channels=self.in_dims[index] if g_index == 0 else self.out_dims[index],139g_cond_mtd=g_cond_mtd,140affine_input_dim=self.affine_input_dim,141upsample=True if g_index == (g_depth - 1) else False,142MODULES=MODULES)143] for g_index in range(g_depth)]144145if index + 1 in attn_g_loc and apply_attn:146self.blocks += [[ops.SelfAttention(self.out_dims[index], is_generator=True, MODULES=MODULES)]]147148self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])149150self.bn4 = ops.batchnorm_2d(in_features=self.out_dims[-1])151self.activation = MODULES.g_act_fn152self.conv2d5 = MODULES.g_conv2d(in_channels=self.out_dims[-1], out_channels=3, kernel_size=3, stride=1, padding=1)153self.tanh = nn.Tanh()154155ops.init_weights(self.modules, g_init)156157def forward(self, z, label, shared_label=None, eval=False):158affine_list = []159with torch.cuda.amp.autocast() if self.mixed_precision and not eval else misc.dummy_context_mgr() as mp:160if self.MODEL.info_type != "N/A":161if self.MODEL.g_info_injection == "concat":162z = self.info_mix_linear(z)163elif self.MODEL.g_info_injection == "cBN":164z, z_info = z[:, :self.z_dim], z[:, self.z_dim:]165affine_list.append(self.info_proj_linear(z_info))166167if self.g_cond_mtd != "W/O":168if shared_label is None:169shared_label = self.shared(label)170affine_list.append(shared_label)171if len(affine_list) > 0:172z = torch.cat(affine_list + [z], 1)173174affine = z175act = self.linear0(z)176act = act.view(-1, self.in_dims[0], self.bottom, self.bottom)177for index, blocklist in enumerate(self.blocks):178for block in blocklist:179if isinstance(block, ops.SelfAttention):180act = block(act)181else:182act = block(act, affine)183184act = self.bn4(act)185act = self.activation(act)186act = self.conv2d5(act)187out = self.tanh(act)188return out189190191class DiscBlock(nn.Module):192def __init__(self, in_channels, out_channels, MODULES, optblock, downsample=True, channel_ratio=4):193super(DiscBlock, self).__init__()194self.optblock = optblock195self.downsample = downsample196hidden_channels = out_channels // channel_ratio197self.ch_mismatch = True if (in_channels != out_channels) else False198if self.optblock: assert self.downsample and self.ch_mismatch, "downsample and ch_mismatch should be True."199200self.activation = MODULES.d_act_fn201self.conv2d1 = MODULES.d_conv2d(in_channels=in_channels,202out_channels=hidden_channels,203kernel_size=1,204stride=1,205padding=0)206self.conv2d2 = MODULES.d_conv2d(in_channels=hidden_channels,207out_channels=hidden_channels,208kernel_size=3,209stride=1,210padding=1)211self.conv2d3 = MODULES.d_conv2d(in_channels=hidden_channels,212out_channels=hidden_channels,213kernel_size=3,214stride=1,215padding=1)216self.conv2d4 = MODULES.d_conv2d(in_channels=hidden_channels,217out_channels=out_channels,218kernel_size=1,219stride=1,220padding=0)221222if self.ch_mismatch or self.downsample:223self.conv2d0 = MODULES.d_conv2d(in_channels=in_channels,224out_channels=out_channels,225kernel_size=1,226stride=1,227padding=0)228229if self.downsample:230self.average_pooling = nn.AvgPool2d(2)231232def forward(self, x):233x0 = x234x = self.conv2d1(self.activation(x))235x = self.conv2d2(self.activation(x))236x = self.conv2d3(self.activation(x))237if self.downsample:238x = self.average_pooling(x)239x = self.conv2d4(self.activation(x))240241if self.optblock:242x0 = self.average_pooling(x0)243x0 = self.conv2d0(x0)244else:245if self.downsample or self.ch_mismatch:246x0 = self.conv2d0(x0)247if self.downsample:248x0 = self.average_pooling(x0)249out = x + x0250return out251252253class Discriminator(nn.Module):254def __init__(self, img_size, d_conv_dim, apply_d_sn, apply_attn, attn_d_loc, d_cond_mtd, aux_cls_type, d_embed_dim, normalize_d_embed,255num_classes, d_init, d_depth, mixed_precision, MODULES, MODEL):256super(Discriminator, self).__init__()257d_in_dims_collection = {258"32": [d_conv_dim, d_conv_dim * 4, d_conv_dim * 4],259"64": [d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8],260"128": [d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 16],261"256": [d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 8, d_conv_dim * 16],262"512": [d_conv_dim, d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 8, d_conv_dim * 16]263}264265d_out_dims_collection = {266"32": [d_conv_dim * 4, d_conv_dim * 4, d_conv_dim * 4],267"64": [d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 16],268"128": [d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 16, d_conv_dim * 16],269"256": [d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 8, d_conv_dim * 16, d_conv_dim * 16],270"512":271[d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 8, d_conv_dim * 16, d_conv_dim * 16]272}273274d_down = {275"32": [True, True, False, False],276"64": [True, True, True, True, False],277"128": [True, True, True, True, True, False],278"256": [True, True, True, True, True, True, False],279"512": [True, True, True, True, True, True, True, False]280}281282self.d_cond_mtd = d_cond_mtd283self.aux_cls_type = aux_cls_type284self.normalize_d_embed = normalize_d_embed285self.num_classes = num_classes286self.mixed_precision = mixed_precision287self.in_dims = d_in_dims_collection[str(img_size)]288self.out_dims = d_out_dims_collection[str(img_size)]289self.MODEL = MODEL290down = d_down[str(img_size)]291292self.input_conv = MODULES.d_conv2d(in_channels=3, out_channels=self.in_dims[0], kernel_size=3, stride=1, padding=1)293294self.blocks = []295for index in range(len(self.in_dims)):296self.blocks += [[297DiscBlock(in_channels=self.in_dims[index] if d_index == 0 else self.out_dims[index],298out_channels=self.out_dims[index],299MODULES=MODULES,300optblock=index == 0 and d_index == 0,301downsample=True if down[index] and d_index == 0 else False)302] for d_index in range(d_depth)]303304if (index+1) in attn_d_loc and apply_attn:305self.blocks += [[ops.SelfAttention(self.out_dims[index], is_generator=False, MODULES=MODULES)]]306307self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])308309self.activation = MODULES.d_act_fn310311# linear layer for adversarial training312if self.d_cond_mtd == "MH":313self.linear1 = MODULES.d_linear(in_features=self.out_dims[-1], out_features=1 + num_classes, bias=True)314elif self.d_cond_mtd == "MD":315self.linear1 = MODULES.d_linear(in_features=self.out_dims[-1], out_features=num_classes, bias=True)316else:317self.linear1 = MODULES.d_linear(in_features=self.out_dims[-1], out_features=1, bias=True)318319# double num_classes for Auxiliary Discriminative Classifier320if self.aux_cls_type == "ADC":321num_classes = num_classes * 2322323# linear and embedding layers for discriminator conditioning324if self.d_cond_mtd == "AC":325self.linear2 = MODULES.d_linear(in_features=self.out_dims[-1], out_features=num_classes, bias=False)326elif self.d_cond_mtd == "PD":327self.embedding = MODULES.d_embedding(num_classes, self.out_dims[-1])328elif self.d_cond_mtd in ["2C", "D2DCE"]:329self.linear2 = MODULES.d_linear(in_features=self.out_dims[-1], out_features=d_embed_dim, bias=True)330self.embedding = MODULES.d_embedding(num_classes, d_embed_dim)331else:332pass333334# linear and embedding layers for evolved classifier-based GAN335if self.aux_cls_type == "TAC":336if self.d_cond_mtd == "AC":337self.linear_mi = MODULES.d_linear(in_features=self.out_dims[-1], out_features=num_classes, bias=False)338elif self.d_cond_mtd in ["2C", "D2DCE"]:339self.linear_mi = MODULES.d_linear(in_features=self.out_dims[-1], out_features=d_embed_dim, bias=True)340self.embedding_mi = MODULES.d_embedding(num_classes, d_embed_dim)341else:342raise NotImplementedError343344# Q head network for infoGAN345if self.MODEL.info_type in ["discrete", "both"]:346out_features = self.MODEL.info_num_discrete_c*self.MODEL.info_dim_discrete_c347self.info_discrete_linear = MODULES.d_linear(in_features=self.out_dims[-1], out_features=out_features, bias=False)348if self.MODEL.info_type in ["continuous", "both"]:349out_features = self.MODEL.info_num_conti_c350self.info_conti_mu_linear = MODULES.d_linear(in_features=self.out_dims[-1], out_features=out_features, bias=False)351self.info_conti_var_linear = MODULES.d_linear(in_features=self.out_dims[-1], out_features=out_features, bias=False)352353if d_init:354ops.init_weights(self.modules, d_init)355356def forward(self, x, label, eval=False, adc_fake=False):357with torch.cuda.amp.autocast() if self.mixed_precision and not eval else misc.dummy_context_mgr() as mp:358embed, proxy, cls_output = None, None, None359mi_embed, mi_proxy, mi_cls_output = None, None, None360info_discrete_c_logits, info_conti_mu, info_conti_var = None, None, None361h = self.input_conv(x)362for index, blocklist in enumerate(self.blocks):363for block in blocklist:364h = block(h)365bottom_h, bottom_w = h.shape[2], h.shape[3]366h = self.activation(h)367h = torch.sum(h, dim=[2, 3])368369# adversarial training370adv_output = torch.squeeze(self.linear1(h))371372# make class labels odd (for fake) or even (for real) for ADC373if self.aux_cls_type == "ADC":374if adc_fake:375label = label*2 + 1376else:377label = label*2378379# forward pass through InfoGAN Q head380if self.MODEL.info_type in ["discrete", "both"]:381info_discrete_c_logits = self.info_discrete_linear(h/(bottom_h*bottom_w))382if self.MODEL.info_type in ["continuous", "both"]:383info_conti_mu = self.info_conti_mu_linear(h/(bottom_h*bottom_w))384info_conti_var = torch.exp(self.info_conti_var_linear(h/(bottom_h*bottom_w)))385386# class conditioning387if self.d_cond_mtd == "AC":388if self.normalize_d_embed:389for W in self.linear2.parameters():390W = F.normalize(W, dim=1)391h = F.normalize(h, dim=1)392cls_output = self.linear2(h)393elif self.d_cond_mtd == "PD":394adv_output = adv_output + torch.sum(torch.mul(self.embedding(label), h), 1)395elif self.d_cond_mtd in ["2C", "D2DCE"]:396embed = self.linear2(h)397proxy = self.embedding(label)398if self.normalize_d_embed:399embed = F.normalize(embed, dim=1)400proxy = F.normalize(proxy, dim=1)401elif self.d_cond_mtd == "MD":402idx = torch.LongTensor(range(label.size(0))).to(label.device)403adv_output = adv_output[idx, label]404elif self.d_cond_mtd in ["W/O", "MH"]:405pass406else:407raise NotImplementedError408409# extra conditioning for TACGAN and ADCGAN410if self.aux_cls_type == "TAC":411if self.d_cond_mtd == "AC":412if self.normalize_d_embed:413for W in self.linear_mi.parameters():414W = F.normalize(W, dim=1)415mi_cls_output = self.linear_mi(h)416elif self.d_cond_mtd in ["2C", "D2DCE"]:417mi_embed = self.linear_mi(h)418mi_proxy = self.embedding_mi(label)419if self.normalize_d_embed:420mi_embed = F.normalize(mi_embed, dim=1)421mi_proxy = F.normalize(mi_proxy, dim=1)422return {423"h": h,424"adv_output": adv_output,425"embed": embed,426"proxy": proxy,427"cls_output": cls_output,428"label": label,429"mi_embed": mi_embed,430"mi_proxy": mi_proxy,431"mi_cls_output": mi_cls_output,432"info_discrete_c_logits": info_discrete_c_logits,433"info_conti_mu": info_conti_mu,434"info_conti_var": info_conti_var435}436437438