Path: blob/main/examples/community/stable_diffusion_controlnet_inpaint.py
1448 views
# Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/12import inspect3from typing import Any, Callable, Dict, List, Optional, Union45import numpy as np6import PIL.Image7import torch8import torch.nn.functional as F9from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer1011from diffusers import AutoencoderKL, ControlNetModel, DiffusionPipeline, UNet2DConditionModel, logging12from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker13from diffusers.schedulers import KarrasDiffusionSchedulers14from diffusers.utils import (15PIL_INTERPOLATION,16is_accelerate_available,17is_accelerate_version,18randn_tensor,19replace_example_docstring,20)212223logger = logging.get_logger(__name__) # pylint: disable=invalid-name2425EXAMPLE_DOC_STRING = """26Examples:27```py28>>> import numpy as np29>>> import torch30>>> from PIL import Image31>>> from stable_diffusion_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline3233>>> from transformers import AutoImageProcessor, UperNetForSemanticSegmentation34>>> from diffusers import ControlNetModel, UniPCMultistepScheduler35>>> from diffusers.utils import load_image3637>>> def ade_palette():38return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],39[4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],40[230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],41[150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],42[143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],43[0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],44[255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],45[255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],46[255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],47[224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],48[255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],49[6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],50[140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],51[255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],52[255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],53[11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],54[0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],55[255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],56[0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],57[173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],58[255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],59[255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],60[255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],61[0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],62[0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],63[143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],64[8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],65[255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],66[92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],67[163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],68[255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],69[255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],70[10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],71[255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],72[41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],73[71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],74[184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],75[102, 255, 0], [92, 0, 255]]7677>>> image_processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-small")78>>> image_segmentor = UperNetForSemanticSegmentation.from_pretrained("openmmlab/upernet-convnext-small")7980>>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-seg", torch_dtype=torch.float16)8182>>> pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(83"runwayml/stable-diffusion-inpainting", controlnet=controlnet, safety_checker=None, torch_dtype=torch.float1684)8586>>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)87>>> pipe.enable_xformers_memory_efficient_attention()88>>> pipe.enable_model_cpu_offload()8990>>> def image_to_seg(image):91pixel_values = image_processor(image, return_tensors="pt").pixel_values92with torch.no_grad():93outputs = image_segmentor(pixel_values)94seg = image_processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]95color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 396palette = np.array(ade_palette())97for label, color in enumerate(palette):98color_seg[seg == label, :] = color99color_seg = color_seg.astype(np.uint8)100seg_image = Image.fromarray(color_seg)101return seg_image102103>>> image = load_image(104"https://github.com/CompVis/latent-diffusion/raw/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"105)106107>>> mask_image = load_image(108"https://github.com/CompVis/latent-diffusion/raw/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"109)110111>>> controlnet_conditioning_image = image_to_seg(image)112113>>> image = pipe(114"Face of a yellow cat, high resolution, sitting on a park bench",115image,116mask_image,117controlnet_conditioning_image,118num_inference_steps=20,119).images[0]120121>>> image.save("out.png")122```123"""124125126def prepare_image(image):127if isinstance(image, torch.Tensor):128# Batch single image129if image.ndim == 3:130image = image.unsqueeze(0)131132image = image.to(dtype=torch.float32)133else:134# preprocess image135if isinstance(image, (PIL.Image.Image, np.ndarray)):136image = [image]137138if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):139image = [np.array(i.convert("RGB"))[None, :] for i in image]140image = np.concatenate(image, axis=0)141elif isinstance(image, list) and isinstance(image[0], np.ndarray):142image = np.concatenate([i[None, :] for i in image], axis=0)143144image = image.transpose(0, 3, 1, 2)145image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0146147return image148149150def prepare_mask_image(mask_image):151if isinstance(mask_image, torch.Tensor):152if mask_image.ndim == 2:153# Batch and add channel dim for single mask154mask_image = mask_image.unsqueeze(0).unsqueeze(0)155elif mask_image.ndim == 3 and mask_image.shape[0] == 1:156# Single mask, the 0'th dimension is considered to be157# the existing batch size of 1158mask_image = mask_image.unsqueeze(0)159elif mask_image.ndim == 3 and mask_image.shape[0] != 1:160# Batch of mask, the 0'th dimension is considered to be161# the batching dimension162mask_image = mask_image.unsqueeze(1)163164# Binarize mask165mask_image[mask_image < 0.5] = 0166mask_image[mask_image >= 0.5] = 1167else:168# preprocess mask169if isinstance(mask_image, (PIL.Image.Image, np.ndarray)):170mask_image = [mask_image]171172if isinstance(mask_image, list) and isinstance(mask_image[0], PIL.Image.Image):173mask_image = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask_image], axis=0)174mask_image = mask_image.astype(np.float32) / 255.0175elif isinstance(mask_image, list) and isinstance(mask_image[0], np.ndarray):176mask_image = np.concatenate([m[None, None, :] for m in mask_image], axis=0)177178mask_image[mask_image < 0.5] = 0179mask_image[mask_image >= 0.5] = 1180mask_image = torch.from_numpy(mask_image)181182return mask_image183184185def prepare_controlnet_conditioning_image(186controlnet_conditioning_image, width, height, batch_size, num_images_per_prompt, device, dtype187):188if not isinstance(controlnet_conditioning_image, torch.Tensor):189if isinstance(controlnet_conditioning_image, PIL.Image.Image):190controlnet_conditioning_image = [controlnet_conditioning_image]191192if isinstance(controlnet_conditioning_image[0], PIL.Image.Image):193controlnet_conditioning_image = [194np.array(i.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]))[None, :]195for i in controlnet_conditioning_image196]197controlnet_conditioning_image = np.concatenate(controlnet_conditioning_image, axis=0)198controlnet_conditioning_image = np.array(controlnet_conditioning_image).astype(np.float32) / 255.0199controlnet_conditioning_image = controlnet_conditioning_image.transpose(0, 3, 1, 2)200controlnet_conditioning_image = torch.from_numpy(controlnet_conditioning_image)201elif isinstance(controlnet_conditioning_image[0], torch.Tensor):202controlnet_conditioning_image = torch.cat(controlnet_conditioning_image, dim=0)203204image_batch_size = controlnet_conditioning_image.shape[0]205206if image_batch_size == 1:207repeat_by = batch_size208else:209# image batch size is the same as prompt batch size210repeat_by = num_images_per_prompt211212controlnet_conditioning_image = controlnet_conditioning_image.repeat_interleave(repeat_by, dim=0)213214controlnet_conditioning_image = controlnet_conditioning_image.to(device=device, dtype=dtype)215216return controlnet_conditioning_image217218219class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):220"""221Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/222"""223224_optional_components = ["safety_checker", "feature_extractor"]225226def __init__(227self,228vae: AutoencoderKL,229text_encoder: CLIPTextModel,230tokenizer: CLIPTokenizer,231unet: UNet2DConditionModel,232controlnet: ControlNetModel,233scheduler: KarrasDiffusionSchedulers,234safety_checker: StableDiffusionSafetyChecker,235feature_extractor: CLIPImageProcessor,236requires_safety_checker: bool = True,237):238super().__init__()239240if safety_checker is None and requires_safety_checker:241logger.warning(242f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"243" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"244" results in services or applications open to the public. Both the diffusers team and Hugging Face"245" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"246" it only for use-cases that involve analyzing network behavior or auditing its results. For more"247" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."248)249250if safety_checker is not None and feature_extractor is None:251raise ValueError(252"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"253" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."254)255256self.register_modules(257vae=vae,258text_encoder=text_encoder,259tokenizer=tokenizer,260unet=unet,261controlnet=controlnet,262scheduler=scheduler,263safety_checker=safety_checker,264feature_extractor=feature_extractor,265)266self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)267self.register_to_config(requires_safety_checker=requires_safety_checker)268269def enable_vae_slicing(self):270r"""271Enable sliced VAE decoding.272273When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several274steps. This is useful to save some memory and allow larger batch sizes.275"""276self.vae.enable_slicing()277278def disable_vae_slicing(self):279r"""280Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to281computing decoding in one step.282"""283self.vae.disable_slicing()284285def enable_sequential_cpu_offload(self, gpu_id=0):286r"""287Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,288text_encoder, vae, controlnet, and safety checker have their state dicts saved to CPU and then are moved to a289`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.290Note that offloading happens on a submodule basis. Memory savings are higher than with291`enable_model_cpu_offload`, but performance is lower.292"""293if is_accelerate_available():294from accelerate import cpu_offload295else:296raise ImportError("Please install accelerate via `pip install accelerate`")297298device = torch.device(f"cuda:{gpu_id}")299300for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.controlnet]:301cpu_offload(cpu_offloaded_model, device)302303if self.safety_checker is not None:304cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)305306def enable_model_cpu_offload(self, gpu_id=0):307r"""308Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared309to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`310method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with311`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.312"""313if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):314from accelerate import cpu_offload_with_hook315else:316raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")317318device = torch.device(f"cuda:{gpu_id}")319320hook = None321for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:322_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)323324if self.safety_checker is not None:325# the safety checker can offload the vae again326_, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)327328# control net hook has be manually offloaded as it alternates with unet329cpu_offload_with_hook(self.controlnet, device)330331# We'll offload the last model manually.332self.final_offload_hook = hook333334@property335def _execution_device(self):336r"""337Returns the device on which the pipeline's models will be executed. After calling338`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module339hooks.340"""341if not hasattr(self.unet, "_hf_hook"):342return self.device343for module in self.unet.modules():344if (345hasattr(module, "_hf_hook")346and hasattr(module._hf_hook, "execution_device")347and module._hf_hook.execution_device is not None348):349return torch.device(module._hf_hook.execution_device)350return self.device351352def _encode_prompt(353self,354prompt,355device,356num_images_per_prompt,357do_classifier_free_guidance,358negative_prompt=None,359prompt_embeds: Optional[torch.FloatTensor] = None,360negative_prompt_embeds: Optional[torch.FloatTensor] = None,361):362r"""363Encodes the prompt into text encoder hidden states.364365Args:366prompt (`str` or `List[str]`, *optional*):367prompt to be encoded368device: (`torch.device`):369torch device370num_images_per_prompt (`int`):371number of images that should be generated per prompt372do_classifier_free_guidance (`bool`):373whether to use classifier free guidance or not374negative_prompt (`str` or `List[str]`, *optional*):375The prompt or prompts not to guide the image generation. If not defined, one has to pass376`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.377Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).378prompt_embeds (`torch.FloatTensor`, *optional*):379Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not380provided, text embeddings will be generated from `prompt` input argument.381negative_prompt_embeds (`torch.FloatTensor`, *optional*):382Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt383weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input384argument.385"""386if prompt is not None and isinstance(prompt, str):387batch_size = 1388elif prompt is not None and isinstance(prompt, list):389batch_size = len(prompt)390else:391batch_size = prompt_embeds.shape[0]392393if prompt_embeds is None:394text_inputs = self.tokenizer(395prompt,396padding="max_length",397max_length=self.tokenizer.model_max_length,398truncation=True,399return_tensors="pt",400)401text_input_ids = text_inputs.input_ids402untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids403404if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(405text_input_ids, untruncated_ids406):407removed_text = self.tokenizer.batch_decode(408untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]409)410logger.warning(411"The following part of your input was truncated because CLIP can only handle sequences up to"412f" {self.tokenizer.model_max_length} tokens: {removed_text}"413)414415if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:416attention_mask = text_inputs.attention_mask.to(device)417else:418attention_mask = None419420prompt_embeds = self.text_encoder(421text_input_ids.to(device),422attention_mask=attention_mask,423)424prompt_embeds = prompt_embeds[0]425426prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)427428bs_embed, seq_len, _ = prompt_embeds.shape429# duplicate text embeddings for each generation per prompt, using mps friendly method430prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)431prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)432433# get unconditional embeddings for classifier free guidance434if do_classifier_free_guidance and negative_prompt_embeds is None:435uncond_tokens: List[str]436if negative_prompt is None:437uncond_tokens = [""] * batch_size438elif type(prompt) is not type(negative_prompt):439raise TypeError(440f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="441f" {type(prompt)}."442)443elif isinstance(negative_prompt, str):444uncond_tokens = [negative_prompt]445elif batch_size != len(negative_prompt):446raise ValueError(447f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"448f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"449" the batch size of `prompt`."450)451else:452uncond_tokens = negative_prompt453454max_length = prompt_embeds.shape[1]455uncond_input = self.tokenizer(456uncond_tokens,457padding="max_length",458max_length=max_length,459truncation=True,460return_tensors="pt",461)462463if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:464attention_mask = uncond_input.attention_mask.to(device)465else:466attention_mask = None467468negative_prompt_embeds = self.text_encoder(469uncond_input.input_ids.to(device),470attention_mask=attention_mask,471)472negative_prompt_embeds = negative_prompt_embeds[0]473474if do_classifier_free_guidance:475# duplicate unconditional embeddings for each generation per prompt, using mps friendly method476seq_len = negative_prompt_embeds.shape[1]477478negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)479480negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)481negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)482483# For classifier free guidance, we need to do two forward passes.484# Here we concatenate the unconditional and text embeddings into a single batch485# to avoid doing two forward passes486prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])487488return prompt_embeds489490def run_safety_checker(self, image, device, dtype):491if self.safety_checker is not None:492safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)493image, has_nsfw_concept = self.safety_checker(494images=image, clip_input=safety_checker_input.pixel_values.to(dtype)495)496else:497has_nsfw_concept = None498return image, has_nsfw_concept499500def decode_latents(self, latents):501latents = 1 / self.vae.config.scaling_factor * latents502image = self.vae.decode(latents).sample503image = (image / 2 + 0.5).clamp(0, 1)504# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16505image = image.cpu().permute(0, 2, 3, 1).float().numpy()506return image507508def prepare_extra_step_kwargs(self, generator, eta):509# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature510# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.511# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502512# and should be between [0, 1]513514accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())515extra_step_kwargs = {}516if accepts_eta:517extra_step_kwargs["eta"] = eta518519# check if the scheduler accepts generator520accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())521if accepts_generator:522extra_step_kwargs["generator"] = generator523return extra_step_kwargs524525def check_inputs(526self,527prompt,528image,529mask_image,530controlnet_conditioning_image,531height,532width,533callback_steps,534negative_prompt=None,535prompt_embeds=None,536negative_prompt_embeds=None,537):538if height % 8 != 0 or width % 8 != 0:539raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")540541if (callback_steps is None) or (542callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)543):544raise ValueError(545f"`callback_steps` has to be a positive integer but is {callback_steps} of type"546f" {type(callback_steps)}."547)548549if prompt is not None and prompt_embeds is not None:550raise ValueError(551f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"552" only forward one of the two."553)554elif prompt is None and prompt_embeds is None:555raise ValueError(556"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."557)558elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):559raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")560561if negative_prompt is not None and negative_prompt_embeds is not None:562raise ValueError(563f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"564f" {negative_prompt_embeds}. Please make sure to only forward one of the two."565)566567if prompt_embeds is not None and negative_prompt_embeds is not None:568if prompt_embeds.shape != negative_prompt_embeds.shape:569raise ValueError(570"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"571f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"572f" {negative_prompt_embeds.shape}."573)574575controlnet_cond_image_is_pil = isinstance(controlnet_conditioning_image, PIL.Image.Image)576controlnet_cond_image_is_tensor = isinstance(controlnet_conditioning_image, torch.Tensor)577controlnet_cond_image_is_pil_list = isinstance(controlnet_conditioning_image, list) and isinstance(578controlnet_conditioning_image[0], PIL.Image.Image579)580controlnet_cond_image_is_tensor_list = isinstance(controlnet_conditioning_image, list) and isinstance(581controlnet_conditioning_image[0], torch.Tensor582)583584if (585not controlnet_cond_image_is_pil586and not controlnet_cond_image_is_tensor587and not controlnet_cond_image_is_pil_list588and not controlnet_cond_image_is_tensor_list589):590raise TypeError(591"image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"592)593594if controlnet_cond_image_is_pil:595controlnet_cond_image_batch_size = 1596elif controlnet_cond_image_is_tensor:597controlnet_cond_image_batch_size = controlnet_conditioning_image.shape[0]598elif controlnet_cond_image_is_pil_list:599controlnet_cond_image_batch_size = len(controlnet_conditioning_image)600elif controlnet_cond_image_is_tensor_list:601controlnet_cond_image_batch_size = len(controlnet_conditioning_image)602603if prompt is not None and isinstance(prompt, str):604prompt_batch_size = 1605elif prompt is not None and isinstance(prompt, list):606prompt_batch_size = len(prompt)607elif prompt_embeds is not None:608prompt_batch_size = prompt_embeds.shape[0]609610if controlnet_cond_image_batch_size != 1 and controlnet_cond_image_batch_size != prompt_batch_size:611raise ValueError(612f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {controlnet_cond_image_batch_size}, prompt batch size: {prompt_batch_size}"613)614615if isinstance(image, torch.Tensor) and not isinstance(mask_image, torch.Tensor):616raise TypeError("if `image` is a tensor, `mask_image` must also be a tensor")617618if isinstance(image, PIL.Image.Image) and not isinstance(mask_image, PIL.Image.Image):619raise TypeError("if `image` is a PIL image, `mask_image` must also be a PIL image")620621if isinstance(image, torch.Tensor):622if image.ndim != 3 and image.ndim != 4:623raise ValueError("`image` must have 3 or 4 dimensions")624625if mask_image.ndim != 2 and mask_image.ndim != 3 and mask_image.ndim != 4:626raise ValueError("`mask_image` must have 2, 3, or 4 dimensions")627628if image.ndim == 3:629image_batch_size = 1630image_channels, image_height, image_width = image.shape631elif image.ndim == 4:632image_batch_size, image_channels, image_height, image_width = image.shape633634if mask_image.ndim == 2:635mask_image_batch_size = 1636mask_image_channels = 1637mask_image_height, mask_image_width = mask_image.shape638elif mask_image.ndim == 3:639mask_image_channels = 1640mask_image_batch_size, mask_image_height, mask_image_width = mask_image.shape641elif mask_image.ndim == 4:642mask_image_batch_size, mask_image_channels, mask_image_height, mask_image_width = mask_image.shape643644if image_channels != 3:645raise ValueError("`image` must have 3 channels")646647if mask_image_channels != 1:648raise ValueError("`mask_image` must have 1 channel")649650if image_batch_size != mask_image_batch_size:651raise ValueError("`image` and `mask_image` mush have the same batch sizes")652653if image_height != mask_image_height or image_width != mask_image_width:654raise ValueError("`image` and `mask_image` must have the same height and width dimensions")655656if image.min() < -1 or image.max() > 1:657raise ValueError("`image` should be in range [-1, 1]")658659if mask_image.min() < 0 or mask_image.max() > 1:660raise ValueError("`mask_image` should be in range [0, 1]")661else:662mask_image_channels = 1663image_channels = 3664665single_image_latent_channels = self.vae.config.latent_channels666667total_latent_channels = single_image_latent_channels * 2 + mask_image_channels668669if total_latent_channels != self.unet.config.in_channels:670raise ValueError(671f"The config of `pipeline.unet` expects {self.unet.config.in_channels} but received"672f" non inpainting latent channels: {single_image_latent_channels},"673f" mask channels: {mask_image_channels}, and masked image channels: {single_image_latent_channels}."674f" Please verify the config of `pipeline.unet` and the `mask_image` and `image` inputs."675)676677def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):678shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)679if isinstance(generator, list) and len(generator) != batch_size:680raise ValueError(681f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"682f" size of {batch_size}. Make sure the batch size matches the length of the generators."683)684685if latents is None:686latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)687else:688latents = latents.to(device)689690# scale the initial noise by the standard deviation required by the scheduler691latents = latents * self.scheduler.init_noise_sigma692693return latents694695def prepare_mask_latents(self, mask_image, batch_size, height, width, dtype, device, do_classifier_free_guidance):696# resize the mask to latents shape as we concatenate the mask to the latents697# we do that before converting to dtype to avoid breaking in case we're using cpu_offload698# and half precision699mask_image = F.interpolate(mask_image, size=(height // self.vae_scale_factor, width // self.vae_scale_factor))700mask_image = mask_image.to(device=device, dtype=dtype)701702# duplicate mask for each generation per prompt, using mps friendly method703if mask_image.shape[0] < batch_size:704if not batch_size % mask_image.shape[0] == 0:705raise ValueError(706"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"707f" a total batch size of {batch_size}, but {mask_image.shape[0]} masks were passed. Make sure the number"708" of masks that you pass is divisible by the total requested batch size."709)710mask_image = mask_image.repeat(batch_size // mask_image.shape[0], 1, 1, 1)711712mask_image = torch.cat([mask_image] * 2) if do_classifier_free_guidance else mask_image713714mask_image_latents = mask_image715716return mask_image_latents717718def prepare_masked_image_latents(719self, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance720):721masked_image = masked_image.to(device=device, dtype=dtype)722723# encode the mask image into latents space so we can concatenate it to the latents724if isinstance(generator, list):725masked_image_latents = [726self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(generator=generator[i])727for i in range(batch_size)728]729masked_image_latents = torch.cat(masked_image_latents, dim=0)730else:731masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)732masked_image_latents = self.vae.config.scaling_factor * masked_image_latents733734# duplicate masked_image_latents for each generation per prompt, using mps friendly method735if masked_image_latents.shape[0] < batch_size:736if not batch_size % masked_image_latents.shape[0] == 0:737raise ValueError(738"The passed images and the required batch size don't match. Images are supposed to be duplicated"739f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."740" Make sure the number of images that you pass is divisible by the total requested batch size."741)742masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)743744masked_image_latents = (745torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents746)747748# aligning device to prevent device errors when concating it with the latent model input749masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)750return masked_image_latents751752def _default_height_width(self, height, width, image):753if isinstance(image, list):754image = image[0]755756if height is None:757if isinstance(image, PIL.Image.Image):758height = image.height759elif isinstance(image, torch.Tensor):760height = image.shape[3]761762height = (height // 8) * 8 # round down to nearest multiple of 8763764if width is None:765if isinstance(image, PIL.Image.Image):766width = image.width767elif isinstance(image, torch.Tensor):768width = image.shape[2]769770width = (width // 8) * 8 # round down to nearest multiple of 8771772return height, width773774@torch.no_grad()775@replace_example_docstring(EXAMPLE_DOC_STRING)776def __call__(777self,778prompt: Union[str, List[str]] = None,779image: Union[torch.Tensor, PIL.Image.Image] = None,780mask_image: Union[torch.Tensor, PIL.Image.Image] = None,781controlnet_conditioning_image: Union[782torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]783] = None,784height: Optional[int] = None,785width: Optional[int] = None,786num_inference_steps: int = 50,787guidance_scale: float = 7.5,788negative_prompt: Optional[Union[str, List[str]]] = None,789num_images_per_prompt: Optional[int] = 1,790eta: float = 0.0,791generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,792latents: Optional[torch.FloatTensor] = None,793prompt_embeds: Optional[torch.FloatTensor] = None,794negative_prompt_embeds: Optional[torch.FloatTensor] = None,795output_type: Optional[str] = "pil",796return_dict: bool = True,797callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,798callback_steps: int = 1,799cross_attention_kwargs: Optional[Dict[str, Any]] = None,800controlnet_conditioning_scale: float = 1.0,801):802r"""803Function invoked when calling the pipeline for generation.804805Args:806prompt (`str` or `List[str]`, *optional*):807The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.808instead.809image (`torch.Tensor` or `PIL.Image.Image`):810`Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will811be masked out with `mask_image` and repainted according to `prompt`.812mask_image (`torch.Tensor` or `PIL.Image.Image`):813`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be814repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted815to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)816instead of 3, so the expected shape would be `(B, H, W, 1)`.817controlnet_conditioning_image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]` or `List[PIL.Image.Image]`):818The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If819the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. PIL.Image.Image` can820also be accepted as an image. The control image is automatically resized to fit the output image.821height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):822The height in pixels of the generated image.823width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):824The width in pixels of the generated image.825num_inference_steps (`int`, *optional*, defaults to 50):826The number of denoising steps. More denoising steps usually lead to a higher quality image at the827expense of slower inference.828guidance_scale (`float`, *optional*, defaults to 7.5):829Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).830`guidance_scale` is defined as `w` of equation 2. of [Imagen831Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >8321`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,833usually at the expense of lower image quality.834negative_prompt (`str` or `List[str]`, *optional*):835The prompt or prompts not to guide the image generation. If not defined, one has to pass836`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.837Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).838num_images_per_prompt (`int`, *optional*, defaults to 1):839The number of images to generate per prompt.840eta (`float`, *optional*, defaults to 0.0):841Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to842[`schedulers.DDIMScheduler`], will be ignored for others.843generator (`torch.Generator` or `List[torch.Generator]`, *optional*):844One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)845to make generation deterministic.846latents (`torch.FloatTensor`, *optional*):847Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image848generation. Can be used to tweak the same generation with different prompts. If not provided, a latents849tensor will ge generated by sampling using the supplied random `generator`.850prompt_embeds (`torch.FloatTensor`, *optional*):851Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not852provided, text embeddings will be generated from `prompt` input argument.853negative_prompt_embeds (`torch.FloatTensor`, *optional*):854Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt855weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input856argument.857output_type (`str`, *optional*, defaults to `"pil"`):858The output format of the generate image. Choose between859[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.860return_dict (`bool`, *optional*, defaults to `True`):861Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a862plain tuple.863callback (`Callable`, *optional*):864A function that will be called every `callback_steps` steps during inference. The function will be865called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.866callback_steps (`int`, *optional*, defaults to 1):867The frequency at which the `callback` function will be called. If not specified, the callback will be868called at every step.869cross_attention_kwargs (`dict`, *optional*):870A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under871`self.processor` in872[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).873controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):874The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added875to the residual in the original unet.876877Examples:878879Returns:880[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:881[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.882When returning a tuple, the first element is a list with the generated images, and the second element is a883list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"884(nsfw) content, according to the `safety_checker`.885"""886# 0. Default height and width to unet887height, width = self._default_height_width(height, width, controlnet_conditioning_image)888889# 1. Check inputs. Raise error if not correct890self.check_inputs(891prompt,892image,893mask_image,894controlnet_conditioning_image,895height,896width,897callback_steps,898negative_prompt,899prompt_embeds,900negative_prompt_embeds,901)902903# 2. Define call parameters904if prompt is not None and isinstance(prompt, str):905batch_size = 1906elif prompt is not None and isinstance(prompt, list):907batch_size = len(prompt)908else:909batch_size = prompt_embeds.shape[0]910911device = self._execution_device912# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)913# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`914# corresponds to doing no classifier free guidance.915do_classifier_free_guidance = guidance_scale > 1.0916917# 3. Encode input prompt918prompt_embeds = self._encode_prompt(919prompt,920device,921num_images_per_prompt,922do_classifier_free_guidance,923negative_prompt,924prompt_embeds=prompt_embeds,925negative_prompt_embeds=negative_prompt_embeds,926)927928# 4. Prepare mask, image, and controlnet_conditioning_image929image = prepare_image(image)930931mask_image = prepare_mask_image(mask_image)932933controlnet_conditioning_image = prepare_controlnet_conditioning_image(934controlnet_conditioning_image,935width,936height,937batch_size * num_images_per_prompt,938num_images_per_prompt,939device,940self.controlnet.dtype,941)942943masked_image = image * (mask_image < 0.5)944945# 5. Prepare timesteps946self.scheduler.set_timesteps(num_inference_steps, device=device)947timesteps = self.scheduler.timesteps948949# 6. Prepare latent variables950num_channels_latents = self.vae.config.latent_channels951latents = self.prepare_latents(952batch_size * num_images_per_prompt,953num_channels_latents,954height,955width,956prompt_embeds.dtype,957device,958generator,959latents,960)961962mask_image_latents = self.prepare_mask_latents(963mask_image,964batch_size * num_images_per_prompt,965height,966width,967prompt_embeds.dtype,968device,969do_classifier_free_guidance,970)971972masked_image_latents = self.prepare_masked_image_latents(973masked_image,974batch_size * num_images_per_prompt,975height,976width,977prompt_embeds.dtype,978device,979generator,980do_classifier_free_guidance,981)982983if do_classifier_free_guidance:984controlnet_conditioning_image = torch.cat([controlnet_conditioning_image] * 2)985986# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline987extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)988989# 8. Denoising loop990num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order991with self.progress_bar(total=num_inference_steps) as progress_bar:992for i, t in enumerate(timesteps):993# expand the latents if we are doing classifier free guidance994non_inpainting_latent_model_input = (995torch.cat([latents] * 2) if do_classifier_free_guidance else latents996)997998non_inpainting_latent_model_input = self.scheduler.scale_model_input(999non_inpainting_latent_model_input, t1000)10011002inpainting_latent_model_input = torch.cat(1003[non_inpainting_latent_model_input, mask_image_latents, masked_image_latents], dim=11004)10051006down_block_res_samples, mid_block_res_sample = self.controlnet(1007non_inpainting_latent_model_input,1008t,1009encoder_hidden_states=prompt_embeds,1010controlnet_cond=controlnet_conditioning_image,1011return_dict=False,1012)10131014down_block_res_samples = [1015down_block_res_sample * controlnet_conditioning_scale1016for down_block_res_sample in down_block_res_samples1017]1018mid_block_res_sample *= controlnet_conditioning_scale10191020# predict the noise residual1021noise_pred = self.unet(1022inpainting_latent_model_input,1023t,1024encoder_hidden_states=prompt_embeds,1025cross_attention_kwargs=cross_attention_kwargs,1026down_block_additional_residuals=down_block_res_samples,1027mid_block_additional_residual=mid_block_res_sample,1028).sample10291030# perform guidance1031if do_classifier_free_guidance:1032noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)1033noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)10341035# compute the previous noisy sample x_t -> x_t-11036latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample10371038# call the callback, if provided1039if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):1040progress_bar.update()1041if callback is not None and i % callback_steps == 0:1042callback(i, t, latents)10431044# If we do sequential model offloading, let's offload unet and controlnet1045# manually for max memory savings1046if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:1047self.unet.to("cpu")1048self.controlnet.to("cpu")1049torch.cuda.empty_cache()10501051if output_type == "latent":1052image = latents1053has_nsfw_concept = None1054elif output_type == "pil":1055# 8. Post-processing1056image = self.decode_latents(latents)10571058# 9. Run safety checker1059image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)10601061# 10. Convert to PIL1062image = self.numpy_to_pil(image)1063else:1064# 8. Post-processing1065image = self.decode_latents(latents)10661067# 9. Run safety checker1068image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)10691070# Offload last model to CPU1071if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:1072self.final_offload_hook.offload()10731074if not return_dict:1075return (image, has_nsfw_concept)10761077return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)107810791080