Path: blob/main/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint.py
1979 views
import argparse1import hashlib2import itertools3import math4import os5import random6from pathlib import Path7from typing import Optional89import numpy as np10import torch11import torch.nn.functional as F12import torch.utils.checkpoint13from accelerate import Accelerator14from accelerate.logging import get_logger15from accelerate.utils import ProjectConfiguration, set_seed16from huggingface_hub import HfFolder, Repository, create_repo, whoami17from PIL import Image, ImageDraw18from torch.utils.data import Dataset19from torchvision import transforms20from tqdm.auto import tqdm21from transformers import CLIPTextModel, CLIPTokenizer2223from diffusers import (24AutoencoderKL,25DDPMScheduler,26StableDiffusionInpaintPipeline,27StableDiffusionPipeline,28UNet2DConditionModel,29)30from diffusers.optimization import get_scheduler31from diffusers.utils import check_min_version323334# Will error if the minimal version of diffusers is not installed. Remove at your own risks.35check_min_version("0.13.0.dev0")3637logger = get_logger(__name__)383940def prepare_mask_and_masked_image(image, mask):41image = np.array(image.convert("RGB"))42image = image[None].transpose(0, 3, 1, 2)43image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.04445mask = np.array(mask.convert("L"))46mask = mask.astype(np.float32) / 255.047mask = mask[None, None]48mask[mask < 0.5] = 049mask[mask >= 0.5] = 150mask = torch.from_numpy(mask)5152masked_image = image * (mask < 0.5)5354return mask, masked_image555657# generate random masks58def random_mask(im_shape, ratio=1, mask_full_image=False):59mask = Image.new("L", im_shape, 0)60draw = ImageDraw.Draw(mask)61size = (random.randint(0, int(im_shape[0] * ratio)), random.randint(0, int(im_shape[1] * ratio)))62# use this to always mask the whole image63if mask_full_image:64size = (int(im_shape[0] * ratio), int(im_shape[1] * ratio))65limits = (im_shape[0] - size[0] // 2, im_shape[1] - size[1] // 2)66center = (random.randint(size[0] // 2, limits[0]), random.randint(size[1] // 2, limits[1]))67draw_type = random.randint(0, 1)68if draw_type == 0 or mask_full_image:69draw.rectangle(70(center[0] - size[0] // 2, center[1] - size[1] // 2, center[0] + size[0] // 2, center[1] + size[1] // 2),71fill=255,72)73else:74draw.ellipse(75(center[0] - size[0] // 2, center[1] - size[1] // 2, center[0] + size[0] // 2, center[1] + size[1] // 2),76fill=255,77)7879return mask808182def parse_args():83parser = argparse.ArgumentParser(description="Simple example of a training script.")84parser.add_argument(85"--pretrained_model_name_or_path",86type=str,87default=None,88required=True,89help="Path to pretrained model or model identifier from huggingface.co/models.",90)91parser.add_argument(92"--tokenizer_name",93type=str,94default=None,95help="Pretrained tokenizer name or path if not the same as model_name",96)97parser.add_argument(98"--instance_data_dir",99type=str,100default=None,101required=True,102help="A folder containing the training data of instance images.",103)104parser.add_argument(105"--class_data_dir",106type=str,107default=None,108required=False,109help="A folder containing the training data of class images.",110)111parser.add_argument(112"--instance_prompt",113type=str,114default=None,115help="The prompt with identifier specifying the instance",116)117parser.add_argument(118"--class_prompt",119type=str,120default=None,121help="The prompt to specify images in the same class as provided instance images.",122)123parser.add_argument(124"--with_prior_preservation",125default=False,126action="store_true",127help="Flag to add prior preservation loss.",128)129parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")130parser.add_argument(131"--num_class_images",132type=int,133default=100,134help=(135"Minimal class images for prior preservation loss. If not have enough images, additional images will be"136" sampled with class_prompt."137),138)139parser.add_argument(140"--output_dir",141type=str,142default="text-inversion-model",143help="The output directory where the model predictions and checkpoints will be written.",144)145parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")146parser.add_argument(147"--resolution",148type=int,149default=512,150help=(151"The resolution for input images, all the images in the train/validation dataset will be resized to this"152" resolution"153),154)155parser.add_argument(156"--center_crop",157default=False,158action="store_true",159help=(160"Whether to center crop the input images to the resolution. If not set, the images will be randomly"161" cropped. The images will be resized to the resolution first before cropping."162),163)164parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder")165parser.add_argument(166"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."167)168parser.add_argument(169"--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."170)171parser.add_argument("--num_train_epochs", type=int, default=1)172parser.add_argument(173"--max_train_steps",174type=int,175default=None,176help="Total number of training steps to perform. If provided, overrides num_train_epochs.",177)178parser.add_argument(179"--gradient_accumulation_steps",180type=int,181default=1,182help="Number of updates steps to accumulate before performing a backward/update pass.",183)184parser.add_argument(185"--gradient_checkpointing",186action="store_true",187help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",188)189parser.add_argument(190"--learning_rate",191type=float,192default=5e-6,193help="Initial learning rate (after the potential warmup period) to use.",194)195parser.add_argument(196"--scale_lr",197action="store_true",198default=False,199help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",200)201parser.add_argument(202"--lr_scheduler",203type=str,204default="constant",205help=(206'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'207' "constant", "constant_with_warmup"]'208),209)210parser.add_argument(211"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."212)213parser.add_argument(214"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."215)216parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")217parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")218parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")219parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")220parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")221parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")222parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")223parser.add_argument(224"--hub_model_id",225type=str,226default=None,227help="The name of the repository to keep in sync with the local `output_dir`.",228)229parser.add_argument(230"--logging_dir",231type=str,232default="logs",233help=(234"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"235" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."236),237)238parser.add_argument(239"--mixed_precision",240type=str,241default="no",242choices=["no", "fp16", "bf16"],243help=(244"Whether to use mixed precision. Choose"245"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."246"and an Nvidia Ampere GPU."247),248)249parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")250parser.add_argument(251"--checkpointing_steps",252type=int,253default=500,254help=(255"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"256" checkpoints in case they are better than the last checkpoint and are suitable for resuming training"257" using `--resume_from_checkpoint`."258),259)260parser.add_argument(261"--checkpoints_total_limit",262type=int,263default=None,264help=(265"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."266" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"267" for more docs"268),269)270parser.add_argument(271"--resume_from_checkpoint",272type=str,273default=None,274help=(275"Whether training should be resumed from a previous checkpoint. Use a path saved by"276' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'277),278)279280args = parser.parse_args()281env_local_rank = int(os.environ.get("LOCAL_RANK", -1))282if env_local_rank != -1 and env_local_rank != args.local_rank:283args.local_rank = env_local_rank284285if args.instance_data_dir is None:286raise ValueError("You must specify a train data directory.")287288if args.with_prior_preservation:289if args.class_data_dir is None:290raise ValueError("You must specify a data directory for class images.")291if args.class_prompt is None:292raise ValueError("You must specify prompt for class images.")293294return args295296297class DreamBoothDataset(Dataset):298"""299A dataset to prepare the instance and class images with the prompts for fine-tuning the model.300It pre-processes the images and the tokenizes prompts.301"""302303def __init__(304self,305instance_data_root,306instance_prompt,307tokenizer,308class_data_root=None,309class_prompt=None,310size=512,311center_crop=False,312):313self.size = size314self.center_crop = center_crop315self.tokenizer = tokenizer316317self.instance_data_root = Path(instance_data_root)318if not self.instance_data_root.exists():319raise ValueError("Instance images root doesn't exists.")320321self.instance_images_path = list(Path(instance_data_root).iterdir())322self.num_instance_images = len(self.instance_images_path)323self.instance_prompt = instance_prompt324self._length = self.num_instance_images325326if class_data_root is not None:327self.class_data_root = Path(class_data_root)328self.class_data_root.mkdir(parents=True, exist_ok=True)329self.class_images_path = list(self.class_data_root.iterdir())330self.num_class_images = len(self.class_images_path)331self._length = max(self.num_class_images, self.num_instance_images)332self.class_prompt = class_prompt333else:334self.class_data_root = None335336self.image_transforms_resize_and_crop = transforms.Compose(337[338transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),339transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),340]341)342343self.image_transforms = transforms.Compose(344[345transforms.ToTensor(),346transforms.Normalize([0.5], [0.5]),347]348)349350def __len__(self):351return self._length352353def __getitem__(self, index):354example = {}355instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])356if not instance_image.mode == "RGB":357instance_image = instance_image.convert("RGB")358instance_image = self.image_transforms_resize_and_crop(instance_image)359360example["PIL_images"] = instance_image361example["instance_images"] = self.image_transforms(instance_image)362363example["instance_prompt_ids"] = self.tokenizer(364self.instance_prompt,365padding="do_not_pad",366truncation=True,367max_length=self.tokenizer.model_max_length,368).input_ids369370if self.class_data_root:371class_image = Image.open(self.class_images_path[index % self.num_class_images])372if not class_image.mode == "RGB":373class_image = class_image.convert("RGB")374class_image = self.image_transforms_resize_and_crop(class_image)375example["class_images"] = self.image_transforms(class_image)376example["class_PIL_images"] = class_image377example["class_prompt_ids"] = self.tokenizer(378self.class_prompt,379padding="do_not_pad",380truncation=True,381max_length=self.tokenizer.model_max_length,382).input_ids383384return example385386387class PromptDataset(Dataset):388"A simple dataset to prepare the prompts to generate class images on multiple GPUs."389390def __init__(self, prompt, num_samples):391self.prompt = prompt392self.num_samples = num_samples393394def __len__(self):395return self.num_samples396397def __getitem__(self, index):398example = {}399example["prompt"] = self.prompt400example["index"] = index401return example402403404def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):405if token is None:406token = HfFolder.get_token()407if organization is None:408username = whoami(token)["name"]409return f"{username}/{model_id}"410else:411return f"{organization}/{model_id}"412413414def main():415args = parse_args()416logging_dir = Path(args.output_dir, args.logging_dir)417418accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit)419420accelerator = Accelerator(421gradient_accumulation_steps=args.gradient_accumulation_steps,422mixed_precision=args.mixed_precision,423log_with="tensorboard",424logging_dir=logging_dir,425accelerator_project_config=accelerator_project_config,426)427428# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate429# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.430# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.431if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:432raise ValueError(433"Gradient accumulation is not supported when training the text encoder in distributed training. "434"Please set gradient_accumulation_steps to 1. This feature will be supported in the future."435)436437if args.seed is not None:438set_seed(args.seed)439440if args.with_prior_preservation:441class_images_dir = Path(args.class_data_dir)442if not class_images_dir.exists():443class_images_dir.mkdir(parents=True)444cur_class_images = len(list(class_images_dir.iterdir()))445446if cur_class_images < args.num_class_images:447torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32448pipeline = StableDiffusionInpaintPipeline.from_pretrained(449args.pretrained_model_name_or_path, torch_dtype=torch_dtype, safety_checker=None450)451pipeline.set_progress_bar_config(disable=True)452453num_new_images = args.num_class_images - cur_class_images454logger.info(f"Number of class images to sample: {num_new_images}.")455456sample_dataset = PromptDataset(args.class_prompt, num_new_images)457sample_dataloader = torch.utils.data.DataLoader(458sample_dataset, batch_size=args.sample_batch_size, num_workers=1459)460461sample_dataloader = accelerator.prepare(sample_dataloader)462pipeline.to(accelerator.device)463transform_to_pil = transforms.ToPILImage()464for example in tqdm(465sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process466):467bsz = len(example["prompt"])468fake_images = torch.rand((3, args.resolution, args.resolution))469transform_to_pil = transforms.ToPILImage()470fake_pil_images = transform_to_pil(fake_images)471472fake_mask = random_mask((args.resolution, args.resolution), ratio=1, mask_full_image=True)473474images = pipeline(prompt=example["prompt"], mask_image=fake_mask, image=fake_pil_images).images475476for i, image in enumerate(images):477hash_image = hashlib.sha1(image.tobytes()).hexdigest()478image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"479image.save(image_filename)480481del pipeline482if torch.cuda.is_available():483torch.cuda.empty_cache()484485# Handle the repository creation486if accelerator.is_main_process:487if args.push_to_hub:488if args.hub_model_id is None:489repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)490else:491repo_name = args.hub_model_id492create_repo(repo_name, exist_ok=True, token=args.hub_token)493repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)494495with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:496if "step_*" not in gitignore:497gitignore.write("step_*\n")498if "epoch_*" not in gitignore:499gitignore.write("epoch_*\n")500elif args.output_dir is not None:501os.makedirs(args.output_dir, exist_ok=True)502503# Load the tokenizer504if args.tokenizer_name:505tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)506elif args.pretrained_model_name_or_path:507tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")508509# Load models and create wrapper for stable diffusion510text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")511vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")512unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")513514vae.requires_grad_(False)515if not args.train_text_encoder:516text_encoder.requires_grad_(False)517518if args.gradient_checkpointing:519unet.enable_gradient_checkpointing()520if args.train_text_encoder:521text_encoder.gradient_checkpointing_enable()522523if args.scale_lr:524args.learning_rate = (525args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes526)527528# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs529if args.use_8bit_adam:530try:531import bitsandbytes as bnb532except ImportError:533raise ImportError(534"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."535)536537optimizer_class = bnb.optim.AdamW8bit538else:539optimizer_class = torch.optim.AdamW540541params_to_optimize = (542itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()543)544optimizer = optimizer_class(545params_to_optimize,546lr=args.learning_rate,547betas=(args.adam_beta1, args.adam_beta2),548weight_decay=args.adam_weight_decay,549eps=args.adam_epsilon,550)551552noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")553554train_dataset = DreamBoothDataset(555instance_data_root=args.instance_data_dir,556instance_prompt=args.instance_prompt,557class_data_root=args.class_data_dir if args.with_prior_preservation else None,558class_prompt=args.class_prompt,559tokenizer=tokenizer,560size=args.resolution,561center_crop=args.center_crop,562)563564def collate_fn(examples):565input_ids = [example["instance_prompt_ids"] for example in examples]566pixel_values = [example["instance_images"] for example in examples]567568# Concat class and instance examples for prior preservation.569# We do this to avoid doing two forward passes.570if args.with_prior_preservation:571input_ids += [example["class_prompt_ids"] for example in examples]572pixel_values += [example["class_images"] for example in examples]573pior_pil = [example["class_PIL_images"] for example in examples]574575masks = []576masked_images = []577for example in examples:578pil_image = example["PIL_images"]579# generate a random mask580mask = random_mask(pil_image.size, 1, False)581# prepare mask and masked image582mask, masked_image = prepare_mask_and_masked_image(pil_image, mask)583584masks.append(mask)585masked_images.append(masked_image)586587if args.with_prior_preservation:588for pil_image in pior_pil:589# generate a random mask590mask = random_mask(pil_image.size, 1, False)591# prepare mask and masked image592mask, masked_image = prepare_mask_and_masked_image(pil_image, mask)593594masks.append(mask)595masked_images.append(masked_image)596597pixel_values = torch.stack(pixel_values)598pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()599600input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids601masks = torch.stack(masks)602masked_images = torch.stack(masked_images)603batch = {"input_ids": input_ids, "pixel_values": pixel_values, "masks": masks, "masked_images": masked_images}604return batch605606train_dataloader = torch.utils.data.DataLoader(607train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn608)609610# Scheduler and math around the number of training steps.611overrode_max_train_steps = False612num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)613if args.max_train_steps is None:614args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch615overrode_max_train_steps = True616617lr_scheduler = get_scheduler(618args.lr_scheduler,619optimizer=optimizer,620num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,621num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,622)623624if args.train_text_encoder:625unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(626unet, text_encoder, optimizer, train_dataloader, lr_scheduler627)628else:629unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(630unet, optimizer, train_dataloader, lr_scheduler631)632accelerator.register_for_checkpointing(lr_scheduler)633634weight_dtype = torch.float32635if args.mixed_precision == "fp16":636weight_dtype = torch.float16637elif args.mixed_precision == "bf16":638weight_dtype = torch.bfloat16639640# Move text_encode and vae to gpu.641# For mixed precision training we cast the text_encoder and vae weights to half-precision642# as these models are only used for inference, keeping weights in full precision is not required.643vae.to(accelerator.device, dtype=weight_dtype)644if not args.train_text_encoder:645text_encoder.to(accelerator.device, dtype=weight_dtype)646647# We need to recalculate our total training steps as the size of the training dataloader may have changed.648num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)649if overrode_max_train_steps:650args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch651# Afterwards we recalculate our number of training epochs652args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)653654# We need to initialize the trackers we use, and also store our configuration.655# The trackers initializes automatically on the main process.656if accelerator.is_main_process:657accelerator.init_trackers("dreambooth", config=vars(args))658659# Train!660total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps661662logger.info("***** Running training *****")663logger.info(f" Num examples = {len(train_dataset)}")664logger.info(f" Num batches each epoch = {len(train_dataloader)}")665logger.info(f" Num Epochs = {args.num_train_epochs}")666logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")667logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")668logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")669logger.info(f" Total optimization steps = {args.max_train_steps}")670global_step = 0671first_epoch = 0672673if args.resume_from_checkpoint:674if args.resume_from_checkpoint != "latest":675path = os.path.basename(args.resume_from_checkpoint)676else:677# Get the most recent checkpoint678dirs = os.listdir(args.output_dir)679dirs = [d for d in dirs if d.startswith("checkpoint")]680dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))681path = dirs[-1] if len(dirs) > 0 else None682683if path is None:684accelerator.print(685f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."686)687args.resume_from_checkpoint = None688else:689accelerator.print(f"Resuming from checkpoint {path}")690accelerator.load_state(os.path.join(args.output_dir, path))691global_step = int(path.split("-")[1])692693resume_global_step = global_step * args.gradient_accumulation_steps694first_epoch = global_step // num_update_steps_per_epoch695resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)696697# Only show the progress bar once on each machine.698progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)699progress_bar.set_description("Steps")700701for epoch in range(first_epoch, args.num_train_epochs):702unet.train()703for step, batch in enumerate(train_dataloader):704# Skip steps until we reach the resumed step705if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:706if step % args.gradient_accumulation_steps == 0:707progress_bar.update(1)708continue709710with accelerator.accumulate(unet):711# Convert images to latent space712713latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()714latents = latents * vae.config.scaling_factor715716# Convert masked images to latent space717masked_latents = vae.encode(718batch["masked_images"].reshape(batch["pixel_values"].shape).to(dtype=weight_dtype)719).latent_dist.sample()720masked_latents = masked_latents * vae.config.scaling_factor721722masks = batch["masks"]723# resize the mask to latents shape as we concatenate the mask to the latents724mask = torch.stack(725[726torch.nn.functional.interpolate(mask, size=(args.resolution // 8, args.resolution // 8))727for mask in masks728]729)730mask = mask.reshape(-1, 1, args.resolution // 8, args.resolution // 8)731732# Sample noise that we'll add to the latents733noise = torch.randn_like(latents)734bsz = latents.shape[0]735# Sample a random timestep for each image736timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)737timesteps = timesteps.long()738739# Add noise to the latents according to the noise magnitude at each timestep740# (this is the forward diffusion process)741noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)742743# concatenate the noised latents with the mask and the masked latents744latent_model_input = torch.cat([noisy_latents, mask, masked_latents], dim=1)745746# Get the text embedding for conditioning747encoder_hidden_states = text_encoder(batch["input_ids"])[0]748749# Predict the noise residual750noise_pred = unet(latent_model_input, timesteps, encoder_hidden_states).sample751752# Get the target for loss depending on the prediction type753if noise_scheduler.config.prediction_type == "epsilon":754target = noise755elif noise_scheduler.config.prediction_type == "v_prediction":756target = noise_scheduler.get_velocity(latents, noise, timesteps)757else:758raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")759760if args.with_prior_preservation:761# Chunk the noise and noise_pred into two parts and compute the loss on each part separately.762noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)763target, target_prior = torch.chunk(target, 2, dim=0)764765# Compute instance loss766loss = F.mse_loss(noise_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()767768# Compute prior loss769prior_loss = F.mse_loss(noise_pred_prior.float(), target_prior.float(), reduction="mean")770771# Add the prior loss to the instance loss.772loss = loss + args.prior_loss_weight * prior_loss773else:774loss = F.mse_loss(noise_pred.float(), target.float(), reduction="mean")775776accelerator.backward(loss)777if accelerator.sync_gradients:778params_to_clip = (779itertools.chain(unet.parameters(), text_encoder.parameters())780if args.train_text_encoder781else unet.parameters()782)783accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)784optimizer.step()785lr_scheduler.step()786optimizer.zero_grad()787788# Checks if the accelerator has performed an optimization step behind the scenes789if accelerator.sync_gradients:790progress_bar.update(1)791global_step += 1792793if global_step % args.checkpointing_steps == 0:794if accelerator.is_main_process:795save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")796accelerator.save_state(save_path)797logger.info(f"Saved state to {save_path}")798799logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}800progress_bar.set_postfix(**logs)801accelerator.log(logs, step=global_step)802803if global_step >= args.max_train_steps:804break805806accelerator.wait_for_everyone()807808# Create the pipeline using using the trained modules and save it.809if accelerator.is_main_process:810pipeline = StableDiffusionPipeline.from_pretrained(811args.pretrained_model_name_or_path,812unet=accelerator.unwrap_model(unet),813text_encoder=accelerator.unwrap_model(text_encoder),814)815pipeline.save_pretrained(args.output_dir)816817if args.push_to_hub:818repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)819820accelerator.end_training()821822823if __name__ == "__main__":824main()825826827