Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
rasbt
GitHub Repository: rasbt/machine-learning-book
Path: blob/main/ch17/ch17_part2.py
1245 views
1
# coding: utf-8
2
3
4
import sys
5
from python_environment_check import check_packages
6
#from google.colab import drive
7
import torch
8
import torch.nn as nn
9
import numpy as np
10
import matplotlib.pyplot as plt
11
import torchvision
12
from torchvision import transforms
13
from torch.utils.data import DataLoader
14
from torch.autograd import grad as torch_grad
15
16
# # Machine Learning with PyTorch and Scikit-Learn
17
# # -- Code Examples
18
19
# ## Package version checks
20
21
# Add folder to path in order to load from the check_packages.py script:
22
23
24
25
sys.path.insert(0, '..')
26
27
28
# Check recommended package versions:
29
30
31
32
33
34
d = {
35
'torch': '1.8.0',
36
'torchvision': '0.9.0',
37
'numpy': '1.21.2',
38
'matplotlib': '3.4.3',
39
}
40
41
check_packages(d)
42
43
44
# # Chapter 17 - Generative Adversarial Networks for Synthesizing New Data (Part 2/2)
45
46
# **Contents**
47
#
48
# - [Improving the quality of synthesized images using a convolutional and Wasserstein GAN](#Improving-the-quality-of-synthesized-images-using-a-convolutional-and-Wasserstein-GAN)
49
# - [Transposed convolution](#Transposed-convolution)
50
# - [Batch normalization](#Batch-normalization)
51
# - [Implementing the generator and discriminator](#Implementing-the-generator-and-discriminator)
52
# - [Dissimilarity measures between two distributions](#Dissimilarity-measures-between-two-distributions)
53
# - [Using EM distance in practice for GANs](#Using-EM-distance-in-practice-for-GANs)
54
# - [Gradient penalty](#Gradient-penalty)
55
# - [Implementing WGAN-GP to train the DCGAN model](#Implementing-WGAN-GP-to-train-the-DCGAN-model)
56
# - [Mode collapse](#Mode-collapse)
57
# - [Other GAN applications](#Other-GAN-applications)
58
# - [Summary](#Summary)
59
60
# Note that the optional watermark extension is a small IPython notebook plugin that I developed to make the code reproducible. You can just skip the following line(s).
61
62
63
64
65
66
67
68
69
70
# # Improving the quality of synthesized images using a convolutional and Wasserstein GAN
71
72
# ## Transposed convolution
73
74
75
76
77
78
79
80
81
82
# ## Batch normalization
83
84
85
86
87
88
# ## Implementing the generator and discriminator
89
90
91
92
93
94
95
96
97
98
# * **Setting up the Google Colab**
99
100
101
102
#drive.mount('/content/drive/')
103
104
105
106
107
108
109
print(torch.__version__)
110
print("GPU Available:", torch.cuda.is_available())
111
112
if torch.cuda.is_available():
113
device = torch.device("cuda:0")
114
else:
115
device = "cpu"
116
117
118
119
120
121
122
123
# ## Train the DCGAN model
124
125
126
127
128
129
image_path = './'
130
transform = transforms.Compose([
131
transforms.ToTensor(),
132
transforms.Normalize(mean=(0.5), std=(0.5))
133
])
134
mnist_dataset = torchvision.datasets.MNIST(root=image_path,
135
train=True,
136
transform=transform,
137
download=False)
138
139
batch_size = 64
140
141
torch.manual_seed(1)
142
np.random.seed(1)
143
144
## Set up the dataset
145
mnist_dl = DataLoader(mnist_dataset, batch_size=batch_size,
146
shuffle=True, drop_last=True)
147
148
149
150
151
def make_generator_network(input_size, n_filters):
152
model = nn.Sequential(
153
nn.ConvTranspose2d(input_size, n_filters*4, 4, 1, 0,
154
bias=False),
155
nn.BatchNorm2d(n_filters*4),
156
nn.LeakyReLU(0.2),
157
158
nn.ConvTranspose2d(n_filters*4, n_filters*2, 3, 2, 1, bias=False),
159
nn.BatchNorm2d(n_filters*2),
160
nn.LeakyReLU(0.2),
161
162
nn.ConvTranspose2d(n_filters*2, n_filters, 4, 2, 1, bias=False),
163
nn.BatchNorm2d(n_filters),
164
nn.LeakyReLU(0.2),
165
166
nn.ConvTranspose2d(n_filters, 1, 4, 2, 1, bias=False),
167
nn.Tanh())
168
return model
169
170
class Discriminator(nn.Module):
171
def __init__(self, n_filters):
172
super().__init__()
173
self.network = nn.Sequential(
174
nn.Conv2d(1, n_filters, 4, 2, 1, bias=False),
175
nn.LeakyReLU(0.2),
176
177
nn.Conv2d(n_filters, n_filters*2, 4, 2, 1, bias=False),
178
nn.BatchNorm2d(n_filters * 2),
179
nn.LeakyReLU(0.2),
180
181
nn.Conv2d(n_filters*2, n_filters*4, 3, 2, 1, bias=False),
182
nn.BatchNorm2d(n_filters*4),
183
nn.LeakyReLU(0.2),
184
185
nn.Conv2d(n_filters*4, 1, 4, 1, 0, bias=False),
186
nn.Sigmoid())
187
188
def forward(self, input):
189
output = self.network(input)
190
return output.view(-1, 1).squeeze(0)
191
192
193
194
195
z_size = 100
196
image_size = (28, 28)
197
n_filters = 32
198
gen_model = make_generator_network(z_size, n_filters).to(device)
199
print(gen_model)
200
disc_model = Discriminator(n_filters).to(device)
201
print(disc_model)
202
203
204
205
206
## Loss function and optimizers:
207
loss_fn = nn.BCELoss()
208
g_optimizer = torch.optim.Adam(gen_model.parameters(), 0.0003)
209
d_optimizer = torch.optim.Adam(disc_model.parameters(), 0.0002)
210
211
212
213
214
def create_noise(batch_size, z_size, mode_z):
215
if mode_z == 'uniform':
216
input_z = torch.rand(batch_size, z_size, 1, 1)*2 - 1
217
elif mode_z == 'normal':
218
input_z = torch.randn(batch_size, z_size, 1, 1)
219
return input_z
220
221
222
223
224
## Train the discriminator
225
def d_train(x):
226
disc_model.zero_grad()
227
228
# Train discriminator with a real batch
229
batch_size = x.size(0)
230
x = x.to(device)
231
d_labels_real = torch.ones(batch_size, 1, device=device)
232
233
d_proba_real = disc_model(x)
234
d_loss_real = loss_fn(d_proba_real, d_labels_real)
235
236
# Train discriminator on a fake batch
237
input_z = create_noise(batch_size, z_size, mode_z).to(device)
238
g_output = gen_model(input_z)
239
240
d_proba_fake = disc_model(g_output)
241
d_labels_fake = torch.zeros(batch_size, 1, device=device)
242
d_loss_fake = loss_fn(d_proba_fake, d_labels_fake)
243
244
# gradient backprop & optimize ONLY D's parameters
245
d_loss = d_loss_real + d_loss_fake
246
d_loss.backward()
247
d_optimizer.step()
248
249
return d_loss.data.item(), d_proba_real.detach(), d_proba_fake.detach()
250
251
252
253
254
## Train the generator
255
def g_train(x):
256
gen_model.zero_grad()
257
258
batch_size = x.size(0)
259
input_z = create_noise(batch_size, z_size, mode_z).to(device)
260
g_labels_real = torch.ones((batch_size, 1), device=device)
261
262
g_output = gen_model(input_z)
263
d_proba_fake = disc_model(g_output)
264
g_loss = loss_fn(d_proba_fake, g_labels_real)
265
266
# gradient backprop & optimize ONLY G's parameters
267
g_loss.backward()
268
g_optimizer.step()
269
270
return g_loss.data.item()
271
272
273
274
275
mode_z = 'uniform'
276
fixed_z = create_noise(batch_size, z_size, mode_z).to(device)
277
278
def create_samples(g_model, input_z):
279
g_output = g_model(input_z)
280
images = torch.reshape(g_output, (batch_size, *image_size))
281
return (images+1)/2.0
282
283
epoch_samples = []
284
285
num_epochs = 100
286
torch.manual_seed(1)
287
288
for epoch in range(1, num_epochs+1):
289
gen_model.train()
290
d_losses, g_losses = [], []
291
for i, (x, _) in enumerate(mnist_dl):
292
d_loss, d_proba_real, d_proba_fake = d_train(x)
293
d_losses.append(d_loss)
294
g_losses.append(g_train(x))
295
296
print(f'Epoch {epoch:03d} | Avg Losses >>'
297
f' G/D {torch.FloatTensor(g_losses).mean():.4f}'
298
f'/{torch.FloatTensor(d_losses).mean():.4f}')
299
gen_model.eval()
300
epoch_samples.append(
301
create_samples(gen_model, fixed_z).detach().cpu().numpy())
302
303
304
305
306
selected_epochs = [1, 2, 4, 10, 50, 100]
307
fig = plt.figure(figsize=(10, 14))
308
for i,e in enumerate(selected_epochs):
309
for j in range(5):
310
ax = fig.add_subplot(6, 5, i*5+j+1)
311
ax.set_xticks([])
312
ax.set_yticks([])
313
if j == 0:
314
ax.text(
315
-0.06, 0.5, f'Epoch {e}',
316
rotation=90, size=18, color='red',
317
horizontalalignment='right',
318
verticalalignment='center',
319
transform=ax.transAxes)
320
321
image = epoch_samples[e-1][j]
322
ax.imshow(image, cmap='gray_r')
323
324
# plt.savefig('figures/ch17-dcgan-samples.pdf')
325
plt.show()
326
327
328
# ## Dissimilarity measures between two distributions
329
330
331
332
333
334
335
336
337
338
# ## Using EM distance in practice for GANs
339
340
# ## Gradient penalty
341
342
# ## Implementing WGAN-GP to train the DCGAN model
343
344
345
346
def make_generator_network_wgan(input_size, n_filters):
347
model = nn.Sequential(
348
nn.ConvTranspose2d(input_size, n_filters*4, 4, 1, 0,
349
bias=False),
350
nn.InstanceNorm2d(n_filters*4),
351
nn.LeakyReLU(0.2),
352
353
nn.ConvTranspose2d(n_filters*4, n_filters*2, 3, 2, 1, bias=False),
354
nn.InstanceNorm2d(n_filters*2),
355
nn.LeakyReLU(0.2),
356
357
nn.ConvTranspose2d(n_filters*2, n_filters, 4, 2, 1, bias=False),
358
nn.InstanceNorm2d(n_filters),
359
nn.LeakyReLU(0.2),
360
361
nn.ConvTranspose2d(n_filters, 1, 4, 2, 1, bias=False),
362
nn.Tanh())
363
return model
364
365
class DiscriminatorWGAN(nn.Module):
366
def __init__(self, n_filters):
367
super().__init__()
368
self.network = nn.Sequential(
369
nn.Conv2d(1, n_filters, 4, 2, 1, bias=False),
370
nn.LeakyReLU(0.2),
371
372
nn.Conv2d(n_filters, n_filters*2, 4, 2, 1, bias=False),
373
nn.InstanceNorm2d(n_filters * 2),
374
nn.LeakyReLU(0.2),
375
376
nn.Conv2d(n_filters*2, n_filters*4, 3, 2, 1, bias=False),
377
nn.InstanceNorm2d(n_filters*4),
378
nn.LeakyReLU(0.2),
379
380
nn.Conv2d(n_filters*4, 1, 4, 1, 0, bias=False),
381
nn.Sigmoid())
382
383
def forward(self, input):
384
output = self.network(input)
385
return output.view(-1, 1).squeeze(0)
386
387
388
389
390
gen_model = make_generator_network_wgan(z_size, n_filters).to(device)
391
disc_model = DiscriminatorWGAN(n_filters).to(device)
392
393
g_optimizer = torch.optim.Adam(gen_model.parameters(), 0.0002)
394
d_optimizer = torch.optim.Adam(disc_model.parameters(), 0.0002)
395
396
397
398
399
400
401
def gradient_penalty(real_data, generated_data):
402
batch_size = real_data.size(0)
403
404
# Calculate interpolation
405
alpha = torch.rand(real_data.shape[0], 1, 1, 1, requires_grad=True, device=device)
406
interpolated = alpha * real_data + (1 - alpha) * generated_data
407
408
# Calculate probability of interpolated examples
409
proba_interpolated = disc_model(interpolated)
410
411
# Calculate gradients of probabilities with respect to examples
412
gradients = torch_grad(outputs=proba_interpolated, inputs=interpolated,
413
grad_outputs=torch.ones(proba_interpolated.size(), device=device),
414
create_graph=True, retain_graph=True)[0]
415
416
gradients = gradients.view(batch_size, -1)
417
gradients_norm = gradients.norm(2, dim=1)
418
return lambda_gp * ((gradients_norm - 1)**2).mean()
419
420
421
422
423
## Train the discriminator
424
def d_train_wgan(x):
425
disc_model.zero_grad()
426
427
batch_size = x.size(0)
428
x = x.to(device)
429
430
# Calculate probabilities on real and generated data
431
d_real = disc_model(x)
432
input_z = create_noise(batch_size, z_size, mode_z).to(device)
433
g_output = gen_model(input_z)
434
d_generated = disc_model(g_output)
435
d_loss = d_generated.mean() - d_real.mean() + gradient_penalty(x.data, g_output.data)
436
d_loss.backward()
437
d_optimizer.step()
438
439
return d_loss.data.item()
440
441
442
443
444
## Train the generator
445
def g_train_wgan(x):
446
gen_model.zero_grad()
447
448
batch_size = x.size(0)
449
input_z = create_noise(batch_size, z_size, mode_z).to(device)
450
g_output = gen_model(input_z)
451
452
d_generated = disc_model(g_output)
453
g_loss = -d_generated.mean()
454
455
# gradient backprop & optimize ONLY G's parameters
456
g_loss.backward()
457
g_optimizer.step()
458
459
return g_loss.data.item()
460
461
462
463
464
epoch_samples_wgan = []
465
lambda_gp = 10.0
466
num_epochs = 100
467
torch.manual_seed(1)
468
critic_iterations = 5
469
470
for epoch in range(1, num_epochs+1):
471
gen_model.train()
472
d_losses, g_losses = [], []
473
for i, (x, _) in enumerate(mnist_dl):
474
for _ in range(critic_iterations):
475
d_loss = d_train_wgan(x)
476
d_losses.append(d_loss)
477
g_losses.append(g_train_wgan(x))
478
479
print(f'Epoch {epoch:03d} | D Loss >>'
480
f' {torch.FloatTensor(d_losses).mean():.4f}')
481
gen_model.eval()
482
epoch_samples_wgan.append(
483
create_samples(gen_model, fixed_z).detach().cpu().numpy())
484
485
486
487
488
selected_epochs = [1, 2, 4, 10, 50, 100]
489
# selected_epochs = [1, 10, 20, 30, 50, 70]
490
fig = plt.figure(figsize=(10, 14))
491
for i,e in enumerate(selected_epochs):
492
for j in range(5):
493
ax = fig.add_subplot(6, 5, i*5+j+1)
494
ax.set_xticks([])
495
ax.set_yticks([])
496
if j == 0:
497
ax.text(
498
-0.06, 0.5, f'Epoch {e}',
499
rotation=90, size=18, color='red',
500
horizontalalignment='right',
501
verticalalignment='center',
502
transform=ax.transAxes)
503
504
image = epoch_samples_wgan[e-1][j]
505
ax.imshow(image, cmap='gray_r')
506
507
# plt.savefig('figures/ch17-wgan-gp-samples.pdf')
508
plt.show()
509
510
511
# ## Mode collapse
512
513
514
515
516
517
#
518
# ----
519
520
#
521
#
522
# Readers may ignore the next cell.
523
#
524
#
525
526
527
528
529
530