Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/examples/community/clip_guided_stable_diffusion_img2img.py
1448 views
1
import inspect
2
from typing import List, Optional, Union
3
4
import numpy as np
5
import PIL
6
import torch
7
from torch import nn
8
from torch.nn import functional as F
9
from torchvision import transforms
10
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextModel, CLIPTokenizer
11
12
from diffusers import (
13
AutoencoderKL,
14
DDIMScheduler,
15
DiffusionPipeline,
16
DPMSolverMultistepScheduler,
17
LMSDiscreteScheduler,
18
PNDMScheduler,
19
UNet2DConditionModel,
20
)
21
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
22
from diffusers.utils import (
23
PIL_INTERPOLATION,
24
deprecate,
25
randn_tensor,
26
)
27
28
29
EXAMPLE_DOC_STRING = """
30
Examples:
31
```
32
from io import BytesIO
33
34
import requests
35
import torch
36
from diffusers import DiffusionPipeline
37
from PIL import Image
38
from transformers import CLIPFeatureExtractor, CLIPModel
39
40
feature_extractor = CLIPFeatureExtractor.from_pretrained(
41
"laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
42
)
43
clip_model = CLIPModel.from_pretrained(
44
"laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.float16
45
)
46
47
48
guided_pipeline = DiffusionPipeline.from_pretrained(
49
"CompVis/stable-diffusion-v1-4",
50
# custom_pipeline="clip_guided_stable_diffusion",
51
custom_pipeline="/home/njindal/diffusers/examples/community/clip_guided_stable_diffusion.py",
52
clip_model=clip_model,
53
feature_extractor=feature_extractor,
54
torch_dtype=torch.float16,
55
)
56
guided_pipeline.enable_attention_slicing()
57
guided_pipeline = guided_pipeline.to("cuda")
58
59
prompt = "fantasy book cover, full moon, fantasy forest landscape, golden vector elements, fantasy magic, dark light night, intricate, elegant, sharp focus, illustration, highly detailed, digital painting, concept art, matte, art by WLOP and Artgerm and Albert Bierstadt, masterpiece"
60
61
url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
62
63
response = requests.get(url)
64
init_image = Image.open(BytesIO(response.content)).convert("RGB")
65
66
image = guided_pipeline(
67
prompt=prompt,
68
num_inference_steps=30,
69
image=init_image,
70
strength=0.75,
71
guidance_scale=7.5,
72
clip_guidance_scale=100,
73
num_cutouts=4,
74
use_cutouts=False,
75
).images[0]
76
display(image)
77
```
78
"""
79
80
81
def preprocess(image, w, h):
82
if isinstance(image, torch.Tensor):
83
return image
84
elif isinstance(image, PIL.Image.Image):
85
image = [image]
86
87
if isinstance(image[0], PIL.Image.Image):
88
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
89
image = np.concatenate(image, axis=0)
90
image = np.array(image).astype(np.float32) / 255.0
91
image = image.transpose(0, 3, 1, 2)
92
image = 2.0 * image - 1.0
93
image = torch.from_numpy(image)
94
elif isinstance(image[0], torch.Tensor):
95
image = torch.cat(image, dim=0)
96
return image
97
98
99
class MakeCutouts(nn.Module):
100
def __init__(self, cut_size, cut_power=1.0):
101
super().__init__()
102
103
self.cut_size = cut_size
104
self.cut_power = cut_power
105
106
def forward(self, pixel_values, num_cutouts):
107
sideY, sideX = pixel_values.shape[2:4]
108
max_size = min(sideX, sideY)
109
min_size = min(sideX, sideY, self.cut_size)
110
cutouts = []
111
for _ in range(num_cutouts):
112
size = int(torch.rand([]) ** self.cut_power * (max_size - min_size) + min_size)
113
offsetx = torch.randint(0, sideX - size + 1, ())
114
offsety = torch.randint(0, sideY - size + 1, ())
115
cutout = pixel_values[:, :, offsety : offsety + size, offsetx : offsetx + size]
116
cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
117
return torch.cat(cutouts)
118
119
120
def spherical_dist_loss(x, y):
121
x = F.normalize(x, dim=-1)
122
y = F.normalize(y, dim=-1)
123
return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
124
125
126
def set_requires_grad(model, value):
127
for param in model.parameters():
128
param.requires_grad = value
129
130
131
class CLIPGuidedStableDiffusion(DiffusionPipeline):
132
"""CLIP guided stable diffusion based on the amazing repo by @crowsonkb and @Jack000
133
- https://github.com/Jack000/glid-3-xl
134
- https://github.dev/crowsonkb/k-diffusion
135
"""
136
137
def __init__(
138
self,
139
vae: AutoencoderKL,
140
text_encoder: CLIPTextModel,
141
clip_model: CLIPModel,
142
tokenizer: CLIPTokenizer,
143
unet: UNet2DConditionModel,
144
scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler, DPMSolverMultistepScheduler],
145
feature_extractor: CLIPFeatureExtractor,
146
):
147
super().__init__()
148
self.register_modules(
149
vae=vae,
150
text_encoder=text_encoder,
151
clip_model=clip_model,
152
tokenizer=tokenizer,
153
unet=unet,
154
scheduler=scheduler,
155
feature_extractor=feature_extractor,
156
)
157
158
self.normalize = transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
159
self.cut_out_size = (
160
feature_extractor.size
161
if isinstance(feature_extractor.size, int)
162
else feature_extractor.size["shortest_edge"]
163
)
164
self.make_cutouts = MakeCutouts(self.cut_out_size)
165
166
set_requires_grad(self.text_encoder, False)
167
set_requires_grad(self.clip_model, False)
168
169
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
170
if slice_size == "auto":
171
# half the attention head size is usually a good trade-off between
172
# speed and memory
173
slice_size = self.unet.config.attention_head_dim // 2
174
self.unet.set_attention_slice(slice_size)
175
176
def disable_attention_slicing(self):
177
self.enable_attention_slicing(None)
178
179
def freeze_vae(self):
180
set_requires_grad(self.vae, False)
181
182
def unfreeze_vae(self):
183
set_requires_grad(self.vae, True)
184
185
def freeze_unet(self):
186
set_requires_grad(self.unet, False)
187
188
def unfreeze_unet(self):
189
set_requires_grad(self.unet, True)
190
191
def get_timesteps(self, num_inference_steps, strength, device):
192
# get the original timestep using init_timestep
193
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
194
195
t_start = max(num_inference_steps - init_timestep, 0)
196
timesteps = self.scheduler.timesteps[t_start:]
197
198
return timesteps, num_inference_steps - t_start
199
200
def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
201
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
202
raise ValueError(
203
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
204
)
205
206
image = image.to(device=device, dtype=dtype)
207
208
batch_size = batch_size * num_images_per_prompt
209
if isinstance(generator, list) and len(generator) != batch_size:
210
raise ValueError(
211
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
212
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
213
)
214
215
if isinstance(generator, list):
216
init_latents = [
217
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
218
]
219
init_latents = torch.cat(init_latents, dim=0)
220
else:
221
init_latents = self.vae.encode(image).latent_dist.sample(generator)
222
223
init_latents = self.vae.config.scaling_factor * init_latents
224
225
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
226
# expand init_latents for batch_size
227
deprecation_message = (
228
f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
229
" images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
230
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
231
" your script to pass as many initial images as text prompts to suppress this warning."
232
)
233
deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
234
additional_image_per_prompt = batch_size // init_latents.shape[0]
235
init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
236
elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
237
raise ValueError(
238
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
239
)
240
else:
241
init_latents = torch.cat([init_latents], dim=0)
242
243
shape = init_latents.shape
244
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
245
246
# get latents
247
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
248
latents = init_latents
249
250
return latents
251
252
@torch.enable_grad()
253
def cond_fn(
254
self,
255
latents,
256
timestep,
257
index,
258
text_embeddings,
259
noise_pred_original,
260
text_embeddings_clip,
261
clip_guidance_scale,
262
num_cutouts,
263
use_cutouts=True,
264
):
265
latents = latents.detach().requires_grad_()
266
267
latent_model_input = self.scheduler.scale_model_input(latents, timestep)
268
269
# predict the noise residual
270
noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample
271
272
if isinstance(self.scheduler, (PNDMScheduler, DDIMScheduler, DPMSolverMultistepScheduler)):
273
alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
274
beta_prod_t = 1 - alpha_prod_t
275
# compute predicted original sample from predicted noise also called
276
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
277
pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5)
278
279
fac = torch.sqrt(beta_prod_t)
280
sample = pred_original_sample * (fac) + latents * (1 - fac)
281
elif isinstance(self.scheduler, LMSDiscreteScheduler):
282
sigma = self.scheduler.sigmas[index]
283
sample = latents - sigma * noise_pred
284
else:
285
raise ValueError(f"scheduler type {type(self.scheduler)} not supported")
286
287
sample = 1 / self.vae.config.scaling_factor * sample
288
image = self.vae.decode(sample).sample
289
image = (image / 2 + 0.5).clamp(0, 1)
290
291
if use_cutouts:
292
image = self.make_cutouts(image, num_cutouts)
293
else:
294
image = transforms.Resize(self.cut_out_size)(image)
295
image = self.normalize(image).to(latents.dtype)
296
297
image_embeddings_clip = self.clip_model.get_image_features(image)
298
image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
299
300
if use_cutouts:
301
dists = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip)
302
dists = dists.view([num_cutouts, sample.shape[0], -1])
303
loss = dists.sum(2).mean(0).sum() * clip_guidance_scale
304
else:
305
loss = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip).mean() * clip_guidance_scale
306
307
grads = -torch.autograd.grad(loss, latents)[0]
308
309
if isinstance(self.scheduler, LMSDiscreteScheduler):
310
latents = latents.detach() + grads * (sigma**2)
311
noise_pred = noise_pred_original
312
else:
313
noise_pred = noise_pred_original - torch.sqrt(beta_prod_t) * grads
314
return noise_pred, latents
315
316
@torch.no_grad()
317
def __call__(
318
self,
319
prompt: Union[str, List[str]],
320
height: Optional[int] = 512,
321
width: Optional[int] = 512,
322
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
323
strength: float = 0.8,
324
num_inference_steps: Optional[int] = 50,
325
guidance_scale: Optional[float] = 7.5,
326
num_images_per_prompt: Optional[int] = 1,
327
eta: float = 0.0,
328
clip_guidance_scale: Optional[float] = 100,
329
clip_prompt: Optional[Union[str, List[str]]] = None,
330
num_cutouts: Optional[int] = 4,
331
use_cutouts: Optional[bool] = True,
332
generator: Optional[torch.Generator] = None,
333
latents: Optional[torch.FloatTensor] = None,
334
output_type: Optional[str] = "pil",
335
return_dict: bool = True,
336
):
337
if isinstance(prompt, str):
338
batch_size = 1
339
elif isinstance(prompt, list):
340
batch_size = len(prompt)
341
else:
342
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
343
344
if height % 8 != 0 or width % 8 != 0:
345
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
346
347
# get prompt text embeddings
348
text_input = self.tokenizer(
349
prompt,
350
padding="max_length",
351
max_length=self.tokenizer.model_max_length,
352
truncation=True,
353
return_tensors="pt",
354
)
355
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
356
# duplicate text embeddings for each generation per prompt
357
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
358
359
# set timesteps
360
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
361
extra_set_kwargs = {}
362
if accepts_offset:
363
extra_set_kwargs["offset"] = 1
364
365
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
366
# Some schedulers like PNDM have timesteps as arrays
367
# It's more optimized to move all timesteps to correct device beforehand
368
self.scheduler.timesteps.to(self.device)
369
370
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, self.device)
371
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
372
373
# Preprocess image
374
image = preprocess(image, width, height)
375
latents = self.prepare_latents(
376
image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, self.device, generator
377
)
378
379
if clip_guidance_scale > 0:
380
if clip_prompt is not None:
381
clip_text_input = self.tokenizer(
382
clip_prompt,
383
padding="max_length",
384
max_length=self.tokenizer.model_max_length,
385
truncation=True,
386
return_tensors="pt",
387
).input_ids.to(self.device)
388
else:
389
clip_text_input = text_input.input_ids.to(self.device)
390
text_embeddings_clip = self.clip_model.get_text_features(clip_text_input)
391
text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
392
# duplicate text embeddings clip for each generation per prompt
393
text_embeddings_clip = text_embeddings_clip.repeat_interleave(num_images_per_prompt, dim=0)
394
395
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
396
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
397
# corresponds to doing no classifier free guidance.
398
do_classifier_free_guidance = guidance_scale > 1.0
399
# get unconditional embeddings for classifier free guidance
400
if do_classifier_free_guidance:
401
max_length = text_input.input_ids.shape[-1]
402
uncond_input = self.tokenizer([""], padding="max_length", max_length=max_length, return_tensors="pt")
403
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
404
# duplicate unconditional embeddings for each generation per prompt
405
uncond_embeddings = uncond_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
406
407
# For classifier free guidance, we need to do two forward passes.
408
# Here we concatenate the unconditional and text embeddings into a single batch
409
# to avoid doing two forward passes
410
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
411
412
# get the initial random noise unless the user supplied it
413
414
# Unlike in other pipelines, latents need to be generated in the target device
415
# for 1-to-1 results reproducibility with the CompVis implementation.
416
# However this currently doesn't work in `mps`.
417
latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
418
latents_dtype = text_embeddings.dtype
419
if latents is None:
420
if self.device.type == "mps":
421
# randn does not work reproducibly on mps
422
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
423
self.device
424
)
425
else:
426
latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
427
else:
428
if latents.shape != latents_shape:
429
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
430
latents = latents.to(self.device)
431
432
# scale the initial noise by the standard deviation required by the scheduler
433
latents = latents * self.scheduler.init_noise_sigma
434
435
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
436
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
437
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
438
# and should be between [0, 1]
439
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
440
extra_step_kwargs = {}
441
if accepts_eta:
442
extra_step_kwargs["eta"] = eta
443
444
# check if the scheduler accepts generator
445
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
446
if accepts_generator:
447
extra_step_kwargs["generator"] = generator
448
449
with self.progress_bar(total=num_inference_steps):
450
for i, t in enumerate(timesteps):
451
# expand the latents if we are doing classifier free guidance
452
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
453
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
454
455
# predict the noise residual
456
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
457
458
# perform classifier free guidance
459
if do_classifier_free_guidance:
460
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
461
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
462
463
# perform clip guidance
464
if clip_guidance_scale > 0:
465
text_embeddings_for_guidance = (
466
text_embeddings.chunk(2)[1] if do_classifier_free_guidance else text_embeddings
467
)
468
noise_pred, latents = self.cond_fn(
469
latents,
470
t,
471
i,
472
text_embeddings_for_guidance,
473
noise_pred,
474
text_embeddings_clip,
475
clip_guidance_scale,
476
num_cutouts,
477
use_cutouts,
478
)
479
480
# compute the previous noisy sample x_t -> x_t-1
481
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
482
483
# scale and decode the image latents with vae
484
latents = 1 / self.vae.config.scaling_factor * latents
485
image = self.vae.decode(latents).sample
486
487
image = (image / 2 + 0.5).clamp(0, 1)
488
image = image.cpu().permute(0, 2, 3, 1).numpy()
489
490
if output_type == "pil":
491
image = self.numpy_to_pil(image)
492
493
if not return_dict:
494
return (image, None)
495
496
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None)
497
498