Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
rasbt
GitHub Repository: rasbt/machine-learning-book
Path: blob/main/ch12/ch12_part1.py
1247 views
1
# coding: utf-8
2
3
4
import sys
5
from python_environment_check import check_packages
6
import torch
7
import numpy as np
8
from torch.utils.data import DataLoader
9
from torch.utils.data import Dataset
10
from torch.utils.data import TensorDataset
11
import pathlib
12
import matplotlib.pyplot as plt
13
import os
14
from PIL import Image
15
import torchvision.transforms as transforms
16
import torchvision
17
from itertools import islice
18
19
# # Machine Learning with PyTorch and Scikit-Learn
20
# # -- Code Examples
21
22
# ## Package version checks
23
24
# Add folder to path in order to load from the check_packages.py script:
25
26
27
28
sys.path.insert(0, '..')
29
30
31
# Check recommended package versions:
32
33
34
35
36
37
d = {
38
'numpy': '1.21.2',
39
'matplotlib': '3.4.3',
40
'torch': '1.9.0',
41
}
42
check_packages(d)
43
44
45
# # Chapter 12: Parallelizing Neural Network Training with PyTorch (Part 1/2)
46
#
47
48
# - [PyTorch and training performance](#PyTorch-and-training-performance)
49
# - [Performance challenges](#Performance-challenges)
50
# - [What is PyTorch?](#What-is-PyTorch?)
51
# - [How we will learn PyTorch](#How-we-will-learn-PyTorch)
52
# - [First steps with PyTorch](#First-steps-with-PyTorch)
53
# - [Installing PyTorch](#Installing-PyTorch)
54
# - [Creating tensors in PyTorch](#Creating-tensors-in-PyTorch)
55
# - [Manipulating the data type and shape of a tensor](#Manipulating-the-data-type-and-shape-of-a-tensor)
56
# - [Applying mathematical operations to tensors](#Applying-mathematical-operations-to-tensors)
57
# - [Split, stack, and concatenate tensors](#Split,-stack,-and-concatenate-tensors)
58
59
# Note that the optional watermark extension is a small IPython notebook plugin that I developed to make the code reproducible. You can just skip the following line(s).
60
61
62
63
64
65
# ## PyTorch and training performance
66
67
# ### Performance challenges
68
69
70
71
IPythonImage(filename='figures/12_01.png', width=500)
72
73
74
# ### What is PyTorch?
75
76
77
78
IPythonImage(filename='figures/12_02.png', width=500)
79
80
81
# ### How we will learn PyTorch
82
83
# ## First steps with PyTorch
84
85
# ### Installing PyTorch
86
87
88
89
#! pip install torch
90
91
92
93
94
95
print('PyTorch version:', torch.__version__)
96
97
np.set_printoptions(precision=3)
98
99
100
101
102
103
104
# ### Creating tensors in PyTorch
105
106
107
108
a = [1, 2, 3]
109
b = np.array([4, 5, 6], dtype=np.int32)
110
111
t_a = torch.tensor(a)
112
t_b = torch.from_numpy(b)
113
114
print(t_a)
115
print(t_b)
116
117
118
119
120
torch.is_tensor(a), torch.is_tensor(t_a)
121
122
123
124
125
t_ones = torch.ones(2, 3)
126
127
t_ones.shape
128
129
130
131
132
print(t_ones)
133
134
135
136
137
rand_tensor = torch.rand(2,3)
138
139
print(rand_tensor)
140
141
142
# ### Manipulating the data type and shape of a tensor
143
144
145
146
t_a_new = t_a.to(torch.int64)
147
148
print(t_a_new.dtype)
149
150
151
152
153
t = torch.rand(3, 5)
154
155
t_tr = torch.transpose(t, 0, 1)
156
print(t.shape, ' --> ', t_tr.shape)
157
158
159
160
161
t = torch.zeros(30)
162
163
t_reshape = t.reshape(5, 6)
164
165
print(t_reshape.shape)
166
167
168
169
170
t = torch.zeros(1, 2, 1, 4, 1)
171
172
t_sqz = torch.squeeze(t, 2)
173
174
print(t.shape, ' --> ', t_sqz.shape)
175
176
177
# ### Applying mathematical operations to tensors
178
179
180
181
torch.manual_seed(1)
182
183
t1 = 2 * torch.rand(5, 2) - 1
184
t2 = torch.normal(mean=0, std=1, size=(5, 2))
185
186
187
188
189
t3 = torch.multiply(t1, t2)
190
print(t3)
191
192
193
194
195
t4 = torch.mean(t1, axis=0)
196
print(t4)
197
198
199
200
201
t5 = torch.matmul(t1, torch.transpose(t2, 0, 1))
202
203
print(t5)
204
205
206
207
208
t6 = torch.matmul(torch.transpose(t1, 0, 1), t2)
209
210
print(t6)
211
212
213
214
215
norm_t1 = torch.linalg.norm(t1, ord=2, dim=1)
216
217
print(norm_t1)
218
219
220
221
222
np.sqrt(np.sum(np.square(t1.numpy()), axis=1))
223
224
225
# ### Split, stack, and concatenate tensors
226
227
228
229
torch.manual_seed(1)
230
231
t = torch.rand(6)
232
233
print(t)
234
235
t_splits = torch.chunk(t, 3)
236
237
[item.numpy() for item in t_splits]
238
239
240
241
242
torch.manual_seed(1)
243
t = torch.rand(5)
244
245
print(t)
246
247
t_splits = torch.split(t, split_size_or_sections=[3, 2])
248
249
[item.numpy() for item in t_splits]
250
251
252
253
254
A = torch.ones(3)
255
B = torch.zeros(2)
256
257
C = torch.cat([A, B], axis=0)
258
print(C)
259
260
261
262
263
A = torch.ones(3)
264
B = torch.zeros(3)
265
266
S = torch.stack([A, B], axis=1)
267
print(S)
268
269
270
# ## Building input pipelines in PyTorch
271
272
# ### Creating a PyTorch DataLoader from existing tensors
273
274
275
276
277
t = torch.arange(6, dtype=torch.float32)
278
data_loader = DataLoader(t)
279
280
281
282
283
for item in data_loader:
284
print(item)
285
286
287
288
289
data_loader = DataLoader(t, batch_size=3, drop_last=False)
290
291
for i, batch in enumerate(data_loader, 1):
292
print(f'batch {i}:', batch)
293
294
295
# ### Combining two tensors into a joint dataset
296
297
298
299
300
class JointDataset(Dataset):
301
def __init__(self, x, y):
302
self.x = x
303
self.y = y
304
def __len__(self):
305
return len(self.x)
306
def __getitem__(self, idx):
307
return self.x[idx], self.y[idx]
308
309
310
311
312
torch.manual_seed(1)
313
314
t_x = torch.rand([4, 3], dtype=torch.float32)
315
t_y = torch.arange(4)
316
joint_dataset = JointDataset(t_x, t_y)
317
318
# Or use TensorDataset directly
319
joint_dataset = TensorDataset(t_x, t_y)
320
321
for example in joint_dataset:
322
print(' x: ', example[0],
323
' y: ', example[1])
324
325
326
# ### Shuffle, batch, and repeat
327
328
329
330
torch.manual_seed(1)
331
data_loader = DataLoader(dataset=joint_dataset, batch_size=2, shuffle=True)
332
333
for i, batch in enumerate(data_loader, 1):
334
print(f'batch {i}:', 'x:', batch[0],
335
'\n y:', batch[1])
336
337
for epoch in range(2):
338
print(f'epoch {epoch+1}')
339
for i, batch in enumerate(data_loader, 1):
340
print(f'batch {i}:', 'x:', batch[0],
341
'\n y:', batch[1])
342
343
344
# ### Creating a dataset from files on your local storage disk
345
346
347
348
349
imgdir_path = pathlib.Path('cat_dog_images')
350
351
file_list = sorted([str(path) for path in imgdir_path.glob('*.jpg')])
352
353
print(file_list)
354
355
356
357
358
359
360
fig = plt.figure(figsize=(10, 5))
361
for i, file in enumerate(file_list):
362
img = Image.open(file)
363
print('Image shape: ', np.array(img).shape)
364
ax = fig.add_subplot(2, 3, i+1)
365
ax.set_xticks([]); ax.set_yticks([])
366
ax.imshow(img)
367
ax.set_title(os.path.basename(file), size=15)
368
369
#plt.savefig('figures/12_03.pdf')
370
plt.tight_layout()
371
plt.show()
372
373
374
375
376
labels = [1 if 'dog' in os.path.basename(file) else 0
377
for file in file_list]
378
print(labels)
379
380
381
382
383
class ImageDataset(Dataset):
384
def __init__(self, file_list, labels):
385
self.file_list = file_list
386
self.labels = labels
387
388
def __getitem__(self, index):
389
file = self.file_list[index]
390
label = self.labels[index]
391
return file, label
392
393
def __len__(self):
394
return len(self.labels)
395
396
image_dataset = ImageDataset(file_list, labels)
397
for file, label in image_dataset:
398
print(file, label)
399
400
401
402
403
404
class ImageDataset(Dataset):
405
def __init__(self, file_list, labels, transform=None):
406
self.file_list = file_list
407
self.labels = labels
408
self.transform = transform
409
def __getitem__(self, index):
410
img = Image.open(self.file_list[index])
411
if self.transform is not None:
412
img = self.transform(img)
413
label = self.labels[index]
414
return img, label
415
def __len__(self):
416
return len(self.labels)
417
418
img_height, img_width = 80, 120
419
420
transform = transforms.Compose([
421
transforms.ToTensor(),
422
transforms.Resize((img_height, img_width)),
423
])
424
425
image_dataset = ImageDataset(file_list, labels, transform)
426
427
428
429
430
fig = plt.figure(figsize=(10, 6))
431
for i, example in enumerate(image_dataset):
432
ax = fig.add_subplot(2, 3, i+1)
433
ax.set_xticks([]); ax.set_yticks([])
434
ax.imshow(example[0].numpy().transpose((1, 2, 0)))
435
ax.set_title(f'{example[1]}', size=15)
436
437
plt.tight_layout()
438
plt.savefig('figures/12_04.pdf')
439
plt.show()
440
441
442
# ### Fetching available datasets from the torchvision.datasets library
443
444
445
446
# ! pip install torchvision
447
448
449
450
451
452
453
# **Fetching CelebA dataset**
454
#
455
# ---
456
457
# 1. Downloading the image files manually
458
#
459
# - You can try setting `download=True` below. If this results in a `BadZipfile` error, we recommend downloading the `img_align_celeba.zip` file manually from http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html. In the Google Drive folder, you can find it under the `Img` folder as shown below:
460
461
462
463
IPythonImage(filename='figures/gdrive-download-location-1.png', width=500)
464
465
466
# - You can also try this direct link: https://drive.google.com/file/d/0B7EVK8r0v71pZjFTYXZWM3FlRnM/view?usp=sharing&resourcekey=0-dYn9z10tMJOBAkviAcfdyQ
467
# - After downloading, please put this file into the `./celeba` subolder and unzip it.
468
469
# 2. Next, you need to download the annotation files and put them into the same `./celeba` subfolder. The annotation files can be found under `Anno`:
470
471
472
473
IPythonImage(filename='figures/gdrive-download-location-2.png', width=300)
474
475
476
# - direct links are provided below:
477
# - [identity_CelebA.txt](https://drive.google.com/file/d/1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS/view?usp=sharing)
478
# - [list_attr_celeba.txt](https://drive.google.com/file/d/0B7EVK8r0v71pblRyaVFSWGxPY0U/view?usp=sharing&resourcekey=0-YW2qIuRcWHy_1C2VaRGL3Q)
479
# - [list_bbox_celeba.txt](https://drive.google.com/file/d/0B7EVK8r0v71pbThiMVRxWXZ4dU0/view?usp=sharing&resourcekey=0-z-17UMo1wt4moRL2lu9D8A)
480
# - [list_landmarks_align_celeba.txt](https://drive.google.com/file/d/0B7EVK8r0v71pd0FJY3Blby1HUTQ/view?usp=sharing&resourcekey=0-aFtzLN5nfdhHXpAsgYA8_g)
481
# - [list_landmarks_celeba.txt](https://drive.google.com/file/d/0B7EVK8r0v71pTzJIdlJWdHczRlU/view?usp=sharing&resourcekey=0-49BtYuqFDomi-1v0vNVwrQ)
482
483
484
485
IPythonImage(filename='figures/gdrive-download-location-3.png', width=300)
486
487
488
# 3. Lastly, you need to download the file `list_eval_partition.txt` and place it under `./celeba`:
489
490
# - [list_eval_partition.txt](https://drive.google.com/file/d/0B7EVK8r0v71pY0NSMzRuSXJEVkk/view?usp=sharing&resourcekey=0-i4TGCi_51OtQ5K9FSp4EDg)
491
492
# After completing steps 1-3 above, please ensure you have the following files in your `./celeba` subfolder, and the files are non-empty (that is, they have similar file sizes as shown below):
493
494
495
496
IPythonImage(filename='figures/celeba-files.png', width=400)
497
498
499
# ---
500
501
502
503
image_path = './'
504
celeba_dataset = torchvision.datasets.CelebA(image_path, split='train', target_type='attr', download=False)
505
506
assert isinstance(celeba_dataset, torch.utils.data.Dataset)
507
508
509
510
511
example = next(iter(celeba_dataset))
512
print(example)
513
514
515
516
517
fig = plt.figure(figsize=(12, 8))
518
for i, (image, attributes) in islice(enumerate(celeba_dataset), 18):
519
ax = fig.add_subplot(3, 6, i+1)
520
ax.set_xticks([]); ax.set_yticks([])
521
ax.imshow(image)
522
ax.set_title(f'{attributes[31]}', size=15)
523
524
#plt.savefig('figures/12_05.pdf')
525
plt.show()
526
527
528
529
530
mnist_dataset = torchvision.datasets.MNIST(image_path, 'train', download=True)
531
532
assert isinstance(mnist_dataset, torch.utils.data.Dataset)
533
534
example = next(iter(mnist_dataset))
535
print(example)
536
537
fig = plt.figure(figsize=(15, 6))
538
for i, (image, label) in islice(enumerate(mnist_dataset), 10):
539
ax = fig.add_subplot(2, 5, i+1)
540
ax.set_xticks([]); ax.set_yticks([])
541
ax.imshow(image, cmap='gray_r')
542
ax.set_title(f'{label}', size=15)
543
544
#plt.savefig('figures/12_06.pdf')
545
plt.show()
546
547
548
# ---
549
#
550
# Readers may ignore the next cell.
551
552
553
554
555
556