Path: blob/master/extensions-builtin/LDSR/ldsr_model_arch.py
2447 views
import os1import gc2import time34import numpy as np5import torch6import torchvision7from PIL import Image8from einops import rearrange, repeat9from omegaconf import OmegaConf10import safetensors.torch1112from ldm.models.diffusion.ddim import DDIMSampler13from ldm.util import instantiate_from_config, ismap14from modules import shared, sd_hijack, devices1516cached_ldsr_model: torch.nn.Module = None171819# Create LDSR Class20class LDSR:21def load_model_from_config(self, half_attention):22global cached_ldsr_model2324if shared.opts.ldsr_cached and cached_ldsr_model is not None:25print("Loading model from cache")26model: torch.nn.Module = cached_ldsr_model27else:28print(f"Loading model from {self.modelPath}")29_, extension = os.path.splitext(self.modelPath)30if extension.lower() == ".safetensors":31pl_sd = safetensors.torch.load_file(self.modelPath, device="cpu")32else:33pl_sd = torch.load(self.modelPath, map_location="cpu")34sd = pl_sd["state_dict"] if "state_dict" in pl_sd else pl_sd35config = OmegaConf.load(self.yamlPath)36config.model.target = "ldm.models.diffusion.ddpm.LatentDiffusionV1"37model: torch.nn.Module = instantiate_from_config(config.model)38model.load_state_dict(sd, strict=False)39model = model.to(shared.device)40if half_attention:41model = model.half()42if shared.cmd_opts.opt_channelslast:43model = model.to(memory_format=torch.channels_last)4445sd_hijack.model_hijack.hijack(model) # apply optimization46model.eval()4748if shared.opts.ldsr_cached:49cached_ldsr_model = model5051return {"model": model}5253def __init__(self, model_path, yaml_path):54self.modelPath = model_path55self.yamlPath = yaml_path5657@staticmethod58def run(model, selected_path, custom_steps, eta):59example = get_cond(selected_path)6061n_runs = 162guider = None63ckwargs = None64ddim_use_x0_pred = False65temperature = 1.66eta = eta67custom_shape = None6869height, width = example["image"].shape[1:3]70split_input = height >= 128 and width >= 1287172if split_input:73ks = 12874stride = 6475vqf = 4 #76model.split_input_params = {"ks": (ks, ks), "stride": (stride, stride),77"vqf": vqf,78"patch_distributed_vq": True,79"tie_braker": False,80"clip_max_weight": 0.5,81"clip_min_weight": 0.01,82"clip_max_tie_weight": 0.5,83"clip_min_tie_weight": 0.01}84else:85if hasattr(model, "split_input_params"):86delattr(model, "split_input_params")8788x_t = None89logs = None90for _ in range(n_runs):91if custom_shape is not None:92x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)93x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0])9495logs = make_convolutional_sample(example, model,96custom_steps=custom_steps,97eta=eta, quantize_x0=False,98custom_shape=custom_shape,99temperature=temperature, noise_dropout=0.,100corrector=guider, corrector_kwargs=ckwargs, x_T=x_t,101ddim_use_x0_pred=ddim_use_x0_pred102)103return logs104105def super_resolution(self, image, steps=100, target_scale=2, half_attention=False):106model = self.load_model_from_config(half_attention)107108# Run settings109diffusion_steps = int(steps)110eta = 1.0111112113gc.collect()114devices.torch_gc()115116im_og = image117width_og, height_og = im_og.size118# If we can adjust the max upscale size, then the 4 below should be our variable119down_sample_rate = target_scale / 4120wd = width_og * down_sample_rate121hd = height_og * down_sample_rate122width_downsampled_pre = int(np.ceil(wd))123height_downsampled_pre = int(np.ceil(hd))124125if down_sample_rate != 1:126print(127f'Downsampling from [{width_og}, {height_og}] to [{width_downsampled_pre}, {height_downsampled_pre}]')128im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)129else:130print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)")131132# pad width and height to multiples of 64, pads with the edge values of image to avoid artifacts133pad_w, pad_h = np.max(((2, 2), np.ceil(np.array(im_og.size) / 64).astype(int)), axis=0) * 64 - im_og.size134im_padded = Image.fromarray(np.pad(np.array(im_og), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))135136logs = self.run(model["model"], im_padded, diffusion_steps, eta)137138sample = logs["sample"]139sample = sample.detach().cpu()140sample = torch.clamp(sample, -1., 1.)141sample = (sample + 1.) / 2. * 255142sample = sample.numpy().astype(np.uint8)143sample = np.transpose(sample, (0, 2, 3, 1))144a = Image.fromarray(sample[0])145146# remove padding147a = a.crop((0, 0) + tuple(np.array(im_og.size) * 4))148149del model150gc.collect()151devices.torch_gc()152153return a154155156def get_cond(selected_path):157example = {}158up_f = 4159c = selected_path.convert('RGB')160c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)161c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]],162antialias=True)163c_up = rearrange(c_up, '1 c h w -> 1 h w c')164c = rearrange(c, '1 c h w -> 1 h w c')165c = 2. * c - 1.166167c = c.to(shared.device)168example["LR_image"] = c169example["image"] = c_up170171return example172173174@torch.no_grad()175def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_sequence=None,176mask=None, x0=None, quantize_x0=False, temperature=1., score_corrector=None,177corrector_kwargs=None, x_t=None178):179ddim = DDIMSampler(model)180bs = shape[0]181shape = shape[1:]182print(f"Sampling with eta = {eta}; steps: {steps}")183samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, conditioning=cond, callback=callback,184normals_sequence=normals_sequence, quantize_x0=quantize_x0, eta=eta,185mask=mask, x0=x0, temperature=temperature, verbose=False,186score_corrector=score_corrector,187corrector_kwargs=corrector_kwargs, x_t=x_t)188189return samples, intermediates190191192@torch.no_grad()193def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,194corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False):195log = {}196197z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,198return_first_stage_outputs=True,199force_c_encode=not (hasattr(model, 'split_input_params')200and model.cond_stage_key == 'coordinates_bbox'),201return_original_cond=True)202203if custom_shape is not None:204z = torch.randn(custom_shape)205print(f"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}")206207z0 = None208209log["input"] = x210log["reconstruction"] = xrec211212if ismap(xc):213log["original_conditioning"] = model.to_rgb(xc)214if hasattr(model, 'cond_stage_key'):215log[model.cond_stage_key] = model.to_rgb(xc)216217else:218log["original_conditioning"] = xc if xc is not None else torch.zeros_like(x)219if model.cond_stage_model:220log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x)221if model.cond_stage_key == 'class_label':222log[model.cond_stage_key] = xc[model.cond_stage_key]223224with model.ema_scope("Plotting"):225t0 = time.time()226227sample, intermediates = convsample_ddim(model, c, steps=custom_steps, shape=z.shape,228eta=eta,229quantize_x0=quantize_x0, mask=None, x0=z0,230temperature=temperature, score_corrector=corrector, corrector_kwargs=corrector_kwargs,231x_t=x_T)232t1 = time.time()233234if ddim_use_x0_pred:235sample = intermediates['pred_x0'][-1]236237x_sample = model.decode_first_stage(sample)238239try:240x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)241log["sample_noquant"] = x_sample_noquant242log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)243except Exception:244pass245246log["sample"] = x_sample247log["time"] = t1 - t0248249return log250251252