Path: blob/master/Deep-Convolutional-GAN/PyTorch/dcgan_anime_pytorch.py
3142 views
import os1import torch2import numpy as np3import torch.nn as nn4import torch.optim as optim5from torchvision import datasets, transforms6from torch.autograd import Variable7from torchvision.utils import save_image8from torchvision.utils import make_grid9from torch.utils.tensorboard import SummaryWriter10from torchsummary import summary11import datetime12import matplotlib.pyplot as plt1314os.environ['CUDA_VISIBLE_DEVICES'] = '0'1516torch.manual_seed(1)1718device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')19batch_size = 1282021train_transform = transforms.Compose([transforms.Resize((64, 64)),22transforms.ToTensor(),23transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])24train_dataset = datasets.ImageFolder(root='../dcgan/anime', transform=train_transform)25train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)2627def show_images(images):28fig, ax = plt.subplots(figsize=(20, 20))29ax.set_xticks([]); ax.set_yticks([])30ax.imshow(make_grid(images.detach(), nrow=22).permute(1, 2, 0))3132image_shape = (3, 64, 64)33image_dim = int(np.prod(image_shape))34latent_dim = 1003536# custom weights initialization called on generator and discriminator37def weights_init(m):38classname = m.__class__.__name__39if classname.find('Conv') != -1:40torch.nn.init.normal_(m.weight, 0.0, 0.02)41elif classname.find('BatchNorm') != -1:42torch.nn.init.normal_(m.weight, 1.0, 0.02)43torch.nn.init.zeros_(m.bias)4445# Generator Model Class Definition46class Generator(nn.Module):47def __init__(self):48super(Generator, self).__init__()49self.main = nn.Sequential(50# Block 1:input is Z, going into a convolution51nn.ConvTranspose2d(latent_dim, 64 * 8, 4, 1, 0, bias=False),52nn.BatchNorm2d(64 * 8),53nn.ReLU(True),54# Block 2: (64 * 8) x 4 x 455nn.ConvTranspose2d(64 * 8, 64 * 4, 4, 2, 1, bias=False),56nn.BatchNorm2d(64 * 4),57nn.ReLU(True),58# Block 3: (64 * 4) x 8 x 859nn.ConvTranspose2d(64 * 4, 64 * 2, 4, 2, 1, bias=False),60nn.BatchNorm2d(64 * 2),61nn.ReLU(True),62# Block 4: (64 * 2) x 16 x 1663nn.ConvTranspose2d(64 * 2, 64, 4, 2, 1, bias=False),64nn.BatchNorm2d(64),65nn.ReLU(True),66# Block 5: (64) x 32 x 3267nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),68nn.Tanh()69# Output: (3) x 64 x 6470)7172def forward(self, input):73output = self.main(input)74return output7576generator = Generator().to(device)77generator.apply(weights_init)7879summary(generator, (100,1,1))8081# Discriminator Model Class Definition82class Discriminator(nn.Module):83def __init__(self):84super(Discriminator, self).__init__()85self.main = nn.Sequential(86# Block 1: (3) x 64 x 6487nn.Conv2d(3, 64, 4, 2, 1, bias=False),88nn.LeakyReLU(0.2, inplace=True),89# Block 2: (64) x 32 x 3290nn.Conv2d(64, 64 * 2, 4, 2, 1, bias=False),91nn.BatchNorm2d(64 * 2),92nn.LeakyReLU(0.2, inplace=True),93# Block 3: (64*2) x 16 x 1694nn.Conv2d(64 * 2, 64 * 4, 4, 2, 1, bias=False),95nn.BatchNorm2d(64 * 4),96nn.LeakyReLU(0.2, inplace=True),97# Block 4: (64*4) x 8 x 898nn.Conv2d(64 * 4, 64 * 8, 4, 2, 1, bias=False),99nn.BatchNorm2d(64 * 8),100nn.LeakyReLU(0.2, inplace=True),101# Block 5: (64*8) x 4 x 4102nn.Conv2d(64 * 8, 1, 4, 1, 0, bias=False),103nn.Sigmoid(),104nn.Flatten()105# Output: 1106)107108def forward(self, input):109output = self.main(input)110return output111112discriminator = Discriminator().to(device)113discriminator.apply(weights_init)114print(discriminator)115116summary(discriminator, (3,64,64))117118adversarial_loss = nn.BCELoss()119120def generator_loss(fake_output, label):121gen_loss = adversarial_loss(fake_output, label)122#print(gen_loss)123return gen_loss124125def discriminator_loss(output, label):126disc_loss = adversarial_loss(output, label)127return disc_loss128129fixed_noise = torch.randn(128, latent_dim, 1, 1, device=device)130real_label = 1131fake_label = 0132133learning_rate = 0.0002134G_optimizer = optim.Adam(generator.parameters(), lr = learning_rate, betas=(0.5, 0.999))135D_optimizer = optim.Adam(discriminator.parameters(), lr = learning_rate, betas=(0.5, 0.999))136137num_epochs = 2138D_loss_plot, G_loss_plot = [], []139for epoch in range(1, num_epochs+1):140141D_loss_list, G_loss_list = [], []142143for index, (real_images, _) in enumerate(train_loader):144D_optimizer.zero_grad()145real_images = real_images.to(device)146147real_target = Variable(torch.ones(real_images.size(0)).to(device))148fake_target = Variable(torch.zeros(real_images.size(0)).to(device))149150real_target = real_target.unsqueeze(1)151fake_target = fake_target.unsqueeze(1)152153D_real_loss = discriminator_loss(discriminator(real_images), real_target)154# print(discriminator(real_images))155D_real_loss.backward()156157noise_vector = torch.randn(real_images.size(0), latent_dim, 1, 1, device=device)158noise_vector = noise_vector.to(device)159160generated_image = generator(noise_vector)161output = discriminator(generated_image.detach())162D_fake_loss = discriminator_loss(output, fake_target)163164165# train with fake166D_fake_loss.backward()167168D_total_loss = D_real_loss + D_fake_loss169D_loss_list.append(D_total_loss)170171#D_total_loss.backward()172D_optimizer.step()173174# Train generator with real labels175G_optimizer.zero_grad()176G_loss = generator_loss(discriminator(generated_image), real_target)177G_loss_list.append(G_loss)178179G_loss.backward()180G_optimizer.step()181182183print('Epoch: [%d/%d]: D_loss: %.3f, G_loss: %.3f' % (184(epoch), num_epochs, torch.mean(torch.FloatTensor(D_loss_list)),\185torch.mean(torch.FloatTensor(G_loss_list))))186187D_loss_plot.append(torch.mean(torch.FloatTensor(D_loss_list)))188G_loss_plot.append(torch.mean(torch.FloatTensor(G_loss_list)))189save_image(generated_image.data[:50], 'dcgan/torch/images/sample_%d'%epoch + '.png', nrow=5, normalize=True)190191torch.save(generator.state_dict(), 'dcgan/torch/training_weights/generator_epoch_%d.pth' % (epoch))192torch.save(discriminator.state_dict(), 'dcgan/torch/training_weights/discriminator_epoch_%d.pth' % (epoch))193194