Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
labmlai
GitHub Repository: labmlai/annotated_deep_learning_paper_implementations
Path: blob/master/labml_nn/conv_mixer/experiment.py
4925 views
1
"""
2
---
3
title: Train ConvMixer on CIFAR 10
4
summary: >
5
Train ConvMixer on CIFAR 10
6
---
7
8
# Train a [ConvMixer](index.html) on CIFAR 10
9
10
This script trains a ConvMixer on CIFAR 10 dataset.
11
12
This is not an attempt to reproduce the results of the paper.
13
The paper uses image augmentations
14
present in [PyTorch Image Models (timm)](https://github.com/rwightman/pytorch-image-models)
15
for training. We haven't done this for simplicity - which causes our validation accuracy to drop.
16
"""
17
18
from labml import experiment
19
from labml.configs import option
20
from labml_nn.experiments.cifar10 import CIFAR10Configs
21
22
23
class Configs(CIFAR10Configs):
24
"""
25
## Configurations
26
27
We use [`CIFAR10Configs`](../experiments/cifar10.html) which defines all the
28
dataset related configurations, optimizer, and a training loop.
29
"""
30
31
# Size of a patch, $p$
32
patch_size: int = 2
33
# Number of channels in patch embeddings, $h$
34
d_model: int = 256
35
# Number of [ConvMixer layers](#ConvMixerLayer) or depth, $d$
36
n_layers: int = 8
37
# Kernel size of the depth-wise convolution, $k$
38
kernel_size: int = 7
39
# Number of classes in the task
40
n_classes: int = 10
41
42
43
@option(Configs.model)
44
def _conv_mixer(c: Configs):
45
"""
46
### Create model
47
"""
48
from labml_nn.conv_mixer import ConvMixerLayer, ConvMixer, ClassificationHead, PatchEmbeddings
49
50
# Create ConvMixer
51
return ConvMixer(ConvMixerLayer(c.d_model, c.kernel_size), c.n_layers,
52
PatchEmbeddings(c.d_model, c.patch_size, 3),
53
ClassificationHead(c.d_model, c.n_classes)).to(c.device)
54
55
56
def main():
57
# Create experiment
58
experiment.create(name='ConvMixer', comment='cifar10')
59
# Create configurations
60
conf = Configs()
61
# Load configurations
62
experiment.configs(conf, {
63
# Optimizer
64
'optimizer.optimizer': 'Adam',
65
'optimizer.learning_rate': 2.5e-4,
66
67
# Training epochs and batch size
68
'epochs': 150,
69
'train_batch_size': 64,
70
71
# Simple image augmentations
72
'train_dataset': 'cifar10_train_augmented',
73
# Do not augment images for validation
74
'valid_dataset': 'cifar10_valid_no_augment',
75
})
76
# Set model for saving/loading
77
experiment.add_pytorch_models({'model': conf.model})
78
# Start the experiment and run the training loop
79
with experiment.start():
80
conf.run()
81
82
83
#
84
if __name__ == '__main__':
85
main()
86
87