Path: blob/main/ch13/ch13_part3_lightning.py
1245 views
# coding: utf-8123from pkg_resources import parse_version4import sys5from python_environment_check import check_packages6import pytorch_lightning as pl7import torch8import torch.nn as nn9from torchmetrics import __version__ as torchmetrics_version10from torchmetrics import Accuracy11from torch.utils.data import DataLoader12from torch.utils.data import random_split13from torchvision.datasets import MNIST14from torchvision import transforms15from pytorch_lightning.callbacks import ModelCheckpoint1617# # Machine Learning with PyTorch and Scikit-Learn18# # -- Code Examples1920# ## Package version checks2122# Add folder to path in order to load from the check_packages.py script:23242526sys.path.insert(0, '..')272829# Check recommended package versions:303132333435d = {36'torch': '1.8',37'torchvision': '0.9.0',38'tensorboard': '2.7.0',39'pytorch_lightning': '1.5.0',40'torchmetrics': '0.6.2'41}42check_packages(d)434445# # Chapter 13: Going Deeper -- the Mechanics of PyTorch (Part 3/3)4647# **Outline**48#49# - [Higher-level PyTorch APIs: a short introduction to PyTorch Lightning](#Higher-level-PyTorch-APIs-a-short-introduction-to-PyTorch-Lightning)50# - [Setting up the PyTorch Lightning model](#Setting-up-the-PyTorch-Lightning-model)51# - [Setting up the data loaders for Lightning](#Setting-up-the-data-loaders-for-Lightning)52# - [Training the model using the PyTorch Lightning Trainer class](#Training-the-model-using-the-PyTorch-Lightning-Trainer-class)53# - [Evaluating the model using TensorBoard](#Evaluating-the-model-using-TensorBoard)54# - [Summary](#Summary)5556# ## Higher-level PyTorch APIs: a short introduction to PyTorch Lightning5758# ### Setting up the PyTorch Lightning model5960# ## Higher-level PyTorch APIs: a short introduction to PyTorch Lightning6162# ### Setting up the PyTorch Lightning model636465666768697071class MultiLayerPerceptron(pl.LightningModule):72def __init__(self, image_shape=(1, 28, 28), hidden_units=(32, 16)):73super().__init__()7475# new PL attributes:7677if parse_version(torchmetrics_version) > parse_version(0.8):78self.train_acc = Accuracy(task="multiclass", num_classes=10)79self.valid_acc = Accuracy(task="multiclass", num_classes=10)80self.test_acc = Accuracy(task="multiclass", num_classes=10)81else:82self.train_acc = Accuracy()83self.valid_acc = Accuracy()84self.test_acc = Accuracy()8586# Model similar to previous section:87input_size = image_shape[0] * image_shape[1] * image_shape[2]88all_layers = [nn.Flatten()]89for hidden_unit in hidden_units:90layer = nn.Linear(input_size, hidden_unit)91all_layers.append(layer)92all_layers.append(nn.ReLU())93input_size = hidden_unit9495all_layers.append(nn.Linear(hidden_units[-1], 10))96self.model = nn.Sequential(*all_layers)9798def forward(self, x):99x = self.model(x)100return x101102def training_step(self, batch, batch_idx):103x, y = batch104logits = self(x)105loss = nn.functional.cross_entropy(logits, y)106preds = torch.argmax(logits, dim=1)107self.train_acc.update(preds, y)108self.log("train_loss", loss, prog_bar=True)109return loss110111def training_epoch_end(self, outs):112self.log("train_acc", self.train_acc.compute())113self.train_acc.reset()114115def validation_step(self, batch, batch_idx):116x, y = batch117logits = self(x)118loss = nn.functional.cross_entropy(logits, y)119preds = torch.argmax(logits, dim=1)120self.valid_acc.update(preds, y)121self.log("valid_loss", loss, prog_bar=True)122return loss123124def validation_epoch_end(self, outs):125self.log("valid_acc", self.valid_acc.compute(), prog_bar=True)126self.valid_acc.reset()127128def test_step(self, batch, batch_idx):129x, y = batch130logits = self(x)131loss = nn.functional.cross_entropy(logits, y)132preds = torch.argmax(logits, dim=1)133self.test_acc.update(preds, y)134self.log("test_loss", loss, prog_bar=True)135self.log("test_acc", self.test_acc.compute(), prog_bar=True)136return loss137138def configure_optimizers(self):139optimizer = torch.optim.Adam(self.parameters(), lr=0.001)140return optimizer141142143# ### Setting up the data loaders144145146147148149150151152class MnistDataModule(pl.LightningDataModule):153def __init__(self, data_path='./'):154super().__init__()155self.data_path = data_path156self.transform = transforms.Compose([transforms.ToTensor()])157158def prepare_data(self):159MNIST(root=self.data_path, download=True)160161def setup(self, stage=None):162# stage is either 'fit', 'validate', 'test', or 'predict'163# here note relevant164mnist_all = MNIST(165root=self.data_path,166train=True,167transform=self.transform,168download=False169)170171self.train, self.val = random_split(172mnist_all, [55000, 5000], generator=torch.Generator().manual_seed(1)173)174175self.test = MNIST(176root=self.data_path,177train=False,178transform=self.transform,179download=False180)181182def train_dataloader(self):183return DataLoader(self.train, batch_size=64, num_workers=4)184185def val_dataloader(self):186return DataLoader(self.val, batch_size=64, num_workers=4)187188def test_dataloader(self):189return DataLoader(self.test, batch_size=64, num_workers=4)190191192torch.manual_seed(1)193mnist_dm = MnistDataModule()194195196# ### Training the model using the PyTorch Lightning Trainer class197198199200201202mnistclassifier = MultiLayerPerceptron()203204callbacks = [ModelCheckpoint(save_top_k=1, mode='max', monitor="valid_acc")] # save top 1 model205206if torch.cuda.is_available(): # if you have GPUs207trainer = pl.Trainer(max_epochs=10, callbacks=callbacks, gpus=1)208else:209trainer = pl.Trainer(max_epochs=10, callbacks=callbacks)210211trainer.fit(model=mnistclassifier, datamodule=mnist_dm)212213214# ### Evaluating the model using TensorBoard215216217218trainer.test(model=mnistclassifier, datamodule=mnist_dm, ckpt_path='best')219220221222223224225226227228229# Start tensorboard230231232233234235236237238path = 'lightning_logs/version_0/checkpoints/epoch=8-step=7739.ckpt'239240if torch.cuda.is_available(): # if you have GPUs241trainer = pl.Trainer(242max_epochs=15, callbacks=callbacks, resume_from_checkpoint=path, gpus=1243)244else:245trainer = pl.Trainer(246max_epochs=15, callbacks=callbacks, resume_from_checkpoint=path247)248249trainer.fit(model=mnistclassifier, datamodule=mnist_dm)250251252253254255256257258259260261262trainer.test(model=mnistclassifier, datamodule=mnist_dm)263264265266267trainer.test(model=mnistclassifier, datamodule=mnist_dm, ckpt_path='best')268269270271272path = "lightning_logs/version_0/checkpoints/epoch=13-step=12039.ckpt"273model = MultiLayerPerceptron.load_from_checkpoint(path)274275276# ## Summary277278# ---279#280# Readers may ignore the next cell.281282283284285286287