Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
labmlai
GitHub Repository: labmlai/annotated_deep_learning_paper_implementations
Path: blob/master/labml_nn/diffusion/ddpm/evaluate.py
4921 views
1
"""
2
---
3
title: Denoising Diffusion Probabilistic Models (DDPM) evaluation/sampling
4
summary: >
5
Code to generate samples from a trained
6
Denoising Diffusion Probabilistic Model.
7
---
8
9
# [Denoising Diffusion Probabilistic Models (DDPM)](index.html) evaluation/sampling
10
11
This is the code to generate images and create interpolations between given images.
12
"""
13
14
import numpy as np
15
import torch
16
from matplotlib import pyplot as plt
17
from torchvision.transforms.functional import to_pil_image, resize
18
19
from labml import experiment, monit
20
from labml_nn.diffusion.ddpm import DenoiseDiffusion, gather
21
from labml_nn.diffusion.ddpm.experiment import Configs
22
23
24
class Sampler:
25
"""
26
## Sampler class
27
"""
28
29
def __init__(self, diffusion: DenoiseDiffusion, image_channels: int, image_size: int, device: torch.device):
30
"""
31
* `diffusion` is the `DenoiseDiffusion` instance
32
* `image_channels` is the number of channels in the image
33
* `image_size` is the image size
34
* `device` is the device of the model
35
"""
36
self.device = device
37
self.image_size = image_size
38
self.image_channels = image_channels
39
self.diffusion = diffusion
40
41
# $T$
42
self.n_steps = diffusion.n_steps
43
# $\textcolor{lightgreen}{\epsilon_\theta}(x_t, t)$
44
self.eps_model = diffusion.eps_model
45
# $\beta_t$
46
self.beta = diffusion.beta
47
# $\alpha_t$
48
self.alpha = diffusion.alpha
49
# $\bar\alpha_t$
50
self.alpha_bar = diffusion.alpha_bar
51
# $\bar\alpha_{t-1}$
52
alpha_bar_tm1 = torch.cat([self.alpha_bar.new_ones((1,)), self.alpha_bar[:-1]])
53
54
# To calculate
55
#
56
# \begin{align}
57
# q(x_{t-1}|x_t, x_0) &= \mathcal{N} \Big(x_{t-1}; \tilde\mu_t(x_t, x_0), \tilde\beta_t \mathbf{I} \Big) \\
58
# \tilde\mu_t(x_t, x_0) &= \frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1 - \bar\alpha_t}x_0
59
# + \frac{\sqrt{\alpha_t}(1 - \bar\alpha_{t-1})}{1-\bar\alpha_t}x_t \\
60
# \tilde\beta_t &= \frac{1 - \bar\alpha_{t-1}}{1 - \bar\alpha_t} \beta_t
61
# \end{align}
62
63
# $$\tilde\beta_t = \frac{1 - \bar\alpha_{t-1}}{1 - \bar\alpha_t} \beta_t$$
64
self.beta_tilde = self.beta * (1 - alpha_bar_tm1) / (1 - self.alpha_bar)
65
# $$\frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1 - \bar\alpha_t}$$
66
self.mu_tilde_coef1 = self.beta * (alpha_bar_tm1 ** 0.5) / (1 - self.alpha_bar)
67
# $$\frac{\sqrt{\alpha_t}(1 - \bar\alpha_{t-1}}{1-\bar\alpha_t}$$
68
self.mu_tilde_coef2 = (self.alpha ** 0.5) * (1 - alpha_bar_tm1) / (1 - self.alpha_bar)
69
# $\sigma^2 = \beta$
70
self.sigma2 = self.beta
71
72
def show_image(self, img, title=""):
73
"""Helper function to display an image"""
74
img = img.clip(0, 1)
75
img = img.cpu().numpy()
76
plt.imshow(img.transpose(1, 2, 0))
77
plt.title(title)
78
plt.show()
79
80
def make_video(self, frames, path="video.mp4"):
81
"""Helper function to create a video"""
82
import imageio
83
# 20 second video
84
writer = imageio.get_writer(path, fps=len(frames) // 20)
85
# Add each image
86
for f in frames:
87
f = f.clip(0, 1)
88
f = to_pil_image(resize(f, [368, 368]))
89
writer.append_data(np.array(f))
90
#
91
writer.close()
92
93
def sample_animation(self, n_frames: int = 1000, create_video: bool = True):
94
"""
95
#### Sample an image step-by-step using $\textcolor{lightgreen}{p_\theta}(x_{t-1}|x_t)$
96
97
We sample an image step-by-step using $\textcolor{lightgreen}{p_\theta}(x_{t-1}|x_t)$ and at each step
98
show the estimate
99
$$x_0 \approx \hat{x}_0 = \frac{1}{\sqrt{\bar\alpha}}
100
\Big( x_t - \sqrt{1 - \bar\alpha_t} \textcolor{lightgreen}{\epsilon_\theta}(x_t, t) \Big)$$
101
"""
102
103
# $x_T \sim p(x_T) = \mathcal{N}(x_T; \mathbf{0}, \mathbf{I})$
104
xt = torch.randn([1, self.image_channels, self.image_size, self.image_size], device=self.device)
105
106
# Interval to log $\hat{x}_0$
107
interval = self.n_steps // n_frames
108
# Frames for video
109
frames = []
110
# Sample $T$ steps
111
for t_inv in monit.iterate('Denoise', self.n_steps):
112
# $t$
113
t_ = self.n_steps - t_inv - 1
114
# $t$ in a tensor
115
t = xt.new_full((1,), t_, dtype=torch.long)
116
# $\textcolor{lightgreen}{\epsilon_\theta}(x_t, t)$
117
eps_theta = self.eps_model(xt, t)
118
if t_ % interval == 0:
119
# Get $\hat{x}_0$ and add to frames
120
x0 = self.p_x0(xt, t, eps_theta)
121
frames.append(x0[0])
122
if not create_video:
123
self.show_image(x0[0], f"{t_}")
124
# Sample from $\textcolor{lightgreen}{p_\theta}(x_{t-1}|x_t)$
125
xt = self.p_sample(xt, t, eps_theta)
126
127
# Make video
128
if create_video:
129
self.make_video(frames)
130
131
def interpolate(self, x1: torch.Tensor, x2: torch.Tensor, lambda_: float, t_: int = 100):
132
"""
133
#### Interpolate two images $x_0$ and $x'_0$
134
135
We get $x_t \sim q(x_t|x_0)$ and $x'_t \sim q(x'_t|x_0)$.
136
137
Then interpolate to
138
$$\bar{x}_t = (1 - \lambda)x_t + \lambda x'_0$$
139
140
Then get
141
$$\bar{x}_0 \sim \textcolor{lightgreen}{p_\theta}(x_0|\bar{x}_t)$$
142
143
* `x1` is $x_0$
144
* `x2` is $x'_0$
145
* `lambda_` is $\lambda$
146
* `t_` is $t$
147
"""
148
149
# Number of samples
150
n_samples = x1.shape[0]
151
# $t$ tensor
152
t = torch.full((n_samples,), t_, device=self.device)
153
# $$\bar{x}_t = (1 - \lambda)x_t + \lambda x'_0$$
154
xt = (1 - lambda_) * self.diffusion.q_sample(x1, t) + lambda_ * self.diffusion.q_sample(x2, t)
155
156
# $$\bar{x}_0 \sim \textcolor{lightgreen}{p_\theta}(x_0|\bar{x}_t)$$
157
return self._sample_x0(xt, t_)
158
159
def interpolate_animate(self, x1: torch.Tensor, x2: torch.Tensor, n_frames: int = 100, t_: int = 100,
160
create_video=True):
161
"""
162
#### Interpolate two images $x_0$ and $x'_0$ and make a video
163
164
* `x1` is $x_0$
165
* `x2` is $x'_0$
166
* `n_frames` is the number of frames for the image
167
* `t_` is $t$
168
* `create_video` specifies whether to make a video or to show each frame
169
"""
170
171
# Show original images
172
self.show_image(x1, "x1")
173
self.show_image(x2, "x2")
174
# Add batch dimension
175
x1 = x1[None, :, :, :]
176
x2 = x2[None, :, :, :]
177
# $t$ tensor
178
t = torch.full((1,), t_, device=self.device)
179
# $x_t \sim q(x_t|x_0)$
180
x1t = self.diffusion.q_sample(x1, t)
181
# $x'_t \sim q(x'_t|x_0)$
182
x2t = self.diffusion.q_sample(x2, t)
183
184
frames = []
185
# Get frames with different $\lambda$
186
for i in monit.iterate('Interpolate', n_frames + 1, is_children_silent=True):
187
# $\lambda$
188
lambda_ = i / n_frames
189
# $$\bar{x}_t = (1 - \lambda)x_t + \lambda x'_0$$
190
xt = (1 - lambda_) * x1t + lambda_ * x2t
191
# $$\bar{x}_0 \sim \textcolor{lightgreen}{p_\theta}(x_0|\bar{x}_t)$$
192
x0 = self._sample_x0(xt, t_)
193
# Add to frames
194
frames.append(x0[0])
195
# Show frame
196
if not create_video:
197
self.show_image(x0[0], f"{lambda_ :.2f}")
198
199
# Make video
200
if create_video:
201
self.make_video(frames)
202
203
def _sample_x0(self, xt: torch.Tensor, n_steps: int):
204
"""
205
#### Sample an image using $\textcolor{lightgreen}{p_\theta}(x_{t-1}|x_t)$
206
207
* `xt` is $x_t$
208
* `n_steps` is $t$
209
"""
210
211
# Number of sampels
212
n_samples = xt.shape[0]
213
# Iterate until $t$ steps
214
for t_ in monit.iterate('Denoise', n_steps):
215
t = n_steps - t_ - 1
216
# Sample from $\textcolor{lightgreen}{p_\theta}(x_{t-1}|x_t)$
217
xt = self.diffusion.p_sample(xt, xt.new_full((n_samples,), t, dtype=torch.long))
218
219
# Return $x_0$
220
return xt
221
222
def sample(self, n_samples: int = 16):
223
"""
224
#### Generate images
225
"""
226
# $x_T \sim p(x_T) = \mathcal{N}(x_T; \mathbf{0}, \mathbf{I})$
227
xt = torch.randn([n_samples, self.image_channels, self.image_size, self.image_size], device=self.device)
228
229
# $$x_0 \sim \textcolor{lightgreen}{p_\theta}(x_0|x_t)$$
230
x0 = self._sample_x0(xt, self.n_steps)
231
232
# Show images
233
for i in range(n_samples):
234
self.show_image(x0[i])
235
236
def p_sample(self, xt: torch.Tensor, t: torch.Tensor, eps_theta: torch.Tensor):
237
"""
238
#### Sample from $\textcolor{lightgreen}{p_\theta}(x_{t-1}|x_t)$
239
240
\begin{align}
241
\textcolor{lightgreen}{p_\theta}(x_{t-1} | x_t) &= \mathcal{N}\big(x_{t-1};
242
\textcolor{lightgreen}{\mu_\theta}(x_t, t), \sigma_t^2 \mathbf{I} \big) \\
243
\textcolor{lightgreen}{\mu_\theta}(x_t, t)
244
&= \frac{1}{\sqrt{\alpha_t}} \Big(x_t -
245
\frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\textcolor{lightgreen}{\epsilon_\theta}(x_t, t) \Big)
246
\end{align}
247
"""
248
# [gather](utils.html) $\bar\alpha_t$
249
alpha_bar = gather(self.alpha_bar, t)
250
# $\alpha_t$
251
alpha = gather(self.alpha, t)
252
# $\frac{\beta}{\sqrt{1-\bar\alpha_t}}$
253
eps_coef = (1 - alpha) / (1 - alpha_bar) ** .5
254
# $$\frac{1}{\sqrt{\alpha_t}} \Big(x_t -
255
# \frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\textcolor{lightgreen}{\epsilon_\theta}(x_t, t) \Big)$$
256
mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta)
257
# $\sigma^2$
258
var = gather(self.sigma2, t)
259
260
# $\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$
261
eps = torch.randn(xt.shape, device=xt.device)
262
# Sample
263
return mean + (var ** .5) * eps
264
265
def p_x0(self, xt: torch.Tensor, t: torch.Tensor, eps: torch.Tensor):
266
"""
267
#### Estimate $x_0$
268
269
$$x_0 \approx \hat{x}_0 = \frac{1}{\sqrt{\bar\alpha}}
270
\Big( x_t - \sqrt{1 - \bar\alpha_t} \textcolor{lightgreen}{\epsilon_\theta}(x_t, t) \Big)$$
271
"""
272
# [gather](utils.html) $\bar\alpha_t$
273
alpha_bar = gather(self.alpha_bar, t)
274
275
# $$x_0 \approx \hat{x}_0 = \frac{1}{\sqrt{\bar\alpha}}
276
# \Big( x_t - \sqrt{1 - \bar\alpha_t} \textcolor{lightgreen}{\epsilon_\theta}(x_t, t) \Big)$$
277
return (xt - (1 - alpha_bar) ** 0.5 * eps) / (alpha_bar ** 0.5)
278
279
280
def main():
281
"""Generate samples"""
282
283
# Training experiment run UUID
284
run_uuid = "a44333ea251411ec8007d1a1762ed686"
285
286
# Start an evaluation
287
experiment.evaluate()
288
289
# Create configs
290
configs = Configs()
291
# Load custom configuration of the training run
292
configs_dict = experiment.load_configs(run_uuid)
293
# Set configurations
294
experiment.configs(configs, configs_dict)
295
296
# Initialize
297
configs.init()
298
299
# Set PyTorch modules for saving and loading
300
experiment.add_pytorch_models({'eps_model': configs.eps_model})
301
302
# Load training experiment
303
experiment.load(run_uuid)
304
305
# Create sampler
306
sampler = Sampler(diffusion=configs.diffusion,
307
image_channels=configs.image_channels,
308
image_size=configs.image_size,
309
device=configs.device)
310
311
# Start evaluation
312
with experiment.start():
313
# No gradients
314
with torch.no_grad():
315
# Sample an image with an denoising animation
316
sampler.sample_animation()
317
318
if False:
319
# Get some images fro data
320
data = next(iter(configs.data_loader)).to(configs.device)
321
322
# Create an interpolation animation
323
sampler.interpolate_animate(data[0], data[1])
324
325
326
#
327
if __name__ == '__main__':
328
main()
329
330