Path: blob/master/labml_nn/distillation/large.py
4918 views
"""1---2title: Train a large model on CIFAR 103summary: >4Train a large model on CIFAR 10 for distillation.5---67# Train a large model on CIFAR 1089This trains a large model on CIFAR 10 for [distillation](index.html).10"""1112import torch.nn as nn1314from labml import experiment, logger15from labml.configs import option16from labml_nn.experiments.cifar10 import CIFAR10Configs, CIFAR10VGGModel17from labml_nn.normalization.batch_norm import BatchNorm181920class Configs(CIFAR10Configs):21"""22## Configurations2324We use [`CIFAR10Configs`](../experiments/cifar10.html) which defines all the25dataset related configurations, optimizer, and a training loop.26"""27pass282930class LargeModel(CIFAR10VGGModel):31"""32### VGG style model for CIFAR-10 classification3334This derives from the [generic VGG style architecture](../experiments/cifar10.html).35"""3637def conv_block(self, in_channels, out_channels) -> nn.Module:38"""39Create a convolution layer and the activations40"""41return nn.Sequential(42# Dropout43nn.Dropout(0.1),44# Convolution layer45nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),46# Batch normalization47BatchNorm(out_channels, track_running_stats=False),48# ReLU activation49nn.ReLU(inplace=True),50)5152def __init__(self):53# Create a model with given convolution sizes (channels)54super().__init__([[64, 64], [128, 128], [256, 256, 256], [512, 512, 512], [512, 512, 512]])555657@option(Configs.model)58def _large_model(c: Configs):59"""60### Create model61"""62return LargeModel().to(c.device)636465def main():66# Create experiment67experiment.create(name='cifar10', comment='large model')68# Create configurations69conf = Configs()70# Load configurations71experiment.configs(conf, {72'optimizer.optimizer': 'Adam',73'optimizer.learning_rate': 2.5e-4,74'is_save_models': True,75'epochs': 20,76})77# Set model for saving/loading78experiment.add_pytorch_models({'model': conf.model})79# Print number of parameters in the model80logger.inspect(params=(sum(p.numel() for p in conf.model.parameters() if p.requires_grad)))81# Start the experiment and run the training loop82with experiment.start():83conf.run()848586#87if __name__ == '__main__':88main()899091