Path: blob/master/extensions-builtin/LDSR/sd_hijack_autoencoder.py
2447 views
# The content of this file comes from the ldm/models/autoencoder.py file of the compvis/stable-diffusion repo1# The VQModel & VQModelInterface were subsequently removed from ldm/models/autoencoder.py when we moved to the stability-ai/stablediffusion repo2# As the LDSR upscaler relies on VQModel & VQModelInterface, the hijack aims to put them back into the ldm.models.autoencoder3import numpy as np4import torch5import pytorch_lightning as pl6import torch.nn.functional as F7from contextlib import contextmanager89from torch.optim.lr_scheduler import LambdaLR1011from ldm.modules.ema import LitEma12from vqvae_quantize import VectorQuantizer2 as VectorQuantizer13from ldm.modules.diffusionmodules.model import Encoder, Decoder14from ldm.util import instantiate_from_config1516import ldm.models.autoencoder17from packaging import version1819class VQModel(pl.LightningModule):20def __init__(self,21ddconfig,22lossconfig,23n_embed,24embed_dim,25ckpt_path=None,26ignore_keys=None,27image_key="image",28colorize_nlabels=None,29monitor=None,30batch_resize_range=None,31scheduler_config=None,32lr_g_factor=1.0,33remap=None,34sane_index_shape=False, # tell vector quantizer to return indices as bhw35use_ema=False36):37super().__init__()38self.embed_dim = embed_dim39self.n_embed = n_embed40self.image_key = image_key41self.encoder = Encoder(**ddconfig)42self.decoder = Decoder(**ddconfig)43self.loss = instantiate_from_config(lossconfig)44self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,45remap=remap,46sane_index_shape=sane_index_shape)47self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)48self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)49if colorize_nlabels is not None:50assert type(colorize_nlabels)==int51self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))52if monitor is not None:53self.monitor = monitor54self.batch_resize_range = batch_resize_range55if self.batch_resize_range is not None:56print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")5758self.use_ema = use_ema59if self.use_ema:60self.model_ema = LitEma(self)61print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")6263if ckpt_path is not None:64self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [])65self.scheduler_config = scheduler_config66self.lr_g_factor = lr_g_factor6768@contextmanager69def ema_scope(self, context=None):70if self.use_ema:71self.model_ema.store(self.parameters())72self.model_ema.copy_to(self)73if context is not None:74print(f"{context}: Switched to EMA weights")75try:76yield None77finally:78if self.use_ema:79self.model_ema.restore(self.parameters())80if context is not None:81print(f"{context}: Restored training weights")8283def init_from_ckpt(self, path, ignore_keys=None):84sd = torch.load(path, map_location="cpu")["state_dict"]85keys = list(sd.keys())86for k in keys:87for ik in ignore_keys or []:88if k.startswith(ik):89print("Deleting key {} from state_dict.".format(k))90del sd[k]91missing, unexpected = self.load_state_dict(sd, strict=False)92print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")93if missing:94print(f"Missing Keys: {missing}")95if unexpected:96print(f"Unexpected Keys: {unexpected}")9798def on_train_batch_end(self, *args, **kwargs):99if self.use_ema:100self.model_ema(self)101102def encode(self, x):103h = self.encoder(x)104h = self.quant_conv(h)105quant, emb_loss, info = self.quantize(h)106return quant, emb_loss, info107108def encode_to_prequant(self, x):109h = self.encoder(x)110h = self.quant_conv(h)111return h112113def decode(self, quant):114quant = self.post_quant_conv(quant)115dec = self.decoder(quant)116return dec117118def decode_code(self, code_b):119quant_b = self.quantize.embed_code(code_b)120dec = self.decode(quant_b)121return dec122123def forward(self, input, return_pred_indices=False):124quant, diff, (_,_,ind) = self.encode(input)125dec = self.decode(quant)126if return_pred_indices:127return dec, diff, ind128return dec, diff129130def get_input(self, batch, k):131x = batch[k]132if len(x.shape) == 3:133x = x[..., None]134x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()135if self.batch_resize_range is not None:136lower_size = self.batch_resize_range[0]137upper_size = self.batch_resize_range[1]138if self.global_step <= 4:139# do the first few batches with max size to avoid later oom140new_resize = upper_size141else:142new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))143if new_resize != x.shape[2]:144x = F.interpolate(x, size=new_resize, mode="bicubic")145x = x.detach()146return x147148def training_step(self, batch, batch_idx, optimizer_idx):149# https://github.com/pytorch/pytorch/issues/37142150# try not to fool the heuristics151x = self.get_input(batch, self.image_key)152xrec, qloss, ind = self(x, return_pred_indices=True)153154if optimizer_idx == 0:155# autoencode156aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,157last_layer=self.get_last_layer(), split="train",158predicted_indices=ind)159160self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)161return aeloss162163if optimizer_idx == 1:164# discriminator165discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,166last_layer=self.get_last_layer(), split="train")167self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)168return discloss169170def validation_step(self, batch, batch_idx):171log_dict = self._validation_step(batch, batch_idx)172with self.ema_scope():173self._validation_step(batch, batch_idx, suffix="_ema")174return log_dict175176def _validation_step(self, batch, batch_idx, suffix=""):177x = self.get_input(batch, self.image_key)178xrec, qloss, ind = self(x, return_pred_indices=True)179aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,180self.global_step,181last_layer=self.get_last_layer(),182split="val"+suffix,183predicted_indices=ind184)185186discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,187self.global_step,188last_layer=self.get_last_layer(),189split="val"+suffix,190predicted_indices=ind191)192rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]193self.log(f"val{suffix}/rec_loss", rec_loss,194prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)195self.log(f"val{suffix}/aeloss", aeloss,196prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)197if version.parse(pl.__version__) >= version.parse('1.4.0'):198del log_dict_ae[f"val{suffix}/rec_loss"]199self.log_dict(log_dict_ae)200self.log_dict(log_dict_disc)201return self.log_dict202203def configure_optimizers(self):204lr_d = self.learning_rate205lr_g = self.lr_g_factor*self.learning_rate206print("lr_d", lr_d)207print("lr_g", lr_g)208opt_ae = torch.optim.Adam(list(self.encoder.parameters())+209list(self.decoder.parameters())+210list(self.quantize.parameters())+211list(self.quant_conv.parameters())+212list(self.post_quant_conv.parameters()),213lr=lr_g, betas=(0.5, 0.9))214opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),215lr=lr_d, betas=(0.5, 0.9))216217if self.scheduler_config is not None:218scheduler = instantiate_from_config(self.scheduler_config)219220print("Setting up LambdaLR scheduler...")221scheduler = [222{223'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),224'interval': 'step',225'frequency': 1226},227{228'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),229'interval': 'step',230'frequency': 1231},232]233return [opt_ae, opt_disc], scheduler234return [opt_ae, opt_disc], []235236def get_last_layer(self):237return self.decoder.conv_out.weight238239def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):240log = {}241x = self.get_input(batch, self.image_key)242x = x.to(self.device)243if only_inputs:244log["inputs"] = x245return log246xrec, _ = self(x)247if x.shape[1] > 3:248# colorize with random projection249assert xrec.shape[1] > 3250x = self.to_rgb(x)251xrec = self.to_rgb(xrec)252log["inputs"] = x253log["reconstructions"] = xrec254if plot_ema:255with self.ema_scope():256xrec_ema, _ = self(x)257if x.shape[1] > 3:258xrec_ema = self.to_rgb(xrec_ema)259log["reconstructions_ema"] = xrec_ema260return log261262def to_rgb(self, x):263assert self.image_key == "segmentation"264if not hasattr(self, "colorize"):265self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))266x = F.conv2d(x, weight=self.colorize)267x = 2.*(x-x.min())/(x.max()-x.min()) - 1.268return x269270271class VQModelInterface(VQModel):272def __init__(self, embed_dim, *args, **kwargs):273super().__init__(*args, embed_dim=embed_dim, **kwargs)274self.embed_dim = embed_dim275276def encode(self, x):277h = self.encoder(x)278h = self.quant_conv(h)279return h280281def decode(self, h, force_not_quantize=False):282# also go through quantization layer283if not force_not_quantize:284quant, emb_loss, info = self.quantize(h)285else:286quant = h287quant = self.post_quant_conv(quant)288dec = self.decoder(quant)289return dec290291ldm.models.autoencoder.VQModel = VQModel292ldm.models.autoencoder.VQModelInterface = VQModelInterface293294295