Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
labmlai
GitHub Repository: labmlai/annotated_deep_learning_paper_implementations
Path: blob/master/labml_nn/optimizers/mnist_experiment.py
4937 views
1
"""
2
---
3
title: MNIST example to test the optimizers
4
summary: This is a simple MNIST example with a CNN model to test the optimizers.
5
---
6
7
# MNIST example to test the optimizers
8
"""
9
import torch.nn as nn
10
import torch.utils.data
11
12
from labml import experiment, tracker
13
from labml.configs import option
14
from labml_nn.helpers.datasets import MNISTConfigs
15
from labml_nn.helpers.device import DeviceConfigs
16
from labml_nn.helpers.metrics import Accuracy
17
from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex
18
from labml_nn.optimizers.configs import OptimizerConfigs
19
20
21
class Model(nn.Module):
22
"""
23
## The model
24
"""
25
26
def __init__(self):
27
super().__init__()
28
self.conv1 = nn.Conv2d(1, 20, 5, 1)
29
self.pool1 = nn.MaxPool2d(2)
30
self.conv2 = nn.Conv2d(20, 50, 5, 1)
31
self.pool2 = nn.MaxPool2d(2)
32
self.fc1 = nn.Linear(16 * 50, 500)
33
self.fc2 = nn.Linear(500, 10)
34
self.activation = nn.ReLU()
35
36
def forward(self, x):
37
x = self.activation(self.conv1(x))
38
x = self.pool1(x)
39
x = self.activation(self.conv2(x))
40
x = self.pool2(x)
41
x = self.activation(self.fc1(x.view(-1, 16 * 50)))
42
return self.fc2(x)
43
44
45
class Configs(MNISTConfigs, TrainValidConfigs):
46
"""
47
## Configurable Experiment Definition
48
"""
49
optimizer: torch.optim.Adam
50
model: nn.Module
51
device: torch.device = DeviceConfigs()
52
epochs: int = 10
53
54
is_save_models = True
55
model: nn.Module
56
inner_iterations = 10
57
58
accuracy_func = Accuracy()
59
loss_func = nn.CrossEntropyLoss()
60
61
def init(self):
62
tracker.set_queue("loss.*", 20, True)
63
tracker.set_scalar("accuracy.*", True)
64
self.state_modules = [self.accuracy_func]
65
66
def step(self, batch: any, batch_idx: BatchIndex):
67
# Get the batch
68
data, target = batch[0].to(self.device), batch[1].to(self.device)
69
70
# Add global step if we are in training mode
71
if self.mode.is_train:
72
tracker.add_global_step(len(data))
73
74
# Run the model
75
output = self.model(data)
76
77
# Calculate the loss
78
loss = self.loss_func(output, target)
79
# Calculate the accuracy
80
self.accuracy_func(output, target)
81
# Log the loss
82
tracker.add("loss.", loss)
83
84
# Optimize if we are in training mode
85
if self.mode.is_train:
86
# Calculate the gradients
87
loss.backward()
88
89
# Take optimizer step
90
self.optimizer.step()
91
# Log the parameter and gradient L2 norms once per epoch
92
if batch_idx.is_last:
93
tracker.add('model', self.model)
94
tracker.add('optimizer', (self.optimizer, {'model': self.model}))
95
# Clear the gradients
96
self.optimizer.zero_grad()
97
98
# Save logs
99
tracker.save()
100
101
102
@option(Configs.model)
103
def model(c: Configs):
104
return Model().to(c.device)
105
106
107
@option(Configs.optimizer)
108
def _optimizer(c: Configs):
109
"""
110
Create a configurable optimizer.
111
We can change the optimizer type and hyper-parameters using configurations.
112
"""
113
opt_conf = OptimizerConfigs()
114
opt_conf.parameters = c.model.parameters()
115
return opt_conf
116
117
118
def main():
119
conf = Configs()
120
conf.inner_iterations = 10
121
experiment.create(name='mnist_ada_belief')
122
experiment.configs(conf, {'inner_iterations': 10,
123
# Specify the optimizer
124
'optimizer.optimizer': 'Adam',
125
'optimizer.learning_rate': 1.5e-4})
126
experiment.add_pytorch_models(dict(model=conf.model))
127
with experiment.start():
128
conf.run()
129
130
131
if __name__ == '__main__':
132
main()
133
134