Path: blob/main/examples/dreambooth/train_dreambooth_lora.py
1441 views
#!/usr/bin/env python1# coding=utf-82# Copyright 2023 The HuggingFace Inc. team. All rights reserved.3#4# Licensed under the Apache License, Version 2.0 (the "License");5# you may not use this file except in compliance with the License.6# You may obtain a copy of the License at7#8# http://www.apache.org/licenses/LICENSE-2.09#10# Unless required by applicable law or agreed to in writing, software11# distributed under the License is distributed on an "AS IS" BASIS,12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.13# See the License for the specific language governing permissions and1415import argparse16import hashlib17import logging18import math19import os20import warnings21from pathlib import Path22from typing import Optional2324import numpy as np25import torch26import torch.nn.functional as F27import torch.utils.checkpoint28import transformers29from accelerate import Accelerator30from accelerate.logging import get_logger31from accelerate.utils import ProjectConfiguration, set_seed32from huggingface_hub import HfFolder, Repository, create_repo, whoami33from packaging import version34from PIL import Image35from torch.utils.data import Dataset36from torchvision import transforms37from tqdm.auto import tqdm38from transformers import AutoTokenizer, PretrainedConfig3940import diffusers41from diffusers import (42AutoencoderKL,43DDPMScheduler,44DiffusionPipeline,45DPMSolverMultistepScheduler,46UNet2DConditionModel,47)48from diffusers.loaders import AttnProcsLayers49from diffusers.models.attention_processor import LoRAAttnProcessor50from diffusers.optimization import get_scheduler51from diffusers.utils import check_min_version, is_wandb_available52from diffusers.utils.import_utils import is_xformers_available535455# Will error if the minimal version of diffusers is not installed. Remove at your own risks.56check_min_version("0.15.0.dev0")5758logger = get_logger(__name__)596061def save_model_card(repo_name, images=None, base_model=str, prompt=str, repo_folder=None):62img_str = ""63for i, image in enumerate(images):64image.save(os.path.join(repo_folder, f"image_{i}.png"))65img_str += f"\n"6667yaml = f"""68---69license: creativeml-openrail-m70base_model: {base_model}71instance_prompt: {prompt}72tags:73- stable-diffusion74- stable-diffusion-diffusers75- text-to-image76- diffusers77- lora78inference: true79---80"""81model_card = f"""82# LoRA DreamBooth - {repo_name}8384These are LoRA adaption weights for {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n85{img_str}86"""87with open(os.path.join(repo_folder, "README.md"), "w") as f:88f.write(yaml + model_card)899091def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):92text_encoder_config = PretrainedConfig.from_pretrained(93pretrained_model_name_or_path,94subfolder="text_encoder",95revision=revision,96)97model_class = text_encoder_config.architectures[0]9899if model_class == "CLIPTextModel":100from transformers import CLIPTextModel101102return CLIPTextModel103elif model_class == "RobertaSeriesModelWithTransformation":104from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation105106return RobertaSeriesModelWithTransformation107else:108raise ValueError(f"{model_class} is not supported.")109110111def parse_args(input_args=None):112parser = argparse.ArgumentParser(description="Simple example of a training script.")113parser.add_argument(114"--pretrained_model_name_or_path",115type=str,116default=None,117required=True,118help="Path to pretrained model or model identifier from huggingface.co/models.",119)120parser.add_argument(121"--revision",122type=str,123default=None,124required=False,125help="Revision of pretrained model identifier from huggingface.co/models.",126)127parser.add_argument(128"--tokenizer_name",129type=str,130default=None,131help="Pretrained tokenizer name or path if not the same as model_name",132)133parser.add_argument(134"--instance_data_dir",135type=str,136default=None,137required=True,138help="A folder containing the training data of instance images.",139)140parser.add_argument(141"--class_data_dir",142type=str,143default=None,144required=False,145help="A folder containing the training data of class images.",146)147parser.add_argument(148"--instance_prompt",149type=str,150default=None,151required=True,152help="The prompt with identifier specifying the instance",153)154parser.add_argument(155"--class_prompt",156type=str,157default=None,158help="The prompt to specify images in the same class as provided instance images.",159)160parser.add_argument(161"--validation_prompt",162type=str,163default=None,164help="A prompt that is used during validation to verify that the model is learning.",165)166parser.add_argument(167"--num_validation_images",168type=int,169default=4,170help="Number of images that should be generated during validation with `validation_prompt`.",171)172parser.add_argument(173"--validation_epochs",174type=int,175default=50,176help=(177"Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"178" `args.validation_prompt` multiple times: `args.num_validation_images`."179),180)181parser.add_argument(182"--with_prior_preservation",183default=False,184action="store_true",185help="Flag to add prior preservation loss.",186)187parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")188parser.add_argument(189"--num_class_images",190type=int,191default=100,192help=(193"Minimal class images for prior preservation loss. If there are not enough images already present in"194" class_data_dir, additional images will be sampled with class_prompt."195),196)197parser.add_argument(198"--output_dir",199type=str,200default="lora-dreambooth-model",201help="The output directory where the model predictions and checkpoints will be written.",202)203parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")204parser.add_argument(205"--resolution",206type=int,207default=512,208help=(209"The resolution for input images, all the images in the train/validation dataset will be resized to this"210" resolution"211),212)213parser.add_argument(214"--center_crop",215default=False,216action="store_true",217help=(218"Whether to center crop the input images to the resolution. If not set, the images will be randomly"219" cropped. The images will be resized to the resolution first before cropping."220),221)222parser.add_argument(223"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."224)225parser.add_argument(226"--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."227)228parser.add_argument("--num_train_epochs", type=int, default=1)229parser.add_argument(230"--max_train_steps",231type=int,232default=None,233help="Total number of training steps to perform. If provided, overrides num_train_epochs.",234)235parser.add_argument(236"--checkpointing_steps",237type=int,238default=500,239help=(240"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"241" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"242" training using `--resume_from_checkpoint`."243),244)245parser.add_argument(246"--checkpoints_total_limit",247type=int,248default=None,249help=(250"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."251" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"252" for more docs"253),254)255parser.add_argument(256"--resume_from_checkpoint",257type=str,258default=None,259help=(260"Whether training should be resumed from a previous checkpoint. Use a path saved by"261' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'262),263)264parser.add_argument(265"--gradient_accumulation_steps",266type=int,267default=1,268help="Number of updates steps to accumulate before performing a backward/update pass.",269)270parser.add_argument(271"--gradient_checkpointing",272action="store_true",273help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",274)275parser.add_argument(276"--learning_rate",277type=float,278default=5e-4,279help="Initial learning rate (after the potential warmup period) to use.",280)281parser.add_argument(282"--scale_lr",283action="store_true",284default=False,285help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",286)287parser.add_argument(288"--lr_scheduler",289type=str,290default="constant",291help=(292'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'293' "constant", "constant_with_warmup"]'294),295)296parser.add_argument(297"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."298)299parser.add_argument(300"--lr_num_cycles",301type=int,302default=1,303help="Number of hard resets of the lr in cosine_with_restarts scheduler.",304)305parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")306parser.add_argument(307"--dataloader_num_workers",308type=int,309default=0,310help=(311"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."312),313)314parser.add_argument(315"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."316)317parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")318parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")319parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")320parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")321parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")322parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")323parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")324parser.add_argument(325"--hub_model_id",326type=str,327default=None,328help="The name of the repository to keep in sync with the local `output_dir`.",329)330parser.add_argument(331"--logging_dir",332type=str,333default="logs",334help=(335"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"336" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."337),338)339parser.add_argument(340"--allow_tf32",341action="store_true",342help=(343"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"344" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"345),346)347parser.add_argument(348"--report_to",349type=str,350default="tensorboard",351help=(352'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'353' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'354),355)356parser.add_argument(357"--mixed_precision",358type=str,359default=None,360choices=["no", "fp16", "bf16"],361help=(362"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="363" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"364" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."365),366)367parser.add_argument(368"--prior_generation_precision",369type=str,370default=None,371choices=["no", "fp32", "fp16", "bf16"],372help=(373"Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="374" 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."375),376)377parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")378parser.add_argument(379"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."380)381382if input_args is not None:383args = parser.parse_args(input_args)384else:385args = parser.parse_args()386387env_local_rank = int(os.environ.get("LOCAL_RANK", -1))388if env_local_rank != -1 and env_local_rank != args.local_rank:389args.local_rank = env_local_rank390391if args.with_prior_preservation:392if args.class_data_dir is None:393raise ValueError("You must specify a data directory for class images.")394if args.class_prompt is None:395raise ValueError("You must specify prompt for class images.")396else:397# logger is not available yet398if args.class_data_dir is not None:399warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")400if args.class_prompt is not None:401warnings.warn("You need not use --class_prompt without --with_prior_preservation.")402403return args404405406class DreamBoothDataset(Dataset):407"""408A dataset to prepare the instance and class images with the prompts for fine-tuning the model.409It pre-processes the images and the tokenizes prompts.410"""411412def __init__(413self,414instance_data_root,415instance_prompt,416tokenizer,417class_data_root=None,418class_prompt=None,419class_num=None,420size=512,421center_crop=False,422):423self.size = size424self.center_crop = center_crop425self.tokenizer = tokenizer426427self.instance_data_root = Path(instance_data_root)428if not self.instance_data_root.exists():429raise ValueError("Instance images root doesn't exists.")430431self.instance_images_path = list(Path(instance_data_root).iterdir())432self.num_instance_images = len(self.instance_images_path)433self.instance_prompt = instance_prompt434self._length = self.num_instance_images435436if class_data_root is not None:437self.class_data_root = Path(class_data_root)438self.class_data_root.mkdir(parents=True, exist_ok=True)439self.class_images_path = list(self.class_data_root.iterdir())440if class_num is not None:441self.num_class_images = min(len(self.class_images_path), class_num)442else:443self.num_class_images = len(self.class_images_path)444self._length = max(self.num_class_images, self.num_instance_images)445self.class_prompt = class_prompt446else:447self.class_data_root = None448449self.image_transforms = transforms.Compose(450[451transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),452transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),453transforms.ToTensor(),454transforms.Normalize([0.5], [0.5]),455]456)457458def __len__(self):459return self._length460461def __getitem__(self, index):462example = {}463instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])464if not instance_image.mode == "RGB":465instance_image = instance_image.convert("RGB")466example["instance_images"] = self.image_transforms(instance_image)467example["instance_prompt_ids"] = self.tokenizer(468self.instance_prompt,469truncation=True,470padding="max_length",471max_length=self.tokenizer.model_max_length,472return_tensors="pt",473).input_ids474475if self.class_data_root:476class_image = Image.open(self.class_images_path[index % self.num_class_images])477if not class_image.mode == "RGB":478class_image = class_image.convert("RGB")479example["class_images"] = self.image_transforms(class_image)480example["class_prompt_ids"] = self.tokenizer(481self.class_prompt,482truncation=True,483padding="max_length",484max_length=self.tokenizer.model_max_length,485return_tensors="pt",486).input_ids487488return example489490491def collate_fn(examples, with_prior_preservation=False):492input_ids = [example["instance_prompt_ids"] for example in examples]493pixel_values = [example["instance_images"] for example in examples]494495# Concat class and instance examples for prior preservation.496# We do this to avoid doing two forward passes.497if with_prior_preservation:498input_ids += [example["class_prompt_ids"] for example in examples]499pixel_values += [example["class_images"] for example in examples]500501pixel_values = torch.stack(pixel_values)502pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()503504input_ids = torch.cat(input_ids, dim=0)505506batch = {507"input_ids": input_ids,508"pixel_values": pixel_values,509}510return batch511512513class PromptDataset(Dataset):514"A simple dataset to prepare the prompts to generate class images on multiple GPUs."515516def __init__(self, prompt, num_samples):517self.prompt = prompt518self.num_samples = num_samples519520def __len__(self):521return self.num_samples522523def __getitem__(self, index):524example = {}525example["prompt"] = self.prompt526example["index"] = index527return example528529530def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):531if token is None:532token = HfFolder.get_token()533if organization is None:534username = whoami(token)["name"]535return f"{username}/{model_id}"536else:537return f"{organization}/{model_id}"538539540def main(args):541logging_dir = Path(args.output_dir, args.logging_dir)542543accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit)544545accelerator = Accelerator(546gradient_accumulation_steps=args.gradient_accumulation_steps,547mixed_precision=args.mixed_precision,548log_with=args.report_to,549logging_dir=logging_dir,550project_config=accelerator_project_config,551)552553if args.report_to == "wandb":554if not is_wandb_available():555raise ImportError("Make sure to install wandb if you want to use it for logging during training.")556import wandb557558# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate559# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.560# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.561# Make one log on every process with the configuration for debugging.562logging.basicConfig(563format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",564datefmt="%m/%d/%Y %H:%M:%S",565level=logging.INFO,566)567logger.info(accelerator.state, main_process_only=False)568if accelerator.is_local_main_process:569transformers.utils.logging.set_verbosity_warning()570diffusers.utils.logging.set_verbosity_info()571else:572transformers.utils.logging.set_verbosity_error()573diffusers.utils.logging.set_verbosity_error()574575# If passed along, set the training seed now.576if args.seed is not None:577set_seed(args.seed)578579# Generate class images if prior preservation is enabled.580if args.with_prior_preservation:581class_images_dir = Path(args.class_data_dir)582if not class_images_dir.exists():583class_images_dir.mkdir(parents=True)584cur_class_images = len(list(class_images_dir.iterdir()))585586if cur_class_images < args.num_class_images:587torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32588if args.prior_generation_precision == "fp32":589torch_dtype = torch.float32590elif args.prior_generation_precision == "fp16":591torch_dtype = torch.float16592elif args.prior_generation_precision == "bf16":593torch_dtype = torch.bfloat16594pipeline = DiffusionPipeline.from_pretrained(595args.pretrained_model_name_or_path,596torch_dtype=torch_dtype,597safety_checker=None,598revision=args.revision,599)600pipeline.set_progress_bar_config(disable=True)601602num_new_images = args.num_class_images - cur_class_images603logger.info(f"Number of class images to sample: {num_new_images}.")604605sample_dataset = PromptDataset(args.class_prompt, num_new_images)606sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)607608sample_dataloader = accelerator.prepare(sample_dataloader)609pipeline.to(accelerator.device)610611for example in tqdm(612sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process613):614images = pipeline(example["prompt"]).images615616for i, image in enumerate(images):617hash_image = hashlib.sha1(image.tobytes()).hexdigest()618image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"619image.save(image_filename)620621del pipeline622if torch.cuda.is_available():623torch.cuda.empty_cache()624625# Handle the repository creation626if accelerator.is_main_process:627if args.push_to_hub:628if args.hub_model_id is None:629repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)630else:631repo_name = args.hub_model_id632633create_repo(repo_name, exist_ok=True, token=args.hub_token)634repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)635636with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:637if "step_*" not in gitignore:638gitignore.write("step_*\n")639if "epoch_*" not in gitignore:640gitignore.write("epoch_*\n")641elif args.output_dir is not None:642os.makedirs(args.output_dir, exist_ok=True)643644# Load the tokenizer645if args.tokenizer_name:646tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)647elif args.pretrained_model_name_or_path:648tokenizer = AutoTokenizer.from_pretrained(649args.pretrained_model_name_or_path,650subfolder="tokenizer",651revision=args.revision,652use_fast=False,653)654655# import correct text encoder class656text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)657658# Load scheduler and models659noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")660text_encoder = text_encoder_cls.from_pretrained(661args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision662)663vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)664unet = UNet2DConditionModel.from_pretrained(665args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision666)667668# We only train the additional adapter LoRA layers669vae.requires_grad_(False)670text_encoder.requires_grad_(False)671unet.requires_grad_(False)672673# For mixed precision training we cast the text_encoder and vae weights to half-precision674# as these models are only used for inference, keeping weights in full precision is not required.675weight_dtype = torch.float32676if accelerator.mixed_precision == "fp16":677weight_dtype = torch.float16678elif accelerator.mixed_precision == "bf16":679weight_dtype = torch.bfloat16680681# Move unet, vae and text_encoder to device and cast to weight_dtype682unet.to(accelerator.device, dtype=weight_dtype)683vae.to(accelerator.device, dtype=weight_dtype)684text_encoder.to(accelerator.device, dtype=weight_dtype)685686if args.enable_xformers_memory_efficient_attention:687if is_xformers_available():688import xformers689690xformers_version = version.parse(xformers.__version__)691if xformers_version == version.parse("0.0.16"):692logger.warn(693"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."694)695unet.enable_xformers_memory_efficient_attention()696else:697raise ValueError("xformers is not available. Make sure it is installed correctly")698699# now we will add new LoRA weights to the attention layers700# It's important to realize here how many attention weights will be added and of which sizes701# The sizes of the attention layers consist only of two different variables:702# 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`.703# 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`.704705# Let's first see how many attention processors we will have to set.706# For Stable Diffusion, it should be equal to:707# - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12708# - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2709# - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18710# => 32 layers711712# Set correct lora layers713lora_attn_procs = {}714for name in unet.attn_processors.keys():715cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim716if name.startswith("mid_block"):717hidden_size = unet.config.block_out_channels[-1]718elif name.startswith("up_blocks"):719block_id = int(name[len("up_blocks.")])720hidden_size = list(reversed(unet.config.block_out_channels))[block_id]721elif name.startswith("down_blocks"):722block_id = int(name[len("down_blocks.")])723hidden_size = unet.config.block_out_channels[block_id]724725lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)726727unet.set_attn_processor(lora_attn_procs)728lora_layers = AttnProcsLayers(unet.attn_processors)729730accelerator.register_for_checkpointing(lora_layers)731732if args.scale_lr:733args.learning_rate = (734args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes735)736737# Enable TF32 for faster training on Ampere GPUs,738# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices739if args.allow_tf32:740torch.backends.cuda.matmul.allow_tf32 = True741742if args.scale_lr:743args.learning_rate = (744args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes745)746747# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs748if args.use_8bit_adam:749try:750import bitsandbytes as bnb751except ImportError:752raise ImportError(753"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."754)755756optimizer_class = bnb.optim.AdamW8bit757else:758optimizer_class = torch.optim.AdamW759760# Optimizer creation761optimizer = optimizer_class(762lora_layers.parameters(),763lr=args.learning_rate,764betas=(args.adam_beta1, args.adam_beta2),765weight_decay=args.adam_weight_decay,766eps=args.adam_epsilon,767)768769# Dataset and DataLoaders creation:770train_dataset = DreamBoothDataset(771instance_data_root=args.instance_data_dir,772instance_prompt=args.instance_prompt,773class_data_root=args.class_data_dir if args.with_prior_preservation else None,774class_prompt=args.class_prompt,775class_num=args.num_class_images,776tokenizer=tokenizer,777size=args.resolution,778center_crop=args.center_crop,779)780781train_dataloader = torch.utils.data.DataLoader(782train_dataset,783batch_size=args.train_batch_size,784shuffle=True,785collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),786num_workers=args.dataloader_num_workers,787)788789# Scheduler and math around the number of training steps.790overrode_max_train_steps = False791num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)792if args.max_train_steps is None:793args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch794overrode_max_train_steps = True795796lr_scheduler = get_scheduler(797args.lr_scheduler,798optimizer=optimizer,799num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,800num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,801num_cycles=args.lr_num_cycles,802power=args.lr_power,803)804805# Prepare everything with our `accelerator`.806lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(807lora_layers, optimizer, train_dataloader, lr_scheduler808)809810# We need to recalculate our total training steps as the size of the training dataloader may have changed.811num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)812if overrode_max_train_steps:813args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch814# Afterwards we recalculate our number of training epochs815args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)816817# We need to initialize the trackers we use, and also store our configuration.818# The trackers initializes automatically on the main process.819if accelerator.is_main_process:820accelerator.init_trackers("dreambooth-lora", config=vars(args))821822# Train!823total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps824825logger.info("***** Running training *****")826logger.info(f" Num examples = {len(train_dataset)}")827logger.info(f" Num batches each epoch = {len(train_dataloader)}")828logger.info(f" Num Epochs = {args.num_train_epochs}")829logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")830logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")831logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")832logger.info(f" Total optimization steps = {args.max_train_steps}")833global_step = 0834first_epoch = 0835836# Potentially load in the weights and states from a previous save837if args.resume_from_checkpoint:838if args.resume_from_checkpoint != "latest":839path = os.path.basename(args.resume_from_checkpoint)840else:841# Get the mos recent checkpoint842dirs = os.listdir(args.output_dir)843dirs = [d for d in dirs if d.startswith("checkpoint")]844dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))845path = dirs[-1] if len(dirs) > 0 else None846847if path is None:848accelerator.print(849f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."850)851args.resume_from_checkpoint = None852else:853accelerator.print(f"Resuming from checkpoint {path}")854accelerator.load_state(os.path.join(args.output_dir, path))855global_step = int(path.split("-")[1])856857resume_global_step = global_step * args.gradient_accumulation_steps858first_epoch = global_step // num_update_steps_per_epoch859resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)860861# Only show the progress bar once on each machine.862progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)863progress_bar.set_description("Steps")864865for epoch in range(first_epoch, args.num_train_epochs):866unet.train()867for step, batch in enumerate(train_dataloader):868# Skip steps until we reach the resumed step869if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:870if step % args.gradient_accumulation_steps == 0:871progress_bar.update(1)872continue873874with accelerator.accumulate(unet):875# Convert images to latent space876latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()877latents = latents * vae.config.scaling_factor878879# Sample noise that we'll add to the latents880noise = torch.randn_like(latents)881bsz = latents.shape[0]882# Sample a random timestep for each image883timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)884timesteps = timesteps.long()885886# Add noise to the latents according to the noise magnitude at each timestep887# (this is the forward diffusion process)888noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)889890# Get the text embedding for conditioning891encoder_hidden_states = text_encoder(batch["input_ids"])[0]892893# Predict the noise residual894model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample895896# Get the target for loss depending on the prediction type897if noise_scheduler.config.prediction_type == "epsilon":898target = noise899elif noise_scheduler.config.prediction_type == "v_prediction":900target = noise_scheduler.get_velocity(latents, noise, timesteps)901else:902raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")903904if args.with_prior_preservation:905# Chunk the noise and model_pred into two parts and compute the loss on each part separately.906model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)907target, target_prior = torch.chunk(target, 2, dim=0)908909# Compute instance loss910loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")911912# Compute prior loss913prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")914915# Add the prior loss to the instance loss.916loss = loss + args.prior_loss_weight * prior_loss917else:918loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")919920accelerator.backward(loss)921if accelerator.sync_gradients:922params_to_clip = lora_layers.parameters()923accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)924optimizer.step()925lr_scheduler.step()926optimizer.zero_grad()927928# Checks if the accelerator has performed an optimization step behind the scenes929if accelerator.sync_gradients:930progress_bar.update(1)931global_step += 1932933if global_step % args.checkpointing_steps == 0:934if accelerator.is_main_process:935save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")936accelerator.save_state(save_path)937logger.info(f"Saved state to {save_path}")938939logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}940progress_bar.set_postfix(**logs)941accelerator.log(logs, step=global_step)942943if global_step >= args.max_train_steps:944break945946if accelerator.is_main_process:947if args.validation_prompt is not None and epoch % args.validation_epochs == 0:948logger.info(949f"Running validation... \n Generating {args.num_validation_images} images with prompt:"950f" {args.validation_prompt}."951)952# create pipeline953pipeline = DiffusionPipeline.from_pretrained(954args.pretrained_model_name_or_path,955unet=accelerator.unwrap_model(unet),956text_encoder=accelerator.unwrap_model(text_encoder),957revision=args.revision,958torch_dtype=weight_dtype,959)960pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)961pipeline = pipeline.to(accelerator.device)962pipeline.set_progress_bar_config(disable=True)963964# run inference965generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)966images = [967pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]968for _ in range(args.num_validation_images)969]970971for tracker in accelerator.trackers:972if tracker.name == "tensorboard":973np_images = np.stack([np.asarray(img) for img in images])974tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")975if tracker.name == "wandb":976tracker.log(977{978"validation": [979wandb.Image(image, caption=f"{i}: {args.validation_prompt}")980for i, image in enumerate(images)981]982}983)984985del pipeline986torch.cuda.empty_cache()987988# Save the lora layers989accelerator.wait_for_everyone()990if accelerator.is_main_process:991unet = unet.to(torch.float32)992unet.save_attn_procs(args.output_dir)993994# Final inference995# Load previous pipeline996pipeline = DiffusionPipeline.from_pretrained(997args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype998)999pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)1000pipeline = pipeline.to(accelerator.device)10011002# load attention processors1003pipeline.unet.load_attn_procs(args.output_dir)10041005# run inference1006if args.validation_prompt and args.num_validation_images > 0:1007generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None1008images = [1009pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]1010for _ in range(args.num_validation_images)1011]10121013for tracker in accelerator.trackers:1014if tracker.name == "tensorboard":1015np_images = np.stack([np.asarray(img) for img in images])1016tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")1017if tracker.name == "wandb":1018tracker.log(1019{1020"test": [1021wandb.Image(image, caption=f"{i}: {args.validation_prompt}")1022for i, image in enumerate(images)1023]1024}1025)10261027if args.push_to_hub:1028save_model_card(1029repo_name,1030images=images,1031base_model=args.pretrained_model_name_or_path,1032prompt=args.instance_prompt,1033repo_folder=args.output_dir,1034)1035repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)10361037accelerator.end_training()103810391040if __name__ == "__main__":1041args = parse_args()1042main(args)10431044