Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
labmlai
GitHub Repository: labmlai/annotated_deep_learning_paper_implementations
Path: blob/master/labml_nn/distillation/small.py
4918 views
1
"""
2
---
3
title: Train a small model on CIFAR 10
4
summary: >
5
Train a small model on CIFAR 10 to test how much distillation benefits.
6
---
7
8
# Train a small model on CIFAR 10
9
10
This trains a small model on CIFAR 10 to test how much [distillation](index.html) benefits.
11
"""
12
13
import torch.nn as nn
14
15
from labml import experiment, logger
16
from labml.configs import option
17
from labml_nn.experiments.cifar10 import CIFAR10Configs, CIFAR10VGGModel
18
from labml_nn.normalization.batch_norm import BatchNorm
19
20
21
class Configs(CIFAR10Configs):
22
"""
23
## Configurations
24
25
We use [`CIFAR10Configs`](../experiments/cifar10.html) which defines all the
26
dataset related configurations, optimizer, and a training loop.
27
"""
28
pass
29
30
31
class SmallModel(CIFAR10VGGModel):
32
"""
33
### VGG style model for CIFAR-10 classification
34
35
This derives from the [generic VGG style architecture](../experiments/cifar10.html).
36
"""
37
38
def conv_block(self, in_channels, out_channels) -> nn.Module:
39
"""
40
Create a convolution layer and the activations
41
"""
42
return nn.Sequential(
43
# Convolution layer
44
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
45
# Batch normalization
46
BatchNorm(out_channels, track_running_stats=False),
47
# ReLU activation
48
nn.ReLU(inplace=True),
49
)
50
51
def __init__(self):
52
# Create a model with given convolution sizes (channels)
53
super().__init__([[32, 32], [64, 64], [128], [128], [128]])
54
55
56
@option(Configs.model)
57
def _small_model(c: Configs):
58
"""
59
### Create model
60
"""
61
return SmallModel().to(c.device)
62
63
64
def main():
65
# Create experiment
66
experiment.create(name='cifar10', comment='small model')
67
# Create configurations
68
conf = Configs()
69
# Load configurations
70
experiment.configs(conf, {
71
'optimizer.optimizer': 'Adam',
72
'optimizer.learning_rate': 2.5e-4,
73
})
74
# Set model for saving/loading
75
experiment.add_pytorch_models({'model': conf.model})
76
# Print number of parameters in the model
77
logger.inspect(params=(sum(p.numel() for p in conf.model.parameters() if p.requires_grad)))
78
# Start the experiment and run the training loop
79
with experiment.start():
80
conf.run()
81
82
83
#
84
if __name__ == '__main__':
85
main()
86
87