Path: blob/main/examples/community/text_inpainting.py
1448 views
from typing import Callable, List, Optional, Union12import PIL3import torch4from transformers import (5CLIPImageProcessor,6CLIPSegForImageSegmentation,7CLIPSegProcessor,8CLIPTextModel,9CLIPTokenizer,10)1112from diffusers import DiffusionPipeline13from diffusers.configuration_utils import FrozenDict14from diffusers.models import AutoencoderKL, UNet2DConditionModel15from diffusers.pipelines.stable_diffusion import StableDiffusionInpaintPipeline16from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker17from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler18from diffusers.utils import deprecate, is_accelerate_available, logging192021logger = logging.get_logger(__name__) # pylint: disable=invalid-name222324class TextInpainting(DiffusionPipeline):25r"""26Pipeline for text based inpainting using Stable Diffusion.27Uses CLIPSeg to get a mask from the given text, then calls the Inpainting pipeline with the generated mask2829This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the30library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)3132Args:33segmentation_model ([`CLIPSegForImageSegmentation`]):34CLIPSeg Model to generate mask from the given text. Please refer to the [model card]() for details.35segmentation_processor ([`CLIPSegProcessor`]):36CLIPSeg processor to get image, text features to translate prompt to English, if necessary. Please refer to the37[model card](https://huggingface.co/docs/transformers/model_doc/clipseg) for details.38vae ([`AutoencoderKL`]):39Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.40text_encoder ([`CLIPTextModel`]):41Frozen text-encoder. Stable Diffusion uses the text portion of42[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically43the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.44tokenizer (`CLIPTokenizer`):45Tokenizer of class46[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).47unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.48scheduler ([`SchedulerMixin`]):49A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of50[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].51safety_checker ([`StableDiffusionSafetyChecker`]):52Classification module that estimates whether generated images could be considered offensive or harmful.53Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.54feature_extractor ([`CLIPImageProcessor`]):55Model that extracts features from generated images to be used as inputs for the `safety_checker`.56"""5758def __init__(59self,60segmentation_model: CLIPSegForImageSegmentation,61segmentation_processor: CLIPSegProcessor,62vae: AutoencoderKL,63text_encoder: CLIPTextModel,64tokenizer: CLIPTokenizer,65unet: UNet2DConditionModel,66scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],67safety_checker: StableDiffusionSafetyChecker,68feature_extractor: CLIPImageProcessor,69):70super().__init__()7172if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:73deprecation_message = (74f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"75f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "76"to update the config accordingly as leaving `steps_offset` might led to incorrect results"77" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"78" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"79" file"80)81deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)82new_config = dict(scheduler.config)83new_config["steps_offset"] = 184scheduler._internal_dict = FrozenDict(new_config)8586if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False:87deprecation_message = (88f"The configuration file of this scheduler: {scheduler} has not set the configuration"89" `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make"90" sure to update the config accordingly as not setting `skip_prk_steps` in the config might lead to"91" incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face"92" Hub, it would be very nice if you could open a Pull request for the"93" `scheduler/scheduler_config.json` file"94)95deprecate("skip_prk_steps not set", "1.0.0", deprecation_message, standard_warn=False)96new_config = dict(scheduler.config)97new_config["skip_prk_steps"] = True98scheduler._internal_dict = FrozenDict(new_config)99100if safety_checker is None:101logger.warning(102f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"103" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"104" results in services or applications open to the public. Both the diffusers team and Hugging Face"105" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"106" it only for use-cases that involve analyzing network behavior or auditing its results. For more"107" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."108)109110self.register_modules(111segmentation_model=segmentation_model,112segmentation_processor=segmentation_processor,113vae=vae,114text_encoder=text_encoder,115tokenizer=tokenizer,116unet=unet,117scheduler=scheduler,118safety_checker=safety_checker,119feature_extractor=feature_extractor,120)121122def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):123r"""124Enable sliced attention computation.125126When this option is enabled, the attention module will split the input tensor in slices, to compute attention127in several steps. This is useful to save some memory in exchange for a small speed decrease.128129Args:130slice_size (`str` or `int`, *optional*, defaults to `"auto"`):131When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If132a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,133`attention_head_dim` must be a multiple of `slice_size`.134"""135if slice_size == "auto":136# half the attention head size is usually a good trade-off between137# speed and memory138slice_size = self.unet.config.attention_head_dim // 2139self.unet.set_attention_slice(slice_size)140141def disable_attention_slicing(self):142r"""143Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go144back to computing attention in one step.145"""146# set slice_size = `None` to disable `attention slicing`147self.enable_attention_slicing(None)148149def enable_sequential_cpu_offload(self):150r"""151Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,152text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a153`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.154"""155if is_accelerate_available():156from accelerate import cpu_offload157else:158raise ImportError("Please install accelerate via `pip install accelerate`")159160device = torch.device("cuda")161162for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:163if cpu_offloaded_model is not None:164cpu_offload(cpu_offloaded_model, device)165166@property167# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device168def _execution_device(self):169r"""170Returns the device on which the pipeline's models will be executed. After calling171`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module172hooks.173"""174if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):175return self.device176for module in self.unet.modules():177if (178hasattr(module, "_hf_hook")179and hasattr(module._hf_hook, "execution_device")180and module._hf_hook.execution_device is not None181):182return torch.device(module._hf_hook.execution_device)183return self.device184185@torch.no_grad()186def __call__(187self,188prompt: Union[str, List[str]],189image: Union[torch.FloatTensor, PIL.Image.Image],190text: str,191height: int = 512,192width: int = 512,193num_inference_steps: int = 50,194guidance_scale: float = 7.5,195negative_prompt: Optional[Union[str, List[str]]] = None,196num_images_per_prompt: Optional[int] = 1,197eta: float = 0.0,198generator: Optional[torch.Generator] = None,199latents: Optional[torch.FloatTensor] = None,200output_type: Optional[str] = "pil",201return_dict: bool = True,202callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,203callback_steps: int = 1,204**kwargs,205):206r"""207Function invoked when calling the pipeline for generation.208209Args:210prompt (`str` or `List[str]`):211The prompt or prompts to guide the image generation.212image (`PIL.Image.Image`):213`Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will214be masked out with `mask_image` and repainted according to `prompt`.215text (`str``):216The text to use to generate the mask.217height (`int`, *optional*, defaults to 512):218The height in pixels of the generated image.219width (`int`, *optional*, defaults to 512):220The width in pixels of the generated image.221num_inference_steps (`int`, *optional*, defaults to 50):222The number of denoising steps. More denoising steps usually lead to a higher quality image at the223expense of slower inference.224guidance_scale (`float`, *optional*, defaults to 7.5):225Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).226`guidance_scale` is defined as `w` of equation 2. of [Imagen227Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >2281`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,229usually at the expense of lower image quality.230negative_prompt (`str` or `List[str]`, *optional*):231The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored232if `guidance_scale` is less than `1`).233num_images_per_prompt (`int`, *optional*, defaults to 1):234The number of images to generate per prompt.235eta (`float`, *optional*, defaults to 0.0):236Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to237[`schedulers.DDIMScheduler`], will be ignored for others.238generator (`torch.Generator`, *optional*):239A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation240deterministic.241latents (`torch.FloatTensor`, *optional*):242Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image243generation. Can be used to tweak the same generation with different prompts. If not provided, a latents244tensor will ge generated by sampling using the supplied random `generator`.245output_type (`str`, *optional*, defaults to `"pil"`):246The output format of the generate image. Choose between247[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.248return_dict (`bool`, *optional*, defaults to `True`):249Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a250plain tuple.251callback (`Callable`, *optional*):252A function that will be called every `callback_steps` steps during inference. The function will be253called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.254callback_steps (`int`, *optional*, defaults to 1):255The frequency at which the `callback` function will be called. If not specified, the callback will be256called at every step.257258Returns:259[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:260[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.261When returning a tuple, the first element is a list with the generated images, and the second element is a262list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"263(nsfw) content, according to the `safety_checker`.264"""265266# We use the input text to generate the mask267inputs = self.segmentation_processor(268text=[text], images=[image], padding="max_length", return_tensors="pt"269).to(self.device)270outputs = self.segmentation_model(**inputs)271mask = torch.sigmoid(outputs.logits).cpu().detach().unsqueeze(-1).numpy()272mask_pil = self.numpy_to_pil(mask)[0].resize(image.size)273274# Run inpainting pipeline with the generated mask275inpainting_pipeline = StableDiffusionInpaintPipeline(276vae=self.vae,277text_encoder=self.text_encoder,278tokenizer=self.tokenizer,279unet=self.unet,280scheduler=self.scheduler,281safety_checker=self.safety_checker,282feature_extractor=self.feature_extractor,283)284return inpainting_pipeline(285prompt=prompt,286image=image,287mask_image=mask_pil,288height=height,289width=width,290num_inference_steps=num_inference_steps,291guidance_scale=guidance_scale,292negative_prompt=negative_prompt,293num_images_per_prompt=num_images_per_prompt,294eta=eta,295generator=generator,296latents=latents,297output_type=output_type,298return_dict=return_dict,299callback=callback,300callback_steps=callback_steps,301)302303304