Path: blob/main/examples/community/unclip_text_interpolation.py
1448 views
import inspect1from typing import List, Optional, Tuple, Union23import torch4from torch.nn import functional as F5from transformers import CLIPTextModelWithProjection, CLIPTokenizer6from transformers.models.clip.modeling_clip import CLIPTextModelOutput78from diffusers import (9DiffusionPipeline,10ImagePipelineOutput,11PriorTransformer,12UnCLIPScheduler,13UNet2DConditionModel,14UNet2DModel,15)16from diffusers.pipelines.unclip import UnCLIPTextProjModel17from diffusers.utils import is_accelerate_available, logging, randn_tensor181920logger = logging.get_logger(__name__) # pylint: disable=invalid-name212223def slerp(val, low, high):24"""25Find the interpolation point between the 'low' and 'high' values for the given 'val'. See https://en.wikipedia.org/wiki/Slerp for more details on the topic.26"""27low_norm = low / torch.norm(low)28high_norm = high / torch.norm(high)29omega = torch.acos((low_norm * high_norm))30so = torch.sin(omega)31res = (torch.sin((1.0 - val) * omega) / so) * low + (torch.sin(val * omega) / so) * high32return res333435class UnCLIPTextInterpolationPipeline(DiffusionPipeline):3637"""38Pipeline for prompt-to-prompt interpolation on CLIP text embeddings and using the UnCLIP / Dall-E to decode them to images.3940This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the41library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)4243Args:44text_encoder ([`CLIPTextModelWithProjection`]):45Frozen text-encoder.46tokenizer (`CLIPTokenizer`):47Tokenizer of class48[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).49prior ([`PriorTransformer`]):50The canonincal unCLIP prior to approximate the image embedding from the text embedding.51text_proj ([`UnCLIPTextProjModel`]):52Utility class to prepare and combine the embeddings before they are passed to the decoder.53decoder ([`UNet2DConditionModel`]):54The decoder to invert the image embedding into an image.55super_res_first ([`UNet2DModel`]):56Super resolution unet. Used in all but the last step of the super resolution diffusion process.57super_res_last ([`UNet2DModel`]):58Super resolution unet. Used in the last step of the super resolution diffusion process.59prior_scheduler ([`UnCLIPScheduler`]):60Scheduler used in the prior denoising process. Just a modified DDPMScheduler.61decoder_scheduler ([`UnCLIPScheduler`]):62Scheduler used in the decoder denoising process. Just a modified DDPMScheduler.63super_res_scheduler ([`UnCLIPScheduler`]):64Scheduler used in the super resolution denoising process. Just a modified DDPMScheduler.6566"""6768prior: PriorTransformer69decoder: UNet2DConditionModel70text_proj: UnCLIPTextProjModel71text_encoder: CLIPTextModelWithProjection72tokenizer: CLIPTokenizer73super_res_first: UNet2DModel74super_res_last: UNet2DModel7576prior_scheduler: UnCLIPScheduler77decoder_scheduler: UnCLIPScheduler78super_res_scheduler: UnCLIPScheduler7980# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.__init__81def __init__(82self,83prior: PriorTransformer,84decoder: UNet2DConditionModel,85text_encoder: CLIPTextModelWithProjection,86tokenizer: CLIPTokenizer,87text_proj: UnCLIPTextProjModel,88super_res_first: UNet2DModel,89super_res_last: UNet2DModel,90prior_scheduler: UnCLIPScheduler,91decoder_scheduler: UnCLIPScheduler,92super_res_scheduler: UnCLIPScheduler,93):94super().__init__()9596self.register_modules(97prior=prior,98decoder=decoder,99text_encoder=text_encoder,100tokenizer=tokenizer,101text_proj=text_proj,102super_res_first=super_res_first,103super_res_last=super_res_last,104prior_scheduler=prior_scheduler,105decoder_scheduler=decoder_scheduler,106super_res_scheduler=super_res_scheduler,107)108109# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents110def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):111if latents is None:112latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)113else:114if latents.shape != shape:115raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")116latents = latents.to(device)117118latents = latents * scheduler.init_noise_sigma119return latents120121# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline._encode_prompt122def _encode_prompt(123self,124prompt,125device,126num_images_per_prompt,127do_classifier_free_guidance,128text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None,129text_attention_mask: Optional[torch.Tensor] = None,130):131if text_model_output is None:132batch_size = len(prompt) if isinstance(prompt, list) else 1133# get prompt text embeddings134text_inputs = self.tokenizer(135prompt,136padding="max_length",137max_length=self.tokenizer.model_max_length,138truncation=True,139return_tensors="pt",140)141text_input_ids = text_inputs.input_ids142text_mask = text_inputs.attention_mask.bool().to(device)143144untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids145146if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(147text_input_ids, untruncated_ids148):149removed_text = self.tokenizer.batch_decode(150untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]151)152logger.warning(153"The following part of your input was truncated because CLIP can only handle sequences up to"154f" {self.tokenizer.model_max_length} tokens: {removed_text}"155)156text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]157158text_encoder_output = self.text_encoder(text_input_ids.to(device))159160prompt_embeds = text_encoder_output.text_embeds161text_encoder_hidden_states = text_encoder_output.last_hidden_state162163else:164batch_size = text_model_output[0].shape[0]165prompt_embeds, text_encoder_hidden_states = text_model_output[0], text_model_output[1]166text_mask = text_attention_mask167168prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)169text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)170text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)171172if do_classifier_free_guidance:173uncond_tokens = [""] * batch_size174175uncond_input = self.tokenizer(176uncond_tokens,177padding="max_length",178max_length=self.tokenizer.model_max_length,179truncation=True,180return_tensors="pt",181)182uncond_text_mask = uncond_input.attention_mask.bool().to(device)183negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device))184185negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds186uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state187188# duplicate unconditional embeddings for each generation per prompt, using mps friendly method189190seq_len = negative_prompt_embeds.shape[1]191negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)192negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len)193194seq_len = uncond_text_encoder_hidden_states.shape[1]195uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)196uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(197batch_size * num_images_per_prompt, seq_len, -1198)199uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)200201# done duplicates202203# For classifier free guidance, we need to do two forward passes.204# Here we concatenate the unconditional and text embeddings into a single batch205# to avoid doing two forward passes206prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])207text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states])208209text_mask = torch.cat([uncond_text_mask, text_mask])210211return prompt_embeds, text_encoder_hidden_states, text_mask212213# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.enable_sequential_cpu_offload214def enable_sequential_cpu_offload(self, gpu_id=0):215r"""216Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's217models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only218when their specific submodule has its `forward` method called.219"""220if is_accelerate_available():221from accelerate import cpu_offload222else:223raise ImportError("Please install accelerate via `pip install accelerate`")224225device = torch.device(f"cuda:{gpu_id}")226227# TODO: self.prior.post_process_latents is not covered by the offload hooks, so it fails if added to the list228models = [229self.decoder,230self.text_proj,231self.text_encoder,232self.super_res_first,233self.super_res_last,234]235for cpu_offloaded_model in models:236if cpu_offloaded_model is not None:237cpu_offload(cpu_offloaded_model, device)238239@property240# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline._execution_device241def _execution_device(self):242r"""243Returns the device on which the pipeline's models will be executed. After calling244`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module245hooks.246"""247if self.device != torch.device("meta") or not hasattr(self.decoder, "_hf_hook"):248return self.device249for module in self.decoder.modules():250if (251hasattr(module, "_hf_hook")252and hasattr(module._hf_hook, "execution_device")253and module._hf_hook.execution_device is not None254):255return torch.device(module._hf_hook.execution_device)256return self.device257258@torch.no_grad()259def __call__(260self,261start_prompt: str,262end_prompt: str,263steps: int = 5,264prior_num_inference_steps: int = 25,265decoder_num_inference_steps: int = 25,266super_res_num_inference_steps: int = 7,267generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,268prior_guidance_scale: float = 4.0,269decoder_guidance_scale: float = 8.0,270enable_sequential_cpu_offload=True,271gpu_id=0,272output_type: Optional[str] = "pil",273return_dict: bool = True,274):275"""276Function invoked when calling the pipeline for generation.277278Args:279start_prompt (`str`):280The prompt to start the image generation interpolation from.281end_prompt (`str`):282The prompt to end the image generation interpolation at.283steps (`int`, *optional*, defaults to 5):284The number of steps over which to interpolate from start_prompt to end_prompt. The pipeline returns285the same number of images as this value.286prior_num_inference_steps (`int`, *optional*, defaults to 25):287The number of denoising steps for the prior. More denoising steps usually lead to a higher quality288image at the expense of slower inference.289decoder_num_inference_steps (`int`, *optional*, defaults to 25):290The number of denoising steps for the decoder. More denoising steps usually lead to a higher quality291image at the expense of slower inference.292super_res_num_inference_steps (`int`, *optional*, defaults to 7):293The number of denoising steps for super resolution. More denoising steps usually lead to a higher294quality image at the expense of slower inference.295generator (`torch.Generator` or `List[torch.Generator]`, *optional*):296One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)297to make generation deterministic.298prior_guidance_scale (`float`, *optional*, defaults to 4.0):299Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).300`guidance_scale` is defined as `w` of equation 2. of [Imagen301Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >3021`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,303usually at the expense of lower image quality.304decoder_guidance_scale (`float`, *optional*, defaults to 4.0):305Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).306`guidance_scale` is defined as `w` of equation 2. of [Imagen307Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >3081`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,309usually at the expense of lower image quality.310output_type (`str`, *optional*, defaults to `"pil"`):311The output format of the generated image. Choose between312[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.313enable_sequential_cpu_offload (`bool`, *optional*, defaults to `True`):314If True, offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's315models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only316when their specific submodule has its `forward` method called.317gpu_id (`int`, *optional*, defaults to `0`):318The gpu_id to be passed to enable_sequential_cpu_offload. Only works when enable_sequential_cpu_offload is set to True.319return_dict (`bool`, *optional*, defaults to `True`):320Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.321"""322323if not isinstance(start_prompt, str) or not isinstance(end_prompt, str):324raise ValueError(325f"`start_prompt` and `end_prompt` should be of type `str` but got {type(start_prompt)} and"326f" {type(end_prompt)} instead"327)328329if enable_sequential_cpu_offload:330self.enable_sequential_cpu_offload(gpu_id=gpu_id)331332device = self._execution_device333334# Turn the prompts into embeddings.335inputs = self.tokenizer(336[start_prompt, end_prompt],337padding="max_length",338truncation=True,339max_length=self.tokenizer.model_max_length,340return_tensors="pt",341)342inputs.to(device)343text_model_output = self.text_encoder(**inputs)344345text_attention_mask = torch.max(inputs.attention_mask[0], inputs.attention_mask[1])346text_attention_mask = torch.cat([text_attention_mask.unsqueeze(0)] * steps).to(device)347348# Interpolate from the start to end prompt using slerp and add the generated images to an image output pipeline349batch_text_embeds = []350batch_last_hidden_state = []351352for interp_val in torch.linspace(0, 1, steps):353text_embeds = slerp(interp_val, text_model_output.text_embeds[0], text_model_output.text_embeds[1])354last_hidden_state = slerp(355interp_val, text_model_output.last_hidden_state[0], text_model_output.last_hidden_state[1]356)357batch_text_embeds.append(text_embeds.unsqueeze(0))358batch_last_hidden_state.append(last_hidden_state.unsqueeze(0))359360batch_text_embeds = torch.cat(batch_text_embeds)361batch_last_hidden_state = torch.cat(batch_last_hidden_state)362363text_model_output = CLIPTextModelOutput(364text_embeds=batch_text_embeds, last_hidden_state=batch_last_hidden_state365)366367batch_size = text_model_output[0].shape[0]368369do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0370371prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt(372prompt=None,373device=device,374num_images_per_prompt=1,375do_classifier_free_guidance=do_classifier_free_guidance,376text_model_output=text_model_output,377text_attention_mask=text_attention_mask,378)379380# prior381382self.prior_scheduler.set_timesteps(prior_num_inference_steps, device=device)383prior_timesteps_tensor = self.prior_scheduler.timesteps384385embedding_dim = self.prior.config.embedding_dim386387prior_latents = self.prepare_latents(388(batch_size, embedding_dim),389prompt_embeds.dtype,390device,391generator,392None,393self.prior_scheduler,394)395396for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)):397# expand the latents if we are doing classifier free guidance398latent_model_input = torch.cat([prior_latents] * 2) if do_classifier_free_guidance else prior_latents399400predicted_image_embedding = self.prior(401latent_model_input,402timestep=t,403proj_embedding=prompt_embeds,404encoder_hidden_states=text_encoder_hidden_states,405attention_mask=text_mask,406).predicted_image_embedding407408if do_classifier_free_guidance:409predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2)410predicted_image_embedding = predicted_image_embedding_uncond + prior_guidance_scale * (411predicted_image_embedding_text - predicted_image_embedding_uncond412)413414if i + 1 == prior_timesteps_tensor.shape[0]:415prev_timestep = None416else:417prev_timestep = prior_timesteps_tensor[i + 1]418419prior_latents = self.prior_scheduler.step(420predicted_image_embedding,421timestep=t,422sample=prior_latents,423generator=generator,424prev_timestep=prev_timestep,425).prev_sample426427prior_latents = self.prior.post_process_latents(prior_latents)428429image_embeddings = prior_latents430431# done prior432433# decoder434435text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj(436image_embeddings=image_embeddings,437prompt_embeds=prompt_embeds,438text_encoder_hidden_states=text_encoder_hidden_states,439do_classifier_free_guidance=do_classifier_free_guidance,440)441442if device.type == "mps":443# HACK: MPS: There is a panic when padding bool tensors,444# so cast to int tensor for the pad and back to bool afterwards445text_mask = text_mask.type(torch.int)446decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1)447decoder_text_mask = decoder_text_mask.type(torch.bool)448else:449decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=True)450451self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device)452decoder_timesteps_tensor = self.decoder_scheduler.timesteps453454num_channels_latents = self.decoder.in_channels455height = self.decoder.sample_size456width = self.decoder.sample_size457458decoder_latents = self.prepare_latents(459(batch_size, num_channels_latents, height, width),460text_encoder_hidden_states.dtype,461device,462generator,463None,464self.decoder_scheduler,465)466467for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)):468# expand the latents if we are doing classifier free guidance469latent_model_input = torch.cat([decoder_latents] * 2) if do_classifier_free_guidance else decoder_latents470471noise_pred = self.decoder(472sample=latent_model_input,473timestep=t,474encoder_hidden_states=text_encoder_hidden_states,475class_labels=additive_clip_time_embeddings,476attention_mask=decoder_text_mask,477).sample478479if do_classifier_free_guidance:480noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)481noise_pred_uncond, _ = noise_pred_uncond.split(latent_model_input.shape[1], dim=1)482noise_pred_text, predicted_variance = noise_pred_text.split(latent_model_input.shape[1], dim=1)483noise_pred = noise_pred_uncond + decoder_guidance_scale * (noise_pred_text - noise_pred_uncond)484noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)485486if i + 1 == decoder_timesteps_tensor.shape[0]:487prev_timestep = None488else:489prev_timestep = decoder_timesteps_tensor[i + 1]490491# compute the previous noisy sample x_t -> x_t-1492decoder_latents = self.decoder_scheduler.step(493noise_pred, t, decoder_latents, prev_timestep=prev_timestep, generator=generator494).prev_sample495496decoder_latents = decoder_latents.clamp(-1, 1)497498image_small = decoder_latents499500# done decoder501502# super res503504self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device)505super_res_timesteps_tensor = self.super_res_scheduler.timesteps506507channels = self.super_res_first.in_channels // 2508height = self.super_res_first.sample_size509width = self.super_res_first.sample_size510511super_res_latents = self.prepare_latents(512(batch_size, channels, height, width),513image_small.dtype,514device,515generator,516None,517self.super_res_scheduler,518)519520if device.type == "mps":521# MPS does not support many interpolations522image_upscaled = F.interpolate(image_small, size=[height, width])523else:524interpolate_antialias = {}525if "antialias" in inspect.signature(F.interpolate).parameters:526interpolate_antialias["antialias"] = True527528image_upscaled = F.interpolate(529image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias530)531532for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)):533# no classifier free guidance534535if i == super_res_timesteps_tensor.shape[0] - 1:536unet = self.super_res_last537else:538unet = self.super_res_first539540latent_model_input = torch.cat([super_res_latents, image_upscaled], dim=1)541542noise_pred = unet(543sample=latent_model_input,544timestep=t,545).sample546547if i + 1 == super_res_timesteps_tensor.shape[0]:548prev_timestep = None549else:550prev_timestep = super_res_timesteps_tensor[i + 1]551552# compute the previous noisy sample x_t -> x_t-1553super_res_latents = self.super_res_scheduler.step(554noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator555).prev_sample556557image = super_res_latents558# done super res559560# post processing561562image = image * 0.5 + 0.5563image = image.clamp(0, 1)564image = image.cpu().permute(0, 2, 3, 1).float().numpy()565566if output_type == "pil":567image = self.numpy_to_pil(image)568569if not return_dict:570return (image,)571572return ImagePipelineOutput(images=image)573574575