Path: blob/main/examples/community/magic_mix.py
1448 views
from typing import Union12import torch3from PIL import Image4from torchvision import transforms as tfms5from tqdm.auto import tqdm6from transformers import CLIPTextModel, CLIPTokenizer78from diffusers import (9AutoencoderKL,10DDIMScheduler,11DiffusionPipeline,12LMSDiscreteScheduler,13PNDMScheduler,14UNet2DConditionModel,15)161718class MagicMixPipeline(DiffusionPipeline):19def __init__(20self,21vae: AutoencoderKL,22text_encoder: CLIPTextModel,23tokenizer: CLIPTokenizer,24unet: UNet2DConditionModel,25scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler],26):27super().__init__()2829self.register_modules(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler)3031# convert PIL image to latents32def encode(self, img):33with torch.no_grad():34latent = self.vae.encode(tfms.ToTensor()(img).unsqueeze(0).to(self.device) * 2 - 1)35latent = 0.18215 * latent.latent_dist.sample()36return latent3738# convert latents to PIL image39def decode(self, latent):40latent = (1 / 0.18215) * latent41with torch.no_grad():42img = self.vae.decode(latent).sample43img = (img / 2 + 0.5).clamp(0, 1)44img = img.detach().cpu().permute(0, 2, 3, 1).numpy()45img = (img * 255).round().astype("uint8")46return Image.fromarray(img[0])4748# convert prompt into text embeddings, also unconditional embeddings49def prep_text(self, prompt):50text_input = self.tokenizer(51prompt,52padding="max_length",53max_length=self.tokenizer.model_max_length,54truncation=True,55return_tensors="pt",56)5758text_embedding = self.text_encoder(text_input.input_ids.to(self.device))[0]5960uncond_input = self.tokenizer(61"",62padding="max_length",63max_length=self.tokenizer.model_max_length,64truncation=True,65return_tensors="pt",66)6768uncond_embedding = self.text_encoder(uncond_input.input_ids.to(self.device))[0]6970return torch.cat([uncond_embedding, text_embedding])7172def __call__(73self,74img: Image.Image,75prompt: str,76kmin: float = 0.3,77kmax: float = 0.6,78mix_factor: float = 0.5,79seed: int = 42,80steps: int = 50,81guidance_scale: float = 7.5,82) -> Image.Image:83tmin = steps - int(kmin * steps)84tmax = steps - int(kmax * steps)8586text_embeddings = self.prep_text(prompt)8788self.scheduler.set_timesteps(steps)8990width, height = img.size91encoded = self.encode(img)9293torch.manual_seed(seed)94noise = torch.randn(95(1, self.unet.in_channels, height // 8, width // 8),96).to(self.device)9798latents = self.scheduler.add_noise(99encoded,100noise,101timesteps=self.scheduler.timesteps[tmax],102)103104input = torch.cat([latents] * 2)105106input = self.scheduler.scale_model_input(input, self.scheduler.timesteps[tmax])107108with torch.no_grad():109pred = self.unet(110input,111self.scheduler.timesteps[tmax],112encoder_hidden_states=text_embeddings,113).sample114115pred_uncond, pred_text = pred.chunk(2)116pred = pred_uncond + guidance_scale * (pred_text - pred_uncond)117118latents = self.scheduler.step(pred, self.scheduler.timesteps[tmax], latents).prev_sample119120for i, t in enumerate(tqdm(self.scheduler.timesteps)):121if i > tmax:122if i < tmin: # layout generation phase123orig_latents = self.scheduler.add_noise(124encoded,125noise,126timesteps=t,127)128129input = (mix_factor * latents) + (1301 - mix_factor131) * orig_latents # interpolating between layout noise and conditionally generated noise to preserve layout sematics132input = torch.cat([input] * 2)133134else: # content generation phase135input = torch.cat([latents] * 2)136137input = self.scheduler.scale_model_input(input, t)138139with torch.no_grad():140pred = self.unet(141input,142t,143encoder_hidden_states=text_embeddings,144).sample145146pred_uncond, pred_text = pred.chunk(2)147pred = pred_uncond + guidance_scale * (pred_text - pred_uncond)148149latents = self.scheduler.step(pred, t, latents).prev_sample150151return self.decode(latents)152153154