Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
labmlai
GitHub Repository: labmlai/annotated_deep_learning_paper_implementations
Path: blob/master/labml_nn/unet/experiment.py
4925 views
1
"""
2
---
3
title: Training a U-Net on Carvana dataset
4
summary: >
5
Code for training a U-Net model on Carvana dataset.
6
---
7
8
# Training [U-Net](index.html)
9
10
This trains a [U-Net](index.html) model on [Carvana dataset](carvana.html).
11
You can find the download instructions
12
[on Kaggle](https://www.kaggle.com/competitions/carvana-image-masking-challenge/data).
13
14
Save the training images inside `carvana/train` folder and the masks in `carvana/train_masks` folder.
15
16
For simplicity, we do not do a training and validation split.
17
"""
18
19
import numpy as np
20
import torchvision.transforms.functional
21
22
import torch
23
import torch.utils.data
24
from labml import lab, tracker, experiment, monit
25
from labml.configs import BaseConfigs
26
from labml_nn.helpers.device import DeviceConfigs
27
from labml_nn.unet import UNet
28
from labml_nn.unet.carvana import CarvanaDataset
29
from torch import nn
30
31
32
class Configs(BaseConfigs):
33
"""
34
## Configurations
35
"""
36
# Device to train the model on.
37
# [`DeviceConfigs`](../helpers/device.html)
38
# picks up an available CUDA device or defaults to CPU.
39
device: torch.device = DeviceConfigs()
40
41
# [U-Net](index.html) model
42
model: UNet
43
44
# Number of channels in the image. $3$ for RGB.
45
image_channels: int = 3
46
# Number of channels in the output mask. $1$ for binary mask.
47
mask_channels: int = 1
48
49
# Batch size
50
batch_size: int = 1
51
# Learning rate
52
learning_rate: float = 2.5e-4
53
54
# Number of training epochs
55
epochs: int = 4
56
57
# Dataset
58
dataset: CarvanaDataset
59
# Dataloader
60
data_loader: torch.utils.data.DataLoader
61
62
# Loss function
63
loss_func = nn.BCELoss()
64
# Sigmoid function for binary classification
65
sigmoid = nn.Sigmoid()
66
67
# Adam optimizer
68
optimizer: torch.optim.Adam
69
70
def init(self):
71
# Initialize the [Carvana dataset](carvana.html)
72
self.dataset = CarvanaDataset(lab.get_data_path() / 'carvana' / 'train',
73
lab.get_data_path() / 'carvana' / 'train_masks')
74
# Initialize the model
75
self.model = UNet(self.image_channels, self.mask_channels).to(self.device)
76
77
# Create dataloader
78
self.data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size,
79
shuffle=True, pin_memory=True)
80
# Create optimizer
81
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
82
83
# Image logging
84
tracker.set_image("sample", True)
85
86
@torch.no_grad()
87
def sample(self, idx=-1):
88
"""
89
### Sample images
90
"""
91
92
# Get a random sample
93
x, _ = self.dataset[np.random.randint(len(self.dataset))]
94
# Move data to device
95
x = x.to(self.device)
96
97
# Get predicted mask
98
mask = self.sigmoid(self.model(x[None, :]))
99
# Crop the image to the size of the mask
100
x = torchvision.transforms.functional.center_crop(x, [mask.shape[2], mask.shape[3]])
101
# Log samples
102
tracker.save('sample', x * mask)
103
104
def train(self):
105
"""
106
### Train for an epoch
107
"""
108
109
# Iterate through the dataset.
110
# Use [`mix`](https://docs.labml.ai/api/monit.html#labml.monit.mix)
111
# to sample $50$ times per epoch.
112
for _, (image, mask) in monit.mix(('Train', self.data_loader), (self.sample, list(range(50)))):
113
# Increment global step
114
tracker.add_global_step()
115
# Move data to device
116
image, mask = image.to(self.device), mask.to(self.device)
117
118
# Make the gradients zero
119
self.optimizer.zero_grad()
120
# Get predicted mask logits
121
logits = self.model(image)
122
# Crop the target mask to the size of the logits. Size of the logits will be smaller if we
123
# don't use padding in convolutional layers in the U-Net.
124
mask = torchvision.transforms.functional.center_crop(mask, [logits.shape[2], logits.shape[3]])
125
# Calculate loss
126
loss = self.loss_func(self.sigmoid(logits), mask)
127
# Compute gradients
128
loss.backward()
129
# Take an optimization step
130
self.optimizer.step()
131
# Track the loss
132
tracker.save('loss', loss)
133
134
def run(self):
135
"""
136
### Training loop
137
"""
138
for _ in monit.loop(self.epochs):
139
# Train the model
140
self.train()
141
# New line in the console
142
tracker.new_line()
143
# Save the model
144
145
146
def main():
147
# Create experiment
148
experiment.create(name='unet')
149
150
# Create configurations
151
configs = Configs()
152
153
# Set configurations. You can override the defaults by passing the values in the dictionary.
154
experiment.configs(configs, {})
155
156
# Initialize
157
configs.init()
158
159
# Set models for saving and loading
160
experiment.add_pytorch_models({'model': configs.model})
161
162
# Start and run the training loop
163
with experiment.start():
164
configs.run()
165
166
167
#
168
if __name__ == '__main__':
169
main()
170
171