Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
labmlai
GitHub Repository: labmlai/annotated_deep_learning_paper_implementations
Path: blob/master/labml_nn/distillation/large.py
4918 views
1
"""
2
---
3
title: Train a large model on CIFAR 10
4
summary: >
5
Train a large model on CIFAR 10 for distillation.
6
---
7
8
# Train a large model on CIFAR 10
9
10
This trains a large model on CIFAR 10 for [distillation](index.html).
11
"""
12
13
import torch.nn as nn
14
15
from labml import experiment, logger
16
from labml.configs import option
17
from labml_nn.experiments.cifar10 import CIFAR10Configs, CIFAR10VGGModel
18
from labml_nn.normalization.batch_norm import BatchNorm
19
20
21
class Configs(CIFAR10Configs):
22
"""
23
## Configurations
24
25
We use [`CIFAR10Configs`](../experiments/cifar10.html) which defines all the
26
dataset related configurations, optimizer, and a training loop.
27
"""
28
pass
29
30
31
class LargeModel(CIFAR10VGGModel):
32
"""
33
### VGG style model for CIFAR-10 classification
34
35
This derives from the [generic VGG style architecture](../experiments/cifar10.html).
36
"""
37
38
def conv_block(self, in_channels, out_channels) -> nn.Module:
39
"""
40
Create a convolution layer and the activations
41
"""
42
return nn.Sequential(
43
# Dropout
44
nn.Dropout(0.1),
45
# Convolution layer
46
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
47
# Batch normalization
48
BatchNorm(out_channels, track_running_stats=False),
49
# ReLU activation
50
nn.ReLU(inplace=True),
51
)
52
53
def __init__(self):
54
# Create a model with given convolution sizes (channels)
55
super().__init__([[64, 64], [128, 128], [256, 256, 256], [512, 512, 512], [512, 512, 512]])
56
57
58
@option(Configs.model)
59
def _large_model(c: Configs):
60
"""
61
### Create model
62
"""
63
return LargeModel().to(c.device)
64
65
66
def main():
67
# Create experiment
68
experiment.create(name='cifar10', comment='large model')
69
# Create configurations
70
conf = Configs()
71
# Load configurations
72
experiment.configs(conf, {
73
'optimizer.optimizer': 'Adam',
74
'optimizer.learning_rate': 2.5e-4,
75
'is_save_models': True,
76
'epochs': 20,
77
})
78
# Set model for saving/loading
79
experiment.add_pytorch_models({'model': conf.model})
80
# Print number of parameters in the model
81
logger.inspect(params=(sum(p.numel() for p in conf.model.parameters() if p.requires_grad)))
82
# Start the experiment and run the training loop
83
with experiment.start():
84
conf.run()
85
86
87
#
88
if __name__ == '__main__':
89
main()
90
91