Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/examples/community/stable_diffusion_controlnet_inpaint.py
1448 views
1
# Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/
2
3
import inspect
4
from typing import Any, Callable, Dict, List, Optional, Union
5
6
import numpy as np
7
import PIL.Image
8
import torch
9
import torch.nn.functional as F
10
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
11
12
from diffusers import AutoencoderKL, ControlNetModel, DiffusionPipeline, UNet2DConditionModel, logging
13
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
14
from diffusers.schedulers import KarrasDiffusionSchedulers
15
from diffusers.utils import (
16
PIL_INTERPOLATION,
17
is_accelerate_available,
18
is_accelerate_version,
19
randn_tensor,
20
replace_example_docstring,
21
)
22
23
24
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
25
26
EXAMPLE_DOC_STRING = """
27
Examples:
28
```py
29
>>> import numpy as np
30
>>> import torch
31
>>> from PIL import Image
32
>>> from stable_diffusion_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline
33
34
>>> from transformers import AutoImageProcessor, UperNetForSemanticSegmentation
35
>>> from diffusers import ControlNetModel, UniPCMultistepScheduler
36
>>> from diffusers.utils import load_image
37
38
>>> def ade_palette():
39
return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
40
[4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
41
[230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
42
[150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
43
[143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
44
[0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
45
[255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
46
[255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
47
[255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
48
[224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
49
[255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
50
[6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
51
[140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
52
[255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
53
[255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
54
[11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
55
[0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
56
[255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
57
[0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
58
[173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
59
[255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
60
[255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
61
[255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
62
[0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
63
[0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
64
[143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
65
[8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
66
[255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
67
[92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
68
[163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
69
[255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
70
[255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
71
[10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
72
[255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
73
[41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
74
[71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
75
[184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
76
[102, 255, 0], [92, 0, 255]]
77
78
>>> image_processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-small")
79
>>> image_segmentor = UperNetForSemanticSegmentation.from_pretrained("openmmlab/upernet-convnext-small")
80
81
>>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-seg", torch_dtype=torch.float16)
82
83
>>> pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
84
"runwayml/stable-diffusion-inpainting", controlnet=controlnet, safety_checker=None, torch_dtype=torch.float16
85
)
86
87
>>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
88
>>> pipe.enable_xformers_memory_efficient_attention()
89
>>> pipe.enable_model_cpu_offload()
90
91
>>> def image_to_seg(image):
92
pixel_values = image_processor(image, return_tensors="pt").pixel_values
93
with torch.no_grad():
94
outputs = image_segmentor(pixel_values)
95
seg = image_processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
96
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3
97
palette = np.array(ade_palette())
98
for label, color in enumerate(palette):
99
color_seg[seg == label, :] = color
100
color_seg = color_seg.astype(np.uint8)
101
seg_image = Image.fromarray(color_seg)
102
return seg_image
103
104
>>> image = load_image(
105
"https://github.com/CompVis/latent-diffusion/raw/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
106
)
107
108
>>> mask_image = load_image(
109
"https://github.com/CompVis/latent-diffusion/raw/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
110
)
111
112
>>> controlnet_conditioning_image = image_to_seg(image)
113
114
>>> image = pipe(
115
"Face of a yellow cat, high resolution, sitting on a park bench",
116
image,
117
mask_image,
118
controlnet_conditioning_image,
119
num_inference_steps=20,
120
).images[0]
121
122
>>> image.save("out.png")
123
```
124
"""
125
126
127
def prepare_image(image):
128
if isinstance(image, torch.Tensor):
129
# Batch single image
130
if image.ndim == 3:
131
image = image.unsqueeze(0)
132
133
image = image.to(dtype=torch.float32)
134
else:
135
# preprocess image
136
if isinstance(image, (PIL.Image.Image, np.ndarray)):
137
image = [image]
138
139
if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
140
image = [np.array(i.convert("RGB"))[None, :] for i in image]
141
image = np.concatenate(image, axis=0)
142
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
143
image = np.concatenate([i[None, :] for i in image], axis=0)
144
145
image = image.transpose(0, 3, 1, 2)
146
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
147
148
return image
149
150
151
def prepare_mask_image(mask_image):
152
if isinstance(mask_image, torch.Tensor):
153
if mask_image.ndim == 2:
154
# Batch and add channel dim for single mask
155
mask_image = mask_image.unsqueeze(0).unsqueeze(0)
156
elif mask_image.ndim == 3 and mask_image.shape[0] == 1:
157
# Single mask, the 0'th dimension is considered to be
158
# the existing batch size of 1
159
mask_image = mask_image.unsqueeze(0)
160
elif mask_image.ndim == 3 and mask_image.shape[0] != 1:
161
# Batch of mask, the 0'th dimension is considered to be
162
# the batching dimension
163
mask_image = mask_image.unsqueeze(1)
164
165
# Binarize mask
166
mask_image[mask_image < 0.5] = 0
167
mask_image[mask_image >= 0.5] = 1
168
else:
169
# preprocess mask
170
if isinstance(mask_image, (PIL.Image.Image, np.ndarray)):
171
mask_image = [mask_image]
172
173
if isinstance(mask_image, list) and isinstance(mask_image[0], PIL.Image.Image):
174
mask_image = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask_image], axis=0)
175
mask_image = mask_image.astype(np.float32) / 255.0
176
elif isinstance(mask_image, list) and isinstance(mask_image[0], np.ndarray):
177
mask_image = np.concatenate([m[None, None, :] for m in mask_image], axis=0)
178
179
mask_image[mask_image < 0.5] = 0
180
mask_image[mask_image >= 0.5] = 1
181
mask_image = torch.from_numpy(mask_image)
182
183
return mask_image
184
185
186
def prepare_controlnet_conditioning_image(
187
controlnet_conditioning_image, width, height, batch_size, num_images_per_prompt, device, dtype
188
):
189
if not isinstance(controlnet_conditioning_image, torch.Tensor):
190
if isinstance(controlnet_conditioning_image, PIL.Image.Image):
191
controlnet_conditioning_image = [controlnet_conditioning_image]
192
193
if isinstance(controlnet_conditioning_image[0], PIL.Image.Image):
194
controlnet_conditioning_image = [
195
np.array(i.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]))[None, :]
196
for i in controlnet_conditioning_image
197
]
198
controlnet_conditioning_image = np.concatenate(controlnet_conditioning_image, axis=0)
199
controlnet_conditioning_image = np.array(controlnet_conditioning_image).astype(np.float32) / 255.0
200
controlnet_conditioning_image = controlnet_conditioning_image.transpose(0, 3, 1, 2)
201
controlnet_conditioning_image = torch.from_numpy(controlnet_conditioning_image)
202
elif isinstance(controlnet_conditioning_image[0], torch.Tensor):
203
controlnet_conditioning_image = torch.cat(controlnet_conditioning_image, dim=0)
204
205
image_batch_size = controlnet_conditioning_image.shape[0]
206
207
if image_batch_size == 1:
208
repeat_by = batch_size
209
else:
210
# image batch size is the same as prompt batch size
211
repeat_by = num_images_per_prompt
212
213
controlnet_conditioning_image = controlnet_conditioning_image.repeat_interleave(repeat_by, dim=0)
214
215
controlnet_conditioning_image = controlnet_conditioning_image.to(device=device, dtype=dtype)
216
217
return controlnet_conditioning_image
218
219
220
class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
221
"""
222
Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/
223
"""
224
225
_optional_components = ["safety_checker", "feature_extractor"]
226
227
def __init__(
228
self,
229
vae: AutoencoderKL,
230
text_encoder: CLIPTextModel,
231
tokenizer: CLIPTokenizer,
232
unet: UNet2DConditionModel,
233
controlnet: ControlNetModel,
234
scheduler: KarrasDiffusionSchedulers,
235
safety_checker: StableDiffusionSafetyChecker,
236
feature_extractor: CLIPImageProcessor,
237
requires_safety_checker: bool = True,
238
):
239
super().__init__()
240
241
if safety_checker is None and requires_safety_checker:
242
logger.warning(
243
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
244
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
245
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
246
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
247
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
248
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
249
)
250
251
if safety_checker is not None and feature_extractor is None:
252
raise ValueError(
253
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
254
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
255
)
256
257
self.register_modules(
258
vae=vae,
259
text_encoder=text_encoder,
260
tokenizer=tokenizer,
261
unet=unet,
262
controlnet=controlnet,
263
scheduler=scheduler,
264
safety_checker=safety_checker,
265
feature_extractor=feature_extractor,
266
)
267
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
268
self.register_to_config(requires_safety_checker=requires_safety_checker)
269
270
def enable_vae_slicing(self):
271
r"""
272
Enable sliced VAE decoding.
273
274
When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
275
steps. This is useful to save some memory and allow larger batch sizes.
276
"""
277
self.vae.enable_slicing()
278
279
def disable_vae_slicing(self):
280
r"""
281
Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
282
computing decoding in one step.
283
"""
284
self.vae.disable_slicing()
285
286
def enable_sequential_cpu_offload(self, gpu_id=0):
287
r"""
288
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
289
text_encoder, vae, controlnet, and safety checker have their state dicts saved to CPU and then are moved to a
290
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
291
Note that offloading happens on a submodule basis. Memory savings are higher than with
292
`enable_model_cpu_offload`, but performance is lower.
293
"""
294
if is_accelerate_available():
295
from accelerate import cpu_offload
296
else:
297
raise ImportError("Please install accelerate via `pip install accelerate`")
298
299
device = torch.device(f"cuda:{gpu_id}")
300
301
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.controlnet]:
302
cpu_offload(cpu_offloaded_model, device)
303
304
if self.safety_checker is not None:
305
cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
306
307
def enable_model_cpu_offload(self, gpu_id=0):
308
r"""
309
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
310
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
311
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
312
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
313
"""
314
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
315
from accelerate import cpu_offload_with_hook
316
else:
317
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
318
319
device = torch.device(f"cuda:{gpu_id}")
320
321
hook = None
322
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
323
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
324
325
if self.safety_checker is not None:
326
# the safety checker can offload the vae again
327
_, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
328
329
# control net hook has be manually offloaded as it alternates with unet
330
cpu_offload_with_hook(self.controlnet, device)
331
332
# We'll offload the last model manually.
333
self.final_offload_hook = hook
334
335
@property
336
def _execution_device(self):
337
r"""
338
Returns the device on which the pipeline's models will be executed. After calling
339
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
340
hooks.
341
"""
342
if not hasattr(self.unet, "_hf_hook"):
343
return self.device
344
for module in self.unet.modules():
345
if (
346
hasattr(module, "_hf_hook")
347
and hasattr(module._hf_hook, "execution_device")
348
and module._hf_hook.execution_device is not None
349
):
350
return torch.device(module._hf_hook.execution_device)
351
return self.device
352
353
def _encode_prompt(
354
self,
355
prompt,
356
device,
357
num_images_per_prompt,
358
do_classifier_free_guidance,
359
negative_prompt=None,
360
prompt_embeds: Optional[torch.FloatTensor] = None,
361
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
362
):
363
r"""
364
Encodes the prompt into text encoder hidden states.
365
366
Args:
367
prompt (`str` or `List[str]`, *optional*):
368
prompt to be encoded
369
device: (`torch.device`):
370
torch device
371
num_images_per_prompt (`int`):
372
number of images that should be generated per prompt
373
do_classifier_free_guidance (`bool`):
374
whether to use classifier free guidance or not
375
negative_prompt (`str` or `List[str]`, *optional*):
376
The prompt or prompts not to guide the image generation. If not defined, one has to pass
377
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
378
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
379
prompt_embeds (`torch.FloatTensor`, *optional*):
380
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
381
provided, text embeddings will be generated from `prompt` input argument.
382
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
383
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
384
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
385
argument.
386
"""
387
if prompt is not None and isinstance(prompt, str):
388
batch_size = 1
389
elif prompt is not None and isinstance(prompt, list):
390
batch_size = len(prompt)
391
else:
392
batch_size = prompt_embeds.shape[0]
393
394
if prompt_embeds is None:
395
text_inputs = self.tokenizer(
396
prompt,
397
padding="max_length",
398
max_length=self.tokenizer.model_max_length,
399
truncation=True,
400
return_tensors="pt",
401
)
402
text_input_ids = text_inputs.input_ids
403
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
404
405
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
406
text_input_ids, untruncated_ids
407
):
408
removed_text = self.tokenizer.batch_decode(
409
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
410
)
411
logger.warning(
412
"The following part of your input was truncated because CLIP can only handle sequences up to"
413
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
414
)
415
416
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
417
attention_mask = text_inputs.attention_mask.to(device)
418
else:
419
attention_mask = None
420
421
prompt_embeds = self.text_encoder(
422
text_input_ids.to(device),
423
attention_mask=attention_mask,
424
)
425
prompt_embeds = prompt_embeds[0]
426
427
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
428
429
bs_embed, seq_len, _ = prompt_embeds.shape
430
# duplicate text embeddings for each generation per prompt, using mps friendly method
431
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
432
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
433
434
# get unconditional embeddings for classifier free guidance
435
if do_classifier_free_guidance and negative_prompt_embeds is None:
436
uncond_tokens: List[str]
437
if negative_prompt is None:
438
uncond_tokens = [""] * batch_size
439
elif type(prompt) is not type(negative_prompt):
440
raise TypeError(
441
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
442
f" {type(prompt)}."
443
)
444
elif isinstance(negative_prompt, str):
445
uncond_tokens = [negative_prompt]
446
elif batch_size != len(negative_prompt):
447
raise ValueError(
448
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
449
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
450
" the batch size of `prompt`."
451
)
452
else:
453
uncond_tokens = negative_prompt
454
455
max_length = prompt_embeds.shape[1]
456
uncond_input = self.tokenizer(
457
uncond_tokens,
458
padding="max_length",
459
max_length=max_length,
460
truncation=True,
461
return_tensors="pt",
462
)
463
464
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
465
attention_mask = uncond_input.attention_mask.to(device)
466
else:
467
attention_mask = None
468
469
negative_prompt_embeds = self.text_encoder(
470
uncond_input.input_ids.to(device),
471
attention_mask=attention_mask,
472
)
473
negative_prompt_embeds = negative_prompt_embeds[0]
474
475
if do_classifier_free_guidance:
476
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
477
seq_len = negative_prompt_embeds.shape[1]
478
479
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
480
481
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
482
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
483
484
# For classifier free guidance, we need to do two forward passes.
485
# Here we concatenate the unconditional and text embeddings into a single batch
486
# to avoid doing two forward passes
487
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
488
489
return prompt_embeds
490
491
def run_safety_checker(self, image, device, dtype):
492
if self.safety_checker is not None:
493
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
494
image, has_nsfw_concept = self.safety_checker(
495
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
496
)
497
else:
498
has_nsfw_concept = None
499
return image, has_nsfw_concept
500
501
def decode_latents(self, latents):
502
latents = 1 / self.vae.config.scaling_factor * latents
503
image = self.vae.decode(latents).sample
504
image = (image / 2 + 0.5).clamp(0, 1)
505
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
506
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
507
return image
508
509
def prepare_extra_step_kwargs(self, generator, eta):
510
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
511
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
512
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
513
# and should be between [0, 1]
514
515
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
516
extra_step_kwargs = {}
517
if accepts_eta:
518
extra_step_kwargs["eta"] = eta
519
520
# check if the scheduler accepts generator
521
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
522
if accepts_generator:
523
extra_step_kwargs["generator"] = generator
524
return extra_step_kwargs
525
526
def check_inputs(
527
self,
528
prompt,
529
image,
530
mask_image,
531
controlnet_conditioning_image,
532
height,
533
width,
534
callback_steps,
535
negative_prompt=None,
536
prompt_embeds=None,
537
negative_prompt_embeds=None,
538
):
539
if height % 8 != 0 or width % 8 != 0:
540
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
541
542
if (callback_steps is None) or (
543
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
544
):
545
raise ValueError(
546
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
547
f" {type(callback_steps)}."
548
)
549
550
if prompt is not None and prompt_embeds is not None:
551
raise ValueError(
552
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
553
" only forward one of the two."
554
)
555
elif prompt is None and prompt_embeds is None:
556
raise ValueError(
557
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
558
)
559
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
560
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
561
562
if negative_prompt is not None and negative_prompt_embeds is not None:
563
raise ValueError(
564
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
565
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
566
)
567
568
if prompt_embeds is not None and negative_prompt_embeds is not None:
569
if prompt_embeds.shape != negative_prompt_embeds.shape:
570
raise ValueError(
571
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
572
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
573
f" {negative_prompt_embeds.shape}."
574
)
575
576
controlnet_cond_image_is_pil = isinstance(controlnet_conditioning_image, PIL.Image.Image)
577
controlnet_cond_image_is_tensor = isinstance(controlnet_conditioning_image, torch.Tensor)
578
controlnet_cond_image_is_pil_list = isinstance(controlnet_conditioning_image, list) and isinstance(
579
controlnet_conditioning_image[0], PIL.Image.Image
580
)
581
controlnet_cond_image_is_tensor_list = isinstance(controlnet_conditioning_image, list) and isinstance(
582
controlnet_conditioning_image[0], torch.Tensor
583
)
584
585
if (
586
not controlnet_cond_image_is_pil
587
and not controlnet_cond_image_is_tensor
588
and not controlnet_cond_image_is_pil_list
589
and not controlnet_cond_image_is_tensor_list
590
):
591
raise TypeError(
592
"image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"
593
)
594
595
if controlnet_cond_image_is_pil:
596
controlnet_cond_image_batch_size = 1
597
elif controlnet_cond_image_is_tensor:
598
controlnet_cond_image_batch_size = controlnet_conditioning_image.shape[0]
599
elif controlnet_cond_image_is_pil_list:
600
controlnet_cond_image_batch_size = len(controlnet_conditioning_image)
601
elif controlnet_cond_image_is_tensor_list:
602
controlnet_cond_image_batch_size = len(controlnet_conditioning_image)
603
604
if prompt is not None and isinstance(prompt, str):
605
prompt_batch_size = 1
606
elif prompt is not None and isinstance(prompt, list):
607
prompt_batch_size = len(prompt)
608
elif prompt_embeds is not None:
609
prompt_batch_size = prompt_embeds.shape[0]
610
611
if controlnet_cond_image_batch_size != 1 and controlnet_cond_image_batch_size != prompt_batch_size:
612
raise ValueError(
613
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {controlnet_cond_image_batch_size}, prompt batch size: {prompt_batch_size}"
614
)
615
616
if isinstance(image, torch.Tensor) and not isinstance(mask_image, torch.Tensor):
617
raise TypeError("if `image` is a tensor, `mask_image` must also be a tensor")
618
619
if isinstance(image, PIL.Image.Image) and not isinstance(mask_image, PIL.Image.Image):
620
raise TypeError("if `image` is a PIL image, `mask_image` must also be a PIL image")
621
622
if isinstance(image, torch.Tensor):
623
if image.ndim != 3 and image.ndim != 4:
624
raise ValueError("`image` must have 3 or 4 dimensions")
625
626
if mask_image.ndim != 2 and mask_image.ndim != 3 and mask_image.ndim != 4:
627
raise ValueError("`mask_image` must have 2, 3, or 4 dimensions")
628
629
if image.ndim == 3:
630
image_batch_size = 1
631
image_channels, image_height, image_width = image.shape
632
elif image.ndim == 4:
633
image_batch_size, image_channels, image_height, image_width = image.shape
634
635
if mask_image.ndim == 2:
636
mask_image_batch_size = 1
637
mask_image_channels = 1
638
mask_image_height, mask_image_width = mask_image.shape
639
elif mask_image.ndim == 3:
640
mask_image_channels = 1
641
mask_image_batch_size, mask_image_height, mask_image_width = mask_image.shape
642
elif mask_image.ndim == 4:
643
mask_image_batch_size, mask_image_channels, mask_image_height, mask_image_width = mask_image.shape
644
645
if image_channels != 3:
646
raise ValueError("`image` must have 3 channels")
647
648
if mask_image_channels != 1:
649
raise ValueError("`mask_image` must have 1 channel")
650
651
if image_batch_size != mask_image_batch_size:
652
raise ValueError("`image` and `mask_image` mush have the same batch sizes")
653
654
if image_height != mask_image_height or image_width != mask_image_width:
655
raise ValueError("`image` and `mask_image` must have the same height and width dimensions")
656
657
if image.min() < -1 or image.max() > 1:
658
raise ValueError("`image` should be in range [-1, 1]")
659
660
if mask_image.min() < 0 or mask_image.max() > 1:
661
raise ValueError("`mask_image` should be in range [0, 1]")
662
else:
663
mask_image_channels = 1
664
image_channels = 3
665
666
single_image_latent_channels = self.vae.config.latent_channels
667
668
total_latent_channels = single_image_latent_channels * 2 + mask_image_channels
669
670
if total_latent_channels != self.unet.config.in_channels:
671
raise ValueError(
672
f"The config of `pipeline.unet` expects {self.unet.config.in_channels} but received"
673
f" non inpainting latent channels: {single_image_latent_channels},"
674
f" mask channels: {mask_image_channels}, and masked image channels: {single_image_latent_channels}."
675
f" Please verify the config of `pipeline.unet` and the `mask_image` and `image` inputs."
676
)
677
678
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
679
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
680
if isinstance(generator, list) and len(generator) != batch_size:
681
raise ValueError(
682
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
683
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
684
)
685
686
if latents is None:
687
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
688
else:
689
latents = latents.to(device)
690
691
# scale the initial noise by the standard deviation required by the scheduler
692
latents = latents * self.scheduler.init_noise_sigma
693
694
return latents
695
696
def prepare_mask_latents(self, mask_image, batch_size, height, width, dtype, device, do_classifier_free_guidance):
697
# resize the mask to latents shape as we concatenate the mask to the latents
698
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
699
# and half precision
700
mask_image = F.interpolate(mask_image, size=(height // self.vae_scale_factor, width // self.vae_scale_factor))
701
mask_image = mask_image.to(device=device, dtype=dtype)
702
703
# duplicate mask for each generation per prompt, using mps friendly method
704
if mask_image.shape[0] < batch_size:
705
if not batch_size % mask_image.shape[0] == 0:
706
raise ValueError(
707
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
708
f" a total batch size of {batch_size}, but {mask_image.shape[0]} masks were passed. Make sure the number"
709
" of masks that you pass is divisible by the total requested batch size."
710
)
711
mask_image = mask_image.repeat(batch_size // mask_image.shape[0], 1, 1, 1)
712
713
mask_image = torch.cat([mask_image] * 2) if do_classifier_free_guidance else mask_image
714
715
mask_image_latents = mask_image
716
717
return mask_image_latents
718
719
def prepare_masked_image_latents(
720
self, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
721
):
722
masked_image = masked_image.to(device=device, dtype=dtype)
723
724
# encode the mask image into latents space so we can concatenate it to the latents
725
if isinstance(generator, list):
726
masked_image_latents = [
727
self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(generator=generator[i])
728
for i in range(batch_size)
729
]
730
masked_image_latents = torch.cat(masked_image_latents, dim=0)
731
else:
732
masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
733
masked_image_latents = self.vae.config.scaling_factor * masked_image_latents
734
735
# duplicate masked_image_latents for each generation per prompt, using mps friendly method
736
if masked_image_latents.shape[0] < batch_size:
737
if not batch_size % masked_image_latents.shape[0] == 0:
738
raise ValueError(
739
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
740
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
741
" Make sure the number of images that you pass is divisible by the total requested batch size."
742
)
743
masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
744
745
masked_image_latents = (
746
torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
747
)
748
749
# aligning device to prevent device errors when concating it with the latent model input
750
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
751
return masked_image_latents
752
753
def _default_height_width(self, height, width, image):
754
if isinstance(image, list):
755
image = image[0]
756
757
if height is None:
758
if isinstance(image, PIL.Image.Image):
759
height = image.height
760
elif isinstance(image, torch.Tensor):
761
height = image.shape[3]
762
763
height = (height // 8) * 8 # round down to nearest multiple of 8
764
765
if width is None:
766
if isinstance(image, PIL.Image.Image):
767
width = image.width
768
elif isinstance(image, torch.Tensor):
769
width = image.shape[2]
770
771
width = (width // 8) * 8 # round down to nearest multiple of 8
772
773
return height, width
774
775
@torch.no_grad()
776
@replace_example_docstring(EXAMPLE_DOC_STRING)
777
def __call__(
778
self,
779
prompt: Union[str, List[str]] = None,
780
image: Union[torch.Tensor, PIL.Image.Image] = None,
781
mask_image: Union[torch.Tensor, PIL.Image.Image] = None,
782
controlnet_conditioning_image: Union[
783
torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]
784
] = None,
785
height: Optional[int] = None,
786
width: Optional[int] = None,
787
num_inference_steps: int = 50,
788
guidance_scale: float = 7.5,
789
negative_prompt: Optional[Union[str, List[str]]] = None,
790
num_images_per_prompt: Optional[int] = 1,
791
eta: float = 0.0,
792
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
793
latents: Optional[torch.FloatTensor] = None,
794
prompt_embeds: Optional[torch.FloatTensor] = None,
795
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
796
output_type: Optional[str] = "pil",
797
return_dict: bool = True,
798
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
799
callback_steps: int = 1,
800
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
801
controlnet_conditioning_scale: float = 1.0,
802
):
803
r"""
804
Function invoked when calling the pipeline for generation.
805
806
Args:
807
prompt (`str` or `List[str]`, *optional*):
808
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
809
instead.
810
image (`torch.Tensor` or `PIL.Image.Image`):
811
`Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
812
be masked out with `mask_image` and repainted according to `prompt`.
813
mask_image (`torch.Tensor` or `PIL.Image.Image`):
814
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
815
repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
816
to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
817
instead of 3, so the expected shape would be `(B, H, W, 1)`.
818
controlnet_conditioning_image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]` or `List[PIL.Image.Image]`):
819
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
820
the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. PIL.Image.Image` can
821
also be accepted as an image. The control image is automatically resized to fit the output image.
822
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
823
The height in pixels of the generated image.
824
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
825
The width in pixels of the generated image.
826
num_inference_steps (`int`, *optional*, defaults to 50):
827
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
828
expense of slower inference.
829
guidance_scale (`float`, *optional*, defaults to 7.5):
830
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
831
`guidance_scale` is defined as `w` of equation 2. of [Imagen
832
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
833
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
834
usually at the expense of lower image quality.
835
negative_prompt (`str` or `List[str]`, *optional*):
836
The prompt or prompts not to guide the image generation. If not defined, one has to pass
837
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
838
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
839
num_images_per_prompt (`int`, *optional*, defaults to 1):
840
The number of images to generate per prompt.
841
eta (`float`, *optional*, defaults to 0.0):
842
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
843
[`schedulers.DDIMScheduler`], will be ignored for others.
844
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
845
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
846
to make generation deterministic.
847
latents (`torch.FloatTensor`, *optional*):
848
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
849
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
850
tensor will ge generated by sampling using the supplied random `generator`.
851
prompt_embeds (`torch.FloatTensor`, *optional*):
852
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
853
provided, text embeddings will be generated from `prompt` input argument.
854
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
855
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
856
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
857
argument.
858
output_type (`str`, *optional*, defaults to `"pil"`):
859
The output format of the generate image. Choose between
860
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
861
return_dict (`bool`, *optional*, defaults to `True`):
862
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
863
plain tuple.
864
callback (`Callable`, *optional*):
865
A function that will be called every `callback_steps` steps during inference. The function will be
866
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
867
callback_steps (`int`, *optional*, defaults to 1):
868
The frequency at which the `callback` function will be called. If not specified, the callback will be
869
called at every step.
870
cross_attention_kwargs (`dict`, *optional*):
871
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
872
`self.processor` in
873
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
874
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
875
The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
876
to the residual in the original unet.
877
878
Examples:
879
880
Returns:
881
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
882
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
883
When returning a tuple, the first element is a list with the generated images, and the second element is a
884
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
885
(nsfw) content, according to the `safety_checker`.
886
"""
887
# 0. Default height and width to unet
888
height, width = self._default_height_width(height, width, controlnet_conditioning_image)
889
890
# 1. Check inputs. Raise error if not correct
891
self.check_inputs(
892
prompt,
893
image,
894
mask_image,
895
controlnet_conditioning_image,
896
height,
897
width,
898
callback_steps,
899
negative_prompt,
900
prompt_embeds,
901
negative_prompt_embeds,
902
)
903
904
# 2. Define call parameters
905
if prompt is not None and isinstance(prompt, str):
906
batch_size = 1
907
elif prompt is not None and isinstance(prompt, list):
908
batch_size = len(prompt)
909
else:
910
batch_size = prompt_embeds.shape[0]
911
912
device = self._execution_device
913
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
914
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
915
# corresponds to doing no classifier free guidance.
916
do_classifier_free_guidance = guidance_scale > 1.0
917
918
# 3. Encode input prompt
919
prompt_embeds = self._encode_prompt(
920
prompt,
921
device,
922
num_images_per_prompt,
923
do_classifier_free_guidance,
924
negative_prompt,
925
prompt_embeds=prompt_embeds,
926
negative_prompt_embeds=negative_prompt_embeds,
927
)
928
929
# 4. Prepare mask, image, and controlnet_conditioning_image
930
image = prepare_image(image)
931
932
mask_image = prepare_mask_image(mask_image)
933
934
controlnet_conditioning_image = prepare_controlnet_conditioning_image(
935
controlnet_conditioning_image,
936
width,
937
height,
938
batch_size * num_images_per_prompt,
939
num_images_per_prompt,
940
device,
941
self.controlnet.dtype,
942
)
943
944
masked_image = image * (mask_image < 0.5)
945
946
# 5. Prepare timesteps
947
self.scheduler.set_timesteps(num_inference_steps, device=device)
948
timesteps = self.scheduler.timesteps
949
950
# 6. Prepare latent variables
951
num_channels_latents = self.vae.config.latent_channels
952
latents = self.prepare_latents(
953
batch_size * num_images_per_prompt,
954
num_channels_latents,
955
height,
956
width,
957
prompt_embeds.dtype,
958
device,
959
generator,
960
latents,
961
)
962
963
mask_image_latents = self.prepare_mask_latents(
964
mask_image,
965
batch_size * num_images_per_prompt,
966
height,
967
width,
968
prompt_embeds.dtype,
969
device,
970
do_classifier_free_guidance,
971
)
972
973
masked_image_latents = self.prepare_masked_image_latents(
974
masked_image,
975
batch_size * num_images_per_prompt,
976
height,
977
width,
978
prompt_embeds.dtype,
979
device,
980
generator,
981
do_classifier_free_guidance,
982
)
983
984
if do_classifier_free_guidance:
985
controlnet_conditioning_image = torch.cat([controlnet_conditioning_image] * 2)
986
987
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
988
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
989
990
# 8. Denoising loop
991
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
992
with self.progress_bar(total=num_inference_steps) as progress_bar:
993
for i, t in enumerate(timesteps):
994
# expand the latents if we are doing classifier free guidance
995
non_inpainting_latent_model_input = (
996
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
997
)
998
999
non_inpainting_latent_model_input = self.scheduler.scale_model_input(
1000
non_inpainting_latent_model_input, t
1001
)
1002
1003
inpainting_latent_model_input = torch.cat(
1004
[non_inpainting_latent_model_input, mask_image_latents, masked_image_latents], dim=1
1005
)
1006
1007
down_block_res_samples, mid_block_res_sample = self.controlnet(
1008
non_inpainting_latent_model_input,
1009
t,
1010
encoder_hidden_states=prompt_embeds,
1011
controlnet_cond=controlnet_conditioning_image,
1012
return_dict=False,
1013
)
1014
1015
down_block_res_samples = [
1016
down_block_res_sample * controlnet_conditioning_scale
1017
for down_block_res_sample in down_block_res_samples
1018
]
1019
mid_block_res_sample *= controlnet_conditioning_scale
1020
1021
# predict the noise residual
1022
noise_pred = self.unet(
1023
inpainting_latent_model_input,
1024
t,
1025
encoder_hidden_states=prompt_embeds,
1026
cross_attention_kwargs=cross_attention_kwargs,
1027
down_block_additional_residuals=down_block_res_samples,
1028
mid_block_additional_residual=mid_block_res_sample,
1029
).sample
1030
1031
# perform guidance
1032
if do_classifier_free_guidance:
1033
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1034
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1035
1036
# compute the previous noisy sample x_t -> x_t-1
1037
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
1038
1039
# call the callback, if provided
1040
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1041
progress_bar.update()
1042
if callback is not None and i % callback_steps == 0:
1043
callback(i, t, latents)
1044
1045
# If we do sequential model offloading, let's offload unet and controlnet
1046
# manually for max memory savings
1047
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1048
self.unet.to("cpu")
1049
self.controlnet.to("cpu")
1050
torch.cuda.empty_cache()
1051
1052
if output_type == "latent":
1053
image = latents
1054
has_nsfw_concept = None
1055
elif output_type == "pil":
1056
# 8. Post-processing
1057
image = self.decode_latents(latents)
1058
1059
# 9. Run safety checker
1060
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
1061
1062
# 10. Convert to PIL
1063
image = self.numpy_to_pil(image)
1064
else:
1065
# 8. Post-processing
1066
image = self.decode_latents(latents)
1067
1068
# 9. Run safety checker
1069
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
1070
1071
# Offload last model to CPU
1072
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1073
self.final_offload_hook.offload()
1074
1075
if not return_dict:
1076
return (image, has_nsfw_concept)
1077
1078
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
1079
1080