Path: blob/master/labml_nn/experiments/cifar10.py
4910 views
"""1---2title: CIFAR10 Experiment3summary: >4This is a reusable trainer for CIFAR10 dataset5---67# CIFAR10 Experiment8"""9from typing import List1011import torch.nn as nn1213from labml import lab14from labml.configs import option15from labml_nn.helpers.datasets import CIFAR10Configs as CIFAR10DatasetConfigs16from labml_nn.experiments.mnist import MNISTConfigs171819class CIFAR10Configs(CIFAR10DatasetConfigs, MNISTConfigs):20"""21## Configurations2223This extends from [CIFAR 10 dataset configurations](../helpers/datasets.html)24and [`MNISTConfigs`](mnist.html).25"""26# Use CIFAR10 dataset by default27dataset_name: str = 'CIFAR10'282930@option(CIFAR10Configs.train_dataset)31def cifar10_train_augmented():32"""33### Augmented CIFAR 10 train dataset34"""35from torchvision.datasets import CIFAR1036from torchvision.transforms import transforms37return CIFAR10(str(lab.get_data_path()),38train=True,39download=True,40transform=transforms.Compose([41# Pad and crop42transforms.RandomCrop(32, padding=4),43# Random horizontal flip44transforms.RandomHorizontalFlip(),45#46transforms.ToTensor(),47transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))48]))495051@option(CIFAR10Configs.valid_dataset)52def cifar10_valid_no_augment():53"""54### Non-augmented CIFAR 10 validation dataset55"""56from torchvision.datasets import CIFAR1057from torchvision.transforms import transforms58return CIFAR10(str(lab.get_data_path()),59train=False,60download=True,61transform=transforms.Compose([62transforms.ToTensor(),63transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))64]))656667class CIFAR10VGGModel(nn.Module):68"""69### VGG model for CIFAR-10 classification70"""7172def conv_block(self, in_channels, out_channels) -> nn.Module:73"""74Convolution and activation combined75"""76return nn.Sequential(77nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),78nn.ReLU(inplace=True),79)8081def __init__(self, blocks: List[List[int]]):82super().__init__()8384# 5 $2 \times 2$ pooling layers will produce a output of size $1 \ times 1$.85# CIFAR 10 image size is $32 \times 32$86assert len(blocks) == 587layers = []88# RGB channels89in_channels = 390# Number of channels in each layer in each block91for block in blocks:92# Convolution, Normalization and Activation layers93for channels in block:94layers += self.conv_block(in_channels, channels)95in_channels = channels96# Max pooling at end of each block97layers += [nn.MaxPool2d(kernel_size=2, stride=2)]9899# Create a sequential model with the layers100self.layers = nn.Sequential(*layers)101# Final logits layer102self.fc = nn.Linear(in_channels, 10)103104def forward(self, x):105# The VGG layers106x = self.layers(x)107# Reshape for classification layer108x = x.view(x.shape[0], -1)109# Final linear layer110return self.fc(x)111112113