Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/examples/community/stable_diffusion_controlnet_img2img.py
1448 views
1
# Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/
2
3
import inspect
4
from typing import Any, Callable, Dict, List, Optional, Union
5
6
import numpy as np
7
import PIL.Image
8
import torch
9
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
10
11
from diffusers import AutoencoderKL, ControlNetModel, DiffusionPipeline, UNet2DConditionModel, logging
12
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
13
from diffusers.schedulers import KarrasDiffusionSchedulers
14
from diffusers.utils import (
15
PIL_INTERPOLATION,
16
is_accelerate_available,
17
is_accelerate_version,
18
randn_tensor,
19
replace_example_docstring,
20
)
21
22
23
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
24
25
EXAMPLE_DOC_STRING = """
26
Examples:
27
```py
28
>>> import numpy as np
29
>>> import torch
30
>>> from PIL import Image
31
>>> from diffusers import ControlNetModel, UniPCMultistepScheduler
32
>>> from diffusers.utils import load_image
33
34
>>> input_image = load_image("https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png")
35
36
>>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
37
38
>>> pipe_controlnet = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
39
"runwayml/stable-diffusion-v1-5",
40
controlnet=controlnet,
41
safety_checker=None,
42
torch_dtype=torch.float16
43
)
44
45
>>> pipe_controlnet.scheduler = UniPCMultistepScheduler.from_config(pipe_controlnet.scheduler.config)
46
>>> pipe_controlnet.enable_xformers_memory_efficient_attention()
47
>>> pipe_controlnet.enable_model_cpu_offload()
48
49
# using image with edges for our canny controlnet
50
>>> control_image = load_image(
51
"https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/vermeer_canny_edged.png")
52
53
54
>>> result_img = pipe_controlnet(controlnet_conditioning_image=control_image,
55
image=input_image,
56
prompt="an android robot, cyberpank, digitl art masterpiece",
57
num_inference_steps=20).images[0]
58
59
>>> result_img.show()
60
```
61
"""
62
63
64
def prepare_image(image):
65
if isinstance(image, torch.Tensor):
66
# Batch single image
67
if image.ndim == 3:
68
image = image.unsqueeze(0)
69
70
image = image.to(dtype=torch.float32)
71
else:
72
# preprocess image
73
if isinstance(image, (PIL.Image.Image, np.ndarray)):
74
image = [image]
75
76
if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
77
image = [np.array(i.convert("RGB"))[None, :] for i in image]
78
image = np.concatenate(image, axis=0)
79
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
80
image = np.concatenate([i[None, :] for i in image], axis=0)
81
82
image = image.transpose(0, 3, 1, 2)
83
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
84
85
return image
86
87
88
def prepare_controlnet_conditioning_image(
89
controlnet_conditioning_image, width, height, batch_size, num_images_per_prompt, device, dtype
90
):
91
if not isinstance(controlnet_conditioning_image, torch.Tensor):
92
if isinstance(controlnet_conditioning_image, PIL.Image.Image):
93
controlnet_conditioning_image = [controlnet_conditioning_image]
94
95
if isinstance(controlnet_conditioning_image[0], PIL.Image.Image):
96
controlnet_conditioning_image = [
97
np.array(i.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]))[None, :]
98
for i in controlnet_conditioning_image
99
]
100
controlnet_conditioning_image = np.concatenate(controlnet_conditioning_image, axis=0)
101
controlnet_conditioning_image = np.array(controlnet_conditioning_image).astype(np.float32) / 255.0
102
controlnet_conditioning_image = controlnet_conditioning_image.transpose(0, 3, 1, 2)
103
controlnet_conditioning_image = torch.from_numpy(controlnet_conditioning_image)
104
elif isinstance(controlnet_conditioning_image[0], torch.Tensor):
105
controlnet_conditioning_image = torch.cat(controlnet_conditioning_image, dim=0)
106
107
image_batch_size = controlnet_conditioning_image.shape[0]
108
109
if image_batch_size == 1:
110
repeat_by = batch_size
111
else:
112
# image batch size is the same as prompt batch size
113
repeat_by = num_images_per_prompt
114
115
controlnet_conditioning_image = controlnet_conditioning_image.repeat_interleave(repeat_by, dim=0)
116
117
controlnet_conditioning_image = controlnet_conditioning_image.to(device=device, dtype=dtype)
118
119
return controlnet_conditioning_image
120
121
122
class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
123
"""
124
Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/
125
"""
126
127
_optional_components = ["safety_checker", "feature_extractor"]
128
129
def __init__(
130
self,
131
vae: AutoencoderKL,
132
text_encoder: CLIPTextModel,
133
tokenizer: CLIPTokenizer,
134
unet: UNet2DConditionModel,
135
controlnet: ControlNetModel,
136
scheduler: KarrasDiffusionSchedulers,
137
safety_checker: StableDiffusionSafetyChecker,
138
feature_extractor: CLIPImageProcessor,
139
requires_safety_checker: bool = True,
140
):
141
super().__init__()
142
143
if safety_checker is None and requires_safety_checker:
144
logger.warning(
145
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
146
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
147
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
148
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
149
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
150
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
151
)
152
153
if safety_checker is not None and feature_extractor is None:
154
raise ValueError(
155
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
156
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
157
)
158
159
self.register_modules(
160
vae=vae,
161
text_encoder=text_encoder,
162
tokenizer=tokenizer,
163
unet=unet,
164
controlnet=controlnet,
165
scheduler=scheduler,
166
safety_checker=safety_checker,
167
feature_extractor=feature_extractor,
168
)
169
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
170
self.register_to_config(requires_safety_checker=requires_safety_checker)
171
172
def enable_vae_slicing(self):
173
r"""
174
Enable sliced VAE decoding.
175
176
When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
177
steps. This is useful to save some memory and allow larger batch sizes.
178
"""
179
self.vae.enable_slicing()
180
181
def disable_vae_slicing(self):
182
r"""
183
Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
184
computing decoding in one step.
185
"""
186
self.vae.disable_slicing()
187
188
def enable_sequential_cpu_offload(self, gpu_id=0):
189
r"""
190
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
191
text_encoder, vae, controlnet, and safety checker have their state dicts saved to CPU and then are moved to a
192
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
193
Note that offloading happens on a submodule basis. Memory savings are higher than with
194
`enable_model_cpu_offload`, but performance is lower.
195
"""
196
if is_accelerate_available():
197
from accelerate import cpu_offload
198
else:
199
raise ImportError("Please install accelerate via `pip install accelerate`")
200
201
device = torch.device(f"cuda:{gpu_id}")
202
203
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.controlnet]:
204
cpu_offload(cpu_offloaded_model, device)
205
206
if self.safety_checker is not None:
207
cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
208
209
def enable_model_cpu_offload(self, gpu_id=0):
210
r"""
211
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
212
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
213
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
214
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
215
"""
216
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
217
from accelerate import cpu_offload_with_hook
218
else:
219
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
220
221
device = torch.device(f"cuda:{gpu_id}")
222
223
hook = None
224
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
225
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
226
227
if self.safety_checker is not None:
228
# the safety checker can offload the vae again
229
_, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
230
231
# control net hook has be manually offloaded as it alternates with unet
232
cpu_offload_with_hook(self.controlnet, device)
233
234
# We'll offload the last model manually.
235
self.final_offload_hook = hook
236
237
@property
238
def _execution_device(self):
239
r"""
240
Returns the device on which the pipeline's models will be executed. After calling
241
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
242
hooks.
243
"""
244
if not hasattr(self.unet, "_hf_hook"):
245
return self.device
246
for module in self.unet.modules():
247
if (
248
hasattr(module, "_hf_hook")
249
and hasattr(module._hf_hook, "execution_device")
250
and module._hf_hook.execution_device is not None
251
):
252
return torch.device(module._hf_hook.execution_device)
253
return self.device
254
255
def _encode_prompt(
256
self,
257
prompt,
258
device,
259
num_images_per_prompt,
260
do_classifier_free_guidance,
261
negative_prompt=None,
262
prompt_embeds: Optional[torch.FloatTensor] = None,
263
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
264
):
265
r"""
266
Encodes the prompt into text encoder hidden states.
267
268
Args:
269
prompt (`str` or `List[str]`, *optional*):
270
prompt to be encoded
271
device: (`torch.device`):
272
torch device
273
num_images_per_prompt (`int`):
274
number of images that should be generated per prompt
275
do_classifier_free_guidance (`bool`):
276
whether to use classifier free guidance or not
277
negative_prompt (`str` or `List[str]`, *optional*):
278
The prompt or prompts not to guide the image generation. If not defined, one has to pass
279
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
280
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
281
prompt_embeds (`torch.FloatTensor`, *optional*):
282
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
283
provided, text embeddings will be generated from `prompt` input argument.
284
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
285
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
286
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
287
argument.
288
"""
289
if prompt is not None and isinstance(prompt, str):
290
batch_size = 1
291
elif prompt is not None and isinstance(prompt, list):
292
batch_size = len(prompt)
293
else:
294
batch_size = prompt_embeds.shape[0]
295
296
if prompt_embeds is None:
297
text_inputs = self.tokenizer(
298
prompt,
299
padding="max_length",
300
max_length=self.tokenizer.model_max_length,
301
truncation=True,
302
return_tensors="pt",
303
)
304
text_input_ids = text_inputs.input_ids
305
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
306
307
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
308
text_input_ids, untruncated_ids
309
):
310
removed_text = self.tokenizer.batch_decode(
311
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
312
)
313
logger.warning(
314
"The following part of your input was truncated because CLIP can only handle sequences up to"
315
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
316
)
317
318
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
319
attention_mask = text_inputs.attention_mask.to(device)
320
else:
321
attention_mask = None
322
323
prompt_embeds = self.text_encoder(
324
text_input_ids.to(device),
325
attention_mask=attention_mask,
326
)
327
prompt_embeds = prompt_embeds[0]
328
329
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
330
331
bs_embed, seq_len, _ = prompt_embeds.shape
332
# duplicate text embeddings for each generation per prompt, using mps friendly method
333
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
334
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
335
336
# get unconditional embeddings for classifier free guidance
337
if do_classifier_free_guidance and negative_prompt_embeds is None:
338
uncond_tokens: List[str]
339
if negative_prompt is None:
340
uncond_tokens = [""] * batch_size
341
elif type(prompt) is not type(negative_prompt):
342
raise TypeError(
343
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
344
f" {type(prompt)}."
345
)
346
elif isinstance(negative_prompt, str):
347
uncond_tokens = [negative_prompt]
348
elif batch_size != len(negative_prompt):
349
raise ValueError(
350
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
351
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
352
" the batch size of `prompt`."
353
)
354
else:
355
uncond_tokens = negative_prompt
356
357
max_length = prompt_embeds.shape[1]
358
uncond_input = self.tokenizer(
359
uncond_tokens,
360
padding="max_length",
361
max_length=max_length,
362
truncation=True,
363
return_tensors="pt",
364
)
365
366
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
367
attention_mask = uncond_input.attention_mask.to(device)
368
else:
369
attention_mask = None
370
371
negative_prompt_embeds = self.text_encoder(
372
uncond_input.input_ids.to(device),
373
attention_mask=attention_mask,
374
)
375
negative_prompt_embeds = negative_prompt_embeds[0]
376
377
if do_classifier_free_guidance:
378
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
379
seq_len = negative_prompt_embeds.shape[1]
380
381
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
382
383
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
384
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
385
386
# For classifier free guidance, we need to do two forward passes.
387
# Here we concatenate the unconditional and text embeddings into a single batch
388
# to avoid doing two forward passes
389
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
390
391
return prompt_embeds
392
393
def run_safety_checker(self, image, device, dtype):
394
if self.safety_checker is not None:
395
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
396
image, has_nsfw_concept = self.safety_checker(
397
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
398
)
399
else:
400
has_nsfw_concept = None
401
return image, has_nsfw_concept
402
403
def decode_latents(self, latents):
404
latents = 1 / self.vae.config.scaling_factor * latents
405
image = self.vae.decode(latents).sample
406
image = (image / 2 + 0.5).clamp(0, 1)
407
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
408
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
409
return image
410
411
def prepare_extra_step_kwargs(self, generator, eta):
412
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
413
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
414
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
415
# and should be between [0, 1]
416
417
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
418
extra_step_kwargs = {}
419
if accepts_eta:
420
extra_step_kwargs["eta"] = eta
421
422
# check if the scheduler accepts generator
423
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
424
if accepts_generator:
425
extra_step_kwargs["generator"] = generator
426
return extra_step_kwargs
427
428
def check_inputs(
429
self,
430
prompt,
431
image,
432
controlnet_conditioning_image,
433
height,
434
width,
435
callback_steps,
436
negative_prompt=None,
437
prompt_embeds=None,
438
negative_prompt_embeds=None,
439
strength=None,
440
controlnet_guidance_start=None,
441
controlnet_guidance_end=None,
442
):
443
if height % 8 != 0 or width % 8 != 0:
444
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
445
446
if (callback_steps is None) or (
447
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
448
):
449
raise ValueError(
450
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
451
f" {type(callback_steps)}."
452
)
453
454
if prompt is not None and prompt_embeds is not None:
455
raise ValueError(
456
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
457
" only forward one of the two."
458
)
459
elif prompt is None and prompt_embeds is None:
460
raise ValueError(
461
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
462
)
463
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
464
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
465
466
if negative_prompt is not None and negative_prompt_embeds is not None:
467
raise ValueError(
468
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
469
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
470
)
471
472
if prompt_embeds is not None and negative_prompt_embeds is not None:
473
if prompt_embeds.shape != negative_prompt_embeds.shape:
474
raise ValueError(
475
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
476
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
477
f" {negative_prompt_embeds.shape}."
478
)
479
480
controlnet_cond_image_is_pil = isinstance(controlnet_conditioning_image, PIL.Image.Image)
481
controlnet_cond_image_is_tensor = isinstance(controlnet_conditioning_image, torch.Tensor)
482
controlnet_cond_image_is_pil_list = isinstance(controlnet_conditioning_image, list) and isinstance(
483
controlnet_conditioning_image[0], PIL.Image.Image
484
)
485
controlnet_cond_image_is_tensor_list = isinstance(controlnet_conditioning_image, list) and isinstance(
486
controlnet_conditioning_image[0], torch.Tensor
487
)
488
489
if (
490
not controlnet_cond_image_is_pil
491
and not controlnet_cond_image_is_tensor
492
and not controlnet_cond_image_is_pil_list
493
and not controlnet_cond_image_is_tensor_list
494
):
495
raise TypeError(
496
"image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"
497
)
498
499
if controlnet_cond_image_is_pil:
500
controlnet_cond_image_batch_size = 1
501
elif controlnet_cond_image_is_tensor:
502
controlnet_cond_image_batch_size = controlnet_conditioning_image.shape[0]
503
elif controlnet_cond_image_is_pil_list:
504
controlnet_cond_image_batch_size = len(controlnet_conditioning_image)
505
elif controlnet_cond_image_is_tensor_list:
506
controlnet_cond_image_batch_size = len(controlnet_conditioning_image)
507
508
if prompt is not None and isinstance(prompt, str):
509
prompt_batch_size = 1
510
elif prompt is not None and isinstance(prompt, list):
511
prompt_batch_size = len(prompt)
512
elif prompt_embeds is not None:
513
prompt_batch_size = prompt_embeds.shape[0]
514
515
if controlnet_cond_image_batch_size != 1 and controlnet_cond_image_batch_size != prompt_batch_size:
516
raise ValueError(
517
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {controlnet_cond_image_batch_size}, prompt batch size: {prompt_batch_size}"
518
)
519
520
if isinstance(image, torch.Tensor):
521
if image.ndim != 3 and image.ndim != 4:
522
raise ValueError("`image` must have 3 or 4 dimensions")
523
524
# if mask_image.ndim != 2 and mask_image.ndim != 3 and mask_image.ndim != 4:
525
# raise ValueError("`mask_image` must have 2, 3, or 4 dimensions")
526
527
if image.ndim == 3:
528
image_batch_size = 1
529
image_channels, image_height, image_width = image.shape
530
elif image.ndim == 4:
531
image_batch_size, image_channels, image_height, image_width = image.shape
532
533
if image_channels != 3:
534
raise ValueError("`image` must have 3 channels")
535
536
if image.min() < -1 or image.max() > 1:
537
raise ValueError("`image` should be in range [-1, 1]")
538
539
if self.vae.config.latent_channels != self.unet.config.in_channels:
540
raise ValueError(
541
f"The config of `pipeline.unet` expects {self.unet.config.in_channels} but received"
542
f" latent channels: {self.vae.config.latent_channels},"
543
f" Please verify the config of `pipeline.unet` and the `pipeline.vae`"
544
)
545
546
if strength < 0 or strength > 1:
547
raise ValueError(f"The value of `strength` should in [0.0, 1.0] but is {strength}")
548
549
if controlnet_guidance_start < 0 or controlnet_guidance_start > 1:
550
raise ValueError(
551
f"The value of `controlnet_guidance_start` should in [0.0, 1.0] but is {controlnet_guidance_start}"
552
)
553
554
if controlnet_guidance_end < 0 or controlnet_guidance_end > 1:
555
raise ValueError(
556
f"The value of `controlnet_guidance_end` should in [0.0, 1.0] but is {controlnet_guidance_end}"
557
)
558
559
if controlnet_guidance_start > controlnet_guidance_end:
560
raise ValueError(
561
"The value of `controlnet_guidance_start` should be less than `controlnet_guidance_end`, but got"
562
f" `controlnet_guidance_start` {controlnet_guidance_start} >= `controlnet_guidance_end` {controlnet_guidance_end}"
563
)
564
565
def get_timesteps(self, num_inference_steps, strength, device):
566
# get the original timestep using init_timestep
567
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
568
569
t_start = max(num_inference_steps - init_timestep, 0)
570
timesteps = self.scheduler.timesteps[t_start:]
571
572
return timesteps, num_inference_steps - t_start
573
574
def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
575
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
576
raise ValueError(
577
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
578
)
579
580
image = image.to(device=device, dtype=dtype)
581
582
batch_size = batch_size * num_images_per_prompt
583
if isinstance(generator, list) and len(generator) != batch_size:
584
raise ValueError(
585
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
586
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
587
)
588
589
if isinstance(generator, list):
590
init_latents = [
591
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
592
]
593
init_latents = torch.cat(init_latents, dim=0)
594
else:
595
init_latents = self.vae.encode(image).latent_dist.sample(generator)
596
597
init_latents = self.vae.config.scaling_factor * init_latents
598
599
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
600
raise ValueError(
601
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
602
)
603
else:
604
init_latents = torch.cat([init_latents], dim=0)
605
606
shape = init_latents.shape
607
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
608
609
# get latents
610
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
611
latents = init_latents
612
613
return latents
614
615
def _default_height_width(self, height, width, image):
616
if isinstance(image, list):
617
image = image[0]
618
619
if height is None:
620
if isinstance(image, PIL.Image.Image):
621
height = image.height
622
elif isinstance(image, torch.Tensor):
623
height = image.shape[3]
624
625
height = (height // 8) * 8 # round down to nearest multiple of 8
626
627
if width is None:
628
if isinstance(image, PIL.Image.Image):
629
width = image.width
630
elif isinstance(image, torch.Tensor):
631
width = image.shape[2]
632
633
width = (width // 8) * 8 # round down to nearest multiple of 8
634
635
return height, width
636
637
@torch.no_grad()
638
@replace_example_docstring(EXAMPLE_DOC_STRING)
639
def __call__(
640
self,
641
prompt: Union[str, List[str]] = None,
642
image: Union[torch.Tensor, PIL.Image.Image] = None,
643
controlnet_conditioning_image: Union[
644
torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]
645
] = None,
646
strength: float = 0.8,
647
height: Optional[int] = None,
648
width: Optional[int] = None,
649
num_inference_steps: int = 50,
650
guidance_scale: float = 7.5,
651
negative_prompt: Optional[Union[str, List[str]]] = None,
652
num_images_per_prompt: Optional[int] = 1,
653
eta: float = 0.0,
654
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
655
latents: Optional[torch.FloatTensor] = None,
656
prompt_embeds: Optional[torch.FloatTensor] = None,
657
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
658
output_type: Optional[str] = "pil",
659
return_dict: bool = True,
660
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
661
callback_steps: int = 1,
662
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
663
controlnet_conditioning_scale: float = 1.0,
664
controlnet_guidance_start: float = 0.0,
665
controlnet_guidance_end: float = 1.0,
666
):
667
r"""
668
Function invoked when calling the pipeline for generation.
669
670
Args:
671
prompt (`str` or `List[str]`, *optional*):
672
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
673
instead.
674
image (`torch.Tensor` or `PIL.Image.Image`):
675
`Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
676
be masked out with `mask_image` and repainted according to `prompt`.
677
controlnet_conditioning_image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]` or `List[PIL.Image.Image]`):
678
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
679
the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. PIL.Image.Image` can
680
also be accepted as an image. The control image is automatically resized to fit the output image.
681
strength (`float`, *optional*):
682
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
683
will be used as a starting point, adding more noise to it the larger the `strength`. The number of
684
denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
685
be maximum and the denoising process will run for the full number of iterations specified in
686
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
687
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
688
The height in pixels of the generated image.
689
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
690
The width in pixels of the generated image.
691
num_inference_steps (`int`, *optional*, defaults to 50):
692
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
693
expense of slower inference.
694
guidance_scale (`float`, *optional*, defaults to 7.5):
695
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
696
`guidance_scale` is defined as `w` of equation 2. of [Imagen
697
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
698
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
699
usually at the expense of lower image quality.
700
negative_prompt (`str` or `List[str]`, *optional*):
701
The prompt or prompts not to guide the image generation. If not defined, one has to pass
702
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
703
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
704
num_images_per_prompt (`int`, *optional*, defaults to 1):
705
The number of images to generate per prompt.
706
eta (`float`, *optional*, defaults to 0.0):
707
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
708
[`schedulers.DDIMScheduler`], will be ignored for others.
709
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
710
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
711
to make generation deterministic.
712
latents (`torch.FloatTensor`, *optional*):
713
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
714
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
715
tensor will ge generated by sampling using the supplied random `generator`.
716
prompt_embeds (`torch.FloatTensor`, *optional*):
717
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
718
provided, text embeddings will be generated from `prompt` input argument.
719
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
720
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
721
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
722
argument.
723
output_type (`str`, *optional*, defaults to `"pil"`):
724
The output format of the generate image. Choose between
725
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
726
return_dict (`bool`, *optional*, defaults to `True`):
727
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
728
plain tuple.
729
callback (`Callable`, *optional*):
730
A function that will be called every `callback_steps` steps during inference. The function will be
731
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
732
callback_steps (`int`, *optional*, defaults to 1):
733
The frequency at which the `callback` function will be called. If not specified, the callback will be
734
called at every step.
735
cross_attention_kwargs (`dict`, *optional*):
736
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
737
`self.processor` in
738
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
739
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
740
The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
741
to the residual in the original unet.
742
controlnet_guidance_start ('float', *optional*, defaults to 0.0):
743
The percentage of total steps the controlnet starts applying. Must be between 0 and 1.
744
controlnet_guidance_end ('float', *optional*, defaults to 1.0):
745
The percentage of total steps the controlnet ends applying. Must be between 0 and 1. Must be greater
746
than `controlnet_guidance_start`.
747
748
Examples:
749
750
Returns:
751
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
752
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
753
When returning a tuple, the first element is a list with the generated images, and the second element is a
754
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
755
(nsfw) content, according to the `safety_checker`.
756
"""
757
# 0. Default height and width to unet
758
height, width = self._default_height_width(height, width, controlnet_conditioning_image)
759
760
# 1. Check inputs. Raise error if not correct
761
self.check_inputs(
762
prompt,
763
image,
764
# mask_image,
765
controlnet_conditioning_image,
766
height,
767
width,
768
callback_steps,
769
negative_prompt,
770
prompt_embeds,
771
negative_prompt_embeds,
772
strength,
773
controlnet_guidance_start,
774
controlnet_guidance_end,
775
)
776
777
# 2. Define call parameters
778
if prompt is not None and isinstance(prompt, str):
779
batch_size = 1
780
elif prompt is not None and isinstance(prompt, list):
781
batch_size = len(prompt)
782
else:
783
batch_size = prompt_embeds.shape[0]
784
785
device = self._execution_device
786
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
787
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
788
# corresponds to doing no classifier free guidance.
789
do_classifier_free_guidance = guidance_scale > 1.0
790
791
# 3. Encode input prompt
792
prompt_embeds = self._encode_prompt(
793
prompt,
794
device,
795
num_images_per_prompt,
796
do_classifier_free_guidance,
797
negative_prompt,
798
prompt_embeds=prompt_embeds,
799
negative_prompt_embeds=negative_prompt_embeds,
800
)
801
802
# 4. Prepare mask, image, and controlnet_conditioning_image
803
image = prepare_image(image)
804
805
# mask_image = prepare_mask_image(mask_image)
806
807
controlnet_conditioning_image = prepare_controlnet_conditioning_image(
808
controlnet_conditioning_image,
809
width,
810
height,
811
batch_size * num_images_per_prompt,
812
num_images_per_prompt,
813
device,
814
self.controlnet.dtype,
815
)
816
817
# masked_image = image * (mask_image < 0.5)
818
819
# 5. Prepare timesteps
820
self.scheduler.set_timesteps(num_inference_steps, device=device)
821
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
822
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
823
824
# 6. Prepare latent variables
825
latents = self.prepare_latents(
826
image,
827
latent_timestep,
828
batch_size,
829
num_images_per_prompt,
830
prompt_embeds.dtype,
831
device,
832
generator,
833
)
834
835
if do_classifier_free_guidance:
836
controlnet_conditioning_image = torch.cat([controlnet_conditioning_image] * 2)
837
838
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
839
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
840
841
# 8. Denoising loop
842
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
843
with self.progress_bar(total=num_inference_steps) as progress_bar:
844
for i, t in enumerate(timesteps):
845
# expand the latents if we are doing classifier free guidance
846
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
847
848
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
849
850
# compute the percentage of total steps we are at
851
current_sampling_percent = i / len(timesteps)
852
853
if (
854
current_sampling_percent < controlnet_guidance_start
855
or current_sampling_percent > controlnet_guidance_end
856
):
857
# do not apply the controlnet
858
down_block_res_samples = None
859
mid_block_res_sample = None
860
else:
861
# apply the controlnet
862
down_block_res_samples, mid_block_res_sample = self.controlnet(
863
latent_model_input,
864
t,
865
encoder_hidden_states=prompt_embeds,
866
controlnet_cond=controlnet_conditioning_image,
867
return_dict=False,
868
)
869
870
down_block_res_samples = [
871
down_block_res_sample * controlnet_conditioning_scale
872
for down_block_res_sample in down_block_res_samples
873
]
874
mid_block_res_sample *= controlnet_conditioning_scale
875
876
# predict the noise residual
877
noise_pred = self.unet(
878
latent_model_input,
879
t,
880
encoder_hidden_states=prompt_embeds,
881
cross_attention_kwargs=cross_attention_kwargs,
882
down_block_additional_residuals=down_block_res_samples,
883
mid_block_additional_residual=mid_block_res_sample,
884
).sample
885
886
# perform guidance
887
if do_classifier_free_guidance:
888
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
889
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
890
891
# compute the previous noisy sample x_t -> x_t-1
892
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
893
894
# call the callback, if provided
895
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
896
progress_bar.update()
897
if callback is not None and i % callback_steps == 0:
898
callback(i, t, latents)
899
900
# If we do sequential model offloading, let's offload unet and controlnet
901
# manually for max memory savings
902
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
903
self.unet.to("cpu")
904
self.controlnet.to("cpu")
905
torch.cuda.empty_cache()
906
907
if output_type == "latent":
908
image = latents
909
has_nsfw_concept = None
910
elif output_type == "pil":
911
# 8. Post-processing
912
image = self.decode_latents(latents)
913
914
# 9. Run safety checker
915
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
916
917
# 10. Convert to PIL
918
image = self.numpy_to_pil(image)
919
else:
920
# 8. Post-processing
921
image = self.decode_latents(latents)
922
923
# 9. Run safety checker
924
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
925
926
# Offload last model to CPU
927
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
928
self.final_offload_hook.offload()
929
930
if not return_dict:
931
return (image, has_nsfw_concept)
932
933
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
934
935