Path: blob/main/examples/community/multilingual_stable_diffusion.py
1448 views
import inspect1from typing import Callable, List, Optional, Union23import torch4from transformers import (5CLIPImageProcessor,6CLIPTextModel,7CLIPTokenizer,8MBart50TokenizerFast,9MBartForConditionalGeneration,10pipeline,11)1213from diffusers import DiffusionPipeline14from diffusers.configuration_utils import FrozenDict15from diffusers.models import AutoencoderKL, UNet2DConditionModel16from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput17from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker18from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler19from diffusers.utils import deprecate, logging202122logger = logging.get_logger(__name__) # pylint: disable=invalid-name232425def detect_language(pipe, prompt, batch_size):26"""helper function to detect language(s) of prompt"""2728if batch_size == 1:29preds = pipe(prompt, top_k=1, truncation=True, max_length=128)30return preds[0]["label"]31else:32detected_languages = []33for p in prompt:34preds = pipe(p, top_k=1, truncation=True, max_length=128)35detected_languages.append(preds[0]["label"])3637return detected_languages383940def translate_prompt(prompt, translation_tokenizer, translation_model, device):41"""helper function to translate prompt to English"""4243encoded_prompt = translation_tokenizer(prompt, return_tensors="pt").to(device)44generated_tokens = translation_model.generate(**encoded_prompt, max_new_tokens=1000)45en_trans = translation_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)4647return en_trans[0]484950class MultilingualStableDiffusion(DiffusionPipeline):51r"""52Pipeline for text-to-image generation using Stable Diffusion in different languages.5354This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the55library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)5657Args:58detection_pipeline ([`pipeline`]):59Transformers pipeline to detect prompt's language.60translation_model ([`MBartForConditionalGeneration`]):61Model to translate prompt to English, if necessary. Please refer to the62[model card](https://huggingface.co/docs/transformers/model_doc/mbart) for details.63translation_tokenizer ([`MBart50TokenizerFast`]):64Tokenizer of the translation model.65vae ([`AutoencoderKL`]):66Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.67text_encoder ([`CLIPTextModel`]):68Frozen text-encoder. Stable Diffusion uses the text portion of69[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically70the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.71tokenizer (`CLIPTokenizer`):72Tokenizer of class73[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).74unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.75scheduler ([`SchedulerMixin`]):76A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of77[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].78safety_checker ([`StableDiffusionSafetyChecker`]):79Classification module that estimates whether generated images could be considered offensive or harmful.80Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.81feature_extractor ([`CLIPImageProcessor`]):82Model that extracts features from generated images to be used as inputs for the `safety_checker`.83"""8485def __init__(86self,87detection_pipeline: pipeline,88translation_model: MBartForConditionalGeneration,89translation_tokenizer: MBart50TokenizerFast,90vae: AutoencoderKL,91text_encoder: CLIPTextModel,92tokenizer: CLIPTokenizer,93unet: UNet2DConditionModel,94scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],95safety_checker: StableDiffusionSafetyChecker,96feature_extractor: CLIPImageProcessor,97):98super().__init__()99100if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:101deprecation_message = (102f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"103f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "104"to update the config accordingly as leaving `steps_offset` might led to incorrect results"105" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"106" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"107" file"108)109deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)110new_config = dict(scheduler.config)111new_config["steps_offset"] = 1112scheduler._internal_dict = FrozenDict(new_config)113114if safety_checker is None:115logger.warning(116f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"117" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"118" results in services or applications open to the public. Both the diffusers team and Hugging Face"119" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"120" it only for use-cases that involve analyzing network behavior or auditing its results. For more"121" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."122)123124self.register_modules(125detection_pipeline=detection_pipeline,126translation_model=translation_model,127translation_tokenizer=translation_tokenizer,128vae=vae,129text_encoder=text_encoder,130tokenizer=tokenizer,131unet=unet,132scheduler=scheduler,133safety_checker=safety_checker,134feature_extractor=feature_extractor,135)136137def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):138r"""139Enable sliced attention computation.140141When this option is enabled, the attention module will split the input tensor in slices, to compute attention142in several steps. This is useful to save some memory in exchange for a small speed decrease.143144Args:145slice_size (`str` or `int`, *optional*, defaults to `"auto"`):146When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If147a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,148`attention_head_dim` must be a multiple of `slice_size`.149"""150if slice_size == "auto":151# half the attention head size is usually a good trade-off between152# speed and memory153slice_size = self.unet.config.attention_head_dim // 2154self.unet.set_attention_slice(slice_size)155156def disable_attention_slicing(self):157r"""158Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go159back to computing attention in one step.160"""161# set slice_size = `None` to disable `attention slicing`162self.enable_attention_slicing(None)163164@torch.no_grad()165def __call__(166self,167prompt: Union[str, List[str]],168height: int = 512,169width: int = 512,170num_inference_steps: int = 50,171guidance_scale: float = 7.5,172negative_prompt: Optional[Union[str, List[str]]] = None,173num_images_per_prompt: Optional[int] = 1,174eta: float = 0.0,175generator: Optional[torch.Generator] = None,176latents: Optional[torch.FloatTensor] = None,177output_type: Optional[str] = "pil",178return_dict: bool = True,179callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,180callback_steps: int = 1,181**kwargs,182):183r"""184Function invoked when calling the pipeline for generation.185186Args:187prompt (`str` or `List[str]`):188The prompt or prompts to guide the image generation. Can be in different languages.189height (`int`, *optional*, defaults to 512):190The height in pixels of the generated image.191width (`int`, *optional*, defaults to 512):192The width in pixels of the generated image.193num_inference_steps (`int`, *optional*, defaults to 50):194The number of denoising steps. More denoising steps usually lead to a higher quality image at the195expense of slower inference.196guidance_scale (`float`, *optional*, defaults to 7.5):197Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).198`guidance_scale` is defined as `w` of equation 2. of [Imagen199Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >2001`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,201usually at the expense of lower image quality.202negative_prompt (`str` or `List[str]`, *optional*):203The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored204if `guidance_scale` is less than `1`).205num_images_per_prompt (`int`, *optional*, defaults to 1):206The number of images to generate per prompt.207eta (`float`, *optional*, defaults to 0.0):208Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to209[`schedulers.DDIMScheduler`], will be ignored for others.210generator (`torch.Generator`, *optional*):211A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation212deterministic.213latents (`torch.FloatTensor`, *optional*):214Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image215generation. Can be used to tweak the same generation with different prompts. If not provided, a latents216tensor will ge generated by sampling using the supplied random `generator`.217output_type (`str`, *optional*, defaults to `"pil"`):218The output format of the generate image. Choose between219[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.220return_dict (`bool`, *optional*, defaults to `True`):221Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a222plain tuple.223callback (`Callable`, *optional*):224A function that will be called every `callback_steps` steps during inference. The function will be225called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.226callback_steps (`int`, *optional*, defaults to 1):227The frequency at which the `callback` function will be called. If not specified, the callback will be228called at every step.229230Returns:231[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:232[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.233When returning a tuple, the first element is a list with the generated images, and the second element is a234list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"235(nsfw) content, according to the `safety_checker`.236"""237if isinstance(prompt, str):238batch_size = 1239elif isinstance(prompt, list):240batch_size = len(prompt)241else:242raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")243244if height % 8 != 0 or width % 8 != 0:245raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")246247if (callback_steps is None) or (248callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)249):250raise ValueError(251f"`callback_steps` has to be a positive integer but is {callback_steps} of type"252f" {type(callback_steps)}."253)254255# detect language and translate if necessary256prompt_language = detect_language(self.detection_pipeline, prompt, batch_size)257if batch_size == 1 and prompt_language != "en":258prompt = translate_prompt(prompt, self.translation_tokenizer, self.translation_model, self.device)259260if isinstance(prompt, list):261for index in range(batch_size):262if prompt_language[index] != "en":263p = translate_prompt(264prompt[index], self.translation_tokenizer, self.translation_model, self.device265)266prompt[index] = p267268# get prompt text embeddings269text_inputs = self.tokenizer(270prompt,271padding="max_length",272max_length=self.tokenizer.model_max_length,273return_tensors="pt",274)275text_input_ids = text_inputs.input_ids276277if text_input_ids.shape[-1] > self.tokenizer.model_max_length:278removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])279logger.warning(280"The following part of your input was truncated because CLIP can only handle sequences up to"281f" {self.tokenizer.model_max_length} tokens: {removed_text}"282)283text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]284text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]285286# duplicate text embeddings for each generation per prompt, using mps friendly method287bs_embed, seq_len, _ = text_embeddings.shape288text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)289text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)290291# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)292# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`293# corresponds to doing no classifier free guidance.294do_classifier_free_guidance = guidance_scale > 1.0295# get unconditional embeddings for classifier free guidance296if do_classifier_free_guidance:297uncond_tokens: List[str]298if negative_prompt is None:299uncond_tokens = [""] * batch_size300elif type(prompt) is not type(negative_prompt):301raise TypeError(302f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="303f" {type(prompt)}."304)305elif isinstance(negative_prompt, str):306# detect language and translate it if necessary307negative_prompt_language = detect_language(self.detection_pipeline, negative_prompt, batch_size)308if negative_prompt_language != "en":309negative_prompt = translate_prompt(310negative_prompt, self.translation_tokenizer, self.translation_model, self.device311)312if isinstance(negative_prompt, str):313uncond_tokens = [negative_prompt]314elif batch_size != len(negative_prompt):315raise ValueError(316f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"317f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"318" the batch size of `prompt`."319)320else:321# detect language and translate it if necessary322if isinstance(negative_prompt, list):323negative_prompt_languages = detect_language(self.detection_pipeline, negative_prompt, batch_size)324for index in range(batch_size):325if negative_prompt_languages[index] != "en":326p = translate_prompt(327negative_prompt[index], self.translation_tokenizer, self.translation_model, self.device328)329negative_prompt[index] = p330uncond_tokens = negative_prompt331332max_length = text_input_ids.shape[-1]333uncond_input = self.tokenizer(334uncond_tokens,335padding="max_length",336max_length=max_length,337truncation=True,338return_tensors="pt",339)340uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]341342# duplicate unconditional embeddings for each generation per prompt, using mps friendly method343seq_len = uncond_embeddings.shape[1]344uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)345uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)346347# For classifier free guidance, we need to do two forward passes.348# Here we concatenate the unconditional and text embeddings into a single batch349# to avoid doing two forward passes350text_embeddings = torch.cat([uncond_embeddings, text_embeddings])351352# get the initial random noise unless the user supplied it353354# Unlike in other pipelines, latents need to be generated in the target device355# for 1-to-1 results reproducibility with the CompVis implementation.356# However this currently doesn't work in `mps`.357latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)358latents_dtype = text_embeddings.dtype359if latents is None:360if self.device.type == "mps":361# randn does not work reproducibly on mps362latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(363self.device364)365else:366latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)367else:368if latents.shape != latents_shape:369raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")370latents = latents.to(self.device)371372# set timesteps373self.scheduler.set_timesteps(num_inference_steps)374375# Some schedulers like PNDM have timesteps as arrays376# It's more optimized to move all timesteps to correct device beforehand377timesteps_tensor = self.scheduler.timesteps.to(self.device)378379# scale the initial noise by the standard deviation required by the scheduler380latents = latents * self.scheduler.init_noise_sigma381382# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature383# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.384# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502385# and should be between [0, 1]386accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())387extra_step_kwargs = {}388if accepts_eta:389extra_step_kwargs["eta"] = eta390391for i, t in enumerate(self.progress_bar(timesteps_tensor)):392# expand the latents if we are doing classifier free guidance393latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents394latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)395396# predict the noise residual397noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample398399# perform guidance400if do_classifier_free_guidance:401noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)402noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)403404# compute the previous noisy sample x_t -> x_t-1405latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample406407# call the callback, if provided408if callback is not None and i % callback_steps == 0:409callback(i, t, latents)410411latents = 1 / 0.18215 * latents412image = self.vae.decode(latents).sample413414image = (image / 2 + 0.5).clamp(0, 1)415416# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16417image = image.cpu().permute(0, 2, 3, 1).float().numpy()418419if self.safety_checker is not None:420safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(421self.device422)423image, has_nsfw_concept = self.safety_checker(424images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)425)426else:427has_nsfw_concept = None428429if output_type == "pil":430image = self.numpy_to_pil(image)431432if not return_dict:433return (image, has_nsfw_concept)434435return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)436437438