Path: blob/master/labml_nn/unet/experiment.py
4925 views
"""1---2title: Training a U-Net on Carvana dataset3summary: >4Code for training a U-Net model on Carvana dataset.5---67# Training [U-Net](index.html)89This trains a [U-Net](index.html) model on [Carvana dataset](carvana.html).10You can find the download instructions11[on Kaggle](https://www.kaggle.com/competitions/carvana-image-masking-challenge/data).1213Save the training images inside `carvana/train` folder and the masks in `carvana/train_masks` folder.1415For simplicity, we do not do a training and validation split.16"""1718import numpy as np19import torchvision.transforms.functional2021import torch22import torch.utils.data23from labml import lab, tracker, experiment, monit24from labml.configs import BaseConfigs25from labml_nn.helpers.device import DeviceConfigs26from labml_nn.unet import UNet27from labml_nn.unet.carvana import CarvanaDataset28from torch import nn293031class Configs(BaseConfigs):32"""33## Configurations34"""35# Device to train the model on.36# [`DeviceConfigs`](../helpers/device.html)37# picks up an available CUDA device or defaults to CPU.38device: torch.device = DeviceConfigs()3940# [U-Net](index.html) model41model: UNet4243# Number of channels in the image. $3$ for RGB.44image_channels: int = 345# Number of channels in the output mask. $1$ for binary mask.46mask_channels: int = 14748# Batch size49batch_size: int = 150# Learning rate51learning_rate: float = 2.5e-45253# Number of training epochs54epochs: int = 45556# Dataset57dataset: CarvanaDataset58# Dataloader59data_loader: torch.utils.data.DataLoader6061# Loss function62loss_func = nn.BCELoss()63# Sigmoid function for binary classification64sigmoid = nn.Sigmoid()6566# Adam optimizer67optimizer: torch.optim.Adam6869def init(self):70# Initialize the [Carvana dataset](carvana.html)71self.dataset = CarvanaDataset(lab.get_data_path() / 'carvana' / 'train',72lab.get_data_path() / 'carvana' / 'train_masks')73# Initialize the model74self.model = UNet(self.image_channels, self.mask_channels).to(self.device)7576# Create dataloader77self.data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size,78shuffle=True, pin_memory=True)79# Create optimizer80self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)8182# Image logging83tracker.set_image("sample", True)8485@torch.no_grad()86def sample(self, idx=-1):87"""88### Sample images89"""9091# Get a random sample92x, _ = self.dataset[np.random.randint(len(self.dataset))]93# Move data to device94x = x.to(self.device)9596# Get predicted mask97mask = self.sigmoid(self.model(x[None, :]))98# Crop the image to the size of the mask99x = torchvision.transforms.functional.center_crop(x, [mask.shape[2], mask.shape[3]])100# Log samples101tracker.save('sample', x * mask)102103def train(self):104"""105### Train for an epoch106"""107108# Iterate through the dataset.109# Use [`mix`](https://docs.labml.ai/api/monit.html#labml.monit.mix)110# to sample $50$ times per epoch.111for _, (image, mask) in monit.mix(('Train', self.data_loader), (self.sample, list(range(50)))):112# Increment global step113tracker.add_global_step()114# Move data to device115image, mask = image.to(self.device), mask.to(self.device)116117# Make the gradients zero118self.optimizer.zero_grad()119# Get predicted mask logits120logits = self.model(image)121# Crop the target mask to the size of the logits. Size of the logits will be smaller if we122# don't use padding in convolutional layers in the U-Net.123mask = torchvision.transforms.functional.center_crop(mask, [logits.shape[2], logits.shape[3]])124# Calculate loss125loss = self.loss_func(self.sigmoid(logits), mask)126# Compute gradients127loss.backward()128# Take an optimization step129self.optimizer.step()130# Track the loss131tracker.save('loss', loss)132133def run(self):134"""135### Training loop136"""137for _ in monit.loop(self.epochs):138# Train the model139self.train()140# New line in the console141tracker.new_line()142# Save the model143144145def main():146# Create experiment147experiment.create(name='unet')148149# Create configurations150configs = Configs()151152# Set configurations. You can override the defaults by passing the values in the dictionary.153experiment.configs(configs, {})154155# Initialize156configs.init()157158# Set models for saving and loading159experiment.add_pytorch_models({'model': configs.model})160161# Start and run the training loop162with experiment.start():163configs.run()164165166#167if __name__ == '__main__':168main()169170171