Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
hackassin
GitHub Repository: hackassin/learnopencv
Path: blob/master/Conditional-GAN-PyTorch-TensorFlow/PyTorch/cgan_pytorch.py
3150 views
1
import torch
2
import numpy as np
3
import torch.nn as nn
4
import torch.optim as optim
5
from torchvision import datasets, transforms
6
from torch.autograd import Variable
7
from torchvision.utils import save_image
8
from torchvision.utils import make_grid
9
from torch.utils.tensorboard import SummaryWriter
10
from torchsummary import summary
11
import matplotlib.pyplot as plt
12
import datetime
13
from numpy import asarray
14
from numpy.random import randn
15
from numpy.random import randint
16
from numpy import linspace
17
from matplotlib import pyplot
18
from matplotlib import gridspec
19
20
21
torch.manual_seed(1)
22
23
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
24
batch_size = 128
25
26
train_transform = transforms.Compose([
27
transforms.Resize(128),
28
transforms.ToTensor(),
29
transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])])
30
train_dataset = datasets.ImageFolder(root='rps', transform=train_transform)
31
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
32
33
def show_images(images):
34
fig, ax = plt.subplots(figsize=(20, 20))
35
ax.set_xticks([]); ax.set_yticks([])
36
ax.imshow(make_grid(images.detach(), nrow=22).permute(1, 2, 0))
37
38
def show_batch(dl):
39
for images, _ in dl:
40
show_images(images)
41
break
42
43
show_batch(train_loader)
44
45
image_shape = (3, 128, 128)
46
image_dim = int(np.prod(image_shape))
47
latent_dim = 100
48
49
n_classes = 3
50
embedding_dim = 100
51
52
# custom weights initialization called on generator and discriminator
53
def weights_init(m):
54
classname = m.__class__.__name__
55
if classname.find('Conv') != -1:
56
torch.nn.init.normal_(m.weight, 0.0, 0.02)
57
elif classname.find('BatchNorm') != -1:
58
torch.nn.init.normal_(m.weight, 1.0, 0.02)
59
torch.nn.init.zeros_(m.bias)
60
61
class Generator(nn.Module):
62
def __init__(self):
63
super(Generator, self).__init__()
64
65
66
self.label_conditioned_generator = nn.Sequential(nn.Embedding(n_classes, embedding_dim),
67
nn.Linear(embedding_dim, 16))
68
69
70
self.latent = nn.Sequential(nn.Linear(latent_dim, 4*4*512),
71
nn.LeakyReLU(0.2, inplace=True))
72
73
74
self.model = nn.Sequential(nn.ConvTranspose2d(513, 64*8, 4, 2, 1, bias=False),
75
nn.BatchNorm2d(64*8, momentum=0.1, eps=0.8),
76
nn.ReLU(True),
77
nn.ConvTranspose2d(64*8, 64*4, 4, 2, 1,bias=False),
78
nn.BatchNorm2d(64*4, momentum=0.1, eps=0.8),
79
nn.ReLU(True),
80
nn.ConvTranspose2d(64*4, 64*2, 4, 2, 1,bias=False),
81
nn.BatchNorm2d(64*2, momentum=0.1, eps=0.8),
82
nn.ReLU(True),
83
nn.ConvTranspose2d(64*2, 64*1, 4, 2, 1,bias=False),
84
nn.BatchNorm2d(64*1, momentum=0.1, eps=0.8),
85
nn.ReLU(True),
86
nn.ConvTranspose2d(64*1, 3, 4, 2, 1, bias=False),
87
nn.Tanh())
88
89
def forward(self, inputs):
90
noise_vector, label = inputs
91
label_output = self.label_conditioned_generator(label)
92
label_output = label_output.view(-1, 1, 4, 4)
93
latent_output = self.latent(noise_vector)
94
latent_output = latent_output.view(-1, 512,4,4)
95
concat = torch.cat((latent_output, label_output), dim=1)
96
image = self.model(concat)
97
#print(image.size())
98
return image
99
100
generator = Generator().to(device)
101
generator.apply(weights_init)
102
print(generator)
103
104
a = torch.ones(100)
105
b = torch.ones(1)
106
b = b.long()
107
a = a.to(device)
108
b = b.to(device)
109
110
class Discriminator(nn.Module):
111
def __init__(self):
112
super(Discriminator, self).__init__()
113
114
115
self.label_condition_disc = nn.Sequential(nn.Embedding(n_classes, embedding_dim),
116
nn.Linear(embedding_dim, 3*128*128))
117
118
self.model = nn.Sequential(nn.Conv2d(6, 64, 4, 2, 1, bias=False),
119
nn.LeakyReLU(0.2, inplace=True),
120
nn.Conv2d(64, 64*2, 4, 3, 2, bias=False),
121
nn.BatchNorm2d(64*2, momentum=0.1, eps=0.8),
122
nn.LeakyReLU(0.2, inplace=True),
123
nn.Conv2d(64*2, 64*4, 4, 3,2, bias=False),
124
nn.BatchNorm2d(64*4, momentum=0.1, eps=0.8),
125
nn.LeakyReLU(0.2, inplace=True),
126
nn.Conv2d(64*4, 64*8, 4, 3, 2, bias=False),
127
nn.BatchNorm2d(64*8, momentum=0.1, eps=0.8),
128
nn.LeakyReLU(0.2, inplace=True),
129
nn.Flatten(),
130
nn.Dropout(0.4),
131
nn.Linear(4608, 1),
132
nn.Sigmoid()
133
)
134
135
def forward(self, inputs):
136
img, label = inputs
137
label_output = self.label_condition_disc(label)
138
label_output = label_output.view(-1, 3, 128, 128)
139
concat = torch.cat((img, label_output), dim=1)
140
#print(concat.size())
141
output = self.model(concat)
142
return output
143
144
discriminator = Discriminator().to(device)
145
discriminator.apply(weights_init)
146
print(discriminator)
147
148
a = torch.ones(2,3,128,128)
149
b = torch.ones(2,1)
150
b = b.long()
151
a = a.to(device)
152
b = b.to(device)
153
154
c = discriminator((a,b))
155
c.size()
156
157
adversarial_loss = nn.BCELoss()
158
159
adversarial_loss = nn.BCELoss()
160
161
def generator_loss(fake_output, label):
162
gen_loss = adversarial_loss(fake_output, label)
163
#print(gen_loss)
164
return gen_loss
165
166
def discriminator_loss(output, label):
167
disc_loss = adversarial_loss(output, label)
168
return disc_loss
169
170
learning_rate = 0.0002
171
G_optimizer = optim.Adam(generator.parameters(), lr = learning_rate, betas=(0.5, 0.999))
172
D_optimizer = optim.Adam(discriminator.parameters(), lr = learning_rate, betas=(0.5, 0.999))
173
174
num_epochs = 2
175
D_loss_plot, G_loss_plot = [], []
176
for epoch in range(1, num_epochs+1):
177
178
D_loss_list, G_loss_list = [], []
179
180
for index, (real_images, labels) in enumerate(train_loader):
181
D_optimizer.zero_grad()
182
real_images = real_images.to(device)
183
labels = labels.to(device)
184
labels = labels.unsqueeze(1).long()
185
186
187
real_target = Variable(torch.ones(real_images.size(0), 1).to(device))
188
fake_target = Variable(torch.zeros(real_images.size(0), 1).to(device))
189
190
D_real_loss = discriminator_loss(discriminator((real_images, labels)), real_target)
191
# print(discriminator(real_images))
192
#D_real_loss.backward()
193
194
noise_vector = torch.randn(real_images.size(0), latent_dim, device=device)
195
noise_vector = noise_vector.to(device)
196
197
198
generated_image = generator((noise_vector, labels))
199
output = discriminator((generated_image.detach(), labels))
200
D_fake_loss = discriminator_loss(output, fake_target)
201
202
203
# train with fake
204
#D_fake_loss.backward()
205
206
D_total_loss = (D_real_loss + D_fake_loss) / 2
207
D_loss_list.append(D_total_loss)
208
209
D_total_loss.backward()
210
D_optimizer.step()
211
212
# Train generator with real labels
213
G_optimizer.zero_grad()
214
G_loss = generator_loss(discriminator((generated_image, labels)), real_target)
215
G_loss_list.append(G_loss)
216
217
G_loss.backward()
218
G_optimizer.step()
219
220
221
print('Epoch: [%d/%d]: D_loss: %.3f, G_loss: %.3f' % (
222
(epoch), num_epochs, torch.mean(torch.FloatTensor(D_loss_list)),\
223
torch.mean(torch.FloatTensor(G_loss_list))))
224
225
D_loss_plot.append(torch.mean(torch.FloatTensor(D_loss_list)))
226
G_loss_plot.append(torch.mean(torch.FloatTensor(G_loss_list)))
227
save_image(generated_image.data[:50], 'torch/images/sample_%d'%epoch + '.png', nrow=5, normalize=True)
228
229
torch.save(generator.state_dict(), 'torch/training_weights/generator_epoch_%d.pth' % (epoch))
230
torch.save(discriminator.state_dict(), 'torch/training_weights/discriminator_epoch_%d.pth' % (epoch))
231
232
generator.load_state_dict(torch.load('torch/training_weights/generator_epoch_1.pth'), strict=False)
233
generator.eval()
234
235
# example of interpolating between generated faces
236
# generate points in latent space as input for the generator
237
def generate_latent_points(latent_dim, n_samples, n_classes=10):
238
# generate points in the latent space
239
x_input = randn(latent_dim * n_samples)
240
# reshape into a batch of inputs for the network
241
z_input = x_input.reshape(n_samples, latent_dim)
242
return z_input
243
244
# uniform interpolation between two points in latent space
245
def interpolate_points(p1, p2, n_steps=10):
246
# interpolate ratios between the points
247
ratios = linspace(0, 1, num=n_steps)
248
# linear interpolate vectors
249
vectors = list()
250
for ratio in ratios:
251
v = (1.0 - ratio) * p1 + ratio * p2
252
vectors.append(v)
253
return asarray(vectors)
254
255
256
pts = generate_latent_points(100, 2)
257
# interpolate points in latent space
258
interpolated = interpolate_points(pts[0], pts[1])
259
260
interpolated = torch.tensor(interpolated)
261
interpolated = interpolated.to(device)
262
interpolated = interpolated.type(torch.float32)
263
264
output = None
265
for label in range(3):
266
labels = torch.ones(10) * label
267
labels = labels.to(device)
268
labels = labels.unsqueeze(1).long()
269
print(labels.size())
270
predictions = generator((interpolated, labels))
271
predictions = predictions.permute(0,2,3,1)
272
pred = predictions.detach().cpu()
273
if output is None:
274
output = pred
275
else:
276
output = np.concatenate((output,pred))
277
278
print(output.shape)
279
280
nrow = 3
281
ncol = 10
282
fig = plt.figure(figsize=(25,25))
283
gs = gridspec.GridSpec(nrow, ncol, width_ratios=[1, 1, 1,1, 1,1, 1, 1, 1, 1],
284
wspace=0.0, hspace=0.0, top=0.2, bottom=0.00, left=0.17, right=0.845)
285
286
#output = output.reshape(-1, 128, 128, 3)
287
#print("Generated Images are Conditioned on Label:", label_dict[np.array(labels)[0]])
288
k = 0
289
for i in range(nrow):
290
for j in range(ncol):
291
pred = (output[k, :, :, :] + 1 ) * 127.5
292
pred = np.array(pred)
293
ax= plt.subplot(gs[i,j])
294
ax.imshow(pred.astype(np.uint8))
295
ax.set_xticklabels([])
296
ax.set_yticklabels([])
297
ax.axis('off')
298
k += 1
299
300
301
#plt.savefig('result_torch.png', dpi=300)
302
plt.show()
303