Path: blob/main/examples/community/stable_diffusion_controlnet_inpaint_img2img.py
1448 views
# Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/12import inspect3from typing import Any, Callable, Dict, List, Optional, Union45import numpy as np6import PIL.Image7import torch8import torch.nn.functional as F9from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer1011from diffusers import AutoencoderKL, ControlNetModel, DiffusionPipeline, UNet2DConditionModel, logging12from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker13from diffusers.schedulers import KarrasDiffusionSchedulers14from diffusers.utils import (15PIL_INTERPOLATION,16is_accelerate_available,17is_accelerate_version,18randn_tensor,19replace_example_docstring,20)212223logger = logging.get_logger(__name__) # pylint: disable=invalid-name2425EXAMPLE_DOC_STRING = """26Examples:27```py28>>> import numpy as np29>>> import torch30>>> from PIL import Image31>>> from stable_diffusion_controlnet_inpaint_img2img import StableDiffusionControlNetInpaintImg2ImgPipeline3233>>> from transformers import AutoImageProcessor, UperNetForSemanticSegmentation34>>> from diffusers import ControlNetModel, UniPCMultistepScheduler35>>> from diffusers.utils import load_image3637>>> def ade_palette():38return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],39[4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],40[230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],41[150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],42[143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],43[0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],44[255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],45[255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],46[255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],47[224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],48[255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],49[6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],50[140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],51[255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],52[255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],53[11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],54[0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],55[255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],56[0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],57[173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],58[255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],59[255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],60[255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],61[0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],62[0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],63[143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],64[8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],65[255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],66[92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],67[163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],68[255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],69[255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],70[10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],71[255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],72[41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],73[71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],74[184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],75[102, 255, 0], [92, 0, 255]]7677>>> image_processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-small")78>>> image_segmentor = UperNetForSemanticSegmentation.from_pretrained("openmmlab/upernet-convnext-small")7980>>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-seg", torch_dtype=torch.float16)8182>>> pipe = StableDiffusionControlNetInpaintImg2ImgPipeline.from_pretrained(83"runwayml/stable-diffusion-inpainting", controlnet=controlnet, safety_checker=None, torch_dtype=torch.float1684)8586>>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)87>>> pipe.enable_xformers_memory_efficient_attention()88>>> pipe.enable_model_cpu_offload()8990>>> def image_to_seg(image):91pixel_values = image_processor(image, return_tensors="pt").pixel_values92with torch.no_grad():93outputs = image_segmentor(pixel_values)94seg = image_processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]95color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 396palette = np.array(ade_palette())97for label, color in enumerate(palette):98color_seg[seg == label, :] = color99color_seg = color_seg.astype(np.uint8)100seg_image = Image.fromarray(color_seg)101return seg_image102103>>> image = load_image(104"https://github.com/CompVis/latent-diffusion/raw/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"105)106107>>> mask_image = load_image(108"https://github.com/CompVis/latent-diffusion/raw/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"109)110111>>> controlnet_conditioning_image = image_to_seg(image)112113>>> image = pipe(114"Face of a yellow cat, high resolution, sitting on a park bench",115image,116mask_image,117controlnet_conditioning_image,118num_inference_steps=20,119).images[0]120121>>> image.save("out.png")122```123"""124125126def prepare_image(image):127if isinstance(image, torch.Tensor):128# Batch single image129if image.ndim == 3:130image = image.unsqueeze(0)131132image = image.to(dtype=torch.float32)133else:134# preprocess image135if isinstance(image, (PIL.Image.Image, np.ndarray)):136image = [image]137138if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):139image = [np.array(i.convert("RGB"))[None, :] for i in image]140image = np.concatenate(image, axis=0)141elif isinstance(image, list) and isinstance(image[0], np.ndarray):142image = np.concatenate([i[None, :] for i in image], axis=0)143144image = image.transpose(0, 3, 1, 2)145image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0146147return image148149150def prepare_mask_image(mask_image):151if isinstance(mask_image, torch.Tensor):152if mask_image.ndim == 2:153# Batch and add channel dim for single mask154mask_image = mask_image.unsqueeze(0).unsqueeze(0)155elif mask_image.ndim == 3 and mask_image.shape[0] == 1:156# Single mask, the 0'th dimension is considered to be157# the existing batch size of 1158mask_image = mask_image.unsqueeze(0)159elif mask_image.ndim == 3 and mask_image.shape[0] != 1:160# Batch of mask, the 0'th dimension is considered to be161# the batching dimension162mask_image = mask_image.unsqueeze(1)163164# Binarize mask165mask_image[mask_image < 0.5] = 0166mask_image[mask_image >= 0.5] = 1167else:168# preprocess mask169if isinstance(mask_image, (PIL.Image.Image, np.ndarray)):170mask_image = [mask_image]171172if isinstance(mask_image, list) and isinstance(mask_image[0], PIL.Image.Image):173mask_image = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask_image], axis=0)174mask_image = mask_image.astype(np.float32) / 255.0175elif isinstance(mask_image, list) and isinstance(mask_image[0], np.ndarray):176mask_image = np.concatenate([m[None, None, :] for m in mask_image], axis=0)177178mask_image[mask_image < 0.5] = 0179mask_image[mask_image >= 0.5] = 1180mask_image = torch.from_numpy(mask_image)181182return mask_image183184185def prepare_controlnet_conditioning_image(186controlnet_conditioning_image, width, height, batch_size, num_images_per_prompt, device, dtype187):188if not isinstance(controlnet_conditioning_image, torch.Tensor):189if isinstance(controlnet_conditioning_image, PIL.Image.Image):190controlnet_conditioning_image = [controlnet_conditioning_image]191192if isinstance(controlnet_conditioning_image[0], PIL.Image.Image):193controlnet_conditioning_image = [194np.array(i.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]))[None, :]195for i in controlnet_conditioning_image196]197controlnet_conditioning_image = np.concatenate(controlnet_conditioning_image, axis=0)198controlnet_conditioning_image = np.array(controlnet_conditioning_image).astype(np.float32) / 255.0199controlnet_conditioning_image = controlnet_conditioning_image.transpose(0, 3, 1, 2)200controlnet_conditioning_image = torch.from_numpy(controlnet_conditioning_image)201elif isinstance(controlnet_conditioning_image[0], torch.Tensor):202controlnet_conditioning_image = torch.cat(controlnet_conditioning_image, dim=0)203204image_batch_size = controlnet_conditioning_image.shape[0]205206if image_batch_size == 1:207repeat_by = batch_size208else:209# image batch size is the same as prompt batch size210repeat_by = num_images_per_prompt211212controlnet_conditioning_image = controlnet_conditioning_image.repeat_interleave(repeat_by, dim=0)213214controlnet_conditioning_image = controlnet_conditioning_image.to(device=device, dtype=dtype)215216return controlnet_conditioning_image217218219class StableDiffusionControlNetInpaintImg2ImgPipeline(DiffusionPipeline):220"""221Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/222"""223224_optional_components = ["safety_checker", "feature_extractor"]225226def __init__(227self,228vae: AutoencoderKL,229text_encoder: CLIPTextModel,230tokenizer: CLIPTokenizer,231unet: UNet2DConditionModel,232controlnet: ControlNetModel,233scheduler: KarrasDiffusionSchedulers,234safety_checker: StableDiffusionSafetyChecker,235feature_extractor: CLIPImageProcessor,236requires_safety_checker: bool = True,237):238super().__init__()239240if safety_checker is None and requires_safety_checker:241logger.warning(242f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"243" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"244" results in services or applications open to the public. Both the diffusers team and Hugging Face"245" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"246" it only for use-cases that involve analyzing network behavior or auditing its results. For more"247" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."248)249250if safety_checker is not None and feature_extractor is None:251raise ValueError(252"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"253" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."254)255256self.register_modules(257vae=vae,258text_encoder=text_encoder,259tokenizer=tokenizer,260unet=unet,261controlnet=controlnet,262scheduler=scheduler,263safety_checker=safety_checker,264feature_extractor=feature_extractor,265)266self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)267self.register_to_config(requires_safety_checker=requires_safety_checker)268269def enable_vae_slicing(self):270r"""271Enable sliced VAE decoding.272273When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several274steps. This is useful to save some memory and allow larger batch sizes.275"""276self.vae.enable_slicing()277278def disable_vae_slicing(self):279r"""280Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to281computing decoding in one step.282"""283self.vae.disable_slicing()284285def enable_sequential_cpu_offload(self, gpu_id=0):286r"""287Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,288text_encoder, vae, controlnet, and safety checker have their state dicts saved to CPU and then are moved to a289`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.290Note that offloading happens on a submodule basis. Memory savings are higher than with291`enable_model_cpu_offload`, but performance is lower.292"""293if is_accelerate_available():294from accelerate import cpu_offload295else:296raise ImportError("Please install accelerate via `pip install accelerate`")297298device = torch.device(f"cuda:{gpu_id}")299300for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.controlnet]:301cpu_offload(cpu_offloaded_model, device)302303if self.safety_checker is not None:304cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)305306def enable_model_cpu_offload(self, gpu_id=0):307r"""308Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared309to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`310method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with311`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.312"""313if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):314from accelerate import cpu_offload_with_hook315else:316raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")317318device = torch.device(f"cuda:{gpu_id}")319320hook = None321for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:322_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)323324if self.safety_checker is not None:325# the safety checker can offload the vae again326_, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)327328# control net hook has be manually offloaded as it alternates with unet329cpu_offload_with_hook(self.controlnet, device)330331# We'll offload the last model manually.332self.final_offload_hook = hook333334@property335def _execution_device(self):336r"""337Returns the device on which the pipeline's models will be executed. After calling338`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module339hooks.340"""341if not hasattr(self.unet, "_hf_hook"):342return self.device343for module in self.unet.modules():344if (345hasattr(module, "_hf_hook")346and hasattr(module._hf_hook, "execution_device")347and module._hf_hook.execution_device is not None348):349return torch.device(module._hf_hook.execution_device)350return self.device351352def _encode_prompt(353self,354prompt,355device,356num_images_per_prompt,357do_classifier_free_guidance,358negative_prompt=None,359prompt_embeds: Optional[torch.FloatTensor] = None,360negative_prompt_embeds: Optional[torch.FloatTensor] = None,361):362r"""363Encodes the prompt into text encoder hidden states.364365Args:366prompt (`str` or `List[str]`, *optional*):367prompt to be encoded368device: (`torch.device`):369torch device370num_images_per_prompt (`int`):371number of images that should be generated per prompt372do_classifier_free_guidance (`bool`):373whether to use classifier free guidance or not374negative_prompt (`str` or `List[str]`, *optional*):375The prompt or prompts not to guide the image generation. If not defined, one has to pass376`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.377Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).378prompt_embeds (`torch.FloatTensor`, *optional*):379Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not380provided, text embeddings will be generated from `prompt` input argument.381negative_prompt_embeds (`torch.FloatTensor`, *optional*):382Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt383weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input384argument.385"""386if prompt is not None and isinstance(prompt, str):387batch_size = 1388elif prompt is not None and isinstance(prompt, list):389batch_size = len(prompt)390else:391batch_size = prompt_embeds.shape[0]392393if prompt_embeds is None:394text_inputs = self.tokenizer(395prompt,396padding="max_length",397max_length=self.tokenizer.model_max_length,398truncation=True,399return_tensors="pt",400)401text_input_ids = text_inputs.input_ids402untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids403404if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(405text_input_ids, untruncated_ids406):407removed_text = self.tokenizer.batch_decode(408untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]409)410logger.warning(411"The following part of your input was truncated because CLIP can only handle sequences up to"412f" {self.tokenizer.model_max_length} tokens: {removed_text}"413)414415if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:416attention_mask = text_inputs.attention_mask.to(device)417else:418attention_mask = None419420prompt_embeds = self.text_encoder(421text_input_ids.to(device),422attention_mask=attention_mask,423)424prompt_embeds = prompt_embeds[0]425426prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)427428bs_embed, seq_len, _ = prompt_embeds.shape429# duplicate text embeddings for each generation per prompt, using mps friendly method430prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)431prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)432433# get unconditional embeddings for classifier free guidance434if do_classifier_free_guidance and negative_prompt_embeds is None:435uncond_tokens: List[str]436if negative_prompt is None:437uncond_tokens = [""] * batch_size438elif type(prompt) is not type(negative_prompt):439raise TypeError(440f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="441f" {type(prompt)}."442)443elif isinstance(negative_prompt, str):444uncond_tokens = [negative_prompt]445elif batch_size != len(negative_prompt):446raise ValueError(447f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"448f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"449" the batch size of `prompt`."450)451else:452uncond_tokens = negative_prompt453454max_length = prompt_embeds.shape[1]455uncond_input = self.tokenizer(456uncond_tokens,457padding="max_length",458max_length=max_length,459truncation=True,460return_tensors="pt",461)462463if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:464attention_mask = uncond_input.attention_mask.to(device)465else:466attention_mask = None467468negative_prompt_embeds = self.text_encoder(469uncond_input.input_ids.to(device),470attention_mask=attention_mask,471)472negative_prompt_embeds = negative_prompt_embeds[0]473474if do_classifier_free_guidance:475# duplicate unconditional embeddings for each generation per prompt, using mps friendly method476seq_len = negative_prompt_embeds.shape[1]477478negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)479480negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)481negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)482483# For classifier free guidance, we need to do two forward passes.484# Here we concatenate the unconditional and text embeddings into a single batch485# to avoid doing two forward passes486prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])487488return prompt_embeds489490def run_safety_checker(self, image, device, dtype):491if self.safety_checker is not None:492safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)493image, has_nsfw_concept = self.safety_checker(494images=image, clip_input=safety_checker_input.pixel_values.to(dtype)495)496else:497has_nsfw_concept = None498return image, has_nsfw_concept499500def decode_latents(self, latents):501latents = 1 / self.vae.config.scaling_factor * latents502image = self.vae.decode(latents).sample503image = (image / 2 + 0.5).clamp(0, 1)504# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16505image = image.cpu().permute(0, 2, 3, 1).float().numpy()506return image507508def prepare_extra_step_kwargs(self, generator, eta):509# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature510# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.511# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502512# and should be between [0, 1]513514accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())515extra_step_kwargs = {}516if accepts_eta:517extra_step_kwargs["eta"] = eta518519# check if the scheduler accepts generator520accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())521if accepts_generator:522extra_step_kwargs["generator"] = generator523return extra_step_kwargs524525def check_inputs(526self,527prompt,528image,529mask_image,530controlnet_conditioning_image,531height,532width,533callback_steps,534negative_prompt=None,535prompt_embeds=None,536negative_prompt_embeds=None,537strength=None,538):539if height % 8 != 0 or width % 8 != 0:540raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")541542if (callback_steps is None) or (543callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)544):545raise ValueError(546f"`callback_steps` has to be a positive integer but is {callback_steps} of type"547f" {type(callback_steps)}."548)549550if prompt is not None and prompt_embeds is not None:551raise ValueError(552f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"553" only forward one of the two."554)555elif prompt is None and prompt_embeds is None:556raise ValueError(557"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."558)559elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):560raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")561562if negative_prompt is not None and negative_prompt_embeds is not None:563raise ValueError(564f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"565f" {negative_prompt_embeds}. Please make sure to only forward one of the two."566)567568if prompt_embeds is not None and negative_prompt_embeds is not None:569if prompt_embeds.shape != negative_prompt_embeds.shape:570raise ValueError(571"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"572f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"573f" {negative_prompt_embeds.shape}."574)575576controlnet_cond_image_is_pil = isinstance(controlnet_conditioning_image, PIL.Image.Image)577controlnet_cond_image_is_tensor = isinstance(controlnet_conditioning_image, torch.Tensor)578controlnet_cond_image_is_pil_list = isinstance(controlnet_conditioning_image, list) and isinstance(579controlnet_conditioning_image[0], PIL.Image.Image580)581controlnet_cond_image_is_tensor_list = isinstance(controlnet_conditioning_image, list) and isinstance(582controlnet_conditioning_image[0], torch.Tensor583)584585if (586not controlnet_cond_image_is_pil587and not controlnet_cond_image_is_tensor588and not controlnet_cond_image_is_pil_list589and not controlnet_cond_image_is_tensor_list590):591raise TypeError(592"image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"593)594595if controlnet_cond_image_is_pil:596controlnet_cond_image_batch_size = 1597elif controlnet_cond_image_is_tensor:598controlnet_cond_image_batch_size = controlnet_conditioning_image.shape[0]599elif controlnet_cond_image_is_pil_list:600controlnet_cond_image_batch_size = len(controlnet_conditioning_image)601elif controlnet_cond_image_is_tensor_list:602controlnet_cond_image_batch_size = len(controlnet_conditioning_image)603604if prompt is not None and isinstance(prompt, str):605prompt_batch_size = 1606elif prompt is not None and isinstance(prompt, list):607prompt_batch_size = len(prompt)608elif prompt_embeds is not None:609prompt_batch_size = prompt_embeds.shape[0]610611if controlnet_cond_image_batch_size != 1 and controlnet_cond_image_batch_size != prompt_batch_size:612raise ValueError(613f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {controlnet_cond_image_batch_size}, prompt batch size: {prompt_batch_size}"614)615616if isinstance(image, torch.Tensor) and not isinstance(mask_image, torch.Tensor):617raise TypeError("if `image` is a tensor, `mask_image` must also be a tensor")618619if isinstance(image, PIL.Image.Image) and not isinstance(mask_image, PIL.Image.Image):620raise TypeError("if `image` is a PIL image, `mask_image` must also be a PIL image")621622if isinstance(image, torch.Tensor):623if image.ndim != 3 and image.ndim != 4:624raise ValueError("`image` must have 3 or 4 dimensions")625626if mask_image.ndim != 2 and mask_image.ndim != 3 and mask_image.ndim != 4:627raise ValueError("`mask_image` must have 2, 3, or 4 dimensions")628629if image.ndim == 3:630image_batch_size = 1631image_channels, image_height, image_width = image.shape632elif image.ndim == 4:633image_batch_size, image_channels, image_height, image_width = image.shape634635if mask_image.ndim == 2:636mask_image_batch_size = 1637mask_image_channels = 1638mask_image_height, mask_image_width = mask_image.shape639elif mask_image.ndim == 3:640mask_image_channels = 1641mask_image_batch_size, mask_image_height, mask_image_width = mask_image.shape642elif mask_image.ndim == 4:643mask_image_batch_size, mask_image_channels, mask_image_height, mask_image_width = mask_image.shape644645if image_channels != 3:646raise ValueError("`image` must have 3 channels")647648if mask_image_channels != 1:649raise ValueError("`mask_image` must have 1 channel")650651if image_batch_size != mask_image_batch_size:652raise ValueError("`image` and `mask_image` mush have the same batch sizes")653654if image_height != mask_image_height or image_width != mask_image_width:655raise ValueError("`image` and `mask_image` must have the same height and width dimensions")656657if image.min() < -1 or image.max() > 1:658raise ValueError("`image` should be in range [-1, 1]")659660if mask_image.min() < 0 or mask_image.max() > 1:661raise ValueError("`mask_image` should be in range [0, 1]")662else:663mask_image_channels = 1664image_channels = 3665666single_image_latent_channels = self.vae.config.latent_channels667668total_latent_channels = single_image_latent_channels * 2 + mask_image_channels669670if total_latent_channels != self.unet.config.in_channels:671raise ValueError(672f"The config of `pipeline.unet` expects {self.unet.config.in_channels} but received"673f" non inpainting latent channels: {single_image_latent_channels},"674f" mask channels: {mask_image_channels}, and masked image channels: {single_image_latent_channels}."675f" Please verify the config of `pipeline.unet` and the `mask_image` and `image` inputs."676)677678if strength < 0 or strength > 1:679raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")680681def get_timesteps(self, num_inference_steps, strength, device):682# get the original timestep using init_timestep683init_timestep = min(int(num_inference_steps * strength), num_inference_steps)684685t_start = max(num_inference_steps - init_timestep, 0)686timesteps = self.scheduler.timesteps[t_start:]687688return timesteps, num_inference_steps - t_start689690def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):691if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):692raise ValueError(693f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"694)695696image = image.to(device=device, dtype=dtype)697698batch_size = batch_size * num_images_per_prompt699if isinstance(generator, list) and len(generator) != batch_size:700raise ValueError(701f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"702f" size of {batch_size}. Make sure the batch size matches the length of the generators."703)704705if isinstance(generator, list):706init_latents = [707self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)708]709init_latents = torch.cat(init_latents, dim=0)710else:711init_latents = self.vae.encode(image).latent_dist.sample(generator)712713init_latents = self.vae.config.scaling_factor * init_latents714715if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:716raise ValueError(717f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."718)719else:720init_latents = torch.cat([init_latents], dim=0)721722shape = init_latents.shape723noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)724725# get latents726init_latents = self.scheduler.add_noise(init_latents, noise, timestep)727latents = init_latents728729return latents730731def prepare_mask_latents(self, mask_image, batch_size, height, width, dtype, device, do_classifier_free_guidance):732# resize the mask to latents shape as we concatenate the mask to the latents733# we do that before converting to dtype to avoid breaking in case we're using cpu_offload734# and half precision735mask_image = F.interpolate(mask_image, size=(height // self.vae_scale_factor, width // self.vae_scale_factor))736mask_image = mask_image.to(device=device, dtype=dtype)737738# duplicate mask for each generation per prompt, using mps friendly method739if mask_image.shape[0] < batch_size:740if not batch_size % mask_image.shape[0] == 0:741raise ValueError(742"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"743f" a total batch size of {batch_size}, but {mask_image.shape[0]} masks were passed. Make sure the number"744" of masks that you pass is divisible by the total requested batch size."745)746mask_image = mask_image.repeat(batch_size // mask_image.shape[0], 1, 1, 1)747748mask_image = torch.cat([mask_image] * 2) if do_classifier_free_guidance else mask_image749750mask_image_latents = mask_image751752return mask_image_latents753754def prepare_masked_image_latents(755self, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance756):757masked_image = masked_image.to(device=device, dtype=dtype)758759# encode the mask image into latents space so we can concatenate it to the latents760if isinstance(generator, list):761masked_image_latents = [762self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(generator=generator[i])763for i in range(batch_size)764]765masked_image_latents = torch.cat(masked_image_latents, dim=0)766else:767masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)768masked_image_latents = self.vae.config.scaling_factor * masked_image_latents769770# duplicate masked_image_latents for each generation per prompt, using mps friendly method771if masked_image_latents.shape[0] < batch_size:772if not batch_size % masked_image_latents.shape[0] == 0:773raise ValueError(774"The passed images and the required batch size don't match. Images are supposed to be duplicated"775f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."776" Make sure the number of images that you pass is divisible by the total requested batch size."777)778masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)779780masked_image_latents = (781torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents782)783784# aligning device to prevent device errors when concating it with the latent model input785masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)786return masked_image_latents787788def _default_height_width(self, height, width, image):789if isinstance(image, list):790image = image[0]791792if height is None:793if isinstance(image, PIL.Image.Image):794height = image.height795elif isinstance(image, torch.Tensor):796height = image.shape[3]797798height = (height // 8) * 8 # round down to nearest multiple of 8799800if width is None:801if isinstance(image, PIL.Image.Image):802width = image.width803elif isinstance(image, torch.Tensor):804width = image.shape[2]805806width = (width // 8) * 8 # round down to nearest multiple of 8807808return height, width809810@torch.no_grad()811@replace_example_docstring(EXAMPLE_DOC_STRING)812def __call__(813self,814prompt: Union[str, List[str]] = None,815image: Union[torch.Tensor, PIL.Image.Image] = None,816mask_image: Union[torch.Tensor, PIL.Image.Image] = None,817controlnet_conditioning_image: Union[818torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]819] = None,820strength: float = 0.8,821height: Optional[int] = None,822width: Optional[int] = None,823num_inference_steps: int = 50,824guidance_scale: float = 7.5,825negative_prompt: Optional[Union[str, List[str]]] = None,826num_images_per_prompt: Optional[int] = 1,827eta: float = 0.0,828generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,829latents: Optional[torch.FloatTensor] = None,830prompt_embeds: Optional[torch.FloatTensor] = None,831negative_prompt_embeds: Optional[torch.FloatTensor] = None,832output_type: Optional[str] = "pil",833return_dict: bool = True,834callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,835callback_steps: int = 1,836cross_attention_kwargs: Optional[Dict[str, Any]] = None,837controlnet_conditioning_scale: float = 1.0,838):839r"""840Function invoked when calling the pipeline for generation.841842Args:843prompt (`str` or `List[str]`, *optional*):844The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.845instead.846image (`torch.Tensor` or `PIL.Image.Image`):847`Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will848be masked out with `mask_image` and repainted according to `prompt`.849mask_image (`torch.Tensor` or `PIL.Image.Image`):850`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be851repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted852to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)853instead of 3, so the expected shape would be `(B, H, W, 1)`.854controlnet_conditioning_image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]` or `List[PIL.Image.Image]`):855The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If856the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. PIL.Image.Image` can857also be accepted as an image. The control image is automatically resized to fit the output image.858strength (`float`, *optional*):859Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`860will be used as a starting point, adding more noise to it the larger the `strength`. The number of861denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will862be maximum and the denoising process will run for the full number of iterations specified in863`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.864height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):865The height in pixels of the generated image.866width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):867The width in pixels of the generated image.868num_inference_steps (`int`, *optional*, defaults to 50):869The number of denoising steps. More denoising steps usually lead to a higher quality image at the870expense of slower inference.871guidance_scale (`float`, *optional*, defaults to 7.5):872Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).873`guidance_scale` is defined as `w` of equation 2. of [Imagen874Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >8751`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,876usually at the expense of lower image quality.877negative_prompt (`str` or `List[str]`, *optional*):878The prompt or prompts not to guide the image generation. If not defined, one has to pass879`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.880Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).881num_images_per_prompt (`int`, *optional*, defaults to 1):882The number of images to generate per prompt.883eta (`float`, *optional*, defaults to 0.0):884Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to885[`schedulers.DDIMScheduler`], will be ignored for others.886generator (`torch.Generator` or `List[torch.Generator]`, *optional*):887One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)888to make generation deterministic.889latents (`torch.FloatTensor`, *optional*):890Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image891generation. Can be used to tweak the same generation with different prompts. If not provided, a latents892tensor will ge generated by sampling using the supplied random `generator`.893prompt_embeds (`torch.FloatTensor`, *optional*):894Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not895provided, text embeddings will be generated from `prompt` input argument.896negative_prompt_embeds (`torch.FloatTensor`, *optional*):897Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt898weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input899argument.900output_type (`str`, *optional*, defaults to `"pil"`):901The output format of the generate image. Choose between902[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.903return_dict (`bool`, *optional*, defaults to `True`):904Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a905plain tuple.906callback (`Callable`, *optional*):907A function that will be called every `callback_steps` steps during inference. The function will be908called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.909callback_steps (`int`, *optional*, defaults to 1):910The frequency at which the `callback` function will be called. If not specified, the callback will be911called at every step.912cross_attention_kwargs (`dict`, *optional*):913A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under914`self.processor` in915[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).916controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):917The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added918to the residual in the original unet.919920Examples:921922Returns:923[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:924[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.925When returning a tuple, the first element is a list with the generated images, and the second element is a926list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"927(nsfw) content, according to the `safety_checker`.928"""929# 0. Default height and width to unet930height, width = self._default_height_width(height, width, controlnet_conditioning_image)931932# 1. Check inputs. Raise error if not correct933self.check_inputs(934prompt,935image,936mask_image,937controlnet_conditioning_image,938height,939width,940callback_steps,941negative_prompt,942prompt_embeds,943negative_prompt_embeds,944strength,945)946947# 2. Define call parameters948if prompt is not None and isinstance(prompt, str):949batch_size = 1950elif prompt is not None and isinstance(prompt, list):951batch_size = len(prompt)952else:953batch_size = prompt_embeds.shape[0]954955device = self._execution_device956# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)957# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`958# corresponds to doing no classifier free guidance.959do_classifier_free_guidance = guidance_scale > 1.0960961# 3. Encode input prompt962prompt_embeds = self._encode_prompt(963prompt,964device,965num_images_per_prompt,966do_classifier_free_guidance,967negative_prompt,968prompt_embeds=prompt_embeds,969negative_prompt_embeds=negative_prompt_embeds,970)971972# 4. Prepare mask, image, and controlnet_conditioning_image973image = prepare_image(image)974975mask_image = prepare_mask_image(mask_image)976977controlnet_conditioning_image = prepare_controlnet_conditioning_image(978controlnet_conditioning_image,979width,980height,981batch_size * num_images_per_prompt,982num_images_per_prompt,983device,984self.controlnet.dtype,985)986987masked_image = image * (mask_image < 0.5)988989# 5. Prepare timesteps990self.scheduler.set_timesteps(num_inference_steps, device=device)991timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)992latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)993994# 6. Prepare latent variables995latents = self.prepare_latents(996image,997latent_timestep,998batch_size,999num_images_per_prompt,1000prompt_embeds.dtype,1001device,1002generator,1003)10041005mask_image_latents = self.prepare_mask_latents(1006mask_image,1007batch_size * num_images_per_prompt,1008height,1009width,1010prompt_embeds.dtype,1011device,1012do_classifier_free_guidance,1013)10141015masked_image_latents = self.prepare_masked_image_latents(1016masked_image,1017batch_size * num_images_per_prompt,1018height,1019width,1020prompt_embeds.dtype,1021device,1022generator,1023do_classifier_free_guidance,1024)10251026if do_classifier_free_guidance:1027controlnet_conditioning_image = torch.cat([controlnet_conditioning_image] * 2)10281029# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline1030extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)10311032# 8. Denoising loop1033num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order1034with self.progress_bar(total=num_inference_steps) as progress_bar:1035for i, t in enumerate(timesteps):1036# expand the latents if we are doing classifier free guidance1037non_inpainting_latent_model_input = (1038torch.cat([latents] * 2) if do_classifier_free_guidance else latents1039)10401041non_inpainting_latent_model_input = self.scheduler.scale_model_input(1042non_inpainting_latent_model_input, t1043)10441045inpainting_latent_model_input = torch.cat(1046[non_inpainting_latent_model_input, mask_image_latents, masked_image_latents], dim=11047)10481049down_block_res_samples, mid_block_res_sample = self.controlnet(1050non_inpainting_latent_model_input,1051t,1052encoder_hidden_states=prompt_embeds,1053controlnet_cond=controlnet_conditioning_image,1054return_dict=False,1055)10561057down_block_res_samples = [1058down_block_res_sample * controlnet_conditioning_scale1059for down_block_res_sample in down_block_res_samples1060]1061mid_block_res_sample *= controlnet_conditioning_scale10621063# predict the noise residual1064noise_pred = self.unet(1065inpainting_latent_model_input,1066t,1067encoder_hidden_states=prompt_embeds,1068cross_attention_kwargs=cross_attention_kwargs,1069down_block_additional_residuals=down_block_res_samples,1070mid_block_additional_residual=mid_block_res_sample,1071).sample10721073# perform guidance1074if do_classifier_free_guidance:1075noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)1076noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)10771078# compute the previous noisy sample x_t -> x_t-11079latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample10801081# call the callback, if provided1082if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):1083progress_bar.update()1084if callback is not None and i % callback_steps == 0:1085callback(i, t, latents)10861087# If we do sequential model offloading, let's offload unet and controlnet1088# manually for max memory savings1089if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:1090self.unet.to("cpu")1091self.controlnet.to("cpu")1092torch.cuda.empty_cache()10931094if output_type == "latent":1095image = latents1096has_nsfw_concept = None1097elif output_type == "pil":1098# 8. Post-processing1099image = self.decode_latents(latents)11001101# 9. Run safety checker1102image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)11031104# 10. Convert to PIL1105image = self.numpy_to_pil(image)1106else:1107# 8. Post-processing1108image = self.decode_latents(latents)11091110# 9. Run safety checker1111image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)11121113# Offload last model to CPU1114if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:1115self.final_offload_hook.offload()11161117if not return_dict:1118return (image, has_nsfw_concept)11191120return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)112111221123