Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
rasbt
GitHub Repository: rasbt/machine-learning-book
Path: blob/main/ch17/ch17_part1.py
1245 views
1
# coding: utf-8
2
3
4
import sys
5
from python_environment_check import check_packages
6
import torch
7
#from google.colab import drive
8
import torch.nn as nn
9
import numpy as np
10
import matplotlib.pyplot as plt
11
import torchvision
12
from torchvision import transforms
13
from torch.utils.data import DataLoader
14
import itertools
15
import math
16
17
# # Machine Learning with PyTorch and Scikit-Learn
18
# # -- Code Examples
19
20
# ## Package version checks
21
22
# Add folder to path in order to load from the check_packages.py script:
23
24
25
26
sys.path.insert(0, '..')
27
28
29
# Check recommended package versions:
30
31
32
33
34
35
d = {
36
'torch': '1.8.0',
37
'torchvision': '0.9.0',
38
'numpy': '1.21.2',
39
'matplotlib': '3.4.3',
40
}
41
42
check_packages(d)
43
44
45
# # Chapter 17 - Generative Adversarial Networks for Synthesizing New Data (Part 1/2)
46
47
# **Contents**
48
#
49
# - [Introducing generative adversarial networks](#Introducing-generative-adversarial-networks)
50
# - [Starting with autoencoders](#Starting-with-autoencoders)
51
# - [Generative models for synthesizing new data](#Generative-models-for-synthesizing-new-data)
52
# - [Generating new samples with GANs](#Generating-new-samples-with-GANs)
53
# - [Understanding the loss functions for the generator and discriminator networks in a GAN model](#Understanding-the-loss-functions-for-the-generator-and-discriminator-networks-in-a-GAN-model)
54
# - [Implementing a GAN from scratch](#Implementing-a-GAN-from-scratch)
55
# - [Training GAN models on Google Colab](#Training-GAN-models-on-Google-Colab)
56
# - [Implementing the generator and the discriminator networks](#Implementing-the-generator-and-the-discriminator-networks)
57
# - [Defining the training dataset](#Defining-the-training-dataset)
58
# - [Training the GAN model](#Training-the-GAN-model)
59
60
# Note that the optional watermark extension is a small IPython notebook plugin that I developed to make the code reproducible. You can just skip the following line(s).
61
62
63
64
65
66
67
68
69
70
# # Introducing generative adversarial networks
71
#
72
# ## Starting with autoencoders
73
74
75
76
77
78
# ## Generative models for synthesizing new data
79
80
81
82
83
84
# ## Generating new samples with GANs
85
86
87
88
89
90
# ## Understanding the loss functions for the generator and discriminator networks in a GAN model
91
92
93
94
95
96
# # Implementing a GAN from scratch
97
#
98
99
# ## Training GAN models on Google Colab
100
101
102
103
104
105
106
107
108
109
110
111
112
113
print(torch.__version__)
114
print("GPU Available:", torch.cuda.is_available())
115
116
if torch.cuda.is_available():
117
device = torch.device("cuda:0")
118
else:
119
device = "cpu"
120
121
122
123
124
# !pip install torchvision
125
126
127
128
129
#drive.mount('/content/drive/')
130
131
132
# ## Implementing the generator and the discriminator networks
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
## define a function for the generator:
150
def make_generator_network(
151
input_size=20,
152
num_hidden_layers=1,
153
num_hidden_units=100,
154
num_output_units=784):
155
model = nn.Sequential()
156
for i in range(num_hidden_layers):
157
model.add_module(f'fc_g{i}',
158
nn.Linear(input_size,
159
num_hidden_units))
160
model.add_module(f'relu_g{i}',
161
nn.LeakyReLU())
162
input_size = num_hidden_units
163
model.add_module(f'fc_g{num_hidden_layers}',
164
nn.Linear(input_size, num_output_units))
165
model.add_module('tanh_g', nn.Tanh())
166
return model
167
168
## define a function for the discriminator:
169
def make_discriminator_network(
170
input_size,
171
num_hidden_layers=1,
172
num_hidden_units=100,
173
num_output_units=1):
174
model = nn.Sequential()
175
for i in range(num_hidden_layers):
176
model.add_module(f'fc_d{i}',
177
nn.Linear(input_size,
178
num_hidden_units, bias=False))
179
model.add_module(f'relu_d{i}',
180
nn.LeakyReLU())
181
model.add_module('dropout', nn.Dropout(p=0.5))
182
input_size = num_hidden_units
183
model.add_module(f'fc_d{num_hidden_layers}',
184
nn.Linear(input_size, num_output_units))
185
model.add_module('sigmoid', nn.Sigmoid())
186
return model
187
188
189
190
191
image_size = (28, 28)
192
z_size = 20
193
194
gen_hidden_layers = 1
195
gen_hidden_size = 100
196
disc_hidden_layers = 1
197
disc_hidden_size = 100
198
199
torch.manual_seed(1)
200
201
gen_model = make_generator_network(
202
input_size=z_size,
203
num_hidden_layers=gen_hidden_layers,
204
num_hidden_units=gen_hidden_size,
205
num_output_units=np.prod(image_size))
206
207
print(gen_model)
208
209
210
211
212
disc_model = make_discriminator_network(
213
input_size=np.prod(image_size),
214
num_hidden_layers=disc_hidden_layers,
215
num_hidden_units=disc_hidden_size)
216
217
print(disc_model)
218
219
220
# ## Defining the training dataset
221
222
# * **Step-by-step walk through the data-flow**
223
224
225
226
227
228
image_path = './'
229
transform = transforms.Compose([
230
transforms.ToTensor(),
231
transforms.Normalize(mean=(0.5), std=(0.5)),
232
])
233
mnist_dataset = torchvision.datasets.MNIST(root=image_path,
234
train=True,
235
transform=transform,
236
download=False)
237
238
example, label = next(iter(mnist_dataset))
239
print(f'Min: {example.min()} Max: {example.max()}')
240
print(example.shape)
241
242
243
244
245
def create_noise(batch_size, z_size, mode_z):
246
if mode_z == 'uniform':
247
input_z = torch.rand(batch_size, z_size)*2 - 1
248
elif mode_z == 'normal':
249
input_z = torch.randn(batch_size, z_size)
250
return input_z
251
252
253
254
255
256
257
batch_size = 32
258
dataloader = DataLoader(mnist_dataset, batch_size, shuffle=False)
259
input_real, label = next(iter(dataloader))
260
input_real = input_real.view(batch_size, -1)
261
262
torch.manual_seed(1)
263
mode_z = 'uniform' # 'uniform' vs. 'normal'
264
input_z = create_noise(batch_size, z_size, mode_z)
265
266
print('input-z -- shape:', input_z.shape)
267
print('input-real -- shape:', input_real.shape)
268
269
g_output = gen_model(input_z)
270
print('Output of G -- shape:', g_output.shape)
271
272
d_proba_real = disc_model(input_real)
273
d_proba_fake = disc_model(g_output)
274
print('Disc. (real) -- shape:', d_proba_real.shape)
275
print('Disc. (fake) -- shape:', d_proba_fake.shape)
276
277
278
# ## Training the GAN model
279
280
281
282
loss_fn = nn.BCELoss()
283
284
## Loss for the Generator
285
g_labels_real = torch.ones_like(d_proba_fake)
286
g_loss = loss_fn(d_proba_fake, g_labels_real)
287
print(f'Generator Loss: {g_loss:.4f}')
288
289
## Loss for the Discriminator
290
d_labels_real = torch.ones_like(d_proba_real)
291
d_labels_fake = torch.zeros_like(d_proba_fake)
292
293
d_loss_real = loss_fn(d_proba_real, d_labels_real)
294
d_loss_fake = loss_fn(d_proba_fake, d_labels_fake)
295
print(f'Discriminator Losses: Real {d_loss_real:.4f} Fake {d_loss_fake:.4f}')
296
297
298
# * **Final training**
299
300
301
302
batch_size = 64
303
304
torch.manual_seed(1)
305
np.random.seed(1)
306
307
## Set up the dataset
308
mnist_dl = DataLoader(mnist_dataset, batch_size=batch_size,
309
shuffle=True, drop_last=True)
310
311
## Set up the models
312
gen_model = make_generator_network(
313
input_size=z_size,
314
num_hidden_layers=gen_hidden_layers,
315
num_hidden_units=gen_hidden_size,
316
num_output_units=np.prod(image_size)).to(device)
317
318
disc_model = make_discriminator_network(
319
input_size=np.prod(image_size),
320
num_hidden_layers=disc_hidden_layers,
321
num_hidden_units=disc_hidden_size).to(device)
322
323
## Loss function and optimizers:
324
loss_fn = nn.BCELoss()
325
g_optimizer = torch.optim.Adam(gen_model.parameters())
326
d_optimizer = torch.optim.Adam(disc_model.parameters())
327
328
329
330
331
## Train the discriminator
332
def d_train(x):
333
disc_model.zero_grad()
334
335
# Train discriminator with a real batch
336
batch_size = x.size(0)
337
x = x.view(batch_size, -1).to(device)
338
d_labels_real = torch.ones(batch_size, 1, device=device)
339
340
d_proba_real = disc_model(x)
341
d_loss_real = loss_fn(d_proba_real, d_labels_real)
342
343
# Train discriminator on a fake batch
344
input_z = create_noise(batch_size, z_size, mode_z).to(device)
345
g_output = gen_model(input_z)
346
347
d_proba_fake = disc_model(g_output)
348
d_labels_fake = torch.zeros(batch_size, 1, device=device)
349
d_loss_fake = loss_fn(d_proba_fake, d_labels_fake)
350
351
# gradient backprop & optimize ONLY D's parameters
352
d_loss = d_loss_real + d_loss_fake
353
d_loss.backward()
354
d_optimizer.step()
355
356
return d_loss.data.item(), d_proba_real.detach(), d_proba_fake.detach()
357
358
359
360
361
## Train the generator
362
def g_train(x):
363
gen_model.zero_grad()
364
365
batch_size = x.size(0)
366
input_z = create_noise(batch_size, z_size, mode_z).to(device)
367
g_labels_real = torch.ones(batch_size, 1, device=device)
368
369
g_output = gen_model(input_z)
370
d_proba_fake = disc_model(g_output)
371
g_loss = loss_fn(d_proba_fake, g_labels_real)
372
373
# gradient backprop & optimize ONLY G's parameters
374
g_loss.backward()
375
g_optimizer.step()
376
377
return g_loss.data.item()
378
379
380
381
382
fixed_z = create_noise(batch_size, z_size, mode_z).to(device)
383
384
def create_samples(g_model, input_z):
385
g_output = g_model(input_z)
386
images = torch.reshape(g_output, (batch_size, *image_size))
387
return (images+1)/2.0
388
389
epoch_samples = []
390
391
all_d_losses = []
392
all_g_losses = []
393
394
all_d_real = []
395
all_d_fake = []
396
397
num_epochs = 100
398
torch.manual_seed(1)
399
for epoch in range(1, num_epochs+1):
400
d_losses, g_losses = [], []
401
d_vals_real, d_vals_fake = [], []
402
for i, (x, _) in enumerate(mnist_dl):
403
d_loss, d_proba_real, d_proba_fake = d_train(x)
404
d_losses.append(d_loss)
405
g_losses.append(g_train(x))
406
407
d_vals_real.append(d_proba_real.mean().cpu())
408
d_vals_fake.append(d_proba_fake.mean().cpu())
409
410
all_d_losses.append(torch.tensor(d_losses).mean())
411
all_g_losses.append(torch.tensor(g_losses).mean())
412
all_d_real.append(torch.tensor(d_vals_real).mean())
413
all_d_fake.append(torch.tensor(d_vals_fake).mean())
414
print(f'Epoch {epoch:03d} | Avg Losses >>'
415
f' G/D {all_g_losses[-1]:.4f}/{all_d_losses[-1]:.4f}'
416
f' [D-Real: {all_d_real[-1]:.4f} D-Fake: {all_d_fake[-1]:.4f}]')
417
epoch_samples.append(
418
create_samples(gen_model, fixed_z).detach().cpu().numpy())
419
420
421
422
423
424
425
426
fig = plt.figure(figsize=(16, 6))
427
428
## Plotting the losses
429
ax = fig.add_subplot(1, 2, 1)
430
431
plt.plot(all_g_losses, label='Generator loss')
432
half_d_losses = [all_d_loss/2 for all_d_loss in all_d_losses]
433
plt.plot(half_d_losses, label='Discriminator loss')
434
plt.legend(fontsize=20)
435
ax.set_xlabel('Iteration', size=15)
436
ax.set_ylabel('Loss', size=15)
437
438
## Plotting the outputs of the discriminator
439
ax = fig.add_subplot(1, 2, 2)
440
plt.plot(all_d_real, label=r'Real: $D(\mathbf{x})$')
441
plt.plot(all_d_fake, label=r'Fake: $D(G(\mathbf{z}))$')
442
plt.legend(fontsize=20)
443
ax.set_xlabel('Iteration', size=15)
444
ax.set_ylabel('Discriminator output', size=15)
445
446
#plt.savefig('figures/ch17-gan-learning-curve.pdf')
447
plt.show()
448
449
450
451
452
selected_epochs = [1, 2, 4, 10, 50, 100]
453
fig = plt.figure(figsize=(10, 14))
454
for i,e in enumerate(selected_epochs):
455
for j in range(5):
456
ax = fig.add_subplot(6, 5, i*5+j+1)
457
ax.set_xticks([])
458
ax.set_yticks([])
459
if j == 0:
460
ax.text(
461
-0.06, 0.5, f'Epoch {e}',
462
rotation=90, size=18, color='red',
463
horizontalalignment='right',
464
verticalalignment='center',
465
transform=ax.transAxes)
466
467
image = epoch_samples[e-1][j]
468
ax.imshow(image, cmap='gray_r')
469
470
#plt.savefig('figures/ch17-vanila-gan-samples.pdf')
471
plt.show()
472
473
474
#
475
# ----
476
477
478
479
480
481
def distance(X, Y, sqrt):
482
nX = X.size(0)
483
nY = Y.size(0)
484
X = X.view(nX,-1).cuda()
485
X2 = (X*X).sum(1).resize_(nX,1)
486
Y = Y.view(nY,-1).cuda()
487
Y2 = (Y*Y).sum(1).resize_(nY,1)
488
489
M = torch.zeros(nX, nY)
490
M.copy_(X2.expand(nX,nY) + Y2.expand(nY,nX).transpose(0,1) - 2*torch.mm(X,Y.transpose(0,1)))
491
492
del X, X2, Y, Y2
493
494
if sqrt:
495
M = ((M+M.abs())/2).sqrt()
496
497
return M
498
499
500
501
502
def mmd(Mxx, Mxy, Myy, sigma) :
503
scale = Mxx.mean()
504
Mxx = torch.exp(-Mxx/(scale*2*sigma*sigma))
505
Mxy = torch.exp(-Mxy/(scale*2*sigma*sigma))
506
Myy = torch.exp(-Myy/(scale*2*sigma*sigma))
507
a = Mxx.mean()+Myy.mean()-2*Mxy.mean()
508
mmd = math.sqrt(max(a, 0))
509
510
return mmd
511
512
513
514
515
def compute_score(fake, real , k=1, sigma=1, sqrt=True):
516
517
Mxx = distance(real, real, False)
518
Mxy = distance(real, fake, False)
519
Myy = distance(fake, fake, False)
520
521
522
print(mmd(Mxx, Mxy, Myy, sigma))
523
524
525
526
527
whole_dl = DataLoader(mnist_dataset, batch_size=10000,
528
shuffle=True, drop_last=True)
529
530
531
532
533
real_image = next(iter(whole_dl))[0]
534
535
536
537
538
compute_score(torch.from_numpy(epoch_samples[-1]), real_image)
539
540
541
#
542
#
543
# Readers may ignore the next cell.
544
#
545
546
547
548
549
550
551
552
553
554
555