Path: blob/main/examples/community/unclip_image_interpolation.py
1448 views
import inspect1from typing import List, Optional, Union23import PIL4import torch5from torch.nn import functional as F6from transformers import (7CLIPImageProcessor,8CLIPTextModelWithProjection,9CLIPTokenizer,10CLIPVisionModelWithProjection,11)1213from diffusers import (14DiffusionPipeline,15ImagePipelineOutput,16UnCLIPScheduler,17UNet2DConditionModel,18UNet2DModel,19)20from diffusers.pipelines.unclip import UnCLIPTextProjModel21from diffusers.utils import is_accelerate_available, logging, randn_tensor222324logger = logging.get_logger(__name__) # pylint: disable=invalid-name252627def slerp(val, low, high):28"""29Find the interpolation point between the 'low' and 'high' values for the given 'val'. See https://en.wikipedia.org/wiki/Slerp for more details on the topic.30"""31low_norm = low / torch.norm(low)32high_norm = high / torch.norm(high)33omega = torch.acos((low_norm * high_norm))34so = torch.sin(omega)35res = (torch.sin((1.0 - val) * omega) / so) * low + (torch.sin(val * omega) / so) * high36return res373839class UnCLIPImageInterpolationPipeline(DiffusionPipeline):40"""41Pipeline to generate variations from an input image using unCLIP4243This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the44library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)4546Args:47text_encoder ([`CLIPTextModelWithProjection`]):48Frozen text-encoder.49tokenizer (`CLIPTokenizer`):50Tokenizer of class51[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).52feature_extractor ([`CLIPImageProcessor`]):53Model that extracts features from generated images to be used as inputs for the `image_encoder`.54image_encoder ([`CLIPVisionModelWithProjection`]):55Frozen CLIP image-encoder. unCLIP Image Variation uses the vision portion of56[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModelWithProjection),57specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.58text_proj ([`UnCLIPTextProjModel`]):59Utility class to prepare and combine the embeddings before they are passed to the decoder.60decoder ([`UNet2DConditionModel`]):61The decoder to invert the image embedding into an image.62super_res_first ([`UNet2DModel`]):63Super resolution unet. Used in all but the last step of the super resolution diffusion process.64super_res_last ([`UNet2DModel`]):65Super resolution unet. Used in the last step of the super resolution diffusion process.66decoder_scheduler ([`UnCLIPScheduler`]):67Scheduler used in the decoder denoising process. Just a modified DDPMScheduler.68super_res_scheduler ([`UnCLIPScheduler`]):69Scheduler used in the super resolution denoising process. Just a modified DDPMScheduler.7071"""7273decoder: UNet2DConditionModel74text_proj: UnCLIPTextProjModel75text_encoder: CLIPTextModelWithProjection76tokenizer: CLIPTokenizer77feature_extractor: CLIPImageProcessor78image_encoder: CLIPVisionModelWithProjection79super_res_first: UNet2DModel80super_res_last: UNet2DModel8182decoder_scheduler: UnCLIPScheduler83super_res_scheduler: UnCLIPScheduler8485# Copied from diffusers.pipelines.unclip.pipeline_unclip_image_variation.UnCLIPImageVariationPipeline.__init__86def __init__(87self,88decoder: UNet2DConditionModel,89text_encoder: CLIPTextModelWithProjection,90tokenizer: CLIPTokenizer,91text_proj: UnCLIPTextProjModel,92feature_extractor: CLIPImageProcessor,93image_encoder: CLIPVisionModelWithProjection,94super_res_first: UNet2DModel,95super_res_last: UNet2DModel,96decoder_scheduler: UnCLIPScheduler,97super_res_scheduler: UnCLIPScheduler,98):99super().__init__()100101self.register_modules(102decoder=decoder,103text_encoder=text_encoder,104tokenizer=tokenizer,105text_proj=text_proj,106feature_extractor=feature_extractor,107image_encoder=image_encoder,108super_res_first=super_res_first,109super_res_last=super_res_last,110decoder_scheduler=decoder_scheduler,111super_res_scheduler=super_res_scheduler,112)113114# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents115def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):116if latents is None:117latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)118else:119if latents.shape != shape:120raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")121latents = latents.to(device)122123latents = latents * scheduler.init_noise_sigma124return latents125126# Copied from diffusers.pipelines.unclip.pipeline_unclip_image_variation.UnCLIPImageVariationPipeline._encode_prompt127def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance):128batch_size = len(prompt) if isinstance(prompt, list) else 1129130# get prompt text embeddings131text_inputs = self.tokenizer(132prompt,133padding="max_length",134max_length=self.tokenizer.model_max_length,135return_tensors="pt",136)137text_input_ids = text_inputs.input_ids138text_mask = text_inputs.attention_mask.bool().to(device)139text_encoder_output = self.text_encoder(text_input_ids.to(device))140141prompt_embeds = text_encoder_output.text_embeds142text_encoder_hidden_states = text_encoder_output.last_hidden_state143144prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)145text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)146text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)147148if do_classifier_free_guidance:149uncond_tokens = [""] * batch_size150151max_length = text_input_ids.shape[-1]152uncond_input = self.tokenizer(153uncond_tokens,154padding="max_length",155max_length=max_length,156truncation=True,157return_tensors="pt",158)159uncond_text_mask = uncond_input.attention_mask.bool().to(device)160negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device))161162negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds163uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state164165# duplicate unconditional embeddings for each generation per prompt, using mps friendly method166167seq_len = negative_prompt_embeds.shape[1]168negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)169negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len)170171seq_len = uncond_text_encoder_hidden_states.shape[1]172uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)173uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(174batch_size * num_images_per_prompt, seq_len, -1175)176uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)177178# done duplicates179180# For classifier free guidance, we need to do two forward passes.181# Here we concatenate the unconditional and text embeddings into a single batch182# to avoid doing two forward passes183prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])184text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states])185186text_mask = torch.cat([uncond_text_mask, text_mask])187188return prompt_embeds, text_encoder_hidden_states, text_mask189190# Copied from diffusers.pipelines.unclip.pipeline_unclip_image_variation.UnCLIPImageVariationPipeline._encode_image191def _encode_image(self, image, device, num_images_per_prompt, image_embeddings: Optional[torch.Tensor] = None):192dtype = next(self.image_encoder.parameters()).dtype193194if image_embeddings is None:195if not isinstance(image, torch.Tensor):196image = self.feature_extractor(images=image, return_tensors="pt").pixel_values197198image = image.to(device=device, dtype=dtype)199image_embeddings = self.image_encoder(image).image_embeds200201image_embeddings = image_embeddings.repeat_interleave(num_images_per_prompt, dim=0)202203return image_embeddings204205# Copied from diffusers.pipelines.unclip.pipeline_unclip_image_variation.UnCLIPImageVariationPipeline.enable_sequential_cpu_offload206def enable_sequential_cpu_offload(self, gpu_id=0):207r"""208Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's209models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only210when their specific submodule has its `forward` method called.211"""212if is_accelerate_available():213from accelerate import cpu_offload214else:215raise ImportError("Please install accelerate via `pip install accelerate`")216217device = torch.device(f"cuda:{gpu_id}")218219models = [220self.decoder,221self.text_proj,222self.text_encoder,223self.super_res_first,224self.super_res_last,225]226for cpu_offloaded_model in models:227if cpu_offloaded_model is not None:228cpu_offload(cpu_offloaded_model, device)229230@property231# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline._execution_device232def _execution_device(self):233r"""234Returns the device on which the pipeline's models will be executed. After calling235`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module236hooks.237"""238if self.device != torch.device("meta") or not hasattr(self.decoder, "_hf_hook"):239return self.device240for module in self.decoder.modules():241if (242hasattr(module, "_hf_hook")243and hasattr(module._hf_hook, "execution_device")244and module._hf_hook.execution_device is not None245):246return torch.device(module._hf_hook.execution_device)247return self.device248249@torch.no_grad()250def __call__(251self,252image: Optional[Union[List[PIL.Image.Image], torch.FloatTensor]] = None,253steps: int = 5,254decoder_num_inference_steps: int = 25,255super_res_num_inference_steps: int = 7,256generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,257image_embeddings: Optional[torch.Tensor] = None,258decoder_latents: Optional[torch.FloatTensor] = None,259super_res_latents: Optional[torch.FloatTensor] = None,260decoder_guidance_scale: float = 8.0,261output_type: Optional[str] = "pil",262return_dict: bool = True,263):264"""265Function invoked when calling the pipeline for generation.266267Args:268image (`List[PIL.Image.Image]` or `torch.FloatTensor`):269The images to use for the image interpolation. Only accepts a list of two PIL Images or If you provide a tensor, it needs to comply with the270configuration of271[this](https://huggingface.co/fusing/karlo-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json)272`CLIPImageProcessor` while still having a shape of two in the 0th dimension. Can be left to `None` only when `image_embeddings` are passed.273steps (`int`, *optional*, defaults to 5):274The number of interpolation images to generate.275decoder_num_inference_steps (`int`, *optional*, defaults to 25):276The number of denoising steps for the decoder. More denoising steps usually lead to a higher quality277image at the expense of slower inference.278super_res_num_inference_steps (`int`, *optional*, defaults to 7):279The number of denoising steps for super resolution. More denoising steps usually lead to a higher280quality image at the expense of slower inference.281generator (`torch.Generator` or `List[torch.Generator]`, *optional*):282One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)283to make generation deterministic.284image_embeddings (`torch.Tensor`, *optional*):285Pre-defined image embeddings that can be derived from the image encoder. Pre-defined image embeddings286can be passed for tasks like image interpolations. `image` can the be left to `None`.287decoder_latents (`torch.FloatTensor` of shape (batch size, channels, height, width), *optional*):288Pre-generated noisy latents to be used as inputs for the decoder.289super_res_latents (`torch.FloatTensor` of shape (batch size, channels, super res height, super res width), *optional*):290Pre-generated noisy latents to be used as inputs for the decoder.291decoder_guidance_scale (`float`, *optional*, defaults to 4.0):292Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).293`guidance_scale` is defined as `w` of equation 2. of [Imagen294Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >2951`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,296usually at the expense of lower image quality.297output_type (`str`, *optional*, defaults to `"pil"`):298The output format of the generated image. Choose between299[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.300return_dict (`bool`, *optional*, defaults to `True`):301Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.302"""303304batch_size = steps305306device = self._execution_device307308if isinstance(image, List):309if len(image) != 2:310raise AssertionError(311f"Expected 'image' List to be of size 2, but passed 'image' length is {len(image)}"312)313elif not (isinstance(image[0], PIL.Image.Image) and isinstance(image[0], PIL.Image.Image)):314raise AssertionError(315f"Expected 'image' List to contain PIL.Image.Image, but passed 'image' contents are {type(image[0])} and {type(image[1])}"316)317elif isinstance(image, torch.FloatTensor):318if image.shape[0] != 2:319raise AssertionError(320f"Expected 'image' to be torch.FloatTensor of shape 2 in 0th dimension, but passed 'image' size is {image.shape[0]}"321)322elif isinstance(image_embeddings, torch.Tensor):323if image_embeddings.shape[0] != 2:324raise AssertionError(325f"Expected 'image_embeddings' to be torch.FloatTensor of shape 2 in 0th dimension, but passed 'image_embeddings' shape is {image_embeddings.shape[0]}"326)327else:328raise AssertionError(329f"Expected 'image' or 'image_embeddings' to be not None with types List[PIL.Image] or Torch.FloatTensor respectively. Received {type(image)} and {type(image_embeddings)} repsectively"330)331332original_image_embeddings = self._encode_image(333image=image, device=device, num_images_per_prompt=1, image_embeddings=image_embeddings334)335336image_embeddings = []337338for interp_step in torch.linspace(0, 1, steps):339temp_image_embeddings = slerp(340interp_step, original_image_embeddings[0], original_image_embeddings[1]341).unsqueeze(0)342image_embeddings.append(temp_image_embeddings)343344image_embeddings = torch.cat(image_embeddings).to(device)345346do_classifier_free_guidance = decoder_guidance_scale > 1.0347348prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt(349prompt=["" for i in range(steps)],350device=device,351num_images_per_prompt=1,352do_classifier_free_guidance=do_classifier_free_guidance,353)354355text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj(356image_embeddings=image_embeddings,357prompt_embeds=prompt_embeds,358text_encoder_hidden_states=text_encoder_hidden_states,359do_classifier_free_guidance=do_classifier_free_guidance,360)361362if device.type == "mps":363# HACK: MPS: There is a panic when padding bool tensors,364# so cast to int tensor for the pad and back to bool afterwards365text_mask = text_mask.type(torch.int)366decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1)367decoder_text_mask = decoder_text_mask.type(torch.bool)368else:369decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=True)370371self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device)372decoder_timesteps_tensor = self.decoder_scheduler.timesteps373374num_channels_latents = self.decoder.in_channels375height = self.decoder.sample_size376width = self.decoder.sample_size377378decoder_latents = self.prepare_latents(379(batch_size, num_channels_latents, height, width),380text_encoder_hidden_states.dtype,381device,382generator,383decoder_latents,384self.decoder_scheduler,385)386387for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)):388# expand the latents if we are doing classifier free guidance389latent_model_input = torch.cat([decoder_latents] * 2) if do_classifier_free_guidance else decoder_latents390391noise_pred = self.decoder(392sample=latent_model_input,393timestep=t,394encoder_hidden_states=text_encoder_hidden_states,395class_labels=additive_clip_time_embeddings,396attention_mask=decoder_text_mask,397).sample398399if do_classifier_free_guidance:400noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)401noise_pred_uncond, _ = noise_pred_uncond.split(latent_model_input.shape[1], dim=1)402noise_pred_text, predicted_variance = noise_pred_text.split(latent_model_input.shape[1], dim=1)403noise_pred = noise_pred_uncond + decoder_guidance_scale * (noise_pred_text - noise_pred_uncond)404noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)405406if i + 1 == decoder_timesteps_tensor.shape[0]:407prev_timestep = None408else:409prev_timestep = decoder_timesteps_tensor[i + 1]410411# compute the previous noisy sample x_t -> x_t-1412decoder_latents = self.decoder_scheduler.step(413noise_pred, t, decoder_latents, prev_timestep=prev_timestep, generator=generator414).prev_sample415416decoder_latents = decoder_latents.clamp(-1, 1)417418image_small = decoder_latents419420# done decoder421422# super res423424self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device)425super_res_timesteps_tensor = self.super_res_scheduler.timesteps426427channels = self.super_res_first.in_channels // 2428height = self.super_res_first.sample_size429width = self.super_res_first.sample_size430431super_res_latents = self.prepare_latents(432(batch_size, channels, height, width),433image_small.dtype,434device,435generator,436super_res_latents,437self.super_res_scheduler,438)439440if device.type == "mps":441# MPS does not support many interpolations442image_upscaled = F.interpolate(image_small, size=[height, width])443else:444interpolate_antialias = {}445if "antialias" in inspect.signature(F.interpolate).parameters:446interpolate_antialias["antialias"] = True447448image_upscaled = F.interpolate(449image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias450)451452for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)):453# no classifier free guidance454455if i == super_res_timesteps_tensor.shape[0] - 1:456unet = self.super_res_last457else:458unet = self.super_res_first459460latent_model_input = torch.cat([super_res_latents, image_upscaled], dim=1)461462noise_pred = unet(463sample=latent_model_input,464timestep=t,465).sample466467if i + 1 == super_res_timesteps_tensor.shape[0]:468prev_timestep = None469else:470prev_timestep = super_res_timesteps_tensor[i + 1]471472# compute the previous noisy sample x_t -> x_t-1473super_res_latents = self.super_res_scheduler.step(474noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator475).prev_sample476477image = super_res_latents478# done super res479480# post processing481482image = image * 0.5 + 0.5483image = image.clamp(0, 1)484image = image.cpu().permute(0, 2, 3, 1).float().numpy()485486if output_type == "pil":487image = self.numpy_to_pil(image)488489if not return_dict:490return (image,)491492return ImagePipelineOutput(images=image)493494495