Path: blob/main/examples/community/seed_resize_stable_diffusion.py
1448 views
"""1modified based on diffusion library from Huggingface: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py2"""3import inspect4from typing import Callable, List, Optional, Union56import torch7from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer89from diffusers import DiffusionPipeline10from diffusers.models import AutoencoderKL, UNet2DConditionModel11from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput12from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker13from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler14from diffusers.utils import logging151617logger = logging.get_logger(__name__) # pylint: disable=invalid-name181920class SeedResizeStableDiffusionPipeline(DiffusionPipeline):21r"""22Pipeline for text-to-image generation using Stable Diffusion.2324This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the25library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)2627Args:28vae ([`AutoencoderKL`]):29Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.30text_encoder ([`CLIPTextModel`]):31Frozen text-encoder. Stable Diffusion uses the text portion of32[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically33the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.34tokenizer (`CLIPTokenizer`):35Tokenizer of class36[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).37unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.38scheduler ([`SchedulerMixin`]):39A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of40[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].41safety_checker ([`StableDiffusionSafetyChecker`]):42Classification module that estimates whether generated images could be considered offensive or harmful.43Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.44feature_extractor ([`CLIPImageProcessor`]):45Model that extracts features from generated images to be used as inputs for the `safety_checker`.46"""4748def __init__(49self,50vae: AutoencoderKL,51text_encoder: CLIPTextModel,52tokenizer: CLIPTokenizer,53unet: UNet2DConditionModel,54scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],55safety_checker: StableDiffusionSafetyChecker,56feature_extractor: CLIPImageProcessor,57):58super().__init__()59self.register_modules(60vae=vae,61text_encoder=text_encoder,62tokenizer=tokenizer,63unet=unet,64scheduler=scheduler,65safety_checker=safety_checker,66feature_extractor=feature_extractor,67)6869def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):70r"""71Enable sliced attention computation.7273When this option is enabled, the attention module will split the input tensor in slices, to compute attention74in several steps. This is useful to save some memory in exchange for a small speed decrease.7576Args:77slice_size (`str` or `int`, *optional*, defaults to `"auto"`):78When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If79a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,80`attention_head_dim` must be a multiple of `slice_size`.81"""82if slice_size == "auto":83# half the attention head size is usually a good trade-off between84# speed and memory85slice_size = self.unet.config.attention_head_dim // 286self.unet.set_attention_slice(slice_size)8788def disable_attention_slicing(self):89r"""90Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go91back to computing attention in one step.92"""93# set slice_size = `None` to disable `attention slicing`94self.enable_attention_slicing(None)9596@torch.no_grad()97def __call__(98self,99prompt: Union[str, List[str]],100height: int = 512,101width: int = 512,102num_inference_steps: int = 50,103guidance_scale: float = 7.5,104negative_prompt: Optional[Union[str, List[str]]] = None,105num_images_per_prompt: Optional[int] = 1,106eta: float = 0.0,107generator: Optional[torch.Generator] = None,108latents: Optional[torch.FloatTensor] = None,109output_type: Optional[str] = "pil",110return_dict: bool = True,111callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,112callback_steps: int = 1,113text_embeddings: Optional[torch.FloatTensor] = None,114**kwargs,115):116r"""117Function invoked when calling the pipeline for generation.118119Args:120prompt (`str` or `List[str]`):121The prompt or prompts to guide the image generation.122height (`int`, *optional*, defaults to 512):123The height in pixels of the generated image.124width (`int`, *optional*, defaults to 512):125The width in pixels of the generated image.126num_inference_steps (`int`, *optional*, defaults to 50):127The number of denoising steps. More denoising steps usually lead to a higher quality image at the128expense of slower inference.129guidance_scale (`float`, *optional*, defaults to 7.5):130Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).131`guidance_scale` is defined as `w` of equation 2. of [Imagen132Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >1331`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,134usually at the expense of lower image quality.135negative_prompt (`str` or `List[str]`, *optional*):136The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored137if `guidance_scale` is less than `1`).138num_images_per_prompt (`int`, *optional*, defaults to 1):139The number of images to generate per prompt.140eta (`float`, *optional*, defaults to 0.0):141Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to142[`schedulers.DDIMScheduler`], will be ignored for others.143generator (`torch.Generator`, *optional*):144A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation145deterministic.146latents (`torch.FloatTensor`, *optional*):147Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image148generation. Can be used to tweak the same generation with different prompts. If not provided, a latents149tensor will ge generated by sampling using the supplied random `generator`.150output_type (`str`, *optional*, defaults to `"pil"`):151The output format of the generate image. Choose between152[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.153return_dict (`bool`, *optional*, defaults to `True`):154Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a155plain tuple.156callback (`Callable`, *optional*):157A function that will be called every `callback_steps` steps during inference. The function will be158called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.159callback_steps (`int`, *optional*, defaults to 1):160The frequency at which the `callback` function will be called. If not specified, the callback will be161called at every step.162163Returns:164[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:165[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.166When returning a tuple, the first element is a list with the generated images, and the second element is a167list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"168(nsfw) content, according to the `safety_checker`.169"""170171if isinstance(prompt, str):172batch_size = 1173elif isinstance(prompt, list):174batch_size = len(prompt)175else:176raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")177178if height % 8 != 0 or width % 8 != 0:179raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")180181if (callback_steps is None) or (182callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)183):184raise ValueError(185f"`callback_steps` has to be a positive integer but is {callback_steps} of type"186f" {type(callback_steps)}."187)188189# get prompt text embeddings190text_inputs = self.tokenizer(191prompt,192padding="max_length",193max_length=self.tokenizer.model_max_length,194return_tensors="pt",195)196text_input_ids = text_inputs.input_ids197198if text_input_ids.shape[-1] > self.tokenizer.model_max_length:199removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])200logger.warning(201"The following part of your input was truncated because CLIP can only handle sequences up to"202f" {self.tokenizer.model_max_length} tokens: {removed_text}"203)204text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]205206if text_embeddings is None:207text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]208209# duplicate text embeddings for each generation per prompt, using mps friendly method210bs_embed, seq_len, _ = text_embeddings.shape211text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)212text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)213214# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)215# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`216# corresponds to doing no classifier free guidance.217do_classifier_free_guidance = guidance_scale > 1.0218# get unconditional embeddings for classifier free guidance219if do_classifier_free_guidance:220uncond_tokens: List[str]221if negative_prompt is None:222uncond_tokens = [""]223elif type(prompt) is not type(negative_prompt):224raise TypeError(225f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="226f" {type(prompt)}."227)228elif isinstance(negative_prompt, str):229uncond_tokens = [negative_prompt]230elif batch_size != len(negative_prompt):231raise ValueError(232f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"233f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"234" the batch size of `prompt`."235)236else:237uncond_tokens = negative_prompt238239max_length = text_input_ids.shape[-1]240uncond_input = self.tokenizer(241uncond_tokens,242padding="max_length",243max_length=max_length,244truncation=True,245return_tensors="pt",246)247uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]248249# duplicate unconditional embeddings for each generation per prompt, using mps friendly method250seq_len = uncond_embeddings.shape[1]251uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1)252uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)253254# For classifier free guidance, we need to do two forward passes.255# Here we concatenate the unconditional and text embeddings into a single batch256# to avoid doing two forward passes257text_embeddings = torch.cat([uncond_embeddings, text_embeddings])258259# get the initial random noise unless the user supplied it260261# Unlike in other pipelines, latents need to be generated in the target device262# for 1-to-1 results reproducibility with the CompVis implementation.263# However this currently doesn't work in `mps`.264latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)265latents_shape_reference = (batch_size * num_images_per_prompt, self.unet.in_channels, 64, 64)266latents_dtype = text_embeddings.dtype267if latents is None:268if self.device.type == "mps":269# randn does not exist on mps270latents_reference = torch.randn(271latents_shape_reference, generator=generator, device="cpu", dtype=latents_dtype272).to(self.device)273latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(274self.device275)276else:277latents_reference = torch.randn(278latents_shape_reference, generator=generator, device=self.device, dtype=latents_dtype279)280latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)281else:282if latents_reference.shape != latents_shape:283raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")284latents_reference = latents_reference.to(self.device)285latents = latents.to(self.device)286287# This is the key part of the pipeline where we288# try to ensure that the generated images w/ the same seed289# but different sizes actually result in similar images290dx = (latents_shape[3] - latents_shape_reference[3]) // 2291dy = (latents_shape[2] - latents_shape_reference[2]) // 2292w = latents_shape_reference[3] if dx >= 0 else latents_shape_reference[3] + 2 * dx293h = latents_shape_reference[2] if dy >= 0 else latents_shape_reference[2] + 2 * dy294tx = 0 if dx < 0 else dx295ty = 0 if dy < 0 else dy296dx = max(-dx, 0)297dy = max(-dy, 0)298# import pdb299# pdb.set_trace()300latents[:, :, ty : ty + h, tx : tx + w] = latents_reference[:, :, dy : dy + h, dx : dx + w]301302# set timesteps303self.scheduler.set_timesteps(num_inference_steps)304305# Some schedulers like PNDM have timesteps as arrays306# It's more optimized to move all timesteps to correct device beforehand307timesteps_tensor = self.scheduler.timesteps.to(self.device)308309# scale the initial noise by the standard deviation required by the scheduler310latents = latents * self.scheduler.init_noise_sigma311312# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature313# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.314# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502315# and should be between [0, 1]316accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())317extra_step_kwargs = {}318if accepts_eta:319extra_step_kwargs["eta"] = eta320321for i, t in enumerate(self.progress_bar(timesteps_tensor)):322# expand the latents if we are doing classifier free guidance323latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents324latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)325326# predict the noise residual327noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample328329# perform guidance330if do_classifier_free_guidance:331noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)332noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)333334# compute the previous noisy sample x_t -> x_t-1335latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample336337# call the callback, if provided338if callback is not None and i % callback_steps == 0:339callback(i, t, latents)340341latents = 1 / 0.18215 * latents342image = self.vae.decode(latents).sample343344image = (image / 2 + 0.5).clamp(0, 1)345346# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16347image = image.cpu().permute(0, 2, 3, 1).float().numpy()348349if self.safety_checker is not None:350safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(351self.device352)353image, has_nsfw_concept = self.safety_checker(354images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)355)356else:357has_nsfw_concept = None358359if output_type == "pil":360image = self.numpy_to_pil(image)361362if not return_dict:363return (image, has_nsfw_concept)364365return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)366367368