Path: blob/main/examples/community/clip_guided_stable_diffusion_img2img.py
1448 views
import inspect1from typing import List, Optional, Union23import numpy as np4import PIL5import torch6from torch import nn7from torch.nn import functional as F8from torchvision import transforms9from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextModel, CLIPTokenizer1011from diffusers import (12AutoencoderKL,13DDIMScheduler,14DiffusionPipeline,15DPMSolverMultistepScheduler,16LMSDiscreteScheduler,17PNDMScheduler,18UNet2DConditionModel,19)20from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput21from diffusers.utils import (22PIL_INTERPOLATION,23deprecate,24randn_tensor,25)262728EXAMPLE_DOC_STRING = """29Examples:30```31from io import BytesIO3233import requests34import torch35from diffusers import DiffusionPipeline36from PIL import Image37from transformers import CLIPFeatureExtractor, CLIPModel3839feature_extractor = CLIPFeatureExtractor.from_pretrained(40"laion/CLIP-ViT-B-32-laion2B-s34B-b79K"41)42clip_model = CLIPModel.from_pretrained(43"laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.float1644)454647guided_pipeline = DiffusionPipeline.from_pretrained(48"CompVis/stable-diffusion-v1-4",49# custom_pipeline="clip_guided_stable_diffusion",50custom_pipeline="/home/njindal/diffusers/examples/community/clip_guided_stable_diffusion.py",51clip_model=clip_model,52feature_extractor=feature_extractor,53torch_dtype=torch.float16,54)55guided_pipeline.enable_attention_slicing()56guided_pipeline = guided_pipeline.to("cuda")5758prompt = "fantasy book cover, full moon, fantasy forest landscape, golden vector elements, fantasy magic, dark light night, intricate, elegant, sharp focus, illustration, highly detailed, digital painting, concept art, matte, art by WLOP and Artgerm and Albert Bierstadt, masterpiece"5960url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"6162response = requests.get(url)63init_image = Image.open(BytesIO(response.content)).convert("RGB")6465image = guided_pipeline(66prompt=prompt,67num_inference_steps=30,68image=init_image,69strength=0.75,70guidance_scale=7.5,71clip_guidance_scale=100,72num_cutouts=4,73use_cutouts=False,74).images[0]75display(image)76```77"""787980def preprocess(image, w, h):81if isinstance(image, torch.Tensor):82return image83elif isinstance(image, PIL.Image.Image):84image = [image]8586if isinstance(image[0], PIL.Image.Image):87image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]88image = np.concatenate(image, axis=0)89image = np.array(image).astype(np.float32) / 255.090image = image.transpose(0, 3, 1, 2)91image = 2.0 * image - 1.092image = torch.from_numpy(image)93elif isinstance(image[0], torch.Tensor):94image = torch.cat(image, dim=0)95return image969798class MakeCutouts(nn.Module):99def __init__(self, cut_size, cut_power=1.0):100super().__init__()101102self.cut_size = cut_size103self.cut_power = cut_power104105def forward(self, pixel_values, num_cutouts):106sideY, sideX = pixel_values.shape[2:4]107max_size = min(sideX, sideY)108min_size = min(sideX, sideY, self.cut_size)109cutouts = []110for _ in range(num_cutouts):111size = int(torch.rand([]) ** self.cut_power * (max_size - min_size) + min_size)112offsetx = torch.randint(0, sideX - size + 1, ())113offsety = torch.randint(0, sideY - size + 1, ())114cutout = pixel_values[:, :, offsety : offsety + size, offsetx : offsetx + size]115cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))116return torch.cat(cutouts)117118119def spherical_dist_loss(x, y):120x = F.normalize(x, dim=-1)121y = F.normalize(y, dim=-1)122return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)123124125def set_requires_grad(model, value):126for param in model.parameters():127param.requires_grad = value128129130class CLIPGuidedStableDiffusion(DiffusionPipeline):131"""CLIP guided stable diffusion based on the amazing repo by @crowsonkb and @Jack000132- https://github.com/Jack000/glid-3-xl133- https://github.dev/crowsonkb/k-diffusion134"""135136def __init__(137self,138vae: AutoencoderKL,139text_encoder: CLIPTextModel,140clip_model: CLIPModel,141tokenizer: CLIPTokenizer,142unet: UNet2DConditionModel,143scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler, DPMSolverMultistepScheduler],144feature_extractor: CLIPFeatureExtractor,145):146super().__init__()147self.register_modules(148vae=vae,149text_encoder=text_encoder,150clip_model=clip_model,151tokenizer=tokenizer,152unet=unet,153scheduler=scheduler,154feature_extractor=feature_extractor,155)156157self.normalize = transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)158self.cut_out_size = (159feature_extractor.size160if isinstance(feature_extractor.size, int)161else feature_extractor.size["shortest_edge"]162)163self.make_cutouts = MakeCutouts(self.cut_out_size)164165set_requires_grad(self.text_encoder, False)166set_requires_grad(self.clip_model, False)167168def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):169if slice_size == "auto":170# half the attention head size is usually a good trade-off between171# speed and memory172slice_size = self.unet.config.attention_head_dim // 2173self.unet.set_attention_slice(slice_size)174175def disable_attention_slicing(self):176self.enable_attention_slicing(None)177178def freeze_vae(self):179set_requires_grad(self.vae, False)180181def unfreeze_vae(self):182set_requires_grad(self.vae, True)183184def freeze_unet(self):185set_requires_grad(self.unet, False)186187def unfreeze_unet(self):188set_requires_grad(self.unet, True)189190def get_timesteps(self, num_inference_steps, strength, device):191# get the original timestep using init_timestep192init_timestep = min(int(num_inference_steps * strength), num_inference_steps)193194t_start = max(num_inference_steps - init_timestep, 0)195timesteps = self.scheduler.timesteps[t_start:]196197return timesteps, num_inference_steps - t_start198199def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):200if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):201raise ValueError(202f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"203)204205image = image.to(device=device, dtype=dtype)206207batch_size = batch_size * num_images_per_prompt208if isinstance(generator, list) and len(generator) != batch_size:209raise ValueError(210f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"211f" size of {batch_size}. Make sure the batch size matches the length of the generators."212)213214if isinstance(generator, list):215init_latents = [216self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)217]218init_latents = torch.cat(init_latents, dim=0)219else:220init_latents = self.vae.encode(image).latent_dist.sample(generator)221222init_latents = self.vae.config.scaling_factor * init_latents223224if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:225# expand init_latents for batch_size226deprecation_message = (227f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"228" images (`image`). Initial images are now duplicating to match the number of text prompts. Note"229" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"230" your script to pass as many initial images as text prompts to suppress this warning."231)232deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)233additional_image_per_prompt = batch_size // init_latents.shape[0]234init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)235elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:236raise ValueError(237f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."238)239else:240init_latents = torch.cat([init_latents], dim=0)241242shape = init_latents.shape243noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)244245# get latents246init_latents = self.scheduler.add_noise(init_latents, noise, timestep)247latents = init_latents248249return latents250251@torch.enable_grad()252def cond_fn(253self,254latents,255timestep,256index,257text_embeddings,258noise_pred_original,259text_embeddings_clip,260clip_guidance_scale,261num_cutouts,262use_cutouts=True,263):264latents = latents.detach().requires_grad_()265266latent_model_input = self.scheduler.scale_model_input(latents, timestep)267268# predict the noise residual269noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample270271if isinstance(self.scheduler, (PNDMScheduler, DDIMScheduler, DPMSolverMultistepScheduler)):272alpha_prod_t = self.scheduler.alphas_cumprod[timestep]273beta_prod_t = 1 - alpha_prod_t274# compute predicted original sample from predicted noise also called275# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf276pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5)277278fac = torch.sqrt(beta_prod_t)279sample = pred_original_sample * (fac) + latents * (1 - fac)280elif isinstance(self.scheduler, LMSDiscreteScheduler):281sigma = self.scheduler.sigmas[index]282sample = latents - sigma * noise_pred283else:284raise ValueError(f"scheduler type {type(self.scheduler)} not supported")285286sample = 1 / self.vae.config.scaling_factor * sample287image = self.vae.decode(sample).sample288image = (image / 2 + 0.5).clamp(0, 1)289290if use_cutouts:291image = self.make_cutouts(image, num_cutouts)292else:293image = transforms.Resize(self.cut_out_size)(image)294image = self.normalize(image).to(latents.dtype)295296image_embeddings_clip = self.clip_model.get_image_features(image)297image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True)298299if use_cutouts:300dists = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip)301dists = dists.view([num_cutouts, sample.shape[0], -1])302loss = dists.sum(2).mean(0).sum() * clip_guidance_scale303else:304loss = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip).mean() * clip_guidance_scale305306grads = -torch.autograd.grad(loss, latents)[0]307308if isinstance(self.scheduler, LMSDiscreteScheduler):309latents = latents.detach() + grads * (sigma**2)310noise_pred = noise_pred_original311else:312noise_pred = noise_pred_original - torch.sqrt(beta_prod_t) * grads313return noise_pred, latents314315@torch.no_grad()316def __call__(317self,318prompt: Union[str, List[str]],319height: Optional[int] = 512,320width: Optional[int] = 512,321image: Union[torch.FloatTensor, PIL.Image.Image] = None,322strength: float = 0.8,323num_inference_steps: Optional[int] = 50,324guidance_scale: Optional[float] = 7.5,325num_images_per_prompt: Optional[int] = 1,326eta: float = 0.0,327clip_guidance_scale: Optional[float] = 100,328clip_prompt: Optional[Union[str, List[str]]] = None,329num_cutouts: Optional[int] = 4,330use_cutouts: Optional[bool] = True,331generator: Optional[torch.Generator] = None,332latents: Optional[torch.FloatTensor] = None,333output_type: Optional[str] = "pil",334return_dict: bool = True,335):336if isinstance(prompt, str):337batch_size = 1338elif isinstance(prompt, list):339batch_size = len(prompt)340else:341raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")342343if height % 8 != 0 or width % 8 != 0:344raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")345346# get prompt text embeddings347text_input = self.tokenizer(348prompt,349padding="max_length",350max_length=self.tokenizer.model_max_length,351truncation=True,352return_tensors="pt",353)354text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]355# duplicate text embeddings for each generation per prompt356text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)357358# set timesteps359accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())360extra_set_kwargs = {}361if accepts_offset:362extra_set_kwargs["offset"] = 1363364self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)365# Some schedulers like PNDM have timesteps as arrays366# It's more optimized to move all timesteps to correct device beforehand367self.scheduler.timesteps.to(self.device)368369timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, self.device)370latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)371372# Preprocess image373image = preprocess(image, width, height)374latents = self.prepare_latents(375image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, self.device, generator376)377378if clip_guidance_scale > 0:379if clip_prompt is not None:380clip_text_input = self.tokenizer(381clip_prompt,382padding="max_length",383max_length=self.tokenizer.model_max_length,384truncation=True,385return_tensors="pt",386).input_ids.to(self.device)387else:388clip_text_input = text_input.input_ids.to(self.device)389text_embeddings_clip = self.clip_model.get_text_features(clip_text_input)390text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True)391# duplicate text embeddings clip for each generation per prompt392text_embeddings_clip = text_embeddings_clip.repeat_interleave(num_images_per_prompt, dim=0)393394# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)395# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`396# corresponds to doing no classifier free guidance.397do_classifier_free_guidance = guidance_scale > 1.0398# get unconditional embeddings for classifier free guidance399if do_classifier_free_guidance:400max_length = text_input.input_ids.shape[-1]401uncond_input = self.tokenizer([""], padding="max_length", max_length=max_length, return_tensors="pt")402uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]403# duplicate unconditional embeddings for each generation per prompt404uncond_embeddings = uncond_embeddings.repeat_interleave(num_images_per_prompt, dim=0)405406# For classifier free guidance, we need to do two forward passes.407# Here we concatenate the unconditional and text embeddings into a single batch408# to avoid doing two forward passes409text_embeddings = torch.cat([uncond_embeddings, text_embeddings])410411# get the initial random noise unless the user supplied it412413# Unlike in other pipelines, latents need to be generated in the target device414# for 1-to-1 results reproducibility with the CompVis implementation.415# However this currently doesn't work in `mps`.416latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)417latents_dtype = text_embeddings.dtype418if latents is None:419if self.device.type == "mps":420# randn does not work reproducibly on mps421latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(422self.device423)424else:425latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)426else:427if latents.shape != latents_shape:428raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")429latents = latents.to(self.device)430431# scale the initial noise by the standard deviation required by the scheduler432latents = latents * self.scheduler.init_noise_sigma433434# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature435# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.436# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502437# and should be between [0, 1]438accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())439extra_step_kwargs = {}440if accepts_eta:441extra_step_kwargs["eta"] = eta442443# check if the scheduler accepts generator444accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())445if accepts_generator:446extra_step_kwargs["generator"] = generator447448with self.progress_bar(total=num_inference_steps):449for i, t in enumerate(timesteps):450# expand the latents if we are doing classifier free guidance451latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents452latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)453454# predict the noise residual455noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample456457# perform classifier free guidance458if do_classifier_free_guidance:459noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)460noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)461462# perform clip guidance463if clip_guidance_scale > 0:464text_embeddings_for_guidance = (465text_embeddings.chunk(2)[1] if do_classifier_free_guidance else text_embeddings466)467noise_pred, latents = self.cond_fn(468latents,469t,470i,471text_embeddings_for_guidance,472noise_pred,473text_embeddings_clip,474clip_guidance_scale,475num_cutouts,476use_cutouts,477)478479# compute the previous noisy sample x_t -> x_t-1480latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample481482# scale and decode the image latents with vae483latents = 1 / self.vae.config.scaling_factor * latents484image = self.vae.decode(latents).sample485486image = (image / 2 + 0.5).clamp(0, 1)487image = image.cpu().permute(0, 2, 3, 1).numpy()488489if output_type == "pil":490image = self.numpy_to_pil(image)491492if not return_dict:493return (image, None)494495return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None)496497498