Path: blob/master/labml_nn/conv_mixer/experiment.py
4925 views
"""1---2title: Train ConvMixer on CIFAR 103summary: >4Train ConvMixer on CIFAR 105---67# Train a [ConvMixer](index.html) on CIFAR 1089This script trains a ConvMixer on CIFAR 10 dataset.1011This is not an attempt to reproduce the results of the paper.12The paper uses image augmentations13present in [PyTorch Image Models (timm)](https://github.com/rwightman/pytorch-image-models)14for training. We haven't done this for simplicity - which causes our validation accuracy to drop.15"""1617from labml import experiment18from labml.configs import option19from labml_nn.experiments.cifar10 import CIFAR10Configs202122class Configs(CIFAR10Configs):23"""24## Configurations2526We use [`CIFAR10Configs`](../experiments/cifar10.html) which defines all the27dataset related configurations, optimizer, and a training loop.28"""2930# Size of a patch, $p$31patch_size: int = 232# Number of channels in patch embeddings, $h$33d_model: int = 25634# Number of [ConvMixer layers](#ConvMixerLayer) or depth, $d$35n_layers: int = 836# Kernel size of the depth-wise convolution, $k$37kernel_size: int = 738# Number of classes in the task39n_classes: int = 10404142@option(Configs.model)43def _conv_mixer(c: Configs):44"""45### Create model46"""47from labml_nn.conv_mixer import ConvMixerLayer, ConvMixer, ClassificationHead, PatchEmbeddings4849# Create ConvMixer50return ConvMixer(ConvMixerLayer(c.d_model, c.kernel_size), c.n_layers,51PatchEmbeddings(c.d_model, c.patch_size, 3),52ClassificationHead(c.d_model, c.n_classes)).to(c.device)535455def main():56# Create experiment57experiment.create(name='ConvMixer', comment='cifar10')58# Create configurations59conf = Configs()60# Load configurations61experiment.configs(conf, {62# Optimizer63'optimizer.optimizer': 'Adam',64'optimizer.learning_rate': 2.5e-4,6566# Training epochs and batch size67'epochs': 150,68'train_batch_size': 64,6970# Simple image augmentations71'train_dataset': 'cifar10_train_augmented',72# Do not augment images for validation73'valid_dataset': 'cifar10_valid_no_augment',74})75# Set model for saving/loading76experiment.add_pytorch_models({'model': conf.model})77# Start the experiment and run the training loop78with experiment.start():79conf.run()808182#83if __name__ == '__main__':84main()858687