Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
labmlai
GitHub Repository: labmlai/annotated_deep_learning_paper_implementations
Path: blob/master/labml_nn/diffusion/ddpm/experiment.py
4921 views
1
"""
2
---
3
title: Denoising Diffusion Probabilistic Models (DDPM) training
4
summary: >
5
Training code for
6
Denoising Diffusion Probabilistic Model.
7
---
8
9
# [Denoising Diffusion Probabilistic Models (DDPM)](index.html) training
10
11
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/diffusion/ddpm/experiment.ipynb)
12
13
This trains a DDPM based model on CelebA HQ dataset. You can find the download instruction in this
14
[discussion on fast.ai](https://forums.fast.ai/t/download-celeba-hq-dataset/45873/3).
15
Save the images inside [`data/celebA` folder](#dataset_path).
16
17
The paper had used a exponential moving average of the model with a decay of $0.9999$. We have skipped this for
18
simplicity.
19
"""
20
from typing import List
21
22
import torchvision
23
from PIL import Image
24
25
import torch
26
import torch.utils.data
27
from labml import lab, tracker, experiment, monit
28
from labml.configs import BaseConfigs, option
29
from labml_nn.diffusion.ddpm import DenoiseDiffusion
30
from labml_nn.diffusion.ddpm.unet import UNet
31
from labml_nn.helpers.device import DeviceConfigs
32
33
34
class Configs(BaseConfigs):
35
"""
36
## Configurations
37
"""
38
# Device to train the model on.
39
# [`DeviceConfigs`](../../device.html)
40
# picks up an available CUDA device or defaults to CPU.
41
device: torch.device = DeviceConfigs()
42
43
# U-Net model for $\textcolor{lightgreen}{\epsilon_\theta}(x_t, t)$
44
eps_model: UNet
45
# [DDPM algorithm](index.html)
46
diffusion: DenoiseDiffusion
47
48
# Number of channels in the image. $3$ for RGB.
49
image_channels: int = 3
50
# Image size
51
image_size: int = 32
52
# Number of channels in the initial feature map
53
n_channels: int = 64
54
# The list of channel numbers at each resolution.
55
# The number of channels is `channel_multipliers[i] * n_channels`
56
channel_multipliers: List[int] = [1, 2, 2, 4]
57
# The list of booleans that indicate whether to use attention at each resolution
58
is_attention: List[int] = [False, False, False, True]
59
60
# Number of time steps $T$
61
n_steps: int = 1_000
62
# Batch size
63
batch_size: int = 64
64
# Number of samples to generate
65
n_samples: int = 16
66
# Learning rate
67
learning_rate: float = 2e-5
68
69
# Number of training epochs
70
epochs: int = 1_000
71
72
# Dataset
73
dataset: torch.utils.data.Dataset
74
# Dataloader
75
data_loader: torch.utils.data.DataLoader
76
77
# Adam optimizer
78
optimizer: torch.optim.Adam
79
80
def init(self):
81
# Create $\textcolor{lightgreen}{\epsilon_\theta}(x_t, t)$ model
82
self.eps_model = UNet(
83
image_channels=self.image_channels,
84
n_channels=self.n_channels,
85
ch_mults=self.channel_multipliers,
86
is_attn=self.is_attention,
87
).to(self.device)
88
89
# Create [DDPM class](index.html)
90
self.diffusion = DenoiseDiffusion(
91
eps_model=self.eps_model,
92
n_steps=self.n_steps,
93
device=self.device,
94
)
95
96
# Create dataloader
97
self.data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size, shuffle=True, pin_memory=True)
98
# Create optimizer
99
self.optimizer = torch.optim.Adam(self.eps_model.parameters(), lr=self.learning_rate)
100
101
# Image logging
102
tracker.set_image("sample", True)
103
104
def sample(self):
105
"""
106
### Sample images
107
"""
108
with torch.no_grad():
109
# $x_T \sim p(x_T) = \mathcal{N}(x_T; \mathbf{0}, \mathbf{I})$
110
x = torch.randn([self.n_samples, self.image_channels, self.image_size, self.image_size],
111
device=self.device)
112
113
# Remove noise for $T$ steps
114
for t_ in monit.iterate('Sample', self.n_steps):
115
# $t$
116
t = self.n_steps - t_ - 1
117
# Sample from $\textcolor{lightgreen}{p_\theta}(x_{t-1}|x_t)$
118
x = self.diffusion.p_sample(x, x.new_full((self.n_samples,), t, dtype=torch.long))
119
120
# Log samples
121
tracker.save('sample', x)
122
123
def train(self):
124
"""
125
### Train
126
"""
127
128
# Iterate through the dataset
129
for data in monit.iterate('Train', self.data_loader):
130
# Increment global step
131
tracker.add_global_step()
132
# Move data to device
133
data = data.to(self.device)
134
135
# Make the gradients zero
136
self.optimizer.zero_grad()
137
# Calculate loss
138
loss = self.diffusion.loss(data)
139
# Compute gradients
140
loss.backward()
141
# Take an optimization step
142
self.optimizer.step()
143
# Track the loss
144
tracker.save('loss', loss)
145
146
def run(self):
147
"""
148
### Training loop
149
"""
150
for _ in monit.loop(self.epochs):
151
# Train the model
152
self.train()
153
# Sample some images
154
self.sample()
155
# New line in the console
156
tracker.new_line()
157
158
159
class CelebADataset(torch.utils.data.Dataset):
160
"""
161
### CelebA HQ dataset
162
"""
163
164
def __init__(self, image_size: int):
165
super().__init__()
166
167
# CelebA images folder
168
folder = lab.get_data_path() / 'celebA'
169
# List of files
170
self._files = [p for p in folder.glob(f'**/*.jpg')]
171
172
# Transformations to resize the image and convert to tensor
173
self._transform = torchvision.transforms.Compose([
174
torchvision.transforms.Resize(image_size),
175
torchvision.transforms.ToTensor(),
176
])
177
178
def __len__(self):
179
"""
180
Size of the dataset
181
"""
182
return len(self._files)
183
184
def __getitem__(self, index: int):
185
"""
186
Get an image
187
"""
188
img = Image.open(self._files[index])
189
return self._transform(img)
190
191
192
@option(Configs.dataset, 'CelebA')
193
def celeb_dataset(c: Configs):
194
"""
195
Create CelebA dataset
196
"""
197
return CelebADataset(c.image_size)
198
199
200
class MNISTDataset(torchvision.datasets.MNIST):
201
"""
202
### MNIST dataset
203
"""
204
205
def __init__(self, image_size):
206
transform = torchvision.transforms.Compose([
207
torchvision.transforms.Resize(image_size),
208
torchvision.transforms.ToTensor(),
209
])
210
211
super().__init__(str(lab.get_data_path()), train=True, download=True, transform=transform)
212
213
def __getitem__(self, item):
214
return super().__getitem__(item)[0]
215
216
217
@option(Configs.dataset, 'MNIST')
218
def mnist_dataset(c: Configs):
219
"""
220
Create MNIST dataset
221
"""
222
return MNISTDataset(c.image_size)
223
224
225
def main():
226
# Create experiment
227
experiment.create(name='diffuse', writers={'screen', 'labml'})
228
229
# Create configurations
230
configs = Configs()
231
232
# Set configurations. You can override the defaults by passing the values in the dictionary.
233
experiment.configs(configs, {
234
'dataset': 'CelebA', # 'MNIST'
235
'image_channels': 3, # 1,
236
'epochs': 100, # 5,
237
})
238
239
# Initialize
240
configs.init()
241
242
# Set models for saving and loading
243
experiment.add_pytorch_models({'eps_model': configs.eps_model})
244
245
# Start and run the training loop
246
with experiment.start():
247
configs.run()
248
249
250
#
251
if __name__ == '__main__':
252
main()
253
254