Path: blob/main/examples/community/sd_text2img_k_diffusion.py
1448 views
# Copyright 2023 The HuggingFace Team. All rights reserved.1#2# Licensed under the Apache License, Version 2.0 (the "License");3# you may not use this file except in compliance with the License.4# You may obtain a copy of the License at5#6# http://www.apache.org/licenses/LICENSE-2.07#8# Unless required by applicable law or agreed to in writing, software9# distributed under the License is distributed on an "AS IS" BASIS,10# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.11# See the License for the specific language governing permissions and12# limitations under the License.1314import importlib15import warnings16from typing import Callable, List, Optional, Union1718import torch19from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser2021from diffusers import DiffusionPipeline, LMSDiscreteScheduler22from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput23from diffusers.utils import is_accelerate_available, logging242526logger = logging.get_logger(__name__) # pylint: disable=invalid-name272829class ModelWrapper:30def __init__(self, model, alphas_cumprod):31self.model = model32self.alphas_cumprod = alphas_cumprod3334def apply_model(self, *args, **kwargs):35if len(args) == 3:36encoder_hidden_states = args[-1]37args = args[:2]38if kwargs.get("cond", None) is not None:39encoder_hidden_states = kwargs.pop("cond")40return self.model(*args, encoder_hidden_states=encoder_hidden_states, **kwargs).sample414243class StableDiffusionPipeline(DiffusionPipeline):44r"""45Pipeline for text-to-image generation using Stable Diffusion.4647This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the48library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)4950Args:51vae ([`AutoencoderKL`]):52Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.53text_encoder ([`CLIPTextModel`]):54Frozen text-encoder. Stable Diffusion uses the text portion of55[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically56the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.57tokenizer (`CLIPTokenizer`):58Tokenizer of class59[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).60unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.61scheduler ([`SchedulerMixin`]):62A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of63[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].64safety_checker ([`StableDiffusionSafetyChecker`]):65Classification module that estimates whether generated images could be considered offensive or harmful.66Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.67feature_extractor ([`CLIPImageProcessor`]):68Model that extracts features from generated images to be used as inputs for the `safety_checker`.69"""70_optional_components = ["safety_checker", "feature_extractor"]7172def __init__(73self,74vae,75text_encoder,76tokenizer,77unet,78scheduler,79safety_checker,80feature_extractor,81):82super().__init__()8384if safety_checker is None:85logger.warning(86f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"87" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"88" results in services or applications open to the public. Both the diffusers team and Hugging Face"89" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"90" it only for use-cases that involve analyzing network behavior or auditing its results. For more"91" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."92)9394# get correct sigmas from LMS95scheduler = LMSDiscreteScheduler.from_config(scheduler.config)96self.register_modules(97vae=vae,98text_encoder=text_encoder,99tokenizer=tokenizer,100unet=unet,101scheduler=scheduler,102safety_checker=safety_checker,103feature_extractor=feature_extractor,104)105106model = ModelWrapper(unet, scheduler.alphas_cumprod)107if scheduler.prediction_type == "v_prediction":108self.k_diffusion_model = CompVisVDenoiser(model)109else:110self.k_diffusion_model = CompVisDenoiser(model)111112def set_sampler(self, scheduler_type: str):113warnings.warn("The `set_sampler` method is deprecated, please use `set_scheduler` instead.")114return self.set_scheduler(scheduler_type)115116def set_scheduler(self, scheduler_type: str):117library = importlib.import_module("k_diffusion")118sampling = getattr(library, "sampling")119self.sampler = getattr(sampling, scheduler_type)120121def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):122r"""123Enable sliced attention computation.124125When this option is enabled, the attention module will split the input tensor in slices, to compute attention126in several steps. This is useful to save some memory in exchange for a small speed decrease.127128Args:129slice_size (`str` or `int`, *optional*, defaults to `"auto"`):130When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If131a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,132`attention_head_dim` must be a multiple of `slice_size`.133"""134if slice_size == "auto":135# half the attention head size is usually a good trade-off between136# speed and memory137slice_size = self.unet.config.attention_head_dim // 2138self.unet.set_attention_slice(slice_size)139140def disable_attention_slicing(self):141r"""142Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go143back to computing attention in one step.144"""145# set slice_size = `None` to disable `attention slicing`146self.enable_attention_slicing(None)147148def enable_sequential_cpu_offload(self, gpu_id=0):149r"""150Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,151text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a152`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.153"""154if is_accelerate_available():155from accelerate import cpu_offload156else:157raise ImportError("Please install accelerate via `pip install accelerate`")158159device = torch.device(f"cuda:{gpu_id}")160161for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:162if cpu_offloaded_model is not None:163cpu_offload(cpu_offloaded_model, device)164165@property166def _execution_device(self):167r"""168Returns the device on which the pipeline's models will be executed. After calling169`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module170hooks.171"""172if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):173return self.device174for module in self.unet.modules():175if (176hasattr(module, "_hf_hook")177and hasattr(module._hf_hook, "execution_device")178and module._hf_hook.execution_device is not None179):180return torch.device(module._hf_hook.execution_device)181return self.device182183def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):184r"""185Encodes the prompt into text encoder hidden states.186187Args:188prompt (`str` or `list(int)`):189prompt to be encoded190device: (`torch.device`):191torch device192num_images_per_prompt (`int`):193number of images that should be generated per prompt194do_classifier_free_guidance (`bool`):195whether to use classifier free guidance or not196negative_prompt (`str` or `List[str]`):197The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored198if `guidance_scale` is less than `1`).199"""200batch_size = len(prompt) if isinstance(prompt, list) else 1201202text_inputs = self.tokenizer(203prompt,204padding="max_length",205max_length=self.tokenizer.model_max_length,206truncation=True,207return_tensors="pt",208)209text_input_ids = text_inputs.input_ids210untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids211212if not torch.equal(text_input_ids, untruncated_ids):213removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])214logger.warning(215"The following part of your input was truncated because CLIP can only handle sequences up to"216f" {self.tokenizer.model_max_length} tokens: {removed_text}"217)218219if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:220attention_mask = text_inputs.attention_mask.to(device)221else:222attention_mask = None223224text_embeddings = self.text_encoder(225text_input_ids.to(device),226attention_mask=attention_mask,227)228text_embeddings = text_embeddings[0]229230# duplicate text embeddings for each generation per prompt, using mps friendly method231bs_embed, seq_len, _ = text_embeddings.shape232text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)233text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)234235# get unconditional embeddings for classifier free guidance236if do_classifier_free_guidance:237uncond_tokens: List[str]238if negative_prompt is None:239uncond_tokens = [""] * batch_size240elif type(prompt) is not type(negative_prompt):241raise TypeError(242f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="243f" {type(prompt)}."244)245elif isinstance(negative_prompt, str):246uncond_tokens = [negative_prompt]247elif batch_size != len(negative_prompt):248raise ValueError(249f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"250f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"251" the batch size of `prompt`."252)253else:254uncond_tokens = negative_prompt255256max_length = text_input_ids.shape[-1]257uncond_input = self.tokenizer(258uncond_tokens,259padding="max_length",260max_length=max_length,261truncation=True,262return_tensors="pt",263)264265if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:266attention_mask = uncond_input.attention_mask.to(device)267else:268attention_mask = None269270uncond_embeddings = self.text_encoder(271uncond_input.input_ids.to(device),272attention_mask=attention_mask,273)274uncond_embeddings = uncond_embeddings[0]275276# duplicate unconditional embeddings for each generation per prompt, using mps friendly method277seq_len = uncond_embeddings.shape[1]278uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)279uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)280281# For classifier free guidance, we need to do two forward passes.282# Here we concatenate the unconditional and text embeddings into a single batch283# to avoid doing two forward passes284text_embeddings = torch.cat([uncond_embeddings, text_embeddings])285286return text_embeddings287288def run_safety_checker(self, image, device, dtype):289if self.safety_checker is not None:290safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)291image, has_nsfw_concept = self.safety_checker(292images=image, clip_input=safety_checker_input.pixel_values.to(dtype)293)294else:295has_nsfw_concept = None296return image, has_nsfw_concept297298def decode_latents(self, latents):299latents = 1 / 0.18215 * latents300image = self.vae.decode(latents).sample301image = (image / 2 + 0.5).clamp(0, 1)302# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16303image = image.cpu().permute(0, 2, 3, 1).float().numpy()304return image305306def check_inputs(self, prompt, height, width, callback_steps):307if not isinstance(prompt, str) and not isinstance(prompt, list):308raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")309310if height % 8 != 0 or width % 8 != 0:311raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")312313if (callback_steps is None) or (314callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)315):316raise ValueError(317f"`callback_steps` has to be a positive integer but is {callback_steps} of type"318f" {type(callback_steps)}."319)320321def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):322shape = (batch_size, num_channels_latents, height // 8, width // 8)323if latents is None:324if device.type == "mps":325# randn does not work reproducibly on mps326latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)327else:328latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)329else:330if latents.shape != shape:331raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")332latents = latents.to(device)333334# scale the initial noise by the standard deviation required by the scheduler335return latents336337@torch.no_grad()338def __call__(339self,340prompt: Union[str, List[str]],341height: int = 512,342width: int = 512,343num_inference_steps: int = 50,344guidance_scale: float = 7.5,345negative_prompt: Optional[Union[str, List[str]]] = None,346num_images_per_prompt: Optional[int] = 1,347eta: float = 0.0,348generator: Optional[torch.Generator] = None,349latents: Optional[torch.FloatTensor] = None,350output_type: Optional[str] = "pil",351return_dict: bool = True,352callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,353callback_steps: int = 1,354**kwargs,355):356r"""357Function invoked when calling the pipeline for generation.358359Args:360prompt (`str` or `List[str]`):361The prompt or prompts to guide the image generation.362height (`int`, *optional*, defaults to 512):363The height in pixels of the generated image.364width (`int`, *optional*, defaults to 512):365The width in pixels of the generated image.366num_inference_steps (`int`, *optional*, defaults to 50):367The number of denoising steps. More denoising steps usually lead to a higher quality image at the368expense of slower inference.369guidance_scale (`float`, *optional*, defaults to 7.5):370Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).371`guidance_scale` is defined as `w` of equation 2. of [Imagen372Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >3731`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,374usually at the expense of lower image quality.375negative_prompt (`str` or `List[str]`, *optional*):376The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored377if `guidance_scale` is less than `1`).378num_images_per_prompt (`int`, *optional*, defaults to 1):379The number of images to generate per prompt.380eta (`float`, *optional*, defaults to 0.0):381Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to382[`schedulers.DDIMScheduler`], will be ignored for others.383generator (`torch.Generator`, *optional*):384A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation385deterministic.386latents (`torch.FloatTensor`, *optional*):387Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image388generation. Can be used to tweak the same generation with different prompts. If not provided, a latents389tensor will ge generated by sampling using the supplied random `generator`.390output_type (`str`, *optional*, defaults to `"pil"`):391The output format of the generate image. Choose between392[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.393return_dict (`bool`, *optional*, defaults to `True`):394Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a395plain tuple.396callback (`Callable`, *optional*):397A function that will be called every `callback_steps` steps during inference. The function will be398called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.399callback_steps (`int`, *optional*, defaults to 1):400The frequency at which the `callback` function will be called. If not specified, the callback will be401called at every step.402403Returns:404[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:405[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.406When returning a tuple, the first element is a list with the generated images, and the second element is a407list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"408(nsfw) content, according to the `safety_checker`.409"""410411# 1. Check inputs. Raise error if not correct412self.check_inputs(prompt, height, width, callback_steps)413414# 2. Define call parameters415batch_size = 1 if isinstance(prompt, str) else len(prompt)416device = self._execution_device417# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)418# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`419# corresponds to doing no classifier free guidance.420do_classifier_free_guidance = True421if guidance_scale <= 1.0:422raise ValueError("has to use guidance_scale")423424# 3. Encode input prompt425text_embeddings = self._encode_prompt(426prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt427)428429# 4. Prepare timesteps430self.scheduler.set_timesteps(num_inference_steps, device=text_embeddings.device)431sigmas = self.scheduler.sigmas432sigmas = sigmas.to(text_embeddings.dtype)433434# 5. Prepare latent variables435num_channels_latents = self.unet.in_channels436latents = self.prepare_latents(437batch_size * num_images_per_prompt,438num_channels_latents,439height,440width,441text_embeddings.dtype,442device,443generator,444latents,445)446latents = latents * sigmas[0]447self.k_diffusion_model.sigmas = self.k_diffusion_model.sigmas.to(latents.device)448self.k_diffusion_model.log_sigmas = self.k_diffusion_model.log_sigmas.to(latents.device)449450def model_fn(x, t):451latent_model_input = torch.cat([x] * 2)452453noise_pred = self.k_diffusion_model(latent_model_input, t, cond=text_embeddings)454455noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)456noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)457return noise_pred458459latents = self.sampler(model_fn, latents, sigmas)460461# 8. Post-processing462image = self.decode_latents(latents)463464# 9. Run safety checker465image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)466467# 10. Convert to PIL468if output_type == "pil":469image = self.numpy_to_pil(image)470471if not return_dict:472return (image, has_nsfw_concept)473474return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)475476477