Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
labmlai
GitHub Repository: labmlai/annotated_deep_learning_paper_implementations
Path: blob/master/labml_nn/experiments/mnist.py
4910 views
1
"""
2
---
3
title: MNIST Experiment
4
summary: >
5
This is a reusable trainer for MNIST dataset
6
---
7
8
# MNIST Experiment
9
"""
10
11
import torch.nn as nn
12
import torch.utils.data
13
14
from labml import tracker
15
from labml.configs import option
16
from labml_nn.helpers.datasets import MNISTConfigs as MNISTDatasetConfigs
17
from labml_nn.helpers.device import DeviceConfigs
18
from labml_nn.helpers.metrics import Accuracy
19
from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex
20
from labml_nn.optimizers.configs import OptimizerConfigs
21
22
23
class MNISTConfigs(MNISTDatasetConfigs, TrainValidConfigs):
24
"""
25
<a id="MNISTConfigs"></a>
26
27
## Trainer configurations
28
"""
29
30
# Optimizer
31
optimizer: torch.optim.Adam
32
# Training device
33
device: torch.device = DeviceConfigs()
34
35
# Classification model
36
model: nn.Module
37
# Number of epochs to train for
38
epochs: int = 10
39
40
# Number of times to switch between training and validation within an epoch
41
inner_iterations = 10
42
43
# Accuracy function
44
accuracy = Accuracy()
45
# Loss function
46
loss_func = nn.CrossEntropyLoss()
47
48
def init(self):
49
"""
50
### Initialization
51
"""
52
# Set tracker configurations
53
tracker.set_scalar("loss.*", True)
54
tracker.set_scalar("accuracy.*", True)
55
# Add accuracy as a state module.
56
# The name is probably confusing, since it's meant to store
57
# states between training and validation for RNNs.
58
# This will keep the accuracy metric stats separate for training and validation.
59
self.state_modules = [self.accuracy]
60
61
def step(self, batch: any, batch_idx: BatchIndex):
62
"""
63
### Training or validation step
64
"""
65
66
# Training/Evaluation mode
67
self.model.train(self.mode.is_train)
68
69
# Move data to the device
70
data, target = batch[0].to(self.device), batch[1].to(self.device)
71
72
# Update global step (number of samples processed) when in training mode
73
if self.mode.is_train:
74
tracker.add_global_step(len(data))
75
76
# Get model outputs.
77
output = self.model(data)
78
79
# Calculate and log loss
80
loss = self.loss_func(output, target)
81
tracker.add("loss.", loss)
82
83
# Calculate and log accuracy
84
self.accuracy(output, target)
85
self.accuracy.track()
86
87
# Train the model
88
if self.mode.is_train:
89
# Calculate gradients
90
loss.backward()
91
# Take optimizer step
92
self.optimizer.step()
93
# Log the model parameters and gradients on last batch of every epoch
94
if batch_idx.is_last:
95
tracker.add('model', self.model)
96
# Clear the gradients
97
self.optimizer.zero_grad()
98
99
# Save the tracked metrics
100
tracker.save()
101
102
103
@option(MNISTConfigs.optimizer)
104
def _optimizer(c: MNISTConfigs):
105
"""
106
### Default optimizer configurations
107
"""
108
opt_conf = OptimizerConfigs()
109
opt_conf.parameters = c.model.parameters()
110
opt_conf.optimizer = 'Adam'
111
return opt_conf
112
113