Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
hackassin
GitHub Repository: hackassin/learnopencv
Path: blob/master/Deep-Convolutional-GAN/PyTorch/dcgan_anime_pytorch.py
3142 views
1
import os
2
import torch
3
import numpy as np
4
import torch.nn as nn
5
import torch.optim as optim
6
from torchvision import datasets, transforms
7
from torch.autograd import Variable
8
from torchvision.utils import save_image
9
from torchvision.utils import make_grid
10
from torch.utils.tensorboard import SummaryWriter
11
from torchsummary import summary
12
import datetime
13
import matplotlib.pyplot as plt
14
15
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
16
17
torch.manual_seed(1)
18
19
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20
batch_size = 128
21
22
train_transform = transforms.Compose([transforms.Resize((64, 64)),
23
transforms.ToTensor(),
24
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
25
train_dataset = datasets.ImageFolder(root='../dcgan/anime', transform=train_transform)
26
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
27
28
def show_images(images):
29
fig, ax = plt.subplots(figsize=(20, 20))
30
ax.set_xticks([]); ax.set_yticks([])
31
ax.imshow(make_grid(images.detach(), nrow=22).permute(1, 2, 0))
32
33
image_shape = (3, 64, 64)
34
image_dim = int(np.prod(image_shape))
35
latent_dim = 100
36
37
# custom weights initialization called on generator and discriminator
38
def weights_init(m):
39
classname = m.__class__.__name__
40
if classname.find('Conv') != -1:
41
torch.nn.init.normal_(m.weight, 0.0, 0.02)
42
elif classname.find('BatchNorm') != -1:
43
torch.nn.init.normal_(m.weight, 1.0, 0.02)
44
torch.nn.init.zeros_(m.bias)
45
46
# Generator Model Class Definition
47
class Generator(nn.Module):
48
def __init__(self):
49
super(Generator, self).__init__()
50
self.main = nn.Sequential(
51
# Block 1:input is Z, going into a convolution
52
nn.ConvTranspose2d(latent_dim, 64 * 8, 4, 1, 0, bias=False),
53
nn.BatchNorm2d(64 * 8),
54
nn.ReLU(True),
55
# Block 2: (64 * 8) x 4 x 4
56
nn.ConvTranspose2d(64 * 8, 64 * 4, 4, 2, 1, bias=False),
57
nn.BatchNorm2d(64 * 4),
58
nn.ReLU(True),
59
# Block 3: (64 * 4) x 8 x 8
60
nn.ConvTranspose2d(64 * 4, 64 * 2, 4, 2, 1, bias=False),
61
nn.BatchNorm2d(64 * 2),
62
nn.ReLU(True),
63
# Block 4: (64 * 2) x 16 x 16
64
nn.ConvTranspose2d(64 * 2, 64, 4, 2, 1, bias=False),
65
nn.BatchNorm2d(64),
66
nn.ReLU(True),
67
# Block 5: (64) x 32 x 32
68
nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
69
nn.Tanh()
70
# Output: (3) x 64 x 64
71
)
72
73
def forward(self, input):
74
output = self.main(input)
75
return output
76
77
generator = Generator().to(device)
78
generator.apply(weights_init)
79
80
summary(generator, (100,1,1))
81
82
# Discriminator Model Class Definition
83
class Discriminator(nn.Module):
84
def __init__(self):
85
super(Discriminator, self).__init__()
86
self.main = nn.Sequential(
87
# Block 1: (3) x 64 x 64
88
nn.Conv2d(3, 64, 4, 2, 1, bias=False),
89
nn.LeakyReLU(0.2, inplace=True),
90
# Block 2: (64) x 32 x 32
91
nn.Conv2d(64, 64 * 2, 4, 2, 1, bias=False),
92
nn.BatchNorm2d(64 * 2),
93
nn.LeakyReLU(0.2, inplace=True),
94
# Block 3: (64*2) x 16 x 16
95
nn.Conv2d(64 * 2, 64 * 4, 4, 2, 1, bias=False),
96
nn.BatchNorm2d(64 * 4),
97
nn.LeakyReLU(0.2, inplace=True),
98
# Block 4: (64*4) x 8 x 8
99
nn.Conv2d(64 * 4, 64 * 8, 4, 2, 1, bias=False),
100
nn.BatchNorm2d(64 * 8),
101
nn.LeakyReLU(0.2, inplace=True),
102
# Block 5: (64*8) x 4 x 4
103
nn.Conv2d(64 * 8, 1, 4, 1, 0, bias=False),
104
nn.Sigmoid(),
105
nn.Flatten()
106
# Output: 1
107
)
108
109
def forward(self, input):
110
output = self.main(input)
111
return output
112
113
discriminator = Discriminator().to(device)
114
discriminator.apply(weights_init)
115
print(discriminator)
116
117
summary(discriminator, (3,64,64))
118
119
adversarial_loss = nn.BCELoss()
120
121
def generator_loss(fake_output, label):
122
gen_loss = adversarial_loss(fake_output, label)
123
#print(gen_loss)
124
return gen_loss
125
126
def discriminator_loss(output, label):
127
disc_loss = adversarial_loss(output, label)
128
return disc_loss
129
130
fixed_noise = torch.randn(128, latent_dim, 1, 1, device=device)
131
real_label = 1
132
fake_label = 0
133
134
learning_rate = 0.0002
135
G_optimizer = optim.Adam(generator.parameters(), lr = learning_rate, betas=(0.5, 0.999))
136
D_optimizer = optim.Adam(discriminator.parameters(), lr = learning_rate, betas=(0.5, 0.999))
137
138
num_epochs = 2
139
D_loss_plot, G_loss_plot = [], []
140
for epoch in range(1, num_epochs+1):
141
142
D_loss_list, G_loss_list = [], []
143
144
for index, (real_images, _) in enumerate(train_loader):
145
D_optimizer.zero_grad()
146
real_images = real_images.to(device)
147
148
real_target = Variable(torch.ones(real_images.size(0)).to(device))
149
fake_target = Variable(torch.zeros(real_images.size(0)).to(device))
150
151
real_target = real_target.unsqueeze(1)
152
fake_target = fake_target.unsqueeze(1)
153
154
D_real_loss = discriminator_loss(discriminator(real_images), real_target)
155
# print(discriminator(real_images))
156
D_real_loss.backward()
157
158
noise_vector = torch.randn(real_images.size(0), latent_dim, 1, 1, device=device)
159
noise_vector = noise_vector.to(device)
160
161
generated_image = generator(noise_vector)
162
output = discriminator(generated_image.detach())
163
D_fake_loss = discriminator_loss(output, fake_target)
164
165
166
# train with fake
167
D_fake_loss.backward()
168
169
D_total_loss = D_real_loss + D_fake_loss
170
D_loss_list.append(D_total_loss)
171
172
#D_total_loss.backward()
173
D_optimizer.step()
174
175
# Train generator with real labels
176
G_optimizer.zero_grad()
177
G_loss = generator_loss(discriminator(generated_image), real_target)
178
G_loss_list.append(G_loss)
179
180
G_loss.backward()
181
G_optimizer.step()
182
183
184
print('Epoch: [%d/%d]: D_loss: %.3f, G_loss: %.3f' % (
185
(epoch), num_epochs, torch.mean(torch.FloatTensor(D_loss_list)),\
186
torch.mean(torch.FloatTensor(G_loss_list))))
187
188
D_loss_plot.append(torch.mean(torch.FloatTensor(D_loss_list)))
189
G_loss_plot.append(torch.mean(torch.FloatTensor(G_loss_list)))
190
save_image(generated_image.data[:50], 'dcgan/torch/images/sample_%d'%epoch + '.png', nrow=5, normalize=True)
191
192
torch.save(generator.state_dict(), 'dcgan/torch/training_weights/generator_epoch_%d.pth' % (epoch))
193
torch.save(discriminator.state_dict(), 'dcgan/torch/training_weights/discriminator_epoch_%d.pth' % (epoch))
194