Path: blob/master/labml_nn/diffusion/ddpm/unet.py
4921 views
"""1---2title: U-Net model for Denoising Diffusion Probabilistic Models (DDPM)3summary: >4UNet model for Denoising Diffusion Probabilistic Models (DDPM)5---67# U-Net model for [Denoising Diffusion Probabilistic Models (DDPM)](index.html)89This is a [U-Net](../../unet/index.html) based model to predict noise10$\textcolor{lightgreen}{\epsilon_\theta}(x_t, t)$.1112U-Net is a gets it's name from the U shape in the model diagram.13It processes a given image by progressively lowering (halving) the feature map resolution and then14increasing the resolution.15There are pass-through connection at each resolution.16171819This implementation contains a bunch of modifications to original U-Net (residual blocks, multi-head attention)20and also adds time-step embeddings $t$.21"""2223import math24from typing import Optional, Tuple, Union, List2526import torch27from torch import nn282930class Swish(nn.Module):31"""32### Swish activation function3334$$x \cdot \sigma(x)$$35"""3637def forward(self, x):38return x * torch.sigmoid(x)394041class TimeEmbedding(nn.Module):42"""43### Embeddings for $t$44"""4546def __init__(self, n_channels: int):47"""48* `n_channels` is the number of dimensions in the embedding49"""50super().__init__()51self.n_channels = n_channels52# First linear layer53self.lin1 = nn.Linear(self.n_channels // 4, self.n_channels)54# Activation55self.act = Swish()56# Second linear layer57self.lin2 = nn.Linear(self.n_channels, self.n_channels)5859def forward(self, t: torch.Tensor):60# Create sinusoidal position embeddings61# [same as those from the transformer](../../transformers/positional_encoding.html)62#63# \begin{align}64# PE^{(1)}_{t,i} &= sin\Bigg(\frac{t}{10000^{\frac{i}{d - 1}}}\Bigg) \\65# PE^{(2)}_{t,i} &= cos\Bigg(\frac{t}{10000^{\frac{i}{d - 1}}}\Bigg)66# \end{align}67#68# where $d$ is `half_dim`69half_dim = self.n_channels // 870emb = math.log(10_000) / (half_dim - 1)71emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)72emb = t[:, None] * emb[None, :]73emb = torch.cat((emb.sin(), emb.cos()), dim=1)7475# Transform with the MLP76emb = self.act(self.lin1(emb))77emb = self.lin2(emb)7879#80return emb818283class ResidualBlock(nn.Module):84"""85### Residual block8687A residual block has two convolution layers with group normalization.88Each resolution is processed with two residual blocks.89"""9091def __init__(self, in_channels: int, out_channels: int, time_channels: int,92n_groups: int = 32, dropout: float = 0.1):93"""94* `in_channels` is the number of input channels95* `out_channels` is the number of input channels96* `time_channels` is the number channels in the time step ($t$) embeddings97* `n_groups` is the number of groups for [group normalization](../../normalization/group_norm/index.html)98* `dropout` is the dropout rate99"""100super().__init__()101# Group normalization and the first convolution layer102self.norm1 = nn.GroupNorm(n_groups, in_channels)103self.act1 = Swish()104self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))105106# Group normalization and the second convolution layer107self.norm2 = nn.GroupNorm(n_groups, out_channels)108self.act2 = Swish()109self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))110111# If the number of input channels is not equal to the number of output channels we have to112# project the shortcut connection113if in_channels != out_channels:114self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1))115else:116self.shortcut = nn.Identity()117118# Linear layer for time embeddings119self.time_emb = nn.Linear(time_channels, out_channels)120self.time_act = Swish()121122self.dropout = nn.Dropout(dropout)123124def forward(self, x: torch.Tensor, t: torch.Tensor):125"""126* `x` has shape `[batch_size, in_channels, height, width]`127* `t` has shape `[batch_size, time_channels]`128"""129# First convolution layer130h = self.conv1(self.act1(self.norm1(x)))131# Add time embeddings132h += self.time_emb(self.time_act(t))[:, :, None, None]133# Second convolution layer134h = self.conv2(self.dropout(self.act2(self.norm2(h))))135136# Add the shortcut connection and return137return h + self.shortcut(x)138139140class AttentionBlock(nn.Module):141"""142### Attention block143144This is similar to [transformer multi-head attention](../../transformers/mha.html).145"""146147def __init__(self, n_channels: int, n_heads: int = 1, d_k: int = None, n_groups: int = 32):148"""149* `n_channels` is the number of channels in the input150* `n_heads` is the number of heads in multi-head attention151* `d_k` is the number of dimensions in each head152* `n_groups` is the number of groups for [group normalization](../../normalization/group_norm/index.html)153"""154super().__init__()155156# Default `d_k`157if d_k is None:158d_k = n_channels159# Normalization layer160self.norm = nn.GroupNorm(n_groups, n_channels)161# Projections for query, key and values162self.projection = nn.Linear(n_channels, n_heads * d_k * 3)163# Linear layer for final transformation164self.output = nn.Linear(n_heads * d_k, n_channels)165# Scale for dot-product attention166self.scale = d_k ** -0.5167#168self.n_heads = n_heads169self.d_k = d_k170171def forward(self, x: torch.Tensor, t: Optional[torch.Tensor] = None):172"""173* `x` has shape `[batch_size, in_channels, height, width]`174* `t` has shape `[batch_size, time_channels]`175"""176# `t` is not used, but it's kept in the arguments because for the attention layer function signature177# to match with `ResidualBlock`.178_ = t179# Get shape180batch_size, n_channels, height, width = x.shape181# Change `x` to shape `[batch_size, seq, n_channels]`182x = x.view(batch_size, n_channels, -1).permute(0, 2, 1)183# Get query, key, and values (concatenated) and shape it to `[batch_size, seq, n_heads, 3 * d_k]`184qkv = self.projection(x).view(batch_size, -1, self.n_heads, 3 * self.d_k)185# Split query, key, and values. Each of them will have shape `[batch_size, seq, n_heads, d_k]`186q, k, v = torch.chunk(qkv, 3, dim=-1)187# Calculate scaled dot-product $\frac{Q K^\top}{\sqrt{d_k}}$188attn = torch.einsum('bihd,bjhd->bijh', q, k) * self.scale189# Softmax along the sequence dimension $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)$190attn = attn.softmax(dim=2)191# Multiply by values192res = torch.einsum('bijh,bjhd->bihd', attn, v)193# Reshape to `[batch_size, seq, n_heads * d_k]`194res = res.view(batch_size, -1, self.n_heads * self.d_k)195# Transform to `[batch_size, seq, n_channels]`196res = self.output(res)197198# Add skip connection199res += x200201# Change to shape `[batch_size, in_channels, height, width]`202res = res.permute(0, 2, 1).view(batch_size, n_channels, height, width)203204#205return res206207208class DownBlock(nn.Module):209"""210### Down block211212This combines `ResidualBlock` and `AttentionBlock`. These are used in the first half of U-Net at each resolution.213"""214215def __init__(self, in_channels: int, out_channels: int, time_channels: int, has_attn: bool):216super().__init__()217self.res = ResidualBlock(in_channels, out_channels, time_channels)218if has_attn:219self.attn = AttentionBlock(out_channels)220else:221self.attn = nn.Identity()222223def forward(self, x: torch.Tensor, t: torch.Tensor):224x = self.res(x, t)225x = self.attn(x)226return x227228229class UpBlock(nn.Module):230"""231### Up block232233This combines `ResidualBlock` and `AttentionBlock`. These are used in the second half of U-Net at each resolution.234"""235236def __init__(self, in_channels: int, out_channels: int, time_channels: int, has_attn: bool):237super().__init__()238# The input has `in_channels + out_channels` because we concatenate the output of the same resolution239# from the first half of the U-Net240self.res = ResidualBlock(in_channels + out_channels, out_channels, time_channels)241if has_attn:242self.attn = AttentionBlock(out_channels)243else:244self.attn = nn.Identity()245246def forward(self, x: torch.Tensor, t: torch.Tensor):247x = self.res(x, t)248x = self.attn(x)249return x250251252class MiddleBlock(nn.Module):253"""254### Middle block255256It combines a `ResidualBlock`, `AttentionBlock`, followed by another `ResidualBlock`.257This block is applied at the lowest resolution of the U-Net.258"""259260def __init__(self, n_channels: int, time_channels: int):261super().__init__()262self.res1 = ResidualBlock(n_channels, n_channels, time_channels)263self.attn = AttentionBlock(n_channels)264self.res2 = ResidualBlock(n_channels, n_channels, time_channels)265266def forward(self, x: torch.Tensor, t: torch.Tensor):267x = self.res1(x, t)268x = self.attn(x)269x = self.res2(x, t)270return x271272273class Upsample(nn.Module):274"""275### Scale up the feature map by $2 \times$276"""277278def __init__(self, n_channels):279super().__init__()280self.conv = nn.ConvTranspose2d(n_channels, n_channels, (4, 4), (2, 2), (1, 1))281282def forward(self, x: torch.Tensor, t: torch.Tensor):283# `t` is not used, but it's kept in the arguments because for the attention layer function signature284# to match with `ResidualBlock`.285_ = t286return self.conv(x)287288289class Downsample(nn.Module):290"""291### Scale down the feature map by $\frac{1}{2} \times$292"""293294def __init__(self, n_channels):295super().__init__()296self.conv = nn.Conv2d(n_channels, n_channels, (3, 3), (2, 2), (1, 1))297298def forward(self, x: torch.Tensor, t: torch.Tensor):299# `t` is not used, but it's kept in the arguments because for the attention layer function signature300# to match with `ResidualBlock`.301_ = t302return self.conv(x)303304305class UNet(nn.Module):306"""307## U-Net308"""309310def __init__(self, image_channels: int = 3, n_channels: int = 64,311ch_mults: Union[Tuple[int, ...], List[int]] = (1, 2, 2, 4),312is_attn: Union[Tuple[bool, ...], List[bool]] = (False, False, True, True),313n_blocks: int = 2):314"""315* `image_channels` is the number of channels in the image. $3$ for RGB.316* `n_channels` is number of channels in the initial feature map that we transform the image into317* `ch_mults` is the list of channel numbers at each resolution. The number of channels is `ch_mults[i] * n_channels`318* `is_attn` is a list of booleans that indicate whether to use attention at each resolution319* `n_blocks` is the number of `UpDownBlocks` at each resolution320"""321super().__init__()322323# Number of resolutions324n_resolutions = len(ch_mults)325326# Project image into feature map327self.image_proj = nn.Conv2d(image_channels, n_channels, kernel_size=(3, 3), padding=(1, 1))328329# Time embedding layer. Time embedding has `n_channels * 4` channels330self.time_emb = TimeEmbedding(n_channels * 4)331332# #### First half of U-Net - decreasing resolution333down = []334# Number of channels335out_channels = in_channels = n_channels336# For each resolution337for i in range(n_resolutions):338# Number of output channels at this resolution339out_channels = in_channels * ch_mults[i]340# Add `n_blocks`341for _ in range(n_blocks):342down.append(DownBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))343in_channels = out_channels344# Down sample at all resolutions except the last345if i < n_resolutions - 1:346down.append(Downsample(in_channels))347348# Combine the set of modules349self.down = nn.ModuleList(down)350351# Middle block352self.middle = MiddleBlock(out_channels, n_channels * 4, )353354# #### Second half of U-Net - increasing resolution355up = []356# Number of channels357in_channels = out_channels358# For each resolution359for i in reversed(range(n_resolutions)):360# `n_blocks` at the same resolution361out_channels = in_channels362for _ in range(n_blocks):363up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))364# Final block to reduce the number of channels365out_channels = in_channels // ch_mults[i]366up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))367in_channels = out_channels368# Up sample at all resolutions except last369if i > 0:370up.append(Upsample(in_channels))371372# Combine the set of modules373self.up = nn.ModuleList(up)374375# Final normalization and convolution layer376self.norm = nn.GroupNorm(8, n_channels)377self.act = Swish()378self.final = nn.Conv2d(in_channels, image_channels, kernel_size=(3, 3), padding=(1, 1))379380def forward(self, x: torch.Tensor, t: torch.Tensor):381"""382* `x` has shape `[batch_size, in_channels, height, width]`383* `t` has shape `[batch_size]`384"""385386# Get time-step embeddings387t = self.time_emb(t)388389# Get image projection390x = self.image_proj(x)391392# `h` will store outputs at each resolution for skip connection393h = [x]394# First half of U-Net395for m in self.down:396x = m(x, t)397h.append(x)398399# Middle (bottom)400x = self.middle(x, t)401402# Second half of U-Net403for m in self.up:404if isinstance(m, Upsample):405x = m(x, t)406else:407# Get the skip connection from first half of U-Net and concatenate408s = h.pop()409x = torch.cat((x, s), dim=1)410#411x = m(x, t)412413# Final normalization and convolution414return self.final(self.act(self.norm(x)))415416417