Path: blob/master/modules/models/diffusion/ddpm_edit.py
3073 views
"""1wild mixture of2https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py3https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py4https://github.com/CompVis/taming-transformers5-- merci6"""78# File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).9# See more details in LICENSE.1011import torch12import torch.nn as nn13import numpy as np14import pytorch_lightning as pl15from torch.optim.lr_scheduler import LambdaLR16from einops import rearrange, repeat17from contextlib import contextmanager18from functools import partial19from tqdm import tqdm20from torchvision.utils import make_grid21from pytorch_lightning.utilities.distributed import rank_zero_only2223from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config24from ldm.modules.ema import LitEma25from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution26from ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL27from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like28from ldm.models.diffusion.ddim import DDIMSampler2930try:31from ldm.models.autoencoder import VQModelInterface32except Exception:33class VQModelInterface:34pass3536__conditioning_keys__ = {'concat': 'c_concat',37'crossattn': 'c_crossattn',38'adm': 'y'}394041def disabled_train(self, mode=True):42"""Overwrite model.train with this function to make sure train/eval mode43does not change anymore."""44return self454647def uniform_on_device(r1, r2, shape, device):48return (r1 - r2) * torch.rand(*shape, device=device) + r2495051class DDPM(pl.LightningModule):52# classic DDPM with Gaussian diffusion, in image space53def __init__(self,54unet_config,55timesteps=1000,56beta_schedule="linear",57loss_type="l2",58ckpt_path=None,59ignore_keys=None,60load_only_unet=False,61monitor="val/loss",62use_ema=True,63first_stage_key="image",64image_size=256,65channels=3,66log_every_t=100,67clip_denoised=True,68linear_start=1e-4,69linear_end=2e-2,70cosine_s=8e-3,71given_betas=None,72original_elbo_weight=0.,73v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta74l_simple_weight=1.,75conditioning_key=None,76parameterization="eps", # all assuming fixed variance schedules77scheduler_config=None,78use_positional_encodings=False,79learn_logvar=False,80logvar_init=0.,81load_ema=True,82):83super().__init__()84assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'85self.parameterization = parameterization86print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")87self.cond_stage_model = None88self.clip_denoised = clip_denoised89self.log_every_t = log_every_t90self.first_stage_key = first_stage_key91self.image_size = image_size # try conv?92self.channels = channels93self.use_positional_encodings = use_positional_encodings94self.model = DiffusionWrapper(unet_config, conditioning_key)95count_params(self.model, verbose=True)96self.use_ema = use_ema9798self.use_scheduler = scheduler_config is not None99if self.use_scheduler:100self.scheduler_config = scheduler_config101102self.v_posterior = v_posterior103self.original_elbo_weight = original_elbo_weight104self.l_simple_weight = l_simple_weight105106if monitor is not None:107self.monitor = monitor108109if self.use_ema and load_ema:110self.model_ema = LitEma(self.model)111print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")112113if ckpt_path is not None:114self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [], only_model=load_only_unet)115116# If initialing from EMA-only checkpoint, create EMA model after loading.117if self.use_ema and not load_ema:118self.model_ema = LitEma(self.model)119print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")120121self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,122linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)123124self.loss_type = loss_type125126self.learn_logvar = learn_logvar127self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))128if self.learn_logvar:129self.logvar = nn.Parameter(self.logvar, requires_grad=True)130131132def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,133linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):134if exists(given_betas):135betas = given_betas136else:137betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,138cosine_s=cosine_s)139alphas = 1. - betas140alphas_cumprod = np.cumprod(alphas, axis=0)141alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])142143timesteps, = betas.shape144self.num_timesteps = int(timesteps)145self.linear_start = linear_start146self.linear_end = linear_end147assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'148149to_torch = partial(torch.tensor, dtype=torch.float32)150151self.register_buffer('betas', to_torch(betas))152self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))153self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))154155# calculations for diffusion q(x_t | x_{t-1}) and others156self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))157self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))158self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))159self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))160self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))161162# calculations for posterior q(x_{t-1} | x_t, x_0)163posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (1641. - alphas_cumprod) + self.v_posterior * betas165# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)166self.register_buffer('posterior_variance', to_torch(posterior_variance))167# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain168self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))169self.register_buffer('posterior_mean_coef1', to_torch(170betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))171self.register_buffer('posterior_mean_coef2', to_torch(172(1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))173174if self.parameterization == "eps":175lvlb_weights = self.betas ** 2 / (1762 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))177elif self.parameterization == "x0":178lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))179else:180raise NotImplementedError("mu not supported")181# TODO how to choose this term182lvlb_weights[0] = lvlb_weights[1]183self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)184assert not torch.isnan(self.lvlb_weights).all()185186@contextmanager187def ema_scope(self, context=None):188if self.use_ema:189self.model_ema.store(self.model.parameters())190self.model_ema.copy_to(self.model)191if context is not None:192print(f"{context}: Switched to EMA weights")193try:194yield None195finally:196if self.use_ema:197self.model_ema.restore(self.model.parameters())198if context is not None:199print(f"{context}: Restored training weights")200201def init_from_ckpt(self, path, ignore_keys=None, only_model=False):202ignore_keys = ignore_keys or []203204sd = torch.load(path, map_location="cpu")205if "state_dict" in list(sd.keys()):206sd = sd["state_dict"]207keys = list(sd.keys())208209# Our model adds additional channels to the first layer to condition on an input image.210# For the first layer, copy existing channel weights and initialize new channel weights to zero.211input_keys = [212"model.diffusion_model.input_blocks.0.0.weight",213"model_ema.diffusion_modelinput_blocks00weight",214]215216self_sd = self.state_dict()217for input_key in input_keys:218if input_key not in sd or input_key not in self_sd:219continue220221input_weight = self_sd[input_key]222223if input_weight.size() != sd[input_key].size():224print(f"Manual init: {input_key}")225input_weight.zero_()226input_weight[:, :4, :, :].copy_(sd[input_key])227ignore_keys.append(input_key)228229for k in keys:230for ik in ignore_keys:231if k.startswith(ik):232print(f"Deleting key {k} from state_dict.")233del sd[k]234missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(235sd, strict=False)236print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")237if missing:238print(f"Missing Keys: {missing}")239if unexpected:240print(f"Unexpected Keys: {unexpected}")241242def q_mean_variance(self, x_start, t):243"""244Get the distribution q(x_t | x_0).245:param x_start: the [N x C x ...] tensor of noiseless inputs.246:param t: the number of diffusion steps (minus 1). Here, 0 means one step.247:return: A tuple (mean, variance, log_variance), all of x_start's shape.248"""249mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)250variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)251log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)252return mean, variance, log_variance253254def predict_start_from_noise(self, x_t, t, noise):255return (256extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -257extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise258)259260def q_posterior(self, x_start, x_t, t):261posterior_mean = (262extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +263extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t264)265posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)266posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)267return posterior_mean, posterior_variance, posterior_log_variance_clipped268269def p_mean_variance(self, x, t, clip_denoised: bool):270model_out = self.model(x, t)271if self.parameterization == "eps":272x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)273elif self.parameterization == "x0":274x_recon = model_out275if clip_denoised:276x_recon.clamp_(-1., 1.)277278model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)279return model_mean, posterior_variance, posterior_log_variance280281@torch.no_grad()282def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):283b, *_, device = *x.shape, x.device284model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)285noise = noise_like(x.shape, device, repeat_noise)286# no noise when t == 0287nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))288return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise289290@torch.no_grad()291def p_sample_loop(self, shape, return_intermediates=False):292device = self.betas.device293b = shape[0]294img = torch.randn(shape, device=device)295intermediates = [img]296for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):297img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),298clip_denoised=self.clip_denoised)299if i % self.log_every_t == 0 or i == self.num_timesteps - 1:300intermediates.append(img)301if return_intermediates:302return img, intermediates303return img304305@torch.no_grad()306def sample(self, batch_size=16, return_intermediates=False):307image_size = self.image_size308channels = self.channels309return self.p_sample_loop((batch_size, channels, image_size, image_size),310return_intermediates=return_intermediates)311312def q_sample(self, x_start, t, noise=None):313noise = default(noise, lambda: torch.randn_like(x_start))314return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +315extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)316317def get_loss(self, pred, target, mean=True):318if self.loss_type == 'l1':319loss = (target - pred).abs()320if mean:321loss = loss.mean()322elif self.loss_type == 'l2':323if mean:324loss = torch.nn.functional.mse_loss(target, pred)325else:326loss = torch.nn.functional.mse_loss(target, pred, reduction='none')327else:328raise NotImplementedError("unknown loss type '{loss_type}'")329330return loss331332def p_losses(self, x_start, t, noise=None):333noise = default(noise, lambda: torch.randn_like(x_start))334x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)335model_out = self.model(x_noisy, t)336337loss_dict = {}338if self.parameterization == "eps":339target = noise340elif self.parameterization == "x0":341target = x_start342else:343raise NotImplementedError(f"Parameterization {self.parameterization} not yet supported")344345loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])346347log_prefix = 'train' if self.training else 'val'348349loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})350loss_simple = loss.mean() * self.l_simple_weight351352loss_vlb = (self.lvlb_weights[t] * loss).mean()353loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})354355loss = loss_simple + self.original_elbo_weight * loss_vlb356357loss_dict.update({f'{log_prefix}/loss': loss})358359return loss, loss_dict360361def forward(self, x, *args, **kwargs):362# b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size363# assert h == img_size and w == img_size, f'height and width of image must be {img_size}'364t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()365return self.p_losses(x, t, *args, **kwargs)366367def get_input(self, batch, k):368return batch[k]369370def shared_step(self, batch):371x = self.get_input(batch, self.first_stage_key)372loss, loss_dict = self(x)373return loss, loss_dict374375def training_step(self, batch, batch_idx):376loss, loss_dict = self.shared_step(batch)377378self.log_dict(loss_dict, prog_bar=True,379logger=True, on_step=True, on_epoch=True)380381self.log("global_step", self.global_step,382prog_bar=True, logger=True, on_step=True, on_epoch=False)383384if self.use_scheduler:385lr = self.optimizers().param_groups[0]['lr']386self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)387388return loss389390@torch.no_grad()391def validation_step(self, batch, batch_idx):392_, loss_dict_no_ema = self.shared_step(batch)393with self.ema_scope():394_, loss_dict_ema = self.shared_step(batch)395loss_dict_ema = {f"{key}_ema": loss_dict_ema[key] for key in loss_dict_ema}396self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)397self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)398399def on_train_batch_end(self, *args, **kwargs):400if self.use_ema:401self.model_ema(self.model)402403def _get_rows_from_list(self, samples):404n_imgs_per_row = len(samples)405denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')406denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')407denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)408return denoise_grid409410@torch.no_grad()411def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):412log = {}413x = self.get_input(batch, self.first_stage_key)414N = min(x.shape[0], N)415n_row = min(x.shape[0], n_row)416x = x.to(self.device)[:N]417log["inputs"] = x418419# get diffusion row420diffusion_row = []421x_start = x[:n_row]422423for t in range(self.num_timesteps):424if t % self.log_every_t == 0 or t == self.num_timesteps - 1:425t = repeat(torch.tensor([t]), '1 -> b', b=n_row)426t = t.to(self.device).long()427noise = torch.randn_like(x_start)428x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)429diffusion_row.append(x_noisy)430431log["diffusion_row"] = self._get_rows_from_list(diffusion_row)432433if sample:434# get denoise row435with self.ema_scope("Plotting"):436samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)437438log["samples"] = samples439log["denoise_row"] = self._get_rows_from_list(denoise_row)440441if return_keys:442if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:443return log444else:445return {key: log[key] for key in return_keys}446return log447448def configure_optimizers(self):449lr = self.learning_rate450params = list(self.model.parameters())451if self.learn_logvar:452params = params + [self.logvar]453opt = torch.optim.AdamW(params, lr=lr)454return opt455456457class LatentDiffusion(DDPM):458"""main class"""459def __init__(self,460first_stage_config,461cond_stage_config,462num_timesteps_cond=None,463cond_stage_key="image",464cond_stage_trainable=False,465concat_mode=True,466cond_stage_forward=None,467conditioning_key=None,468scale_factor=1.0,469scale_by_std=False,470load_ema=True,471*args, **kwargs):472self.num_timesteps_cond = default(num_timesteps_cond, 1)473self.scale_by_std = scale_by_std474assert self.num_timesteps_cond <= kwargs['timesteps']475# for backwards compatibility after implementation of DiffusionWrapper476if conditioning_key is None:477conditioning_key = 'concat' if concat_mode else 'crossattn'478if cond_stage_config == '__is_unconditional__':479conditioning_key = None480ckpt_path = kwargs.pop("ckpt_path", None)481ignore_keys = kwargs.pop("ignore_keys", [])482super().__init__(*args, conditioning_key=conditioning_key, load_ema=load_ema, **kwargs)483self.concat_mode = concat_mode484self.cond_stage_trainable = cond_stage_trainable485self.cond_stage_key = cond_stage_key486try:487self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1488except Exception:489self.num_downs = 0490if not scale_by_std:491self.scale_factor = scale_factor492else:493self.register_buffer('scale_factor', torch.tensor(scale_factor))494self.instantiate_first_stage(first_stage_config)495self.instantiate_cond_stage(cond_stage_config)496self.cond_stage_forward = cond_stage_forward497self.clip_denoised = False498self.bbox_tokenizer = None499500self.restarted_from_ckpt = False501if ckpt_path is not None:502self.init_from_ckpt(ckpt_path, ignore_keys)503self.restarted_from_ckpt = True504505if self.use_ema and not load_ema:506self.model_ema = LitEma(self.model)507print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")508509def make_cond_schedule(self, ):510self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)511ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()512self.cond_ids[:self.num_timesteps_cond] = ids513514@rank_zero_only515@torch.no_grad()516def on_train_batch_start(self, batch, batch_idx, dataloader_idx):517# only for very first batch518if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:519assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'520# set rescale weight to 1./std of encodings521print("### USING STD-RESCALING ###")522x = super().get_input(batch, self.first_stage_key)523x = x.to(self.device)524encoder_posterior = self.encode_first_stage(x)525z = self.get_first_stage_encoding(encoder_posterior).detach()526del self.scale_factor527self.register_buffer('scale_factor', 1. / z.flatten().std())528print(f"setting self.scale_factor to {self.scale_factor}")529print("### USING STD-RESCALING ###")530531def register_schedule(self,532given_betas=None, beta_schedule="linear", timesteps=1000,533linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):534super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)535536self.shorten_cond_schedule = self.num_timesteps_cond > 1537if self.shorten_cond_schedule:538self.make_cond_schedule()539540def instantiate_first_stage(self, config):541model = instantiate_from_config(config)542self.first_stage_model = model.eval()543self.first_stage_model.train = disabled_train544for param in self.first_stage_model.parameters():545param.requires_grad = False546547def instantiate_cond_stage(self, config):548if not self.cond_stage_trainable:549if config == "__is_first_stage__":550print("Using first stage also as cond stage.")551self.cond_stage_model = self.first_stage_model552elif config == "__is_unconditional__":553print(f"Training {self.__class__.__name__} as an unconditional model.")554self.cond_stage_model = None555# self.be_unconditional = True556else:557model = instantiate_from_config(config)558self.cond_stage_model = model.eval()559self.cond_stage_model.train = disabled_train560for param in self.cond_stage_model.parameters():561param.requires_grad = False562else:563assert config != '__is_first_stage__'564assert config != '__is_unconditional__'565model = instantiate_from_config(config)566self.cond_stage_model = model567568def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):569denoise_row = []570for zd in tqdm(samples, desc=desc):571denoise_row.append(self.decode_first_stage(zd.to(self.device),572force_not_quantize=force_no_decoder_quantization))573n_imgs_per_row = len(denoise_row)574denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W575denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')576denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')577denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)578return denoise_grid579580def get_first_stage_encoding(self, encoder_posterior):581if isinstance(encoder_posterior, DiagonalGaussianDistribution):582z = encoder_posterior.sample()583elif isinstance(encoder_posterior, torch.Tensor):584z = encoder_posterior585else:586raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")587return self.scale_factor * z588589def get_learned_conditioning(self, c):590if self.cond_stage_forward is None:591if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):592c = self.cond_stage_model.encode(c)593if isinstance(c, DiagonalGaussianDistribution):594c = c.mode()595else:596c = self.cond_stage_model(c)597else:598assert hasattr(self.cond_stage_model, self.cond_stage_forward)599c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)600return c601602def meshgrid(self, h, w):603y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)604x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)605606arr = torch.cat([y, x], dim=-1)607return arr608609def delta_border(self, h, w):610"""611:param h: height612:param w: width613:return: normalized distance to image border,614wtith min distance = 0 at border and max dist = 0.5 at image center615"""616lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)617arr = self.meshgrid(h, w) / lower_right_corner618dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]619dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]620edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]621return edge_dist622623def get_weighting(self, h, w, Ly, Lx, device):624weighting = self.delta_border(h, w)625weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],626self.split_input_params["clip_max_weight"], )627weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)628629if self.split_input_params["tie_braker"]:630L_weighting = self.delta_border(Ly, Lx)631L_weighting = torch.clip(L_weighting,632self.split_input_params["clip_min_tie_weight"],633self.split_input_params["clip_max_tie_weight"])634635L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)636weighting = weighting * L_weighting637return weighting638639def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code640"""641:param x: img of size (bs, c, h, w)642:return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])643"""644bs, nc, h, w = x.shape645646# number of crops in image647Ly = (h - kernel_size[0]) // stride[0] + 1648Lx = (w - kernel_size[1]) // stride[1] + 1649650if uf == 1 and df == 1:651fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)652unfold = torch.nn.Unfold(**fold_params)653654fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)655656weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)657normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap658weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))659660elif uf > 1 and df == 1:661fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)662unfold = torch.nn.Unfold(**fold_params)663664fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),665dilation=1, padding=0,666stride=(stride[0] * uf, stride[1] * uf))667fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)668669weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)670normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap671weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))672673elif df > 1 and uf == 1:674fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)675unfold = torch.nn.Unfold(**fold_params)676677fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),678dilation=1, padding=0,679stride=(stride[0] // df, stride[1] // df))680fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)681682weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)683normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap684weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))685686else:687raise NotImplementedError688689return fold, unfold, normalization, weighting690691@torch.no_grad()692def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,693cond_key=None, return_original_cond=False, bs=None, uncond=0.05):694x = super().get_input(batch, k)695if bs is not None:696x = x[:bs]697x = x.to(self.device)698encoder_posterior = self.encode_first_stage(x)699z = self.get_first_stage_encoding(encoder_posterior).detach()700cond_key = cond_key or self.cond_stage_key701xc = super().get_input(batch, cond_key)702if bs is not None:703xc["c_crossattn"] = xc["c_crossattn"][:bs]704xc["c_concat"] = xc["c_concat"][:bs]705cond = {}706707# To support classifier-free guidance, randomly drop out only text conditioning 5%, only image conditioning 5%, and both 5%.708random = torch.rand(x.size(0), device=x.device)709prompt_mask = rearrange(random < 2 * uncond, "n -> n 1 1")710input_mask = 1 - rearrange((random >= uncond).float() * (random < 3 * uncond).float(), "n -> n 1 1 1")711712null_prompt = self.get_learned_conditioning([""])713cond["c_crossattn"] = [torch.where(prompt_mask, null_prompt, self.get_learned_conditioning(xc["c_crossattn"]).detach())]714cond["c_concat"] = [input_mask * self.encode_first_stage((xc["c_concat"].to(self.device))).mode().detach()]715716out = [z, cond]717if return_first_stage_outputs:718xrec = self.decode_first_stage(z)719out.extend([x, xrec])720if return_original_cond:721out.append(xc)722return out723724@torch.no_grad()725def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):726if predict_cids:727if z.dim() == 4:728z = torch.argmax(z.exp(), dim=1).long()729z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)730z = rearrange(z, 'b h w c -> b c h w').contiguous()731732z = 1. / self.scale_factor * z733734if hasattr(self, "split_input_params"):735if self.split_input_params["patch_distributed_vq"]:736ks = self.split_input_params["ks"] # eg. (128, 128)737stride = self.split_input_params["stride"] # eg. (64, 64)738uf = self.split_input_params["vqf"]739bs, nc, h, w = z.shape740if ks[0] > h or ks[1] > w:741ks = (min(ks[0], h), min(ks[1], w))742print("reducing Kernel")743744if stride[0] > h or stride[1] > w:745stride = (min(stride[0], h), min(stride[1], w))746print("reducing stride")747748fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)749750z = unfold(z) # (bn, nc * prod(**ks), L)751# 1. Reshape to img shape752z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )753754# 2. apply model loop over last dim755if isinstance(self.first_stage_model, VQModelInterface):756output_list = [self.first_stage_model.decode(z[:, :, :, :, i],757force_not_quantize=predict_cids or force_not_quantize)758for i in range(z.shape[-1])]759else:760761output_list = [self.first_stage_model.decode(z[:, :, :, :, i])762for i in range(z.shape[-1])]763764o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)765o = o * weighting766# Reverse 1. reshape to img shape767o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)768# stitch crops together769decoded = fold(o)770decoded = decoded / normalization # norm is shape (1, 1, h, w)771return decoded772else:773if isinstance(self.first_stage_model, VQModelInterface):774return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)775else:776return self.first_stage_model.decode(z)777778else:779if isinstance(self.first_stage_model, VQModelInterface):780return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)781else:782return self.first_stage_model.decode(z)783784# same as above but without decorator785def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):786if predict_cids:787if z.dim() == 4:788z = torch.argmax(z.exp(), dim=1).long()789z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)790z = rearrange(z, 'b h w c -> b c h w').contiguous()791792z = 1. / self.scale_factor * z793794if hasattr(self, "split_input_params"):795if self.split_input_params["patch_distributed_vq"]:796ks = self.split_input_params["ks"] # eg. (128, 128)797stride = self.split_input_params["stride"] # eg. (64, 64)798uf = self.split_input_params["vqf"]799bs, nc, h, w = z.shape800if ks[0] > h or ks[1] > w:801ks = (min(ks[0], h), min(ks[1], w))802print("reducing Kernel")803804if stride[0] > h or stride[1] > w:805stride = (min(stride[0], h), min(stride[1], w))806print("reducing stride")807808fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)809810z = unfold(z) # (bn, nc * prod(**ks), L)811# 1. Reshape to img shape812z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )813814# 2. apply model loop over last dim815if isinstance(self.first_stage_model, VQModelInterface):816output_list = [self.first_stage_model.decode(z[:, :, :, :, i],817force_not_quantize=predict_cids or force_not_quantize)818for i in range(z.shape[-1])]819else:820821output_list = [self.first_stage_model.decode(z[:, :, :, :, i])822for i in range(z.shape[-1])]823824o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)825o = o * weighting826# Reverse 1. reshape to img shape827o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)828# stitch crops together829decoded = fold(o)830decoded = decoded / normalization # norm is shape (1, 1, h, w)831return decoded832else:833if isinstance(self.first_stage_model, VQModelInterface):834return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)835else:836return self.first_stage_model.decode(z)837838else:839if isinstance(self.first_stage_model, VQModelInterface):840return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)841else:842return self.first_stage_model.decode(z)843844@torch.no_grad()845def encode_first_stage(self, x):846if hasattr(self, "split_input_params"):847if self.split_input_params["patch_distributed_vq"]:848ks = self.split_input_params["ks"] # eg. (128, 128)849stride = self.split_input_params["stride"] # eg. (64, 64)850df = self.split_input_params["vqf"]851self.split_input_params['original_image_size'] = x.shape[-2:]852bs, nc, h, w = x.shape853if ks[0] > h or ks[1] > w:854ks = (min(ks[0], h), min(ks[1], w))855print("reducing Kernel")856857if stride[0] > h or stride[1] > w:858stride = (min(stride[0], h), min(stride[1], w))859print("reducing stride")860861fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df)862z = unfold(x) # (bn, nc * prod(**ks), L)863# Reshape to img shape864z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )865866output_list = [self.first_stage_model.encode(z[:, :, :, :, i])867for i in range(z.shape[-1])]868869o = torch.stack(output_list, axis=-1)870o = o * weighting871872# Reverse reshape to img shape873o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)874# stitch crops together875decoded = fold(o)876decoded = decoded / normalization877return decoded878879else:880return self.first_stage_model.encode(x)881else:882return self.first_stage_model.encode(x)883884def shared_step(self, batch, **kwargs):885x, c = self.get_input(batch, self.first_stage_key)886loss = self(x, c)887return loss888889def forward(self, x, c, *args, **kwargs):890t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()891if self.model.conditioning_key is not None:892assert c is not None893if self.cond_stage_trainable:894c = self.get_learned_conditioning(c)895if self.shorten_cond_schedule: # TODO: drop this option896tc = self.cond_ids[t].to(self.device)897c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))898return self.p_losses(x, c, t, *args, **kwargs)899900def apply_model(self, x_noisy, t, cond, return_ids=False):901902if isinstance(cond, dict):903# hybrid case, cond is expected to be a dict904pass905else:906if not isinstance(cond, list):907cond = [cond]908key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'909cond = {key: cond}910911if hasattr(self, "split_input_params"):912assert len(cond) == 1 # todo can only deal with one conditioning atm913assert not return_ids914ks = self.split_input_params["ks"] # eg. (128, 128)915stride = self.split_input_params["stride"] # eg. (64, 64)916917h, w = x_noisy.shape[-2:]918919fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride)920921z = unfold(x_noisy) # (bn, nc * prod(**ks), L)922# Reshape to img shape923z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )924z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])]925926if self.cond_stage_key in ["image", "LR_image", "segmentation",927'bbox_img'] and self.model.conditioning_key: # todo check for completeness928c_key = next(iter(cond.keys())) # get key929c = next(iter(cond.values())) # get value930assert (len(c) == 1) # todo extend to list with more than one elem931c = c[0] # get element932933c = unfold(c)934c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L )935936cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])]937938elif self.cond_stage_key == 'coordinates_bbox':939assert 'original_image_size' in self.split_input_params, 'BoundingBoxRescaling is missing original_image_size'940941# assuming padding of unfold is always 0 and its dilation is always 1942n_patches_per_row = int((w - ks[0]) / stride[0] + 1)943full_img_h, full_img_w = self.split_input_params['original_image_size']944# as we are operating on latents, we need the factor from the original image size to the945# spatial latent size to properly rescale the crops for regenerating the bbox annotations946num_downs = self.first_stage_model.encoder.num_resolutions - 1947rescale_latent = 2 ** (num_downs)948949# get top left positions of patches as conforming for the bbbox tokenizer, therefore we950# need to rescale the tl patch coordinates to be in between (0,1)951tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w,952rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h)953for patch_nr in range(z.shape[-1])]954955# patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w)956patch_limits = [(x_tl, y_tl,957rescale_latent * ks[0] / full_img_w,958rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates]959# patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates]960961# tokenize crop coordinates for the bounding boxes of the respective patches962patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device)963for bbox in patch_limits] # list of length l with tensors of shape (1, 2)964print(patch_limits_tknzd[0].shape)965# cut tknzd crop position from conditioning966assert isinstance(cond, dict), 'cond must be dict to be fed into model'967cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device)968print(cut_cond.shape)969970adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd])971adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n')972print(adapted_cond.shape)973adapted_cond = self.get_learned_conditioning(adapted_cond)974print(adapted_cond.shape)975adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1])976print(adapted_cond.shape)977978cond_list = [{'c_crossattn': [e]} for e in adapted_cond]979980else:981cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient982983# apply model by loop over crops984output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])]985assert not isinstance(output_list[0],986tuple) # todo cant deal with multiple model outputs check this never happens987988o = torch.stack(output_list, axis=-1)989o = o * weighting990# Reverse reshape to img shape991o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)992# stitch crops together993x_recon = fold(o) / normalization994995else:996x_recon = self.model(x_noisy, t, **cond)997998if isinstance(x_recon, tuple) and not return_ids:999return x_recon[0]1000else:1001return x_recon10021003def _predict_eps_from_xstart(self, x_t, t, pred_xstart):1004return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \1005extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)10061007def _prior_bpd(self, x_start):1008"""1009Get the prior KL term for the variational lower-bound, measured in1010bits-per-dim.1011This term can't be optimized, as it only depends on the encoder.1012:param x_start: the [N x C x ...] tensor of inputs.1013:return: a batch of [N] KL values (in bits), one per batch element.1014"""1015batch_size = x_start.shape[0]1016t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)1017qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)1018kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)1019return mean_flat(kl_prior) / np.log(2.0)10201021def p_losses(self, x_start, cond, t, noise=None):1022noise = default(noise, lambda: torch.randn_like(x_start))1023x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)1024model_output = self.apply_model(x_noisy, t, cond)10251026loss_dict = {}1027prefix = 'train' if self.training else 'val'10281029if self.parameterization == "x0":1030target = x_start1031elif self.parameterization == "eps":1032target = noise1033else:1034raise NotImplementedError()10351036loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])1037loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})10381039logvar_t = self.logvar[t].to(self.device)1040loss = loss_simple / torch.exp(logvar_t) + logvar_t1041# loss = loss_simple / torch.exp(self.logvar) + self.logvar1042if self.learn_logvar:1043loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})1044loss_dict.update({'logvar': self.logvar.data.mean()})10451046loss = self.l_simple_weight * loss.mean()10471048loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))1049loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()1050loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})1051loss += (self.original_elbo_weight * loss_vlb)1052loss_dict.update({f'{prefix}/loss': loss})10531054return loss, loss_dict10551056def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,1057return_x0=False, score_corrector=None, corrector_kwargs=None):1058t_in = t1059model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)10601061if score_corrector is not None:1062assert self.parameterization == "eps"1063model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)10641065if return_codebook_ids:1066model_out, logits = model_out10671068if self.parameterization == "eps":1069x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)1070elif self.parameterization == "x0":1071x_recon = model_out1072else:1073raise NotImplementedError()10741075if clip_denoised:1076x_recon.clamp_(-1., 1.)1077if quantize_denoised:1078x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)1079model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)1080if return_codebook_ids:1081return model_mean, posterior_variance, posterior_log_variance, logits1082elif return_x0:1083return model_mean, posterior_variance, posterior_log_variance, x_recon1084else:1085return model_mean, posterior_variance, posterior_log_variance10861087@torch.no_grad()1088def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,1089return_codebook_ids=False, quantize_denoised=False, return_x0=False,1090temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):1091b, *_, device = *x.shape, x.device1092outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,1093return_codebook_ids=return_codebook_ids,1094quantize_denoised=quantize_denoised,1095return_x0=return_x0,1096score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)1097if return_codebook_ids:1098raise DeprecationWarning("Support dropped.")1099model_mean, _, model_log_variance, logits = outputs1100elif return_x0:1101model_mean, _, model_log_variance, x0 = outputs1102else:1103model_mean, _, model_log_variance = outputs11041105noise = noise_like(x.shape, device, repeat_noise) * temperature1106if noise_dropout > 0.:1107noise = torch.nn.functional.dropout(noise, p=noise_dropout)1108# no noise when t == 01109nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))11101111if return_codebook_ids:1112return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)1113if return_x0:1114return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x01115else:1116return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise11171118@torch.no_grad()1119def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,1120img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,1121score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,1122log_every_t=None):1123if not log_every_t:1124log_every_t = self.log_every_t1125timesteps = self.num_timesteps1126if batch_size is not None:1127b = batch_size if batch_size is not None else shape[0]1128shape = [batch_size] + list(shape)1129else:1130b = batch_size = shape[0]1131if x_T is None:1132img = torch.randn(shape, device=self.device)1133else:1134img = x_T1135intermediates = []1136if cond is not None:1137if isinstance(cond, dict):1138cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else1139[x[:batch_size] for x in cond[key]] for key in cond}1140else:1141cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]11421143if start_T is not None:1144timesteps = min(timesteps, start_T)1145iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',1146total=timesteps) if verbose else reversed(1147range(0, timesteps))1148if type(temperature) == float:1149temperature = [temperature] * timesteps11501151for i in iterator:1152ts = torch.full((b,), i, device=self.device, dtype=torch.long)1153if self.shorten_cond_schedule:1154assert self.model.conditioning_key != 'hybrid'1155tc = self.cond_ids[ts].to(cond.device)1156cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))11571158img, x0_partial = self.p_sample(img, cond, ts,1159clip_denoised=self.clip_denoised,1160quantize_denoised=quantize_denoised, return_x0=True,1161temperature=temperature[i], noise_dropout=noise_dropout,1162score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)1163if mask is not None:1164assert x0 is not None1165img_orig = self.q_sample(x0, ts)1166img = img_orig * mask + (1. - mask) * img11671168if i % log_every_t == 0 or i == timesteps - 1:1169intermediates.append(x0_partial)1170if callback:1171callback(i)1172if img_callback:1173img_callback(img, i)1174return img, intermediates11751176@torch.no_grad()1177def p_sample_loop(self, cond, shape, return_intermediates=False,1178x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,1179mask=None, x0=None, img_callback=None, start_T=None,1180log_every_t=None):11811182if not log_every_t:1183log_every_t = self.log_every_t1184device = self.betas.device1185b = shape[0]1186if x_T is None:1187img = torch.randn(shape, device=device)1188else:1189img = x_T11901191intermediates = [img]1192if timesteps is None:1193timesteps = self.num_timesteps11941195if start_T is not None:1196timesteps = min(timesteps, start_T)1197iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(1198range(0, timesteps))11991200if mask is not None:1201assert x0 is not None1202assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match12031204for i in iterator:1205ts = torch.full((b,), i, device=device, dtype=torch.long)1206if self.shorten_cond_schedule:1207assert self.model.conditioning_key != 'hybrid'1208tc = self.cond_ids[ts].to(cond.device)1209cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))12101211img = self.p_sample(img, cond, ts,1212clip_denoised=self.clip_denoised,1213quantize_denoised=quantize_denoised)1214if mask is not None:1215img_orig = self.q_sample(x0, ts)1216img = img_orig * mask + (1. - mask) * img12171218if i % log_every_t == 0 or i == timesteps - 1:1219intermediates.append(img)1220if callback:1221callback(i)1222if img_callback:1223img_callback(img, i)12241225if return_intermediates:1226return img, intermediates1227return img12281229@torch.no_grad()1230def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,1231verbose=True, timesteps=None, quantize_denoised=False,1232mask=None, x0=None, shape=None,**kwargs):1233if shape is None:1234shape = (batch_size, self.channels, self.image_size, self.image_size)1235if cond is not None:1236if isinstance(cond, dict):1237cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else1238[x[:batch_size] for x in cond[key]] for key in cond}1239else:1240cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]1241return self.p_sample_loop(cond,1242shape,1243return_intermediates=return_intermediates, x_T=x_T,1244verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,1245mask=mask, x0=x0)12461247@torch.no_grad()1248def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):12491250if ddim:1251ddim_sampler = DDIMSampler(self)1252shape = (self.channels, self.image_size, self.image_size)1253samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size,1254shape,cond,verbose=False,**kwargs)12551256else:1257samples, intermediates = self.sample(cond=cond, batch_size=batch_size,1258return_intermediates=True,**kwargs)12591260return samples, intermediates126112621263@torch.no_grad()1264def log_images(self, batch, N=4, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,1265quantize_denoised=True, inpaint=False, plot_denoise_rows=False, plot_progressive_rows=False,1266plot_diffusion_rows=False, **kwargs):12671268use_ddim = False12691270log = {}1271z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,1272return_first_stage_outputs=True,1273force_c_encode=True,1274return_original_cond=True,1275bs=N, uncond=0)1276N = min(x.shape[0], N)1277n_row = min(x.shape[0], n_row)1278log["inputs"] = x1279log["reals"] = xc["c_concat"]1280log["reconstruction"] = xrec1281if self.model.conditioning_key is not None:1282if hasattr(self.cond_stage_model, "decode"):1283xc = self.cond_stage_model.decode(c)1284log["conditioning"] = xc1285elif self.cond_stage_key in ["caption"]:1286xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"])1287log["conditioning"] = xc1288elif self.cond_stage_key == 'class_label':1289xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])1290log['conditioning'] = xc1291elif isimage(xc):1292log["conditioning"] = xc1293if ismap(xc):1294log["original_conditioning"] = self.to_rgb(xc)12951296if plot_diffusion_rows:1297# get diffusion row1298diffusion_row = []1299z_start = z[:n_row]1300for t in range(self.num_timesteps):1301if t % self.log_every_t == 0 or t == self.num_timesteps - 1:1302t = repeat(torch.tensor([t]), '1 -> b', b=n_row)1303t = t.to(self.device).long()1304noise = torch.randn_like(z_start)1305z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)1306diffusion_row.append(self.decode_first_stage(z_noisy))13071308diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W1309diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')1310diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')1311diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])1312log["diffusion_row"] = diffusion_grid13131314if sample:1315# get denoise row1316with self.ema_scope("Plotting"):1317samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,1318ddim_steps=ddim_steps,eta=ddim_eta)1319# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)1320x_samples = self.decode_first_stage(samples)1321log["samples"] = x_samples1322if plot_denoise_rows:1323denoise_grid = self._get_denoise_row_from_list(z_denoise_row)1324log["denoise_row"] = denoise_grid13251326if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(1327self.first_stage_model, IdentityFirstStage):1328# also display when quantizing x0 while sampling1329with self.ema_scope("Plotting Quantized Denoised"):1330samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,1331ddim_steps=ddim_steps,eta=ddim_eta,1332quantize_denoised=True)1333# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,1334# quantize_denoised=True)1335x_samples = self.decode_first_stage(samples.to(self.device))1336log["samples_x0_quantized"] = x_samples13371338if inpaint:1339# make a simple center square1340h, w = z.shape[2], z.shape[3]1341mask = torch.ones(N, h, w).to(self.device)1342# zeros will be filled in1343mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.1344mask = mask[:, None, ...]1345with self.ema_scope("Plotting Inpaint"):13461347samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta,1348ddim_steps=ddim_steps, x0=z[:N], mask=mask)1349x_samples = self.decode_first_stage(samples.to(self.device))1350log["samples_inpainting"] = x_samples1351log["mask"] = mask13521353# outpaint1354with self.ema_scope("Plotting Outpaint"):1355samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,1356ddim_steps=ddim_steps, x0=z[:N], mask=mask)1357x_samples = self.decode_first_stage(samples.to(self.device))1358log["samples_outpainting"] = x_samples13591360if plot_progressive_rows:1361with self.ema_scope("Plotting Progressives"):1362img, progressives = self.progressive_denoising(c,1363shape=(self.channels, self.image_size, self.image_size),1364batch_size=N)1365prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")1366log["progressive_row"] = prog_row13671368if return_keys:1369if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:1370return log1371else:1372return {key: log[key] for key in return_keys}1373return log13741375def configure_optimizers(self):1376lr = self.learning_rate1377params = list(self.model.parameters())1378if self.cond_stage_trainable:1379print(f"{self.__class__.__name__}: Also optimizing conditioner params!")1380params = params + list(self.cond_stage_model.parameters())1381if self.learn_logvar:1382print('Diffusion model optimizing logvar')1383params.append(self.logvar)1384opt = torch.optim.AdamW(params, lr=lr)1385if self.use_scheduler:1386assert 'target' in self.scheduler_config1387scheduler = instantiate_from_config(self.scheduler_config)13881389print("Setting up LambdaLR scheduler...")1390scheduler = [1391{1392'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),1393'interval': 'step',1394'frequency': 11395}]1396return [opt], scheduler1397return opt13981399@torch.no_grad()1400def to_rgb(self, x):1401x = x.float()1402if not hasattr(self, "colorize"):1403self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)1404x = nn.functional.conv2d(x, weight=self.colorize)1405x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.1406return x140714081409class DiffusionWrapper(pl.LightningModule):1410def __init__(self, diff_model_config, conditioning_key):1411super().__init__()1412self.diffusion_model = instantiate_from_config(diff_model_config)1413self.conditioning_key = conditioning_key1414assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm']14151416def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):1417if self.conditioning_key is None:1418out = self.diffusion_model(x, t)1419elif self.conditioning_key == 'concat':1420xc = torch.cat([x] + c_concat, dim=1)1421out = self.diffusion_model(xc, t)1422elif self.conditioning_key == 'crossattn':1423cc = torch.cat(c_crossattn, 1)1424out = self.diffusion_model(x, t, context=cc)1425elif self.conditioning_key == 'hybrid':1426xc = torch.cat([x] + c_concat, dim=1)1427cc = torch.cat(c_crossattn, 1)1428out = self.diffusion_model(xc, t, context=cc)1429elif self.conditioning_key == 'adm':1430cc = c_crossattn[0]1431out = self.diffusion_model(x, t, y=cc)1432else:1433raise NotImplementedError()14341435return out143614371438class Layout2ImgDiffusion(LatentDiffusion):1439# TODO: move all layout-specific hacks to this class1440def __init__(self, cond_stage_key, *args, **kwargs):1441assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'1442super().__init__(*args, cond_stage_key=cond_stage_key, **kwargs)14431444def log_images(self, batch, N=8, *args, **kwargs):1445logs = super().log_images(*args, batch=batch, N=N, **kwargs)14461447key = 'train' if self.training else 'validation'1448dset = self.trainer.datamodule.datasets[key]1449mapper = dset.conditional_builders[self.cond_stage_key]14501451bbox_imgs = []1452map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno))1453for tknzd_bbox in batch[self.cond_stage_key][:N]:1454bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256))1455bbox_imgs.append(bboximg)14561457cond_img = torch.stack(bbox_imgs, dim=0)1458logs['bbox_image'] = cond_img1459return logs146014611462