Path: blob/master/labml_nn/diffusion/stable_diffusion/sampler/ddpm.py
1677 views
"""1---2title: Denoising Diffusion Probabilistic Models (DDPM) Sampling3summary: >4Annotated PyTorch implementation/tutorial of5Denoising Diffusion Probabilistic Models (DDPM) Sampling6for stable diffusion model.7---89# Denoising Diffusion Probabilistic Models (DDPM) Sampling1011For a simpler DDPM implementation refer to our [DDPM implementation](../../ddpm/index.html).12We use same notations for $\alpha_t$, $\beta_t$ schedules, etc.13"""1415from typing import Optional, List1617import numpy as np18import torch1920from labml import monit21from labml_nn.diffusion.stable_diffusion.latent_diffusion import LatentDiffusion22from labml_nn.diffusion.stable_diffusion.sampler import DiffusionSampler232425class DDPMSampler(DiffusionSampler):26"""27## DDPM Sampler2829This extends the [`DiffusionSampler` base class](index.html).3031DDPM samples images by repeatedly removing noise by sampling step by step from32$p_\theta(x_{t-1} | x_t)$,3334\begin{align}3536p_\theta(x_{t-1} | x_t) &= \mathcal{N}\big(x_{t-1}; \mu_\theta(x_t, t), \tilde\beta_t \mathbf{I} \big) \\3738\mu_t(x_t, t) &= \frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1 - \bar\alpha_t}x_039+ \frac{\sqrt{\alpha_t}(1 - \bar\alpha_{t-1})}{1-\bar\alpha_t}x_t \\4041\tilde\beta_t &= \frac{1 - \bar\alpha_{t-1}}{1 - \bar\alpha_t} \beta_t \\4243x_0 &= \frac{1}{\sqrt{\bar\alpha_t}} x_t - \Big(\sqrt{\frac{1}{\bar\alpha_t} - 1}\Big)\epsilon_\theta \\4445\end{align}46"""4748model: LatentDiffusion4950def __init__(self, model: LatentDiffusion):51"""52:param model: is the model to predict noise $\epsilon_\text{cond}(x_t, c)$53"""54super().__init__(model)5556# Sampling steps $1, 2, \dots, T$57self.time_steps = np.asarray(list(range(self.n_steps)))5859with torch.no_grad():60# $\bar\alpha_t$61alpha_bar = self.model.alpha_bar62# $\beta_t$ schedule63beta = self.model.beta64# $\bar\alpha_{t-1}$65alpha_bar_prev = torch.cat([alpha_bar.new_tensor([1.]), alpha_bar[:-1]])6667# $\sqrt{\bar\alpha}$68self.sqrt_alpha_bar = alpha_bar ** .569# $\sqrt{1 - \bar\alpha}$70self.sqrt_1m_alpha_bar = (1. - alpha_bar) ** .571# $\frac{1}{\sqrt{\bar\alpha_t}}$72self.sqrt_recip_alpha_bar = alpha_bar ** -.573# $\sqrt{\frac{1}{\bar\alpha_t} - 1}$74self.sqrt_recip_m1_alpha_bar = (1 / alpha_bar - 1) ** .57576# $\frac{1 - \bar\alpha_{t-1}}{1 - \bar\alpha_t} \beta_t$77variance = beta * (1. - alpha_bar_prev) / (1. - alpha_bar)78# Clamped log of $\tilde\beta_t$79self.log_var = torch.log(torch.clamp(variance, min=1e-20))80# $\frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1 - \bar\alpha_t}$81self.mean_x0_coef = beta * (alpha_bar_prev ** .5) / (1. - alpha_bar)82# $\frac{\sqrt{\alpha_t}(1 - \bar\alpha_{t-1})}{1-\bar\alpha_t}$83self.mean_xt_coef = (1. - alpha_bar_prev) * ((1 - beta) ** 0.5) / (1. - alpha_bar)8485@torch.no_grad()86def sample(self,87shape: List[int],88cond: torch.Tensor,89repeat_noise: bool = False,90temperature: float = 1.,91x_last: Optional[torch.Tensor] = None,92uncond_scale: float = 1.,93uncond_cond: Optional[torch.Tensor] = None,94skip_steps: int = 0,95):96"""97### Sampling Loop9899:param shape: is the shape of the generated images in the100form `[batch_size, channels, height, width]`101:param cond: is the conditional embeddings $c$102:param temperature: is the noise temperature (random noise gets multiplied by this)103:param x_last: is $x_T$. If not provided random noise will be used.104:param uncond_scale: is the unconditional guidance scale $s$. This is used for105$\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$106:param uncond_cond: is the conditional embedding for empty prompt $c_u$107:param skip_steps: is the number of time steps to skip $t'$. We start sampling from $T - t'$.108And `x_last` is then $x_{T - t'}$.109"""110111# Get device and batch size112device = self.model.device113bs = shape[0]114115# Get $x_T$116x = x_last if x_last is not None else torch.randn(shape, device=device)117118# Time steps to sample at $T - t', T - t' - 1, \dots, 1$119time_steps = np.flip(self.time_steps)[skip_steps:]120121# Sampling loop122for step in monit.iterate('Sample', time_steps):123# Time step $t$124ts = x.new_full((bs,), step, dtype=torch.long)125126# Sample $x_{t-1}$127x, pred_x0, e_t = self.p_sample(x, cond, ts, step,128repeat_noise=repeat_noise,129temperature=temperature,130uncond_scale=uncond_scale,131uncond_cond=uncond_cond)132133# Return $x_0$134return x135136@torch.no_grad()137def p_sample(self, x: torch.Tensor, c: torch.Tensor, t: torch.Tensor, step: int,138repeat_noise: bool = False,139temperature: float = 1.,140uncond_scale: float = 1., uncond_cond: Optional[torch.Tensor] = None):141"""142### Sample $x_{t-1}$ from $p_\theta(x_{t-1} | x_t)$143144:param x: is $x_t$ of shape `[batch_size, channels, height, width]`145:param c: is the conditional embeddings $c$ of shape `[batch_size, emb_size]`146:param t: is $t$ of shape `[batch_size]`147:param step: is the step $t$ as an integer148:repeat_noise: specified whether the noise should be same for all samples in the batch149:param temperature: is the noise temperature (random noise gets multiplied by this)150:param uncond_scale: is the unconditional guidance scale $s$. This is used for151$\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$152:param uncond_cond: is the conditional embedding for empty prompt $c_u$153"""154155# Get $\epsilon_\theta$156e_t = self.get_eps(x, t, c,157uncond_scale=uncond_scale,158uncond_cond=uncond_cond)159160# Get batch size161bs = x.shape[0]162163# $\frac{1}{\sqrt{\bar\alpha_t}}$164sqrt_recip_alpha_bar = x.new_full((bs, 1, 1, 1), self.sqrt_recip_alpha_bar[step])165# $\sqrt{\frac{1}{\bar\alpha_t} - 1}$166sqrt_recip_m1_alpha_bar = x.new_full((bs, 1, 1, 1), self.sqrt_recip_m1_alpha_bar[step])167168# Calculate $x_0$ with current $\epsilon_\theta$169#170# $$x_0 = \frac{1}{\sqrt{\bar\alpha_t}} x_t - \Big(\sqrt{\frac{1}{\bar\alpha_t} - 1}\Big)\epsilon_\theta$$171x0 = sqrt_recip_alpha_bar * x - sqrt_recip_m1_alpha_bar * e_t172173# $\frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1 - \bar\alpha_t}$174mean_x0_coef = x.new_full((bs, 1, 1, 1), self.mean_x0_coef[step])175# $\frac{\sqrt{\alpha_t}(1 - \bar\alpha_{t-1})}{1-\bar\alpha_t}$176mean_xt_coef = x.new_full((bs, 1, 1, 1), self.mean_xt_coef[step])177178# Calculate $\mu_t(x_t, t)$179#180# $$\mu_t(x_t, t) = \frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1 - \bar\alpha_t}x_0181# + \frac{\sqrt{\alpha_t}(1 - \bar\alpha_{t-1})}{1-\bar\alpha_t}x_t$$182mean = mean_x0_coef * x0 + mean_xt_coef * x183# $\log \tilde\beta_t$184log_var = x.new_full((bs, 1, 1, 1), self.log_var[step])185186# Do not add noise when $t = 1$ (final step sampling process).187# Note that `step` is `0` when $t = 1$)188if step == 0:189noise = 0190# If same noise is used for all samples in the batch191elif repeat_noise:192noise = torch.randn((1, *x.shape[1:]))193# Different noise for each sample194else:195noise = torch.randn(x.shape)196197# Multiply noise by the temperature198noise = noise * temperature199200# Sample from,201#202# $$p_\theta(x_{t-1} | x_t) = \mathcal{N}\big(x_{t-1}; \mu_\theta(x_t, t), \tilde\beta_t \mathbf{I} \big)$$203x_prev = mean + (0.5 * log_var).exp() * noise204205#206return x_prev, x0, e_t207208@torch.no_grad()209def q_sample(self, x0: torch.Tensor, index: int, noise: Optional[torch.Tensor] = None):210"""211### Sample from $q(x_t|x_0)$212213$$q(x_t|x_0) = \mathcal{N} \Big(x_t; \sqrt{\bar\alpha_t} x_0, (1-\bar\alpha_t) \mathbf{I} \Big)$$214215:param x0: is $x_0$ of shape `[batch_size, channels, height, width]`216:param index: is the time step $t$ index217:param noise: is the noise, $\epsilon$218"""219220# Random noise, if noise is not specified221if noise is None:222noise = torch.randn_like(x0)223224# Sample from $\mathcal{N} \Big(x_t; \sqrt{\bar\alpha_t} x_0, (1-\bar\alpha_t) \mathbf{I} \Big)$225return self.sqrt_alpha_bar[index] * x0 + self.sqrt_1m_alpha_bar[index] * noise226227228