Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
labmlai
GitHub Repository: labmlai/annotated_deep_learning_paper_implementations
Path: blob/master/labml_nn/diffusion/stable_diffusion/sampler/ddpm.py
1677 views
1
"""
2
---
3
title: Denoising Diffusion Probabilistic Models (DDPM) Sampling
4
summary: >
5
Annotated PyTorch implementation/tutorial of
6
Denoising Diffusion Probabilistic Models (DDPM) Sampling
7
for stable diffusion model.
8
---
9
10
# Denoising Diffusion Probabilistic Models (DDPM) Sampling
11
12
For a simpler DDPM implementation refer to our [DDPM implementation](../../ddpm/index.html).
13
We use same notations for $\alpha_t$, $\beta_t$ schedules, etc.
14
"""
15
16
from typing import Optional, List
17
18
import numpy as np
19
import torch
20
21
from labml import monit
22
from labml_nn.diffusion.stable_diffusion.latent_diffusion import LatentDiffusion
23
from labml_nn.diffusion.stable_diffusion.sampler import DiffusionSampler
24
25
26
class DDPMSampler(DiffusionSampler):
27
"""
28
## DDPM Sampler
29
30
This extends the [`DiffusionSampler` base class](index.html).
31
32
DDPM samples images by repeatedly removing noise by sampling step by step from
33
$p_\theta(x_{t-1} | x_t)$,
34
35
\begin{align}
36
37
p_\theta(x_{t-1} | x_t) &= \mathcal{N}\big(x_{t-1}; \mu_\theta(x_t, t), \tilde\beta_t \mathbf{I} \big) \\
38
39
\mu_t(x_t, t) &= \frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1 - \bar\alpha_t}x_0
40
+ \frac{\sqrt{\alpha_t}(1 - \bar\alpha_{t-1})}{1-\bar\alpha_t}x_t \\
41
42
\tilde\beta_t &= \frac{1 - \bar\alpha_{t-1}}{1 - \bar\alpha_t} \beta_t \\
43
44
x_0 &= \frac{1}{\sqrt{\bar\alpha_t}} x_t - \Big(\sqrt{\frac{1}{\bar\alpha_t} - 1}\Big)\epsilon_\theta \\
45
46
\end{align}
47
"""
48
49
model: LatentDiffusion
50
51
def __init__(self, model: LatentDiffusion):
52
"""
53
:param model: is the model to predict noise $\epsilon_\text{cond}(x_t, c)$
54
"""
55
super().__init__(model)
56
57
# Sampling steps $1, 2, \dots, T$
58
self.time_steps = np.asarray(list(range(self.n_steps)))
59
60
with torch.no_grad():
61
# $\bar\alpha_t$
62
alpha_bar = self.model.alpha_bar
63
# $\beta_t$ schedule
64
beta = self.model.beta
65
# $\bar\alpha_{t-1}$
66
alpha_bar_prev = torch.cat([alpha_bar.new_tensor([1.]), alpha_bar[:-1]])
67
68
# $\sqrt{\bar\alpha}$
69
self.sqrt_alpha_bar = alpha_bar ** .5
70
# $\sqrt{1 - \bar\alpha}$
71
self.sqrt_1m_alpha_bar = (1. - alpha_bar) ** .5
72
# $\frac{1}{\sqrt{\bar\alpha_t}}$
73
self.sqrt_recip_alpha_bar = alpha_bar ** -.5
74
# $\sqrt{\frac{1}{\bar\alpha_t} - 1}$
75
self.sqrt_recip_m1_alpha_bar = (1 / alpha_bar - 1) ** .5
76
77
# $\frac{1 - \bar\alpha_{t-1}}{1 - \bar\alpha_t} \beta_t$
78
variance = beta * (1. - alpha_bar_prev) / (1. - alpha_bar)
79
# Clamped log of $\tilde\beta_t$
80
self.log_var = torch.log(torch.clamp(variance, min=1e-20))
81
# $\frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1 - \bar\alpha_t}$
82
self.mean_x0_coef = beta * (alpha_bar_prev ** .5) / (1. - alpha_bar)
83
# $\frac{\sqrt{\alpha_t}(1 - \bar\alpha_{t-1})}{1-\bar\alpha_t}$
84
self.mean_xt_coef = (1. - alpha_bar_prev) * ((1 - beta) ** 0.5) / (1. - alpha_bar)
85
86
@torch.no_grad()
87
def sample(self,
88
shape: List[int],
89
cond: torch.Tensor,
90
repeat_noise: bool = False,
91
temperature: float = 1.,
92
x_last: Optional[torch.Tensor] = None,
93
uncond_scale: float = 1.,
94
uncond_cond: Optional[torch.Tensor] = None,
95
skip_steps: int = 0,
96
):
97
"""
98
### Sampling Loop
99
100
:param shape: is the shape of the generated images in the
101
form `[batch_size, channels, height, width]`
102
:param cond: is the conditional embeddings $c$
103
:param temperature: is the noise temperature (random noise gets multiplied by this)
104
:param x_last: is $x_T$. If not provided random noise will be used.
105
:param uncond_scale: is the unconditional guidance scale $s$. This is used for
106
$\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$
107
:param uncond_cond: is the conditional embedding for empty prompt $c_u$
108
:param skip_steps: is the number of time steps to skip $t'$. We start sampling from $T - t'$.
109
And `x_last` is then $x_{T - t'}$.
110
"""
111
112
# Get device and batch size
113
device = self.model.device
114
bs = shape[0]
115
116
# Get $x_T$
117
x = x_last if x_last is not None else torch.randn(shape, device=device)
118
119
# Time steps to sample at $T - t', T - t' - 1, \dots, 1$
120
time_steps = np.flip(self.time_steps)[skip_steps:]
121
122
# Sampling loop
123
for step in monit.iterate('Sample', time_steps):
124
# Time step $t$
125
ts = x.new_full((bs,), step, dtype=torch.long)
126
127
# Sample $x_{t-1}$
128
x, pred_x0, e_t = self.p_sample(x, cond, ts, step,
129
repeat_noise=repeat_noise,
130
temperature=temperature,
131
uncond_scale=uncond_scale,
132
uncond_cond=uncond_cond)
133
134
# Return $x_0$
135
return x
136
137
@torch.no_grad()
138
def p_sample(self, x: torch.Tensor, c: torch.Tensor, t: torch.Tensor, step: int,
139
repeat_noise: bool = False,
140
temperature: float = 1.,
141
uncond_scale: float = 1., uncond_cond: Optional[torch.Tensor] = None):
142
"""
143
### Sample $x_{t-1}$ from $p_\theta(x_{t-1} | x_t)$
144
145
:param x: is $x_t$ of shape `[batch_size, channels, height, width]`
146
:param c: is the conditional embeddings $c$ of shape `[batch_size, emb_size]`
147
:param t: is $t$ of shape `[batch_size]`
148
:param step: is the step $t$ as an integer
149
:repeat_noise: specified whether the noise should be same for all samples in the batch
150
:param temperature: is the noise temperature (random noise gets multiplied by this)
151
:param uncond_scale: is the unconditional guidance scale $s$. This is used for
152
$\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$
153
:param uncond_cond: is the conditional embedding for empty prompt $c_u$
154
"""
155
156
# Get $\epsilon_\theta$
157
e_t = self.get_eps(x, t, c,
158
uncond_scale=uncond_scale,
159
uncond_cond=uncond_cond)
160
161
# Get batch size
162
bs = x.shape[0]
163
164
# $\frac{1}{\sqrt{\bar\alpha_t}}$
165
sqrt_recip_alpha_bar = x.new_full((bs, 1, 1, 1), self.sqrt_recip_alpha_bar[step])
166
# $\sqrt{\frac{1}{\bar\alpha_t} - 1}$
167
sqrt_recip_m1_alpha_bar = x.new_full((bs, 1, 1, 1), self.sqrt_recip_m1_alpha_bar[step])
168
169
# Calculate $x_0$ with current $\epsilon_\theta$
170
#
171
# $$x_0 = \frac{1}{\sqrt{\bar\alpha_t}} x_t - \Big(\sqrt{\frac{1}{\bar\alpha_t} - 1}\Big)\epsilon_\theta$$
172
x0 = sqrt_recip_alpha_bar * x - sqrt_recip_m1_alpha_bar * e_t
173
174
# $\frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1 - \bar\alpha_t}$
175
mean_x0_coef = x.new_full((bs, 1, 1, 1), self.mean_x0_coef[step])
176
# $\frac{\sqrt{\alpha_t}(1 - \bar\alpha_{t-1})}{1-\bar\alpha_t}$
177
mean_xt_coef = x.new_full((bs, 1, 1, 1), self.mean_xt_coef[step])
178
179
# Calculate $\mu_t(x_t, t)$
180
#
181
# $$\mu_t(x_t, t) = \frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1 - \bar\alpha_t}x_0
182
# + \frac{\sqrt{\alpha_t}(1 - \bar\alpha_{t-1})}{1-\bar\alpha_t}x_t$$
183
mean = mean_x0_coef * x0 + mean_xt_coef * x
184
# $\log \tilde\beta_t$
185
log_var = x.new_full((bs, 1, 1, 1), self.log_var[step])
186
187
# Do not add noise when $t = 1$ (final step sampling process).
188
# Note that `step` is `0` when $t = 1$)
189
if step == 0:
190
noise = 0
191
# If same noise is used for all samples in the batch
192
elif repeat_noise:
193
noise = torch.randn((1, *x.shape[1:]))
194
# Different noise for each sample
195
else:
196
noise = torch.randn(x.shape)
197
198
# Multiply noise by the temperature
199
noise = noise * temperature
200
201
# Sample from,
202
#
203
# $$p_\theta(x_{t-1} | x_t) = \mathcal{N}\big(x_{t-1}; \mu_\theta(x_t, t), \tilde\beta_t \mathbf{I} \big)$$
204
x_prev = mean + (0.5 * log_var).exp() * noise
205
206
#
207
return x_prev, x0, e_t
208
209
@torch.no_grad()
210
def q_sample(self, x0: torch.Tensor, index: int, noise: Optional[torch.Tensor] = None):
211
"""
212
### Sample from $q(x_t|x_0)$
213
214
$$q(x_t|x_0) = \mathcal{N} \Big(x_t; \sqrt{\bar\alpha_t} x_0, (1-\bar\alpha_t) \mathbf{I} \Big)$$
215
216
:param x0: is $x_0$ of shape `[batch_size, channels, height, width]`
217
:param index: is the time step $t$ index
218
:param noise: is the noise, $\epsilon$
219
"""
220
221
# Random noise, if noise is not specified
222
if noise is None:
223
noise = torch.randn_like(x0)
224
225
# Sample from $\mathcal{N} \Big(x_t; \sqrt{\bar\alpha_t} x_0, (1-\bar\alpha_t) \mathbf{I} \Big)$
226
return self.sqrt_alpha_bar[index] * x0 + self.sqrt_1m_alpha_bar[index] * noise
227
228