Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
AUTOMATIC1111
GitHub Repository: AUTOMATIC1111/stable-diffusion-webui
Path: blob/master/extensions-builtin/LDSR/ldsr_model_arch.py
2447 views
1
import os
2
import gc
3
import time
4
5
import numpy as np
6
import torch
7
import torchvision
8
from PIL import Image
9
from einops import rearrange, repeat
10
from omegaconf import OmegaConf
11
import safetensors.torch
12
13
from ldm.models.diffusion.ddim import DDIMSampler
14
from ldm.util import instantiate_from_config, ismap
15
from modules import shared, sd_hijack, devices
16
17
cached_ldsr_model: torch.nn.Module = None
18
19
20
# Create LDSR Class
21
class LDSR:
22
def load_model_from_config(self, half_attention):
23
global cached_ldsr_model
24
25
if shared.opts.ldsr_cached and cached_ldsr_model is not None:
26
print("Loading model from cache")
27
model: torch.nn.Module = cached_ldsr_model
28
else:
29
print(f"Loading model from {self.modelPath}")
30
_, extension = os.path.splitext(self.modelPath)
31
if extension.lower() == ".safetensors":
32
pl_sd = safetensors.torch.load_file(self.modelPath, device="cpu")
33
else:
34
pl_sd = torch.load(self.modelPath, map_location="cpu")
35
sd = pl_sd["state_dict"] if "state_dict" in pl_sd else pl_sd
36
config = OmegaConf.load(self.yamlPath)
37
config.model.target = "ldm.models.diffusion.ddpm.LatentDiffusionV1"
38
model: torch.nn.Module = instantiate_from_config(config.model)
39
model.load_state_dict(sd, strict=False)
40
model = model.to(shared.device)
41
if half_attention:
42
model = model.half()
43
if shared.cmd_opts.opt_channelslast:
44
model = model.to(memory_format=torch.channels_last)
45
46
sd_hijack.model_hijack.hijack(model) # apply optimization
47
model.eval()
48
49
if shared.opts.ldsr_cached:
50
cached_ldsr_model = model
51
52
return {"model": model}
53
54
def __init__(self, model_path, yaml_path):
55
self.modelPath = model_path
56
self.yamlPath = yaml_path
57
58
@staticmethod
59
def run(model, selected_path, custom_steps, eta):
60
example = get_cond(selected_path)
61
62
n_runs = 1
63
guider = None
64
ckwargs = None
65
ddim_use_x0_pred = False
66
temperature = 1.
67
eta = eta
68
custom_shape = None
69
70
height, width = example["image"].shape[1:3]
71
split_input = height >= 128 and width >= 128
72
73
if split_input:
74
ks = 128
75
stride = 64
76
vqf = 4 #
77
model.split_input_params = {"ks": (ks, ks), "stride": (stride, stride),
78
"vqf": vqf,
79
"patch_distributed_vq": True,
80
"tie_braker": False,
81
"clip_max_weight": 0.5,
82
"clip_min_weight": 0.01,
83
"clip_max_tie_weight": 0.5,
84
"clip_min_tie_weight": 0.01}
85
else:
86
if hasattr(model, "split_input_params"):
87
delattr(model, "split_input_params")
88
89
x_t = None
90
logs = None
91
for _ in range(n_runs):
92
if custom_shape is not None:
93
x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)
94
x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0])
95
96
logs = make_convolutional_sample(example, model,
97
custom_steps=custom_steps,
98
eta=eta, quantize_x0=False,
99
custom_shape=custom_shape,
100
temperature=temperature, noise_dropout=0.,
101
corrector=guider, corrector_kwargs=ckwargs, x_T=x_t,
102
ddim_use_x0_pred=ddim_use_x0_pred
103
)
104
return logs
105
106
def super_resolution(self, image, steps=100, target_scale=2, half_attention=False):
107
model = self.load_model_from_config(half_attention)
108
109
# Run settings
110
diffusion_steps = int(steps)
111
eta = 1.0
112
113
114
gc.collect()
115
devices.torch_gc()
116
117
im_og = image
118
width_og, height_og = im_og.size
119
# If we can adjust the max upscale size, then the 4 below should be our variable
120
down_sample_rate = target_scale / 4
121
wd = width_og * down_sample_rate
122
hd = height_og * down_sample_rate
123
width_downsampled_pre = int(np.ceil(wd))
124
height_downsampled_pre = int(np.ceil(hd))
125
126
if down_sample_rate != 1:
127
print(
128
f'Downsampling from [{width_og}, {height_og}] to [{width_downsampled_pre}, {height_downsampled_pre}]')
129
im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
130
else:
131
print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)")
132
133
# pad width and height to multiples of 64, pads with the edge values of image to avoid artifacts
134
pad_w, pad_h = np.max(((2, 2), np.ceil(np.array(im_og.size) / 64).astype(int)), axis=0) * 64 - im_og.size
135
im_padded = Image.fromarray(np.pad(np.array(im_og), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))
136
137
logs = self.run(model["model"], im_padded, diffusion_steps, eta)
138
139
sample = logs["sample"]
140
sample = sample.detach().cpu()
141
sample = torch.clamp(sample, -1., 1.)
142
sample = (sample + 1.) / 2. * 255
143
sample = sample.numpy().astype(np.uint8)
144
sample = np.transpose(sample, (0, 2, 3, 1))
145
a = Image.fromarray(sample[0])
146
147
# remove padding
148
a = a.crop((0, 0) + tuple(np.array(im_og.size) * 4))
149
150
del model
151
gc.collect()
152
devices.torch_gc()
153
154
return a
155
156
157
def get_cond(selected_path):
158
example = {}
159
up_f = 4
160
c = selected_path.convert('RGB')
161
c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)
162
c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]],
163
antialias=True)
164
c_up = rearrange(c_up, '1 c h w -> 1 h w c')
165
c = rearrange(c, '1 c h w -> 1 h w c')
166
c = 2. * c - 1.
167
168
c = c.to(shared.device)
169
example["LR_image"] = c
170
example["image"] = c_up
171
172
return example
173
174
175
@torch.no_grad()
176
def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_sequence=None,
177
mask=None, x0=None, quantize_x0=False, temperature=1., score_corrector=None,
178
corrector_kwargs=None, x_t=None
179
):
180
ddim = DDIMSampler(model)
181
bs = shape[0]
182
shape = shape[1:]
183
print(f"Sampling with eta = {eta}; steps: {steps}")
184
samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, conditioning=cond, callback=callback,
185
normals_sequence=normals_sequence, quantize_x0=quantize_x0, eta=eta,
186
mask=mask, x0=x0, temperature=temperature, verbose=False,
187
score_corrector=score_corrector,
188
corrector_kwargs=corrector_kwargs, x_t=x_t)
189
190
return samples, intermediates
191
192
193
@torch.no_grad()
194
def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,
195
corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False):
196
log = {}
197
198
z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,
199
return_first_stage_outputs=True,
200
force_c_encode=not (hasattr(model, 'split_input_params')
201
and model.cond_stage_key == 'coordinates_bbox'),
202
return_original_cond=True)
203
204
if custom_shape is not None:
205
z = torch.randn(custom_shape)
206
print(f"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}")
207
208
z0 = None
209
210
log["input"] = x
211
log["reconstruction"] = xrec
212
213
if ismap(xc):
214
log["original_conditioning"] = model.to_rgb(xc)
215
if hasattr(model, 'cond_stage_key'):
216
log[model.cond_stage_key] = model.to_rgb(xc)
217
218
else:
219
log["original_conditioning"] = xc if xc is not None else torch.zeros_like(x)
220
if model.cond_stage_model:
221
log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x)
222
if model.cond_stage_key == 'class_label':
223
log[model.cond_stage_key] = xc[model.cond_stage_key]
224
225
with model.ema_scope("Plotting"):
226
t0 = time.time()
227
228
sample, intermediates = convsample_ddim(model, c, steps=custom_steps, shape=z.shape,
229
eta=eta,
230
quantize_x0=quantize_x0, mask=None, x0=z0,
231
temperature=temperature, score_corrector=corrector, corrector_kwargs=corrector_kwargs,
232
x_t=x_T)
233
t1 = time.time()
234
235
if ddim_use_x0_pred:
236
sample = intermediates['pred_x0'][-1]
237
238
x_sample = model.decode_first_stage(sample)
239
240
try:
241
x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
242
log["sample_noquant"] = x_sample_noquant
243
log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
244
except Exception:
245
pass
246
247
log["sample"] = x_sample
248
log["time"] = t1 - t0
249
250
return log
251
252