Path: blob/main/examples/unconditional_image_generation/train_unconditional.py
1448 views
import argparse1import inspect2import logging3import math4import os5from pathlib import Path6from typing import Optional78import accelerate9import datasets10import torch11import torch.nn.functional as F12from accelerate import Accelerator13from accelerate.logging import get_logger14from accelerate.utils import ProjectConfiguration15from datasets import load_dataset16from huggingface_hub import HfFolder, Repository, create_repo, whoami17from packaging import version18from torchvision import transforms19from tqdm.auto import tqdm2021import diffusers22from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel23from diffusers.optimization import get_scheduler24from diffusers.training_utils import EMAModel25from diffusers.utils import check_min_version, is_accelerate_version, is_tensorboard_available, is_wandb_available26from diffusers.utils.import_utils import is_xformers_available272829# Will error if the minimal version of diffusers is not installed. Remove at your own risks.30check_min_version("0.15.0.dev0")3132logger = get_logger(__name__, log_level="INFO")333435def _extract_into_tensor(arr, timesteps, broadcast_shape):36"""37Extract values from a 1-D numpy array for a batch of indices.3839:param arr: the 1-D numpy array.40:param timesteps: a tensor of indices into the array to extract.41:param broadcast_shape: a larger shape of K dimensions with the batch42dimension equal to the length of timesteps.43:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.44"""45if not isinstance(arr, torch.Tensor):46arr = torch.from_numpy(arr)47res = arr[timesteps].float().to(timesteps.device)48while len(res.shape) < len(broadcast_shape):49res = res[..., None]50return res.expand(broadcast_shape)515253def parse_args():54parser = argparse.ArgumentParser(description="Simple example of a training script.")55parser.add_argument(56"--dataset_name",57type=str,58default=None,59help=(60"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"61" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"62" or to a folder containing files that HF Datasets can understand."63),64)65parser.add_argument(66"--dataset_config_name",67type=str,68default=None,69help="The config of the Dataset, leave as None if there's only one config.",70)71parser.add_argument(72"--model_config_name_or_path",73type=str,74default=None,75help="The config of the UNet model to train, leave as None to use standard DDPM configuration.",76)77parser.add_argument(78"--train_data_dir",79type=str,80default=None,81help=(82"A folder containing the training data. Folder contents must follow the structure described in"83" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"84" must exist to provide the captions for the images. Ignored if `dataset_name` is specified."85),86)87parser.add_argument(88"--output_dir",89type=str,90default="ddpm-model-64",91help="The output directory where the model predictions and checkpoints will be written.",92)93parser.add_argument("--overwrite_output_dir", action="store_true")94parser.add_argument(95"--cache_dir",96type=str,97default=None,98help="The directory where the downloaded models and datasets will be stored.",99)100parser.add_argument(101"--resolution",102type=int,103default=64,104help=(105"The resolution for input images, all the images in the train/validation dataset will be resized to this"106" resolution"107),108)109parser.add_argument(110"--center_crop",111default=False,112action="store_true",113help=(114"Whether to center crop the input images to the resolution. If not set, the images will be randomly"115" cropped. The images will be resized to the resolution first before cropping."116),117)118parser.add_argument(119"--random_flip",120default=False,121action="store_true",122help="whether to randomly flip images horizontally",123)124parser.add_argument(125"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."126)127parser.add_argument(128"--eval_batch_size", type=int, default=16, help="The number of images to generate for evaluation."129)130parser.add_argument(131"--dataloader_num_workers",132type=int,133default=0,134help=(135"The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main"136" process."137),138)139parser.add_argument("--num_epochs", type=int, default=100)140parser.add_argument("--save_images_epochs", type=int, default=10, help="How often to save images during training.")141parser.add_argument(142"--save_model_epochs", type=int, default=10, help="How often to save the model during training."143)144parser.add_argument(145"--gradient_accumulation_steps",146type=int,147default=1,148help="Number of updates steps to accumulate before performing a backward/update pass.",149)150parser.add_argument(151"--learning_rate",152type=float,153default=1e-4,154help="Initial learning rate (after the potential warmup period) to use.",155)156parser.add_argument(157"--lr_scheduler",158type=str,159default="cosine",160help=(161'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'162' "constant", "constant_with_warmup"]'163),164)165parser.add_argument(166"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."167)168parser.add_argument("--adam_beta1", type=float, default=0.95, help="The beta1 parameter for the Adam optimizer.")169parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")170parser.add_argument(171"--adam_weight_decay", type=float, default=1e-6, help="Weight decay magnitude for the Adam optimizer."172)173parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer.")174parser.add_argument(175"--use_ema",176action="store_true",177help="Whether to use Exponential Moving Average for the final model weights.",178)179parser.add_argument("--ema_inv_gamma", type=float, default=1.0, help="The inverse gamma value for the EMA decay.")180parser.add_argument("--ema_power", type=float, default=3 / 4, help="The power value for the EMA decay.")181parser.add_argument("--ema_max_decay", type=float, default=0.9999, help="The maximum decay magnitude for EMA.")182parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")183parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")184parser.add_argument(185"--hub_model_id",186type=str,187default=None,188help="The name of the repository to keep in sync with the local `output_dir`.",189)190parser.add_argument(191"--hub_private_repo", action="store_true", help="Whether or not to create a private repository."192)193parser.add_argument(194"--logger",195type=str,196default="tensorboard",197choices=["tensorboard", "wandb"],198help=(199"Whether to use [tensorboard](https://www.tensorflow.org/tensorboard) or [wandb](https://www.wandb.ai)"200" for experiment tracking and logging of model metrics and model checkpoints"201),202)203parser.add_argument(204"--logging_dir",205type=str,206default="logs",207help=(208"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"209" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."210),211)212parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")213parser.add_argument(214"--mixed_precision",215type=str,216default="no",217choices=["no", "fp16", "bf16"],218help=(219"Whether to use mixed precision. Choose"220"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."221"and an Nvidia Ampere GPU."222),223)224parser.add_argument(225"--prediction_type",226type=str,227default="epsilon",228choices=["epsilon", "sample"],229help="Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.",230)231parser.add_argument("--ddpm_num_steps", type=int, default=1000)232parser.add_argument("--ddpm_num_inference_steps", type=int, default=1000)233parser.add_argument("--ddpm_beta_schedule", type=str, default="linear")234parser.add_argument(235"--checkpointing_steps",236type=int,237default=500,238help=(239"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"240" training using `--resume_from_checkpoint`."241),242)243parser.add_argument(244"--checkpoints_total_limit",245type=int,246default=None,247help=(248"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."249" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"250" for more docs"251),252)253parser.add_argument(254"--resume_from_checkpoint",255type=str,256default=None,257help=(258"Whether training should be resumed from a previous checkpoint. Use a path saved by"259' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'260),261)262parser.add_argument(263"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."264)265266args = parser.parse_args()267env_local_rank = int(os.environ.get("LOCAL_RANK", -1))268if env_local_rank != -1 and env_local_rank != args.local_rank:269args.local_rank = env_local_rank270271if args.dataset_name is None and args.train_data_dir is None:272raise ValueError("You must specify either a dataset name from the hub or a train data directory.")273274return args275276277def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):278if token is None:279token = HfFolder.get_token()280if organization is None:281username = whoami(token)["name"]282return f"{username}/{model_id}"283else:284return f"{organization}/{model_id}"285286287def main(args):288logging_dir = os.path.join(args.output_dir, args.logging_dir)289290accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit)291292accelerator = Accelerator(293gradient_accumulation_steps=args.gradient_accumulation_steps,294mixed_precision=args.mixed_precision,295log_with=args.logger,296logging_dir=logging_dir,297project_config=accelerator_project_config,298)299300if args.logger == "tensorboard":301if not is_tensorboard_available():302raise ImportError("Make sure to install tensorboard if you want to use it for logging during training.")303304elif args.logger == "wandb":305if not is_wandb_available():306raise ImportError("Make sure to install wandb if you want to use it for logging during training.")307import wandb308309# `accelerate` 0.16.0 will have better support for customized saving310if version.parse(accelerate.__version__) >= version.parse("0.16.0"):311# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format312def save_model_hook(models, weights, output_dir):313if args.use_ema:314ema_model.save_pretrained(os.path.join(output_dir, "unet_ema"))315316for i, model in enumerate(models):317model.save_pretrained(os.path.join(output_dir, "unet"))318319# make sure to pop weight so that corresponding model is not saved again320weights.pop()321322def load_model_hook(models, input_dir):323if args.use_ema:324load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DModel)325ema_model.load_state_dict(load_model.state_dict())326ema_model.to(accelerator.device)327del load_model328329for i in range(len(models)):330# pop models so that they are not loaded again331model = models.pop()332333# load diffusers style into model334load_model = UNet2DModel.from_pretrained(input_dir, subfolder="unet")335model.register_to_config(**load_model.config)336337model.load_state_dict(load_model.state_dict())338del load_model339340accelerator.register_save_state_pre_hook(save_model_hook)341accelerator.register_load_state_pre_hook(load_model_hook)342343# Make one log on every process with the configuration for debugging.344logging.basicConfig(345format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",346datefmt="%m/%d/%Y %H:%M:%S",347level=logging.INFO,348)349logger.info(accelerator.state, main_process_only=False)350if accelerator.is_local_main_process:351datasets.utils.logging.set_verbosity_warning()352diffusers.utils.logging.set_verbosity_info()353else:354datasets.utils.logging.set_verbosity_error()355diffusers.utils.logging.set_verbosity_error()356357# Handle the repository creation358if accelerator.is_main_process:359if args.push_to_hub:360if args.hub_model_id is None:361repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)362else:363repo_name = args.hub_model_id364create_repo(repo_name, exist_ok=True, token=args.hub_token)365repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)366367with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:368if "step_*" not in gitignore:369gitignore.write("step_*\n")370if "epoch_*" not in gitignore:371gitignore.write("epoch_*\n")372elif args.output_dir is not None:373os.makedirs(args.output_dir, exist_ok=True)374375# Initialize the model376if args.model_config_name_or_path is None:377model = UNet2DModel(378sample_size=args.resolution,379in_channels=3,380out_channels=3,381layers_per_block=2,382block_out_channels=(128, 128, 256, 256, 512, 512),383down_block_types=(384"DownBlock2D",385"DownBlock2D",386"DownBlock2D",387"DownBlock2D",388"AttnDownBlock2D",389"DownBlock2D",390),391up_block_types=(392"UpBlock2D",393"AttnUpBlock2D",394"UpBlock2D",395"UpBlock2D",396"UpBlock2D",397"UpBlock2D",398),399)400else:401config = UNet2DModel.load_config(args.model_config_name_or_path)402model = UNet2DModel.from_config(config)403404# Create EMA for the model.405if args.use_ema:406ema_model = EMAModel(407model.parameters(),408decay=args.ema_max_decay,409use_ema_warmup=True,410inv_gamma=args.ema_inv_gamma,411power=args.ema_power,412model_cls=UNet2DModel,413model_config=model.config,414)415416if args.enable_xformers_memory_efficient_attention:417if is_xformers_available():418import xformers419420xformers_version = version.parse(xformers.__version__)421if xformers_version == version.parse("0.0.16"):422logger.warn(423"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."424)425model.enable_xformers_memory_efficient_attention()426else:427raise ValueError("xformers is not available. Make sure it is installed correctly")428429# Initialize the scheduler430accepts_prediction_type = "prediction_type" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())431if accepts_prediction_type:432noise_scheduler = DDPMScheduler(433num_train_timesteps=args.ddpm_num_steps,434beta_schedule=args.ddpm_beta_schedule,435prediction_type=args.prediction_type,436)437else:438noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule)439440# Initialize the optimizer441optimizer = torch.optim.AdamW(442model.parameters(),443lr=args.learning_rate,444betas=(args.adam_beta1, args.adam_beta2),445weight_decay=args.adam_weight_decay,446eps=args.adam_epsilon,447)448449# Get the datasets: you can either provide your own training and evaluation files (see below)450# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).451452# In distributed training, the load_dataset function guarantees that only one local process can concurrently453# download the dataset.454if args.dataset_name is not None:455dataset = load_dataset(456args.dataset_name,457args.dataset_config_name,458cache_dir=args.cache_dir,459split="train",460)461else:462dataset = load_dataset("imagefolder", data_dir=args.train_data_dir, cache_dir=args.cache_dir, split="train")463# See more about loading custom images at464# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder465466# Preprocessing the datasets and DataLoaders creation.467augmentations = transforms.Compose(468[469transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),470transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),471transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),472transforms.ToTensor(),473transforms.Normalize([0.5], [0.5]),474]475)476477def transform_images(examples):478images = [augmentations(image.convert("RGB")) for image in examples["image"]]479return {"input": images}480481logger.info(f"Dataset size: {len(dataset)}")482483dataset.set_transform(transform_images)484train_dataloader = torch.utils.data.DataLoader(485dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers486)487488# Initialize the learning rate scheduler489lr_scheduler = get_scheduler(490args.lr_scheduler,491optimizer=optimizer,492num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,493num_training_steps=(len(train_dataloader) * args.num_epochs),494)495496# Prepare everything with our `accelerator`.497model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(498model, optimizer, train_dataloader, lr_scheduler499)500501if args.use_ema:502ema_model.to(accelerator.device)503504# We need to initialize the trackers we use, and also store our configuration.505# The trackers initializes automatically on the main process.506if accelerator.is_main_process:507run = os.path.split(__file__)[-1].split(".")[0]508accelerator.init_trackers(run)509510total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps511num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)512max_train_steps = args.num_epochs * num_update_steps_per_epoch513514logger.info("***** Running training *****")515logger.info(f" Num examples = {len(dataset)}")516logger.info(f" Num Epochs = {args.num_epochs}")517logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")518logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")519logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")520logger.info(f" Total optimization steps = {max_train_steps}")521522global_step = 0523first_epoch = 0524525# Potentially load in the weights and states from a previous save526if args.resume_from_checkpoint:527if args.resume_from_checkpoint != "latest":528path = os.path.basename(args.resume_from_checkpoint)529else:530# Get the most recent checkpoint531dirs = os.listdir(args.output_dir)532dirs = [d for d in dirs if d.startswith("checkpoint")]533dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))534path = dirs[-1] if len(dirs) > 0 else None535536if path is None:537accelerator.print(538f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."539)540args.resume_from_checkpoint = None541else:542accelerator.print(f"Resuming from checkpoint {path}")543accelerator.load_state(os.path.join(args.output_dir, path))544global_step = int(path.split("-")[1])545546resume_global_step = global_step * args.gradient_accumulation_steps547first_epoch = global_step // num_update_steps_per_epoch548resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)549550# Train!551for epoch in range(first_epoch, args.num_epochs):552model.train()553progress_bar = tqdm(total=num_update_steps_per_epoch, disable=not accelerator.is_local_main_process)554progress_bar.set_description(f"Epoch {epoch}")555for step, batch in enumerate(train_dataloader):556# Skip steps until we reach the resumed step557if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:558if step % args.gradient_accumulation_steps == 0:559progress_bar.update(1)560continue561562clean_images = batch["input"]563# Sample noise that we'll add to the images564noise = torch.randn(clean_images.shape).to(clean_images.device)565bsz = clean_images.shape[0]566# Sample a random timestep for each image567timesteps = torch.randint(5680, noise_scheduler.config.num_train_timesteps, (bsz,), device=clean_images.device569).long()570571# Add noise to the clean images according to the noise magnitude at each timestep572# (this is the forward diffusion process)573noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)574575with accelerator.accumulate(model):576# Predict the noise residual577model_output = model(noisy_images, timesteps).sample578579if args.prediction_type == "epsilon":580loss = F.mse_loss(model_output, noise) # this could have different weights!581elif args.prediction_type == "sample":582alpha_t = _extract_into_tensor(583noise_scheduler.alphas_cumprod, timesteps, (clean_images.shape[0], 1, 1, 1)584)585snr_weights = alpha_t / (1 - alpha_t)586loss = snr_weights * F.mse_loss(587model_output, clean_images, reduction="none"588) # use SNR weighting from distillation paper589loss = loss.mean()590else:591raise ValueError(f"Unsupported prediction type: {args.prediction_type}")592593accelerator.backward(loss)594595if accelerator.sync_gradients:596accelerator.clip_grad_norm_(model.parameters(), 1.0)597optimizer.step()598lr_scheduler.step()599optimizer.zero_grad()600601# Checks if the accelerator has performed an optimization step behind the scenes602if accelerator.sync_gradients:603if args.use_ema:604ema_model.step(model.parameters())605progress_bar.update(1)606global_step += 1607608if global_step % args.checkpointing_steps == 0:609if accelerator.is_main_process:610save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")611accelerator.save_state(save_path)612logger.info(f"Saved state to {save_path}")613614logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}615if args.use_ema:616logs["ema_decay"] = ema_model.cur_decay_value617progress_bar.set_postfix(**logs)618accelerator.log(logs, step=global_step)619progress_bar.close()620621accelerator.wait_for_everyone()622623# Generate sample images for visual inspection624if accelerator.is_main_process:625if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1:626unet = accelerator.unwrap_model(model)627628if args.use_ema:629ema_model.store(unet.parameters())630ema_model.copy_to(unet.parameters())631632pipeline = DDPMPipeline(633unet=unet,634scheduler=noise_scheduler,635)636637generator = torch.Generator(device=pipeline.device).manual_seed(0)638# run pipeline in inference (sample random noise and denoise)639images = pipeline(640generator=generator,641batch_size=args.eval_batch_size,642num_inference_steps=args.ddpm_num_inference_steps,643output_type="numpy",644).images645646if args.use_ema:647ema_model.restore(unet.parameters())648649# denormalize the images and save to tensorboard650images_processed = (images * 255).round().astype("uint8")651652if args.logger == "tensorboard":653if is_accelerate_version(">=", "0.17.0.dev0"):654tracker = accelerator.get_tracker("tensorboard", unwrap=True)655else:656tracker = accelerator.get_tracker("tensorboard")657tracker.add_images("test_samples", images_processed.transpose(0, 3, 1, 2), epoch)658elif args.logger == "wandb":659# Upcoming `log_images` helper coming in https://github.com/huggingface/accelerate/pull/962/files660accelerator.get_tracker("wandb").log(661{"test_samples": [wandb.Image(img) for img in images_processed], "epoch": epoch},662step=global_step,663)664665if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:666# save the model667unet = accelerator.unwrap_model(model)668669if args.use_ema:670ema_model.store(unet.parameters())671ema_model.copy_to(unet.parameters())672673pipeline = DDPMPipeline(674unet=unet,675scheduler=noise_scheduler,676)677678pipeline.save_pretrained(args.output_dir)679680if args.use_ema:681ema_model.restore(unet.parameters())682683if args.push_to_hub:684repo.push_to_hub(commit_message=f"Epoch {epoch}", blocking=False)685686accelerator.end_training()687688689if __name__ == "__main__":690args = parse_args()691main(args)692693694