Path: blob/master/src/models/big_resnet.py
809 views
# PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN1# The MIT License (MIT)2# See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details34# models/big_resnet.py56import torch7import torch.nn as nn8import torch.nn.functional as F910import utils.ops as ops11import utils.misc as misc121314class GenBlock(nn.Module):15def __init__(self, in_channels, out_channels, g_cond_mtd, affine_input_dim, MODULES):16super(GenBlock, self).__init__()17self.g_cond_mtd = g_cond_mtd1819self.bn1 = MODULES.g_bn(affine_input_dim, in_channels, MODULES)20self.bn2 = MODULES.g_bn(affine_input_dim, out_channels, MODULES)2122self.activation = MODULES.g_act_fn23self.conv2d0 = MODULES.g_conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)24self.conv2d1 = MODULES.g_conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)25self.conv2d2 = MODULES.g_conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)2627def forward(self, x, affine):28x0 = x29x = self.bn1(x, affine)30x = self.activation(x)31x = F.interpolate(x, scale_factor=2, mode="nearest")32x = self.conv2d1(x)3334x = self.bn2(x, affine)35x = self.activation(x)36x = self.conv2d2(x)3738x0 = F.interpolate(x0, scale_factor=2, mode="nearest")39x0 = self.conv2d0(x0)40out = x + x041return out424344class Generator(nn.Module):45def __init__(self, z_dim, g_shared_dim, img_size, g_conv_dim, apply_attn, attn_g_loc, g_cond_mtd, num_classes, g_init, g_depth,46mixed_precision, MODULES, MODEL):47super(Generator, self).__init__()48g_in_dims_collection = {49"32": [g_conv_dim * 4, g_conv_dim * 4, g_conv_dim * 4],50"64": [g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2],51"128": [g_conv_dim * 16, g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2],52"256": [g_conv_dim * 16, g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2],53"512": [g_conv_dim * 16, g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim]54}5556g_out_dims_collection = {57"32": [g_conv_dim * 4, g_conv_dim * 4, g_conv_dim * 4],58"64": [g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim],59"128": [g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim],60"256": [g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim],61"512": [g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim, g_conv_dim]62}6364bottom_collection = {"32": 4, "64": 4, "128": 4, "256": 4, "512": 4}6566self.z_dim = z_dim67self.g_shared_dim = g_shared_dim68self.g_cond_mtd = g_cond_mtd69self.num_classes = num_classes70self.mixed_precision = mixed_precision71self.MODEL = MODEL72self.in_dims = g_in_dims_collection[str(img_size)]73self.out_dims = g_out_dims_collection[str(img_size)]74self.bottom = bottom_collection[str(img_size)]75self.num_blocks = len(self.in_dims)76self.chunk_size = z_dim // (self.num_blocks + 1)77self.affine_input_dim = self.chunk_size78assert self.z_dim % (self.num_blocks + 1) == 0, "z_dim should be divided by the number of blocks"7980info_dim = 081if self.MODEL.info_type in ["discrete", "both"]:82info_dim += self.MODEL.info_num_discrete_c*self.MODEL.info_dim_discrete_c83if self.MODEL.info_type in ["continuous", "both"]:84info_dim += self.MODEL.info_num_conti_c8586if self.MODEL.info_type != "N/A":87if self.MODEL.g_info_injection == "concat":88self.info_mix_linear = MODULES.g_linear(in_features=self.z_dim + info_dim, out_features=self.z_dim, bias=True)89elif self.MODEL.g_info_injection == "cBN":90self.affine_input_dim += self.g_shared_dim91self.info_proj_linear = MODULES.g_linear(in_features=info_dim, out_features=self.g_shared_dim, bias=True)9293self.linear0 = MODULES.g_linear(in_features=self.chunk_size, out_features=self.in_dims[0]*self.bottom*self.bottom, bias=True)9495if self.g_cond_mtd != "W/O":96self.affine_input_dim += self.g_shared_dim97self.shared = ops.embedding(num_embeddings=self.num_classes, embedding_dim=self.g_shared_dim)9899self.blocks = []100for index in range(self.num_blocks):101self.blocks += [[102GenBlock(in_channels=self.in_dims[index],103out_channels=self.out_dims[index],104g_cond_mtd=self.g_cond_mtd,105affine_input_dim=self.affine_input_dim,106MODULES=MODULES)107]]108109if index + 1 in attn_g_loc and apply_attn:110self.blocks += [[ops.SelfAttention(self.out_dims[index], is_generator=True, MODULES=MODULES)]]111112self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])113114self.bn4 = ops.batchnorm_2d(in_features=self.out_dims[-1])115self.activation = MODULES.g_act_fn116self.conv2d5 = MODULES.g_conv2d(in_channels=self.out_dims[-1], out_channels=3, kernel_size=3, stride=1, padding=1)117self.tanh = nn.Tanh()118119ops.init_weights(self.modules, g_init)120121def forward(self, z, label, shared_label=None, eval=False):122affine_list = []123with torch.cuda.amp.autocast() if self.mixed_precision and not eval else misc.dummy_context_mgr() as mp:124if self.MODEL.info_type != "N/A":125if self.MODEL.g_info_injection == "concat":126z = self.info_mix_linear(z)127elif self.MODEL.g_info_injection == "cBN":128z, z_info = z[:, :self.z_dim], z[:, self.z_dim:]129affine_list.append(self.info_proj_linear(z_info))130131zs = torch.split(z, self.chunk_size, 1)132z = zs[0]133if self.g_cond_mtd != "W/O":134if shared_label is None:135shared_label = self.shared(label)136affine_list.append(shared_label)137if len(affine_list) == 0:138affines = [item for item in zs[1:]]139else:140affines = [torch.cat(affine_list + [item], 1) for item in zs[1:]]141142act = self.linear0(z)143act = act.view(-1, self.in_dims[0], self.bottom, self.bottom)144counter = 0145for index, blocklist in enumerate(self.blocks):146for block in blocklist:147if isinstance(block, ops.SelfAttention):148act = block(act)149else:150act = block(act, affines[counter])151counter += 1152153act = self.bn4(act)154act = self.activation(act)155act = self.conv2d5(act)156out = self.tanh(act)157return out158159160class DiscOptBlock(nn.Module):161def __init__(self, in_channels, out_channels, apply_d_sn, MODULES):162super(DiscOptBlock, self).__init__()163self.apply_d_sn = apply_d_sn164165self.conv2d0 = MODULES.d_conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)166self.conv2d1 = MODULES.d_conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)167self.conv2d2 = MODULES.d_conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)168169if not apply_d_sn:170self.bn0 = MODULES.d_bn(in_features=in_channels)171self.bn1 = MODULES.d_bn(in_features=out_channels)172173self.activation = MODULES.d_act_fn174self.average_pooling = nn.AvgPool2d(2)175176def forward(self, x):177x0 = x178x = self.conv2d1(x)179if not self.apply_d_sn:180x = self.bn1(x)181x = self.activation(x)182183x = self.conv2d2(x)184x = self.average_pooling(x)185186x0 = self.average_pooling(x0)187if not self.apply_d_sn:188x0 = self.bn0(x0)189x0 = self.conv2d0(x0)190out = x + x0191return out192193194class DiscBlock(nn.Module):195def __init__(self, in_channels, out_channels, apply_d_sn, MODULES, downsample=True):196super(DiscBlock, self).__init__()197self.apply_d_sn = apply_d_sn198self.downsample = downsample199200self.activation = MODULES.d_act_fn201202self.ch_mismatch = False203if in_channels != out_channels:204self.ch_mismatch = True205206if self.ch_mismatch or downsample:207self.conv2d0 = MODULES.d_conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)208if not apply_d_sn:209self.bn0 = MODULES.d_bn(in_features=in_channels)210211self.conv2d1 = MODULES.d_conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)212self.conv2d2 = MODULES.d_conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)213214if not apply_d_sn:215self.bn1 = MODULES.d_bn(in_features=in_channels)216self.bn2 = MODULES.d_bn(in_features=out_channels)217218self.average_pooling = nn.AvgPool2d(2)219220def forward(self, x):221x0 = x222if not self.apply_d_sn:223x = self.bn1(x)224x = self.activation(x)225x = self.conv2d1(x)226227if not self.apply_d_sn:228x = self.bn2(x)229x = self.activation(x)230x = self.conv2d2(x)231if self.downsample:232x = self.average_pooling(x)233234if self.downsample or self.ch_mismatch:235if not self.apply_d_sn:236x0 = self.bn0(x0)237x0 = self.conv2d0(x0)238if self.downsample:239x0 = self.average_pooling(x0)240out = x + x0241return out242243244class Discriminator(nn.Module):245def __init__(self, img_size, d_conv_dim, apply_d_sn, apply_attn, attn_d_loc, d_cond_mtd, aux_cls_type, d_embed_dim, normalize_d_embed,246num_classes, d_init, d_depth, mixed_precision, MODULES, MODEL):247super(Discriminator, self).__init__()248d_in_dims_collection = {249"32": [3] + [d_conv_dim * 2, d_conv_dim * 2, d_conv_dim * 2],250"64": [3] + [d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8],251"128": [3] + [d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 16],252"256": [3] + [d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 8, d_conv_dim * 16],253"512": [3] + [d_conv_dim, d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 8, d_conv_dim * 16]254}255256d_out_dims_collection = {257"32": [d_conv_dim * 2, d_conv_dim * 2, d_conv_dim * 2, d_conv_dim * 2],258"64": [d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 16],259"128": [d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 16, d_conv_dim * 16],260"256": [d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 8, d_conv_dim * 16, d_conv_dim * 16],261"512":262[d_conv_dim, d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 8, d_conv_dim * 16, d_conv_dim * 16]263}264265d_down = {266"32": [True, True, False, False],267"64": [True, True, True, True, False],268"128": [True, True, True, True, True, False],269"256": [True, True, True, True, True, True, False],270"512": [True, True, True, True, True, True, True, False]271}272273self.d_cond_mtd = d_cond_mtd274self.aux_cls_type = aux_cls_type275self.normalize_d_embed = normalize_d_embed276self.num_classes = num_classes277self.mixed_precision = mixed_precision278self.in_dims = d_in_dims_collection[str(img_size)]279self.out_dims = d_out_dims_collection[str(img_size)]280self.MODEL = MODEL281down = d_down[str(img_size)]282283self.blocks = []284for index in range(len(self.in_dims)):285if index == 0:286self.blocks += [[287DiscOptBlock(in_channels=self.in_dims[index], out_channels=self.out_dims[index], apply_d_sn=apply_d_sn, MODULES=MODULES)288]]289else:290self.blocks += [[291DiscBlock(in_channels=self.in_dims[index],292out_channels=self.out_dims[index],293apply_d_sn=apply_d_sn,294MODULES=MODULES,295downsample=down[index])296]]297298if index + 1 in attn_d_loc and apply_attn:299self.blocks += [[ops.SelfAttention(self.out_dims[index], is_generator=False, MODULES=MODULES)]]300301self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])302303self.activation = MODULES.d_act_fn304305# linear layer for adversarial training306if self.d_cond_mtd == "MH":307self.linear1 = MODULES.d_linear(in_features=self.out_dims[-1], out_features=1 + num_classes, bias=True)308elif self.d_cond_mtd == "MD":309self.linear1 = MODULES.d_linear(in_features=self.out_dims[-1], out_features=num_classes, bias=True)310else:311self.linear1 = MODULES.d_linear(in_features=self.out_dims[-1], out_features=1, bias=True)312313# double num_classes for Auxiliary Discriminative Classifier314if self.aux_cls_type == "ADC":315num_classes = num_classes * 2316317# linear and embedding layers for discriminator conditioning318if self.d_cond_mtd == "AC":319self.linear2 = MODULES.d_linear(in_features=self.out_dims[-1], out_features=num_classes, bias=False)320elif self.d_cond_mtd == "PD":321self.embedding = MODULES.d_embedding(num_classes, self.out_dims[-1])322elif self.d_cond_mtd in ["2C", "D2DCE"]:323self.linear2 = MODULES.d_linear(in_features=self.out_dims[-1], out_features=d_embed_dim, bias=True)324self.embedding = MODULES.d_embedding(num_classes, d_embed_dim)325326# linear and embedding layers for evolved classifier-based GAN327if self.aux_cls_type == "TAC":328if self.d_cond_mtd == "AC":329self.linear_mi = MODULES.d_linear(in_features=self.out_dims[-1], out_features=num_classes, bias=False)330elif self.d_cond_mtd in ["2C", "D2DCE"]:331self.linear_mi = MODULES.d_linear(in_features=self.out_dims[-1], out_features=d_embed_dim, bias=True)332self.embedding_mi = MODULES.d_embedding(num_classes, d_embed_dim)333else:334raise NotImplementedError335336# Q head network for infoGAN337if self.MODEL.info_type in ["discrete", "both"]:338out_features = self.MODEL.info_num_discrete_c*self.MODEL.info_dim_discrete_c339self.info_discrete_linear = MODULES.d_linear(in_features=self.out_dims[-1], out_features=out_features, bias=False)340if self.MODEL.info_type in ["continuous", "both"]:341out_features = self.MODEL.info_num_conti_c342self.info_conti_mu_linear = MODULES.d_linear(in_features=self.out_dims[-1], out_features=out_features, bias=False)343self.info_conti_var_linear = MODULES.d_linear(in_features=self.out_dims[-1], out_features=out_features, bias=False)344345if d_init:346ops.init_weights(self.modules, d_init)347348def forward(self, x, label, eval=False, adc_fake=False):349with torch.cuda.amp.autocast() if self.mixed_precision and not eval else misc.dummy_context_mgr() as mp:350embed, proxy, cls_output = None, None, None351mi_embed, mi_proxy, mi_cls_output = None, None, None352info_discrete_c_logits, info_conti_mu, info_conti_var = None, None, None353h = x354for index, blocklist in enumerate(self.blocks):355for block in blocklist:356h = block(h)357bottom_h, bottom_w = h.shape[2], h.shape[3]358h = self.activation(h)359h = torch.sum(h, dim=[2, 3])360361# adversarial training362adv_output = torch.squeeze(self.linear1(h))363364# make class labels odd (for fake) or even (for real) for ADC365if self.aux_cls_type == "ADC":366if adc_fake:367label = label*2 + 1368else:369label = label*2370371# forward pass through InfoGAN Q head372if self.MODEL.info_type in ["discrete", "both"]:373info_discrete_c_logits = self.info_discrete_linear(h/(bottom_h*bottom_w))374if self.MODEL.info_type in ["continuous", "both"]:375info_conti_mu = self.info_conti_mu_linear(h/(bottom_h*bottom_w))376info_conti_var = torch.exp(self.info_conti_var_linear(h/(bottom_h*bottom_w)))377378# class conditioning379if self.d_cond_mtd == "AC":380if self.normalize_d_embed:381for W in self.linear2.parameters():382W = F.normalize(W, dim=1)383h = F.normalize(h, dim=1)384cls_output = self.linear2(h)385elif self.d_cond_mtd == "PD":386adv_output = adv_output + torch.sum(torch.mul(self.embedding(label), h), 1)387elif self.d_cond_mtd in ["2C", "D2DCE"]:388embed = self.linear2(h)389proxy = self.embedding(label)390if self.normalize_d_embed:391embed = F.normalize(embed, dim=1)392proxy = F.normalize(proxy, dim=1)393elif self.d_cond_mtd == "MD":394idx = torch.LongTensor(range(label.size(0))).to(label.device)395adv_output = adv_output[idx, label]396elif self.d_cond_mtd in ["W/O", "MH"]:397pass398else:399raise NotImplementedError400401# extra conditioning for TACGAN and ADCGAN402if self.aux_cls_type == "TAC":403if self.d_cond_mtd == "AC":404if self.normalize_d_embed:405for W in self.linear_mi.parameters():406W = F.normalize(W, dim=1)407mi_cls_output = self.linear_mi(h)408elif self.d_cond_mtd in ["2C", "D2DCE"]:409mi_embed = self.linear_mi(h)410mi_proxy = self.embedding_mi(label)411if self.normalize_d_embed:412mi_embed = F.normalize(mi_embed, dim=1)413mi_proxy = F.normalize(mi_proxy, dim=1)414return {415"h": h,416"adv_output": adv_output,417"embed": embed,418"proxy": proxy,419"cls_output": cls_output,420"label": label,421"mi_embed": mi_embed,422"mi_proxy": mi_proxy,423"mi_cls_output": mi_cls_output,424"info_discrete_c_logits": info_discrete_c_logits,425"info_conti_mu": info_conti_mu,426"info_conti_var": info_conti_var427}428429430