Path: blob/main/examples/imagic/train_imagic.py
1448 views
import argparse1import math2import os3from pathlib import Path4from typing import Optional56import torch7import torch.nn.functional as F8import torch.utils.checkpoint910from accelerate import Accelerator11from accelerate.logging import get_logger12from accelerate.utils import set_seed13from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel14from huggingface_hub import HfFolder, Repository, whoami15from PIL import Image16import numpy as np17from torchvision import transforms18from tqdm.auto import tqdm19from transformers import CLIPTextModel, CLIPTokenizer202122logger = get_logger(__name__)232425def parse_args():26parser = argparse.ArgumentParser(description="Simple example of a training script.")27parser.add_argument(28"--pretrained_model_name_or_path",29type=str,30default=None,31required=True,32help="Path to pretrained model or model identifier from huggingface.co/models.",33)34parser.add_argument(35"--tokenizer_name",36type=str,37default=None,38help="Pretrained tokenizer name or path if not the same as model_name",39)40parser.add_argument(41"--input_image",42type=str,43default=None,44required=True,45help="Path to input image to edit.",46)47parser.add_argument(48"--target_text",49type=str,50default=None,51help="The target text describing the output image.",52)53parser.add_argument(54"--output_dir",55type=str,56default="text-inversion-model",57help="The output directory where the model predictions and checkpoints will be written.",58)59parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")60parser.add_argument(61"--resolution",62type=int,63default=512,64help=(65"The resolution for input images, all the images in the train/validation dataset will be resized to this"66" resolution"67),68)69parser.add_argument(70"--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"71)72parser.add_argument(73"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."74)75parser.add_argument(76"--emb_train_steps",77type=int,78default=500,79help="Total number of training steps to perform.",80)81parser.add_argument(82"--max_train_steps",83type=int,84default=1000,85help="Total number of training steps to perform.",86)87parser.add_argument(88"--gradient_accumulation_steps",89type=int,90default=1,91help="Number of updates steps to accumulate before performing a backward/update pass.",92)93parser.add_argument(94"--gradient_checkpointing",95action="store_true",96help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",97)98parser.add_argument(99"--emb_learning_rate",100type=float,101default=1e-3,102help="Learning rate for optimizing the embeddings.",103)104parser.add_argument(105"--learning_rate",106type=float,107default=1e-6,108help="Learning rate for fine tuning the model.",109)110parser.add_argument(111"--scale_lr",112action="store_true",113default=False,114help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",115)116parser.add_argument(117"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."118)119parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")120parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")121parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")122parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")123parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")124parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")125parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")126parser.add_argument(127"--hub_model_id",128type=str,129default=None,130help="The name of the repository to keep in sync with the local `output_dir`.",131)132parser.add_argument(133"--logging_dir",134type=str,135default="logs",136help=(137"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"138" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."139),140)141parser.add_argument("--log_interval", type=int, default=10, help="Log every N steps.")142parser.add_argument(143"--mixed_precision",144type=str,145default="no",146choices=["no", "fp16", "bf16"],147help=(148"Whether to use mixed precision. Choose"149"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."150"and an Nvidia Ampere GPU."151),152)153parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")154155args = parser.parse_args()156env_local_rank = int(os.environ.get("LOCAL_RANK", -1))157if env_local_rank != -1 and env_local_rank != args.local_rank:158args.local_rank = env_local_rank159160return args161162163class AverageMeter:164def __init__(self, name=None):165self.name = name166self.reset()167168def reset(self):169self.sum = self.count = self.avg = 0170171def update(self, val, n=1):172self.sum += val * n173self.count += n174self.avg = self.sum / self.count175176177def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):178if token is None:179token = HfFolder.get_token()180if organization is None:181username = whoami(token)["name"]182return f"{username}/{model_id}"183else:184return f"{organization}/{model_id}"185186187def main():188args = parse_args()189logging_dir = Path(args.output_dir, args.logging_dir)190191accelerator = Accelerator(192gradient_accumulation_steps=args.gradient_accumulation_steps,193mixed_precision=args.mixed_precision,194log_with="tensorboard",195logging_dir=logging_dir,196)197198if args.seed is not None:199set_seed(args.seed)200201# Handle the repository creation202if accelerator.is_main_process:203if args.push_to_hub:204if args.hub_model_id is None:205repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)206else:207repo_name = args.hub_model_id208repo = Repository(args.output_dir, clone_from=repo_name)209210with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:211if "step_*" not in gitignore:212gitignore.write("step_*\n")213if "epoch_*" not in gitignore:214gitignore.write("epoch_*\n")215elif args.output_dir is not None:216os.makedirs(args.output_dir, exist_ok=True)217218# Load the tokenizer219if args.tokenizer_name:220tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)221elif args.pretrained_model_name_or_path:222tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer", use_auth_token=True)223224# Load models and create wrapper for stable diffusion225text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder", use_auth_token=True)226vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", use_auth_token=True)227unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet", use_auth_token=True)228229if args.gradient_checkpointing:230unet.enable_gradient_checkpointing()231232if args.scale_lr:233args.learning_rate = (234args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes235)236237# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs238if args.use_8bit_adam:239try:240import bitsandbytes as bnb241except ImportError:242raise ImportError(243"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."244)245246optimizer_class = bnb.optim.Adam8bit247else:248optimizer_class = torch.optim.Adam249250noise_scheduler = DDPMScheduler(251beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000252)253254weight_dtype = torch.float32255if args.mixed_precision == "fp16":256weight_dtype = torch.float16257elif args.mixed_precision == "bf16":258weight_dtype = torch.bfloat16259260# Move text_encode and vae to gpu.261# For mixed precision training we cast the text_encoder and vae weights to half-precision262# as these models are only used for inference, keeping weights in full precision is not required.263text_encoder.to(accelerator.device, dtype=weight_dtype)264vae.to(accelerator.device, dtype=weight_dtype)265266# Encode the input image.267input_image = Image.open(args.input_image).convert("RGB")268269image_transforms = transforms.Compose(270[271transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),272transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),273transforms.ToTensor(),274transforms.Normalize([0.5], [0.5]),275]276)277278init_image = image_transforms(input_image)279init_image = init_image[None].to(device=accelerator.device, dtype=weight_dtype)280with torch.inference_mode():281init_latents = vae.encode(init_image).latent_dist.sample()282init_latents = 0.18215 * init_latents283284# Encode the target text.285text_ids = tokenizer(286args.target_text,287padding="max_length",288truncation=True,289max_length=tokenizer.model_max_length,290return_tensors="pt",291).input_ids292293text_ids = text_ids.to(device=accelerator.device)294with torch.inference_mode():295target_embeddings = text_encoder(text_ids)[0]296297del vae, text_encoder298if torch.cuda.is_available():299torch.cuda.empty_cache()300301target_embeddings = target_embeddings.float()302optimized_embeddings = target_embeddings.clone()303304# Optimize the text embeddings first.305optimized_embeddings.requires_grad_(True)306optimizer = optimizer_class(307[optimized_embeddings], # only optimize embeddings308lr=args.emb_learning_rate,309betas=(args.adam_beta1, args.adam_beta2),310# weight_decay=args.adam_weight_decay,311eps=args.adam_epsilon,312)313314unet, optimizer = accelerator.prepare(unet, optimizer)315316# We need to initialize the trackers we use, and also store our configuration.317# The trackers initializes automatically on the main process.318if accelerator.is_main_process:319accelerator.init_trackers("imagic", config=vars(args))320321def train_loop(pbar, optimizer, params):322loss_avg = AverageMeter()323for step in pbar:324with accelerator.accumulate(unet):325noise = torch.randn_like(init_latents)326bsz = init_latents.shape[0]327# Sample a random timestep for each image328timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=init_latents.device)329timesteps = timesteps.long()330331# Add noise to the latents according to the noise magnitude at each timestep332# (this is the forward diffusion process)333noisy_latents = noise_scheduler.add_noise(init_latents, noise, timesteps)334335noise_pred = unet(noisy_latents, timesteps, optimized_embeddings).sample336337loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")338339accelerator.backward(loss)340# if accelerator.sync_gradients: # results aren't good with it, may be will need more training with it.341# accelerator.clip_grad_norm_(params, args.max_grad_norm)342optimizer.step()343optimizer.zero_grad(set_to_none=True)344loss_avg.update(loss.detach_(), bsz)345346if not step % args.log_interval:347logs = {"loss": loss_avg.avg.item()}348progress_bar.set_postfix(**logs)349accelerator.log(logs, step=step)350351accelerator.wait_for_everyone()352353progress_bar = tqdm(range(args.emb_train_steps), disable=not accelerator.is_local_main_process)354progress_bar.set_description("Optimizing embedding")355356train_loop(progress_bar, optimizer, optimized_embeddings)357358optimized_embeddings.requires_grad_(False)359if accelerator.is_main_process:360torch.save(target_embeddings.cpu(), os.path.join(args.output_dir, "target_embeddings.pt"))361torch.save(optimized_embeddings.cpu(), os.path.join(args.output_dir, "optimized_embeddings.pt"))362with open(os.path.join(args.output_dir, "target_text.txt"), "w") as f:363f.write(args.target_text)364365# Fine tune the diffusion model.366optimizer = optimizer_class(367accelerator.unwrap_model(unet).parameters(),368lr=args.learning_rate,369betas=(args.adam_beta1, args.adam_beta2),370# weight_decay=args.adam_weight_decay,371eps=args.adam_epsilon,372)373optimizer = accelerator.prepare(optimizer)374375progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)376progress_bar.set_description("Fine Tuning")377unet.train()378379train_loop(progress_bar, optimizer, unet.parameters())380381# Create the pipeline using using the trained modules and save it.382if accelerator.is_main_process:383pipeline = StableDiffusionPipeline.from_pretrained(384args.pretrained_model_name_or_path,385unet=accelerator.unwrap_model(unet),386use_auth_token=True387)388pipeline.save_pretrained(args.output_dir)389390if args.push_to_hub:391repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)392393accelerator.end_training()394395396if __name__ == "__main__":397main()398399400