Path: blob/main/examples/community/stable_diffusion_controlnet_img2img.py
1448 views
# Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/12import inspect3from typing import Any, Callable, Dict, List, Optional, Union45import numpy as np6import PIL.Image7import torch8from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer910from diffusers import AutoencoderKL, ControlNetModel, DiffusionPipeline, UNet2DConditionModel, logging11from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker12from diffusers.schedulers import KarrasDiffusionSchedulers13from diffusers.utils import (14PIL_INTERPOLATION,15is_accelerate_available,16is_accelerate_version,17randn_tensor,18replace_example_docstring,19)202122logger = logging.get_logger(__name__) # pylint: disable=invalid-name2324EXAMPLE_DOC_STRING = """25Examples:26```py27>>> import numpy as np28>>> import torch29>>> from PIL import Image30>>> from diffusers import ControlNetModel, UniPCMultistepScheduler31>>> from diffusers.utils import load_image3233>>> input_image = load_image("https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png")3435>>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)3637>>> pipe_controlnet = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(38"runwayml/stable-diffusion-v1-5",39controlnet=controlnet,40safety_checker=None,41torch_dtype=torch.float1642)4344>>> pipe_controlnet.scheduler = UniPCMultistepScheduler.from_config(pipe_controlnet.scheduler.config)45>>> pipe_controlnet.enable_xformers_memory_efficient_attention()46>>> pipe_controlnet.enable_model_cpu_offload()4748# using image with edges for our canny controlnet49>>> control_image = load_image(50"https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/vermeer_canny_edged.png")515253>>> result_img = pipe_controlnet(controlnet_conditioning_image=control_image,54image=input_image,55prompt="an android robot, cyberpank, digitl art masterpiece",56num_inference_steps=20).images[0]5758>>> result_img.show()59```60"""616263def prepare_image(image):64if isinstance(image, torch.Tensor):65# Batch single image66if image.ndim == 3:67image = image.unsqueeze(0)6869image = image.to(dtype=torch.float32)70else:71# preprocess image72if isinstance(image, (PIL.Image.Image, np.ndarray)):73image = [image]7475if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):76image = [np.array(i.convert("RGB"))[None, :] for i in image]77image = np.concatenate(image, axis=0)78elif isinstance(image, list) and isinstance(image[0], np.ndarray):79image = np.concatenate([i[None, :] for i in image], axis=0)8081image = image.transpose(0, 3, 1, 2)82image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.08384return image858687def prepare_controlnet_conditioning_image(88controlnet_conditioning_image, width, height, batch_size, num_images_per_prompt, device, dtype89):90if not isinstance(controlnet_conditioning_image, torch.Tensor):91if isinstance(controlnet_conditioning_image, PIL.Image.Image):92controlnet_conditioning_image = [controlnet_conditioning_image]9394if isinstance(controlnet_conditioning_image[0], PIL.Image.Image):95controlnet_conditioning_image = [96np.array(i.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]))[None, :]97for i in controlnet_conditioning_image98]99controlnet_conditioning_image = np.concatenate(controlnet_conditioning_image, axis=0)100controlnet_conditioning_image = np.array(controlnet_conditioning_image).astype(np.float32) / 255.0101controlnet_conditioning_image = controlnet_conditioning_image.transpose(0, 3, 1, 2)102controlnet_conditioning_image = torch.from_numpy(controlnet_conditioning_image)103elif isinstance(controlnet_conditioning_image[0], torch.Tensor):104controlnet_conditioning_image = torch.cat(controlnet_conditioning_image, dim=0)105106image_batch_size = controlnet_conditioning_image.shape[0]107108if image_batch_size == 1:109repeat_by = batch_size110else:111# image batch size is the same as prompt batch size112repeat_by = num_images_per_prompt113114controlnet_conditioning_image = controlnet_conditioning_image.repeat_interleave(repeat_by, dim=0)115116controlnet_conditioning_image = controlnet_conditioning_image.to(device=device, dtype=dtype)117118return controlnet_conditioning_image119120121class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):122"""123Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/124"""125126_optional_components = ["safety_checker", "feature_extractor"]127128def __init__(129self,130vae: AutoencoderKL,131text_encoder: CLIPTextModel,132tokenizer: CLIPTokenizer,133unet: UNet2DConditionModel,134controlnet: ControlNetModel,135scheduler: KarrasDiffusionSchedulers,136safety_checker: StableDiffusionSafetyChecker,137feature_extractor: CLIPImageProcessor,138requires_safety_checker: bool = True,139):140super().__init__()141142if safety_checker is None and requires_safety_checker:143logger.warning(144f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"145" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"146" results in services or applications open to the public. Both the diffusers team and Hugging Face"147" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"148" it only for use-cases that involve analyzing network behavior or auditing its results. For more"149" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."150)151152if safety_checker is not None and feature_extractor is None:153raise ValueError(154"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"155" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."156)157158self.register_modules(159vae=vae,160text_encoder=text_encoder,161tokenizer=tokenizer,162unet=unet,163controlnet=controlnet,164scheduler=scheduler,165safety_checker=safety_checker,166feature_extractor=feature_extractor,167)168self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)169self.register_to_config(requires_safety_checker=requires_safety_checker)170171def enable_vae_slicing(self):172r"""173Enable sliced VAE decoding.174175When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several176steps. This is useful to save some memory and allow larger batch sizes.177"""178self.vae.enable_slicing()179180def disable_vae_slicing(self):181r"""182Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to183computing decoding in one step.184"""185self.vae.disable_slicing()186187def enable_sequential_cpu_offload(self, gpu_id=0):188r"""189Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,190text_encoder, vae, controlnet, and safety checker have their state dicts saved to CPU and then are moved to a191`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.192Note that offloading happens on a submodule basis. Memory savings are higher than with193`enable_model_cpu_offload`, but performance is lower.194"""195if is_accelerate_available():196from accelerate import cpu_offload197else:198raise ImportError("Please install accelerate via `pip install accelerate`")199200device = torch.device(f"cuda:{gpu_id}")201202for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.controlnet]:203cpu_offload(cpu_offloaded_model, device)204205if self.safety_checker is not None:206cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)207208def enable_model_cpu_offload(self, gpu_id=0):209r"""210Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared211to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`212method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with213`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.214"""215if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):216from accelerate import cpu_offload_with_hook217else:218raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")219220device = torch.device(f"cuda:{gpu_id}")221222hook = None223for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:224_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)225226if self.safety_checker is not None:227# the safety checker can offload the vae again228_, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)229230# control net hook has be manually offloaded as it alternates with unet231cpu_offload_with_hook(self.controlnet, device)232233# We'll offload the last model manually.234self.final_offload_hook = hook235236@property237def _execution_device(self):238r"""239Returns the device on which the pipeline's models will be executed. After calling240`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module241hooks.242"""243if not hasattr(self.unet, "_hf_hook"):244return self.device245for module in self.unet.modules():246if (247hasattr(module, "_hf_hook")248and hasattr(module._hf_hook, "execution_device")249and module._hf_hook.execution_device is not None250):251return torch.device(module._hf_hook.execution_device)252return self.device253254def _encode_prompt(255self,256prompt,257device,258num_images_per_prompt,259do_classifier_free_guidance,260negative_prompt=None,261prompt_embeds: Optional[torch.FloatTensor] = None,262negative_prompt_embeds: Optional[torch.FloatTensor] = None,263):264r"""265Encodes the prompt into text encoder hidden states.266267Args:268prompt (`str` or `List[str]`, *optional*):269prompt to be encoded270device: (`torch.device`):271torch device272num_images_per_prompt (`int`):273number of images that should be generated per prompt274do_classifier_free_guidance (`bool`):275whether to use classifier free guidance or not276negative_prompt (`str` or `List[str]`, *optional*):277The prompt or prompts not to guide the image generation. If not defined, one has to pass278`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.279Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).280prompt_embeds (`torch.FloatTensor`, *optional*):281Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not282provided, text embeddings will be generated from `prompt` input argument.283negative_prompt_embeds (`torch.FloatTensor`, *optional*):284Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt285weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input286argument.287"""288if prompt is not None and isinstance(prompt, str):289batch_size = 1290elif prompt is not None and isinstance(prompt, list):291batch_size = len(prompt)292else:293batch_size = prompt_embeds.shape[0]294295if prompt_embeds is None:296text_inputs = self.tokenizer(297prompt,298padding="max_length",299max_length=self.tokenizer.model_max_length,300truncation=True,301return_tensors="pt",302)303text_input_ids = text_inputs.input_ids304untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids305306if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(307text_input_ids, untruncated_ids308):309removed_text = self.tokenizer.batch_decode(310untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]311)312logger.warning(313"The following part of your input was truncated because CLIP can only handle sequences up to"314f" {self.tokenizer.model_max_length} tokens: {removed_text}"315)316317if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:318attention_mask = text_inputs.attention_mask.to(device)319else:320attention_mask = None321322prompt_embeds = self.text_encoder(323text_input_ids.to(device),324attention_mask=attention_mask,325)326prompt_embeds = prompt_embeds[0]327328prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)329330bs_embed, seq_len, _ = prompt_embeds.shape331# duplicate text embeddings for each generation per prompt, using mps friendly method332prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)333prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)334335# get unconditional embeddings for classifier free guidance336if do_classifier_free_guidance and negative_prompt_embeds is None:337uncond_tokens: List[str]338if negative_prompt is None:339uncond_tokens = [""] * batch_size340elif type(prompt) is not type(negative_prompt):341raise TypeError(342f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="343f" {type(prompt)}."344)345elif isinstance(negative_prompt, str):346uncond_tokens = [negative_prompt]347elif batch_size != len(negative_prompt):348raise ValueError(349f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"350f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"351" the batch size of `prompt`."352)353else:354uncond_tokens = negative_prompt355356max_length = prompt_embeds.shape[1]357uncond_input = self.tokenizer(358uncond_tokens,359padding="max_length",360max_length=max_length,361truncation=True,362return_tensors="pt",363)364365if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:366attention_mask = uncond_input.attention_mask.to(device)367else:368attention_mask = None369370negative_prompt_embeds = self.text_encoder(371uncond_input.input_ids.to(device),372attention_mask=attention_mask,373)374negative_prompt_embeds = negative_prompt_embeds[0]375376if do_classifier_free_guidance:377# duplicate unconditional embeddings for each generation per prompt, using mps friendly method378seq_len = negative_prompt_embeds.shape[1]379380negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)381382negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)383negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)384385# For classifier free guidance, we need to do two forward passes.386# Here we concatenate the unconditional and text embeddings into a single batch387# to avoid doing two forward passes388prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])389390return prompt_embeds391392def run_safety_checker(self, image, device, dtype):393if self.safety_checker is not None:394safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)395image, has_nsfw_concept = self.safety_checker(396images=image, clip_input=safety_checker_input.pixel_values.to(dtype)397)398else:399has_nsfw_concept = None400return image, has_nsfw_concept401402def decode_latents(self, latents):403latents = 1 / self.vae.config.scaling_factor * latents404image = self.vae.decode(latents).sample405image = (image / 2 + 0.5).clamp(0, 1)406# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16407image = image.cpu().permute(0, 2, 3, 1).float().numpy()408return image409410def prepare_extra_step_kwargs(self, generator, eta):411# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature412# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.413# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502414# and should be between [0, 1]415416accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())417extra_step_kwargs = {}418if accepts_eta:419extra_step_kwargs["eta"] = eta420421# check if the scheduler accepts generator422accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())423if accepts_generator:424extra_step_kwargs["generator"] = generator425return extra_step_kwargs426427def check_inputs(428self,429prompt,430image,431controlnet_conditioning_image,432height,433width,434callback_steps,435negative_prompt=None,436prompt_embeds=None,437negative_prompt_embeds=None,438strength=None,439controlnet_guidance_start=None,440controlnet_guidance_end=None,441):442if height % 8 != 0 or width % 8 != 0:443raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")444445if (callback_steps is None) or (446callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)447):448raise ValueError(449f"`callback_steps` has to be a positive integer but is {callback_steps} of type"450f" {type(callback_steps)}."451)452453if prompt is not None and prompt_embeds is not None:454raise ValueError(455f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"456" only forward one of the two."457)458elif prompt is None and prompt_embeds is None:459raise ValueError(460"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."461)462elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):463raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")464465if negative_prompt is not None and negative_prompt_embeds is not None:466raise ValueError(467f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"468f" {negative_prompt_embeds}. Please make sure to only forward one of the two."469)470471if prompt_embeds is not None and negative_prompt_embeds is not None:472if prompt_embeds.shape != negative_prompt_embeds.shape:473raise ValueError(474"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"475f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"476f" {negative_prompt_embeds.shape}."477)478479controlnet_cond_image_is_pil = isinstance(controlnet_conditioning_image, PIL.Image.Image)480controlnet_cond_image_is_tensor = isinstance(controlnet_conditioning_image, torch.Tensor)481controlnet_cond_image_is_pil_list = isinstance(controlnet_conditioning_image, list) and isinstance(482controlnet_conditioning_image[0], PIL.Image.Image483)484controlnet_cond_image_is_tensor_list = isinstance(controlnet_conditioning_image, list) and isinstance(485controlnet_conditioning_image[0], torch.Tensor486)487488if (489not controlnet_cond_image_is_pil490and not controlnet_cond_image_is_tensor491and not controlnet_cond_image_is_pil_list492and not controlnet_cond_image_is_tensor_list493):494raise TypeError(495"image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"496)497498if controlnet_cond_image_is_pil:499controlnet_cond_image_batch_size = 1500elif controlnet_cond_image_is_tensor:501controlnet_cond_image_batch_size = controlnet_conditioning_image.shape[0]502elif controlnet_cond_image_is_pil_list:503controlnet_cond_image_batch_size = len(controlnet_conditioning_image)504elif controlnet_cond_image_is_tensor_list:505controlnet_cond_image_batch_size = len(controlnet_conditioning_image)506507if prompt is not None and isinstance(prompt, str):508prompt_batch_size = 1509elif prompt is not None and isinstance(prompt, list):510prompt_batch_size = len(prompt)511elif prompt_embeds is not None:512prompt_batch_size = prompt_embeds.shape[0]513514if controlnet_cond_image_batch_size != 1 and controlnet_cond_image_batch_size != prompt_batch_size:515raise ValueError(516f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {controlnet_cond_image_batch_size}, prompt batch size: {prompt_batch_size}"517)518519if isinstance(image, torch.Tensor):520if image.ndim != 3 and image.ndim != 4:521raise ValueError("`image` must have 3 or 4 dimensions")522523# if mask_image.ndim != 2 and mask_image.ndim != 3 and mask_image.ndim != 4:524# raise ValueError("`mask_image` must have 2, 3, or 4 dimensions")525526if image.ndim == 3:527image_batch_size = 1528image_channels, image_height, image_width = image.shape529elif image.ndim == 4:530image_batch_size, image_channels, image_height, image_width = image.shape531532if image_channels != 3:533raise ValueError("`image` must have 3 channels")534535if image.min() < -1 or image.max() > 1:536raise ValueError("`image` should be in range [-1, 1]")537538if self.vae.config.latent_channels != self.unet.config.in_channels:539raise ValueError(540f"The config of `pipeline.unet` expects {self.unet.config.in_channels} but received"541f" latent channels: {self.vae.config.latent_channels},"542f" Please verify the config of `pipeline.unet` and the `pipeline.vae`"543)544545if strength < 0 or strength > 1:546raise ValueError(f"The value of `strength` should in [0.0, 1.0] but is {strength}")547548if controlnet_guidance_start < 0 or controlnet_guidance_start > 1:549raise ValueError(550f"The value of `controlnet_guidance_start` should in [0.0, 1.0] but is {controlnet_guidance_start}"551)552553if controlnet_guidance_end < 0 or controlnet_guidance_end > 1:554raise ValueError(555f"The value of `controlnet_guidance_end` should in [0.0, 1.0] but is {controlnet_guidance_end}"556)557558if controlnet_guidance_start > controlnet_guidance_end:559raise ValueError(560"The value of `controlnet_guidance_start` should be less than `controlnet_guidance_end`, but got"561f" `controlnet_guidance_start` {controlnet_guidance_start} >= `controlnet_guidance_end` {controlnet_guidance_end}"562)563564def get_timesteps(self, num_inference_steps, strength, device):565# get the original timestep using init_timestep566init_timestep = min(int(num_inference_steps * strength), num_inference_steps)567568t_start = max(num_inference_steps - init_timestep, 0)569timesteps = self.scheduler.timesteps[t_start:]570571return timesteps, num_inference_steps - t_start572573def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):574if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):575raise ValueError(576f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"577)578579image = image.to(device=device, dtype=dtype)580581batch_size = batch_size * num_images_per_prompt582if isinstance(generator, list) and len(generator) != batch_size:583raise ValueError(584f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"585f" size of {batch_size}. Make sure the batch size matches the length of the generators."586)587588if isinstance(generator, list):589init_latents = [590self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)591]592init_latents = torch.cat(init_latents, dim=0)593else:594init_latents = self.vae.encode(image).latent_dist.sample(generator)595596init_latents = self.vae.config.scaling_factor * init_latents597598if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:599raise ValueError(600f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."601)602else:603init_latents = torch.cat([init_latents], dim=0)604605shape = init_latents.shape606noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)607608# get latents609init_latents = self.scheduler.add_noise(init_latents, noise, timestep)610latents = init_latents611612return latents613614def _default_height_width(self, height, width, image):615if isinstance(image, list):616image = image[0]617618if height is None:619if isinstance(image, PIL.Image.Image):620height = image.height621elif isinstance(image, torch.Tensor):622height = image.shape[3]623624height = (height // 8) * 8 # round down to nearest multiple of 8625626if width is None:627if isinstance(image, PIL.Image.Image):628width = image.width629elif isinstance(image, torch.Tensor):630width = image.shape[2]631632width = (width // 8) * 8 # round down to nearest multiple of 8633634return height, width635636@torch.no_grad()637@replace_example_docstring(EXAMPLE_DOC_STRING)638def __call__(639self,640prompt: Union[str, List[str]] = None,641image: Union[torch.Tensor, PIL.Image.Image] = None,642controlnet_conditioning_image: Union[643torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]644] = None,645strength: float = 0.8,646height: Optional[int] = None,647width: Optional[int] = None,648num_inference_steps: int = 50,649guidance_scale: float = 7.5,650negative_prompt: Optional[Union[str, List[str]]] = None,651num_images_per_prompt: Optional[int] = 1,652eta: float = 0.0,653generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,654latents: Optional[torch.FloatTensor] = None,655prompt_embeds: Optional[torch.FloatTensor] = None,656negative_prompt_embeds: Optional[torch.FloatTensor] = None,657output_type: Optional[str] = "pil",658return_dict: bool = True,659callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,660callback_steps: int = 1,661cross_attention_kwargs: Optional[Dict[str, Any]] = None,662controlnet_conditioning_scale: float = 1.0,663controlnet_guidance_start: float = 0.0,664controlnet_guidance_end: float = 1.0,665):666r"""667Function invoked when calling the pipeline for generation.668669Args:670prompt (`str` or `List[str]`, *optional*):671The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.672instead.673image (`torch.Tensor` or `PIL.Image.Image`):674`Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will675be masked out with `mask_image` and repainted according to `prompt`.676controlnet_conditioning_image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]` or `List[PIL.Image.Image]`):677The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If678the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. PIL.Image.Image` can679also be accepted as an image. The control image is automatically resized to fit the output image.680strength (`float`, *optional*):681Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`682will be used as a starting point, adding more noise to it the larger the `strength`. The number of683denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will684be maximum and the denoising process will run for the full number of iterations specified in685`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.686height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):687The height in pixels of the generated image.688width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):689The width in pixels of the generated image.690num_inference_steps (`int`, *optional*, defaults to 50):691The number of denoising steps. More denoising steps usually lead to a higher quality image at the692expense of slower inference.693guidance_scale (`float`, *optional*, defaults to 7.5):694Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).695`guidance_scale` is defined as `w` of equation 2. of [Imagen696Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >6971`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,698usually at the expense of lower image quality.699negative_prompt (`str` or `List[str]`, *optional*):700The prompt or prompts not to guide the image generation. If not defined, one has to pass701`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.702Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).703num_images_per_prompt (`int`, *optional*, defaults to 1):704The number of images to generate per prompt.705eta (`float`, *optional*, defaults to 0.0):706Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to707[`schedulers.DDIMScheduler`], will be ignored for others.708generator (`torch.Generator` or `List[torch.Generator]`, *optional*):709One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)710to make generation deterministic.711latents (`torch.FloatTensor`, *optional*):712Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image713generation. Can be used to tweak the same generation with different prompts. If not provided, a latents714tensor will ge generated by sampling using the supplied random `generator`.715prompt_embeds (`torch.FloatTensor`, *optional*):716Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not717provided, text embeddings will be generated from `prompt` input argument.718negative_prompt_embeds (`torch.FloatTensor`, *optional*):719Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt720weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input721argument.722output_type (`str`, *optional*, defaults to `"pil"`):723The output format of the generate image. Choose between724[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.725return_dict (`bool`, *optional*, defaults to `True`):726Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a727plain tuple.728callback (`Callable`, *optional*):729A function that will be called every `callback_steps` steps during inference. The function will be730called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.731callback_steps (`int`, *optional*, defaults to 1):732The frequency at which the `callback` function will be called. If not specified, the callback will be733called at every step.734cross_attention_kwargs (`dict`, *optional*):735A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under736`self.processor` in737[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).738controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):739The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added740to the residual in the original unet.741controlnet_guidance_start ('float', *optional*, defaults to 0.0):742The percentage of total steps the controlnet starts applying. Must be between 0 and 1.743controlnet_guidance_end ('float', *optional*, defaults to 1.0):744The percentage of total steps the controlnet ends applying. Must be between 0 and 1. Must be greater745than `controlnet_guidance_start`.746747Examples:748749Returns:750[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:751[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.752When returning a tuple, the first element is a list with the generated images, and the second element is a753list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"754(nsfw) content, according to the `safety_checker`.755"""756# 0. Default height and width to unet757height, width = self._default_height_width(height, width, controlnet_conditioning_image)758759# 1. Check inputs. Raise error if not correct760self.check_inputs(761prompt,762image,763# mask_image,764controlnet_conditioning_image,765height,766width,767callback_steps,768negative_prompt,769prompt_embeds,770negative_prompt_embeds,771strength,772controlnet_guidance_start,773controlnet_guidance_end,774)775776# 2. Define call parameters777if prompt is not None and isinstance(prompt, str):778batch_size = 1779elif prompt is not None and isinstance(prompt, list):780batch_size = len(prompt)781else:782batch_size = prompt_embeds.shape[0]783784device = self._execution_device785# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)786# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`787# corresponds to doing no classifier free guidance.788do_classifier_free_guidance = guidance_scale > 1.0789790# 3. Encode input prompt791prompt_embeds = self._encode_prompt(792prompt,793device,794num_images_per_prompt,795do_classifier_free_guidance,796negative_prompt,797prompt_embeds=prompt_embeds,798negative_prompt_embeds=negative_prompt_embeds,799)800801# 4. Prepare mask, image, and controlnet_conditioning_image802image = prepare_image(image)803804# mask_image = prepare_mask_image(mask_image)805806controlnet_conditioning_image = prepare_controlnet_conditioning_image(807controlnet_conditioning_image,808width,809height,810batch_size * num_images_per_prompt,811num_images_per_prompt,812device,813self.controlnet.dtype,814)815816# masked_image = image * (mask_image < 0.5)817818# 5. Prepare timesteps819self.scheduler.set_timesteps(num_inference_steps, device=device)820timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)821latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)822823# 6. Prepare latent variables824latents = self.prepare_latents(825image,826latent_timestep,827batch_size,828num_images_per_prompt,829prompt_embeds.dtype,830device,831generator,832)833834if do_classifier_free_guidance:835controlnet_conditioning_image = torch.cat([controlnet_conditioning_image] * 2)836837# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline838extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)839840# 8. Denoising loop841num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order842with self.progress_bar(total=num_inference_steps) as progress_bar:843for i, t in enumerate(timesteps):844# expand the latents if we are doing classifier free guidance845latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents846847latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)848849# compute the percentage of total steps we are at850current_sampling_percent = i / len(timesteps)851852if (853current_sampling_percent < controlnet_guidance_start854or current_sampling_percent > controlnet_guidance_end855):856# do not apply the controlnet857down_block_res_samples = None858mid_block_res_sample = None859else:860# apply the controlnet861down_block_res_samples, mid_block_res_sample = self.controlnet(862latent_model_input,863t,864encoder_hidden_states=prompt_embeds,865controlnet_cond=controlnet_conditioning_image,866return_dict=False,867)868869down_block_res_samples = [870down_block_res_sample * controlnet_conditioning_scale871for down_block_res_sample in down_block_res_samples872]873mid_block_res_sample *= controlnet_conditioning_scale874875# predict the noise residual876noise_pred = self.unet(877latent_model_input,878t,879encoder_hidden_states=prompt_embeds,880cross_attention_kwargs=cross_attention_kwargs,881down_block_additional_residuals=down_block_res_samples,882mid_block_additional_residual=mid_block_res_sample,883).sample884885# perform guidance886if do_classifier_free_guidance:887noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)888noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)889890# compute the previous noisy sample x_t -> x_t-1891latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample892893# call the callback, if provided894if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):895progress_bar.update()896if callback is not None and i % callback_steps == 0:897callback(i, t, latents)898899# If we do sequential model offloading, let's offload unet and controlnet900# manually for max memory savings901if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:902self.unet.to("cpu")903self.controlnet.to("cpu")904torch.cuda.empty_cache()905906if output_type == "latent":907image = latents908has_nsfw_concept = None909elif output_type == "pil":910# 8. Post-processing911image = self.decode_latents(latents)912913# 9. Run safety checker914image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)915916# 10. Convert to PIL917image = self.numpy_to_pil(image)918else:919# 8. Post-processing920image = self.decode_latents(latents)921922# 9. Run safety checker923image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)924925# Offload last model to CPU926if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:927self.final_offload_hook.offload()928929if not return_dict:930return (image, has_nsfw_concept)931932return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)933934935