Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/examples/community/imagic_stable_diffusion.py
1448 views
1
"""
2
modeled after the textual_inversion.py / train_dreambooth.py and the work
3
of justinpinkney here: https://github.com/justinpinkney/stable-diffusion/blob/main/notebooks/imagic.ipynb
4
"""
5
import inspect
6
import warnings
7
from typing import List, Optional, Union
8
9
import numpy as np
10
import PIL
11
import torch
12
import torch.nn.functional as F
13
from accelerate import Accelerator
14
15
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
16
from packaging import version
17
from tqdm.auto import tqdm
18
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
19
20
from diffusers import DiffusionPipeline
21
from diffusers.models import AutoencoderKL, UNet2DConditionModel
22
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
23
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
24
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
25
from diffusers.utils import logging
26
27
28
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
29
PIL_INTERPOLATION = {
30
"linear": PIL.Image.Resampling.BILINEAR,
31
"bilinear": PIL.Image.Resampling.BILINEAR,
32
"bicubic": PIL.Image.Resampling.BICUBIC,
33
"lanczos": PIL.Image.Resampling.LANCZOS,
34
"nearest": PIL.Image.Resampling.NEAREST,
35
}
36
else:
37
PIL_INTERPOLATION = {
38
"linear": PIL.Image.LINEAR,
39
"bilinear": PIL.Image.BILINEAR,
40
"bicubic": PIL.Image.BICUBIC,
41
"lanczos": PIL.Image.LANCZOS,
42
"nearest": PIL.Image.NEAREST,
43
}
44
# ------------------------------------------------------------------------------
45
46
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
47
48
49
def preprocess(image):
50
w, h = image.size
51
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
52
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
53
image = np.array(image).astype(np.float32) / 255.0
54
image = image[None].transpose(0, 3, 1, 2)
55
image = torch.from_numpy(image)
56
return 2.0 * image - 1.0
57
58
59
class ImagicStableDiffusionPipeline(DiffusionPipeline):
60
r"""
61
Pipeline for imagic image editing.
62
See paper here: https://arxiv.org/pdf/2210.09276.pdf
63
64
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
65
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
66
Args:
67
vae ([`AutoencoderKL`]):
68
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
69
text_encoder ([`CLIPTextModel`]):
70
Frozen text-encoder. Stable Diffusion uses the text portion of
71
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
72
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
73
tokenizer (`CLIPTokenizer`):
74
Tokenizer of class
75
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
76
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
77
scheduler ([`SchedulerMixin`]):
78
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
79
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
80
safety_checker ([`StableDiffusionSafetyChecker`]):
81
Classification module that estimates whether generated images could be considered offsensive or harmful.
82
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
83
feature_extractor ([`CLIPImageProcessor`]):
84
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
85
"""
86
87
def __init__(
88
self,
89
vae: AutoencoderKL,
90
text_encoder: CLIPTextModel,
91
tokenizer: CLIPTokenizer,
92
unet: UNet2DConditionModel,
93
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
94
safety_checker: StableDiffusionSafetyChecker,
95
feature_extractor: CLIPImageProcessor,
96
):
97
super().__init__()
98
self.register_modules(
99
vae=vae,
100
text_encoder=text_encoder,
101
tokenizer=tokenizer,
102
unet=unet,
103
scheduler=scheduler,
104
safety_checker=safety_checker,
105
feature_extractor=feature_extractor,
106
)
107
108
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
109
r"""
110
Enable sliced attention computation.
111
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
112
in several steps. This is useful to save some memory in exchange for a small speed decrease.
113
Args:
114
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
115
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
116
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
117
`attention_head_dim` must be a multiple of `slice_size`.
118
"""
119
if slice_size == "auto":
120
# half the attention head size is usually a good trade-off between
121
# speed and memory
122
slice_size = self.unet.config.attention_head_dim // 2
123
self.unet.set_attention_slice(slice_size)
124
125
def disable_attention_slicing(self):
126
r"""
127
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
128
back to computing attention in one step.
129
"""
130
# set slice_size = `None` to disable `attention slicing`
131
self.enable_attention_slicing(None)
132
133
def train(
134
self,
135
prompt: Union[str, List[str]],
136
image: Union[torch.FloatTensor, PIL.Image.Image],
137
height: Optional[int] = 512,
138
width: Optional[int] = 512,
139
generator: Optional[torch.Generator] = None,
140
embedding_learning_rate: float = 0.001,
141
diffusion_model_learning_rate: float = 2e-6,
142
text_embedding_optimization_steps: int = 500,
143
model_fine_tuning_optimization_steps: int = 1000,
144
**kwargs,
145
):
146
r"""
147
Function invoked when calling the pipeline for generation.
148
Args:
149
prompt (`str` or `List[str]`):
150
The prompt or prompts to guide the image generation.
151
height (`int`, *optional*, defaults to 512):
152
The height in pixels of the generated image.
153
width (`int`, *optional*, defaults to 512):
154
The width in pixels of the generated image.
155
num_inference_steps (`int`, *optional*, defaults to 50):
156
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
157
expense of slower inference.
158
guidance_scale (`float`, *optional*, defaults to 7.5):
159
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
160
`guidance_scale` is defined as `w` of equation 2. of [Imagen
161
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
162
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
163
usually at the expense of lower image quality.
164
eta (`float`, *optional*, defaults to 0.0):
165
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
166
[`schedulers.DDIMScheduler`], will be ignored for others.
167
generator (`torch.Generator`, *optional*):
168
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
169
deterministic.
170
latents (`torch.FloatTensor`, *optional*):
171
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
172
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
173
tensor will ge generated by sampling using the supplied random `generator`.
174
output_type (`str`, *optional*, defaults to `"pil"`):
175
The output format of the generate image. Choose between
176
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
177
return_dict (`bool`, *optional*, defaults to `True`):
178
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
179
plain tuple.
180
Returns:
181
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
182
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
183
When returning a tuple, the first element is a list with the generated images, and the second element is a
184
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
185
(nsfw) content, according to the `safety_checker`.
186
"""
187
accelerator = Accelerator(
188
gradient_accumulation_steps=1,
189
mixed_precision="fp16",
190
)
191
192
if "torch_device" in kwargs:
193
device = kwargs.pop("torch_device")
194
warnings.warn(
195
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
196
" Consider using `pipe.to(torch_device)` instead."
197
)
198
199
if device is None:
200
device = "cuda" if torch.cuda.is_available() else "cpu"
201
self.to(device)
202
203
if height % 8 != 0 or width % 8 != 0:
204
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
205
206
# Freeze vae and unet
207
self.vae.requires_grad_(False)
208
self.unet.requires_grad_(False)
209
self.text_encoder.requires_grad_(False)
210
self.unet.eval()
211
self.vae.eval()
212
self.text_encoder.eval()
213
214
if accelerator.is_main_process:
215
accelerator.init_trackers(
216
"imagic",
217
config={
218
"embedding_learning_rate": embedding_learning_rate,
219
"text_embedding_optimization_steps": text_embedding_optimization_steps,
220
},
221
)
222
223
# get text embeddings for prompt
224
text_input = self.tokenizer(
225
prompt,
226
padding="max_length",
227
max_length=self.tokenizer.model_max_length,
228
truncation=True,
229
return_tensors="pt",
230
)
231
text_embeddings = torch.nn.Parameter(
232
self.text_encoder(text_input.input_ids.to(self.device))[0], requires_grad=True
233
)
234
text_embeddings = text_embeddings.detach()
235
text_embeddings.requires_grad_()
236
text_embeddings_orig = text_embeddings.clone()
237
238
# Initialize the optimizer
239
optimizer = torch.optim.Adam(
240
[text_embeddings], # only optimize the embeddings
241
lr=embedding_learning_rate,
242
)
243
244
if isinstance(image, PIL.Image.Image):
245
image = preprocess(image)
246
247
latents_dtype = text_embeddings.dtype
248
image = image.to(device=self.device, dtype=latents_dtype)
249
init_latent_image_dist = self.vae.encode(image).latent_dist
250
image_latents = init_latent_image_dist.sample(generator=generator)
251
image_latents = 0.18215 * image_latents
252
253
progress_bar = tqdm(range(text_embedding_optimization_steps), disable=not accelerator.is_local_main_process)
254
progress_bar.set_description("Steps")
255
256
global_step = 0
257
258
logger.info("First optimizing the text embedding to better reconstruct the init image")
259
for _ in range(text_embedding_optimization_steps):
260
with accelerator.accumulate(text_embeddings):
261
# Sample noise that we'll add to the latents
262
noise = torch.randn(image_latents.shape).to(image_latents.device)
263
timesteps = torch.randint(1000, (1,), device=image_latents.device)
264
265
# Add noise to the latents according to the noise magnitude at each timestep
266
# (this is the forward diffusion process)
267
noisy_latents = self.scheduler.add_noise(image_latents, noise, timesteps)
268
269
# Predict the noise residual
270
noise_pred = self.unet(noisy_latents, timesteps, text_embeddings).sample
271
272
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
273
accelerator.backward(loss)
274
275
optimizer.step()
276
optimizer.zero_grad()
277
278
# Checks if the accelerator has performed an optimization step behind the scenes
279
if accelerator.sync_gradients:
280
progress_bar.update(1)
281
global_step += 1
282
283
logs = {"loss": loss.detach().item()} # , "lr": lr_scheduler.get_last_lr()[0]}
284
progress_bar.set_postfix(**logs)
285
accelerator.log(logs, step=global_step)
286
287
accelerator.wait_for_everyone()
288
289
text_embeddings.requires_grad_(False)
290
291
# Now we fine tune the unet to better reconstruct the image
292
self.unet.requires_grad_(True)
293
self.unet.train()
294
optimizer = torch.optim.Adam(
295
self.unet.parameters(), # only optimize unet
296
lr=diffusion_model_learning_rate,
297
)
298
progress_bar = tqdm(range(model_fine_tuning_optimization_steps), disable=not accelerator.is_local_main_process)
299
300
logger.info("Next fine tuning the entire model to better reconstruct the init image")
301
for _ in range(model_fine_tuning_optimization_steps):
302
with accelerator.accumulate(self.unet.parameters()):
303
# Sample noise that we'll add to the latents
304
noise = torch.randn(image_latents.shape).to(image_latents.device)
305
timesteps = torch.randint(1000, (1,), device=image_latents.device)
306
307
# Add noise to the latents according to the noise magnitude at each timestep
308
# (this is the forward diffusion process)
309
noisy_latents = self.scheduler.add_noise(image_latents, noise, timesteps)
310
311
# Predict the noise residual
312
noise_pred = self.unet(noisy_latents, timesteps, text_embeddings).sample
313
314
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
315
accelerator.backward(loss)
316
317
optimizer.step()
318
optimizer.zero_grad()
319
320
# Checks if the accelerator has performed an optimization step behind the scenes
321
if accelerator.sync_gradients:
322
progress_bar.update(1)
323
global_step += 1
324
325
logs = {"loss": loss.detach().item()} # , "lr": lr_scheduler.get_last_lr()[0]}
326
progress_bar.set_postfix(**logs)
327
accelerator.log(logs, step=global_step)
328
329
accelerator.wait_for_everyone()
330
self.text_embeddings_orig = text_embeddings_orig
331
self.text_embeddings = text_embeddings
332
333
@torch.no_grad()
334
def __call__(
335
self,
336
alpha: float = 1.2,
337
height: Optional[int] = 512,
338
width: Optional[int] = 512,
339
num_inference_steps: Optional[int] = 50,
340
generator: Optional[torch.Generator] = None,
341
output_type: Optional[str] = "pil",
342
return_dict: bool = True,
343
guidance_scale: float = 7.5,
344
eta: float = 0.0,
345
):
346
r"""
347
Function invoked when calling the pipeline for generation.
348
Args:
349
prompt (`str` or `List[str]`):
350
The prompt or prompts to guide the image generation.
351
height (`int`, *optional*, defaults to 512):
352
The height in pixels of the generated image.
353
width (`int`, *optional*, defaults to 512):
354
The width in pixels of the generated image.
355
num_inference_steps (`int`, *optional*, defaults to 50):
356
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
357
expense of slower inference.
358
guidance_scale (`float`, *optional*, defaults to 7.5):
359
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
360
`guidance_scale` is defined as `w` of equation 2. of [Imagen
361
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
362
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
363
usually at the expense of lower image quality.
364
eta (`float`, *optional*, defaults to 0.0):
365
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
366
[`schedulers.DDIMScheduler`], will be ignored for others.
367
generator (`torch.Generator`, *optional*):
368
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
369
deterministic.
370
latents (`torch.FloatTensor`, *optional*):
371
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
372
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
373
tensor will ge generated by sampling using the supplied random `generator`.
374
output_type (`str`, *optional*, defaults to `"pil"`):
375
The output format of the generate image. Choose between
376
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
377
return_dict (`bool`, *optional*, defaults to `True`):
378
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
379
plain tuple.
380
Returns:
381
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
382
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
383
When returning a tuple, the first element is a list with the generated images, and the second element is a
384
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
385
(nsfw) content, according to the `safety_checker`.
386
"""
387
if height % 8 != 0 or width % 8 != 0:
388
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
389
if self.text_embeddings is None:
390
raise ValueError("Please run the pipe.train() before trying to generate an image.")
391
if self.text_embeddings_orig is None:
392
raise ValueError("Please run the pipe.train() before trying to generate an image.")
393
394
text_embeddings = alpha * self.text_embeddings_orig + (1 - alpha) * self.text_embeddings
395
396
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
397
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
398
# corresponds to doing no classifier free guidance.
399
do_classifier_free_guidance = guidance_scale > 1.0
400
# get unconditional embeddings for classifier free guidance
401
if do_classifier_free_guidance:
402
uncond_tokens = [""]
403
max_length = self.tokenizer.model_max_length
404
uncond_input = self.tokenizer(
405
uncond_tokens,
406
padding="max_length",
407
max_length=max_length,
408
truncation=True,
409
return_tensors="pt",
410
)
411
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
412
413
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
414
seq_len = uncond_embeddings.shape[1]
415
uncond_embeddings = uncond_embeddings.view(1, seq_len, -1)
416
417
# For classifier free guidance, we need to do two forward passes.
418
# Here we concatenate the unconditional and text embeddings into a single batch
419
# to avoid doing two forward passes
420
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
421
422
# get the initial random noise unless the user supplied it
423
424
# Unlike in other pipelines, latents need to be generated in the target device
425
# for 1-to-1 results reproducibility with the CompVis implementation.
426
# However this currently doesn't work in `mps`.
427
latents_shape = (1, self.unet.in_channels, height // 8, width // 8)
428
latents_dtype = text_embeddings.dtype
429
if self.device.type == "mps":
430
# randn does not exist on mps
431
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
432
self.device
433
)
434
else:
435
latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
436
437
# set timesteps
438
self.scheduler.set_timesteps(num_inference_steps)
439
440
# Some schedulers like PNDM have timesteps as arrays
441
# It's more optimized to move all timesteps to correct device beforehand
442
timesteps_tensor = self.scheduler.timesteps.to(self.device)
443
444
# scale the initial noise by the standard deviation required by the scheduler
445
latents = latents * self.scheduler.init_noise_sigma
446
447
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
448
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
449
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
450
# and should be between [0, 1]
451
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
452
extra_step_kwargs = {}
453
if accepts_eta:
454
extra_step_kwargs["eta"] = eta
455
456
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
457
# expand the latents if we are doing classifier free guidance
458
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
459
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
460
461
# predict the noise residual
462
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
463
464
# perform guidance
465
if do_classifier_free_guidance:
466
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
467
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
468
469
# compute the previous noisy sample x_t -> x_t-1
470
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
471
472
latents = 1 / 0.18215 * latents
473
image = self.vae.decode(latents).sample
474
475
image = (image / 2 + 0.5).clamp(0, 1)
476
477
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
478
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
479
480
if self.safety_checker is not None:
481
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
482
self.device
483
)
484
image, has_nsfw_concept = self.safety_checker(
485
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
486
)
487
else:
488
has_nsfw_concept = None
489
490
if output_type == "pil":
491
image = self.numpy_to_pil(image)
492
493
if not return_dict:
494
return (image, has_nsfw_concept)
495
496
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
497
498