Path: blob/main/examples/community/stable_diffusion_mega.py
1448 views
from typing import Any, Callable, Dict, List, Optional, Union12import PIL.Image3import torch4from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer56from diffusers import (7AutoencoderKL,8DDIMScheduler,9DiffusionPipeline,10LMSDiscreteScheduler,11PNDMScheduler,12StableDiffusionImg2ImgPipeline,13StableDiffusionInpaintPipelineLegacy,14StableDiffusionPipeline,15UNet2DConditionModel,16)17from diffusers.configuration_utils import FrozenDict18from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker19from diffusers.utils import deprecate, logging202122logger = logging.get_logger(__name__) # pylint: disable=invalid-name232425class StableDiffusionMegaPipeline(DiffusionPipeline):26r"""27Pipeline for text-to-image generation using Stable Diffusion.2829This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the30library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)3132Args:33vae ([`AutoencoderKL`]):34Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.35text_encoder ([`CLIPTextModel`]):36Frozen text-encoder. Stable Diffusion uses the text portion of37[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically38the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.39tokenizer (`CLIPTokenizer`):40Tokenizer of class41[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).42unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.43scheduler ([`SchedulerMixin`]):44A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of45[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].46safety_checker ([`StableDiffusionMegaSafetyChecker`]):47Classification module that estimates whether generated images could be considered offensive or harmful.48Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.49feature_extractor ([`CLIPImageProcessor`]):50Model that extracts features from generated images to be used as inputs for the `safety_checker`.51"""52_optional_components = ["safety_checker", "feature_extractor"]5354def __init__(55self,56vae: AutoencoderKL,57text_encoder: CLIPTextModel,58tokenizer: CLIPTokenizer,59unet: UNet2DConditionModel,60scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],61safety_checker: StableDiffusionSafetyChecker,62feature_extractor: CLIPImageProcessor,63requires_safety_checker: bool = True,64):65super().__init__()66if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:67deprecation_message = (68f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"69f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "70"to update the config accordingly as leaving `steps_offset` might led to incorrect results"71" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"72" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"73" file"74)75deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)76new_config = dict(scheduler.config)77new_config["steps_offset"] = 178scheduler._internal_dict = FrozenDict(new_config)7980self.register_modules(81vae=vae,82text_encoder=text_encoder,83tokenizer=tokenizer,84unet=unet,85scheduler=scheduler,86safety_checker=safety_checker,87feature_extractor=feature_extractor,88)89self.register_to_config(requires_safety_checker=requires_safety_checker)9091@property92def components(self) -> Dict[str, Any]:93return {k: getattr(self, k) for k in self.config.keys() if not k.startswith("_")}9495def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):96r"""97Enable sliced attention computation.9899When this option is enabled, the attention module will split the input tensor in slices, to compute attention100in several steps. This is useful to save some memory in exchange for a small speed decrease.101102Args:103slice_size (`str` or `int`, *optional*, defaults to `"auto"`):104When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If105a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,106`attention_head_dim` must be a multiple of `slice_size`.107"""108if slice_size == "auto":109# half the attention head size is usually a good trade-off between110# speed and memory111slice_size = self.unet.config.attention_head_dim // 2112self.unet.set_attention_slice(slice_size)113114def disable_attention_slicing(self):115r"""116Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go117back to computing attention in one step.118"""119# set slice_size = `None` to disable `attention slicing`120self.enable_attention_slicing(None)121122@torch.no_grad()123def inpaint(124self,125prompt: Union[str, List[str]],126image: Union[torch.FloatTensor, PIL.Image.Image],127mask_image: Union[torch.FloatTensor, PIL.Image.Image],128strength: float = 0.8,129num_inference_steps: Optional[int] = 50,130guidance_scale: Optional[float] = 7.5,131negative_prompt: Optional[Union[str, List[str]]] = None,132num_images_per_prompt: Optional[int] = 1,133eta: Optional[float] = 0.0,134generator: Optional[torch.Generator] = None,135output_type: Optional[str] = "pil",136return_dict: bool = True,137callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,138callback_steps: int = 1,139):140# For more information on how this function works, please see: https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion#diffusers.StableDiffusionImg2ImgPipeline141return StableDiffusionInpaintPipelineLegacy(**self.components)(142prompt=prompt,143image=image,144mask_image=mask_image,145strength=strength,146num_inference_steps=num_inference_steps,147guidance_scale=guidance_scale,148negative_prompt=negative_prompt,149num_images_per_prompt=num_images_per_prompt,150eta=eta,151generator=generator,152output_type=output_type,153return_dict=return_dict,154callback=callback,155)156157@torch.no_grad()158def img2img(159self,160prompt: Union[str, List[str]],161image: Union[torch.FloatTensor, PIL.Image.Image],162strength: float = 0.8,163num_inference_steps: Optional[int] = 50,164guidance_scale: Optional[float] = 7.5,165negative_prompt: Optional[Union[str, List[str]]] = None,166num_images_per_prompt: Optional[int] = 1,167eta: Optional[float] = 0.0,168generator: Optional[torch.Generator] = None,169output_type: Optional[str] = "pil",170return_dict: bool = True,171callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,172callback_steps: int = 1,173**kwargs,174):175# For more information on how this function works, please see: https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion#diffusers.StableDiffusionImg2ImgPipeline176return StableDiffusionImg2ImgPipeline(**self.components)(177prompt=prompt,178image=image,179strength=strength,180num_inference_steps=num_inference_steps,181guidance_scale=guidance_scale,182negative_prompt=negative_prompt,183num_images_per_prompt=num_images_per_prompt,184eta=eta,185generator=generator,186output_type=output_type,187return_dict=return_dict,188callback=callback,189callback_steps=callback_steps,190)191192@torch.no_grad()193def text2img(194self,195prompt: Union[str, List[str]],196height: int = 512,197width: int = 512,198num_inference_steps: int = 50,199guidance_scale: float = 7.5,200negative_prompt: Optional[Union[str, List[str]]] = None,201num_images_per_prompt: Optional[int] = 1,202eta: float = 0.0,203generator: Optional[torch.Generator] = None,204latents: Optional[torch.FloatTensor] = None,205output_type: Optional[str] = "pil",206return_dict: bool = True,207callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,208callback_steps: int = 1,209):210# For more information on how this function https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion#diffusers.StableDiffusionPipeline211return StableDiffusionPipeline(**self.components)(212prompt=prompt,213height=height,214width=width,215num_inference_steps=num_inference_steps,216guidance_scale=guidance_scale,217negative_prompt=negative_prompt,218num_images_per_prompt=num_images_per_prompt,219eta=eta,220generator=generator,221latents=latents,222output_type=output_type,223return_dict=return_dict,224callback=callback,225callback_steps=callback_steps,226)227228229