Path: blob/main/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py
1980 views
#!/usr/bin/env python1# coding=utf-82# Copyright 2023 The HuggingFace Inc. team. All rights reserved.3#4# Licensed under the Apache License, Version 2.0 (the "License");5# you may not use this file except in compliance with the License.6# You may obtain a copy of the License at7#8# http://www.apache.org/licenses/LICENSE-2.09#10# Unless required by applicable law or agreed to in writing, software11# distributed under the License is distributed on an "AS IS" BASIS,12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.13# See the License for the specific language governing permissions and1415import argparse16import logging17import math18import os19import random20from pathlib import Path21from typing import Optional2223import datasets24import numpy as np25import PIL26import torch27import torch.nn.functional as F28import torch.utils.checkpoint29import transformers30from accelerate import Accelerator31from accelerate.logging import get_logger32from accelerate.utils import ProjectConfiguration, set_seed33from huggingface_hub import HfFolder, Repository, create_repo, whoami34from onnxruntime.training.ortmodule import ORTModule3536# TODO: remove and import from diffusers.utils when the new version of diffusers is released37from packaging import version38from PIL import Image39from torch.utils.data import Dataset40from torchvision import transforms41from tqdm.auto import tqdm42from transformers import CLIPTextModel, CLIPTokenizer4344import diffusers45from diffusers import (46AutoencoderKL,47DDPMScheduler,48DiffusionPipeline,49DPMSolverMultistepScheduler,50StableDiffusionPipeline,51UNet2DConditionModel,52)53from diffusers.optimization import get_scheduler54from diffusers.utils import check_min_version, is_wandb_available55from diffusers.utils.import_utils import is_xformers_available565758if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):59PIL_INTERPOLATION = {60"linear": PIL.Image.Resampling.BILINEAR,61"bilinear": PIL.Image.Resampling.BILINEAR,62"bicubic": PIL.Image.Resampling.BICUBIC,63"lanczos": PIL.Image.Resampling.LANCZOS,64"nearest": PIL.Image.Resampling.NEAREST,65}66else:67PIL_INTERPOLATION = {68"linear": PIL.Image.LINEAR,69"bilinear": PIL.Image.BILINEAR,70"bicubic": PIL.Image.BICUBIC,71"lanczos": PIL.Image.LANCZOS,72"nearest": PIL.Image.NEAREST,73}74# ------------------------------------------------------------------------------757677# Will error if the minimal version of diffusers is not installed. Remove at your own risks.78check_min_version("0.13.0.dev0")7980logger = get_logger(__name__)818283def save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path):84logger.info("Saving embeddings")85learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]86learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}87torch.save(learned_embeds_dict, save_path)888990def parse_args():91parser = argparse.ArgumentParser(description="Simple example of a training script.")92parser.add_argument(93"--save_steps",94type=int,95default=500,96help="Save learned_embeds.bin every X updates steps.",97)98parser.add_argument(99"--only_save_embeds",100action="store_true",101default=False,102help="Save only the embeddings for the new concept.",103)104parser.add_argument(105"--pretrained_model_name_or_path",106type=str,107default=None,108required=True,109help="Path to pretrained model or model identifier from huggingface.co/models.",110)111parser.add_argument(112"--revision",113type=str,114default=None,115required=False,116help="Revision of pretrained model identifier from huggingface.co/models.",117)118parser.add_argument(119"--tokenizer_name",120type=str,121default=None,122help="Pretrained tokenizer name or path if not the same as model_name",123)124parser.add_argument(125"--train_data_dir", type=str, default=None, required=True, help="A folder containing the training data."126)127parser.add_argument(128"--placeholder_token",129type=str,130default=None,131required=True,132help="A token to use as a placeholder for the concept.",133)134parser.add_argument(135"--initializer_token", type=str, default=None, required=True, help="A token to use as initializer word."136)137parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'")138parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.")139parser.add_argument(140"--output_dir",141type=str,142default="text-inversion-model",143help="The output directory where the model predictions and checkpoints will be written.",144)145parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")146parser.add_argument(147"--resolution",148type=int,149default=512,150help=(151"The resolution for input images, all the images in the train/validation dataset will be resized to this"152" resolution"153),154)155parser.add_argument(156"--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution."157)158parser.add_argument(159"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."160)161parser.add_argument("--num_train_epochs", type=int, default=100)162parser.add_argument(163"--max_train_steps",164type=int,165default=5000,166help="Total number of training steps to perform. If provided, overrides num_train_epochs.",167)168parser.add_argument(169"--gradient_accumulation_steps",170type=int,171default=1,172help="Number of updates steps to accumulate before performing a backward/update pass.",173)174parser.add_argument(175"--gradient_checkpointing",176action="store_true",177help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",178)179parser.add_argument(180"--learning_rate",181type=float,182default=1e-4,183help="Initial learning rate (after the potential warmup period) to use.",184)185parser.add_argument(186"--scale_lr",187action="store_true",188default=False,189help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",190)191parser.add_argument(192"--lr_scheduler",193type=str,194default="constant",195help=(196'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'197' "constant", "constant_with_warmup"]'198),199)200parser.add_argument(201"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."202)203parser.add_argument(204"--dataloader_num_workers",205type=int,206default=0,207help=(208"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."209),210)211parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")212parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")213parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")214parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")215parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")216parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")217parser.add_argument(218"--hub_model_id",219type=str,220default=None,221help="The name of the repository to keep in sync with the local `output_dir`.",222)223parser.add_argument(224"--logging_dir",225type=str,226default="logs",227help=(228"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"229" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."230),231)232parser.add_argument(233"--mixed_precision",234type=str,235default="no",236choices=["no", "fp16", "bf16"],237help=(238"Whether to use mixed precision. Choose"239"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."240"and an Nvidia Ampere GPU."241),242)243parser.add_argument(244"--allow_tf32",245action="store_true",246help=(247"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"248" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"249),250)251parser.add_argument(252"--report_to",253type=str,254default="tensorboard",255help=(256'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'257' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'258),259)260parser.add_argument(261"--validation_prompt",262type=str,263default=None,264help="A prompt that is used during validation to verify that the model is learning.",265)266parser.add_argument(267"--num_validation_images",268type=int,269default=4,270help="Number of images that should be generated during validation with `validation_prompt`.",271)272parser.add_argument(273"--validation_epochs",274type=int,275default=50,276help=(277"Run validation every X epochs. Validation consists of running the prompt"278" `args.validation_prompt` multiple times: `args.num_validation_images`"279" and logging the images."280),281)282parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")283parser.add_argument(284"--checkpointing_steps",285type=int,286default=500,287help=(288"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"289" training using `--resume_from_checkpoint`."290),291)292parser.add_argument(293"--checkpoints_total_limit",294type=int,295default=None,296help=(297"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."298" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"299" for more docs"300),301)302parser.add_argument(303"--resume_from_checkpoint",304type=str,305default=None,306help=(307"Whether training should be resumed from a previous checkpoint. Use a path saved by"308' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'309),310)311parser.add_argument(312"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."313)314315args = parser.parse_args()316env_local_rank = int(os.environ.get("LOCAL_RANK", -1))317if env_local_rank != -1 and env_local_rank != args.local_rank:318args.local_rank = env_local_rank319320if args.train_data_dir is None:321raise ValueError("You must specify a train data directory.")322323return args324325326imagenet_templates_small = [327"a photo of a {}",328"a rendering of a {}",329"a cropped photo of the {}",330"the photo of a {}",331"a photo of a clean {}",332"a photo of a dirty {}",333"a dark photo of the {}",334"a photo of my {}",335"a photo of the cool {}",336"a close-up photo of a {}",337"a bright photo of the {}",338"a cropped photo of a {}",339"a photo of the {}",340"a good photo of the {}",341"a photo of one {}",342"a close-up photo of the {}",343"a rendition of the {}",344"a photo of the clean {}",345"a rendition of a {}",346"a photo of a nice {}",347"a good photo of a {}",348"a photo of the nice {}",349"a photo of the small {}",350"a photo of the weird {}",351"a photo of the large {}",352"a photo of a cool {}",353"a photo of a small {}",354]355356imagenet_style_templates_small = [357"a painting in the style of {}",358"a rendering in the style of {}",359"a cropped painting in the style of {}",360"the painting in the style of {}",361"a clean painting in the style of {}",362"a dirty painting in the style of {}",363"a dark painting in the style of {}",364"a picture in the style of {}",365"a cool painting in the style of {}",366"a close-up painting in the style of {}",367"a bright painting in the style of {}",368"a cropped painting in the style of {}",369"a good painting in the style of {}",370"a close-up painting in the style of {}",371"a rendition in the style of {}",372"a nice painting in the style of {}",373"a small painting in the style of {}",374"a weird painting in the style of {}",375"a large painting in the style of {}",376]377378379class TextualInversionDataset(Dataset):380def __init__(381self,382data_root,383tokenizer,384learnable_property="object", # [object, style]385size=512,386repeats=100,387interpolation="bicubic",388flip_p=0.5,389set="train",390placeholder_token="*",391center_crop=False,392):393self.data_root = data_root394self.tokenizer = tokenizer395self.learnable_property = learnable_property396self.size = size397self.placeholder_token = placeholder_token398self.center_crop = center_crop399self.flip_p = flip_p400401self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]402403self.num_images = len(self.image_paths)404self._length = self.num_images405406if set == "train":407self._length = self.num_images * repeats408409self.interpolation = {410"linear": PIL_INTERPOLATION["linear"],411"bilinear": PIL_INTERPOLATION["bilinear"],412"bicubic": PIL_INTERPOLATION["bicubic"],413"lanczos": PIL_INTERPOLATION["lanczos"],414}[interpolation]415416self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small417self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)418419def __len__(self):420return self._length421422def __getitem__(self, i):423example = {}424image = Image.open(self.image_paths[i % self.num_images])425426if not image.mode == "RGB":427image = image.convert("RGB")428429placeholder_string = self.placeholder_token430text = random.choice(self.templates).format(placeholder_string)431432example["input_ids"] = self.tokenizer(433text,434padding="max_length",435truncation=True,436max_length=self.tokenizer.model_max_length,437return_tensors="pt",438).input_ids[0]439440# default to score-sde preprocessing441img = np.array(image).astype(np.uint8)442443if self.center_crop:444crop = min(img.shape[0], img.shape[1])445(446h,447w,448) = (449img.shape[0],450img.shape[1],451)452img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]453454image = Image.fromarray(img)455image = image.resize((self.size, self.size), resample=self.interpolation)456457image = self.flip_transform(image)458image = np.array(image).astype(np.uint8)459image = (image / 127.5 - 1.0).astype(np.float32)460461example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)462return example463464465def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):466if token is None:467token = HfFolder.get_token()468if organization is None:469username = whoami(token)["name"]470return f"{username}/{model_id}"471else:472return f"{organization}/{model_id}"473474475def main():476args = parse_args()477logging_dir = os.path.join(args.output_dir, args.logging_dir)478479accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit)480481accelerator = Accelerator(482gradient_accumulation_steps=args.gradient_accumulation_steps,483mixed_precision=args.mixed_precision,484log_with=args.report_to,485logging_dir=logging_dir,486project_config=accelerator_project_config,487)488489if args.report_to == "wandb":490if not is_wandb_available():491raise ImportError("Make sure to install wandb if you want to use it for logging during training.")492import wandb493494# Make one log on every process with the configuration for debugging.495logging.basicConfig(496format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",497datefmt="%m/%d/%Y %H:%M:%S",498level=logging.INFO,499)500logger.info(accelerator.state, main_process_only=False)501if accelerator.is_local_main_process:502datasets.utils.logging.set_verbosity_warning()503transformers.utils.logging.set_verbosity_warning()504diffusers.utils.logging.set_verbosity_info()505else:506datasets.utils.logging.set_verbosity_error()507transformers.utils.logging.set_verbosity_error()508diffusers.utils.logging.set_verbosity_error()509510# If passed along, set the training seed now.511if args.seed is not None:512set_seed(args.seed)513514# Handle the repository creation515if accelerator.is_main_process:516if args.push_to_hub:517if args.hub_model_id is None:518repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)519else:520repo_name = args.hub_model_id521create_repo(repo_name, exist_ok=True, token=args.hub_token)522repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)523524with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:525if "step_*" not in gitignore:526gitignore.write("step_*\n")527if "epoch_*" not in gitignore:528gitignore.write("epoch_*\n")529elif args.output_dir is not None:530os.makedirs(args.output_dir, exist_ok=True)531532# Load tokenizer533if args.tokenizer_name:534tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)535elif args.pretrained_model_name_or_path:536tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")537538# Load scheduler and models539noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")540text_encoder = CLIPTextModel.from_pretrained(541args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision542)543vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)544unet = UNet2DConditionModel.from_pretrained(545args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision546)547548# Add the placeholder token in tokenizer549num_added_tokens = tokenizer.add_tokens(args.placeholder_token)550if num_added_tokens == 0:551raise ValueError(552f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"553" `placeholder_token` that is not already in the tokenizer."554)555556# Convert the initializer_token, placeholder_token to ids557token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)558# Check if initializer_token is a single token or a sequence of tokens559if len(token_ids) > 1:560raise ValueError("The initializer token must be a single token.")561562initializer_token_id = token_ids[0]563placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)564565# Resize the token embeddings as we are adding new special tokens to the tokenizer566text_encoder.resize_token_embeddings(len(tokenizer))567568# Initialise the newly added placeholder token with the embeddings of the initializer token569token_embeds = text_encoder.get_input_embeddings().weight.data570token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]571572# Freeze vae and unet573vae.requires_grad_(False)574unet.requires_grad_(False)575# Freeze all parameters except for the token embeddings in text encoder576text_encoder.text_model.encoder.requires_grad_(False)577text_encoder.text_model.final_layer_norm.requires_grad_(False)578text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)579580if args.gradient_checkpointing:581# Keep unet in train mode if we are using gradient checkpointing to save memory.582# The dropout cannot be != 0 so it doesn't matter if we are in eval or train mode.583unet.train()584text_encoder.gradient_checkpointing_enable()585unet.enable_gradient_checkpointing()586587if args.enable_xformers_memory_efficient_attention:588if is_xformers_available():589unet.enable_xformers_memory_efficient_attention()590else:591raise ValueError("xformers is not available. Make sure it is installed correctly")592593# Enable TF32 for faster training on Ampere GPUs,594# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices595if args.allow_tf32:596torch.backends.cuda.matmul.allow_tf32 = True597598if args.scale_lr:599args.learning_rate = (600args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes601)602603# Initialize the optimizer604optimizer = torch.optim.AdamW(605text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings606lr=args.learning_rate,607betas=(args.adam_beta1, args.adam_beta2),608weight_decay=args.adam_weight_decay,609eps=args.adam_epsilon,610)611612# Dataset and DataLoaders creation:613train_dataset = TextualInversionDataset(614data_root=args.train_data_dir,615tokenizer=tokenizer,616size=args.resolution,617placeholder_token=args.placeholder_token,618repeats=args.repeats,619learnable_property=args.learnable_property,620center_crop=args.center_crop,621set="train",622)623train_dataloader = torch.utils.data.DataLoader(624train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers625)626627# Scheduler and math around the number of training steps.628overrode_max_train_steps = False629num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)630if args.max_train_steps is None:631args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch632overrode_max_train_steps = True633634lr_scheduler = get_scheduler(635args.lr_scheduler,636optimizer=optimizer,637num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,638num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,639)640641# Prepare everything with our `accelerator`.642text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(643text_encoder, optimizer, train_dataloader, lr_scheduler644)645646text_encoder = ORTModule(text_encoder)647648# For mixed precision training we cast the unet and vae weights to half-precision649# as these models are only used for inference, keeping weights in full precision is not required.650weight_dtype = torch.float32651if accelerator.mixed_precision == "fp16":652weight_dtype = torch.float16653elif accelerator.mixed_precision == "bf16":654weight_dtype = torch.bfloat16655656# Move vae and unet to device and cast to weight_dtype657unet.to(accelerator.device, dtype=weight_dtype)658vae.to(accelerator.device, dtype=weight_dtype)659660# We need to recalculate our total training steps as the size of the training dataloader may have changed.661num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)662if overrode_max_train_steps:663args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch664# Afterwards we recalculate our number of training epochs665args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)666667# We need to initialize the trackers we use, and also store our configuration.668# The trackers initializes automatically on the main process.669if accelerator.is_main_process:670accelerator.init_trackers("textual_inversion", config=vars(args))671672# Train!673total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps674675logger.info("***** Running training *****")676logger.info(f" Num examples = {len(train_dataset)}")677logger.info(f" Num Epochs = {args.num_train_epochs}")678logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")679logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")680logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")681logger.info(f" Total optimization steps = {args.max_train_steps}")682global_step = 0683first_epoch = 0684685# Potentially load in the weights and states from a previous save686if args.resume_from_checkpoint:687if args.resume_from_checkpoint != "latest":688path = os.path.basename(args.resume_from_checkpoint)689else:690# Get the most recent checkpoint691dirs = os.listdir(args.output_dir)692dirs = [d for d in dirs if d.startswith("checkpoint")]693dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))694path = dirs[-1] if len(dirs) > 0 else None695696if path is None:697accelerator.print(698f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."699)700args.resume_from_checkpoint = None701else:702accelerator.print(f"Resuming from checkpoint {path}")703accelerator.load_state(os.path.join(args.output_dir, path))704global_step = int(path.split("-")[1])705706resume_global_step = global_step * args.gradient_accumulation_steps707first_epoch = global_step // num_update_steps_per_epoch708resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)709710# Only show the progress bar once on each machine.711progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)712progress_bar.set_description("Steps")713714# keep original embeddings as reference715orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.clone()716717for epoch in range(first_epoch, args.num_train_epochs):718text_encoder.train()719for step, batch in enumerate(train_dataloader):720# Skip steps until we reach the resumed step721if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:722if step % args.gradient_accumulation_steps == 0:723progress_bar.update(1)724continue725726with accelerator.accumulate(text_encoder):727# Convert images to latent space728latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach()729latents = latents * vae.config.scaling_factor730731# Sample noise that we'll add to the latents732noise = torch.randn_like(latents)733bsz = latents.shape[0]734# Sample a random timestep for each image735timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)736timesteps = timesteps.long()737738# Add noise to the latents according to the noise magnitude at each timestep739# (this is the forward diffusion process)740noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)741742# Get the text embedding for conditioning743encoder_hidden_states = text_encoder(batch["input_ids"])[0].to(dtype=weight_dtype)744745# Predict the noise residual746model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample747748# Get the target for loss depending on the prediction type749if noise_scheduler.config.prediction_type == "epsilon":750target = noise751elif noise_scheduler.config.prediction_type == "v_prediction":752target = noise_scheduler.get_velocity(latents, noise, timesteps)753else:754raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")755756loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")757758accelerator.backward(loss)759760optimizer.step()761lr_scheduler.step()762optimizer.zero_grad()763764# Let's make sure we don't update any embedding weights besides the newly added token765index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id766with torch.no_grad():767accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[768index_no_updates769] = orig_embeds_params[index_no_updates]770771# Checks if the accelerator has performed an optimization step behind the scenes772if accelerator.sync_gradients:773progress_bar.update(1)774global_step += 1775if global_step % args.save_steps == 0:776save_path = os.path.join(args.output_dir, f"learned_embeds-steps-{global_step}.bin")777save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)778779if global_step % args.checkpointing_steps == 0:780if accelerator.is_main_process:781save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")782accelerator.save_state(save_path)783logger.info(f"Saved state to {save_path}")784785logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}786progress_bar.set_postfix(**logs)787accelerator.log(logs, step=global_step)788789if global_step >= args.max_train_steps:790break791792if accelerator.is_main_process and args.validation_prompt is not None and epoch % args.validation_epochs == 0:793logger.info(794f"Running validation... \n Generating {args.num_validation_images} images with prompt:"795f" {args.validation_prompt}."796)797# create pipeline (note: unet and vae are loaded again in float32)798pipeline = DiffusionPipeline.from_pretrained(799args.pretrained_model_name_or_path,800text_encoder=accelerator.unwrap_model(text_encoder),801revision=args.revision,802)803pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)804pipeline = pipeline.to(accelerator.device)805pipeline.set_progress_bar_config(disable=True)806807# run inference808generator = (809None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)810)811prompt = args.num_validation_images * [args.validation_prompt]812images = pipeline(prompt, num_inference_steps=25, generator=generator).images813814for tracker in accelerator.trackers:815if tracker.name == "tensorboard":816np_images = np.stack([np.asarray(img) for img in images])817tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")818if tracker.name == "wandb":819tracker.log(820{821"validation": [822wandb.Image(image, caption=f"{i}: {args.validation_prompt}")823for i, image in enumerate(images)824]825}826)827828del pipeline829torch.cuda.empty_cache()830831# Create the pipeline using using the trained modules and save it.832accelerator.wait_for_everyone()833if accelerator.is_main_process:834if args.push_to_hub and args.only_save_embeds:835logger.warn("Enabling full model saving because --push_to_hub=True was specified.")836save_full_model = True837else:838save_full_model = not args.only_save_embeds839if save_full_model:840pipeline = StableDiffusionPipeline.from_pretrained(841args.pretrained_model_name_or_path,842text_encoder=accelerator.unwrap_model(text_encoder),843vae=vae,844unet=unet,845tokenizer=tokenizer,846)847pipeline.save_pretrained(args.output_dir)848# Save the newly trained embeddings849save_path = os.path.join(args.output_dir, "learned_embeds.bin")850save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)851852if args.push_to_hub:853repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)854855accelerator.end_training()856857858if __name__ == "__main__":859main()860861862