Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/examples/community/stable_diffusion_controlnet_inpaint_img2img.py
1448 views
1
# Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/
2
3
import inspect
4
from typing import Any, Callable, Dict, List, Optional, Union
5
6
import numpy as np
7
import PIL.Image
8
import torch
9
import torch.nn.functional as F
10
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
11
12
from diffusers import AutoencoderKL, ControlNetModel, DiffusionPipeline, UNet2DConditionModel, logging
13
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
14
from diffusers.schedulers import KarrasDiffusionSchedulers
15
from diffusers.utils import (
16
PIL_INTERPOLATION,
17
is_accelerate_available,
18
is_accelerate_version,
19
randn_tensor,
20
replace_example_docstring,
21
)
22
23
24
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
25
26
EXAMPLE_DOC_STRING = """
27
Examples:
28
```py
29
>>> import numpy as np
30
>>> import torch
31
>>> from PIL import Image
32
>>> from stable_diffusion_controlnet_inpaint_img2img import StableDiffusionControlNetInpaintImg2ImgPipeline
33
34
>>> from transformers import AutoImageProcessor, UperNetForSemanticSegmentation
35
>>> from diffusers import ControlNetModel, UniPCMultistepScheduler
36
>>> from diffusers.utils import load_image
37
38
>>> def ade_palette():
39
return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
40
[4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
41
[230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
42
[150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
43
[143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
44
[0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
45
[255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
46
[255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
47
[255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
48
[224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
49
[255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
50
[6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
51
[140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
52
[255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
53
[255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
54
[11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
55
[0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
56
[255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
57
[0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
58
[173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
59
[255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
60
[255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
61
[255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
62
[0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
63
[0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
64
[143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
65
[8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
66
[255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
67
[92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
68
[163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
69
[255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
70
[255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
71
[10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
72
[255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
73
[41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
74
[71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
75
[184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
76
[102, 255, 0], [92, 0, 255]]
77
78
>>> image_processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-small")
79
>>> image_segmentor = UperNetForSemanticSegmentation.from_pretrained("openmmlab/upernet-convnext-small")
80
81
>>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-seg", torch_dtype=torch.float16)
82
83
>>> pipe = StableDiffusionControlNetInpaintImg2ImgPipeline.from_pretrained(
84
"runwayml/stable-diffusion-inpainting", controlnet=controlnet, safety_checker=None, torch_dtype=torch.float16
85
)
86
87
>>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
88
>>> pipe.enable_xformers_memory_efficient_attention()
89
>>> pipe.enable_model_cpu_offload()
90
91
>>> def image_to_seg(image):
92
pixel_values = image_processor(image, return_tensors="pt").pixel_values
93
with torch.no_grad():
94
outputs = image_segmentor(pixel_values)
95
seg = image_processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
96
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3
97
palette = np.array(ade_palette())
98
for label, color in enumerate(palette):
99
color_seg[seg == label, :] = color
100
color_seg = color_seg.astype(np.uint8)
101
seg_image = Image.fromarray(color_seg)
102
return seg_image
103
104
>>> image = load_image(
105
"https://github.com/CompVis/latent-diffusion/raw/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
106
)
107
108
>>> mask_image = load_image(
109
"https://github.com/CompVis/latent-diffusion/raw/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
110
)
111
112
>>> controlnet_conditioning_image = image_to_seg(image)
113
114
>>> image = pipe(
115
"Face of a yellow cat, high resolution, sitting on a park bench",
116
image,
117
mask_image,
118
controlnet_conditioning_image,
119
num_inference_steps=20,
120
).images[0]
121
122
>>> image.save("out.png")
123
```
124
"""
125
126
127
def prepare_image(image):
128
if isinstance(image, torch.Tensor):
129
# Batch single image
130
if image.ndim == 3:
131
image = image.unsqueeze(0)
132
133
image = image.to(dtype=torch.float32)
134
else:
135
# preprocess image
136
if isinstance(image, (PIL.Image.Image, np.ndarray)):
137
image = [image]
138
139
if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
140
image = [np.array(i.convert("RGB"))[None, :] for i in image]
141
image = np.concatenate(image, axis=0)
142
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
143
image = np.concatenate([i[None, :] for i in image], axis=0)
144
145
image = image.transpose(0, 3, 1, 2)
146
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
147
148
return image
149
150
151
def prepare_mask_image(mask_image):
152
if isinstance(mask_image, torch.Tensor):
153
if mask_image.ndim == 2:
154
# Batch and add channel dim for single mask
155
mask_image = mask_image.unsqueeze(0).unsqueeze(0)
156
elif mask_image.ndim == 3 and mask_image.shape[0] == 1:
157
# Single mask, the 0'th dimension is considered to be
158
# the existing batch size of 1
159
mask_image = mask_image.unsqueeze(0)
160
elif mask_image.ndim == 3 and mask_image.shape[0] != 1:
161
# Batch of mask, the 0'th dimension is considered to be
162
# the batching dimension
163
mask_image = mask_image.unsqueeze(1)
164
165
# Binarize mask
166
mask_image[mask_image < 0.5] = 0
167
mask_image[mask_image >= 0.5] = 1
168
else:
169
# preprocess mask
170
if isinstance(mask_image, (PIL.Image.Image, np.ndarray)):
171
mask_image = [mask_image]
172
173
if isinstance(mask_image, list) and isinstance(mask_image[0], PIL.Image.Image):
174
mask_image = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask_image], axis=0)
175
mask_image = mask_image.astype(np.float32) / 255.0
176
elif isinstance(mask_image, list) and isinstance(mask_image[0], np.ndarray):
177
mask_image = np.concatenate([m[None, None, :] for m in mask_image], axis=0)
178
179
mask_image[mask_image < 0.5] = 0
180
mask_image[mask_image >= 0.5] = 1
181
mask_image = torch.from_numpy(mask_image)
182
183
return mask_image
184
185
186
def prepare_controlnet_conditioning_image(
187
controlnet_conditioning_image, width, height, batch_size, num_images_per_prompt, device, dtype
188
):
189
if not isinstance(controlnet_conditioning_image, torch.Tensor):
190
if isinstance(controlnet_conditioning_image, PIL.Image.Image):
191
controlnet_conditioning_image = [controlnet_conditioning_image]
192
193
if isinstance(controlnet_conditioning_image[0], PIL.Image.Image):
194
controlnet_conditioning_image = [
195
np.array(i.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]))[None, :]
196
for i in controlnet_conditioning_image
197
]
198
controlnet_conditioning_image = np.concatenate(controlnet_conditioning_image, axis=0)
199
controlnet_conditioning_image = np.array(controlnet_conditioning_image).astype(np.float32) / 255.0
200
controlnet_conditioning_image = controlnet_conditioning_image.transpose(0, 3, 1, 2)
201
controlnet_conditioning_image = torch.from_numpy(controlnet_conditioning_image)
202
elif isinstance(controlnet_conditioning_image[0], torch.Tensor):
203
controlnet_conditioning_image = torch.cat(controlnet_conditioning_image, dim=0)
204
205
image_batch_size = controlnet_conditioning_image.shape[0]
206
207
if image_batch_size == 1:
208
repeat_by = batch_size
209
else:
210
# image batch size is the same as prompt batch size
211
repeat_by = num_images_per_prompt
212
213
controlnet_conditioning_image = controlnet_conditioning_image.repeat_interleave(repeat_by, dim=0)
214
215
controlnet_conditioning_image = controlnet_conditioning_image.to(device=device, dtype=dtype)
216
217
return controlnet_conditioning_image
218
219
220
class StableDiffusionControlNetInpaintImg2ImgPipeline(DiffusionPipeline):
221
"""
222
Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/
223
"""
224
225
_optional_components = ["safety_checker", "feature_extractor"]
226
227
def __init__(
228
self,
229
vae: AutoencoderKL,
230
text_encoder: CLIPTextModel,
231
tokenizer: CLIPTokenizer,
232
unet: UNet2DConditionModel,
233
controlnet: ControlNetModel,
234
scheduler: KarrasDiffusionSchedulers,
235
safety_checker: StableDiffusionSafetyChecker,
236
feature_extractor: CLIPImageProcessor,
237
requires_safety_checker: bool = True,
238
):
239
super().__init__()
240
241
if safety_checker is None and requires_safety_checker:
242
logger.warning(
243
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
244
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
245
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
246
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
247
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
248
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
249
)
250
251
if safety_checker is not None and feature_extractor is None:
252
raise ValueError(
253
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
254
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
255
)
256
257
self.register_modules(
258
vae=vae,
259
text_encoder=text_encoder,
260
tokenizer=tokenizer,
261
unet=unet,
262
controlnet=controlnet,
263
scheduler=scheduler,
264
safety_checker=safety_checker,
265
feature_extractor=feature_extractor,
266
)
267
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
268
self.register_to_config(requires_safety_checker=requires_safety_checker)
269
270
def enable_vae_slicing(self):
271
r"""
272
Enable sliced VAE decoding.
273
274
When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
275
steps. This is useful to save some memory and allow larger batch sizes.
276
"""
277
self.vae.enable_slicing()
278
279
def disable_vae_slicing(self):
280
r"""
281
Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
282
computing decoding in one step.
283
"""
284
self.vae.disable_slicing()
285
286
def enable_sequential_cpu_offload(self, gpu_id=0):
287
r"""
288
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
289
text_encoder, vae, controlnet, and safety checker have their state dicts saved to CPU and then are moved to a
290
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
291
Note that offloading happens on a submodule basis. Memory savings are higher than with
292
`enable_model_cpu_offload`, but performance is lower.
293
"""
294
if is_accelerate_available():
295
from accelerate import cpu_offload
296
else:
297
raise ImportError("Please install accelerate via `pip install accelerate`")
298
299
device = torch.device(f"cuda:{gpu_id}")
300
301
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.controlnet]:
302
cpu_offload(cpu_offloaded_model, device)
303
304
if self.safety_checker is not None:
305
cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
306
307
def enable_model_cpu_offload(self, gpu_id=0):
308
r"""
309
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
310
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
311
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
312
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
313
"""
314
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
315
from accelerate import cpu_offload_with_hook
316
else:
317
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
318
319
device = torch.device(f"cuda:{gpu_id}")
320
321
hook = None
322
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
323
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
324
325
if self.safety_checker is not None:
326
# the safety checker can offload the vae again
327
_, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
328
329
# control net hook has be manually offloaded as it alternates with unet
330
cpu_offload_with_hook(self.controlnet, device)
331
332
# We'll offload the last model manually.
333
self.final_offload_hook = hook
334
335
@property
336
def _execution_device(self):
337
r"""
338
Returns the device on which the pipeline's models will be executed. After calling
339
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
340
hooks.
341
"""
342
if not hasattr(self.unet, "_hf_hook"):
343
return self.device
344
for module in self.unet.modules():
345
if (
346
hasattr(module, "_hf_hook")
347
and hasattr(module._hf_hook, "execution_device")
348
and module._hf_hook.execution_device is not None
349
):
350
return torch.device(module._hf_hook.execution_device)
351
return self.device
352
353
def _encode_prompt(
354
self,
355
prompt,
356
device,
357
num_images_per_prompt,
358
do_classifier_free_guidance,
359
negative_prompt=None,
360
prompt_embeds: Optional[torch.FloatTensor] = None,
361
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
362
):
363
r"""
364
Encodes the prompt into text encoder hidden states.
365
366
Args:
367
prompt (`str` or `List[str]`, *optional*):
368
prompt to be encoded
369
device: (`torch.device`):
370
torch device
371
num_images_per_prompt (`int`):
372
number of images that should be generated per prompt
373
do_classifier_free_guidance (`bool`):
374
whether to use classifier free guidance or not
375
negative_prompt (`str` or `List[str]`, *optional*):
376
The prompt or prompts not to guide the image generation. If not defined, one has to pass
377
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
378
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
379
prompt_embeds (`torch.FloatTensor`, *optional*):
380
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
381
provided, text embeddings will be generated from `prompt` input argument.
382
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
383
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
384
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
385
argument.
386
"""
387
if prompt is not None and isinstance(prompt, str):
388
batch_size = 1
389
elif prompt is not None and isinstance(prompt, list):
390
batch_size = len(prompt)
391
else:
392
batch_size = prompt_embeds.shape[0]
393
394
if prompt_embeds is None:
395
text_inputs = self.tokenizer(
396
prompt,
397
padding="max_length",
398
max_length=self.tokenizer.model_max_length,
399
truncation=True,
400
return_tensors="pt",
401
)
402
text_input_ids = text_inputs.input_ids
403
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
404
405
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
406
text_input_ids, untruncated_ids
407
):
408
removed_text = self.tokenizer.batch_decode(
409
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
410
)
411
logger.warning(
412
"The following part of your input was truncated because CLIP can only handle sequences up to"
413
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
414
)
415
416
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
417
attention_mask = text_inputs.attention_mask.to(device)
418
else:
419
attention_mask = None
420
421
prompt_embeds = self.text_encoder(
422
text_input_ids.to(device),
423
attention_mask=attention_mask,
424
)
425
prompt_embeds = prompt_embeds[0]
426
427
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
428
429
bs_embed, seq_len, _ = prompt_embeds.shape
430
# duplicate text embeddings for each generation per prompt, using mps friendly method
431
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
432
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
433
434
# get unconditional embeddings for classifier free guidance
435
if do_classifier_free_guidance and negative_prompt_embeds is None:
436
uncond_tokens: List[str]
437
if negative_prompt is None:
438
uncond_tokens = [""] * batch_size
439
elif type(prompt) is not type(negative_prompt):
440
raise TypeError(
441
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
442
f" {type(prompt)}."
443
)
444
elif isinstance(negative_prompt, str):
445
uncond_tokens = [negative_prompt]
446
elif batch_size != len(negative_prompt):
447
raise ValueError(
448
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
449
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
450
" the batch size of `prompt`."
451
)
452
else:
453
uncond_tokens = negative_prompt
454
455
max_length = prompt_embeds.shape[1]
456
uncond_input = self.tokenizer(
457
uncond_tokens,
458
padding="max_length",
459
max_length=max_length,
460
truncation=True,
461
return_tensors="pt",
462
)
463
464
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
465
attention_mask = uncond_input.attention_mask.to(device)
466
else:
467
attention_mask = None
468
469
negative_prompt_embeds = self.text_encoder(
470
uncond_input.input_ids.to(device),
471
attention_mask=attention_mask,
472
)
473
negative_prompt_embeds = negative_prompt_embeds[0]
474
475
if do_classifier_free_guidance:
476
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
477
seq_len = negative_prompt_embeds.shape[1]
478
479
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
480
481
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
482
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
483
484
# For classifier free guidance, we need to do two forward passes.
485
# Here we concatenate the unconditional and text embeddings into a single batch
486
# to avoid doing two forward passes
487
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
488
489
return prompt_embeds
490
491
def run_safety_checker(self, image, device, dtype):
492
if self.safety_checker is not None:
493
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
494
image, has_nsfw_concept = self.safety_checker(
495
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
496
)
497
else:
498
has_nsfw_concept = None
499
return image, has_nsfw_concept
500
501
def decode_latents(self, latents):
502
latents = 1 / self.vae.config.scaling_factor * latents
503
image = self.vae.decode(latents).sample
504
image = (image / 2 + 0.5).clamp(0, 1)
505
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
506
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
507
return image
508
509
def prepare_extra_step_kwargs(self, generator, eta):
510
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
511
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
512
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
513
# and should be between [0, 1]
514
515
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
516
extra_step_kwargs = {}
517
if accepts_eta:
518
extra_step_kwargs["eta"] = eta
519
520
# check if the scheduler accepts generator
521
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
522
if accepts_generator:
523
extra_step_kwargs["generator"] = generator
524
return extra_step_kwargs
525
526
def check_inputs(
527
self,
528
prompt,
529
image,
530
mask_image,
531
controlnet_conditioning_image,
532
height,
533
width,
534
callback_steps,
535
negative_prompt=None,
536
prompt_embeds=None,
537
negative_prompt_embeds=None,
538
strength=None,
539
):
540
if height % 8 != 0 or width % 8 != 0:
541
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
542
543
if (callback_steps is None) or (
544
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
545
):
546
raise ValueError(
547
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
548
f" {type(callback_steps)}."
549
)
550
551
if prompt is not None and prompt_embeds is not None:
552
raise ValueError(
553
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
554
" only forward one of the two."
555
)
556
elif prompt is None and prompt_embeds is None:
557
raise ValueError(
558
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
559
)
560
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
561
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
562
563
if negative_prompt is not None and negative_prompt_embeds is not None:
564
raise ValueError(
565
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
566
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
567
)
568
569
if prompt_embeds is not None and negative_prompt_embeds is not None:
570
if prompt_embeds.shape != negative_prompt_embeds.shape:
571
raise ValueError(
572
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
573
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
574
f" {negative_prompt_embeds.shape}."
575
)
576
577
controlnet_cond_image_is_pil = isinstance(controlnet_conditioning_image, PIL.Image.Image)
578
controlnet_cond_image_is_tensor = isinstance(controlnet_conditioning_image, torch.Tensor)
579
controlnet_cond_image_is_pil_list = isinstance(controlnet_conditioning_image, list) and isinstance(
580
controlnet_conditioning_image[0], PIL.Image.Image
581
)
582
controlnet_cond_image_is_tensor_list = isinstance(controlnet_conditioning_image, list) and isinstance(
583
controlnet_conditioning_image[0], torch.Tensor
584
)
585
586
if (
587
not controlnet_cond_image_is_pil
588
and not controlnet_cond_image_is_tensor
589
and not controlnet_cond_image_is_pil_list
590
and not controlnet_cond_image_is_tensor_list
591
):
592
raise TypeError(
593
"image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"
594
)
595
596
if controlnet_cond_image_is_pil:
597
controlnet_cond_image_batch_size = 1
598
elif controlnet_cond_image_is_tensor:
599
controlnet_cond_image_batch_size = controlnet_conditioning_image.shape[0]
600
elif controlnet_cond_image_is_pil_list:
601
controlnet_cond_image_batch_size = len(controlnet_conditioning_image)
602
elif controlnet_cond_image_is_tensor_list:
603
controlnet_cond_image_batch_size = len(controlnet_conditioning_image)
604
605
if prompt is not None and isinstance(prompt, str):
606
prompt_batch_size = 1
607
elif prompt is not None and isinstance(prompt, list):
608
prompt_batch_size = len(prompt)
609
elif prompt_embeds is not None:
610
prompt_batch_size = prompt_embeds.shape[0]
611
612
if controlnet_cond_image_batch_size != 1 and controlnet_cond_image_batch_size != prompt_batch_size:
613
raise ValueError(
614
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {controlnet_cond_image_batch_size}, prompt batch size: {prompt_batch_size}"
615
)
616
617
if isinstance(image, torch.Tensor) and not isinstance(mask_image, torch.Tensor):
618
raise TypeError("if `image` is a tensor, `mask_image` must also be a tensor")
619
620
if isinstance(image, PIL.Image.Image) and not isinstance(mask_image, PIL.Image.Image):
621
raise TypeError("if `image` is a PIL image, `mask_image` must also be a PIL image")
622
623
if isinstance(image, torch.Tensor):
624
if image.ndim != 3 and image.ndim != 4:
625
raise ValueError("`image` must have 3 or 4 dimensions")
626
627
if mask_image.ndim != 2 and mask_image.ndim != 3 and mask_image.ndim != 4:
628
raise ValueError("`mask_image` must have 2, 3, or 4 dimensions")
629
630
if image.ndim == 3:
631
image_batch_size = 1
632
image_channels, image_height, image_width = image.shape
633
elif image.ndim == 4:
634
image_batch_size, image_channels, image_height, image_width = image.shape
635
636
if mask_image.ndim == 2:
637
mask_image_batch_size = 1
638
mask_image_channels = 1
639
mask_image_height, mask_image_width = mask_image.shape
640
elif mask_image.ndim == 3:
641
mask_image_channels = 1
642
mask_image_batch_size, mask_image_height, mask_image_width = mask_image.shape
643
elif mask_image.ndim == 4:
644
mask_image_batch_size, mask_image_channels, mask_image_height, mask_image_width = mask_image.shape
645
646
if image_channels != 3:
647
raise ValueError("`image` must have 3 channels")
648
649
if mask_image_channels != 1:
650
raise ValueError("`mask_image` must have 1 channel")
651
652
if image_batch_size != mask_image_batch_size:
653
raise ValueError("`image` and `mask_image` mush have the same batch sizes")
654
655
if image_height != mask_image_height or image_width != mask_image_width:
656
raise ValueError("`image` and `mask_image` must have the same height and width dimensions")
657
658
if image.min() < -1 or image.max() > 1:
659
raise ValueError("`image` should be in range [-1, 1]")
660
661
if mask_image.min() < 0 or mask_image.max() > 1:
662
raise ValueError("`mask_image` should be in range [0, 1]")
663
else:
664
mask_image_channels = 1
665
image_channels = 3
666
667
single_image_latent_channels = self.vae.config.latent_channels
668
669
total_latent_channels = single_image_latent_channels * 2 + mask_image_channels
670
671
if total_latent_channels != self.unet.config.in_channels:
672
raise ValueError(
673
f"The config of `pipeline.unet` expects {self.unet.config.in_channels} but received"
674
f" non inpainting latent channels: {single_image_latent_channels},"
675
f" mask channels: {mask_image_channels}, and masked image channels: {single_image_latent_channels}."
676
f" Please verify the config of `pipeline.unet` and the `mask_image` and `image` inputs."
677
)
678
679
if strength < 0 or strength > 1:
680
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
681
682
def get_timesteps(self, num_inference_steps, strength, device):
683
# get the original timestep using init_timestep
684
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
685
686
t_start = max(num_inference_steps - init_timestep, 0)
687
timesteps = self.scheduler.timesteps[t_start:]
688
689
return timesteps, num_inference_steps - t_start
690
691
def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
692
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
693
raise ValueError(
694
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
695
)
696
697
image = image.to(device=device, dtype=dtype)
698
699
batch_size = batch_size * num_images_per_prompt
700
if isinstance(generator, list) and len(generator) != batch_size:
701
raise ValueError(
702
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
703
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
704
)
705
706
if isinstance(generator, list):
707
init_latents = [
708
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
709
]
710
init_latents = torch.cat(init_latents, dim=0)
711
else:
712
init_latents = self.vae.encode(image).latent_dist.sample(generator)
713
714
init_latents = self.vae.config.scaling_factor * init_latents
715
716
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
717
raise ValueError(
718
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
719
)
720
else:
721
init_latents = torch.cat([init_latents], dim=0)
722
723
shape = init_latents.shape
724
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
725
726
# get latents
727
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
728
latents = init_latents
729
730
return latents
731
732
def prepare_mask_latents(self, mask_image, batch_size, height, width, dtype, device, do_classifier_free_guidance):
733
# resize the mask to latents shape as we concatenate the mask to the latents
734
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
735
# and half precision
736
mask_image = F.interpolate(mask_image, size=(height // self.vae_scale_factor, width // self.vae_scale_factor))
737
mask_image = mask_image.to(device=device, dtype=dtype)
738
739
# duplicate mask for each generation per prompt, using mps friendly method
740
if mask_image.shape[0] < batch_size:
741
if not batch_size % mask_image.shape[0] == 0:
742
raise ValueError(
743
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
744
f" a total batch size of {batch_size}, but {mask_image.shape[0]} masks were passed. Make sure the number"
745
" of masks that you pass is divisible by the total requested batch size."
746
)
747
mask_image = mask_image.repeat(batch_size // mask_image.shape[0], 1, 1, 1)
748
749
mask_image = torch.cat([mask_image] * 2) if do_classifier_free_guidance else mask_image
750
751
mask_image_latents = mask_image
752
753
return mask_image_latents
754
755
def prepare_masked_image_latents(
756
self, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
757
):
758
masked_image = masked_image.to(device=device, dtype=dtype)
759
760
# encode the mask image into latents space so we can concatenate it to the latents
761
if isinstance(generator, list):
762
masked_image_latents = [
763
self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(generator=generator[i])
764
for i in range(batch_size)
765
]
766
masked_image_latents = torch.cat(masked_image_latents, dim=0)
767
else:
768
masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
769
masked_image_latents = self.vae.config.scaling_factor * masked_image_latents
770
771
# duplicate masked_image_latents for each generation per prompt, using mps friendly method
772
if masked_image_latents.shape[0] < batch_size:
773
if not batch_size % masked_image_latents.shape[0] == 0:
774
raise ValueError(
775
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
776
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
777
" Make sure the number of images that you pass is divisible by the total requested batch size."
778
)
779
masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
780
781
masked_image_latents = (
782
torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
783
)
784
785
# aligning device to prevent device errors when concating it with the latent model input
786
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
787
return masked_image_latents
788
789
def _default_height_width(self, height, width, image):
790
if isinstance(image, list):
791
image = image[0]
792
793
if height is None:
794
if isinstance(image, PIL.Image.Image):
795
height = image.height
796
elif isinstance(image, torch.Tensor):
797
height = image.shape[3]
798
799
height = (height // 8) * 8 # round down to nearest multiple of 8
800
801
if width is None:
802
if isinstance(image, PIL.Image.Image):
803
width = image.width
804
elif isinstance(image, torch.Tensor):
805
width = image.shape[2]
806
807
width = (width // 8) * 8 # round down to nearest multiple of 8
808
809
return height, width
810
811
@torch.no_grad()
812
@replace_example_docstring(EXAMPLE_DOC_STRING)
813
def __call__(
814
self,
815
prompt: Union[str, List[str]] = None,
816
image: Union[torch.Tensor, PIL.Image.Image] = None,
817
mask_image: Union[torch.Tensor, PIL.Image.Image] = None,
818
controlnet_conditioning_image: Union[
819
torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]
820
] = None,
821
strength: float = 0.8,
822
height: Optional[int] = None,
823
width: Optional[int] = None,
824
num_inference_steps: int = 50,
825
guidance_scale: float = 7.5,
826
negative_prompt: Optional[Union[str, List[str]]] = None,
827
num_images_per_prompt: Optional[int] = 1,
828
eta: float = 0.0,
829
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
830
latents: Optional[torch.FloatTensor] = None,
831
prompt_embeds: Optional[torch.FloatTensor] = None,
832
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
833
output_type: Optional[str] = "pil",
834
return_dict: bool = True,
835
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
836
callback_steps: int = 1,
837
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
838
controlnet_conditioning_scale: float = 1.0,
839
):
840
r"""
841
Function invoked when calling the pipeline for generation.
842
843
Args:
844
prompt (`str` or `List[str]`, *optional*):
845
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
846
instead.
847
image (`torch.Tensor` or `PIL.Image.Image`):
848
`Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
849
be masked out with `mask_image` and repainted according to `prompt`.
850
mask_image (`torch.Tensor` or `PIL.Image.Image`):
851
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
852
repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
853
to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
854
instead of 3, so the expected shape would be `(B, H, W, 1)`.
855
controlnet_conditioning_image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]` or `List[PIL.Image.Image]`):
856
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
857
the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. PIL.Image.Image` can
858
also be accepted as an image. The control image is automatically resized to fit the output image.
859
strength (`float`, *optional*):
860
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
861
will be used as a starting point, adding more noise to it the larger the `strength`. The number of
862
denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
863
be maximum and the denoising process will run for the full number of iterations specified in
864
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
865
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
866
The height in pixels of the generated image.
867
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
868
The width in pixels of the generated image.
869
num_inference_steps (`int`, *optional*, defaults to 50):
870
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
871
expense of slower inference.
872
guidance_scale (`float`, *optional*, defaults to 7.5):
873
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
874
`guidance_scale` is defined as `w` of equation 2. of [Imagen
875
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
876
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
877
usually at the expense of lower image quality.
878
negative_prompt (`str` or `List[str]`, *optional*):
879
The prompt or prompts not to guide the image generation. If not defined, one has to pass
880
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
881
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
882
num_images_per_prompt (`int`, *optional*, defaults to 1):
883
The number of images to generate per prompt.
884
eta (`float`, *optional*, defaults to 0.0):
885
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
886
[`schedulers.DDIMScheduler`], will be ignored for others.
887
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
888
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
889
to make generation deterministic.
890
latents (`torch.FloatTensor`, *optional*):
891
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
892
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
893
tensor will ge generated by sampling using the supplied random `generator`.
894
prompt_embeds (`torch.FloatTensor`, *optional*):
895
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
896
provided, text embeddings will be generated from `prompt` input argument.
897
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
898
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
899
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
900
argument.
901
output_type (`str`, *optional*, defaults to `"pil"`):
902
The output format of the generate image. Choose between
903
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
904
return_dict (`bool`, *optional*, defaults to `True`):
905
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
906
plain tuple.
907
callback (`Callable`, *optional*):
908
A function that will be called every `callback_steps` steps during inference. The function will be
909
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
910
callback_steps (`int`, *optional*, defaults to 1):
911
The frequency at which the `callback` function will be called. If not specified, the callback will be
912
called at every step.
913
cross_attention_kwargs (`dict`, *optional*):
914
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
915
`self.processor` in
916
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
917
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
918
The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
919
to the residual in the original unet.
920
921
Examples:
922
923
Returns:
924
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
925
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
926
When returning a tuple, the first element is a list with the generated images, and the second element is a
927
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
928
(nsfw) content, according to the `safety_checker`.
929
"""
930
# 0. Default height and width to unet
931
height, width = self._default_height_width(height, width, controlnet_conditioning_image)
932
933
# 1. Check inputs. Raise error if not correct
934
self.check_inputs(
935
prompt,
936
image,
937
mask_image,
938
controlnet_conditioning_image,
939
height,
940
width,
941
callback_steps,
942
negative_prompt,
943
prompt_embeds,
944
negative_prompt_embeds,
945
strength,
946
)
947
948
# 2. Define call parameters
949
if prompt is not None and isinstance(prompt, str):
950
batch_size = 1
951
elif prompt is not None and isinstance(prompt, list):
952
batch_size = len(prompt)
953
else:
954
batch_size = prompt_embeds.shape[0]
955
956
device = self._execution_device
957
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
958
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
959
# corresponds to doing no classifier free guidance.
960
do_classifier_free_guidance = guidance_scale > 1.0
961
962
# 3. Encode input prompt
963
prompt_embeds = self._encode_prompt(
964
prompt,
965
device,
966
num_images_per_prompt,
967
do_classifier_free_guidance,
968
negative_prompt,
969
prompt_embeds=prompt_embeds,
970
negative_prompt_embeds=negative_prompt_embeds,
971
)
972
973
# 4. Prepare mask, image, and controlnet_conditioning_image
974
image = prepare_image(image)
975
976
mask_image = prepare_mask_image(mask_image)
977
978
controlnet_conditioning_image = prepare_controlnet_conditioning_image(
979
controlnet_conditioning_image,
980
width,
981
height,
982
batch_size * num_images_per_prompt,
983
num_images_per_prompt,
984
device,
985
self.controlnet.dtype,
986
)
987
988
masked_image = image * (mask_image < 0.5)
989
990
# 5. Prepare timesteps
991
self.scheduler.set_timesteps(num_inference_steps, device=device)
992
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
993
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
994
995
# 6. Prepare latent variables
996
latents = self.prepare_latents(
997
image,
998
latent_timestep,
999
batch_size,
1000
num_images_per_prompt,
1001
prompt_embeds.dtype,
1002
device,
1003
generator,
1004
)
1005
1006
mask_image_latents = self.prepare_mask_latents(
1007
mask_image,
1008
batch_size * num_images_per_prompt,
1009
height,
1010
width,
1011
prompt_embeds.dtype,
1012
device,
1013
do_classifier_free_guidance,
1014
)
1015
1016
masked_image_latents = self.prepare_masked_image_latents(
1017
masked_image,
1018
batch_size * num_images_per_prompt,
1019
height,
1020
width,
1021
prompt_embeds.dtype,
1022
device,
1023
generator,
1024
do_classifier_free_guidance,
1025
)
1026
1027
if do_classifier_free_guidance:
1028
controlnet_conditioning_image = torch.cat([controlnet_conditioning_image] * 2)
1029
1030
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1031
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1032
1033
# 8. Denoising loop
1034
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1035
with self.progress_bar(total=num_inference_steps) as progress_bar:
1036
for i, t in enumerate(timesteps):
1037
# expand the latents if we are doing classifier free guidance
1038
non_inpainting_latent_model_input = (
1039
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1040
)
1041
1042
non_inpainting_latent_model_input = self.scheduler.scale_model_input(
1043
non_inpainting_latent_model_input, t
1044
)
1045
1046
inpainting_latent_model_input = torch.cat(
1047
[non_inpainting_latent_model_input, mask_image_latents, masked_image_latents], dim=1
1048
)
1049
1050
down_block_res_samples, mid_block_res_sample = self.controlnet(
1051
non_inpainting_latent_model_input,
1052
t,
1053
encoder_hidden_states=prompt_embeds,
1054
controlnet_cond=controlnet_conditioning_image,
1055
return_dict=False,
1056
)
1057
1058
down_block_res_samples = [
1059
down_block_res_sample * controlnet_conditioning_scale
1060
for down_block_res_sample in down_block_res_samples
1061
]
1062
mid_block_res_sample *= controlnet_conditioning_scale
1063
1064
# predict the noise residual
1065
noise_pred = self.unet(
1066
inpainting_latent_model_input,
1067
t,
1068
encoder_hidden_states=prompt_embeds,
1069
cross_attention_kwargs=cross_attention_kwargs,
1070
down_block_additional_residuals=down_block_res_samples,
1071
mid_block_additional_residual=mid_block_res_sample,
1072
).sample
1073
1074
# perform guidance
1075
if do_classifier_free_guidance:
1076
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1077
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1078
1079
# compute the previous noisy sample x_t -> x_t-1
1080
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
1081
1082
# call the callback, if provided
1083
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1084
progress_bar.update()
1085
if callback is not None and i % callback_steps == 0:
1086
callback(i, t, latents)
1087
1088
# If we do sequential model offloading, let's offload unet and controlnet
1089
# manually for max memory savings
1090
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1091
self.unet.to("cpu")
1092
self.controlnet.to("cpu")
1093
torch.cuda.empty_cache()
1094
1095
if output_type == "latent":
1096
image = latents
1097
has_nsfw_concept = None
1098
elif output_type == "pil":
1099
# 8. Post-processing
1100
image = self.decode_latents(latents)
1101
1102
# 9. Run safety checker
1103
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
1104
1105
# 10. Convert to PIL
1106
image = self.numpy_to_pil(image)
1107
else:
1108
# 8. Post-processing
1109
image = self.decode_latents(latents)
1110
1111
# 9. Run safety checker
1112
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
1113
1114
# Offload last model to CPU
1115
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1116
self.final_offload_hook.offload()
1117
1118
if not return_dict:
1119
return (image, has_nsfw_concept)
1120
1121
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
1122
1123