Path: blob/master/Conditional-GAN-PyTorch-TensorFlow/PyTorch/cgan_pytorch.py
3150 views
import torch1import numpy as np2import torch.nn as nn3import torch.optim as optim4from torchvision import datasets, transforms5from torch.autograd import Variable6from torchvision.utils import save_image7from torchvision.utils import make_grid8from torch.utils.tensorboard import SummaryWriter9from torchsummary import summary10import matplotlib.pyplot as plt11import datetime12from numpy import asarray13from numpy.random import randn14from numpy.random import randint15from numpy import linspace16from matplotlib import pyplot17from matplotlib import gridspec181920torch.manual_seed(1)2122device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')23batch_size = 1282425train_transform = transforms.Compose([26transforms.Resize(128),27transforms.ToTensor(),28transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])])29train_dataset = datasets.ImageFolder(root='rps', transform=train_transform)30train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)3132def show_images(images):33fig, ax = plt.subplots(figsize=(20, 20))34ax.set_xticks([]); ax.set_yticks([])35ax.imshow(make_grid(images.detach(), nrow=22).permute(1, 2, 0))3637def show_batch(dl):38for images, _ in dl:39show_images(images)40break4142show_batch(train_loader)4344image_shape = (3, 128, 128)45image_dim = int(np.prod(image_shape))46latent_dim = 1004748n_classes = 349embedding_dim = 1005051# custom weights initialization called on generator and discriminator52def weights_init(m):53classname = m.__class__.__name__54if classname.find('Conv') != -1:55torch.nn.init.normal_(m.weight, 0.0, 0.02)56elif classname.find('BatchNorm') != -1:57torch.nn.init.normal_(m.weight, 1.0, 0.02)58torch.nn.init.zeros_(m.bias)5960class Generator(nn.Module):61def __init__(self):62super(Generator, self).__init__()636465self.label_conditioned_generator = nn.Sequential(nn.Embedding(n_classes, embedding_dim),66nn.Linear(embedding_dim, 16))676869self.latent = nn.Sequential(nn.Linear(latent_dim, 4*4*512),70nn.LeakyReLU(0.2, inplace=True))717273self.model = nn.Sequential(nn.ConvTranspose2d(513, 64*8, 4, 2, 1, bias=False),74nn.BatchNorm2d(64*8, momentum=0.1, eps=0.8),75nn.ReLU(True),76nn.ConvTranspose2d(64*8, 64*4, 4, 2, 1,bias=False),77nn.BatchNorm2d(64*4, momentum=0.1, eps=0.8),78nn.ReLU(True),79nn.ConvTranspose2d(64*4, 64*2, 4, 2, 1,bias=False),80nn.BatchNorm2d(64*2, momentum=0.1, eps=0.8),81nn.ReLU(True),82nn.ConvTranspose2d(64*2, 64*1, 4, 2, 1,bias=False),83nn.BatchNorm2d(64*1, momentum=0.1, eps=0.8),84nn.ReLU(True),85nn.ConvTranspose2d(64*1, 3, 4, 2, 1, bias=False),86nn.Tanh())8788def forward(self, inputs):89noise_vector, label = inputs90label_output = self.label_conditioned_generator(label)91label_output = label_output.view(-1, 1, 4, 4)92latent_output = self.latent(noise_vector)93latent_output = latent_output.view(-1, 512,4,4)94concat = torch.cat((latent_output, label_output), dim=1)95image = self.model(concat)96#print(image.size())97return image9899generator = Generator().to(device)100generator.apply(weights_init)101print(generator)102103a = torch.ones(100)104b = torch.ones(1)105b = b.long()106a = a.to(device)107b = b.to(device)108109class Discriminator(nn.Module):110def __init__(self):111super(Discriminator, self).__init__()112113114self.label_condition_disc = nn.Sequential(nn.Embedding(n_classes, embedding_dim),115nn.Linear(embedding_dim, 3*128*128))116117self.model = nn.Sequential(nn.Conv2d(6, 64, 4, 2, 1, bias=False),118nn.LeakyReLU(0.2, inplace=True),119nn.Conv2d(64, 64*2, 4, 3, 2, bias=False),120nn.BatchNorm2d(64*2, momentum=0.1, eps=0.8),121nn.LeakyReLU(0.2, inplace=True),122nn.Conv2d(64*2, 64*4, 4, 3,2, bias=False),123nn.BatchNorm2d(64*4, momentum=0.1, eps=0.8),124nn.LeakyReLU(0.2, inplace=True),125nn.Conv2d(64*4, 64*8, 4, 3, 2, bias=False),126nn.BatchNorm2d(64*8, momentum=0.1, eps=0.8),127nn.LeakyReLU(0.2, inplace=True),128nn.Flatten(),129nn.Dropout(0.4),130nn.Linear(4608, 1),131nn.Sigmoid()132)133134def forward(self, inputs):135img, label = inputs136label_output = self.label_condition_disc(label)137label_output = label_output.view(-1, 3, 128, 128)138concat = torch.cat((img, label_output), dim=1)139#print(concat.size())140output = self.model(concat)141return output142143discriminator = Discriminator().to(device)144discriminator.apply(weights_init)145print(discriminator)146147a = torch.ones(2,3,128,128)148b = torch.ones(2,1)149b = b.long()150a = a.to(device)151b = b.to(device)152153c = discriminator((a,b))154c.size()155156adversarial_loss = nn.BCELoss()157158adversarial_loss = nn.BCELoss()159160def generator_loss(fake_output, label):161gen_loss = adversarial_loss(fake_output, label)162#print(gen_loss)163return gen_loss164165def discriminator_loss(output, label):166disc_loss = adversarial_loss(output, label)167return disc_loss168169learning_rate = 0.0002170G_optimizer = optim.Adam(generator.parameters(), lr = learning_rate, betas=(0.5, 0.999))171D_optimizer = optim.Adam(discriminator.parameters(), lr = learning_rate, betas=(0.5, 0.999))172173num_epochs = 2174D_loss_plot, G_loss_plot = [], []175for epoch in range(1, num_epochs+1):176177D_loss_list, G_loss_list = [], []178179for index, (real_images, labels) in enumerate(train_loader):180D_optimizer.zero_grad()181real_images = real_images.to(device)182labels = labels.to(device)183labels = labels.unsqueeze(1).long()184185186real_target = Variable(torch.ones(real_images.size(0), 1).to(device))187fake_target = Variable(torch.zeros(real_images.size(0), 1).to(device))188189D_real_loss = discriminator_loss(discriminator((real_images, labels)), real_target)190# print(discriminator(real_images))191#D_real_loss.backward()192193noise_vector = torch.randn(real_images.size(0), latent_dim, device=device)194noise_vector = noise_vector.to(device)195196197generated_image = generator((noise_vector, labels))198output = discriminator((generated_image.detach(), labels))199D_fake_loss = discriminator_loss(output, fake_target)200201202# train with fake203#D_fake_loss.backward()204205D_total_loss = (D_real_loss + D_fake_loss) / 2206D_loss_list.append(D_total_loss)207208D_total_loss.backward()209D_optimizer.step()210211# Train generator with real labels212G_optimizer.zero_grad()213G_loss = generator_loss(discriminator((generated_image, labels)), real_target)214G_loss_list.append(G_loss)215216G_loss.backward()217G_optimizer.step()218219220print('Epoch: [%d/%d]: D_loss: %.3f, G_loss: %.3f' % (221(epoch), num_epochs, torch.mean(torch.FloatTensor(D_loss_list)),\222torch.mean(torch.FloatTensor(G_loss_list))))223224D_loss_plot.append(torch.mean(torch.FloatTensor(D_loss_list)))225G_loss_plot.append(torch.mean(torch.FloatTensor(G_loss_list)))226save_image(generated_image.data[:50], 'torch/images/sample_%d'%epoch + '.png', nrow=5, normalize=True)227228torch.save(generator.state_dict(), 'torch/training_weights/generator_epoch_%d.pth' % (epoch))229torch.save(discriminator.state_dict(), 'torch/training_weights/discriminator_epoch_%d.pth' % (epoch))230231generator.load_state_dict(torch.load('torch/training_weights/generator_epoch_1.pth'), strict=False)232generator.eval()233234# example of interpolating between generated faces235# generate points in latent space as input for the generator236def generate_latent_points(latent_dim, n_samples, n_classes=10):237# generate points in the latent space238x_input = randn(latent_dim * n_samples)239# reshape into a batch of inputs for the network240z_input = x_input.reshape(n_samples, latent_dim)241return z_input242243# uniform interpolation between two points in latent space244def interpolate_points(p1, p2, n_steps=10):245# interpolate ratios between the points246ratios = linspace(0, 1, num=n_steps)247# linear interpolate vectors248vectors = list()249for ratio in ratios:250v = (1.0 - ratio) * p1 + ratio * p2251vectors.append(v)252return asarray(vectors)253254255pts = generate_latent_points(100, 2)256# interpolate points in latent space257interpolated = interpolate_points(pts[0], pts[1])258259interpolated = torch.tensor(interpolated)260interpolated = interpolated.to(device)261interpolated = interpolated.type(torch.float32)262263output = None264for label in range(3):265labels = torch.ones(10) * label266labels = labels.to(device)267labels = labels.unsqueeze(1).long()268print(labels.size())269predictions = generator((interpolated, labels))270predictions = predictions.permute(0,2,3,1)271pred = predictions.detach().cpu()272if output is None:273output = pred274else:275output = np.concatenate((output,pred))276277print(output.shape)278279nrow = 3280ncol = 10281fig = plt.figure(figsize=(25,25))282gs = gridspec.GridSpec(nrow, ncol, width_ratios=[1, 1, 1,1, 1,1, 1, 1, 1, 1],283wspace=0.0, hspace=0.0, top=0.2, bottom=0.00, left=0.17, right=0.845)284285#output = output.reshape(-1, 128, 128, 3)286#print("Generated Images are Conditioned on Label:", label_dict[np.array(labels)[0]])287k = 0288for i in range(nrow):289for j in range(ncol):290pred = (output[k, :, :, :] + 1 ) * 127.5291pred = np.array(pred)292ax= plt.subplot(gs[i,j])293ax.imshow(pred.astype(np.uint8))294ax.set_xticklabels([])295ax.set_yticklabels([])296ax.axis('off')297k += 1298299300#plt.savefig('result_torch.png', dpi=300)301plt.show()302303