Path: blob/main/examples/community/interpolate_stable_diffusion.py
1448 views
import inspect1import time2from pathlib import Path3from typing import Callable, List, Optional, Union45import numpy as np6import torch7from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer89from diffusers import DiffusionPipeline10from diffusers.configuration_utils import FrozenDict11from diffusers.models import AutoencoderKL, UNet2DConditionModel12from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput13from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker14from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler15from diffusers.utils import deprecate, logging161718logger = logging.get_logger(__name__) # pylint: disable=invalid-name192021def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):22"""helper function to spherically interpolate two arrays v1 v2"""2324if not isinstance(v0, np.ndarray):25inputs_are_torch = True26input_device = v0.device27v0 = v0.cpu().numpy()28v1 = v1.cpu().numpy()2930dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))31if np.abs(dot) > DOT_THRESHOLD:32v2 = (1 - t) * v0 + t * v133else:34theta_0 = np.arccos(dot)35sin_theta_0 = np.sin(theta_0)36theta_t = theta_0 * t37sin_theta_t = np.sin(theta_t)38s0 = np.sin(theta_0 - theta_t) / sin_theta_039s1 = sin_theta_t / sin_theta_040v2 = s0 * v0 + s1 * v14142if inputs_are_torch:43v2 = torch.from_numpy(v2).to(input_device)4445return v2464748class StableDiffusionWalkPipeline(DiffusionPipeline):49r"""50Pipeline for text-to-image generation using Stable Diffusion.5152This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the53library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)5455Args:56vae ([`AutoencoderKL`]):57Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.58text_encoder ([`CLIPTextModel`]):59Frozen text-encoder. Stable Diffusion uses the text portion of60[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically61the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.62tokenizer (`CLIPTokenizer`):63Tokenizer of class64[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).65unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.66scheduler ([`SchedulerMixin`]):67A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of68[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].69safety_checker ([`StableDiffusionSafetyChecker`]):70Classification module that estimates whether generated images could be considered offensive or harmful.71Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.72feature_extractor ([`CLIPImageProcessor`]):73Model that extracts features from generated images to be used as inputs for the `safety_checker`.74"""7576def __init__(77self,78vae: AutoencoderKL,79text_encoder: CLIPTextModel,80tokenizer: CLIPTokenizer,81unet: UNet2DConditionModel,82scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],83safety_checker: StableDiffusionSafetyChecker,84feature_extractor: CLIPImageProcessor,85):86super().__init__()8788if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:89deprecation_message = (90f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"91f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "92"to update the config accordingly as leaving `steps_offset` might led to incorrect results"93" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"94" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"95" file"96)97deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)98new_config = dict(scheduler.config)99new_config["steps_offset"] = 1100scheduler._internal_dict = FrozenDict(new_config)101102if safety_checker is None:103logger.warning(104f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"105" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"106" results in services or applications open to the public. Both the diffusers team and Hugging Face"107" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"108" it only for use-cases that involve analyzing network behavior or auditing its results. For more"109" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."110)111112self.register_modules(113vae=vae,114text_encoder=text_encoder,115tokenizer=tokenizer,116unet=unet,117scheduler=scheduler,118safety_checker=safety_checker,119feature_extractor=feature_extractor,120)121122def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):123r"""124Enable sliced attention computation.125126When this option is enabled, the attention module will split the input tensor in slices, to compute attention127in several steps. This is useful to save some memory in exchange for a small speed decrease.128129Args:130slice_size (`str` or `int`, *optional*, defaults to `"auto"`):131When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If132a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,133`attention_head_dim` must be a multiple of `slice_size`.134"""135if slice_size == "auto":136# half the attention head size is usually a good trade-off between137# speed and memory138slice_size = self.unet.config.attention_head_dim // 2139self.unet.set_attention_slice(slice_size)140141def disable_attention_slicing(self):142r"""143Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go144back to computing attention in one step.145"""146# set slice_size = `None` to disable `attention slicing`147self.enable_attention_slicing(None)148149@torch.no_grad()150def __call__(151self,152prompt: Optional[Union[str, List[str]]] = None,153height: int = 512,154width: int = 512,155num_inference_steps: int = 50,156guidance_scale: float = 7.5,157negative_prompt: Optional[Union[str, List[str]]] = None,158num_images_per_prompt: Optional[int] = 1,159eta: float = 0.0,160generator: Optional[torch.Generator] = None,161latents: Optional[torch.FloatTensor] = None,162output_type: Optional[str] = "pil",163return_dict: bool = True,164callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,165callback_steps: int = 1,166text_embeddings: Optional[torch.FloatTensor] = None,167**kwargs,168):169r"""170Function invoked when calling the pipeline for generation.171172Args:173prompt (`str` or `List[str]`, *optional*, defaults to `None`):174The prompt or prompts to guide the image generation. If not provided, `text_embeddings` is required.175height (`int`, *optional*, defaults to 512):176The height in pixels of the generated image.177width (`int`, *optional*, defaults to 512):178The width in pixels of the generated image.179num_inference_steps (`int`, *optional*, defaults to 50):180The number of denoising steps. More denoising steps usually lead to a higher quality image at the181expense of slower inference.182guidance_scale (`float`, *optional*, defaults to 7.5):183Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).184`guidance_scale` is defined as `w` of equation 2. of [Imagen185Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >1861`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,187usually at the expense of lower image quality.188negative_prompt (`str` or `List[str]`, *optional*):189The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored190if `guidance_scale` is less than `1`).191num_images_per_prompt (`int`, *optional*, defaults to 1):192The number of images to generate per prompt.193eta (`float`, *optional*, defaults to 0.0):194Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to195[`schedulers.DDIMScheduler`], will be ignored for others.196generator (`torch.Generator`, *optional*):197A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation198deterministic.199latents (`torch.FloatTensor`, *optional*):200Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image201generation. Can be used to tweak the same generation with different prompts. If not provided, a latents202tensor will ge generated by sampling using the supplied random `generator`.203output_type (`str`, *optional*, defaults to `"pil"`):204The output format of the generate image. Choose between205[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.206return_dict (`bool`, *optional*, defaults to `True`):207Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a208plain tuple.209callback (`Callable`, *optional*):210A function that will be called every `callback_steps` steps during inference. The function will be211called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.212callback_steps (`int`, *optional*, defaults to 1):213The frequency at which the `callback` function will be called. If not specified, the callback will be214called at every step.215text_embeddings (`torch.FloatTensor`, *optional*, defaults to `None`):216Pre-generated text embeddings to be used as inputs for image generation. Can be used in place of217`prompt` to avoid re-computing the embeddings. If not provided, the embeddings will be generated from218the supplied `prompt`.219220Returns:221[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:222[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.223When returning a tuple, the first element is a list with the generated images, and the second element is a224list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"225(nsfw) content, according to the `safety_checker`.226"""227228if height % 8 != 0 or width % 8 != 0:229raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")230231if (callback_steps is None) or (232callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)233):234raise ValueError(235f"`callback_steps` has to be a positive integer but is {callback_steps} of type"236f" {type(callback_steps)}."237)238239if text_embeddings is None:240if isinstance(prompt, str):241batch_size = 1242elif isinstance(prompt, list):243batch_size = len(prompt)244else:245raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")246247# get prompt text embeddings248text_inputs = self.tokenizer(249prompt,250padding="max_length",251max_length=self.tokenizer.model_max_length,252return_tensors="pt",253)254text_input_ids = text_inputs.input_ids255256if text_input_ids.shape[-1] > self.tokenizer.model_max_length:257removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])258print(259"The following part of your input was truncated because CLIP can only handle sequences up to"260f" {self.tokenizer.model_max_length} tokens: {removed_text}"261)262text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]263text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]264else:265batch_size = text_embeddings.shape[0]266267# duplicate text embeddings for each generation per prompt, using mps friendly method268bs_embed, seq_len, _ = text_embeddings.shape269text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)270text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)271272# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)273# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`274# corresponds to doing no classifier free guidance.275do_classifier_free_guidance = guidance_scale > 1.0276# get unconditional embeddings for classifier free guidance277if do_classifier_free_guidance:278uncond_tokens: List[str]279if negative_prompt is None:280uncond_tokens = [""] * batch_size281elif type(prompt) is not type(negative_prompt):282raise TypeError(283f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="284f" {type(prompt)}."285)286elif isinstance(negative_prompt, str):287uncond_tokens = [negative_prompt]288elif batch_size != len(negative_prompt):289raise ValueError(290f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"291f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"292" the batch size of `prompt`."293)294else:295uncond_tokens = negative_prompt296297max_length = self.tokenizer.model_max_length298uncond_input = self.tokenizer(299uncond_tokens,300padding="max_length",301max_length=max_length,302truncation=True,303return_tensors="pt",304)305uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]306307# duplicate unconditional embeddings for each generation per prompt, using mps friendly method308seq_len = uncond_embeddings.shape[1]309uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)310uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)311312# For classifier free guidance, we need to do two forward passes.313# Here we concatenate the unconditional and text embeddings into a single batch314# to avoid doing two forward passes315text_embeddings = torch.cat([uncond_embeddings, text_embeddings])316317# get the initial random noise unless the user supplied it318319# Unlike in other pipelines, latents need to be generated in the target device320# for 1-to-1 results reproducibility with the CompVis implementation.321# However this currently doesn't work in `mps`.322latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)323latents_dtype = text_embeddings.dtype324if latents is None:325if self.device.type == "mps":326# randn does not work reproducibly on mps327latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(328self.device329)330else:331latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)332else:333if latents.shape != latents_shape:334raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")335latents = latents.to(self.device)336337# set timesteps338self.scheduler.set_timesteps(num_inference_steps)339340# Some schedulers like PNDM have timesteps as arrays341# It's more optimized to move all timesteps to correct device beforehand342timesteps_tensor = self.scheduler.timesteps.to(self.device)343344# scale the initial noise by the standard deviation required by the scheduler345latents = latents * self.scheduler.init_noise_sigma346347# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature348# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.349# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502350# and should be between [0, 1]351accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())352extra_step_kwargs = {}353if accepts_eta:354extra_step_kwargs["eta"] = eta355356for i, t in enumerate(self.progress_bar(timesteps_tensor)):357# expand the latents if we are doing classifier free guidance358latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents359latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)360361# predict the noise residual362noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample363364# perform guidance365if do_classifier_free_guidance:366noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)367noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)368369# compute the previous noisy sample x_t -> x_t-1370latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample371372# call the callback, if provided373if callback is not None and i % callback_steps == 0:374callback(i, t, latents)375376latents = 1 / 0.18215 * latents377image = self.vae.decode(latents).sample378379image = (image / 2 + 0.5).clamp(0, 1)380381# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16382image = image.cpu().permute(0, 2, 3, 1).float().numpy()383384if self.safety_checker is not None:385safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(386self.device387)388image, has_nsfw_concept = self.safety_checker(389images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)390)391else:392has_nsfw_concept = None393394if output_type == "pil":395image = self.numpy_to_pil(image)396397if not return_dict:398return (image, has_nsfw_concept)399400return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)401402def embed_text(self, text):403"""takes in text and turns it into text embeddings"""404text_input = self.tokenizer(405text,406padding="max_length",407max_length=self.tokenizer.model_max_length,408truncation=True,409return_tensors="pt",410)411with torch.no_grad():412embed = self.text_encoder(text_input.input_ids.to(self.device))[0]413return embed414415def get_noise(self, seed, dtype=torch.float32, height=512, width=512):416"""Takes in random seed and returns corresponding noise vector"""417return torch.randn(418(1, self.unet.in_channels, height // 8, width // 8),419generator=torch.Generator(device=self.device).manual_seed(seed),420device=self.device,421dtype=dtype,422)423424def walk(425self,426prompts: List[str],427seeds: List[int],428num_interpolation_steps: Optional[int] = 6,429output_dir: Optional[str] = "./dreams",430name: Optional[str] = None,431batch_size: Optional[int] = 1,432height: Optional[int] = 512,433width: Optional[int] = 512,434guidance_scale: Optional[float] = 7.5,435num_inference_steps: Optional[int] = 50,436eta: Optional[float] = 0.0,437) -> List[str]:438"""439Walks through a series of prompts and seeds, interpolating between them and saving the results to disk.440441Args:442prompts (`List[str]`):443List of prompts to generate images for.444seeds (`List[int]`):445List of seeds corresponding to provided prompts. Must be the same length as prompts.446num_interpolation_steps (`int`, *optional*, defaults to 6):447Number of interpolation steps to take between prompts.448output_dir (`str`, *optional*, defaults to `./dreams`):449Directory to save the generated images to.450name (`str`, *optional*, defaults to `None`):451Subdirectory of `output_dir` to save the generated images to. If `None`, the name will452be the current time.453batch_size (`int`, *optional*, defaults to 1):454Number of images to generate at once.455height (`int`, *optional*, defaults to 512):456Height of the generated images.457width (`int`, *optional*, defaults to 512):458Width of the generated images.459guidance_scale (`float`, *optional*, defaults to 7.5):460Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).461`guidance_scale` is defined as `w` of equation 2. of [Imagen462Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >4631`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,464usually at the expense of lower image quality.465num_inference_steps (`int`, *optional*, defaults to 50):466The number of denoising steps. More denoising steps usually lead to a higher quality image at the467expense of slower inference.468eta (`float`, *optional*, defaults to 0.0):469Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to470[`schedulers.DDIMScheduler`], will be ignored for others.471472Returns:473`List[str]`: List of paths to the generated images.474"""475if not len(prompts) == len(seeds):476raise ValueError(477f"Number of prompts and seeds must be equalGot {len(prompts)} prompts and {len(seeds)} seeds"478)479480name = name or time.strftime("%Y%m%d-%H%M%S")481save_path = Path(output_dir) / name482save_path.mkdir(exist_ok=True, parents=True)483484frame_idx = 0485frame_filepaths = []486for prompt_a, prompt_b, seed_a, seed_b in zip(prompts, prompts[1:], seeds, seeds[1:]):487# Embed Text488embed_a = self.embed_text(prompt_a)489embed_b = self.embed_text(prompt_b)490491# Get Noise492noise_dtype = embed_a.dtype493noise_a = self.get_noise(seed_a, noise_dtype, height, width)494noise_b = self.get_noise(seed_b, noise_dtype, height, width)495496noise_batch, embeds_batch = None, None497T = np.linspace(0.0, 1.0, num_interpolation_steps)498for i, t in enumerate(T):499noise = slerp(float(t), noise_a, noise_b)500embed = torch.lerp(embed_a, embed_b, t)501502noise_batch = noise if noise_batch is None else torch.cat([noise_batch, noise], dim=0)503embeds_batch = embed if embeds_batch is None else torch.cat([embeds_batch, embed], dim=0)504505batch_is_ready = embeds_batch.shape[0] == batch_size or i + 1 == T.shape[0]506if batch_is_ready:507outputs = self(508latents=noise_batch,509text_embeddings=embeds_batch,510height=height,511width=width,512guidance_scale=guidance_scale,513eta=eta,514num_inference_steps=num_inference_steps,515)516noise_batch, embeds_batch = None, None517518for image in outputs["images"]:519frame_filepath = str(save_path / f"frame_{frame_idx:06d}.png")520image.save(frame_filepath)521frame_filepaths.append(frame_filepath)522frame_idx += 1523return frame_filepaths524525526