Path: blob/master/labml_nn/uncertainty/evidence/experiment.py
4939 views
"""1---2title: "Evidential Deep Learning to Quantify Classification Uncertainty Experiment"3summary: >4This trains is EDL model on MNIST5---67# [Evidential Deep Learning to Quantify Classification Uncertainty](index.html) Experiment89This trains a model based on [Evidential Deep Learning to Quantify Classification Uncertainty](index.html)10on MNIST dataset.11"""1213from typing import Any1415import torch.nn as nn16import torch.utils.data1718from labml import tracker, experiment19from labml.configs import option, calculate20from labml_nn.helpers.schedule import Schedule, RelativePiecewise21from labml_nn.helpers.trainer import BatchIndex22from labml_nn.experiments.mnist import MNISTConfigs23from labml_nn.uncertainty.evidence import KLDivergenceLoss, TrackStatistics, MaximumLikelihoodLoss, \24CrossEntropyBayesRisk, SquaredErrorBayesRisk252627class Model(nn.Module):28"""29## LeNet based model fro MNIST classification30"""3132def __init__(self, dropout: float):33super().__init__()34# First $5x5$ convolution layer35self.conv1 = nn.Conv2d(1, 20, kernel_size=5)36# ReLU activation37self.act1 = nn.ReLU()38# $2x2$ max-pooling39self.max_pool1 = nn.MaxPool2d(2, 2)40# Second $5x5$ convolution layer41self.conv2 = nn.Conv2d(20, 50, kernel_size=5)42# ReLU activation43self.act2 = nn.ReLU()44# $2x2$ max-pooling45self.max_pool2 = nn.MaxPool2d(2, 2)46# First fully-connected layer that maps to $500$ features47self.fc1 = nn.Linear(50 * 4 * 4, 500)48# ReLU activation49self.act3 = nn.ReLU()50# Final fully connected layer to output evidence for $10$ classes.51# The ReLU or Softplus activation is applied to this outside the model to get the52# non-negative evidence53self.fc2 = nn.Linear(500, 10)54# Dropout for the hidden layer55self.dropout = nn.Dropout(p=dropout)5657def __call__(self, x: torch.Tensor):58"""59* `x` is the batch of MNIST images of shape `[batch_size, 1, 28, 28]`60"""61# Apply first convolution and max pooling.62# The result has shape `[batch_size, 20, 12, 12]`63x = self.max_pool1(self.act1(self.conv1(x)))64# Apply second convolution and max pooling.65# The result has shape `[batch_size, 50, 4, 4]`66x = self.max_pool2(self.act2(self.conv2(x)))67# Flatten the tensor to shape `[batch_size, 50 * 4 * 4]`68x = x.view(x.shape[0], -1)69# Apply hidden layer70x = self.act3(self.fc1(x))71# Apply dropout72x = self.dropout(x)73# Apply final layer and return74return self.fc2(x)757677class Configs(MNISTConfigs):78"""79## Configurations8081We use [`MNISTConfigs`](../../experiments/mnist.html#MNISTConfigs) configurations.82"""8384# [KL Divergence regularization](index.html#KLDivergenceLoss)85kl_div_loss = KLDivergenceLoss()86# KL Divergence regularization coefficient schedule87kl_div_coef: Schedule88# KL Divergence regularization coefficient schedule89kl_div_coef_schedule = [(0, 0.), (0.2, 0.01), (1, 1.)]90# [Stats module](index.html#TrackStatistics) for tracking91stats = TrackStatistics()92# Dropout93dropout: float = 0.594# Module to convert the model output to non-zero evidences95outputs_to_evidence: nn.Module9697def init(self):98"""99### Initialization100"""101# Set tracker configurations102tracker.set_scalar("loss.*", True)103tracker.set_scalar("accuracy.*", True)104tracker.set_histogram('u.*', True)105tracker.set_histogram('prob.*', False)106tracker.set_scalar('annealing_coef.*', False)107tracker.set_scalar('kl_div_loss.*', False)108109#110self.state_modules = []111112def step(self, batch: Any, batch_idx: BatchIndex):113"""114### Training or validation step115"""116117# Training/Evaluation mode118self.model.train(self.mode.is_train)119120# Move data to the device121data, target = batch[0].to(self.device), batch[1].to(self.device)122123# One-hot coded targets124eye = torch.eye(10).to(torch.float).to(self.device)125target = eye[target]126127# Update global step (number of samples processed) when in training mode128if self.mode.is_train:129tracker.add_global_step(len(data))130131# Get model outputs132outputs = self.model(data)133# Get evidences $e_k \ge 0$134evidence = self.outputs_to_evidence(outputs)135136# Calculate loss137loss = self.loss_func(evidence, target)138# Calculate KL Divergence regularization loss139kl_div_loss = self.kl_div_loss(evidence, target)140tracker.add("loss.", loss)141tracker.add("kl_div_loss.", kl_div_loss)142143# KL Divergence loss coefficient $\lambda_t$144annealing_coef = min(1., self.kl_div_coef(tracker.get_global_step()))145tracker.add("annealing_coef.", annealing_coef)146147# Total loss148loss = loss + annealing_coef * kl_div_loss149150# Track statistics151self.stats(evidence, target)152153# Train the model154if self.mode.is_train:155# Calculate gradients156loss.backward()157# Take optimizer step158self.optimizer.step()159# Clear the gradients160self.optimizer.zero_grad()161162# Save the tracked metrics163tracker.save()164165166@option(Configs.model)167def mnist_model(c: Configs):168"""169### Create model170"""171return Model(c.dropout).to(c.device)172173174@option(Configs.kl_div_coef)175def kl_div_coef(c: Configs):176"""177### KL Divergence Loss Coefficient Schedule178"""179180# Create a [relative piecewise schedule](../../helpers/schedule.html)181return RelativePiecewise(c.kl_div_coef_schedule, c.epochs * len(c.train_dataset))182183184# [Maximum Likelihood Loss](index.html#MaximumLikelihoodLoss)185calculate(Configs.loss_func, 'max_likelihood_loss', lambda: MaximumLikelihoodLoss())186# [Cross Entropy Bayes Risk](index.html#CrossEntropyBayesRisk)187calculate(Configs.loss_func, 'cross_entropy_bayes_risk', lambda: CrossEntropyBayesRisk())188# [Squared Error Bayes Risk](index.html#SquaredErrorBayesRisk)189calculate(Configs.loss_func, 'squared_error_bayes_risk', lambda: SquaredErrorBayesRisk())190191# ReLU to calculate evidence192calculate(Configs.outputs_to_evidence, 'relu', lambda: nn.ReLU())193# Softplus to calculate evidence194calculate(Configs.outputs_to_evidence, 'softplus', lambda: nn.Softplus())195196197def main():198# Create experiment199experiment.create(name='evidence_mnist')200# Create configurations201conf = Configs()202# Load configurations203experiment.configs(conf, {204'optimizer.optimizer': 'Adam',205'optimizer.learning_rate': 0.001,206'optimizer.weight_decay': 0.005,207208# 'loss_func': 'max_likelihood_loss',209# 'loss_func': 'cross_entropy_bayes_risk',210'loss_func': 'squared_error_bayes_risk',211212'outputs_to_evidence': 'softplus',213214'dropout': 0.5,215})216# Start the experiment and run the training loop217with experiment.start():218conf.run()219220221#222if __name__ == '__main__':223main()224225226