Path: blob/main/examples/community/img2img_inpainting.py
1448 views
import inspect1from typing import Callable, List, Optional, Tuple, Union23import numpy as np4import PIL5import torch6from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer78from diffusers import DiffusionPipeline9from diffusers.configuration_utils import FrozenDict10from diffusers.models import AutoencoderKL, UNet2DConditionModel11from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput12from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker13from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler14from diffusers.utils import deprecate, logging151617logger = logging.get_logger(__name__) # pylint: disable=invalid-name181920def prepare_mask_and_masked_image(image, mask):21image = np.array(image.convert("RGB"))22image = image[None].transpose(0, 3, 1, 2)23image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.02425mask = np.array(mask.convert("L"))26mask = mask.astype(np.float32) / 255.027mask = mask[None, None]28mask[mask < 0.5] = 029mask[mask >= 0.5] = 130mask = torch.from_numpy(mask)3132masked_image = image * (mask < 0.5)3334return mask, masked_image353637def check_size(image, height, width):38if isinstance(image, PIL.Image.Image):39w, h = image.size40elif isinstance(image, torch.Tensor):41*_, h, w = image.shape4243if h != height or w != width:44raise ValueError(f"Image size should be {height}x{width}, but got {h}x{w}")454647def overlay_inner_image(image, inner_image, paste_offset: Tuple[int] = (0, 0)):48inner_image = inner_image.convert("RGBA")49image = image.convert("RGB")5051image.paste(inner_image, paste_offset, inner_image)52image = image.convert("RGB")5354return image555657class ImageToImageInpaintingPipeline(DiffusionPipeline):58r"""59Pipeline for text-guided image-to-image inpainting using Stable Diffusion. *This is an experimental feature*.6061This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the62library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)6364Args:65vae ([`AutoencoderKL`]):66Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.67text_encoder ([`CLIPTextModel`]):68Frozen text-encoder. Stable Diffusion uses the text portion of69[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically70the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.71tokenizer (`CLIPTokenizer`):72Tokenizer of class73[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).74unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.75scheduler ([`SchedulerMixin`]):76A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of77[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].78safety_checker ([`StableDiffusionSafetyChecker`]):79Classification module that estimates whether generated images could be considered offensive or harmful.80Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.81feature_extractor ([`CLIPImageProcessor`]):82Model that extracts features from generated images to be used as inputs for the `safety_checker`.83"""8485def __init__(86self,87vae: AutoencoderKL,88text_encoder: CLIPTextModel,89tokenizer: CLIPTokenizer,90unet: UNet2DConditionModel,91scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],92safety_checker: StableDiffusionSafetyChecker,93feature_extractor: CLIPImageProcessor,94):95super().__init__()9697if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:98deprecation_message = (99f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"100f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "101"to update the config accordingly as leaving `steps_offset` might led to incorrect results"102" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"103" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"104" file"105)106deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)107new_config = dict(scheduler.config)108new_config["steps_offset"] = 1109scheduler._internal_dict = FrozenDict(new_config)110111if safety_checker is None:112logger.warning(113f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"114" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"115" results in services or applications open to the public. Both the diffusers team and Hugging Face"116" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"117" it only for use-cases that involve analyzing network behavior or auditing its results. For more"118" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."119)120121self.register_modules(122vae=vae,123text_encoder=text_encoder,124tokenizer=tokenizer,125unet=unet,126scheduler=scheduler,127safety_checker=safety_checker,128feature_extractor=feature_extractor,129)130131def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):132r"""133Enable sliced attention computation.134135When this option is enabled, the attention module will split the input tensor in slices, to compute attention136in several steps. This is useful to save some memory in exchange for a small speed decrease.137138Args:139slice_size (`str` or `int`, *optional*, defaults to `"auto"`):140When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If141a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,142`attention_head_dim` must be a multiple of `slice_size`.143"""144if slice_size == "auto":145# half the attention head size is usually a good trade-off between146# speed and memory147slice_size = self.unet.config.attention_head_dim // 2148self.unet.set_attention_slice(slice_size)149150def disable_attention_slicing(self):151r"""152Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go153back to computing attention in one step.154"""155# set slice_size = `None` to disable `attention slicing`156self.enable_attention_slicing(None)157158@torch.no_grad()159def __call__(160self,161prompt: Union[str, List[str]],162image: Union[torch.FloatTensor, PIL.Image.Image],163inner_image: Union[torch.FloatTensor, PIL.Image.Image],164mask_image: Union[torch.FloatTensor, PIL.Image.Image],165height: int = 512,166width: int = 512,167num_inference_steps: int = 50,168guidance_scale: float = 7.5,169negative_prompt: Optional[Union[str, List[str]]] = None,170num_images_per_prompt: Optional[int] = 1,171eta: float = 0.0,172generator: Optional[torch.Generator] = None,173latents: Optional[torch.FloatTensor] = None,174output_type: Optional[str] = "pil",175return_dict: bool = True,176callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,177callback_steps: int = 1,178**kwargs,179):180r"""181Function invoked when calling the pipeline for generation.182183Args:184prompt (`str` or `List[str]`):185The prompt or prompts to guide the image generation.186image (`torch.Tensor` or `PIL.Image.Image`):187`Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will188be masked out with `mask_image` and repainted according to `prompt`.189inner_image (`torch.Tensor` or `PIL.Image.Image`):190`Image`, or tensor representing an image batch which will be overlayed onto `image`. Non-transparent191regions of `inner_image` must fit inside white pixels in `mask_image`. Expects four channels, with192the last channel representing the alpha channel, which will be used to blend `inner_image` with193`image`. If not provided, it will be forcibly cast to RGBA.194mask_image (`PIL.Image.Image`):195`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be196repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted197to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)198instead of 3, so the expected shape would be `(B, H, W, 1)`.199height (`int`, *optional*, defaults to 512):200The height in pixels of the generated image.201width (`int`, *optional*, defaults to 512):202The width in pixels of the generated image.203num_inference_steps (`int`, *optional*, defaults to 50):204The number of denoising steps. More denoising steps usually lead to a higher quality image at the205expense of slower inference.206guidance_scale (`float`, *optional*, defaults to 7.5):207Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).208`guidance_scale` is defined as `w` of equation 2. of [Imagen209Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >2101`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,211usually at the expense of lower image quality.212negative_prompt (`str` or `List[str]`, *optional*):213The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored214if `guidance_scale` is less than `1`).215num_images_per_prompt (`int`, *optional*, defaults to 1):216The number of images to generate per prompt.217eta (`float`, *optional*, defaults to 0.0):218Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to219[`schedulers.DDIMScheduler`], will be ignored for others.220generator (`torch.Generator`, *optional*):221A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation222deterministic.223latents (`torch.FloatTensor`, *optional*):224Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image225generation. Can be used to tweak the same generation with different prompts. If not provided, a latents226tensor will ge generated by sampling using the supplied random `generator`.227output_type (`str`, *optional*, defaults to `"pil"`):228The output format of the generate image. Choose between229[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.230return_dict (`bool`, *optional*, defaults to `True`):231Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a232plain tuple.233callback (`Callable`, *optional*):234A function that will be called every `callback_steps` steps during inference. The function will be235called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.236callback_steps (`int`, *optional*, defaults to 1):237The frequency at which the `callback` function will be called. If not specified, the callback will be238called at every step.239240Returns:241[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:242[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.243When returning a tuple, the first element is a list with the generated images, and the second element is a244list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"245(nsfw) content, according to the `safety_checker`.246"""247248if isinstance(prompt, str):249batch_size = 1250elif isinstance(prompt, list):251batch_size = len(prompt)252else:253raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")254255if height % 8 != 0 or width % 8 != 0:256raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")257258if (callback_steps is None) or (259callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)260):261raise ValueError(262f"`callback_steps` has to be a positive integer but is {callback_steps} of type"263f" {type(callback_steps)}."264)265266# check if input sizes are correct267check_size(image, height, width)268check_size(inner_image, height, width)269check_size(mask_image, height, width)270271# get prompt text embeddings272text_inputs = self.tokenizer(273prompt,274padding="max_length",275max_length=self.tokenizer.model_max_length,276return_tensors="pt",277)278text_input_ids = text_inputs.input_ids279280if text_input_ids.shape[-1] > self.tokenizer.model_max_length:281removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])282logger.warning(283"The following part of your input was truncated because CLIP can only handle sequences up to"284f" {self.tokenizer.model_max_length} tokens: {removed_text}"285)286text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]287text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]288289# duplicate text embeddings for each generation per prompt, using mps friendly method290bs_embed, seq_len, _ = text_embeddings.shape291text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)292text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)293294# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)295# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`296# corresponds to doing no classifier free guidance.297do_classifier_free_guidance = guidance_scale > 1.0298# get unconditional embeddings for classifier free guidance299if do_classifier_free_guidance:300uncond_tokens: List[str]301if negative_prompt is None:302uncond_tokens = [""]303elif type(prompt) is not type(negative_prompt):304raise TypeError(305f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="306f" {type(prompt)}."307)308elif isinstance(negative_prompt, str):309uncond_tokens = [negative_prompt]310elif batch_size != len(negative_prompt):311raise ValueError(312f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"313f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"314" the batch size of `prompt`."315)316else:317uncond_tokens = negative_prompt318319max_length = text_input_ids.shape[-1]320uncond_input = self.tokenizer(321uncond_tokens,322padding="max_length",323max_length=max_length,324truncation=True,325return_tensors="pt",326)327uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]328329# duplicate unconditional embeddings for each generation per prompt, using mps friendly method330seq_len = uncond_embeddings.shape[1]331uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1)332uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)333334# For classifier free guidance, we need to do two forward passes.335# Here we concatenate the unconditional and text embeddings into a single batch336# to avoid doing two forward passes337text_embeddings = torch.cat([uncond_embeddings, text_embeddings])338339# get the initial random noise unless the user supplied it340# Unlike in other pipelines, latents need to be generated in the target device341# for 1-to-1 results reproducibility with the CompVis implementation.342# However this currently doesn't work in `mps`.343num_channels_latents = self.vae.config.latent_channels344latents_shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8)345latents_dtype = text_embeddings.dtype346if latents is None:347if self.device.type == "mps":348# randn does not exist on mps349latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(350self.device351)352else:353latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)354else:355if latents.shape != latents_shape:356raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")357latents = latents.to(self.device)358359# overlay the inner image360image = overlay_inner_image(image, inner_image)361362# prepare mask and masked_image363mask, masked_image = prepare_mask_and_masked_image(image, mask_image)364mask = mask.to(device=self.device, dtype=text_embeddings.dtype)365masked_image = masked_image.to(device=self.device, dtype=text_embeddings.dtype)366367# resize the mask to latents shape as we concatenate the mask to the latents368mask = torch.nn.functional.interpolate(mask, size=(height // 8, width // 8))369370# encode the mask image into latents space so we can concatenate it to the latents371masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)372masked_image_latents = 0.18215 * masked_image_latents373374# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method375mask = mask.repeat(batch_size * num_images_per_prompt, 1, 1, 1)376masked_image_latents = masked_image_latents.repeat(batch_size * num_images_per_prompt, 1, 1, 1)377378mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask379masked_image_latents = (380torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents381)382383num_channels_mask = mask.shape[1]384num_channels_masked_image = masked_image_latents.shape[1]385386if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:387raise ValueError(388f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"389f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"390f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"391f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"392" `pipeline.unet` or your `mask_image` or `image` input."393)394395# set timesteps396self.scheduler.set_timesteps(num_inference_steps)397398# Some schedulers like PNDM have timesteps as arrays399# It's more optimized to move all timesteps to correct device beforehand400timesteps_tensor = self.scheduler.timesteps.to(self.device)401402# scale the initial noise by the standard deviation required by the scheduler403latents = latents * self.scheduler.init_noise_sigma404405# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature406# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.407# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502408# and should be between [0, 1]409accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())410extra_step_kwargs = {}411if accepts_eta:412extra_step_kwargs["eta"] = eta413414for i, t in enumerate(self.progress_bar(timesteps_tensor)):415# expand the latents if we are doing classifier free guidance416latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents417418# concat latents, mask, masked_image_latents in the channel dimension419latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)420421latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)422423# predict the noise residual424noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample425426# perform guidance427if do_classifier_free_guidance:428noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)429noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)430431# compute the previous noisy sample x_t -> x_t-1432latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample433434# call the callback, if provided435if callback is not None and i % callback_steps == 0:436callback(i, t, latents)437438latents = 1 / 0.18215 * latents439image = self.vae.decode(latents).sample440441image = (image / 2 + 0.5).clamp(0, 1)442443# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16444image = image.cpu().permute(0, 2, 3, 1).float().numpy()445446if self.safety_checker is not None:447safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(448self.device449)450image, has_nsfw_concept = self.safety_checker(451images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)452)453else:454has_nsfw_concept = None455456if output_type == "pil":457image = self.numpy_to_pil(image)458459if not return_dict:460return (image, has_nsfw_concept)461462return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)463464465