Path: blob/main/examples/text_to_image/train_text_to_image.py
1448 views
#!/usr/bin/env python1# coding=utf-82# Copyright 2023 The HuggingFace Inc. team. All rights reserved.3#4# Licensed under the Apache License, Version 2.0 (the "License");5# you may not use this file except in compliance with the License.6# You may obtain a copy of the License at7#8# http://www.apache.org/licenses/LICENSE-2.09#10# Unless required by applicable law or agreed to in writing, software11# distributed under the License is distributed on an "AS IS" BASIS,12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.13# See the License for the specific language governing permissions and1415import argparse16import logging17import math18import os19import random20from pathlib import Path21from typing import Optional2223import accelerate24import datasets25import numpy as np26import torch27import torch.nn.functional as F28import torch.utils.checkpoint29import transformers30from accelerate import Accelerator31from accelerate.logging import get_logger32from accelerate.utils import ProjectConfiguration, set_seed33from datasets import load_dataset34from huggingface_hub import HfFolder, Repository, create_repo, whoami35from packaging import version36from torchvision import transforms37from tqdm.auto import tqdm38from transformers import CLIPTextModel, CLIPTokenizer3940import diffusers41from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel42from diffusers.optimization import get_scheduler43from diffusers.training_utils import EMAModel44from diffusers.utils import check_min_version, deprecate45from diffusers.utils.import_utils import is_xformers_available464748# Will error if the minimal version of diffusers is not installed. Remove at your own risks.49check_min_version("0.15.0.dev0")5051logger = get_logger(__name__, log_level="INFO")525354def parse_args():55parser = argparse.ArgumentParser(description="Simple example of a training script.")56parser.add_argument(57"--pretrained_model_name_or_path",58type=str,59default=None,60required=True,61help="Path to pretrained model or model identifier from huggingface.co/models.",62)63parser.add_argument(64"--revision",65type=str,66default=None,67required=False,68help="Revision of pretrained model identifier from huggingface.co/models.",69)70parser.add_argument(71"--dataset_name",72type=str,73default=None,74help=(75"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"76" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"77" or to a folder containing files that 🤗 Datasets can understand."78),79)80parser.add_argument(81"--dataset_config_name",82type=str,83default=None,84help="The config of the Dataset, leave as None if there's only one config.",85)86parser.add_argument(87"--train_data_dir",88type=str,89default=None,90help=(91"A folder containing the training data. Folder contents must follow the structure described in"92" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"93" must exist to provide the captions for the images. Ignored if `dataset_name` is specified."94),95)96parser.add_argument(97"--image_column", type=str, default="image", help="The column of the dataset containing an image."98)99parser.add_argument(100"--caption_column",101type=str,102default="text",103help="The column of the dataset containing a caption or a list of captions.",104)105parser.add_argument(106"--max_train_samples",107type=int,108default=None,109help=(110"For debugging purposes or quicker training, truncate the number of training examples to this "111"value if set."112),113)114parser.add_argument(115"--output_dir",116type=str,117default="sd-model-finetuned",118help="The output directory where the model predictions and checkpoints will be written.",119)120parser.add_argument(121"--cache_dir",122type=str,123default=None,124help="The directory where the downloaded models and datasets will be stored.",125)126parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")127parser.add_argument(128"--resolution",129type=int,130default=512,131help=(132"The resolution for input images, all the images in the train/validation dataset will be resized to this"133" resolution"134),135)136parser.add_argument(137"--center_crop",138default=False,139action="store_true",140help=(141"Whether to center crop the input images to the resolution. If not set, the images will be randomly"142" cropped. The images will be resized to the resolution first before cropping."143),144)145parser.add_argument(146"--random_flip",147action="store_true",148help="whether to randomly flip images horizontally",149)150parser.add_argument(151"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."152)153parser.add_argument("--num_train_epochs", type=int, default=100)154parser.add_argument(155"--max_train_steps",156type=int,157default=None,158help="Total number of training steps to perform. If provided, overrides num_train_epochs.",159)160parser.add_argument(161"--gradient_accumulation_steps",162type=int,163default=1,164help="Number of updates steps to accumulate before performing a backward/update pass.",165)166parser.add_argument(167"--gradient_checkpointing",168action="store_true",169help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",170)171parser.add_argument(172"--learning_rate",173type=float,174default=1e-4,175help="Initial learning rate (after the potential warmup period) to use.",176)177parser.add_argument(178"--scale_lr",179action="store_true",180default=False,181help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",182)183parser.add_argument(184"--lr_scheduler",185type=str,186default="constant",187help=(188'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'189' "constant", "constant_with_warmup"]'190),191)192parser.add_argument(193"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."194)195parser.add_argument(196"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."197)198parser.add_argument(199"--allow_tf32",200action="store_true",201help=(202"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"203" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"204),205)206parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")207parser.add_argument(208"--non_ema_revision",209type=str,210default=None,211required=False,212help=(213"Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or"214" remote repository specified with --pretrained_model_name_or_path."215),216)217parser.add_argument(218"--dataloader_num_workers",219type=int,220default=0,221help=(222"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."223),224)225parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")226parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")227parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")228parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")229parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")230parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")231parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")232parser.add_argument(233"--hub_model_id",234type=str,235default=None,236help="The name of the repository to keep in sync with the local `output_dir`.",237)238parser.add_argument(239"--logging_dir",240type=str,241default="logs",242help=(243"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"244" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."245),246)247parser.add_argument(248"--mixed_precision",249type=str,250default=None,251choices=["no", "fp16", "bf16"],252help=(253"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="254" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"255" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."256),257)258parser.add_argument(259"--report_to",260type=str,261default="tensorboard",262help=(263'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'264' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'265),266)267parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")268parser.add_argument(269"--checkpointing_steps",270type=int,271default=500,272help=(273"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"274" training using `--resume_from_checkpoint`."275),276)277parser.add_argument(278"--checkpoints_total_limit",279type=int,280default=None,281help=(282"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."283" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"284" for more docs"285),286)287parser.add_argument(288"--resume_from_checkpoint",289type=str,290default=None,291help=(292"Whether training should be resumed from a previous checkpoint. Use a path saved by"293' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'294),295)296parser.add_argument(297"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."298)299parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")300301args = parser.parse_args()302env_local_rank = int(os.environ.get("LOCAL_RANK", -1))303if env_local_rank != -1 and env_local_rank != args.local_rank:304args.local_rank = env_local_rank305306# Sanity checks307if args.dataset_name is None and args.train_data_dir is None:308raise ValueError("Need either a dataset name or a training folder.")309310# default to using the same revision for the non-ema model if not specified311if args.non_ema_revision is None:312args.non_ema_revision = args.revision313314return args315316317def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):318if token is None:319token = HfFolder.get_token()320if organization is None:321username = whoami(token)["name"]322return f"{username}/{model_id}"323else:324return f"{organization}/{model_id}"325326327dataset_name_mapping = {328"lambdalabs/pokemon-blip-captions": ("image", "text"),329}330331332def main():333args = parse_args()334335if args.non_ema_revision is not None:336deprecate(337"non_ema_revision!=None",338"0.15.0",339message=(340"Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to"341" use `--variant=non_ema` instead."342),343)344logging_dir = os.path.join(args.output_dir, args.logging_dir)345346accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit)347348accelerator = Accelerator(349gradient_accumulation_steps=args.gradient_accumulation_steps,350mixed_precision=args.mixed_precision,351log_with=args.report_to,352logging_dir=logging_dir,353project_config=accelerator_project_config,354)355356# Make one log on every process with the configuration for debugging.357logging.basicConfig(358format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",359datefmt="%m/%d/%Y %H:%M:%S",360level=logging.INFO,361)362logger.info(accelerator.state, main_process_only=False)363if accelerator.is_local_main_process:364datasets.utils.logging.set_verbosity_warning()365transformers.utils.logging.set_verbosity_warning()366diffusers.utils.logging.set_verbosity_info()367else:368datasets.utils.logging.set_verbosity_error()369transformers.utils.logging.set_verbosity_error()370diffusers.utils.logging.set_verbosity_error()371372# If passed along, set the training seed now.373if args.seed is not None:374set_seed(args.seed)375376# Handle the repository creation377if accelerator.is_main_process:378if args.push_to_hub:379if args.hub_model_id is None:380repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)381else:382repo_name = args.hub_model_id383create_repo(repo_name, exist_ok=True, token=args.hub_token)384repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)385386with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:387if "step_*" not in gitignore:388gitignore.write("step_*\n")389if "epoch_*" not in gitignore:390gitignore.write("epoch_*\n")391elif args.output_dir is not None:392os.makedirs(args.output_dir, exist_ok=True)393394# Load scheduler, tokenizer and models.395noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")396tokenizer = CLIPTokenizer.from_pretrained(397args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision398)399text_encoder = CLIPTextModel.from_pretrained(400args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision401)402vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)403unet = UNet2DConditionModel.from_pretrained(404args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision405)406407# Freeze vae and text_encoder408vae.requires_grad_(False)409text_encoder.requires_grad_(False)410411# Create EMA for the unet.412if args.use_ema:413ema_unet = UNet2DConditionModel.from_pretrained(414args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision415)416ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config)417418if args.enable_xformers_memory_efficient_attention:419if is_xformers_available():420import xformers421422xformers_version = version.parse(xformers.__version__)423if xformers_version == version.parse("0.0.16"):424logger.warn(425"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."426)427unet.enable_xformers_memory_efficient_attention()428else:429raise ValueError("xformers is not available. Make sure it is installed correctly")430431# `accelerate` 0.16.0 will have better support for customized saving432if version.parse(accelerate.__version__) >= version.parse("0.16.0"):433# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format434def save_model_hook(models, weights, output_dir):435if args.use_ema:436ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))437438for i, model in enumerate(models):439model.save_pretrained(os.path.join(output_dir, "unet"))440441# make sure to pop weight so that corresponding model is not saved again442weights.pop()443444def load_model_hook(models, input_dir):445if args.use_ema:446load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel)447ema_unet.load_state_dict(load_model.state_dict())448ema_unet.to(accelerator.device)449del load_model450451for i in range(len(models)):452# pop models so that they are not loaded again453model = models.pop()454455# load diffusers style into model456load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")457model.register_to_config(**load_model.config)458459model.load_state_dict(load_model.state_dict())460del load_model461462accelerator.register_save_state_pre_hook(save_model_hook)463accelerator.register_load_state_pre_hook(load_model_hook)464465if args.gradient_checkpointing:466unet.enable_gradient_checkpointing()467468# Enable TF32 for faster training on Ampere GPUs,469# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices470if args.allow_tf32:471torch.backends.cuda.matmul.allow_tf32 = True472473if args.scale_lr:474args.learning_rate = (475args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes476)477478# Initialize the optimizer479if args.use_8bit_adam:480try:481import bitsandbytes as bnb482except ImportError:483raise ImportError(484"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"485)486487optimizer_cls = bnb.optim.AdamW8bit488else:489optimizer_cls = torch.optim.AdamW490491optimizer = optimizer_cls(492unet.parameters(),493lr=args.learning_rate,494betas=(args.adam_beta1, args.adam_beta2),495weight_decay=args.adam_weight_decay,496eps=args.adam_epsilon,497)498499# Get the datasets: you can either provide your own training and evaluation files (see below)500# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).501502# In distributed training, the load_dataset function guarantees that only one local process can concurrently503# download the dataset.504if args.dataset_name is not None:505# Downloading and loading a dataset from the hub.506dataset = load_dataset(507args.dataset_name,508args.dataset_config_name,509cache_dir=args.cache_dir,510)511else:512data_files = {}513if args.train_data_dir is not None:514data_files["train"] = os.path.join(args.train_data_dir, "**")515dataset = load_dataset(516"imagefolder",517data_files=data_files,518cache_dir=args.cache_dir,519)520# See more about loading custom images at521# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder522523# Preprocessing the datasets.524# We need to tokenize inputs and targets.525column_names = dataset["train"].column_names526527# 6. Get the column names for input/target.528dataset_columns = dataset_name_mapping.get(args.dataset_name, None)529if args.image_column is None:530image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]531else:532image_column = args.image_column533if image_column not in column_names:534raise ValueError(535f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"536)537if args.caption_column is None:538caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]539else:540caption_column = args.caption_column541if caption_column not in column_names:542raise ValueError(543f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"544)545546# Preprocessing the datasets.547# We need to tokenize input captions and transform the images.548def tokenize_captions(examples, is_train=True):549captions = []550for caption in examples[caption_column]:551if isinstance(caption, str):552captions.append(caption)553elif isinstance(caption, (list, np.ndarray)):554# take a random caption if there are multiple555captions.append(random.choice(caption) if is_train else caption[0])556else:557raise ValueError(558f"Caption column `{caption_column}` should contain either strings or lists of strings."559)560inputs = tokenizer(561captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"562)563return inputs.input_ids564565# Preprocessing the datasets.566train_transforms = transforms.Compose(567[568transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),569transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),570transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),571transforms.ToTensor(),572transforms.Normalize([0.5], [0.5]),573]574)575576def preprocess_train(examples):577images = [image.convert("RGB") for image in examples[image_column]]578examples["pixel_values"] = [train_transforms(image) for image in images]579examples["input_ids"] = tokenize_captions(examples)580return examples581582with accelerator.main_process_first():583if args.max_train_samples is not None:584dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))585# Set the training transforms586train_dataset = dataset["train"].with_transform(preprocess_train)587588def collate_fn(examples):589pixel_values = torch.stack([example["pixel_values"] for example in examples])590pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()591input_ids = torch.stack([example["input_ids"] for example in examples])592return {"pixel_values": pixel_values, "input_ids": input_ids}593594# DataLoaders creation:595train_dataloader = torch.utils.data.DataLoader(596train_dataset,597shuffle=True,598collate_fn=collate_fn,599batch_size=args.train_batch_size,600num_workers=args.dataloader_num_workers,601)602603# Scheduler and math around the number of training steps.604overrode_max_train_steps = False605num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)606if args.max_train_steps is None:607args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch608overrode_max_train_steps = True609610lr_scheduler = get_scheduler(611args.lr_scheduler,612optimizer=optimizer,613num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,614num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,615)616617# Prepare everything with our `accelerator`.618unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(619unet, optimizer, train_dataloader, lr_scheduler620)621622if args.use_ema:623ema_unet.to(accelerator.device)624625# For mixed precision training we cast the text_encoder and vae weights to half-precision626# as these models are only used for inference, keeping weights in full precision is not required.627weight_dtype = torch.float32628if accelerator.mixed_precision == "fp16":629weight_dtype = torch.float16630elif accelerator.mixed_precision == "bf16":631weight_dtype = torch.bfloat16632633# Move text_encode and vae to gpu and cast to weight_dtype634text_encoder.to(accelerator.device, dtype=weight_dtype)635vae.to(accelerator.device, dtype=weight_dtype)636637# We need to recalculate our total training steps as the size of the training dataloader may have changed.638num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)639if overrode_max_train_steps:640args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch641# Afterwards we recalculate our number of training epochs642args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)643644# We need to initialize the trackers we use, and also store our configuration.645# The trackers initializes automatically on the main process.646if accelerator.is_main_process:647accelerator.init_trackers("text2image-fine-tune", config=vars(args))648649# Train!650total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps651652logger.info("***** Running training *****")653logger.info(f" Num examples = {len(train_dataset)}")654logger.info(f" Num Epochs = {args.num_train_epochs}")655logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")656logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")657logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")658logger.info(f" Total optimization steps = {args.max_train_steps}")659global_step = 0660first_epoch = 0661662# Potentially load in the weights and states from a previous save663if args.resume_from_checkpoint:664if args.resume_from_checkpoint != "latest":665path = os.path.basename(args.resume_from_checkpoint)666else:667# Get the most recent checkpoint668dirs = os.listdir(args.output_dir)669dirs = [d for d in dirs if d.startswith("checkpoint")]670dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))671path = dirs[-1] if len(dirs) > 0 else None672673if path is None:674accelerator.print(675f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."676)677args.resume_from_checkpoint = None678else:679accelerator.print(f"Resuming from checkpoint {path}")680accelerator.load_state(os.path.join(args.output_dir, path))681global_step = int(path.split("-")[1])682683resume_global_step = global_step * args.gradient_accumulation_steps684first_epoch = global_step // num_update_steps_per_epoch685resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)686687# Only show the progress bar once on each machine.688progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)689progress_bar.set_description("Steps")690691for epoch in range(first_epoch, args.num_train_epochs):692unet.train()693train_loss = 0.0694for step, batch in enumerate(train_dataloader):695# Skip steps until we reach the resumed step696if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:697if step % args.gradient_accumulation_steps == 0:698progress_bar.update(1)699continue700701with accelerator.accumulate(unet):702# Convert images to latent space703latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample()704latents = latents * vae.config.scaling_factor705706# Sample noise that we'll add to the latents707noise = torch.randn_like(latents)708if args.noise_offset:709# https://www.crosslabs.org//blog/diffusion-with-offset-noise710noise += args.noise_offset * torch.randn(711(latents.shape[0], latents.shape[1], 1, 1), device=latents.device712)713714bsz = latents.shape[0]715# Sample a random timestep for each image716timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)717timesteps = timesteps.long()718719# Add noise to the latents according to the noise magnitude at each timestep720# (this is the forward diffusion process)721noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)722723# Get the text embedding for conditioning724encoder_hidden_states = text_encoder(batch["input_ids"])[0]725726# Get the target for loss depending on the prediction type727if noise_scheduler.config.prediction_type == "epsilon":728target = noise729elif noise_scheduler.config.prediction_type == "v_prediction":730target = noise_scheduler.get_velocity(latents, noise, timesteps)731else:732raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")733734# Predict the noise residual and compute loss735model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample736loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")737738# Gather the losses across all processes for logging (if we use distributed training).739avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()740train_loss += avg_loss.item() / args.gradient_accumulation_steps741742# Backpropagate743accelerator.backward(loss)744if accelerator.sync_gradients:745accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)746optimizer.step()747lr_scheduler.step()748optimizer.zero_grad()749750# Checks if the accelerator has performed an optimization step behind the scenes751if accelerator.sync_gradients:752if args.use_ema:753ema_unet.step(unet.parameters())754progress_bar.update(1)755global_step += 1756accelerator.log({"train_loss": train_loss}, step=global_step)757train_loss = 0.0758759if global_step % args.checkpointing_steps == 0:760if accelerator.is_main_process:761save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")762accelerator.save_state(save_path)763logger.info(f"Saved state to {save_path}")764765logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}766progress_bar.set_postfix(**logs)767768if global_step >= args.max_train_steps:769break770771# Create the pipeline using the trained modules and save it.772accelerator.wait_for_everyone()773if accelerator.is_main_process:774unet = accelerator.unwrap_model(unet)775if args.use_ema:776ema_unet.copy_to(unet.parameters())777778pipeline = StableDiffusionPipeline.from_pretrained(779args.pretrained_model_name_or_path,780text_encoder=text_encoder,781vae=vae,782unet=unet,783revision=args.revision,784)785pipeline.save_pretrained(args.output_dir)786787if args.push_to_hub:788repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)789790accelerator.end_training()791792793if __name__ == "__main__":794main()795796797