Path: blob/master/labml_nn/optimizers/mnist_experiment.py
4937 views
"""1---2title: MNIST example to test the optimizers3summary: This is a simple MNIST example with a CNN model to test the optimizers.4---56# MNIST example to test the optimizers7"""8import torch.nn as nn9import torch.utils.data1011from labml import experiment, tracker12from labml.configs import option13from labml_nn.helpers.datasets import MNISTConfigs14from labml_nn.helpers.device import DeviceConfigs15from labml_nn.helpers.metrics import Accuracy16from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex17from labml_nn.optimizers.configs import OptimizerConfigs181920class Model(nn.Module):21"""22## The model23"""2425def __init__(self):26super().__init__()27self.conv1 = nn.Conv2d(1, 20, 5, 1)28self.pool1 = nn.MaxPool2d(2)29self.conv2 = nn.Conv2d(20, 50, 5, 1)30self.pool2 = nn.MaxPool2d(2)31self.fc1 = nn.Linear(16 * 50, 500)32self.fc2 = nn.Linear(500, 10)33self.activation = nn.ReLU()3435def forward(self, x):36x = self.activation(self.conv1(x))37x = self.pool1(x)38x = self.activation(self.conv2(x))39x = self.pool2(x)40x = self.activation(self.fc1(x.view(-1, 16 * 50)))41return self.fc2(x)424344class Configs(MNISTConfigs, TrainValidConfigs):45"""46## Configurable Experiment Definition47"""48optimizer: torch.optim.Adam49model: nn.Module50device: torch.device = DeviceConfigs()51epochs: int = 105253is_save_models = True54model: nn.Module55inner_iterations = 105657accuracy_func = Accuracy()58loss_func = nn.CrossEntropyLoss()5960def init(self):61tracker.set_queue("loss.*", 20, True)62tracker.set_scalar("accuracy.*", True)63self.state_modules = [self.accuracy_func]6465def step(self, batch: any, batch_idx: BatchIndex):66# Get the batch67data, target = batch[0].to(self.device), batch[1].to(self.device)6869# Add global step if we are in training mode70if self.mode.is_train:71tracker.add_global_step(len(data))7273# Run the model74output = self.model(data)7576# Calculate the loss77loss = self.loss_func(output, target)78# Calculate the accuracy79self.accuracy_func(output, target)80# Log the loss81tracker.add("loss.", loss)8283# Optimize if we are in training mode84if self.mode.is_train:85# Calculate the gradients86loss.backward()8788# Take optimizer step89self.optimizer.step()90# Log the parameter and gradient L2 norms once per epoch91if batch_idx.is_last:92tracker.add('model', self.model)93tracker.add('optimizer', (self.optimizer, {'model': self.model}))94# Clear the gradients95self.optimizer.zero_grad()9697# Save logs98tracker.save()99100101@option(Configs.model)102def model(c: Configs):103return Model().to(c.device)104105106@option(Configs.optimizer)107def _optimizer(c: Configs):108"""109Create a configurable optimizer.110We can change the optimizer type and hyper-parameters using configurations.111"""112opt_conf = OptimizerConfigs()113opt_conf.parameters = c.model.parameters()114return opt_conf115116117def main():118conf = Configs()119conf.inner_iterations = 10120experiment.create(name='mnist_ada_belief')121experiment.configs(conf, {'inner_iterations': 10,122# Specify the optimizer123'optimizer.optimizer': 'Adam',124'optimizer.learning_rate': 1.5e-4})125experiment.add_pytorch_models(dict(model=conf.model))126with experiment.start():127conf.run()128129130if __name__ == '__main__':131main()132133134