import sys
from python_environment_check import check_packages
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.autograd import grad as torch_grad
sys.path.insert(0, '..')
d = {
'torch': '1.8.0',
'torchvision': '0.9.0',
'numpy': '1.21.2',
'matplotlib': '3.4.3',
}
check_packages(d)
print(torch.__version__)
print("GPU Available:", torch.cuda.is_available())
if torch.cuda.is_available():
device = torch.device("cuda:0")
else:
device = "cpu"
image_path = './'
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.5), std=(0.5))
])
mnist_dataset = torchvision.datasets.MNIST(root=image_path,
train=True,
transform=transform,
download=False)
batch_size = 64
torch.manual_seed(1)
np.random.seed(1)
mnist_dl = DataLoader(mnist_dataset, batch_size=batch_size,
shuffle=True, drop_last=True)
def make_generator_network(input_size, n_filters):
model = nn.Sequential(
nn.ConvTranspose2d(input_size, n_filters*4, 4, 1, 0,
bias=False),
nn.BatchNorm2d(n_filters*4),
nn.LeakyReLU(0.2),
nn.ConvTranspose2d(n_filters*4, n_filters*2, 3, 2, 1, bias=False),
nn.BatchNorm2d(n_filters*2),
nn.LeakyReLU(0.2),
nn.ConvTranspose2d(n_filters*2, n_filters, 4, 2, 1, bias=False),
nn.BatchNorm2d(n_filters),
nn.LeakyReLU(0.2),
nn.ConvTranspose2d(n_filters, 1, 4, 2, 1, bias=False),
nn.Tanh())
return model
class Discriminator(nn.Module):
def __init__(self, n_filters):
super().__init__()
self.network = nn.Sequential(
nn.Conv2d(1, n_filters, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2),
nn.Conv2d(n_filters, n_filters*2, 4, 2, 1, bias=False),
nn.BatchNorm2d(n_filters * 2),
nn.LeakyReLU(0.2),
nn.Conv2d(n_filters*2, n_filters*4, 3, 2, 1, bias=False),
nn.BatchNorm2d(n_filters*4),
nn.LeakyReLU(0.2),
nn.Conv2d(n_filters*4, 1, 4, 1, 0, bias=False),
nn.Sigmoid())
def forward(self, input):
output = self.network(input)
return output.view(-1, 1).squeeze(0)
z_size = 100
image_size = (28, 28)
n_filters = 32
gen_model = make_generator_network(z_size, n_filters).to(device)
print(gen_model)
disc_model = Discriminator(n_filters).to(device)
print(disc_model)
loss_fn = nn.BCELoss()
g_optimizer = torch.optim.Adam(gen_model.parameters(), 0.0003)
d_optimizer = torch.optim.Adam(disc_model.parameters(), 0.0002)
def create_noise(batch_size, z_size, mode_z):
if mode_z == 'uniform':
input_z = torch.rand(batch_size, z_size, 1, 1)*2 - 1
elif mode_z == 'normal':
input_z = torch.randn(batch_size, z_size, 1, 1)
return input_z
def d_train(x):
disc_model.zero_grad()
batch_size = x.size(0)
x = x.to(device)
d_labels_real = torch.ones(batch_size, 1, device=device)
d_proba_real = disc_model(x)
d_loss_real = loss_fn(d_proba_real, d_labels_real)
input_z = create_noise(batch_size, z_size, mode_z).to(device)
g_output = gen_model(input_z)
d_proba_fake = disc_model(g_output)
d_labels_fake = torch.zeros(batch_size, 1, device=device)
d_loss_fake = loss_fn(d_proba_fake, d_labels_fake)
d_loss = d_loss_real + d_loss_fake
d_loss.backward()
d_optimizer.step()
return d_loss.data.item(), d_proba_real.detach(), d_proba_fake.detach()
def g_train(x):
gen_model.zero_grad()
batch_size = x.size(0)
input_z = create_noise(batch_size, z_size, mode_z).to(device)
g_labels_real = torch.ones((batch_size, 1), device=device)
g_output = gen_model(input_z)
d_proba_fake = disc_model(g_output)
g_loss = loss_fn(d_proba_fake, g_labels_real)
g_loss.backward()
g_optimizer.step()
return g_loss.data.item()
mode_z = 'uniform'
fixed_z = create_noise(batch_size, z_size, mode_z).to(device)
def create_samples(g_model, input_z):
g_output = g_model(input_z)
images = torch.reshape(g_output, (batch_size, *image_size))
return (images+1)/2.0
epoch_samples = []
num_epochs = 100
torch.manual_seed(1)
for epoch in range(1, num_epochs+1):
gen_model.train()
d_losses, g_losses = [], []
for i, (x, _) in enumerate(mnist_dl):
d_loss, d_proba_real, d_proba_fake = d_train(x)
d_losses.append(d_loss)
g_losses.append(g_train(x))
print(f'Epoch {epoch:03d} | Avg Losses >>'
f' G/D {torch.FloatTensor(g_losses).mean():.4f}'
f'/{torch.FloatTensor(d_losses).mean():.4f}')
gen_model.eval()
epoch_samples.append(
create_samples(gen_model, fixed_z).detach().cpu().numpy())
selected_epochs = [1, 2, 4, 10, 50, 100]
fig = plt.figure(figsize=(10, 14))
for i,e in enumerate(selected_epochs):
for j in range(5):
ax = fig.add_subplot(6, 5, i*5+j+1)
ax.set_xticks([])
ax.set_yticks([])
if j == 0:
ax.text(
-0.06, 0.5, f'Epoch {e}',
rotation=90, size=18, color='red',
horizontalalignment='right',
verticalalignment='center',
transform=ax.transAxes)
image = epoch_samples[e-1][j]
ax.imshow(image, cmap='gray_r')
plt.show()
def make_generator_network_wgan(input_size, n_filters):
model = nn.Sequential(
nn.ConvTranspose2d(input_size, n_filters*4, 4, 1, 0,
bias=False),
nn.InstanceNorm2d(n_filters*4),
nn.LeakyReLU(0.2),
nn.ConvTranspose2d(n_filters*4, n_filters*2, 3, 2, 1, bias=False),
nn.InstanceNorm2d(n_filters*2),
nn.LeakyReLU(0.2),
nn.ConvTranspose2d(n_filters*2, n_filters, 4, 2, 1, bias=False),
nn.InstanceNorm2d(n_filters),
nn.LeakyReLU(0.2),
nn.ConvTranspose2d(n_filters, 1, 4, 2, 1, bias=False),
nn.Tanh())
return model
class DiscriminatorWGAN(nn.Module):
def __init__(self, n_filters):
super().__init__()
self.network = nn.Sequential(
nn.Conv2d(1, n_filters, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2),
nn.Conv2d(n_filters, n_filters*2, 4, 2, 1, bias=False),
nn.InstanceNorm2d(n_filters * 2),
nn.LeakyReLU(0.2),
nn.Conv2d(n_filters*2, n_filters*4, 3, 2, 1, bias=False),
nn.InstanceNorm2d(n_filters*4),
nn.LeakyReLU(0.2),
nn.Conv2d(n_filters*4, 1, 4, 1, 0, bias=False),
nn.Sigmoid())
def forward(self, input):
output = self.network(input)
return output.view(-1, 1).squeeze(0)
gen_model = make_generator_network_wgan(z_size, n_filters).to(device)
disc_model = DiscriminatorWGAN(n_filters).to(device)
g_optimizer = torch.optim.Adam(gen_model.parameters(), 0.0002)
d_optimizer = torch.optim.Adam(disc_model.parameters(), 0.0002)
def gradient_penalty(real_data, generated_data):
batch_size = real_data.size(0)
alpha = torch.rand(real_data.shape[0], 1, 1, 1, requires_grad=True, device=device)
interpolated = alpha * real_data + (1 - alpha) * generated_data
proba_interpolated = disc_model(interpolated)
gradients = torch_grad(outputs=proba_interpolated, inputs=interpolated,
grad_outputs=torch.ones(proba_interpolated.size(), device=device),
create_graph=True, retain_graph=True)[0]
gradients = gradients.view(batch_size, -1)
gradients_norm = gradients.norm(2, dim=1)
return lambda_gp * ((gradients_norm - 1)**2).mean()
def d_train_wgan(x):
disc_model.zero_grad()
batch_size = x.size(0)
x = x.to(device)
d_real = disc_model(x)
input_z = create_noise(batch_size, z_size, mode_z).to(device)
g_output = gen_model(input_z)
d_generated = disc_model(g_output)
d_loss = d_generated.mean() - d_real.mean() + gradient_penalty(x.data, g_output.data)
d_loss.backward()
d_optimizer.step()
return d_loss.data.item()
def g_train_wgan(x):
gen_model.zero_grad()
batch_size = x.size(0)
input_z = create_noise(batch_size, z_size, mode_z).to(device)
g_output = gen_model(input_z)
d_generated = disc_model(g_output)
g_loss = -d_generated.mean()
g_loss.backward()
g_optimizer.step()
return g_loss.data.item()
epoch_samples_wgan = []
lambda_gp = 10.0
num_epochs = 100
torch.manual_seed(1)
critic_iterations = 5
for epoch in range(1, num_epochs+1):
gen_model.train()
d_losses, g_losses = [], []
for i, (x, _) in enumerate(mnist_dl):
for _ in range(critic_iterations):
d_loss = d_train_wgan(x)
d_losses.append(d_loss)
g_losses.append(g_train_wgan(x))
print(f'Epoch {epoch:03d} | D Loss >>'
f' {torch.FloatTensor(d_losses).mean():.4f}')
gen_model.eval()
epoch_samples_wgan.append(
create_samples(gen_model, fixed_z).detach().cpu().numpy())
selected_epochs = [1, 2, 4, 10, 50, 100]
fig = plt.figure(figsize=(10, 14))
for i,e in enumerate(selected_epochs):
for j in range(5):
ax = fig.add_subplot(6, 5, i*5+j+1)
ax.set_xticks([])
ax.set_yticks([])
if j == 0:
ax.text(
-0.06, 0.5, f'Epoch {e}',
rotation=90, size=18, color='red',
horizontalalignment='right',
verticalalignment='center',
transform=ax.transAxes)
image = epoch_samples_wgan[e-1][j]
ax.imshow(image, cmap='gray_r')
plt.show()