Path: blob/main/examples/community/clip_guided_stable_diffusion.py
1448 views
import inspect1from typing import List, Optional, Union23import torch4from torch import nn5from torch.nn import functional as F6from torchvision import transforms7from transformers import CLIPImageProcessor, CLIPModel, CLIPTextModel, CLIPTokenizer89from diffusers import (10AutoencoderKL,11DDIMScheduler,12DiffusionPipeline,13DPMSolverMultistepScheduler,14LMSDiscreteScheduler,15PNDMScheduler,16UNet2DConditionModel,17)18from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput192021class MakeCutouts(nn.Module):22def __init__(self, cut_size, cut_power=1.0):23super().__init__()2425self.cut_size = cut_size26self.cut_power = cut_power2728def forward(self, pixel_values, num_cutouts):29sideY, sideX = pixel_values.shape[2:4]30max_size = min(sideX, sideY)31min_size = min(sideX, sideY, self.cut_size)32cutouts = []33for _ in range(num_cutouts):34size = int(torch.rand([]) ** self.cut_power * (max_size - min_size) + min_size)35offsetx = torch.randint(0, sideX - size + 1, ())36offsety = torch.randint(0, sideY - size + 1, ())37cutout = pixel_values[:, :, offsety : offsety + size, offsetx : offsetx + size]38cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))39return torch.cat(cutouts)404142def spherical_dist_loss(x, y):43x = F.normalize(x, dim=-1)44y = F.normalize(y, dim=-1)45return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)464748def set_requires_grad(model, value):49for param in model.parameters():50param.requires_grad = value515253class CLIPGuidedStableDiffusion(DiffusionPipeline):54"""CLIP guided stable diffusion based on the amazing repo by @crowsonkb and @Jack00055- https://github.com/Jack000/glid-3-xl56- https://github.dev/crowsonkb/k-diffusion57"""5859def __init__(60self,61vae: AutoencoderKL,62text_encoder: CLIPTextModel,63clip_model: CLIPModel,64tokenizer: CLIPTokenizer,65unet: UNet2DConditionModel,66scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler, DPMSolverMultistepScheduler],67feature_extractor: CLIPImageProcessor,68):69super().__init__()70self.register_modules(71vae=vae,72text_encoder=text_encoder,73clip_model=clip_model,74tokenizer=tokenizer,75unet=unet,76scheduler=scheduler,77feature_extractor=feature_extractor,78)7980self.normalize = transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)81self.cut_out_size = (82feature_extractor.size83if isinstance(feature_extractor.size, int)84else feature_extractor.size["shortest_edge"]85)86self.make_cutouts = MakeCutouts(self.cut_out_size)8788set_requires_grad(self.text_encoder, False)89set_requires_grad(self.clip_model, False)9091def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):92if slice_size == "auto":93# half the attention head size is usually a good trade-off between94# speed and memory95slice_size = self.unet.config.attention_head_dim // 296self.unet.set_attention_slice(slice_size)9798def disable_attention_slicing(self):99self.enable_attention_slicing(None)100101def freeze_vae(self):102set_requires_grad(self.vae, False)103104def unfreeze_vae(self):105set_requires_grad(self.vae, True)106107def freeze_unet(self):108set_requires_grad(self.unet, False)109110def unfreeze_unet(self):111set_requires_grad(self.unet, True)112113@torch.enable_grad()114def cond_fn(115self,116latents,117timestep,118index,119text_embeddings,120noise_pred_original,121text_embeddings_clip,122clip_guidance_scale,123num_cutouts,124use_cutouts=True,125):126latents = latents.detach().requires_grad_()127128latent_model_input = self.scheduler.scale_model_input(latents, timestep)129130# predict the noise residual131noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample132133if isinstance(self.scheduler, (PNDMScheduler, DDIMScheduler, DPMSolverMultistepScheduler)):134alpha_prod_t = self.scheduler.alphas_cumprod[timestep]135beta_prod_t = 1 - alpha_prod_t136# compute predicted original sample from predicted noise also called137# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf138pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5)139140fac = torch.sqrt(beta_prod_t)141sample = pred_original_sample * (fac) + latents * (1 - fac)142elif isinstance(self.scheduler, LMSDiscreteScheduler):143sigma = self.scheduler.sigmas[index]144sample = latents - sigma * noise_pred145else:146raise ValueError(f"scheduler type {type(self.scheduler)} not supported")147148sample = 1 / self.vae.config.scaling_factor * sample149image = self.vae.decode(sample).sample150image = (image / 2 + 0.5).clamp(0, 1)151152if use_cutouts:153image = self.make_cutouts(image, num_cutouts)154else:155image = transforms.Resize(self.cut_out_size)(image)156image = self.normalize(image).to(latents.dtype)157158image_embeddings_clip = self.clip_model.get_image_features(image)159image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True)160161if use_cutouts:162dists = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip)163dists = dists.view([num_cutouts, sample.shape[0], -1])164loss = dists.sum(2).mean(0).sum() * clip_guidance_scale165else:166loss = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip).mean() * clip_guidance_scale167168grads = -torch.autograd.grad(loss, latents)[0]169170if isinstance(self.scheduler, LMSDiscreteScheduler):171latents = latents.detach() + grads * (sigma**2)172noise_pred = noise_pred_original173else:174noise_pred = noise_pred_original - torch.sqrt(beta_prod_t) * grads175return noise_pred, latents176177@torch.no_grad()178def __call__(179self,180prompt: Union[str, List[str]],181height: Optional[int] = 512,182width: Optional[int] = 512,183num_inference_steps: Optional[int] = 50,184guidance_scale: Optional[float] = 7.5,185num_images_per_prompt: Optional[int] = 1,186eta: float = 0.0,187clip_guidance_scale: Optional[float] = 100,188clip_prompt: Optional[Union[str, List[str]]] = None,189num_cutouts: Optional[int] = 4,190use_cutouts: Optional[bool] = True,191generator: Optional[torch.Generator] = None,192latents: Optional[torch.FloatTensor] = None,193output_type: Optional[str] = "pil",194return_dict: bool = True,195):196if isinstance(prompt, str):197batch_size = 1198elif isinstance(prompt, list):199batch_size = len(prompt)200else:201raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")202203if height % 8 != 0 or width % 8 != 0:204raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")205206# get prompt text embeddings207text_input = self.tokenizer(208prompt,209padding="max_length",210max_length=self.tokenizer.model_max_length,211truncation=True,212return_tensors="pt",213)214text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]215# duplicate text embeddings for each generation per prompt216text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)217218if clip_guidance_scale > 0:219if clip_prompt is not None:220clip_text_input = self.tokenizer(221clip_prompt,222padding="max_length",223max_length=self.tokenizer.model_max_length,224truncation=True,225return_tensors="pt",226).input_ids.to(self.device)227else:228clip_text_input = text_input.input_ids.to(self.device)229text_embeddings_clip = self.clip_model.get_text_features(clip_text_input)230text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True)231# duplicate text embeddings clip for each generation per prompt232text_embeddings_clip = text_embeddings_clip.repeat_interleave(num_images_per_prompt, dim=0)233234# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)235# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`236# corresponds to doing no classifier free guidance.237do_classifier_free_guidance = guidance_scale > 1.0238# get unconditional embeddings for classifier free guidance239if do_classifier_free_guidance:240max_length = text_input.input_ids.shape[-1]241uncond_input = self.tokenizer([""], padding="max_length", max_length=max_length, return_tensors="pt")242uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]243# duplicate unconditional embeddings for each generation per prompt244uncond_embeddings = uncond_embeddings.repeat_interleave(num_images_per_prompt, dim=0)245246# For classifier free guidance, we need to do two forward passes.247# Here we concatenate the unconditional and text embeddings into a single batch248# to avoid doing two forward passes249text_embeddings = torch.cat([uncond_embeddings, text_embeddings])250251# get the initial random noise unless the user supplied it252253# Unlike in other pipelines, latents need to be generated in the target device254# for 1-to-1 results reproducibility with the CompVis implementation.255# However this currently doesn't work in `mps`.256latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)257latents_dtype = text_embeddings.dtype258if latents is None:259if self.device.type == "mps":260# randn does not work reproducibly on mps261latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(262self.device263)264else:265latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)266else:267if latents.shape != latents_shape:268raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")269latents = latents.to(self.device)270271# set timesteps272accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())273extra_set_kwargs = {}274if accepts_offset:275extra_set_kwargs["offset"] = 1276277self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)278279# Some schedulers like PNDM have timesteps as arrays280# It's more optimized to move all timesteps to correct device beforehand281timesteps_tensor = self.scheduler.timesteps.to(self.device)282283# scale the initial noise by the standard deviation required by the scheduler284latents = latents * self.scheduler.init_noise_sigma285286# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature287# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.288# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502289# and should be between [0, 1]290accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())291extra_step_kwargs = {}292if accepts_eta:293extra_step_kwargs["eta"] = eta294295# check if the scheduler accepts generator296accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())297if accepts_generator:298extra_step_kwargs["generator"] = generator299300for i, t in enumerate(self.progress_bar(timesteps_tensor)):301# expand the latents if we are doing classifier free guidance302latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents303latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)304305# predict the noise residual306noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample307308# perform classifier free guidance309if do_classifier_free_guidance:310noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)311noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)312313# perform clip guidance314if clip_guidance_scale > 0:315text_embeddings_for_guidance = (316text_embeddings.chunk(2)[1] if do_classifier_free_guidance else text_embeddings317)318noise_pred, latents = self.cond_fn(319latents,320t,321i,322text_embeddings_for_guidance,323noise_pred,324text_embeddings_clip,325clip_guidance_scale,326num_cutouts,327use_cutouts,328)329330# compute the previous noisy sample x_t -> x_t-1331latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample332333# scale and decode the image latents with vae334latents = 1 / self.vae.config.scaling_factor * latents335image = self.vae.decode(latents).sample336337image = (image / 2 + 0.5).clamp(0, 1)338image = image.cpu().permute(0, 2, 3, 1).numpy()339340if output_type == "pil":341image = self.numpy_to_pil(image)342343if not return_dict:344return (image, None)345346return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None)347348349