Path: blob/main/examples/community/composable_stable_diffusion.py
1448 views
# Copyright 2023 The HuggingFace Team. All rights reserved.1#2# Licensed under the Apache License, Version 2.0 (the "License");3# you may not use this file except in compliance with the License.4# You may obtain a copy of the License at5#6# http://www.apache.org/licenses/LICENSE-2.07#8# Unless required by applicable law or agreed to in writing, software9# distributed under the License is distributed on an "AS IS" BASIS,10# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.11# See the License for the specific language governing permissions and12# limitations under the License.1314import inspect15from typing import Callable, List, Optional, Union1617import torch18from packaging import version19from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer2021from diffusers import DiffusionPipeline22from diffusers.configuration_utils import FrozenDict23from diffusers.models import AutoencoderKL, UNet2DConditionModel24from diffusers.schedulers import (25DDIMScheduler,26DPMSolverMultistepScheduler,27EulerAncestralDiscreteScheduler,28EulerDiscreteScheduler,29LMSDiscreteScheduler,30PNDMScheduler,31)32from diffusers.utils import is_accelerate_available3334from ...utils import deprecate, logging35from . import StableDiffusionPipelineOutput36from .safety_checker import StableDiffusionSafetyChecker373839logger = logging.get_logger(__name__) # pylint: disable=invalid-name404142class ComposableStableDiffusionPipeline(DiffusionPipeline):43r"""44Pipeline for text-to-image generation using Stable Diffusion.4546This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the47library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)4849Args:50vae ([`AutoencoderKL`]):51Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.52text_encoder ([`CLIPTextModel`]):53Frozen text-encoder. Stable Diffusion uses the text portion of54[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically55the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.56tokenizer (`CLIPTokenizer`):57Tokenizer of class58[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).59unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.60scheduler ([`SchedulerMixin`]):61A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of62[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].63safety_checker ([`StableDiffusionSafetyChecker`]):64Classification module that estimates whether generated images could be considered offensive or harmful.65Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.66feature_extractor ([`CLIPImageProcessor`]):67Model that extracts features from generated images to be used as inputs for the `safety_checker`.68"""69_optional_components = ["safety_checker", "feature_extractor"]7071def __init__(72self,73vae: AutoencoderKL,74text_encoder: CLIPTextModel,75tokenizer: CLIPTokenizer,76unet: UNet2DConditionModel,77scheduler: Union[78DDIMScheduler,79PNDMScheduler,80LMSDiscreteScheduler,81EulerDiscreteScheduler,82EulerAncestralDiscreteScheduler,83DPMSolverMultistepScheduler,84],85safety_checker: StableDiffusionSafetyChecker,86feature_extractor: CLIPImageProcessor,87requires_safety_checker: bool = True,88):89super().__init__()9091if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:92deprecation_message = (93f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"94f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "95"to update the config accordingly as leaving `steps_offset` might led to incorrect results"96" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"97" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"98" file"99)100deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)101new_config = dict(scheduler.config)102new_config["steps_offset"] = 1103scheduler._internal_dict = FrozenDict(new_config)104105if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:106deprecation_message = (107f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."108" `clip_sample` should be set to False in the configuration file. Please make sure to update the"109" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"110" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"111" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"112)113deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)114new_config = dict(scheduler.config)115new_config["clip_sample"] = False116scheduler._internal_dict = FrozenDict(new_config)117118if safety_checker is None and requires_safety_checker:119logger.warning(120f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"121" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"122" results in services or applications open to the public. Both the diffusers team and Hugging Face"123" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"124" it only for use-cases that involve analyzing network behavior or auditing its results. For more"125" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."126)127128if safety_checker is not None and feature_extractor is None:129raise ValueError(130"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"131" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."132)133134is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(135version.parse(unet.config._diffusers_version).base_version136) < version.parse("0.9.0.dev0")137is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64138if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:139deprecation_message = (140"The configuration file of the unet has set the default `sample_size` to smaller than"141" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"142" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"143" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"144" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"145" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"146" in the config might lead to incorrect results in future versions. If you have downloaded this"147" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"148" the `unet/config.json` file"149)150deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)151new_config = dict(unet.config)152new_config["sample_size"] = 64153unet._internal_dict = FrozenDict(new_config)154155self.register_modules(156vae=vae,157text_encoder=text_encoder,158tokenizer=tokenizer,159unet=unet,160scheduler=scheduler,161safety_checker=safety_checker,162feature_extractor=feature_extractor,163)164self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)165self.register_to_config(requires_safety_checker=requires_safety_checker)166167def enable_vae_slicing(self):168r"""169Enable sliced VAE decoding.170171When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several172steps. This is useful to save some memory and allow larger batch sizes.173"""174self.vae.enable_slicing()175176def disable_vae_slicing(self):177r"""178Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to179computing decoding in one step.180"""181self.vae.disable_slicing()182183def enable_sequential_cpu_offload(self, gpu_id=0):184r"""185Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,186text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a187`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.188"""189if is_accelerate_available():190from accelerate import cpu_offload191else:192raise ImportError("Please install accelerate via `pip install accelerate`")193194device = torch.device(f"cuda:{gpu_id}")195196for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:197if cpu_offloaded_model is not None:198cpu_offload(cpu_offloaded_model, device)199200if self.safety_checker is not None:201# TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate202# fix by only offloading self.safety_checker for now203cpu_offload(self.safety_checker.vision_model, device)204205@property206def _execution_device(self):207r"""208Returns the device on which the pipeline's models will be executed. After calling209`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module210hooks.211"""212if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):213return self.device214for module in self.unet.modules():215if (216hasattr(module, "_hf_hook")217and hasattr(module._hf_hook, "execution_device")218and module._hf_hook.execution_device is not None219):220return torch.device(module._hf_hook.execution_device)221return self.device222223def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):224r"""225Encodes the prompt into text encoder hidden states.226227Args:228prompt (`str` or `list(int)`):229prompt to be encoded230device: (`torch.device`):231torch device232num_images_per_prompt (`int`):233number of images that should be generated per prompt234do_classifier_free_guidance (`bool`):235whether to use classifier free guidance or not236negative_prompt (`str` or `List[str]`):237The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored238if `guidance_scale` is less than `1`).239"""240batch_size = len(prompt) if isinstance(prompt, list) else 1241242text_inputs = self.tokenizer(243prompt,244padding="max_length",245max_length=self.tokenizer.model_max_length,246truncation=True,247return_tensors="pt",248)249text_input_ids = text_inputs.input_ids250untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids251252if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):253removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])254logger.warning(255"The following part of your input was truncated because CLIP can only handle sequences up to"256f" {self.tokenizer.model_max_length} tokens: {removed_text}"257)258259if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:260attention_mask = text_inputs.attention_mask.to(device)261else:262attention_mask = None263264text_embeddings = self.text_encoder(265text_input_ids.to(device),266attention_mask=attention_mask,267)268text_embeddings = text_embeddings[0]269270# duplicate text embeddings for each generation per prompt, using mps friendly method271bs_embed, seq_len, _ = text_embeddings.shape272text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)273text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)274275# get unconditional embeddings for classifier free guidance276if do_classifier_free_guidance:277uncond_tokens: List[str]278if negative_prompt is None:279uncond_tokens = [""] * batch_size280elif type(prompt) is not type(negative_prompt):281raise TypeError(282f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="283f" {type(prompt)}."284)285elif isinstance(negative_prompt, str):286uncond_tokens = [negative_prompt]287elif batch_size != len(negative_prompt):288raise ValueError(289f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"290f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"291" the batch size of `prompt`."292)293else:294uncond_tokens = negative_prompt295296max_length = text_input_ids.shape[-1]297uncond_input = self.tokenizer(298uncond_tokens,299padding="max_length",300max_length=max_length,301truncation=True,302return_tensors="pt",303)304305if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:306attention_mask = uncond_input.attention_mask.to(device)307else:308attention_mask = None309310uncond_embeddings = self.text_encoder(311uncond_input.input_ids.to(device),312attention_mask=attention_mask,313)314uncond_embeddings = uncond_embeddings[0]315316# duplicate unconditional embeddings for each generation per prompt, using mps friendly method317seq_len = uncond_embeddings.shape[1]318uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)319uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)320321# For classifier free guidance, we need to do two forward passes.322# Here we concatenate the unconditional and text embeddings into a single batch323# to avoid doing two forward passes324text_embeddings = torch.cat([uncond_embeddings, text_embeddings])325326return text_embeddings327328def run_safety_checker(self, image, device, dtype):329if self.safety_checker is not None:330safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)331image, has_nsfw_concept = self.safety_checker(332images=image, clip_input=safety_checker_input.pixel_values.to(dtype)333)334else:335has_nsfw_concept = None336return image, has_nsfw_concept337338def decode_latents(self, latents):339latents = 1 / 0.18215 * latents340image = self.vae.decode(latents).sample341image = (image / 2 + 0.5).clamp(0, 1)342# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16343image = image.cpu().permute(0, 2, 3, 1).float().numpy()344return image345346def prepare_extra_step_kwargs(self, generator, eta):347# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature348# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.349# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502350# and should be between [0, 1]351352accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())353extra_step_kwargs = {}354if accepts_eta:355extra_step_kwargs["eta"] = eta356357# check if the scheduler accepts generator358accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())359if accepts_generator:360extra_step_kwargs["generator"] = generator361return extra_step_kwargs362363def check_inputs(self, prompt, height, width, callback_steps):364if not isinstance(prompt, str) and not isinstance(prompt, list):365raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")366367if height % 8 != 0 or width % 8 != 0:368raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")369370if (callback_steps is None) or (371callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)372):373raise ValueError(374f"`callback_steps` has to be a positive integer but is {callback_steps} of type"375f" {type(callback_steps)}."376)377378def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):379shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)380if latents is None:381if device.type == "mps":382# randn does not work reproducibly on mps383latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)384else:385latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)386else:387if latents.shape != shape:388raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")389latents = latents.to(device)390391# scale the initial noise by the standard deviation required by the scheduler392latents = latents * self.scheduler.init_noise_sigma393return latents394395@torch.no_grad()396def __call__(397self,398prompt: Union[str, List[str]],399height: Optional[int] = None,400width: Optional[int] = None,401num_inference_steps: int = 50,402guidance_scale: float = 7.5,403negative_prompt: Optional[Union[str, List[str]]] = None,404num_images_per_prompt: Optional[int] = 1,405eta: float = 0.0,406generator: Optional[torch.Generator] = None,407latents: Optional[torch.FloatTensor] = None,408output_type: Optional[str] = "pil",409return_dict: bool = True,410callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,411callback_steps: int = 1,412weights: Optional[str] = "",413):414r"""415Function invoked when calling the pipeline for generation.416417Args:418prompt (`str` or `List[str]`):419The prompt or prompts to guide the image generation.420height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):421The height in pixels of the generated image.422width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):423The width in pixels of the generated image.424num_inference_steps (`int`, *optional*, defaults to 50):425The number of denoising steps. More denoising steps usually lead to a higher quality image at the426expense of slower inference.427guidance_scale (`float`, *optional*, defaults to 7.5):428Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).429`guidance_scale` is defined as `w` of equation 2. of [Imagen430Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >4311`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,432usually at the expense of lower image quality.433negative_prompt (`str` or `List[str]`, *optional*):434The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored435if `guidance_scale` is less than `1`).436num_images_per_prompt (`int`, *optional*, defaults to 1):437The number of images to generate per prompt.438eta (`float`, *optional*, defaults to 0.0):439Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to440[`schedulers.DDIMScheduler`], will be ignored for others.441generator (`torch.Generator`, *optional*):442A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation443deterministic.444latents (`torch.FloatTensor`, *optional*):445Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image446generation. Can be used to tweak the same generation with different prompts. If not provided, a latents447tensor will ge generated by sampling using the supplied random `generator`.448output_type (`str`, *optional*, defaults to `"pil"`):449The output format of the generate image. Choose between450[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.451return_dict (`bool`, *optional*, defaults to `True`):452Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a453plain tuple.454callback (`Callable`, *optional*):455A function that will be called every `callback_steps` steps during inference. The function will be456called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.457callback_steps (`int`, *optional*, defaults to 1):458The frequency at which the `callback` function will be called. If not specified, the callback will be459called at every step.460461Returns:462[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:463[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.464When returning a tuple, the first element is a list with the generated images, and the second element is a465list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"466(nsfw) content, according to the `safety_checker`.467"""468# 0. Default height and width to unet469height = height or self.unet.config.sample_size * self.vae_scale_factor470width = width or self.unet.config.sample_size * self.vae_scale_factor471472# 1. Check inputs. Raise error if not correct473self.check_inputs(prompt, height, width, callback_steps)474475# 2. Define call parameters476batch_size = 1 if isinstance(prompt, str) else len(prompt)477device = self._execution_device478# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)479# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`480# corresponds to doing no classifier free guidance.481do_classifier_free_guidance = guidance_scale > 1.0482483if "|" in prompt:484prompt = [x.strip() for x in prompt.split("|")]485print(f"composing {prompt}...")486487if not weights:488# specify weights for prompts (excluding the unconditional score)489print("using equal positive weights (conjunction) for all prompts...")490weights = torch.tensor([guidance_scale] * len(prompt), device=self.device).reshape(-1, 1, 1, 1)491else:492# set prompt weight for each493num_prompts = len(prompt) if isinstance(prompt, list) else 1494weights = [float(w.strip()) for w in weights.split("|")]495# guidance scale as the default496if len(weights) < num_prompts:497weights.append(guidance_scale)498else:499weights = weights[:num_prompts]500assert len(weights) == len(prompt), "weights specified are not equal to the number of prompts"501weights = torch.tensor(weights, device=self.device).reshape(-1, 1, 1, 1)502else:503weights = guidance_scale504505# 3. Encode input prompt506text_embeddings = self._encode_prompt(507prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt508)509510# 4. Prepare timesteps511self.scheduler.set_timesteps(num_inference_steps, device=device)512timesteps = self.scheduler.timesteps513514# 5. Prepare latent variables515num_channels_latents = self.unet.in_channels516latents = self.prepare_latents(517batch_size * num_images_per_prompt,518num_channels_latents,519height,520width,521text_embeddings.dtype,522device,523generator,524latents,525)526527# composable diffusion528if isinstance(prompt, list) and batch_size == 1:529# remove extra unconditional embedding530# N = one unconditional embed + conditional embeds531text_embeddings = text_embeddings[len(prompt) - 1 :]532533# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline534extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)535536# 7. Denoising loop537num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order538with self.progress_bar(total=num_inference_steps) as progress_bar:539for i, t in enumerate(timesteps):540# expand the latents if we are doing classifier free guidance541latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents542latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)543544# predict the noise residual545noise_pred = []546for j in range(text_embeddings.shape[0]):547noise_pred.append(548self.unet(latent_model_input[:1], t, encoder_hidden_states=text_embeddings[j : j + 1]).sample549)550noise_pred = torch.cat(noise_pred, dim=0)551552# perform guidance553if do_classifier_free_guidance:554noise_pred_uncond, noise_pred_text = noise_pred[:1], noise_pred[1:]555noise_pred = noise_pred_uncond + (weights * (noise_pred_text - noise_pred_uncond)).sum(556dim=0, keepdims=True557)558559# compute the previous noisy sample x_t -> x_t-1560latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample561562# call the callback, if provided563if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):564progress_bar.update()565if callback is not None and i % callback_steps == 0:566callback(i, t, latents)567568# 8. Post-processing569image = self.decode_latents(latents)570571# 9. Run safety checker572image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)573574# 10. Convert to PIL575if output_type == "pil":576image = self.numpy_to_pil(image)577578if not return_dict:579return (image, has_nsfw_concept)580581return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)582583584