Path: blob/master/labml_nn/experiments/mnist.py
4910 views
"""1---2title: MNIST Experiment3summary: >4This is a reusable trainer for MNIST dataset5---67# MNIST Experiment8"""910import torch.nn as nn11import torch.utils.data1213from labml import tracker14from labml.configs import option15from labml_nn.helpers.datasets import MNISTConfigs as MNISTDatasetConfigs16from labml_nn.helpers.device import DeviceConfigs17from labml_nn.helpers.metrics import Accuracy18from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex19from labml_nn.optimizers.configs import OptimizerConfigs202122class MNISTConfigs(MNISTDatasetConfigs, TrainValidConfigs):23"""24<a id="MNISTConfigs"></a>2526## Trainer configurations27"""2829# Optimizer30optimizer: torch.optim.Adam31# Training device32device: torch.device = DeviceConfigs()3334# Classification model35model: nn.Module36# Number of epochs to train for37epochs: int = 103839# Number of times to switch between training and validation within an epoch40inner_iterations = 104142# Accuracy function43accuracy = Accuracy()44# Loss function45loss_func = nn.CrossEntropyLoss()4647def init(self):48"""49### Initialization50"""51# Set tracker configurations52tracker.set_scalar("loss.*", True)53tracker.set_scalar("accuracy.*", True)54# Add accuracy as a state module.55# The name is probably confusing, since it's meant to store56# states between training and validation for RNNs.57# This will keep the accuracy metric stats separate for training and validation.58self.state_modules = [self.accuracy]5960def step(self, batch: any, batch_idx: BatchIndex):61"""62### Training or validation step63"""6465# Training/Evaluation mode66self.model.train(self.mode.is_train)6768# Move data to the device69data, target = batch[0].to(self.device), batch[1].to(self.device)7071# Update global step (number of samples processed) when in training mode72if self.mode.is_train:73tracker.add_global_step(len(data))7475# Get model outputs.76output = self.model(data)7778# Calculate and log loss79loss = self.loss_func(output, target)80tracker.add("loss.", loss)8182# Calculate and log accuracy83self.accuracy(output, target)84self.accuracy.track()8586# Train the model87if self.mode.is_train:88# Calculate gradients89loss.backward()90# Take optimizer step91self.optimizer.step()92# Log the model parameters and gradients on last batch of every epoch93if batch_idx.is_last:94tracker.add('model', self.model)95# Clear the gradients96self.optimizer.zero_grad()9798# Save the tracked metrics99tracker.save()100101102@option(MNISTConfigs.optimizer)103def _optimizer(c: MNISTConfigs):104"""105### Default optimizer configurations106"""107opt_conf = OptimizerConfigs()108opt_conf.parameters = c.model.parameters()109opt_conf.optimizer = 'Adam'110return opt_conf111112113