Path: blob/main/examples/research_projects/colossalai/train_dreambooth_colossalai.py
1979 views
import argparse1import hashlib2import math3import os4from pathlib import Path5from typing import Optional67import colossalai8import torch9import torch.nn.functional as F10import torch.utils.checkpoint11from colossalai.context.parallel_mode import ParallelMode12from colossalai.core import global_context as gpc13from colossalai.logging import disable_existing_loggers, get_dist_logger14from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer15from colossalai.nn.parallel.utils import get_static_torch_model16from colossalai.utils import get_current_device17from colossalai.utils.model.colo_init_context import ColoInitContext18from huggingface_hub import HfFolder, Repository, create_repo, whoami19from PIL import Image20from torch.utils.data import Dataset21from torchvision import transforms22from tqdm.auto import tqdm23from transformers import AutoTokenizer, PretrainedConfig2425from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel26from diffusers.optimization import get_scheduler272829disable_existing_loggers()30logger = get_dist_logger()313233def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str):34text_encoder_config = PretrainedConfig.from_pretrained(35pretrained_model_name_or_path,36subfolder="text_encoder",37revision=args.revision,38)39model_class = text_encoder_config.architectures[0]4041if model_class == "CLIPTextModel":42from transformers import CLIPTextModel4344return CLIPTextModel45elif model_class == "RobertaSeriesModelWithTransformation":46from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation4748return RobertaSeriesModelWithTransformation49else:50raise ValueError(f"{model_class} is not supported.")515253def parse_args(input_args=None):54parser = argparse.ArgumentParser(description="Simple example of a training script.")55parser.add_argument(56"--pretrained_model_name_or_path",57type=str,58default=None,59required=True,60help="Path to pretrained model or model identifier from huggingface.co/models.",61)62parser.add_argument(63"--revision",64type=str,65default=None,66required=False,67help="Revision of pretrained model identifier from huggingface.co/models.",68)69parser.add_argument(70"--tokenizer_name",71type=str,72default=None,73help="Pretrained tokenizer name or path if not the same as model_name",74)75parser.add_argument(76"--instance_data_dir",77type=str,78default=None,79required=True,80help="A folder containing the training data of instance images.",81)82parser.add_argument(83"--class_data_dir",84type=str,85default=None,86required=False,87help="A folder containing the training data of class images.",88)89parser.add_argument(90"--instance_prompt",91type=str,92default="a photo of sks dog",93required=False,94help="The prompt with identifier specifying the instance",95)96parser.add_argument(97"--class_prompt",98type=str,99default=None,100help="The prompt to specify images in the same class as provided instance images.",101)102parser.add_argument(103"--with_prior_preservation",104default=False,105action="store_true",106help="Flag to add prior preservation loss.",107)108parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")109parser.add_argument(110"--num_class_images",111type=int,112default=100,113help=(114"Minimal class images for prior preservation loss. If there are not enough images already present in"115" class_data_dir, additional images will be sampled with class_prompt."116),117)118parser.add_argument(119"--output_dir",120type=str,121default="text-inversion-model",122help="The output directory where the model predictions and checkpoints will be written.",123)124parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")125parser.add_argument(126"--resolution",127type=int,128default=512,129help=(130"The resolution for input images, all the images in the train/validation dataset will be resized to this"131" resolution"132),133)134parser.add_argument(135"--placement",136type=str,137default="cpu",138help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",139)140parser.add_argument(141"--center_crop",142default=False,143action="store_true",144help=(145"Whether to center crop the input images to the resolution. If not set, the images will be randomly"146" cropped. The images will be resized to the resolution first before cropping."147),148)149parser.add_argument(150"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."151)152parser.add_argument(153"--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."154)155parser.add_argument("--num_train_epochs", type=int, default=1)156parser.add_argument(157"--max_train_steps",158type=int,159default=None,160help="Total number of training steps to perform. If provided, overrides num_train_epochs.",161)162parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.")163parser.add_argument(164"--gradient_checkpointing",165action="store_true",166help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",167)168parser.add_argument(169"--learning_rate",170type=float,171default=5e-6,172help="Initial learning rate (after the potential warmup period) to use.",173)174parser.add_argument(175"--scale_lr",176action="store_true",177default=False,178help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",179)180parser.add_argument(181"--lr_scheduler",182type=str,183default="constant",184help=(185'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'186' "constant", "constant_with_warmup"]'187),188)189parser.add_argument(190"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."191)192parser.add_argument(193"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."194)195196parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")197parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")198parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")199parser.add_argument(200"--hub_model_id",201type=str,202default=None,203help="The name of the repository to keep in sync with the local `output_dir`.",204)205parser.add_argument(206"--logging_dir",207type=str,208default="logs",209help=(210"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"211" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."212),213)214parser.add_argument(215"--mixed_precision",216type=str,217default=None,218choices=["no", "fp16", "bf16"],219help=(220"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="221" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"222" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."223),224)225parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")226227if input_args is not None:228args = parser.parse_args(input_args)229else:230args = parser.parse_args()231232env_local_rank = int(os.environ.get("LOCAL_RANK", -1))233if env_local_rank != -1 and env_local_rank != args.local_rank:234args.local_rank = env_local_rank235236if args.with_prior_preservation:237if args.class_data_dir is None:238raise ValueError("You must specify a data directory for class images.")239if args.class_prompt is None:240raise ValueError("You must specify prompt for class images.")241else:242if args.class_data_dir is not None:243logger.warning("You need not use --class_data_dir without --with_prior_preservation.")244if args.class_prompt is not None:245logger.warning("You need not use --class_prompt without --with_prior_preservation.")246247return args248249250class DreamBoothDataset(Dataset):251"""252A dataset to prepare the instance and class images with the prompts for fine-tuning the model.253It pre-processes the images and the tokenizes prompts.254"""255256def __init__(257self,258instance_data_root,259instance_prompt,260tokenizer,261class_data_root=None,262class_prompt=None,263size=512,264center_crop=False,265):266self.size = size267self.center_crop = center_crop268self.tokenizer = tokenizer269270self.instance_data_root = Path(instance_data_root)271if not self.instance_data_root.exists():272raise ValueError("Instance images root doesn't exists.")273274self.instance_images_path = list(Path(instance_data_root).iterdir())275self.num_instance_images = len(self.instance_images_path)276self.instance_prompt = instance_prompt277self._length = self.num_instance_images278279if class_data_root is not None:280self.class_data_root = Path(class_data_root)281self.class_data_root.mkdir(parents=True, exist_ok=True)282self.class_images_path = list(self.class_data_root.iterdir())283self.num_class_images = len(self.class_images_path)284self._length = max(self.num_class_images, self.num_instance_images)285self.class_prompt = class_prompt286else:287self.class_data_root = None288289self.image_transforms = transforms.Compose(290[291transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),292transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),293transforms.ToTensor(),294transforms.Normalize([0.5], [0.5]),295]296)297298def __len__(self):299return self._length300301def __getitem__(self, index):302example = {}303instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])304if not instance_image.mode == "RGB":305instance_image = instance_image.convert("RGB")306example["instance_images"] = self.image_transforms(instance_image)307example["instance_prompt_ids"] = self.tokenizer(308self.instance_prompt,309padding="do_not_pad",310truncation=True,311max_length=self.tokenizer.model_max_length,312).input_ids313314if self.class_data_root:315class_image = Image.open(self.class_images_path[index % self.num_class_images])316if not class_image.mode == "RGB":317class_image = class_image.convert("RGB")318example["class_images"] = self.image_transforms(class_image)319example["class_prompt_ids"] = self.tokenizer(320self.class_prompt,321padding="do_not_pad",322truncation=True,323max_length=self.tokenizer.model_max_length,324).input_ids325326return example327328329class PromptDataset(Dataset):330"A simple dataset to prepare the prompts to generate class images on multiple GPUs."331332def __init__(self, prompt, num_samples):333self.prompt = prompt334self.num_samples = num_samples335336def __len__(self):337return self.num_samples338339def __getitem__(self, index):340example = {}341example["prompt"] = self.prompt342example["index"] = index343return example344345346def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):347if token is None:348token = HfFolder.get_token()349if organization is None:350username = whoami(token)["name"]351return f"{username}/{model_id}"352else:353return f"{organization}/{model_id}"354355356# Gemini + ZeRO DDP357def gemini_zero_dpp(model: torch.nn.Module, placememt_policy: str = "auto"):358from colossalai.nn.parallel import GeminiDDP359360model = GeminiDDP(361model, device=get_current_device(), placement_policy=placememt_policy, pin_memory=True, search_range_mb=64362)363return model364365366def main(args):367if args.seed is None:368colossalai.launch_from_torch(config={})369else:370colossalai.launch_from_torch(config={}, seed=args.seed)371372local_rank = gpc.get_local_rank(ParallelMode.DATA)373world_size = gpc.get_world_size(ParallelMode.DATA)374375if args.with_prior_preservation:376class_images_dir = Path(args.class_data_dir)377if not class_images_dir.exists():378class_images_dir.mkdir(parents=True)379cur_class_images = len(list(class_images_dir.iterdir()))380381if cur_class_images < args.num_class_images:382torch_dtype = torch.float16 if get_current_device() == "cuda" else torch.float32383pipeline = DiffusionPipeline.from_pretrained(384args.pretrained_model_name_or_path,385torch_dtype=torch_dtype,386safety_checker=None,387revision=args.revision,388)389pipeline.set_progress_bar_config(disable=True)390391num_new_images = args.num_class_images - cur_class_images392logger.info(f"Number of class images to sample: {num_new_images}.")393394sample_dataset = PromptDataset(args.class_prompt, num_new_images)395sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)396397pipeline.to(get_current_device())398399for example in tqdm(400sample_dataloader,401desc="Generating class images",402disable=not local_rank == 0,403):404images = pipeline(example["prompt"]).images405406for i, image in enumerate(images):407hash_image = hashlib.sha1(image.tobytes()).hexdigest()408image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"409image.save(image_filename)410411del pipeline412413# Handle the repository creation414if local_rank == 0:415if args.push_to_hub:416if args.hub_model_id is None:417repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)418else:419repo_name = args.hub_model_id420create_repo(repo_name, exist_ok=True, token=args.hub_token)421repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)422423with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:424if "step_*" not in gitignore:425gitignore.write("step_*\n")426if "epoch_*" not in gitignore:427gitignore.write("epoch_*\n")428elif args.output_dir is not None:429os.makedirs(args.output_dir, exist_ok=True)430431# Load the tokenizer432if args.tokenizer_name:433logger.info(f"Loading tokenizer from {args.tokenizer_name}", ranks=[0])434tokenizer = AutoTokenizer.from_pretrained(435args.tokenizer_name,436revision=args.revision,437use_fast=False,438)439elif args.pretrained_model_name_or_path:440logger.info("Loading tokenizer from pretrained model", ranks=[0])441tokenizer = AutoTokenizer.from_pretrained(442args.pretrained_model_name_or_path,443subfolder="tokenizer",444revision=args.revision,445use_fast=False,446)447# import correct text encoder class448text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path)449450# Load models and create wrapper for stable diffusion451452logger.info(f"Loading text_encoder from {args.pretrained_model_name_or_path}", ranks=[0])453454text_encoder = text_encoder_cls.from_pretrained(455args.pretrained_model_name_or_path,456subfolder="text_encoder",457revision=args.revision,458)459460logger.info(f"Loading AutoencoderKL from {args.pretrained_model_name_or_path}", ranks=[0])461vae = AutoencoderKL.from_pretrained(462args.pretrained_model_name_or_path,463subfolder="vae",464revision=args.revision,465)466467logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0])468with ColoInitContext(device=get_current_device()):469unet = UNet2DConditionModel.from_pretrained(470args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, low_cpu_mem_usage=False471)472473vae.requires_grad_(False)474text_encoder.requires_grad_(False)475476if args.gradient_checkpointing:477unet.enable_gradient_checkpointing()478479if args.scale_lr:480args.learning_rate = args.learning_rate * args.train_batch_size * world_size481482unet = gemini_zero_dpp(unet, args.placement)483484# config optimizer for colossalai zero485optimizer = GeminiAdamOptimizer(486unet, lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm487)488489# load noise_scheduler490noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")491492# prepare dataset493logger.info(f"Prepare dataset from {args.instance_data_dir}", ranks=[0])494train_dataset = DreamBoothDataset(495instance_data_root=args.instance_data_dir,496instance_prompt=args.instance_prompt,497class_data_root=args.class_data_dir if args.with_prior_preservation else None,498class_prompt=args.class_prompt,499tokenizer=tokenizer,500size=args.resolution,501center_crop=args.center_crop,502)503504def collate_fn(examples):505input_ids = [example["instance_prompt_ids"] for example in examples]506pixel_values = [example["instance_images"] for example in examples]507508# Concat class and instance examples for prior preservation.509# We do this to avoid doing two forward passes.510if args.with_prior_preservation:511input_ids += [example["class_prompt_ids"] for example in examples]512pixel_values += [example["class_images"] for example in examples]513514pixel_values = torch.stack(pixel_values)515pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()516517input_ids = tokenizer.pad(518{"input_ids": input_ids},519padding="max_length",520max_length=tokenizer.model_max_length,521return_tensors="pt",522).input_ids523524batch = {525"input_ids": input_ids,526"pixel_values": pixel_values,527}528return batch529530train_dataloader = torch.utils.data.DataLoader(531train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=1532)533534# Scheduler and math around the number of training steps.535overrode_max_train_steps = False536num_update_steps_per_epoch = math.ceil(len(train_dataloader))537if args.max_train_steps is None:538args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch539overrode_max_train_steps = True540541lr_scheduler = get_scheduler(542args.lr_scheduler,543optimizer=optimizer,544num_warmup_steps=args.lr_warmup_steps,545num_training_steps=args.max_train_steps,546)547weight_dtype = torch.float32548if args.mixed_precision == "fp16":549weight_dtype = torch.float16550elif args.mixed_precision == "bf16":551weight_dtype = torch.bfloat16552553# Move text_encode and vae to gpu.554# For mixed precision training we cast the text_encoder and vae weights to half-precision555# as these models are only used for inference, keeping weights in full precision is not required.556vae.to(get_current_device(), dtype=weight_dtype)557text_encoder.to(get_current_device(), dtype=weight_dtype)558559# We need to recalculate our total training steps as the size of the training dataloader may have changed.560num_update_steps_per_epoch = math.ceil(len(train_dataloader))561if overrode_max_train_steps:562args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch563# Afterwards we recalculate our number of training epochs564args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)565566# Train!567total_batch_size = args.train_batch_size * world_size568569logger.info("***** Running training *****", ranks=[0])570logger.info(f" Num examples = {len(train_dataset)}", ranks=[0])571logger.info(f" Num batches each epoch = {len(train_dataloader)}", ranks=[0])572logger.info(f" Num Epochs = {args.num_train_epochs}", ranks=[0])573logger.info(f" Instantaneous batch size per device = {args.train_batch_size}", ranks=[0])574logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}", ranks=[0])575logger.info(f" Total optimization steps = {args.max_train_steps}", ranks=[0])576577# Only show the progress bar once on each machine.578progress_bar = tqdm(range(args.max_train_steps), disable=not local_rank == 0)579progress_bar.set_description("Steps")580global_step = 0581582torch.cuda.synchronize()583for epoch in range(args.num_train_epochs):584unet.train()585for step, batch in enumerate(train_dataloader):586torch.cuda.reset_peak_memory_stats()587# Move batch to gpu588for key, value in batch.items():589batch[key] = value.to(get_current_device(), non_blocking=True)590591# Convert images to latent space592optimizer.zero_grad()593594latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()595latents = latents * 0.18215596597# Sample noise that we'll add to the latents598noise = torch.randn_like(latents)599bsz = latents.shape[0]600# Sample a random timestep for each image601timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)602timesteps = timesteps.long()603604# Add noise to the latents according to the noise magnitude at each timestep605# (this is the forward diffusion process)606noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)607608# Get the text embedding for conditioning609encoder_hidden_states = text_encoder(batch["input_ids"])[0]610611# Predict the noise residual612model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample613614# Get the target for loss depending on the prediction type615if noise_scheduler.config.prediction_type == "epsilon":616target = noise617elif noise_scheduler.config.prediction_type == "v_prediction":618target = noise_scheduler.get_velocity(latents, noise, timesteps)619else:620raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")621622if args.with_prior_preservation:623# Chunk the noise and model_pred into two parts and compute the loss on each part separately.624model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)625target, target_prior = torch.chunk(target, 2, dim=0)626627# Compute instance loss628loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()629630# Compute prior loss631prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")632633# Add the prior loss to the instance loss.634loss = loss + args.prior_loss_weight * prior_loss635else:636loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")637638optimizer.backward(loss)639640optimizer.step()641lr_scheduler.step()642logger.info(f"max GPU_mem cost is {torch.cuda.max_memory_allocated()/2**20} MB", ranks=[0])643# Checks if the accelerator has performed an optimization step behind the scenes644progress_bar.update(1)645global_step += 1646logs = {647"loss": loss.detach().item(),648"lr": optimizer.param_groups[0]["lr"],649} # lr_scheduler.get_last_lr()[0]}650progress_bar.set_postfix(**logs)651652if global_step % args.save_steps == 0:653torch.cuda.synchronize()654torch_unet = get_static_torch_model(unet)655if local_rank == 0:656pipeline = DiffusionPipeline.from_pretrained(657args.pretrained_model_name_or_path,658unet=torch_unet,659revision=args.revision,660)661save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")662pipeline.save_pretrained(save_path)663logger.info(f"Saving model checkpoint to {save_path}", ranks=[0])664if global_step >= args.max_train_steps:665break666667torch.cuda.synchronize()668unet = get_static_torch_model(unet)669670if local_rank == 0:671pipeline = DiffusionPipeline.from_pretrained(672args.pretrained_model_name_or_path,673unet=unet,674revision=args.revision,675)676677pipeline.save_pretrained(args.output_dir)678logger.info(f"Saving model checkpoint to {args.output_dir}", ranks=[0])679680if args.push_to_hub:681repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)682683684if __name__ == "__main__":685args = parse_args()686main(args)687688689