Path: blob/main/examples/dreambooth/train_inpainting_dreambooth.py
1441 views
import argparse1import hashlib2import itertools3import json4import math5import os6import random7import shutil8from contextlib import nullcontext9from pathlib import Path10from typing import Optional1112import torch13import torch.nn.functional as F14import torch.utils.checkpoint15from accelerate import Accelerator16from accelerate.logging import get_logger17from accelerate.utils import set_seed18from huggingface_hub import HfFolder, Repository, whoami19from PIL import Image20from torch.utils.data import Dataset21from torchvision import transforms22from tqdm.auto import tqdm23from transformers import CLIPTextModel, CLIPTokenizer2425from diffusers import (AutoencoderKL, DDIMScheduler, DDPMScheduler,26StableDiffusionInpaintPipeline, UNet2DConditionModel)27from diffusers.optimization import get_scheduler2829torch.backends.cudnn.benchmark = True303132logger = get_logger(__name__)333435def parse_args(input_args=None):36parser = argparse.ArgumentParser(description="Simple example of a training script.")37parser.add_argument(38"--pretrained_model_name_or_path",39type=str,40default=None,41required=True,42help="Path to pretrained model or model identifier from huggingface.co/models.",43)44parser.add_argument(45"--pretrained_vae_name_or_path",46type=str,47default=None,48help="Path to pretrained vae or vae identifier from huggingface.co/models.",49)50parser.add_argument(51"--revision",52type=str,53default="fp16",54required=False,55help="Revision of pretrained model identifier from huggingface.co/models.",56)57parser.add_argument(58"--tokenizer_name",59type=str,60default=None,61help="Pretrained tokenizer name or path if not the same as model_name",62)63parser.add_argument(64"--instance_data_dir",65type=str,66default=None,67help="A folder containing the training data of instance images.",68)69parser.add_argument(70"--class_data_dir",71type=str,72default=None,73help="A folder containing the training data of class images.",74)75parser.add_argument(76"--instance_prompt",77type=str,78default=None,79help="The prompt with identifier specifying the instance",80)81parser.add_argument(82"--class_prompt",83type=str,84default=None,85help="The prompt to specify images in the same class as provided instance images.",86)87# parser.add_argument(88# "--save_sample_prompt",89# type=str,90# default=None,91# help="The prompt used to generate sample outputs to save.",92# )93parser.add_argument(94"--save_sample_negative_prompt",95type=str,96default=None,97help="The negative prompt used to generate sample outputs to save.",98)99parser.add_argument(100"--n_save_sample",101type=int,102default=4,103help="The number of samples to save.",104)105parser.add_argument(106"--save_guidance_scale",107type=float,108default=7.5,109help="CFG for save sample.",110)111parser.add_argument(112"--save_infer_steps",113type=int,114default=50,115help="The number of inference steps for save sample.",116)117parser.add_argument(118"--pad_tokens",119default=False,120action="store_true",121help="Flag to pad tokens to length 77.",122)123parser.add_argument(124"--with_prior_preservation",125default=False,126action="store_true",127help="Flag to add prior preservation loss.",128)129parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")130parser.add_argument(131"--num_class_images",132type=int,133default=100,134help=(135"Minimal class images for prior preservation loss. If not have enough images, additional images will be"136" sampled with class_prompt."137),138)139parser.add_argument(140"--output_dir",141type=str,142default="text-inversion-model",143help="The output directory where the model predictions and checkpoints will be written.",144)145parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")146parser.add_argument(147"--resolution",148type=int,149default=512,150help=(151"The resolution for input images, all the images in the train/validation dataset will be resized to this"152" resolution"153),154)155parser.add_argument(156"--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"157)158parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder")159parser.add_argument(160"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."161)162parser.add_argument(163"--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."164)165parser.add_argument("--num_train_epochs", type=int, default=1)166parser.add_argument(167"--max_train_steps",168type=int,169default=None,170help="Total number of training steps to perform. If provided, overrides num_train_epochs.",171)172parser.add_argument(173"--gradient_accumulation_steps",174type=int,175default=1,176help="Number of updates steps to accumulate before performing a backward/update pass.",177)178parser.add_argument(179"--gradient_checkpointing",180action="store_true",181help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",182)183parser.add_argument(184"--learning_rate",185type=float,186default=5e-6,187help="Initial learning rate (after the potential warmup period) to use.",188)189parser.add_argument(190"--scale_lr",191action="store_true",192default=False,193help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",194)195parser.add_argument(196"--lr_scheduler",197type=str,198default="constant",199help=(200'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'201' "constant", "constant_with_warmup"]'202),203)204parser.add_argument(205"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."206)207parser.add_argument(208"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."209)210parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")211parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")212parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")213parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")214parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")215parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")216parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")217parser.add_argument(218"--hub_model_id",219type=str,220default=None,221help="The name of the repository to keep in sync with the local `output_dir`.",222)223parser.add_argument(224"--logging_dir",225type=str,226default="logs",227help=(228"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"229" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."230),231)232parser.add_argument("--log_interval", type=int, default=10, help="Log every N steps.")233parser.add_argument("--save_interval", type=int, default=10_000, help="Save weights every N steps.")234parser.add_argument("--save_min_steps", type=int, default=0, help="Start saving weights after N steps.")235parser.add_argument(236"--mixed_precision",237type=str,238default="no",239choices=["no", "fp16", "bf16"],240help=(241"Whether to use mixed precision. Choose"242"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."243"and an Nvidia Ampere GPU."244),245)246parser.add_argument("--not_cache_latents", action="store_true", help="Do not precompute and cache latents from VAE.")247parser.add_argument("--hflip", action="store_true", help="Apply horizontal flip data augmentation.")248parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")249parser.add_argument(250"--concepts_list",251type=str,252default=None,253help="Path to json containing multiple concepts, will overwrite parameters like instance_prompt, class_prompt, etc.",254)255256if input_args is not None:257args = parser.parse_args(input_args)258else:259args = parser.parse_args()260261env_local_rank = int(os.environ.get("LOCAL_RANK", -1))262if env_local_rank != -1 and env_local_rank != args.local_rank:263args.local_rank = env_local_rank264265return args266267268def get_cutout_holes(height, width, min_holes=8, max_holes=32, min_height=16, max_height=128, min_width=16, max_width=128):269holes = []270for _n in range(random.randint(min_holes, max_holes)):271hole_height = random.randint(min_height, max_height)272hole_width = random.randint(min_width, max_width)273y1 = random.randint(0, height - hole_height)274x1 = random.randint(0, width - hole_width)275y2 = y1 + hole_height276x2 = x1 + hole_width277holes.append((x1, y1, x2, y2))278return holes279280281def generate_random_mask(image):282mask = torch.zeros_like(image[:1])283holes = get_cutout_holes(mask.shape[1], mask.shape[2])284for (x1, y1, x2, y2) in holes:285mask[:, y1:y2, x1:x2] = 1.286if random.uniform(0, 1) < 0.25:287mask.fill_(1.)288masked_image = image * (mask < 0.5)289return mask, masked_image290291292class DreamBoothDataset(Dataset):293"""294A dataset to prepare the instance and class images with the prompts for fine-tuning the model.295It pre-processes the images and the tokenizes prompts.296"""297298def __init__(299self,300concepts_list,301tokenizer,302with_prior_preservation=True,303size=512,304center_crop=False,305num_class_images=None,306pad_tokens=False,307hflip=False308):309self.size = size310self.center_crop = center_crop311self.tokenizer = tokenizer312self.with_prior_preservation = with_prior_preservation313self.pad_tokens = pad_tokens314315self.instance_images_path = []316self.class_images_path = []317318for concept in concepts_list:319inst_img_path = [(x, concept["instance_prompt"]) for x in Path(concept["instance_data_dir"]).iterdir() if x.is_file()]320self.instance_images_path.extend(inst_img_path)321322if with_prior_preservation:323class_img_path = [(x, concept["class_prompt"]) for x in Path(concept["class_data_dir"]).iterdir() if x.is_file()]324self.class_images_path.extend(class_img_path[:num_class_images])325326random.shuffle(self.instance_images_path)327self.num_instance_images = len(self.instance_images_path)328self.num_class_images = len(self.class_images_path)329self._length = max(self.num_class_images, self.num_instance_images)330331self.image_transforms = transforms.Compose(332[333transforms.RandomHorizontalFlip(0.5 * hflip),334transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),335transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),336transforms.ToTensor(),337transforms.Normalize([0.5], [0.5]),338]339)340341def __len__(self):342return self._length343344def __getitem__(self, index):345example = {}346instance_path, instance_prompt = self.instance_images_path[index % self.num_instance_images]347instance_image = Image.open(instance_path)348if not instance_image.mode == "RGB":349instance_image = instance_image.convert("RGB")350example["instance_images"] = self.image_transforms(instance_image)351example["instance_masks"], example["instance_masked_images"] = generate_random_mask(example["instance_images"])352example["instance_prompt_ids"] = self.tokenizer(353instance_prompt,354padding="max_length" if self.pad_tokens else "do_not_pad",355truncation=True,356max_length=self.tokenizer.model_max_length,357).input_ids358359if self.with_prior_preservation:360class_path, class_prompt = self.class_images_path[index % self.num_class_images]361class_image = Image.open(class_path)362if not class_image.mode == "RGB":363class_image = class_image.convert("RGB")364example["class_images"] = self.image_transforms(class_image)365example["class_masks"], example["class_masked_images"] = generate_random_mask(example["class_images"])366example["class_prompt_ids"] = self.tokenizer(367class_prompt,368padding="max_length" if self.pad_tokens else "do_not_pad",369truncation=True,370max_length=self.tokenizer.model_max_length,371).input_ids372373return example374375376class PromptDataset(Dataset):377"A simple dataset to prepare the prompts to generate class images on multiple GPUs."378379def __init__(self, prompt, num_samples):380self.prompt = prompt381self.num_samples = num_samples382383def __len__(self):384return self.num_samples385386def __getitem__(self, index):387example = {}388example["prompt"] = self.prompt389example["index"] = index390return example391392393class LatentsDataset(Dataset):394def __init__(self, latents_cache, text_encoder_cache):395self.latents_cache = latents_cache396self.text_encoder_cache = text_encoder_cache397398def __len__(self):399return len(self.latents_cache)400401def __getitem__(self, index):402return self.latents_cache[index], self.text_encoder_cache[index]403404405class AverageMeter:406def __init__(self, name=None):407self.name = name408self.reset()409410def reset(self):411self.sum = self.count = self.avg = 0412413def update(self, val, n=1):414self.sum += val * n415self.count += n416self.avg = self.sum / self.count417418419def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):420if token is None:421token = HfFolder.get_token()422if organization is None:423username = whoami(token)["name"]424return f"{username}/{model_id}"425else:426return f"{organization}/{model_id}"427428429def main(args):430logging_dir = Path(args.output_dir, "0", args.logging_dir)431432accelerator = Accelerator(433gradient_accumulation_steps=args.gradient_accumulation_steps,434mixed_precision=args.mixed_precision,435log_with="tensorboard",436logging_dir=logging_dir,437)438439# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate440# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.441# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.442if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:443raise ValueError(444"Gradient accumulation is not supported when training the text encoder in distributed training. "445"Please set gradient_accumulation_steps to 1. This feature will be supported in the future."446)447448if args.seed is not None:449set_seed(args.seed)450451if args.concepts_list is None:452args.concepts_list = [453{454"instance_prompt": args.instance_prompt,455"class_prompt": args.class_prompt,456"instance_data_dir": args.instance_data_dir,457"class_data_dir": args.class_data_dir458}459]460else:461with open(args.concepts_list, "r") as f:462args.concepts_list = json.load(f)463464if args.with_prior_preservation:465pipeline = None466for concept in args.concepts_list:467class_images_dir = Path(concept["class_data_dir"])468class_images_dir.mkdir(parents=True, exist_ok=True)469cur_class_images = len(list(class_images_dir.iterdir()))470471if cur_class_images < args.num_class_images:472torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32473if pipeline is None:474pipeline = StableDiffusionInpaintPipeline.from_pretrained(475args.pretrained_model_name_or_path,476vae=AutoencoderKL.from_pretrained(477args.pretrained_vae_name_or_path or args.pretrained_model_name_or_path,478revision=None if args.pretrained_vae_name_or_path else args.revision,479torch_dtype=torch_dtype480),481torch_dtype=torch_dtype,482safety_checker=None,483revision=args.revision484)485pipeline.set_progress_bar_config(disable=True)486pipeline.to(accelerator.device)487488num_new_images = args.num_class_images - cur_class_images489logger.info(f"Number of class images to sample: {num_new_images}.")490491sample_dataset = PromptDataset(concept["class_prompt"], num_new_images)492sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)493494sample_dataloader = accelerator.prepare(sample_dataloader)495496inp_img = Image.new("RGB", (512, 512), color=(0, 0, 0))497inp_mask = Image.new("L", (512, 512), color=255)498499with torch.autocast("cuda"),torch.inference_mode():500for example in tqdm(501sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process502):503images = pipeline(504prompt=example["prompt"],505image=inp_img,506mask_image=inp_mask,507num_inference_steps=args.save_infer_steps508).images509510for i, image in enumerate(images):511hash_image = hashlib.sha1(image.tobytes()).hexdigest()512image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"513image.save(image_filename)514515del pipeline516if torch.cuda.is_available():517torch.cuda.empty_cache()518519# Load the tokenizer520if args.tokenizer_name:521tokenizer = CLIPTokenizer.from_pretrained(522args.tokenizer_name,523revision=args.revision,524)525elif args.pretrained_model_name_or_path:526tokenizer = CLIPTokenizer.from_pretrained(527args.pretrained_model_name_or_path,528subfolder="tokenizer",529revision=args.revision,530)531532# Load models and create wrapper for stable diffusion533text_encoder = CLIPTextModel.from_pretrained(534args.pretrained_model_name_or_path,535subfolder="text_encoder",536revision=args.revision,537)538vae = AutoencoderKL.from_pretrained(539args.pretrained_model_name_or_path,540subfolder="vae",541revision=args.revision,542)543unet = UNet2DConditionModel.from_pretrained(544args.pretrained_model_name_or_path,545subfolder="unet",546revision=args.revision,547torch_dtype=torch.float32548)549550vae.requires_grad_(False)551if not args.train_text_encoder:552text_encoder.requires_grad_(False)553554if args.gradient_checkpointing:555unet.enable_gradient_checkpointing()556if args.train_text_encoder:557text_encoder.gradient_checkpointing_enable()558559if args.scale_lr:560args.learning_rate = (561args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes562)563564# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs565if args.use_8bit_adam:566try:567import bitsandbytes as bnb568except ImportError:569raise ImportError(570"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."571)572573optimizer_class = bnb.optim.AdamW8bit574else:575optimizer_class = torch.optim.AdamW576577params_to_optimize = (578itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()579)580optimizer = optimizer_class(581params_to_optimize,582lr=args.learning_rate,583betas=(args.adam_beta1, args.adam_beta2),584weight_decay=args.adam_weight_decay,585eps=args.adam_epsilon,586)587588noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler")589590train_dataset = DreamBoothDataset(591concepts_list=args.concepts_list,592tokenizer=tokenizer,593with_prior_preservation=args.with_prior_preservation,594size=args.resolution,595center_crop=args.center_crop,596num_class_images=args.num_class_images,597pad_tokens=args.pad_tokens,598hflip=args.hflip599)600601def collate_fn(examples):602input_ids = [example["instance_prompt_ids"] for example in examples]603pixel_values = [example["instance_images"] for example in examples]604mask_values = [example["instance_masks"] for example in examples]605masked_image_values = [example["instance_masked_images"] for example in examples]606607# Concat class and instance examples for prior preservation.608# We do this to avoid doing two forward passes.609if args.with_prior_preservation:610input_ids += [example["class_prompt_ids"] for example in examples]611pixel_values += [example["class_images"] for example in examples]612mask_values += [example["class_masks"] for example in examples]613masked_image_values += [example["class_masked_images"] for example in examples]614615pixel_values = torch.stack(pixel_values).to(memory_format=torch.contiguous_format).float()616mask_values = torch.stack(mask_values).to(memory_format=torch.contiguous_format).float()617masked_image_values = torch.stack(masked_image_values).to(memory_format=torch.contiguous_format).float()618619input_ids = tokenizer.pad(620{"input_ids": input_ids},621padding=True,622return_tensors="pt",623).input_ids624625batch = {626"input_ids": input_ids,627"pixel_values": pixel_values,628"mask_values": mask_values,629"masked_image_values": masked_image_values630}631return batch632633train_dataloader = torch.utils.data.DataLoader(634train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, pin_memory=True, num_workers=8635)636637weight_dtype = torch.float32638if args.mixed_precision == "fp16":639weight_dtype = torch.float16640elif args.mixed_precision == "bf16":641weight_dtype = torch.bfloat16642643# Move text_encode and vae to gpu.644# For mixed precision training we cast the text_encoder and vae weights to half-precision645# as these models are only used for inference, keeping weights in full precision is not required.646vae.to(accelerator.device, dtype=weight_dtype)647if not args.train_text_encoder:648text_encoder.to(accelerator.device, dtype=weight_dtype)649650if not args.not_cache_latents:651latents_cache = []652text_encoder_cache = []653for batch in tqdm(train_dataloader, desc="Caching latents"):654with torch.no_grad():655batch["pixel_values"] = batch["pixel_values"].to(accelerator.device, non_blocking=True, dtype=weight_dtype)656batch["input_ids"] = batch["input_ids"].to(accelerator.device, non_blocking=True)657latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)658if args.train_text_encoder:659text_encoder_cache.append(batch["input_ids"])660else:661text_encoder_cache.append(text_encoder(batch["input_ids"])[0])662train_dataset = LatentsDataset(latents_cache, text_encoder_cache)663train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, collate_fn=lambda x: x, shuffle=True)664665del vae666if not args.train_text_encoder:667del text_encoder668if torch.cuda.is_available():669torch.cuda.empty_cache()670671# Scheduler and math around the number of training steps.672overrode_max_train_steps = False673num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)674if args.max_train_steps is None:675args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch676overrode_max_train_steps = True677678lr_scheduler = get_scheduler(679args.lr_scheduler,680optimizer=optimizer,681num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,682num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,683)684685if args.train_text_encoder:686unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(687unet, text_encoder, optimizer, train_dataloader, lr_scheduler688)689else:690unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(691unet, optimizer, train_dataloader, lr_scheduler692)693694# We need to recalculate our total training steps as the size of the training dataloader may have changed.695num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)696if overrode_max_train_steps:697args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch698# Afterwards we recalculate our number of training epochs699args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)700701# We need to initialize the trackers we use, and also store our configuration.702# The trackers initializes automatically on the main process.703if accelerator.is_main_process:704accelerator.init_trackers("dreambooth")705706# Train!707total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps708709logger.info("***** Running training *****")710logger.info(f" Num examples = {len(train_dataset)}")711logger.info(f" Num batches each epoch = {len(train_dataloader)}")712logger.info(f" Num Epochs = {args.num_train_epochs}")713logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")714logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")715logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")716logger.info(f" Total optimization steps = {args.max_train_steps}")717718def save_weights(step):719# Create the pipeline using using the trained modules and save it.720if accelerator.is_main_process:721if args.train_text_encoder:722text_enc_model = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True)723else:724text_enc_model = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision)725scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)726pipeline = StableDiffusionInpaintPipeline.from_pretrained(727args.pretrained_model_name_or_path,728unet=accelerator.unwrap_model(unet, keep_fp32_wrapper=True).to(torch.float16),729text_encoder=text_enc_model.to(torch.float16),730vae=AutoencoderKL.from_pretrained(731args.pretrained_vae_name_or_path or args.pretrained_model_name_or_path,732subfolder=None if args.pretrained_vae_name_or_path else "vae",733revision=None if args.pretrained_vae_name_or_path else args.revision,734),735safety_checker=None,736scheduler=scheduler,737torch_dtype=torch.float16,738revision=args.revision,739)740save_dir = os.path.join(args.output_dir, f"{step}")741pipeline.save_pretrained(save_dir)742with open(os.path.join(save_dir, "args.json"), "w") as f:743json.dump(args.__dict__, f, indent=2)744745shutil.copy("train_inpainting_dreambooth.py", save_dir)746747pipeline = pipeline.to(accelerator.device)748pipeline.set_progress_bar_config(disable=True)749for idx, concept in enumerate(args.concepts_list):750g_cuda = torch.Generator(device=accelerator.device).manual_seed(args.seed)751sample_dir = os.path.join(save_dir, "samples", str(idx))752os.makedirs(sample_dir, exist_ok=True)753inp_img = Image.new("RGB", (512, 512), color=(0, 0, 0))754inp_mask = Image.new("L", (512, 512), color=255)755with torch.inference_mode():756for i in tqdm(range(args.n_save_sample), desc="Generating samples"):757images = pipeline(758prompt=concept["instance_prompt"],759image=inp_img,760mask_image=inp_mask,761negative_prompt=args.save_sample_negative_prompt,762guidance_scale=args.save_guidance_scale,763num_inference_steps=args.save_infer_steps,764generator=g_cuda765).images766images[0].save(os.path.join(sample_dir, f"{i}.png"))767del pipeline768if torch.cuda.is_available():769torch.cuda.empty_cache()770print(f"[*] Weights saved at {save_dir}")771unet.to(torch.float32)772text_enc_model.to(torch.float32)773774# Only show the progress bar once on each machine.775progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)776progress_bar.set_description("Steps")777global_step = 0778loss_avg = AverageMeter()779text_enc_context = nullcontext() if args.train_text_encoder else torch.no_grad()780for epoch in range(args.num_train_epochs):781unet.train()782if args.train_text_encoder:783text_encoder.train()784random.shuffle(train_dataset.class_images_path)785for step, batch in enumerate(train_dataloader):786with accelerator.accumulate(unet):787# Convert images to latent space788with torch.no_grad():789if not args.not_cache_latents:790latent_dist = batch[0][0]791else:792latent_dist = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist793masked_latent_dist = vae.encode(batch["masked_image_values"].to(dtype=weight_dtype)).latent_dist794latents = latent_dist.sample() * 0.18215795masked_image_latents = masked_latent_dist.sample() * 0.18215796mask = F.interpolate(batch["mask_values"], scale_factor=1 / 8)797798# Sample noise that we'll add to the latents799noise = torch.randn_like(latents)800bsz = latents.shape[0]801# Sample a random timestep for each image802timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)803timesteps = timesteps.long()804805# Add noise to the latents according to the noise magnitude at each timestep806# (this is the forward diffusion process)807noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)808809# Get the text embedding for conditioning810with text_enc_context:811if not args.not_cache_latents:812if args.train_text_encoder:813encoder_hidden_states = text_encoder(batch[0][1])[0]814else:815encoder_hidden_states = batch[0][1]816else:817encoder_hidden_states = text_encoder(batch["input_ids"])[0]818819latent_model_input = torch.cat([noisy_latents, mask, masked_image_latents], dim=1)820# Predict the noise residual821noise_pred = unet(latent_model_input, timesteps, encoder_hidden_states).sample822823if args.with_prior_preservation:824# Chunk the noise and noise_pred into two parts and compute the loss on each part separately.825noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)826noise, noise_prior = torch.chunk(noise, 2, dim=0)827828# Compute instance loss829loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean()830831# Compute prior loss832prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean")833834# Add the prior loss to the instance loss.835loss = loss + args.prior_loss_weight * prior_loss836else:837loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")838839accelerator.backward(loss)840# if accelerator.sync_gradients:841# params_to_clip = (842# itertools.chain(unet.parameters(), text_encoder.parameters())843# if args.train_text_encoder844# else unet.parameters()845# )846# accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)847optimizer.step()848lr_scheduler.step()849optimizer.zero_grad(set_to_none=True)850loss_avg.update(loss.detach_(), bsz)851852if not global_step % args.log_interval:853logs = {"loss": loss_avg.avg.item(), "lr": lr_scheduler.get_last_lr()[0]}854progress_bar.set_postfix(**logs)855accelerator.log(logs, step=global_step)856857if global_step > 0 and not global_step % args.save_interval and global_step >= args.save_min_steps:858save_weights(global_step)859860progress_bar.update(1)861global_step += 1862863if global_step >= args.max_train_steps:864break865866accelerator.wait_for_everyone()867868save_weights(global_step)869870accelerator.end_training()871872873if __name__ == "__main__":874args = parse_args()875main(args)876877878