Path: blob/main/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py
1980 views
import argparse1import inspect2import logging3import math4import os5from pathlib import Path6from typing import Optional78import datasets9import torch10import torch.nn.functional as F11from accelerate import Accelerator12from accelerate.logging import get_logger13from accelerate.utils import ProjectConfiguration14from datasets import load_dataset15from huggingface_hub import HfFolder, Repository, create_repo, whoami16from onnxruntime.training.ortmodule import ORTModule17from torchvision import transforms18from tqdm.auto import tqdm1920import diffusers21from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel22from diffusers.optimization import get_scheduler23from diffusers.training_utils import EMAModel24from diffusers.utils import check_min_version, is_tensorboard_available, is_wandb_available252627# Will error if the minimal version of diffusers is not installed. Remove at your own risks.28check_min_version("0.13.0.dev0")2930logger = get_logger(__name__, log_level="INFO")313233def _extract_into_tensor(arr, timesteps, broadcast_shape):34"""35Extract values from a 1-D numpy array for a batch of indices.36:param arr: the 1-D numpy array.37:param timesteps: a tensor of indices into the array to extract.38:param broadcast_shape: a larger shape of K dimensions with the batch39dimension equal to the length of timesteps.40:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.41"""42if not isinstance(arr, torch.Tensor):43arr = torch.from_numpy(arr)44res = arr[timesteps].float().to(timesteps.device)45while len(res.shape) < len(broadcast_shape):46res = res[..., None]47return res.expand(broadcast_shape)484950def parse_args():51parser = argparse.ArgumentParser(description="Simple example of a training script.")52parser.add_argument(53"--dataset_name",54type=str,55default=None,56help=(57"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"58" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"59" or to a folder containing files that HF Datasets can understand."60),61)62parser.add_argument(63"--dataset_config_name",64type=str,65default=None,66help="The config of the Dataset, leave as None if there's only one config.",67)68parser.add_argument(69"--train_data_dir",70type=str,71default=None,72help=(73"A folder containing the training data. Folder contents must follow the structure described in"74" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"75" must exist to provide the captions for the images. Ignored if `dataset_name` is specified."76),77)78parser.add_argument(79"--output_dir",80type=str,81default="ddpm-model-64",82help="The output directory where the model predictions and checkpoints will be written.",83)84parser.add_argument("--overwrite_output_dir", action="store_true")85parser.add_argument(86"--cache_dir",87type=str,88default=None,89help="The directory where the downloaded models and datasets will be stored.",90)91parser.add_argument(92"--resolution",93type=int,94default=64,95help=(96"The resolution for input images, all the images in the train/validation dataset will be resized to this"97" resolution"98),99)100parser.add_argument(101"--center_crop",102default=False,103action="store_true",104help=(105"Whether to center crop the input images to the resolution. If not set, the images will be randomly"106" cropped. The images will be resized to the resolution first before cropping."107),108)109parser.add_argument(110"--random_flip",111default=False,112action="store_true",113help="whether to randomly flip images horizontally",114)115parser.add_argument(116"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."117)118parser.add_argument(119"--eval_batch_size", type=int, default=16, help="The number of images to generate for evaluation."120)121parser.add_argument(122"--dataloader_num_workers",123type=int,124default=0,125help=(126"The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main"127" process."128),129)130parser.add_argument("--num_epochs", type=int, default=100)131parser.add_argument("--save_images_epochs", type=int, default=10, help="How often to save images during training.")132parser.add_argument(133"--save_model_epochs", type=int, default=10, help="How often to save the model during training."134)135parser.add_argument(136"--gradient_accumulation_steps",137type=int,138default=1,139help="Number of updates steps to accumulate before performing a backward/update pass.",140)141parser.add_argument(142"--learning_rate",143type=float,144default=1e-4,145help="Initial learning rate (after the potential warmup period) to use.",146)147parser.add_argument(148"--lr_scheduler",149type=str,150default="cosine",151help=(152'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'153' "constant", "constant_with_warmup"]'154),155)156parser.add_argument(157"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."158)159parser.add_argument("--adam_beta1", type=float, default=0.95, help="The beta1 parameter for the Adam optimizer.")160parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")161parser.add_argument(162"--adam_weight_decay", type=float, default=1e-6, help="Weight decay magnitude for the Adam optimizer."163)164parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer.")165parser.add_argument(166"--use_ema",167action="store_true",168help="Whether to use Exponential Moving Average for the final model weights.",169)170parser.add_argument("--ema_inv_gamma", type=float, default=1.0, help="The inverse gamma value for the EMA decay.")171parser.add_argument("--ema_power", type=float, default=3 / 4, help="The power value for the EMA decay.")172parser.add_argument("--ema_max_decay", type=float, default=0.9999, help="The maximum decay magnitude for EMA.")173parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")174parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")175parser.add_argument(176"--hub_model_id",177type=str,178default=None,179help="The name of the repository to keep in sync with the local `output_dir`.",180)181parser.add_argument(182"--hub_private_repo", action="store_true", help="Whether or not to create a private repository."183)184parser.add_argument(185"--logger",186type=str,187default="tensorboard",188choices=["tensorboard", "wandb"],189help=(190"Whether to use [tensorboard](https://www.tensorflow.org/tensorboard) or [wandb](https://www.wandb.ai)"191" for experiment tracking and logging of model metrics and model checkpoints"192),193)194parser.add_argument(195"--logging_dir",196type=str,197default="logs",198help=(199"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"200" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."201),202)203parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")204parser.add_argument(205"--mixed_precision",206type=str,207default="no",208choices=["no", "fp16", "bf16"],209help=(210"Whether to use mixed precision. Choose"211"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."212"and an Nvidia Ampere GPU."213),214)215parser.add_argument(216"--prediction_type",217type=str,218default="epsilon",219choices=["epsilon", "sample"],220help="Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.",221)222parser.add_argument("--ddpm_num_steps", type=int, default=1000)223parser.add_argument("--ddpm_num_inference_steps", type=int, default=1000)224parser.add_argument("--ddpm_beta_schedule", type=str, default="linear")225parser.add_argument(226"--checkpointing_steps",227type=int,228default=500,229help=(230"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"231" training using `--resume_from_checkpoint`."232),233)234parser.add_argument(235"--checkpoints_total_limit",236type=int,237default=None,238help=(239"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."240" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"241" for more docs"242),243)244parser.add_argument(245"--resume_from_checkpoint",246type=str,247default=None,248help=(249"Whether training should be resumed from a previous checkpoint. Use a path saved by"250' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'251),252)253254args = parser.parse_args()255env_local_rank = int(os.environ.get("LOCAL_RANK", -1))256if env_local_rank != -1 and env_local_rank != args.local_rank:257args.local_rank = env_local_rank258259if args.dataset_name is None and args.train_data_dir is None:260raise ValueError("You must specify either a dataset name from the hub or a train data directory.")261262return args263264265def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):266if token is None:267token = HfFolder.get_token()268if organization is None:269username = whoami(token)["name"]270return f"{username}/{model_id}"271else:272return f"{organization}/{model_id}"273274275def main(args):276logging_dir = os.path.join(args.output_dir, args.logging_dir)277278accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit)279280accelerator = Accelerator(281gradient_accumulation_steps=args.gradient_accumulation_steps,282mixed_precision=args.mixed_precision,283log_with=args.logger,284logging_dir=logging_dir,285project_config=accelerator_project_config,286)287288if args.logger == "tensorboard":289if not is_tensorboard_available():290raise ImportError("Make sure to install tensorboard if you want to use it for logging during training.")291292elif args.logger == "wandb":293if not is_wandb_available():294raise ImportError("Make sure to install wandb if you want to use it for logging during training.")295import wandb296297# Make one log on every process with the configuration for debugging.298logging.basicConfig(299format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",300datefmt="%m/%d/%Y %H:%M:%S",301level=logging.INFO,302)303logger.info(accelerator.state, main_process_only=False)304if accelerator.is_local_main_process:305datasets.utils.logging.set_verbosity_warning()306diffusers.utils.logging.set_verbosity_info()307else:308datasets.utils.logging.set_verbosity_error()309diffusers.utils.logging.set_verbosity_error()310311# Handle the repository creation312if accelerator.is_main_process:313if args.push_to_hub:314if args.hub_model_id is None:315repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)316else:317repo_name = args.hub_model_id318create_repo(repo_name, exist_ok=True, token=args.hub_token)319repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)320321with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:322if "step_*" not in gitignore:323gitignore.write("step_*\n")324if "epoch_*" not in gitignore:325gitignore.write("epoch_*\n")326elif args.output_dir is not None:327os.makedirs(args.output_dir, exist_ok=True)328329# Initialize the model330model = UNet2DModel(331sample_size=args.resolution,332in_channels=3,333out_channels=3,334layers_per_block=2,335block_out_channels=(128, 128, 256, 256, 512, 512),336down_block_types=(337"DownBlock2D",338"DownBlock2D",339"DownBlock2D",340"DownBlock2D",341"AttnDownBlock2D",342"DownBlock2D",343),344up_block_types=(345"UpBlock2D",346"AttnUpBlock2D",347"UpBlock2D",348"UpBlock2D",349"UpBlock2D",350"UpBlock2D",351),352)353354# Create EMA for the model.355if args.use_ema:356ema_model = EMAModel(357model.parameters(),358decay=args.ema_max_decay,359use_ema_warmup=True,360inv_gamma=args.ema_inv_gamma,361power=args.ema_power,362)363364# Initialize the scheduler365accepts_prediction_type = "prediction_type" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())366if accepts_prediction_type:367noise_scheduler = DDPMScheduler(368num_train_timesteps=args.ddpm_num_steps,369beta_schedule=args.ddpm_beta_schedule,370prediction_type=args.prediction_type,371)372else:373noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule)374375# Initialize the optimizer376optimizer = torch.optim.AdamW(377model.parameters(),378lr=args.learning_rate,379betas=(args.adam_beta1, args.adam_beta2),380weight_decay=args.adam_weight_decay,381eps=args.adam_epsilon,382)383384# Get the datasets: you can either provide your own training and evaluation files (see below)385# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).386387# In distributed training, the load_dataset function guarantees that only one local process can concurrently388# download the dataset.389if args.dataset_name is not None:390dataset = load_dataset(391args.dataset_name,392args.dataset_config_name,393cache_dir=args.cache_dir,394split="train",395)396else:397dataset = load_dataset("imagefolder", data_dir=args.train_data_dir, cache_dir=args.cache_dir, split="train")398# See more about loading custom images at399# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder400401# Preprocessing the datasets and DataLoaders creation.402augmentations = transforms.Compose(403[404transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),405transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),406transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),407transforms.ToTensor(),408transforms.Normalize([0.5], [0.5]),409]410)411412def transform_images(examples):413images = [augmentations(image.convert("RGB")) for image in examples["image"]]414return {"input": images}415416logger.info(f"Dataset size: {len(dataset)}")417418dataset.set_transform(transform_images)419train_dataloader = torch.utils.data.DataLoader(420dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers421)422423# Initialize the learning rate scheduler424lr_scheduler = get_scheduler(425args.lr_scheduler,426optimizer=optimizer,427num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,428num_training_steps=(len(train_dataloader) * args.num_epochs),429)430431# Prepare everything with our `accelerator`.432model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(433model, optimizer, train_dataloader, lr_scheduler434)435436model = ORTModule(model)437438if args.use_ema:439accelerator.register_for_checkpointing(ema_model)440ema_model.to(accelerator.device)441442# We need to initialize the trackers we use, and also store our configuration.443# The trackers initializes automatically on the main process.444if accelerator.is_main_process:445run = os.path.split(__file__)[-1].split(".")[0]446accelerator.init_trackers(run)447448total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps449num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)450max_train_steps = args.num_epochs * num_update_steps_per_epoch451452logger.info("***** Running training *****")453logger.info(f" Num examples = {len(dataset)}")454logger.info(f" Num Epochs = {args.num_epochs}")455logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")456logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")457logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")458logger.info(f" Total optimization steps = {max_train_steps}")459460global_step = 0461first_epoch = 0462463# Potentially load in the weights and states from a previous save464if args.resume_from_checkpoint:465if args.resume_from_checkpoint != "latest":466path = os.path.basename(args.resume_from_checkpoint)467else:468# Get the most recent checkpoint469dirs = os.listdir(args.output_dir)470dirs = [d for d in dirs if d.startswith("checkpoint")]471dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))472path = dirs[-1] if len(dirs) > 0 else None473474if path is None:475accelerator.print(476f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."477)478args.resume_from_checkpoint = None479else:480accelerator.print(f"Resuming from checkpoint {path}")481accelerator.load_state(os.path.join(args.output_dir, path))482global_step = int(path.split("-")[1])483484resume_global_step = global_step * args.gradient_accumulation_steps485first_epoch = global_step // num_update_steps_per_epoch486resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)487488# Train!489for epoch in range(first_epoch, args.num_epochs):490model.train()491progress_bar = tqdm(total=num_update_steps_per_epoch, disable=not accelerator.is_local_main_process)492progress_bar.set_description(f"Epoch {epoch}")493for step, batch in enumerate(train_dataloader):494# Skip steps until we reach the resumed step495if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:496if step % args.gradient_accumulation_steps == 0:497progress_bar.update(1)498continue499500clean_images = batch["input"]501# Sample noise that we'll add to the images502noise = torch.randn(clean_images.shape).to(clean_images.device)503bsz = clean_images.shape[0]504# Sample a random timestep for each image505timesteps = torch.randint(5060, noise_scheduler.config.num_train_timesteps, (bsz,), device=clean_images.device507).long()508509# Add noise to the clean images according to the noise magnitude at each timestep510# (this is the forward diffusion process)511noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)512513with accelerator.accumulate(model):514# Predict the noise residual515model_output = model(noisy_images, timesteps, return_dict=False)[0]516517if args.prediction_type == "epsilon":518loss = F.mse_loss(model_output, noise) # this could have different weights!519elif args.prediction_type == "sample":520alpha_t = _extract_into_tensor(521noise_scheduler.alphas_cumprod, timesteps, (clean_images.shape[0], 1, 1, 1)522)523snr_weights = alpha_t / (1 - alpha_t)524loss = snr_weights * F.mse_loss(525model_output, clean_images, reduction="none"526) # use SNR weighting from distillation paper527loss = loss.mean()528else:529raise ValueError(f"Unsupported prediction type: {args.prediction_type}")530531accelerator.backward(loss)532533if accelerator.sync_gradients:534accelerator.clip_grad_norm_(model.parameters(), 1.0)535optimizer.step()536lr_scheduler.step()537optimizer.zero_grad()538539# Checks if the accelerator has performed an optimization step behind the scenes540if accelerator.sync_gradients:541if args.use_ema:542ema_model.step(model.parameters())543progress_bar.update(1)544global_step += 1545546if global_step % args.checkpointing_steps == 0:547if accelerator.is_main_process:548save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")549accelerator.save_state(save_path)550logger.info(f"Saved state to {save_path}")551552logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}553if args.use_ema:554logs["ema_decay"] = ema_model.decay555progress_bar.set_postfix(**logs)556accelerator.log(logs, step=global_step)557progress_bar.close()558559accelerator.wait_for_everyone()560561# Generate sample images for visual inspection562if accelerator.is_main_process:563if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1:564unet = accelerator.unwrap_model(model)565if args.use_ema:566ema_model.copy_to(unet.parameters())567pipeline = DDPMPipeline(568unet=unet,569scheduler=noise_scheduler,570)571572generator = torch.Generator(device=pipeline.device).manual_seed(0)573# run pipeline in inference (sample random noise and denoise)574images = pipeline(575generator=generator,576batch_size=args.eval_batch_size,577output_type="numpy",578num_inference_steps=args.ddpm_num_inference_steps,579).images580581# denormalize the images and save to tensorboard582images_processed = (images * 255).round().astype("uint8")583584if args.logger == "tensorboard":585accelerator.get_tracker("tensorboard").add_images(586"test_samples", images_processed.transpose(0, 3, 1, 2), epoch587)588elif args.logger == "wandb":589accelerator.get_tracker("wandb").log(590{"test_samples": [wandb.Image(img) for img in images_processed], "epoch": epoch},591step=global_step,592)593594if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:595# save the model596pipeline.save_pretrained(args.output_dir)597if args.push_to_hub:598repo.push_to_hub(commit_message=f"Epoch {epoch}", blocking=False)599600accelerator.end_training()601602603if __name__ == "__main__":604args = parse_args()605main(args)606607608