Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
rasbt
GitHub Repository: rasbt/machine-learning-book
Path: blob/main/ch14/ch14_part2.py
1245 views
1
# coding: utf-8
2
3
4
import sys
5
from python_environment_check import check_packages
6
import torch
7
import torch.nn as nn
8
import numpy as np
9
import matplotlib.pyplot as plt
10
import torchvision
11
from torchvision import transforms
12
from torch.utils.data import DataLoader
13
from torch.utils.data import Subset
14
15
# # Machine Learning with PyTorch and Scikit-Learn
16
# # -- Code Examples
17
18
# ## Package version checks
19
20
# Add folder to path in order to load from the check_packages.py script:
21
22
23
24
sys.path.insert(0, '..')
25
26
27
# Check recommended package versions:
28
29
30
31
32
33
d = {
34
'numpy': '1.21.2',
35
'scipy': '1.7.0',
36
'matplotlib': '3.4.3',
37
'torch': '1.8.0',
38
'torchvision': '0.9.0'
39
}
40
check_packages(d)
41
42
43
# # Chapter 14: Classifying Images with Deep Convolutional Neural Networks (Part 2/2)
44
45
# **Outline**
46
#
47
# - [Smile classification from face images using a CNN](#Constructing-a-CNN-in-PyTorch)
48
# - [Loading the CelebA dataset](#Loading-the-CelebA-dataset)
49
# - [Image transformation and data augmentation](#Image-transformation-and-data-augmentation)
50
# - [Training a CNN smile classifier](#Training-a-CNN-smile-classifier)
51
# - [Summary](#Summary)
52
53
# Note that the optional watermark extension is a small IPython notebook plugin that I developed to make the code reproducible. You can just skip the following line(s).
54
55
56
57
58
59
60
61
# ## Smile classification from face images using CNN
62
#
63
64
# ### Loading the CelebA dataset
65
66
# You can try setting `download=True` in the code cell below, however due to the daily download limits of the CelebA dataset, this will probably result in an error. Alternatively, we recommend trying the following:
67
#
68
# - You can download the files from the official CelebA website manually (https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html)
69
# - or use our download link, https://drive.google.com/file/d/1m8-EBPgi5MRubrm6iQjafK2QMHDBMSfJ/view?usp=sharing (recommended).
70
#
71
# If you use our download link, it will download a `celeba.zip` file,
72
#
73
# 1. which you need to unpack in the current directory where you are running the code.
74
# 2. In addition, **please also make sure you unzip the `img_align_celeba.zip` file, which is inside the `celeba` folder.**
75
# 3. Also, after downloading and unzipping the celeba folder, you need to run with the setting `download=False` instead of `download=True` (as shown in the code cell below).
76
#
77
# In case you are encountering problems with this approach, please do not hesitate to open a new issue or start a discussion at https://github.com/ rasbt/machine-learning-book so that we can provide you with additional information.
78
79
80
81
82
image_path = './'
83
celeba_train_dataset = torchvision.datasets.CelebA(image_path, split='train', target_type='attr', download=False)
84
celeba_valid_dataset = torchvision.datasets.CelebA(image_path, split='valid', target_type='attr', download=False)
85
celeba_test_dataset = torchvision.datasets.CelebA(image_path, split='test', target_type='attr', download=False)
86
87
print('Train set:', len(celeba_train_dataset))
88
print('Validation set:', len(celeba_valid_dataset))
89
print('Test set:', len(celeba_test_dataset))
90
91
92
# ### Image transformation and data augmentation
93
94
95
96
97
## take 5 examples
98
99
fig = plt.figure(figsize=(16, 8.5))
100
101
## Column 1: cropping to a bounding-box
102
ax = fig.add_subplot(2, 5, 1)
103
img, attr = celeba_train_dataset[0]
104
ax.set_title('Crop to a \nbounding-box', size=15)
105
ax.imshow(img)
106
ax = fig.add_subplot(2, 5, 6)
107
img_cropped = transforms.functional.crop(img, 50, 20, 128, 128)
108
ax.imshow(img_cropped)
109
110
## Column 2: flipping (horizontally)
111
ax = fig.add_subplot(2, 5, 2)
112
img, attr = celeba_train_dataset[1]
113
ax.set_title('Flip (horizontal)', size=15)
114
ax.imshow(img)
115
ax = fig.add_subplot(2, 5, 7)
116
img_flipped = transforms.functional.hflip(img)
117
ax.imshow(img_flipped)
118
119
## Column 3: adjust contrast
120
ax = fig.add_subplot(2, 5, 3)
121
img, attr = celeba_train_dataset[2]
122
ax.set_title('Adjust constrast', size=15)
123
ax.imshow(img)
124
ax = fig.add_subplot(2, 5, 8)
125
img_adj_contrast = transforms.functional.adjust_contrast(img, contrast_factor=2)
126
ax.imshow(img_adj_contrast)
127
128
## Column 4: adjust brightness
129
ax = fig.add_subplot(2, 5, 4)
130
img, attr = celeba_train_dataset[3]
131
ax.set_title('Adjust brightness', size=15)
132
ax.imshow(img)
133
ax = fig.add_subplot(2, 5, 9)
134
img_adj_brightness = transforms.functional.adjust_brightness(img, brightness_factor=1.3)
135
ax.imshow(img_adj_brightness)
136
137
## Column 5: cropping from image center
138
ax = fig.add_subplot(2, 5, 5)
139
img, attr = celeba_train_dataset[4]
140
ax.set_title('Center crop\nand resize', size=15)
141
ax.imshow(img)
142
ax = fig.add_subplot(2, 5, 10)
143
img_center_crop = transforms.functional.center_crop(img, [0.7*218, 0.7*178])
144
img_resized = transforms.functional.resize(img_center_crop, size=(218, 178))
145
ax.imshow(img_resized)
146
147
# plt.savefig('figures/14_14.png', dpi=300)
148
plt.show()
149
150
151
152
153
torch.manual_seed(1)
154
155
fig = plt.figure(figsize=(14, 12))
156
157
for i, (img, attr) in enumerate(celeba_train_dataset):
158
ax = fig.add_subplot(3, 4, i*4+1)
159
ax.imshow(img)
160
if i == 0:
161
ax.set_title('Orig.', size=15)
162
163
ax = fig.add_subplot(3, 4, i*4+2)
164
img_transform = transforms.Compose([transforms.RandomCrop([178, 178])])
165
img_cropped = img_transform(img)
166
ax.imshow(img_cropped)
167
if i == 0:
168
ax.set_title('Step 1: Random crop', size=15)
169
170
ax = fig.add_subplot(3, 4, i*4+3)
171
img_transform = transforms.Compose([transforms.RandomHorizontalFlip()])
172
img_flip = img_transform(img_cropped)
173
ax.imshow(img_flip)
174
if i == 0:
175
ax.set_title('Step 2: Random flip', size=15)
176
177
ax = fig.add_subplot(3, 4, i*4+4)
178
img_resized = transforms.functional.resize(img_flip, size=(128, 128))
179
ax.imshow(img_resized)
180
if i == 0:
181
ax.set_title('Step 3: Resize', size=15)
182
183
if i == 2:
184
break
185
186
# plt.savefig('figures/14_15.png', dpi=300)
187
plt.show()
188
189
190
191
192
get_smile = lambda attr: attr[18]
193
194
transform_train = transforms.Compose([
195
transforms.RandomCrop([178, 178]),
196
transforms.RandomHorizontalFlip(),
197
transforms.Resize([64, 64]),
198
transforms.ToTensor(),
199
])
200
201
transform = transforms.Compose([
202
transforms.CenterCrop([178, 178]),
203
transforms.Resize([64, 64]),
204
transforms.ToTensor(),
205
])
206
207
208
209
210
211
celeba_train_dataset = torchvision.datasets.CelebA(image_path,
212
split='train',
213
target_type='attr',
214
download=False,
215
transform=transform_train,
216
target_transform=get_smile)
217
218
torch.manual_seed(1)
219
data_loader = DataLoader(celeba_train_dataset, batch_size=2)
220
221
fig = plt.figure(figsize=(15, 6))
222
223
num_epochs = 5
224
for j in range(num_epochs):
225
img_batch, label_batch = next(iter(data_loader))
226
img = img_batch[0]
227
ax = fig.add_subplot(2, 5, j + 1)
228
ax.set_xticks([])
229
ax.set_yticks([])
230
ax.set_title(f'Epoch {j}:', size=15)
231
ax.imshow(img.permute(1, 2, 0))
232
233
img = img_batch[1]
234
ax = fig.add_subplot(2, 5, j + 6)
235
ax.set_xticks([])
236
ax.set_yticks([])
237
ax.imshow(img.permute(1, 2, 0))
238
239
240
#plt.savefig('figures/14_16.png', dpi=300)
241
plt.show()
242
243
244
245
246
247
celeba_valid_dataset = torchvision.datasets.CelebA(image_path,
248
split='valid',
249
target_type='attr',
250
download=False,
251
transform=transform,
252
target_transform=get_smile)
253
254
celeba_test_dataset = torchvision.datasets.CelebA(image_path,
255
split='test',
256
target_type='attr',
257
download=False,
258
transform=transform,
259
target_transform=get_smile)
260
261
celeba_train_dataset = Subset(celeba_train_dataset, torch.arange(16000))
262
celeba_valid_dataset = Subset(celeba_valid_dataset, torch.arange(1000))
263
264
print('Train set:', len(celeba_train_dataset))
265
print('Validation set:', len(celeba_valid_dataset))
266
267
268
269
270
batch_size = 32
271
272
torch.manual_seed(1)
273
train_dl = DataLoader(celeba_train_dataset, batch_size, shuffle=True)
274
valid_dl = DataLoader(celeba_valid_dataset, batch_size, shuffle=False)
275
test_dl = DataLoader(celeba_test_dataset, batch_size, shuffle=False)
276
277
278
# ### Training a CNN Smile classifier
279
#
280
# * **Global Average Pooling**
281
282
283
284
285
286
287
288
289
model = nn.Sequential()
290
291
model.add_module('conv1', nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1))
292
model.add_module('relu1', nn.ReLU())
293
model.add_module('pool1', nn.MaxPool2d(kernel_size=2))
294
model.add_module('dropout1', nn.Dropout(p=0.5))
295
296
model.add_module('conv2', nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1))
297
model.add_module('relu2', nn.ReLU())
298
model.add_module('pool2', nn.MaxPool2d(kernel_size=2))
299
model.add_module('dropout2', nn.Dropout(p=0.5))
300
301
model.add_module('conv3', nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1))
302
model.add_module('relu3', nn.ReLU())
303
model.add_module('pool3', nn.MaxPool2d(kernel_size=2))
304
305
model.add_module('conv4', nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1))
306
model.add_module('relu4', nn.ReLU())
307
308
309
310
311
x = torch.ones((4, 3, 64, 64))
312
model(x).shape
313
314
315
316
317
model.add_module('pool4', nn.AvgPool2d(kernel_size=8))
318
model.add_module('flatten', nn.Flatten())
319
320
x = torch.ones((4, 3, 64, 64))
321
model(x).shape
322
323
324
325
326
model.add_module('fc', nn.Linear(256, 1))
327
model.add_module('sigmoid', nn.Sigmoid())
328
329
330
331
332
x = torch.ones((4, 3, 64, 64))
333
model(x).shape
334
335
336
337
338
model
339
340
341
342
343
device = torch.device("cuda:0")
344
# device = torch.device("cpu")
345
model = model.to(device)
346
347
348
349
350
loss_fn = nn.BCELoss()
351
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
352
353
def train(model, num_epochs, train_dl, valid_dl):
354
loss_hist_train = [0] * num_epochs
355
accuracy_hist_train = [0] * num_epochs
356
loss_hist_valid = [0] * num_epochs
357
accuracy_hist_valid = [0] * num_epochs
358
for epoch in range(num_epochs):
359
model.train()
360
for x_batch, y_batch in train_dl:
361
x_batch = x_batch.to(device)
362
y_batch = y_batch.to(device)
363
pred = model(x_batch)[:, 0]
364
loss = loss_fn(pred, y_batch.float())
365
loss.backward()
366
optimizer.step()
367
optimizer.zero_grad()
368
loss_hist_train[epoch] += loss.item()*y_batch.size(0)
369
is_correct = ((pred>=0.5).float() == y_batch).float()
370
accuracy_hist_train[epoch] += is_correct.sum().cpu()
371
372
loss_hist_train[epoch] /= len(train_dl.dataset)
373
accuracy_hist_train[epoch] /= len(train_dl.dataset)
374
375
model.eval()
376
with torch.no_grad():
377
for x_batch, y_batch in valid_dl:
378
x_batch = x_batch.to(device)
379
y_batch = y_batch.to(device)
380
pred = model(x_batch)[:, 0]
381
loss = loss_fn(pred, y_batch.float())
382
loss_hist_valid[epoch] += loss.item()*y_batch.size(0)
383
is_correct = ((pred>=0.5).float() == y_batch).float()
384
accuracy_hist_valid[epoch] += is_correct.sum().cpu()
385
386
loss_hist_valid[epoch] /= len(valid_dl.dataset)
387
accuracy_hist_valid[epoch] /= len(valid_dl.dataset)
388
389
print(f'Epoch {epoch+1} accuracy: {accuracy_hist_train[epoch]:.4f} val_accuracy: {accuracy_hist_valid[epoch]:.4f}')
390
return loss_hist_train, loss_hist_valid, accuracy_hist_train, accuracy_hist_valid
391
392
torch.manual_seed(1)
393
num_epochs = 30
394
hist = train(model, num_epochs, train_dl, valid_dl)
395
396
397
398
399
x_arr = np.arange(len(hist[0])) + 1
400
401
fig = plt.figure(figsize=(12, 4))
402
ax = fig.add_subplot(1, 2, 1)
403
ax.plot(x_arr, hist[0], '-o', label='Train loss')
404
ax.plot(x_arr, hist[1], '--<', label='Validation loss')
405
ax.legend(fontsize=15)
406
ax.set_xlabel('Epoch', size=15)
407
ax.set_ylabel('Loss', size=15)
408
409
ax = fig.add_subplot(1, 2, 2)
410
ax.plot(x_arr, hist[2], '-o', label='Train acc.')
411
ax.plot(x_arr, hist[3], '--<', label='Validation acc.')
412
ax.legend(fontsize=15)
413
ax.set_xlabel('Epoch', size=15)
414
ax.set_ylabel('Accuracy', size=15)
415
416
#plt.savefig('figures/14_17.png', dpi=300)
417
plt.show()
418
419
420
421
422
accuracy_test = 0
423
424
model.eval()
425
with torch.no_grad():
426
for x_batch, y_batch in test_dl:
427
x_batch = x_batch.to(device)
428
y_batch = y_batch.to(device)
429
pred = model(x_batch)[:, 0]
430
is_correct = ((pred>=0.5).float() == y_batch).float()
431
accuracy_test += is_correct.sum().cpu()
432
433
accuracy_test /= len(test_dl.dataset)
434
435
print(f'Test accuracy: {accuracy_test:.4f}')
436
437
438
439
440
pred = model(x_batch)[:, 0] * 100
441
442
fig = plt.figure(figsize=(15, 7))
443
for j in range(10, 20):
444
ax = fig.add_subplot(2, 5, j-10+1)
445
ax.set_xticks([]); ax.set_yticks([])
446
ax.imshow(x_batch[j].cpu().permute(1, 2, 0))
447
if y_batch[j] == 1:
448
label = 'Smile'
449
else:
450
label = 'Not Smile'
451
ax.text(
452
0.5, -0.15,
453
f'GT: {label:s}\nPr(Smile)={pred[j]:.0f}%',
454
size=16,
455
horizontalalignment='center',
456
verticalalignment='center',
457
transform=ax.transAxes)
458
459
#plt.savefig('figures/figures-14_18.png', dpi=300)
460
plt.show()
461
462
463
464
465
path = 'models/celeba-cnn.ph'
466
torch.save(model, path)
467
468
469
# ...
470
#
471
#
472
# ## Summary
473
#
474
# ...
475
#
476
#
477
478
# ----
479
#
480
# Readers may ignore the next cell.
481
482
483
484
485
486
487
488
489
490
491