Path: blob/master/labml_nn/resnet/experiment.py
4918 views
"""1---2title: Train a ResNet on CIFAR 103summary: >4Train a ResNet on CIFAR 105---67# Train a [ResNet](index.html) on CIFAR 108"""9from typing import List, Optional1011from torch import nn1213from labml import experiment14from labml.configs import option15from labml_nn.experiments.cifar10 import CIFAR10Configs16from labml_nn.resnet import ResNetBase171819class Configs(CIFAR10Configs):20"""21## Configurations2223We use [`CIFAR10Configs`](../experiments/cifar10.html) which defines all the24dataset related configurations, optimizer, and a training loop.25"""2627# Number fo blocks for each feature map size28n_blocks: List[int] = [3, 3, 3]29# Number of channels for each feature map size30n_channels: List[int] = [16, 32, 64]31# Bottleneck sizes32bottlenecks: Optional[List[int]] = None33# Kernel size of the initial convolution layer34first_kernel_size: int = 3353637@option(Configs.model)38def _resnet(c: Configs):39"""40### Create model41"""42# [ResNet](index.html)43base = ResNetBase(c.n_blocks, c.n_channels, c.bottlenecks, img_channels=3, first_kernel_size=c.first_kernel_size)44# Linear layer for classification45classification = nn.Linear(c.n_channels[-1], 10)4647# Stack them48model = nn.Sequential(base, classification)49# Move the model to the device50return model.to(c.device)515253def main():54# Create experiment55experiment.create(name='resnet', comment='cifar10')56# Create configurations57conf = Configs()58# Load configurations59experiment.configs(conf, {60'bottlenecks': [8, 16, 16],61'n_blocks': [6, 6, 6],6263'optimizer.optimizer': 'Adam',64'optimizer.learning_rate': 2.5e-4,6566'epochs': 500,67'train_batch_size': 256,6869'train_dataset': 'cifar10_train_augmented',70'valid_dataset': 'cifar10_valid_no_augment',71})72# Set model for saving/loading73experiment.add_pytorch_models({'model': conf.model})74# Start the experiment and run the training loop75with experiment.start():76conf.run()777879#80if __name__ == '__main__':81main()828384