Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
labmlai
GitHub Repository: labmlai/annotated_deep_learning_paper_implementations
Path: blob/master/labml_nn/uncertainty/evidence/experiment.py
4939 views
1
"""
2
---
3
title: "Evidential Deep Learning to Quantify Classification Uncertainty Experiment"
4
summary: >
5
This trains is EDL model on MNIST
6
---
7
8
# [Evidential Deep Learning to Quantify Classification Uncertainty](index.html) Experiment
9
10
This trains a model based on [Evidential Deep Learning to Quantify Classification Uncertainty](index.html)
11
on MNIST dataset.
12
"""
13
14
from typing import Any
15
16
import torch.nn as nn
17
import torch.utils.data
18
19
from labml import tracker, experiment
20
from labml.configs import option, calculate
21
from labml_nn.helpers.schedule import Schedule, RelativePiecewise
22
from labml_nn.helpers.trainer import BatchIndex
23
from labml_nn.experiments.mnist import MNISTConfigs
24
from labml_nn.uncertainty.evidence import KLDivergenceLoss, TrackStatistics, MaximumLikelihoodLoss, \
25
CrossEntropyBayesRisk, SquaredErrorBayesRisk
26
27
28
class Model(nn.Module):
29
"""
30
## LeNet based model fro MNIST classification
31
"""
32
33
def __init__(self, dropout: float):
34
super().__init__()
35
# First $5x5$ convolution layer
36
self.conv1 = nn.Conv2d(1, 20, kernel_size=5)
37
# ReLU activation
38
self.act1 = nn.ReLU()
39
# $2x2$ max-pooling
40
self.max_pool1 = nn.MaxPool2d(2, 2)
41
# Second $5x5$ convolution layer
42
self.conv2 = nn.Conv2d(20, 50, kernel_size=5)
43
# ReLU activation
44
self.act2 = nn.ReLU()
45
# $2x2$ max-pooling
46
self.max_pool2 = nn.MaxPool2d(2, 2)
47
# First fully-connected layer that maps to $500$ features
48
self.fc1 = nn.Linear(50 * 4 * 4, 500)
49
# ReLU activation
50
self.act3 = nn.ReLU()
51
# Final fully connected layer to output evidence for $10$ classes.
52
# The ReLU or Softplus activation is applied to this outside the model to get the
53
# non-negative evidence
54
self.fc2 = nn.Linear(500, 10)
55
# Dropout for the hidden layer
56
self.dropout = nn.Dropout(p=dropout)
57
58
def __call__(self, x: torch.Tensor):
59
"""
60
* `x` is the batch of MNIST images of shape `[batch_size, 1, 28, 28]`
61
"""
62
# Apply first convolution and max pooling.
63
# The result has shape `[batch_size, 20, 12, 12]`
64
x = self.max_pool1(self.act1(self.conv1(x)))
65
# Apply second convolution and max pooling.
66
# The result has shape `[batch_size, 50, 4, 4]`
67
x = self.max_pool2(self.act2(self.conv2(x)))
68
# Flatten the tensor to shape `[batch_size, 50 * 4 * 4]`
69
x = x.view(x.shape[0], -1)
70
# Apply hidden layer
71
x = self.act3(self.fc1(x))
72
# Apply dropout
73
x = self.dropout(x)
74
# Apply final layer and return
75
return self.fc2(x)
76
77
78
class Configs(MNISTConfigs):
79
"""
80
## Configurations
81
82
We use [`MNISTConfigs`](../../experiments/mnist.html#MNISTConfigs) configurations.
83
"""
84
85
# [KL Divergence regularization](index.html#KLDivergenceLoss)
86
kl_div_loss = KLDivergenceLoss()
87
# KL Divergence regularization coefficient schedule
88
kl_div_coef: Schedule
89
# KL Divergence regularization coefficient schedule
90
kl_div_coef_schedule = [(0, 0.), (0.2, 0.01), (1, 1.)]
91
# [Stats module](index.html#TrackStatistics) for tracking
92
stats = TrackStatistics()
93
# Dropout
94
dropout: float = 0.5
95
# Module to convert the model output to non-zero evidences
96
outputs_to_evidence: nn.Module
97
98
def init(self):
99
"""
100
### Initialization
101
"""
102
# Set tracker configurations
103
tracker.set_scalar("loss.*", True)
104
tracker.set_scalar("accuracy.*", True)
105
tracker.set_histogram('u.*', True)
106
tracker.set_histogram('prob.*', False)
107
tracker.set_scalar('annealing_coef.*', False)
108
tracker.set_scalar('kl_div_loss.*', False)
109
110
#
111
self.state_modules = []
112
113
def step(self, batch: Any, batch_idx: BatchIndex):
114
"""
115
### Training or validation step
116
"""
117
118
# Training/Evaluation mode
119
self.model.train(self.mode.is_train)
120
121
# Move data to the device
122
data, target = batch[0].to(self.device), batch[1].to(self.device)
123
124
# One-hot coded targets
125
eye = torch.eye(10).to(torch.float).to(self.device)
126
target = eye[target]
127
128
# Update global step (number of samples processed) when in training mode
129
if self.mode.is_train:
130
tracker.add_global_step(len(data))
131
132
# Get model outputs
133
outputs = self.model(data)
134
# Get evidences $e_k \ge 0$
135
evidence = self.outputs_to_evidence(outputs)
136
137
# Calculate loss
138
loss = self.loss_func(evidence, target)
139
# Calculate KL Divergence regularization loss
140
kl_div_loss = self.kl_div_loss(evidence, target)
141
tracker.add("loss.", loss)
142
tracker.add("kl_div_loss.", kl_div_loss)
143
144
# KL Divergence loss coefficient $\lambda_t$
145
annealing_coef = min(1., self.kl_div_coef(tracker.get_global_step()))
146
tracker.add("annealing_coef.", annealing_coef)
147
148
# Total loss
149
loss = loss + annealing_coef * kl_div_loss
150
151
# Track statistics
152
self.stats(evidence, target)
153
154
# Train the model
155
if self.mode.is_train:
156
# Calculate gradients
157
loss.backward()
158
# Take optimizer step
159
self.optimizer.step()
160
# Clear the gradients
161
self.optimizer.zero_grad()
162
163
# Save the tracked metrics
164
tracker.save()
165
166
167
@option(Configs.model)
168
def mnist_model(c: Configs):
169
"""
170
### Create model
171
"""
172
return Model(c.dropout).to(c.device)
173
174
175
@option(Configs.kl_div_coef)
176
def kl_div_coef(c: Configs):
177
"""
178
### KL Divergence Loss Coefficient Schedule
179
"""
180
181
# Create a [relative piecewise schedule](../../helpers/schedule.html)
182
return RelativePiecewise(c.kl_div_coef_schedule, c.epochs * len(c.train_dataset))
183
184
185
# [Maximum Likelihood Loss](index.html#MaximumLikelihoodLoss)
186
calculate(Configs.loss_func, 'max_likelihood_loss', lambda: MaximumLikelihoodLoss())
187
# [Cross Entropy Bayes Risk](index.html#CrossEntropyBayesRisk)
188
calculate(Configs.loss_func, 'cross_entropy_bayes_risk', lambda: CrossEntropyBayesRisk())
189
# [Squared Error Bayes Risk](index.html#SquaredErrorBayesRisk)
190
calculate(Configs.loss_func, 'squared_error_bayes_risk', lambda: SquaredErrorBayesRisk())
191
192
# ReLU to calculate evidence
193
calculate(Configs.outputs_to_evidence, 'relu', lambda: nn.ReLU())
194
# Softplus to calculate evidence
195
calculate(Configs.outputs_to_evidence, 'softplus', lambda: nn.Softplus())
196
197
198
def main():
199
# Create experiment
200
experiment.create(name='evidence_mnist')
201
# Create configurations
202
conf = Configs()
203
# Load configurations
204
experiment.configs(conf, {
205
'optimizer.optimizer': 'Adam',
206
'optimizer.learning_rate': 0.001,
207
'optimizer.weight_decay': 0.005,
208
209
# 'loss_func': 'max_likelihood_loss',
210
# 'loss_func': 'cross_entropy_bayes_risk',
211
'loss_func': 'squared_error_bayes_risk',
212
213
'outputs_to_evidence': 'softplus',
214
215
'dropout': 0.5,
216
})
217
# Start the experiment and run the training loop
218
with experiment.start():
219
conf.run()
220
221
222
#
223
if __name__ == '__main__':
224
main()
225
226