Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
labmlai
GitHub Repository: labmlai/annotated_deep_learning_paper_implementations
Path: blob/master/labml_nn/resnet/experiment.py
4918 views
1
"""
2
---
3
title: Train a ResNet on CIFAR 10
4
summary: >
5
Train a ResNet on CIFAR 10
6
---
7
8
# Train a [ResNet](index.html) on CIFAR 10
9
"""
10
from typing import List, Optional
11
12
from torch import nn
13
14
from labml import experiment
15
from labml.configs import option
16
from labml_nn.experiments.cifar10 import CIFAR10Configs
17
from labml_nn.resnet import ResNetBase
18
19
20
class Configs(CIFAR10Configs):
21
"""
22
## Configurations
23
24
We use [`CIFAR10Configs`](../experiments/cifar10.html) which defines all the
25
dataset related configurations, optimizer, and a training loop.
26
"""
27
28
# Number fo blocks for each feature map size
29
n_blocks: List[int] = [3, 3, 3]
30
# Number of channels for each feature map size
31
n_channels: List[int] = [16, 32, 64]
32
# Bottleneck sizes
33
bottlenecks: Optional[List[int]] = None
34
# Kernel size of the initial convolution layer
35
first_kernel_size: int = 3
36
37
38
@option(Configs.model)
39
def _resnet(c: Configs):
40
"""
41
### Create model
42
"""
43
# [ResNet](index.html)
44
base = ResNetBase(c.n_blocks, c.n_channels, c.bottlenecks, img_channels=3, first_kernel_size=c.first_kernel_size)
45
# Linear layer for classification
46
classification = nn.Linear(c.n_channels[-1], 10)
47
48
# Stack them
49
model = nn.Sequential(base, classification)
50
# Move the model to the device
51
return model.to(c.device)
52
53
54
def main():
55
# Create experiment
56
experiment.create(name='resnet', comment='cifar10')
57
# Create configurations
58
conf = Configs()
59
# Load configurations
60
experiment.configs(conf, {
61
'bottlenecks': [8, 16, 16],
62
'n_blocks': [6, 6, 6],
63
64
'optimizer.optimizer': 'Adam',
65
'optimizer.learning_rate': 2.5e-4,
66
67
'epochs': 500,
68
'train_batch_size': 256,
69
70
'train_dataset': 'cifar10_train_augmented',
71
'valid_dataset': 'cifar10_valid_no_augment',
72
})
73
# Set model for saving/loading
74
experiment.add_pytorch_models({'model': conf.model})
75
# Start the experiment and run the training loop
76
with experiment.start():
77
conf.run()
78
79
80
#
81
if __name__ == '__main__':
82
main()
83
84