Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
labmlai
GitHub Repository: labmlai/annotated_deep_learning_paper_implementations
Path: blob/master/labml_nn/diffusion/ddpm/unet.py
4921 views
1
"""
2
---
3
title: U-Net model for Denoising Diffusion Probabilistic Models (DDPM)
4
summary: >
5
UNet model for Denoising Diffusion Probabilistic Models (DDPM)
6
---
7
8
# U-Net model for [Denoising Diffusion Probabilistic Models (DDPM)](index.html)
9
10
This is a [U-Net](../../unet/index.html) based model to predict noise
11
$\textcolor{lightgreen}{\epsilon_\theta}(x_t, t)$.
12
13
U-Net is a gets it's name from the U shape in the model diagram.
14
It processes a given image by progressively lowering (halving) the feature map resolution and then
15
increasing the resolution.
16
There are pass-through connection at each resolution.
17
18
![U-Net diagram from paper](../../unet/unet.png)
19
20
This implementation contains a bunch of modifications to original U-Net (residual blocks, multi-head attention)
21
and also adds time-step embeddings $t$.
22
"""
23
24
import math
25
from typing import Optional, Tuple, Union, List
26
27
import torch
28
from torch import nn
29
30
31
class Swish(nn.Module):
32
"""
33
### Swish activation function
34
35
$$x \cdot \sigma(x)$$
36
"""
37
38
def forward(self, x):
39
return x * torch.sigmoid(x)
40
41
42
class TimeEmbedding(nn.Module):
43
"""
44
### Embeddings for $t$
45
"""
46
47
def __init__(self, n_channels: int):
48
"""
49
* `n_channels` is the number of dimensions in the embedding
50
"""
51
super().__init__()
52
self.n_channels = n_channels
53
# First linear layer
54
self.lin1 = nn.Linear(self.n_channels // 4, self.n_channels)
55
# Activation
56
self.act = Swish()
57
# Second linear layer
58
self.lin2 = nn.Linear(self.n_channels, self.n_channels)
59
60
def forward(self, t: torch.Tensor):
61
# Create sinusoidal position embeddings
62
# [same as those from the transformer](../../transformers/positional_encoding.html)
63
#
64
# \begin{align}
65
# PE^{(1)}_{t,i} &= sin\Bigg(\frac{t}{10000^{\frac{i}{d - 1}}}\Bigg) \\
66
# PE^{(2)}_{t,i} &= cos\Bigg(\frac{t}{10000^{\frac{i}{d - 1}}}\Bigg)
67
# \end{align}
68
#
69
# where $d$ is `half_dim`
70
half_dim = self.n_channels // 8
71
emb = math.log(10_000) / (half_dim - 1)
72
emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
73
emb = t[:, None] * emb[None, :]
74
emb = torch.cat((emb.sin(), emb.cos()), dim=1)
75
76
# Transform with the MLP
77
emb = self.act(self.lin1(emb))
78
emb = self.lin2(emb)
79
80
#
81
return emb
82
83
84
class ResidualBlock(nn.Module):
85
"""
86
### Residual block
87
88
A residual block has two convolution layers with group normalization.
89
Each resolution is processed with two residual blocks.
90
"""
91
92
def __init__(self, in_channels: int, out_channels: int, time_channels: int,
93
n_groups: int = 32, dropout: float = 0.1):
94
"""
95
* `in_channels` is the number of input channels
96
* `out_channels` is the number of input channels
97
* `time_channels` is the number channels in the time step ($t$) embeddings
98
* `n_groups` is the number of groups for [group normalization](../../normalization/group_norm/index.html)
99
* `dropout` is the dropout rate
100
"""
101
super().__init__()
102
# Group normalization and the first convolution layer
103
self.norm1 = nn.GroupNorm(n_groups, in_channels)
104
self.act1 = Swish()
105
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))
106
107
# Group normalization and the second convolution layer
108
self.norm2 = nn.GroupNorm(n_groups, out_channels)
109
self.act2 = Swish()
110
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))
111
112
# If the number of input channels is not equal to the number of output channels we have to
113
# project the shortcut connection
114
if in_channels != out_channels:
115
self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1))
116
else:
117
self.shortcut = nn.Identity()
118
119
# Linear layer for time embeddings
120
self.time_emb = nn.Linear(time_channels, out_channels)
121
self.time_act = Swish()
122
123
self.dropout = nn.Dropout(dropout)
124
125
def forward(self, x: torch.Tensor, t: torch.Tensor):
126
"""
127
* `x` has shape `[batch_size, in_channels, height, width]`
128
* `t` has shape `[batch_size, time_channels]`
129
"""
130
# First convolution layer
131
h = self.conv1(self.act1(self.norm1(x)))
132
# Add time embeddings
133
h += self.time_emb(self.time_act(t))[:, :, None, None]
134
# Second convolution layer
135
h = self.conv2(self.dropout(self.act2(self.norm2(h))))
136
137
# Add the shortcut connection and return
138
return h + self.shortcut(x)
139
140
141
class AttentionBlock(nn.Module):
142
"""
143
### Attention block
144
145
This is similar to [transformer multi-head attention](../../transformers/mha.html).
146
"""
147
148
def __init__(self, n_channels: int, n_heads: int = 1, d_k: int = None, n_groups: int = 32):
149
"""
150
* `n_channels` is the number of channels in the input
151
* `n_heads` is the number of heads in multi-head attention
152
* `d_k` is the number of dimensions in each head
153
* `n_groups` is the number of groups for [group normalization](../../normalization/group_norm/index.html)
154
"""
155
super().__init__()
156
157
# Default `d_k`
158
if d_k is None:
159
d_k = n_channels
160
# Normalization layer
161
self.norm = nn.GroupNorm(n_groups, n_channels)
162
# Projections for query, key and values
163
self.projection = nn.Linear(n_channels, n_heads * d_k * 3)
164
# Linear layer for final transformation
165
self.output = nn.Linear(n_heads * d_k, n_channels)
166
# Scale for dot-product attention
167
self.scale = d_k ** -0.5
168
#
169
self.n_heads = n_heads
170
self.d_k = d_k
171
172
def forward(self, x: torch.Tensor, t: Optional[torch.Tensor] = None):
173
"""
174
* `x` has shape `[batch_size, in_channels, height, width]`
175
* `t` has shape `[batch_size, time_channels]`
176
"""
177
# `t` is not used, but it's kept in the arguments because for the attention layer function signature
178
# to match with `ResidualBlock`.
179
_ = t
180
# Get shape
181
batch_size, n_channels, height, width = x.shape
182
# Change `x` to shape `[batch_size, seq, n_channels]`
183
x = x.view(batch_size, n_channels, -1).permute(0, 2, 1)
184
# Get query, key, and values (concatenated) and shape it to `[batch_size, seq, n_heads, 3 * d_k]`
185
qkv = self.projection(x).view(batch_size, -1, self.n_heads, 3 * self.d_k)
186
# Split query, key, and values. Each of them will have shape `[batch_size, seq, n_heads, d_k]`
187
q, k, v = torch.chunk(qkv, 3, dim=-1)
188
# Calculate scaled dot-product $\frac{Q K^\top}{\sqrt{d_k}}$
189
attn = torch.einsum('bihd,bjhd->bijh', q, k) * self.scale
190
# Softmax along the sequence dimension $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)$
191
attn = attn.softmax(dim=2)
192
# Multiply by values
193
res = torch.einsum('bijh,bjhd->bihd', attn, v)
194
# Reshape to `[batch_size, seq, n_heads * d_k]`
195
res = res.view(batch_size, -1, self.n_heads * self.d_k)
196
# Transform to `[batch_size, seq, n_channels]`
197
res = self.output(res)
198
199
# Add skip connection
200
res += x
201
202
# Change to shape `[batch_size, in_channels, height, width]`
203
res = res.permute(0, 2, 1).view(batch_size, n_channels, height, width)
204
205
#
206
return res
207
208
209
class DownBlock(nn.Module):
210
"""
211
### Down block
212
213
This combines `ResidualBlock` and `AttentionBlock`. These are used in the first half of U-Net at each resolution.
214
"""
215
216
def __init__(self, in_channels: int, out_channels: int, time_channels: int, has_attn: bool):
217
super().__init__()
218
self.res = ResidualBlock(in_channels, out_channels, time_channels)
219
if has_attn:
220
self.attn = AttentionBlock(out_channels)
221
else:
222
self.attn = nn.Identity()
223
224
def forward(self, x: torch.Tensor, t: torch.Tensor):
225
x = self.res(x, t)
226
x = self.attn(x)
227
return x
228
229
230
class UpBlock(nn.Module):
231
"""
232
### Up block
233
234
This combines `ResidualBlock` and `AttentionBlock`. These are used in the second half of U-Net at each resolution.
235
"""
236
237
def __init__(self, in_channels: int, out_channels: int, time_channels: int, has_attn: bool):
238
super().__init__()
239
# The input has `in_channels + out_channels` because we concatenate the output of the same resolution
240
# from the first half of the U-Net
241
self.res = ResidualBlock(in_channels + out_channels, out_channels, time_channels)
242
if has_attn:
243
self.attn = AttentionBlock(out_channels)
244
else:
245
self.attn = nn.Identity()
246
247
def forward(self, x: torch.Tensor, t: torch.Tensor):
248
x = self.res(x, t)
249
x = self.attn(x)
250
return x
251
252
253
class MiddleBlock(nn.Module):
254
"""
255
### Middle block
256
257
It combines a `ResidualBlock`, `AttentionBlock`, followed by another `ResidualBlock`.
258
This block is applied at the lowest resolution of the U-Net.
259
"""
260
261
def __init__(self, n_channels: int, time_channels: int):
262
super().__init__()
263
self.res1 = ResidualBlock(n_channels, n_channels, time_channels)
264
self.attn = AttentionBlock(n_channels)
265
self.res2 = ResidualBlock(n_channels, n_channels, time_channels)
266
267
def forward(self, x: torch.Tensor, t: torch.Tensor):
268
x = self.res1(x, t)
269
x = self.attn(x)
270
x = self.res2(x, t)
271
return x
272
273
274
class Upsample(nn.Module):
275
"""
276
### Scale up the feature map by $2 \times$
277
"""
278
279
def __init__(self, n_channels):
280
super().__init__()
281
self.conv = nn.ConvTranspose2d(n_channels, n_channels, (4, 4), (2, 2), (1, 1))
282
283
def forward(self, x: torch.Tensor, t: torch.Tensor):
284
# `t` is not used, but it's kept in the arguments because for the attention layer function signature
285
# to match with `ResidualBlock`.
286
_ = t
287
return self.conv(x)
288
289
290
class Downsample(nn.Module):
291
"""
292
### Scale down the feature map by $\frac{1}{2} \times$
293
"""
294
295
def __init__(self, n_channels):
296
super().__init__()
297
self.conv = nn.Conv2d(n_channels, n_channels, (3, 3), (2, 2), (1, 1))
298
299
def forward(self, x: torch.Tensor, t: torch.Tensor):
300
# `t` is not used, but it's kept in the arguments because for the attention layer function signature
301
# to match with `ResidualBlock`.
302
_ = t
303
return self.conv(x)
304
305
306
class UNet(nn.Module):
307
"""
308
## U-Net
309
"""
310
311
def __init__(self, image_channels: int = 3, n_channels: int = 64,
312
ch_mults: Union[Tuple[int, ...], List[int]] = (1, 2, 2, 4),
313
is_attn: Union[Tuple[bool, ...], List[bool]] = (False, False, True, True),
314
n_blocks: int = 2):
315
"""
316
* `image_channels` is the number of channels in the image. $3$ for RGB.
317
* `n_channels` is number of channels in the initial feature map that we transform the image into
318
* `ch_mults` is the list of channel numbers at each resolution. The number of channels is `ch_mults[i] * n_channels`
319
* `is_attn` is a list of booleans that indicate whether to use attention at each resolution
320
* `n_blocks` is the number of `UpDownBlocks` at each resolution
321
"""
322
super().__init__()
323
324
# Number of resolutions
325
n_resolutions = len(ch_mults)
326
327
# Project image into feature map
328
self.image_proj = nn.Conv2d(image_channels, n_channels, kernel_size=(3, 3), padding=(1, 1))
329
330
# Time embedding layer. Time embedding has `n_channels * 4` channels
331
self.time_emb = TimeEmbedding(n_channels * 4)
332
333
# #### First half of U-Net - decreasing resolution
334
down = []
335
# Number of channels
336
out_channels = in_channels = n_channels
337
# For each resolution
338
for i in range(n_resolutions):
339
# Number of output channels at this resolution
340
out_channels = in_channels * ch_mults[i]
341
# Add `n_blocks`
342
for _ in range(n_blocks):
343
down.append(DownBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
344
in_channels = out_channels
345
# Down sample at all resolutions except the last
346
if i < n_resolutions - 1:
347
down.append(Downsample(in_channels))
348
349
# Combine the set of modules
350
self.down = nn.ModuleList(down)
351
352
# Middle block
353
self.middle = MiddleBlock(out_channels, n_channels * 4, )
354
355
# #### Second half of U-Net - increasing resolution
356
up = []
357
# Number of channels
358
in_channels = out_channels
359
# For each resolution
360
for i in reversed(range(n_resolutions)):
361
# `n_blocks` at the same resolution
362
out_channels = in_channels
363
for _ in range(n_blocks):
364
up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
365
# Final block to reduce the number of channels
366
out_channels = in_channels // ch_mults[i]
367
up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
368
in_channels = out_channels
369
# Up sample at all resolutions except last
370
if i > 0:
371
up.append(Upsample(in_channels))
372
373
# Combine the set of modules
374
self.up = nn.ModuleList(up)
375
376
# Final normalization and convolution layer
377
self.norm = nn.GroupNorm(8, n_channels)
378
self.act = Swish()
379
self.final = nn.Conv2d(in_channels, image_channels, kernel_size=(3, 3), padding=(1, 1))
380
381
def forward(self, x: torch.Tensor, t: torch.Tensor):
382
"""
383
* `x` has shape `[batch_size, in_channels, height, width]`
384
* `t` has shape `[batch_size]`
385
"""
386
387
# Get time-step embeddings
388
t = self.time_emb(t)
389
390
# Get image projection
391
x = self.image_proj(x)
392
393
# `h` will store outputs at each resolution for skip connection
394
h = [x]
395
# First half of U-Net
396
for m in self.down:
397
x = m(x, t)
398
h.append(x)
399
400
# Middle (bottom)
401
x = self.middle(x, t)
402
403
# Second half of U-Net
404
for m in self.up:
405
if isinstance(m, Upsample):
406
x = m(x, t)
407
else:
408
# Get the skip connection from first half of U-Net and concatenate
409
s = h.pop()
410
x = torch.cat((x, s), dim=1)
411
#
412
x = m(x, t)
413
414
# Final normalization and convolution
415
return self.final(self.act(self.norm(x)))
416
417