Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
rasbt
GitHub Repository: rasbt/machine-learning-book
Path: blob/main/ch13/ch13_part3_lightning.py
1245 views
1
# coding: utf-8
2
3
4
from pkg_resources import parse_version
5
import sys
6
from python_environment_check import check_packages
7
import pytorch_lightning as pl
8
import torch
9
import torch.nn as nn
10
from torchmetrics import __version__ as torchmetrics_version
11
from torchmetrics import Accuracy
12
from torch.utils.data import DataLoader
13
from torch.utils.data import random_split
14
from torchvision.datasets import MNIST
15
from torchvision import transforms
16
from pytorch_lightning.callbacks import ModelCheckpoint
17
18
# # Machine Learning with PyTorch and Scikit-Learn
19
# # -- Code Examples
20
21
# ## Package version checks
22
23
# Add folder to path in order to load from the check_packages.py script:
24
25
26
27
sys.path.insert(0, '..')
28
29
30
# Check recommended package versions:
31
32
33
34
35
36
d = {
37
'torch': '1.8',
38
'torchvision': '0.9.0',
39
'tensorboard': '2.7.0',
40
'pytorch_lightning': '1.5.0',
41
'torchmetrics': '0.6.2'
42
}
43
check_packages(d)
44
45
46
# # Chapter 13: Going Deeper -- the Mechanics of PyTorch (Part 3/3)
47
48
# **Outline**
49
#
50
# - [Higher-level PyTorch APIs: a short introduction to PyTorch Lightning](#Higher-level-PyTorch-APIs-a-short-introduction-to-PyTorch-Lightning)
51
# - [Setting up the PyTorch Lightning model](#Setting-up-the-PyTorch-Lightning-model)
52
# - [Setting up the data loaders for Lightning](#Setting-up-the-data-loaders-for-Lightning)
53
# - [Training the model using the PyTorch Lightning Trainer class](#Training-the-model-using-the-PyTorch-Lightning-Trainer-class)
54
# - [Evaluating the model using TensorBoard](#Evaluating-the-model-using-TensorBoard)
55
# - [Summary](#Summary)
56
57
# ## Higher-level PyTorch APIs: a short introduction to PyTorch Lightning
58
59
# ### Setting up the PyTorch Lightning model
60
61
# ## Higher-level PyTorch APIs: a short introduction to PyTorch Lightning
62
63
# ### Setting up the PyTorch Lightning model
64
65
66
67
68
69
70
71
72
class MultiLayerPerceptron(pl.LightningModule):
73
def __init__(self, image_shape=(1, 28, 28), hidden_units=(32, 16)):
74
super().__init__()
75
76
# new PL attributes:
77
78
if parse_version(torchmetrics_version) > parse_version(0.8):
79
self.train_acc = Accuracy(task="multiclass", num_classes=10)
80
self.valid_acc = Accuracy(task="multiclass", num_classes=10)
81
self.test_acc = Accuracy(task="multiclass", num_classes=10)
82
else:
83
self.train_acc = Accuracy()
84
self.valid_acc = Accuracy()
85
self.test_acc = Accuracy()
86
87
# Model similar to previous section:
88
input_size = image_shape[0] * image_shape[1] * image_shape[2]
89
all_layers = [nn.Flatten()]
90
for hidden_unit in hidden_units:
91
layer = nn.Linear(input_size, hidden_unit)
92
all_layers.append(layer)
93
all_layers.append(nn.ReLU())
94
input_size = hidden_unit
95
96
all_layers.append(nn.Linear(hidden_units[-1], 10))
97
self.model = nn.Sequential(*all_layers)
98
99
def forward(self, x):
100
x = self.model(x)
101
return x
102
103
def training_step(self, batch, batch_idx):
104
x, y = batch
105
logits = self(x)
106
loss = nn.functional.cross_entropy(logits, y)
107
preds = torch.argmax(logits, dim=1)
108
self.train_acc.update(preds, y)
109
self.log("train_loss", loss, prog_bar=True)
110
return loss
111
112
def training_epoch_end(self, outs):
113
self.log("train_acc", self.train_acc.compute())
114
self.train_acc.reset()
115
116
def validation_step(self, batch, batch_idx):
117
x, y = batch
118
logits = self(x)
119
loss = nn.functional.cross_entropy(logits, y)
120
preds = torch.argmax(logits, dim=1)
121
self.valid_acc.update(preds, y)
122
self.log("valid_loss", loss, prog_bar=True)
123
return loss
124
125
def validation_epoch_end(self, outs):
126
self.log("valid_acc", self.valid_acc.compute(), prog_bar=True)
127
self.valid_acc.reset()
128
129
def test_step(self, batch, batch_idx):
130
x, y = batch
131
logits = self(x)
132
loss = nn.functional.cross_entropy(logits, y)
133
preds = torch.argmax(logits, dim=1)
134
self.test_acc.update(preds, y)
135
self.log("test_loss", loss, prog_bar=True)
136
self.log("test_acc", self.test_acc.compute(), prog_bar=True)
137
return loss
138
139
def configure_optimizers(self):
140
optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
141
return optimizer
142
143
144
# ### Setting up the data loaders
145
146
147
148
149
150
151
152
153
class MnistDataModule(pl.LightningDataModule):
154
def __init__(self, data_path='./'):
155
super().__init__()
156
self.data_path = data_path
157
self.transform = transforms.Compose([transforms.ToTensor()])
158
159
def prepare_data(self):
160
MNIST(root=self.data_path, download=True)
161
162
def setup(self, stage=None):
163
# stage is either 'fit', 'validate', 'test', or 'predict'
164
# here note relevant
165
mnist_all = MNIST(
166
root=self.data_path,
167
train=True,
168
transform=self.transform,
169
download=False
170
)
171
172
self.train, self.val = random_split(
173
mnist_all, [55000, 5000], generator=torch.Generator().manual_seed(1)
174
)
175
176
self.test = MNIST(
177
root=self.data_path,
178
train=False,
179
transform=self.transform,
180
download=False
181
)
182
183
def train_dataloader(self):
184
return DataLoader(self.train, batch_size=64, num_workers=4)
185
186
def val_dataloader(self):
187
return DataLoader(self.val, batch_size=64, num_workers=4)
188
189
def test_dataloader(self):
190
return DataLoader(self.test, batch_size=64, num_workers=4)
191
192
193
torch.manual_seed(1)
194
mnist_dm = MnistDataModule()
195
196
197
# ### Training the model using the PyTorch Lightning Trainer class
198
199
200
201
202
203
mnistclassifier = MultiLayerPerceptron()
204
205
callbacks = [ModelCheckpoint(save_top_k=1, mode='max', monitor="valid_acc")] # save top 1 model
206
207
if torch.cuda.is_available(): # if you have GPUs
208
trainer = pl.Trainer(max_epochs=10, callbacks=callbacks, gpus=1)
209
else:
210
trainer = pl.Trainer(max_epochs=10, callbacks=callbacks)
211
212
trainer.fit(model=mnistclassifier, datamodule=mnist_dm)
213
214
215
# ### Evaluating the model using TensorBoard
216
217
218
219
trainer.test(model=mnistclassifier, datamodule=mnist_dm, ckpt_path='best')
220
221
222
223
224
225
226
227
228
229
230
# Start tensorboard
231
232
233
234
235
236
237
238
239
path = 'lightning_logs/version_0/checkpoints/epoch=8-step=7739.ckpt'
240
241
if torch.cuda.is_available(): # if you have GPUs
242
trainer = pl.Trainer(
243
max_epochs=15, callbacks=callbacks, resume_from_checkpoint=path, gpus=1
244
)
245
else:
246
trainer = pl.Trainer(
247
max_epochs=15, callbacks=callbacks, resume_from_checkpoint=path
248
)
249
250
trainer.fit(model=mnistclassifier, datamodule=mnist_dm)
251
252
253
254
255
256
257
258
259
260
261
262
263
trainer.test(model=mnistclassifier, datamodule=mnist_dm)
264
265
266
267
268
trainer.test(model=mnistclassifier, datamodule=mnist_dm, ckpt_path='best')
269
270
271
272
273
path = "lightning_logs/version_0/checkpoints/epoch=13-step=12039.ckpt"
274
model = MultiLayerPerceptron.load_from_checkpoint(path)
275
276
277
# ## Summary
278
279
# ---
280
#
281
# Readers may ignore the next cell.
282
283
284
285
286
287