Path: blob/main/examples/community/speech_to_image_diffusion.py
1448 views
import inspect1from typing import Callable, List, Optional, Union23import torch4from transformers import (5CLIPImageProcessor,6CLIPTextModel,7CLIPTokenizer,8WhisperForConditionalGeneration,9WhisperProcessor,10)1112from diffusers import (13AutoencoderKL,14DDIMScheduler,15DiffusionPipeline,16LMSDiscreteScheduler,17PNDMScheduler,18UNet2DConditionModel,19)20from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput21from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker22from diffusers.utils import logging232425logger = logging.get_logger(__name__) # pylint: disable=invalid-name262728class SpeechToImagePipeline(DiffusionPipeline):29def __init__(30self,31speech_model: WhisperForConditionalGeneration,32speech_processor: WhisperProcessor,33vae: AutoencoderKL,34text_encoder: CLIPTextModel,35tokenizer: CLIPTokenizer,36unet: UNet2DConditionModel,37scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],38safety_checker: StableDiffusionSafetyChecker,39feature_extractor: CLIPImageProcessor,40):41super().__init__()4243if safety_checker is None:44logger.warning(45f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"46" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"47" results in services or applications open to the public. Both the diffusers team and Hugging Face"48" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"49" it only for use-cases that involve analyzing network behavior or auditing its results. For more"50" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."51)5253self.register_modules(54speech_model=speech_model,55speech_processor=speech_processor,56vae=vae,57text_encoder=text_encoder,58tokenizer=tokenizer,59unet=unet,60scheduler=scheduler,61feature_extractor=feature_extractor,62)6364def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):65if slice_size == "auto":66slice_size = self.unet.config.attention_head_dim // 267self.unet.set_attention_slice(slice_size)6869def disable_attention_slicing(self):70self.enable_attention_slicing(None)7172@torch.no_grad()73def __call__(74self,75audio,76sampling_rate=16_000,77height: int = 512,78width: int = 512,79num_inference_steps: int = 50,80guidance_scale: float = 7.5,81negative_prompt: Optional[Union[str, List[str]]] = None,82num_images_per_prompt: Optional[int] = 1,83eta: float = 0.0,84generator: Optional[torch.Generator] = None,85latents: Optional[torch.FloatTensor] = None,86output_type: Optional[str] = "pil",87return_dict: bool = True,88callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,89callback_steps: int = 1,90**kwargs,91):92inputs = self.speech_processor.feature_extractor(93audio, return_tensors="pt", sampling_rate=sampling_rate94).input_features.to(self.device)95predicted_ids = self.speech_model.generate(inputs, max_length=480_000)9697prompt = self.speech_processor.tokenizer.batch_decode(predicted_ids, skip_special_tokens=True, normalize=True)[98099]100101if isinstance(prompt, str):102batch_size = 1103elif isinstance(prompt, list):104batch_size = len(prompt)105else:106raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")107108if height % 8 != 0 or width % 8 != 0:109raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")110111if (callback_steps is None) or (112callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)113):114raise ValueError(115f"`callback_steps` has to be a positive integer but is {callback_steps} of type"116f" {type(callback_steps)}."117)118119# get prompt text embeddings120text_inputs = self.tokenizer(121prompt,122padding="max_length",123max_length=self.tokenizer.model_max_length,124return_tensors="pt",125)126text_input_ids = text_inputs.input_ids127128if text_input_ids.shape[-1] > self.tokenizer.model_max_length:129removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])130logger.warning(131"The following part of your input was truncated because CLIP can only handle sequences up to"132f" {self.tokenizer.model_max_length} tokens: {removed_text}"133)134text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]135text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]136137# duplicate text embeddings for each generation per prompt, using mps friendly method138bs_embed, seq_len, _ = text_embeddings.shape139text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)140text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)141142# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)143# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`144# corresponds to doing no classifier free guidance.145do_classifier_free_guidance = guidance_scale > 1.0146# get unconditional embeddings for classifier free guidance147if do_classifier_free_guidance:148uncond_tokens: List[str]149if negative_prompt is None:150uncond_tokens = [""] * batch_size151elif type(prompt) is not type(negative_prompt):152raise TypeError(153f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="154f" {type(prompt)}."155)156elif isinstance(negative_prompt, str):157uncond_tokens = [negative_prompt]158elif batch_size != len(negative_prompt):159raise ValueError(160f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"161f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"162" the batch size of `prompt`."163)164else:165uncond_tokens = negative_prompt166167max_length = text_input_ids.shape[-1]168uncond_input = self.tokenizer(169uncond_tokens,170padding="max_length",171max_length=max_length,172truncation=True,173return_tensors="pt",174)175uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]176177# duplicate unconditional embeddings for each generation per prompt, using mps friendly method178seq_len = uncond_embeddings.shape[1]179uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)180uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)181182# For classifier free guidance, we need to do two forward passes.183# Here we concatenate the unconditional and text embeddings into a single batch184# to avoid doing two forward passes185text_embeddings = torch.cat([uncond_embeddings, text_embeddings])186187# get the initial random noise unless the user supplied it188189# Unlike in other pipelines, latents need to be generated in the target device190# for 1-to-1 results reproducibility with the CompVis implementation.191# However this currently doesn't work in `mps`.192latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)193latents_dtype = text_embeddings.dtype194if latents is None:195if self.device.type == "mps":196# randn does not exist on mps197latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(198self.device199)200else:201latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)202else:203if latents.shape != latents_shape:204raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")205latents = latents.to(self.device)206207# set timesteps208self.scheduler.set_timesteps(num_inference_steps)209210# Some schedulers like PNDM have timesteps as arrays211# It's more optimized to move all timesteps to correct device beforehand212timesteps_tensor = self.scheduler.timesteps.to(self.device)213214# scale the initial noise by the standard deviation required by the scheduler215latents = latents * self.scheduler.init_noise_sigma216217# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature218# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.219# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502220# and should be between [0, 1]221accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())222extra_step_kwargs = {}223if accepts_eta:224extra_step_kwargs["eta"] = eta225226for i, t in enumerate(self.progress_bar(timesteps_tensor)):227# expand the latents if we are doing classifier free guidance228latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents229latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)230231# predict the noise residual232noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample233234# perform guidance235if do_classifier_free_guidance:236noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)237noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)238239# compute the previous noisy sample x_t -> x_t-1240latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample241242# call the callback, if provided243if callback is not None and i % callback_steps == 0:244callback(i, t, latents)245246latents = 1 / 0.18215 * latents247image = self.vae.decode(latents).sample248249image = (image / 2 + 0.5).clamp(0, 1)250251# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16252image = image.cpu().permute(0, 2, 3, 1).float().numpy()253254if output_type == "pil":255image = self.numpy_to_pil(image)256257if not return_dict:258return image259260return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None)261262263