Path: blob/master/labml_nn/diffusion/ddpm/evaluate.py
4921 views
"""1---2title: Denoising Diffusion Probabilistic Models (DDPM) evaluation/sampling3summary: >4Code to generate samples from a trained5Denoising Diffusion Probabilistic Model.6---78# [Denoising Diffusion Probabilistic Models (DDPM)](index.html) evaluation/sampling910This is the code to generate images and create interpolations between given images.11"""1213import numpy as np14import torch15from matplotlib import pyplot as plt16from torchvision.transforms.functional import to_pil_image, resize1718from labml import experiment, monit19from labml_nn.diffusion.ddpm import DenoiseDiffusion, gather20from labml_nn.diffusion.ddpm.experiment import Configs212223class Sampler:24"""25## Sampler class26"""2728def __init__(self, diffusion: DenoiseDiffusion, image_channels: int, image_size: int, device: torch.device):29"""30* `diffusion` is the `DenoiseDiffusion` instance31* `image_channels` is the number of channels in the image32* `image_size` is the image size33* `device` is the device of the model34"""35self.device = device36self.image_size = image_size37self.image_channels = image_channels38self.diffusion = diffusion3940# $T$41self.n_steps = diffusion.n_steps42# $\textcolor{lightgreen}{\epsilon_\theta}(x_t, t)$43self.eps_model = diffusion.eps_model44# $\beta_t$45self.beta = diffusion.beta46# $\alpha_t$47self.alpha = diffusion.alpha48# $\bar\alpha_t$49self.alpha_bar = diffusion.alpha_bar50# $\bar\alpha_{t-1}$51alpha_bar_tm1 = torch.cat([self.alpha_bar.new_ones((1,)), self.alpha_bar[:-1]])5253# To calculate54#55# \begin{align}56# q(x_{t-1}|x_t, x_0) &= \mathcal{N} \Big(x_{t-1}; \tilde\mu_t(x_t, x_0), \tilde\beta_t \mathbf{I} \Big) \\57# \tilde\mu_t(x_t, x_0) &= \frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1 - \bar\alpha_t}x_058# + \frac{\sqrt{\alpha_t}(1 - \bar\alpha_{t-1})}{1-\bar\alpha_t}x_t \\59# \tilde\beta_t &= \frac{1 - \bar\alpha_{t-1}}{1 - \bar\alpha_t} \beta_t60# \end{align}6162# $$\tilde\beta_t = \frac{1 - \bar\alpha_{t-1}}{1 - \bar\alpha_t} \beta_t$$63self.beta_tilde = self.beta * (1 - alpha_bar_tm1) / (1 - self.alpha_bar)64# $$\frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1 - \bar\alpha_t}$$65self.mu_tilde_coef1 = self.beta * (alpha_bar_tm1 ** 0.5) / (1 - self.alpha_bar)66# $$\frac{\sqrt{\alpha_t}(1 - \bar\alpha_{t-1}}{1-\bar\alpha_t}$$67self.mu_tilde_coef2 = (self.alpha ** 0.5) * (1 - alpha_bar_tm1) / (1 - self.alpha_bar)68# $\sigma^2 = \beta$69self.sigma2 = self.beta7071def show_image(self, img, title=""):72"""Helper function to display an image"""73img = img.clip(0, 1)74img = img.cpu().numpy()75plt.imshow(img.transpose(1, 2, 0))76plt.title(title)77plt.show()7879def make_video(self, frames, path="video.mp4"):80"""Helper function to create a video"""81import imageio82# 20 second video83writer = imageio.get_writer(path, fps=len(frames) // 20)84# Add each image85for f in frames:86f = f.clip(0, 1)87f = to_pil_image(resize(f, [368, 368]))88writer.append_data(np.array(f))89#90writer.close()9192def sample_animation(self, n_frames: int = 1000, create_video: bool = True):93"""94#### Sample an image step-by-step using $\textcolor{lightgreen}{p_\theta}(x_{t-1}|x_t)$9596We sample an image step-by-step using $\textcolor{lightgreen}{p_\theta}(x_{t-1}|x_t)$ and at each step97show the estimate98$$x_0 \approx \hat{x}_0 = \frac{1}{\sqrt{\bar\alpha}}99\Big( x_t - \sqrt{1 - \bar\alpha_t} \textcolor{lightgreen}{\epsilon_\theta}(x_t, t) \Big)$$100"""101102# $x_T \sim p(x_T) = \mathcal{N}(x_T; \mathbf{0}, \mathbf{I})$103xt = torch.randn([1, self.image_channels, self.image_size, self.image_size], device=self.device)104105# Interval to log $\hat{x}_0$106interval = self.n_steps // n_frames107# Frames for video108frames = []109# Sample $T$ steps110for t_inv in monit.iterate('Denoise', self.n_steps):111# $t$112t_ = self.n_steps - t_inv - 1113# $t$ in a tensor114t = xt.new_full((1,), t_, dtype=torch.long)115# $\textcolor{lightgreen}{\epsilon_\theta}(x_t, t)$116eps_theta = self.eps_model(xt, t)117if t_ % interval == 0:118# Get $\hat{x}_0$ and add to frames119x0 = self.p_x0(xt, t, eps_theta)120frames.append(x0[0])121if not create_video:122self.show_image(x0[0], f"{t_}")123# Sample from $\textcolor{lightgreen}{p_\theta}(x_{t-1}|x_t)$124xt = self.p_sample(xt, t, eps_theta)125126# Make video127if create_video:128self.make_video(frames)129130def interpolate(self, x1: torch.Tensor, x2: torch.Tensor, lambda_: float, t_: int = 100):131"""132#### Interpolate two images $x_0$ and $x'_0$133134We get $x_t \sim q(x_t|x_0)$ and $x'_t \sim q(x'_t|x_0)$.135136Then interpolate to137$$\bar{x}_t = (1 - \lambda)x_t + \lambda x'_0$$138139Then get140$$\bar{x}_0 \sim \textcolor{lightgreen}{p_\theta}(x_0|\bar{x}_t)$$141142* `x1` is $x_0$143* `x2` is $x'_0$144* `lambda_` is $\lambda$145* `t_` is $t$146"""147148# Number of samples149n_samples = x1.shape[0]150# $t$ tensor151t = torch.full((n_samples,), t_, device=self.device)152# $$\bar{x}_t = (1 - \lambda)x_t + \lambda x'_0$$153xt = (1 - lambda_) * self.diffusion.q_sample(x1, t) + lambda_ * self.diffusion.q_sample(x2, t)154155# $$\bar{x}_0 \sim \textcolor{lightgreen}{p_\theta}(x_0|\bar{x}_t)$$156return self._sample_x0(xt, t_)157158def interpolate_animate(self, x1: torch.Tensor, x2: torch.Tensor, n_frames: int = 100, t_: int = 100,159create_video=True):160"""161#### Interpolate two images $x_0$ and $x'_0$ and make a video162163* `x1` is $x_0$164* `x2` is $x'_0$165* `n_frames` is the number of frames for the image166* `t_` is $t$167* `create_video` specifies whether to make a video or to show each frame168"""169170# Show original images171self.show_image(x1, "x1")172self.show_image(x2, "x2")173# Add batch dimension174x1 = x1[None, :, :, :]175x2 = x2[None, :, :, :]176# $t$ tensor177t = torch.full((1,), t_, device=self.device)178# $x_t \sim q(x_t|x_0)$179x1t = self.diffusion.q_sample(x1, t)180# $x'_t \sim q(x'_t|x_0)$181x2t = self.diffusion.q_sample(x2, t)182183frames = []184# Get frames with different $\lambda$185for i in monit.iterate('Interpolate', n_frames + 1, is_children_silent=True):186# $\lambda$187lambda_ = i / n_frames188# $$\bar{x}_t = (1 - \lambda)x_t + \lambda x'_0$$189xt = (1 - lambda_) * x1t + lambda_ * x2t190# $$\bar{x}_0 \sim \textcolor{lightgreen}{p_\theta}(x_0|\bar{x}_t)$$191x0 = self._sample_x0(xt, t_)192# Add to frames193frames.append(x0[0])194# Show frame195if not create_video:196self.show_image(x0[0], f"{lambda_ :.2f}")197198# Make video199if create_video:200self.make_video(frames)201202def _sample_x0(self, xt: torch.Tensor, n_steps: int):203"""204#### Sample an image using $\textcolor{lightgreen}{p_\theta}(x_{t-1}|x_t)$205206* `xt` is $x_t$207* `n_steps` is $t$208"""209210# Number of sampels211n_samples = xt.shape[0]212# Iterate until $t$ steps213for t_ in monit.iterate('Denoise', n_steps):214t = n_steps - t_ - 1215# Sample from $\textcolor{lightgreen}{p_\theta}(x_{t-1}|x_t)$216xt = self.diffusion.p_sample(xt, xt.new_full((n_samples,), t, dtype=torch.long))217218# Return $x_0$219return xt220221def sample(self, n_samples: int = 16):222"""223#### Generate images224"""225# $x_T \sim p(x_T) = \mathcal{N}(x_T; \mathbf{0}, \mathbf{I})$226xt = torch.randn([n_samples, self.image_channels, self.image_size, self.image_size], device=self.device)227228# $$x_0 \sim \textcolor{lightgreen}{p_\theta}(x_0|x_t)$$229x0 = self._sample_x0(xt, self.n_steps)230231# Show images232for i in range(n_samples):233self.show_image(x0[i])234235def p_sample(self, xt: torch.Tensor, t: torch.Tensor, eps_theta: torch.Tensor):236"""237#### Sample from $\textcolor{lightgreen}{p_\theta}(x_{t-1}|x_t)$238239\begin{align}240\textcolor{lightgreen}{p_\theta}(x_{t-1} | x_t) &= \mathcal{N}\big(x_{t-1};241\textcolor{lightgreen}{\mu_\theta}(x_t, t), \sigma_t^2 \mathbf{I} \big) \\242\textcolor{lightgreen}{\mu_\theta}(x_t, t)243&= \frac{1}{\sqrt{\alpha_t}} \Big(x_t -244\frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\textcolor{lightgreen}{\epsilon_\theta}(x_t, t) \Big)245\end{align}246"""247# [gather](utils.html) $\bar\alpha_t$248alpha_bar = gather(self.alpha_bar, t)249# $\alpha_t$250alpha = gather(self.alpha, t)251# $\frac{\beta}{\sqrt{1-\bar\alpha_t}}$252eps_coef = (1 - alpha) / (1 - alpha_bar) ** .5253# $$\frac{1}{\sqrt{\alpha_t}} \Big(x_t -254# \frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\textcolor{lightgreen}{\epsilon_\theta}(x_t, t) \Big)$$255mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta)256# $\sigma^2$257var = gather(self.sigma2, t)258259# $\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$260eps = torch.randn(xt.shape, device=xt.device)261# Sample262return mean + (var ** .5) * eps263264def p_x0(self, xt: torch.Tensor, t: torch.Tensor, eps: torch.Tensor):265"""266#### Estimate $x_0$267268$$x_0 \approx \hat{x}_0 = \frac{1}{\sqrt{\bar\alpha}}269\Big( x_t - \sqrt{1 - \bar\alpha_t} \textcolor{lightgreen}{\epsilon_\theta}(x_t, t) \Big)$$270"""271# [gather](utils.html) $\bar\alpha_t$272alpha_bar = gather(self.alpha_bar, t)273274# $$x_0 \approx \hat{x}_0 = \frac{1}{\sqrt{\bar\alpha}}275# \Big( x_t - \sqrt{1 - \bar\alpha_t} \textcolor{lightgreen}{\epsilon_\theta}(x_t, t) \Big)$$276return (xt - (1 - alpha_bar) ** 0.5 * eps) / (alpha_bar ** 0.5)277278279def main():280"""Generate samples"""281282# Training experiment run UUID283run_uuid = "a44333ea251411ec8007d1a1762ed686"284285# Start an evaluation286experiment.evaluate()287288# Create configs289configs = Configs()290# Load custom configuration of the training run291configs_dict = experiment.load_configs(run_uuid)292# Set configurations293experiment.configs(configs, configs_dict)294295# Initialize296configs.init()297298# Set PyTorch modules for saving and loading299experiment.add_pytorch_models({'eps_model': configs.eps_model})300301# Load training experiment302experiment.load(run_uuid)303304# Create sampler305sampler = Sampler(diffusion=configs.diffusion,306image_channels=configs.image_channels,307image_size=configs.image_size,308device=configs.device)309310# Start evaluation311with experiment.start():312# No gradients313with torch.no_grad():314# Sample an image with an denoising animation315sampler.sample_animation()316317if False:318# Get some images fro data319data = next(iter(configs.data_loader)).to(configs.device)320321# Create an interpolation animation322sampler.interpolate_animate(data[0], data[1])323324325#326if __name__ == '__main__':327main()328329330