Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
labmlai
GitHub Repository: labmlai/annotated_deep_learning_paper_implementations
Path: blob/master/labml_nn/experiments/cifar10.py
4910 views
1
"""
2
---
3
title: CIFAR10 Experiment
4
summary: >
5
This is a reusable trainer for CIFAR10 dataset
6
---
7
8
# CIFAR10 Experiment
9
"""
10
from typing import List
11
12
import torch.nn as nn
13
14
from labml import lab
15
from labml.configs import option
16
from labml_nn.helpers.datasets import CIFAR10Configs as CIFAR10DatasetConfigs
17
from labml_nn.experiments.mnist import MNISTConfigs
18
19
20
class CIFAR10Configs(CIFAR10DatasetConfigs, MNISTConfigs):
21
"""
22
## Configurations
23
24
This extends from [CIFAR 10 dataset configurations](../helpers/datasets.html)
25
and [`MNISTConfigs`](mnist.html).
26
"""
27
# Use CIFAR10 dataset by default
28
dataset_name: str = 'CIFAR10'
29
30
31
@option(CIFAR10Configs.train_dataset)
32
def cifar10_train_augmented():
33
"""
34
### Augmented CIFAR 10 train dataset
35
"""
36
from torchvision.datasets import CIFAR10
37
from torchvision.transforms import transforms
38
return CIFAR10(str(lab.get_data_path()),
39
train=True,
40
download=True,
41
transform=transforms.Compose([
42
# Pad and crop
43
transforms.RandomCrop(32, padding=4),
44
# Random horizontal flip
45
transforms.RandomHorizontalFlip(),
46
#
47
transforms.ToTensor(),
48
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
49
]))
50
51
52
@option(CIFAR10Configs.valid_dataset)
53
def cifar10_valid_no_augment():
54
"""
55
### Non-augmented CIFAR 10 validation dataset
56
"""
57
from torchvision.datasets import CIFAR10
58
from torchvision.transforms import transforms
59
return CIFAR10(str(lab.get_data_path()),
60
train=False,
61
download=True,
62
transform=transforms.Compose([
63
transforms.ToTensor(),
64
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
65
]))
66
67
68
class CIFAR10VGGModel(nn.Module):
69
"""
70
### VGG model for CIFAR-10 classification
71
"""
72
73
def conv_block(self, in_channels, out_channels) -> nn.Module:
74
"""
75
Convolution and activation combined
76
"""
77
return nn.Sequential(
78
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
79
nn.ReLU(inplace=True),
80
)
81
82
def __init__(self, blocks: List[List[int]]):
83
super().__init__()
84
85
# 5 $2 \times 2$ pooling layers will produce a output of size $1 \ times 1$.
86
# CIFAR 10 image size is $32 \times 32$
87
assert len(blocks) == 5
88
layers = []
89
# RGB channels
90
in_channels = 3
91
# Number of channels in each layer in each block
92
for block in blocks:
93
# Convolution, Normalization and Activation layers
94
for channels in block:
95
layers += self.conv_block(in_channels, channels)
96
in_channels = channels
97
# Max pooling at end of each block
98
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
99
100
# Create a sequential model with the layers
101
self.layers = nn.Sequential(*layers)
102
# Final logits layer
103
self.fc = nn.Linear(in_channels, 10)
104
105
def forward(self, x):
106
# The VGG layers
107
x = self.layers(x)
108
# Reshape for classification layer
109
x = x.view(x.shape[0], -1)
110
# Final linear layer
111
return self.fc(x)
112
113