Path: blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py
1448 views
#!/usr/bin/env python1# coding=utf-82# Copyright 2023 The HuggingFace Inc. team. All rights reserved.3#4# Licensed under the Apache License, Version 2.0 (the "License");5# you may not use this file except in compliance with the License.6# You may obtain a copy of the License at7#8# http://www.apache.org/licenses/LICENSE-2.09#10# Unless required by applicable law or agreed to in writing, software11# distributed under the License is distributed on an "AS IS" BASIS,12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.13# See the License for the specific language governing permissions and14# limitations under the License.1516"""Script to fine-tune Stable Diffusion for InstructPix2Pix."""1718import argparse19import logging20import math21import os22from pathlib import Path23from typing import Optional2425import accelerate26import datasets27import numpy as np28import PIL29import requests30import torch31import torch.nn as nn32import torch.nn.functional as F33import torch.utils.checkpoint34import transformers35from accelerate import Accelerator36from accelerate.logging import get_logger37from accelerate.utils import ProjectConfiguration, set_seed38from datasets import load_dataset39from huggingface_hub import HfFolder, Repository, create_repo, whoami40from packaging import version41from torchvision import transforms42from tqdm.auto import tqdm43from transformers import CLIPTextModel, CLIPTokenizer4445import diffusers46from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionInstructPix2PixPipeline, UNet2DConditionModel47from diffusers.optimization import get_scheduler48from diffusers.training_utils import EMAModel49from diffusers.utils import check_min_version, deprecate, is_wandb_available50from diffusers.utils.import_utils import is_xformers_available515253# Will error if the minimal version of diffusers is not installed. Remove at your own risks.54check_min_version("0.15.0.dev0")5556logger = get_logger(__name__, log_level="INFO")5758DATASET_NAME_MAPPING = {59"fusing/instructpix2pix-1000-samples": ("input_image", "edit_prompt", "edited_image"),60}61WANDB_TABLE_COL_NAMES = ["original_image", "edited_image", "edit_prompt"]626364def parse_args():65parser = argparse.ArgumentParser(description="Simple example of a training script for InstructPix2Pix.")66parser.add_argument(67"--pretrained_model_name_or_path",68type=str,69default=None,70required=True,71help="Path to pretrained model or model identifier from huggingface.co/models.",72)73parser.add_argument(74"--revision",75type=str,76default=None,77required=False,78help="Revision of pretrained model identifier from huggingface.co/models.",79)80parser.add_argument(81"--dataset_name",82type=str,83default=None,84help=(85"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"86" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"87" or to a folder containing files that 🤗 Datasets can understand."88),89)90parser.add_argument(91"--dataset_config_name",92type=str,93default=None,94help="The config of the Dataset, leave as None if there's only one config.",95)96parser.add_argument(97"--train_data_dir",98type=str,99default=None,100help=(101"A folder containing the training data. Folder contents must follow the structure described in"102" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"103" must exist to provide the captions for the images. Ignored if `dataset_name` is specified."104),105)106parser.add_argument(107"--original_image_column",108type=str,109default="input_image",110help="The column of the dataset containing the original image on which edits where made.",111)112parser.add_argument(113"--edited_image_column",114type=str,115default="edited_image",116help="The column of the dataset containing the edited image.",117)118parser.add_argument(119"--edit_prompt_column",120type=str,121default="edit_prompt",122help="The column of the dataset containing the edit instruction.",123)124parser.add_argument(125"--val_image_url",126type=str,127default=None,128help="URL to the original image that you would like to edit (used during inference for debugging purposes).",129)130parser.add_argument(131"--validation_prompt", type=str, default=None, help="A prompt that is sampled during training for inference."132)133parser.add_argument(134"--num_validation_images",135type=int,136default=4,137help="Number of images that should be generated during validation with `validation_prompt`.",138)139parser.add_argument(140"--validation_epochs",141type=int,142default=1,143help=(144"Run fine-tuning validation every X epochs. The validation process consists of running the prompt"145" `args.validation_prompt` multiple times: `args.num_validation_images`."146),147)148parser.add_argument(149"--max_train_samples",150type=int,151default=None,152help=(153"For debugging purposes or quicker training, truncate the number of training examples to this "154"value if set."155),156)157parser.add_argument(158"--output_dir",159type=str,160default="instruct-pix2pix-model",161help="The output directory where the model predictions and checkpoints will be written.",162)163parser.add_argument(164"--cache_dir",165type=str,166default=None,167help="The directory where the downloaded models and datasets will be stored.",168)169parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")170parser.add_argument(171"--resolution",172type=int,173default=256,174help=(175"The resolution for input images, all the images in the train/validation dataset will be resized to this"176" resolution"177),178)179parser.add_argument(180"--center_crop",181default=False,182action="store_true",183help=(184"Whether to center crop the input images to the resolution. If not set, the images will be randomly"185" cropped. The images will be resized to the resolution first before cropping."186),187)188parser.add_argument(189"--random_flip",190action="store_true",191help="whether to randomly flip images horizontally",192)193parser.add_argument(194"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."195)196parser.add_argument("--num_train_epochs", type=int, default=100)197parser.add_argument(198"--max_train_steps",199type=int,200default=None,201help="Total number of training steps to perform. If provided, overrides num_train_epochs.",202)203parser.add_argument(204"--gradient_accumulation_steps",205type=int,206default=1,207help="Number of updates steps to accumulate before performing a backward/update pass.",208)209parser.add_argument(210"--gradient_checkpointing",211action="store_true",212help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",213)214parser.add_argument(215"--learning_rate",216type=float,217default=1e-4,218help="Initial learning rate (after the potential warmup period) to use.",219)220parser.add_argument(221"--scale_lr",222action="store_true",223default=False,224help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",225)226parser.add_argument(227"--lr_scheduler",228type=str,229default="constant",230help=(231'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'232' "constant", "constant_with_warmup"]'233),234)235parser.add_argument(236"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."237)238parser.add_argument(239"--conditioning_dropout_prob",240type=float,241default=None,242help="Conditioning dropout probability. Drops out the conditionings (image and edit prompt) used in training InstructPix2Pix. See section 3.2.1 in the paper: https://arxiv.org/abs/2211.09800.",243)244parser.add_argument(245"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."246)247parser.add_argument(248"--allow_tf32",249action="store_true",250help=(251"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"252" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"253),254)255parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")256parser.add_argument(257"--non_ema_revision",258type=str,259default=None,260required=False,261help=(262"Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or"263" remote repository specified with --pretrained_model_name_or_path."264),265)266parser.add_argument(267"--dataloader_num_workers",268type=int,269default=0,270help=(271"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."272),273)274parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")275parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")276parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")277parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")278parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")279parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")280parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")281parser.add_argument(282"--hub_model_id",283type=str,284default=None,285help="The name of the repository to keep in sync with the local `output_dir`.",286)287parser.add_argument(288"--logging_dir",289type=str,290default="logs",291help=(292"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"293" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."294),295)296parser.add_argument(297"--mixed_precision",298type=str,299default=None,300choices=["no", "fp16", "bf16"],301help=(302"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="303" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"304" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."305),306)307parser.add_argument(308"--report_to",309type=str,310default="tensorboard",311help=(312'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'313' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'314),315)316parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")317parser.add_argument(318"--checkpointing_steps",319type=int,320default=500,321help=(322"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"323" training using `--resume_from_checkpoint`."324),325)326parser.add_argument(327"--checkpoints_total_limit",328type=int,329default=None,330help=(331"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."332" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"333" for more docs"334),335)336parser.add_argument(337"--resume_from_checkpoint",338type=str,339default=None,340help=(341"Whether training should be resumed from a previous checkpoint. Use a path saved by"342' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'343),344)345parser.add_argument(346"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."347)348349args = parser.parse_args()350env_local_rank = int(os.environ.get("LOCAL_RANK", -1))351if env_local_rank != -1 and env_local_rank != args.local_rank:352args.local_rank = env_local_rank353354# Sanity checks355if args.dataset_name is None and args.train_data_dir is None:356raise ValueError("Need either a dataset name or a training folder.")357358# default to using the same revision for the non-ema model if not specified359if args.non_ema_revision is None:360args.non_ema_revision = args.revision361362return args363364365def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):366if token is None:367token = HfFolder.get_token()368if organization is None:369username = whoami(token)["name"]370return f"{username}/{model_id}"371else:372return f"{organization}/{model_id}"373374375def convert_to_np(image, resolution):376image = image.convert("RGB").resize((resolution, resolution))377return np.array(image).transpose(2, 0, 1)378379380def download_image(url):381image = PIL.Image.open(requests.get(url, stream=True).raw)382image = PIL.ImageOps.exif_transpose(image)383image = image.convert("RGB")384return image385386387def main():388args = parse_args()389390if args.non_ema_revision is not None:391deprecate(392"non_ema_revision!=None",393"0.15.0",394message=(395"Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to"396" use `--variant=non_ema` instead."397),398)399logging_dir = os.path.join(args.output_dir, args.logging_dir)400accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit)401accelerator = Accelerator(402gradient_accumulation_steps=args.gradient_accumulation_steps,403mixed_precision=args.mixed_precision,404log_with=args.report_to,405logging_dir=logging_dir,406project_config=accelerator_project_config,407)408409generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)410411if args.report_to == "wandb":412if not is_wandb_available():413raise ImportError("Make sure to install wandb if you want to use it for logging during training.")414import wandb415416# Make one log on every process with the configuration for debugging.417logging.basicConfig(418format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",419datefmt="%m/%d/%Y %H:%M:%S",420level=logging.INFO,421)422logger.info(accelerator.state, main_process_only=False)423if accelerator.is_local_main_process:424datasets.utils.logging.set_verbosity_warning()425transformers.utils.logging.set_verbosity_warning()426diffusers.utils.logging.set_verbosity_info()427else:428datasets.utils.logging.set_verbosity_error()429transformers.utils.logging.set_verbosity_error()430diffusers.utils.logging.set_verbosity_error()431432# If passed along, set the training seed now.433if args.seed is not None:434set_seed(args.seed)435436# Handle the repository creation437if accelerator.is_main_process:438if args.push_to_hub:439if args.hub_model_id is None:440repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)441else:442repo_name = args.hub_model_id443create_repo(repo_name, exist_ok=True, token=args.hub_token)444repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)445446with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:447if "step_*" not in gitignore:448gitignore.write("step_*\n")449if "epoch_*" not in gitignore:450gitignore.write("epoch_*\n")451elif args.output_dir is not None:452os.makedirs(args.output_dir, exist_ok=True)453454# Load scheduler, tokenizer and models.455noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")456tokenizer = CLIPTokenizer.from_pretrained(457args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision458)459text_encoder = CLIPTextModel.from_pretrained(460args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision461)462vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)463unet = UNet2DConditionModel.from_pretrained(464args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision465)466467# InstructPix2Pix uses an additional image for conditioning. To accommodate that,468# it uses 8 channels (instead of 4) in the first (conv) layer of the UNet. This UNet is469# then fine-tuned on the custom InstructPix2Pix dataset. This modified UNet is initialized470# from the pre-trained checkpoints. For the extra channels added to the first layer, they are471# initialized to zero.472if accelerator.is_main_process:473logger.info("Initializing the InstructPix2Pix UNet from the pretrained UNet.")474in_channels = 8475out_channels = unet.conv_in.out_channels476unet.register_to_config(in_channels=in_channels)477478with torch.no_grad():479new_conv_in = nn.Conv2d(480in_channels, out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding481)482new_conv_in.weight.zero_()483new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)484unet.conv_in = new_conv_in485486# Freeze vae and text_encoder487vae.requires_grad_(False)488text_encoder.requires_grad_(False)489490# Create EMA for the unet.491if args.use_ema:492ema_unet = EMAModel(unet.parameters(), model_cls=UNet2DConditionModel, model_config=unet.config)493494if args.enable_xformers_memory_efficient_attention:495if is_xformers_available():496import xformers497498xformers_version = version.parse(xformers.__version__)499if xformers_version == version.parse("0.0.16"):500logger.warn(501"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."502)503unet.enable_xformers_memory_efficient_attention()504else:505raise ValueError("xformers is not available. Make sure it is installed correctly")506507# `accelerate` 0.16.0 will have better support for customized saving508if version.parse(accelerate.__version__) >= version.parse("0.16.0"):509# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format510def save_model_hook(models, weights, output_dir):511if args.use_ema:512ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))513514for i, model in enumerate(models):515model.save_pretrained(os.path.join(output_dir, "unet"))516517# make sure to pop weight so that corresponding model is not saved again518weights.pop()519520def load_model_hook(models, input_dir):521if args.use_ema:522load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel)523ema_unet.load_state_dict(load_model.state_dict())524ema_unet.to(accelerator.device)525del load_model526527for i in range(len(models)):528# pop models so that they are not loaded again529model = models.pop()530531# load diffusers style into model532load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")533model.register_to_config(**load_model.config)534535model.load_state_dict(load_model.state_dict())536del load_model537538accelerator.register_save_state_pre_hook(save_model_hook)539accelerator.register_load_state_pre_hook(load_model_hook)540541if args.gradient_checkpointing:542unet.enable_gradient_checkpointing()543544# Enable TF32 for faster training on Ampere GPUs,545# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices546if args.allow_tf32:547torch.backends.cuda.matmul.allow_tf32 = True548549if args.scale_lr:550args.learning_rate = (551args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes552)553554# Initialize the optimizer555if args.use_8bit_adam:556try:557import bitsandbytes as bnb558except ImportError:559raise ImportError(560"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"561)562563optimizer_cls = bnb.optim.AdamW8bit564else:565optimizer_cls = torch.optim.AdamW566567optimizer = optimizer_cls(568unet.parameters(),569lr=args.learning_rate,570betas=(args.adam_beta1, args.adam_beta2),571weight_decay=args.adam_weight_decay,572eps=args.adam_epsilon,573)574575# Get the datasets: you can either provide your own training and evaluation files (see below)576# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).577578# In distributed training, the load_dataset function guarantees that only one local process can concurrently579# download the dataset.580if args.dataset_name is not None:581# Downloading and loading a dataset from the hub.582dataset = load_dataset(583args.dataset_name,584args.dataset_config_name,585cache_dir=args.cache_dir,586)587else:588data_files = {}589if args.train_data_dir is not None:590data_files["train"] = os.path.join(args.train_data_dir, "**")591dataset = load_dataset(592"imagefolder",593data_files=data_files,594cache_dir=args.cache_dir,595)596# See more about loading custom images at597# https://huggingface.co/docs/datasets/main/en/image_load#imagefolder598599# Preprocessing the datasets.600# We need to tokenize inputs and targets.601column_names = dataset["train"].column_names602603# 6. Get the column names for input/target.604dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)605if args.original_image_column is None:606original_image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]607else:608original_image_column = args.original_image_column609if original_image_column not in column_names:610raise ValueError(611f"--original_image_column' value '{args.original_image_column}' needs to be one of: {', '.join(column_names)}"612)613if args.edit_prompt_column is None:614edit_prompt_column = dataset_columns[1] if dataset_columns is not None else column_names[1]615else:616edit_prompt_column = args.edit_prompt_column617if edit_prompt_column not in column_names:618raise ValueError(619f"--edit_prompt_column' value '{args.edit_prompt_column}' needs to be one of: {', '.join(column_names)}"620)621if args.edited_image_column is None:622edited_image_column = dataset_columns[2] if dataset_columns is not None else column_names[2]623else:624edited_image_column = args.edited_image_column625if edited_image_column not in column_names:626raise ValueError(627f"--edited_image_column' value '{args.edited_image_column}' needs to be one of: {', '.join(column_names)}"628)629630# Preprocessing the datasets.631# We need to tokenize input captions and transform the images.632def tokenize_captions(captions):633inputs = tokenizer(634captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"635)636return inputs.input_ids637638# Preprocessing the datasets.639train_transforms = transforms.Compose(640[641transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),642transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),643]644)645646def preprocess_images(examples):647original_images = np.concatenate(648[convert_to_np(image, args.resolution) for image in examples[original_image_column]]649)650edited_images = np.concatenate(651[convert_to_np(image, args.resolution) for image in examples[edited_image_column]]652)653# We need to ensure that the original and the edited images undergo the same654# augmentation transforms.655images = np.concatenate([original_images, edited_images])656images = torch.tensor(images)657images = 2 * (images / 255) - 1658return train_transforms(images)659660def preprocess_train(examples):661# Preprocess images.662preprocessed_images = preprocess_images(examples)663# Since the original and edited images were concatenated before664# applying the transformations, we need to separate them and reshape665# them accordingly.666original_images, edited_images = preprocessed_images.chunk(2)667original_images = original_images.reshape(-1, 3, args.resolution, args.resolution)668edited_images = edited_images.reshape(-1, 3, args.resolution, args.resolution)669670# Collate the preprocessed images into the `examples`.671examples["original_pixel_values"] = original_images672examples["edited_pixel_values"] = edited_images673674# Preprocess the captions.675captions = [caption for caption in examples[edit_prompt_column]]676examples["input_ids"] = tokenize_captions(captions)677return examples678679with accelerator.main_process_first():680if args.max_train_samples is not None:681dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))682# Set the training transforms683train_dataset = dataset["train"].with_transform(preprocess_train)684685def collate_fn(examples):686original_pixel_values = torch.stack([example["original_pixel_values"] for example in examples])687original_pixel_values = original_pixel_values.to(memory_format=torch.contiguous_format).float()688edited_pixel_values = torch.stack([example["edited_pixel_values"] for example in examples])689edited_pixel_values = edited_pixel_values.to(memory_format=torch.contiguous_format).float()690input_ids = torch.stack([example["input_ids"] for example in examples])691return {692"original_pixel_values": original_pixel_values,693"edited_pixel_values": edited_pixel_values,694"input_ids": input_ids,695}696697# DataLoaders creation:698train_dataloader = torch.utils.data.DataLoader(699train_dataset,700shuffle=True,701collate_fn=collate_fn,702batch_size=args.train_batch_size,703num_workers=args.dataloader_num_workers,704)705706# Scheduler and math around the number of training steps.707overrode_max_train_steps = False708num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)709if args.max_train_steps is None:710args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch711overrode_max_train_steps = True712713lr_scheduler = get_scheduler(714args.lr_scheduler,715optimizer=optimizer,716num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,717num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,718)719720# Prepare everything with our `accelerator`.721unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(722unet, optimizer, train_dataloader, lr_scheduler723)724725if args.use_ema:726ema_unet.to(accelerator.device)727728# For mixed precision training we cast the text_encoder and vae weights to half-precision729# as these models are only used for inference, keeping weights in full precision is not required.730weight_dtype = torch.float32731if accelerator.mixed_precision == "fp16":732weight_dtype = torch.float16733elif accelerator.mixed_precision == "bf16":734weight_dtype = torch.bfloat16735736# Move text_encode and vae to gpu and cast to weight_dtype737text_encoder.to(accelerator.device, dtype=weight_dtype)738vae.to(accelerator.device, dtype=weight_dtype)739740# We need to recalculate our total training steps as the size of the training dataloader may have changed.741num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)742if overrode_max_train_steps:743args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch744# Afterwards we recalculate our number of training epochs745args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)746747# We need to initialize the trackers we use, and also store our configuration.748# The trackers initializes automatically on the main process.749if accelerator.is_main_process:750accelerator.init_trackers("instruct-pix2pix", config=vars(args))751752# Train!753total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps754755logger.info("***** Running training *****")756logger.info(f" Num examples = {len(train_dataset)}")757logger.info(f" Num Epochs = {args.num_train_epochs}")758logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")759logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")760logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")761logger.info(f" Total optimization steps = {args.max_train_steps}")762global_step = 0763first_epoch = 0764765# Potentially load in the weights and states from a previous save766if args.resume_from_checkpoint:767if args.resume_from_checkpoint != "latest":768path = os.path.basename(args.resume_from_checkpoint)769else:770# Get the most recent checkpoint771dirs = os.listdir(args.output_dir)772dirs = [d for d in dirs if d.startswith("checkpoint")]773dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))774path = dirs[-1] if len(dirs) > 0 else None775776if path is None:777accelerator.print(778f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."779)780args.resume_from_checkpoint = None781else:782accelerator.print(f"Resuming from checkpoint {path}")783accelerator.load_state(os.path.join(args.output_dir, path))784global_step = int(path.split("-")[1])785786resume_global_step = global_step * args.gradient_accumulation_steps787first_epoch = global_step // num_update_steps_per_epoch788resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)789790# Only show the progress bar once on each machine.791progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)792progress_bar.set_description("Steps")793794for epoch in range(first_epoch, args.num_train_epochs):795unet.train()796train_loss = 0.0797for step, batch in enumerate(train_dataloader):798# Skip steps until we reach the resumed step799if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:800if step % args.gradient_accumulation_steps == 0:801progress_bar.update(1)802continue803804with accelerator.accumulate(unet):805# We want to learn the denoising process w.r.t the edited images which806# are conditioned on the original image (which was edited) and the edit instruction.807# So, first, convert images to latent space.808latents = vae.encode(batch["edited_pixel_values"].to(weight_dtype)).latent_dist.sample()809latents = latents * vae.config.scaling_factor810811# Sample noise that we'll add to the latents812noise = torch.randn_like(latents)813bsz = latents.shape[0]814# Sample a random timestep for each image815timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)816timesteps = timesteps.long()817818# Add noise to the latents according to the noise magnitude at each timestep819# (this is the forward diffusion process)820noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)821822# Get the text embedding for conditioning.823encoder_hidden_states = text_encoder(batch["input_ids"])[0]824825# Get the additional image embedding for conditioning.826# Instead of getting a diagonal Gaussian here, we simply take the mode.827original_image_embeds = vae.encode(batch["original_pixel_values"].to(weight_dtype)).latent_dist.mode()828829# Conditioning dropout to support classifier-free guidance during inference. For more details830# check out the section 3.2.1 of the original paper https://arxiv.org/abs/2211.09800.831if args.conditioning_dropout_prob is not None:832random_p = torch.rand(bsz, device=latents.device, generator=generator)833# Sample masks for the edit prompts.834prompt_mask = random_p < 2 * args.conditioning_dropout_prob835prompt_mask = prompt_mask.reshape(bsz, 1, 1)836# Final text conditioning.837null_conditioning = text_encoder(tokenize_captions([""]).to(accelerator.device))[0]838encoder_hidden_states = torch.where(prompt_mask, null_conditioning, encoder_hidden_states)839840# Sample masks for the original images.841image_mask_dtype = original_image_embeds.dtype842image_mask = 1 - (843(random_p >= args.conditioning_dropout_prob).to(image_mask_dtype)844* (random_p < 3 * args.conditioning_dropout_prob).to(image_mask_dtype)845)846image_mask = image_mask.reshape(bsz, 1, 1, 1)847# Final image conditioning.848original_image_embeds = image_mask * original_image_embeds849850# Concatenate the `original_image_embeds` with the `noisy_latents`.851concatenated_noisy_latents = torch.cat([noisy_latents, original_image_embeds], dim=1)852853# Get the target for loss depending on the prediction type854if noise_scheduler.config.prediction_type == "epsilon":855target = noise856elif noise_scheduler.config.prediction_type == "v_prediction":857target = noise_scheduler.get_velocity(latents, noise, timesteps)858else:859raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")860861# Predict the noise residual and compute loss862model_pred = unet(concatenated_noisy_latents, timesteps, encoder_hidden_states).sample863loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")864865# Gather the losses across all processes for logging (if we use distributed training).866avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()867train_loss += avg_loss.item() / args.gradient_accumulation_steps868869# Backpropagate870accelerator.backward(loss)871if accelerator.sync_gradients:872accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)873optimizer.step()874lr_scheduler.step()875optimizer.zero_grad()876877# Checks if the accelerator has performed an optimization step behind the scenes878if accelerator.sync_gradients:879if args.use_ema:880ema_unet.step(unet.parameters())881progress_bar.update(1)882global_step += 1883accelerator.log({"train_loss": train_loss}, step=global_step)884train_loss = 0.0885886if global_step % args.checkpointing_steps == 0:887if accelerator.is_main_process:888save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")889accelerator.save_state(save_path)890logger.info(f"Saved state to {save_path}")891892logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}893progress_bar.set_postfix(**logs)894895if global_step >= args.max_train_steps:896break897898if accelerator.is_main_process:899if (900(args.val_image_url is not None)901and (args.validation_prompt is not None)902and (epoch % args.validation_epochs == 0)903):904logger.info(905f"Running validation... \n Generating {args.num_validation_images} images with prompt:"906f" {args.validation_prompt}."907)908# create pipeline909if args.use_ema:910# Store the UNet parameters temporarily and load the EMA parameters to perform inference.911ema_unet.store(unet.parameters())912ema_unet.copy_to(unet.parameters())913pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(914args.pretrained_model_name_or_path,915unet=unet,916revision=args.revision,917torch_dtype=weight_dtype,918)919pipeline = pipeline.to(accelerator.device)920pipeline.set_progress_bar_config(disable=True)921922# run inference923original_image = download_image(args.val_image_url)924edited_images = []925with torch.autocast(str(accelerator.device), enabled=accelerator.mixed_precision == "fp16"):926for _ in range(args.num_validation_images):927edited_images.append(928pipeline(929args.validation_prompt,930image=original_image,931num_inference_steps=20,932image_guidance_scale=1.5,933guidance_scale=7,934generator=generator,935).images[0]936)937938for tracker in accelerator.trackers:939if tracker.name == "wandb":940wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)941for edited_image in edited_images:942wandb_table.add_data(943wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt944)945tracker.log({"validation": wandb_table})946if args.use_ema:947# Switch back to the original UNet parameters.948ema_unet.restore(unet.parameters())949950del pipeline951torch.cuda.empty_cache()952953# Create the pipeline using the trained modules and save it.954accelerator.wait_for_everyone()955if accelerator.is_main_process:956unet = accelerator.unwrap_model(unet)957if args.use_ema:958ema_unet.copy_to(unet.parameters())959960pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(961args.pretrained_model_name_or_path,962text_encoder=accelerator.unwrap_model(text_encoder),963vae=accelerator.unwrap_model(vae),964unet=unet,965revision=args.revision,966)967pipeline.save_pretrained(args.output_dir)968969if args.push_to_hub:970repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)971972if args.validation_prompt is not None:973edited_images = []974pipeline = pipeline.to(accelerator.device)975with torch.autocast(str(accelerator.device)):976for _ in range(args.num_validation_images):977edited_images.append(978pipeline(979args.validation_prompt,980image=original_image,981num_inference_steps=20,982image_guidance_scale=1.5,983guidance_scale=7,984generator=generator,985).images[0]986)987988for tracker in accelerator.trackers:989if tracker.name == "wandb":990wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)991for edited_image in edited_images:992wandb_table.add_data(993wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt994)995tracker.log({"test": wandb_table})996997accelerator.end_training()9989991000if __name__ == "__main__":1001main()100210031004