Path: blob/main/examples/research_projects/mulit_token_textual_inversion/textual_inversion.py
1979 views
#!/usr/bin/env python1# coding=utf-82# Copyright 2023 The HuggingFace Inc. team. All rights reserved.3#4# Licensed under the Apache License, Version 2.0 (the "License");5# you may not use this file except in compliance with the License.6# You may obtain a copy of the License at7#8# http://www.apache.org/licenses/LICENSE-2.09#10# Unless required by applicable law or agreed to in writing, software11# distributed under the License is distributed on an "AS IS" BASIS,12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.13# See the License for the specific language governing permissions and1415import argparse16import logging17import math18import os19import random20from pathlib import Path21from typing import Optional2223import numpy as np24import PIL25import torch26import torch.nn.functional as F27import torch.utils.checkpoint28import transformers29from accelerate import Accelerator30from accelerate.logging import get_logger31from accelerate.utils import ProjectConfiguration, set_seed32from huggingface_hub import HfFolder, Repository, create_repo, whoami33from multi_token_clip import MultiTokenCLIPTokenizer3435# TODO: remove and import from diffusers.utils when the new version of diffusers is released36from packaging import version37from PIL import Image38from torch.utils.data import Dataset39from torchvision import transforms40from tqdm.auto import tqdm41from transformers import CLIPTextModel4243import diffusers44from diffusers import (45AutoencoderKL,46DDPMScheduler,47DiffusionPipeline,48DPMSolverMultistepScheduler,49StableDiffusionPipeline,50UNet2DConditionModel,51)52from diffusers.optimization import get_scheduler53from diffusers.utils import check_min_version, is_wandb_available54from diffusers.utils.import_utils import is_xformers_available555657if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):58PIL_INTERPOLATION = {59"linear": PIL.Image.Resampling.BILINEAR,60"bilinear": PIL.Image.Resampling.BILINEAR,61"bicubic": PIL.Image.Resampling.BICUBIC,62"lanczos": PIL.Image.Resampling.LANCZOS,63"nearest": PIL.Image.Resampling.NEAREST,64}65else:66PIL_INTERPOLATION = {67"linear": PIL.Image.LINEAR,68"bilinear": PIL.Image.BILINEAR,69"bicubic": PIL.Image.BICUBIC,70"lanczos": PIL.Image.LANCZOS,71"nearest": PIL.Image.NEAREST,72}73# ------------------------------------------------------------------------------747576# Will error if the minimal version of diffusers is not installed. Remove at your own risks.77check_min_version("0.14.0.dev0")7879logger = get_logger(__name__)808182def add_tokens(tokenizer, text_encoder, placeholder_token, num_vec_per_token=1, initializer_token=None):83"""84Add tokens to the tokenizer and set the initial value of token embeddings85"""86tokenizer.add_placeholder_tokens(placeholder_token, num_vec_per_token=num_vec_per_token)87text_encoder.resize_token_embeddings(len(tokenizer))88token_embeds = text_encoder.get_input_embeddings().weight.data89placeholder_token_ids = tokenizer.encode(placeholder_token, add_special_tokens=False)90if initializer_token:91token_ids = tokenizer.encode(initializer_token, add_special_tokens=False)92for i, placeholder_token_id in enumerate(placeholder_token_ids):93token_embeds[placeholder_token_id] = token_embeds[token_ids[i * len(token_ids) // num_vec_per_token]]94else:95for i, placeholder_token_id in enumerate(placeholder_token_ids):96token_embeds[placeholder_token_id] = torch.randn_like(token_embeds[placeholder_token_id])97return placeholder_token9899100def save_progress(tokenizer, text_encoder, accelerator, save_path):101for placeholder_token in tokenizer.token_map:102placeholder_token_ids = tokenizer.encode(placeholder_token, add_special_tokens=False)103learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_ids]104if len(placeholder_token_ids) == 1:105learned_embeds = learned_embeds[None]106learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()}107torch.save(learned_embeds_dict, save_path)108109110def load_multitoken_tokenizer(tokenizer, text_encoder, learned_embeds_dict):111for placeholder_token in learned_embeds_dict:112placeholder_embeds = learned_embeds_dict[placeholder_token]113num_vec_per_token = placeholder_embeds.shape[0]114placeholder_embeds = placeholder_embeds.to(dtype=text_encoder.dtype)115add_tokens(tokenizer, text_encoder, placeholder_token, num_vec_per_token=num_vec_per_token)116placeholder_token_ids = tokenizer.encode(placeholder_token, add_special_tokens=False)117token_embeds = text_encoder.get_input_embeddings().weight.data118for i, placeholder_token_id in enumerate(placeholder_token_ids):119token_embeds[placeholder_token_id] = placeholder_embeds[i]120121122def load_multitoken_tokenizer_from_automatic(tokenizer, text_encoder, automatic_dict, placeholder_token):123"""124Automatic1111's tokens have format125{'string_to_token': {'*': 265}, 'string_to_param': {'*': tensor([[ 0.0833, 0.0030, 0.0057, ..., -0.0264, -0.0616, -0.0529],126[ 0.0058, -0.0190, -0.0584, ..., -0.0025, -0.0945, -0.0490],127[ 0.0916, 0.0025, 0.0365, ..., -0.0685, -0.0124, 0.0728],128[ 0.0812, -0.0199, -0.0100, ..., -0.0581, -0.0780, 0.0254]],129requires_grad=True)}, 'name': 'FloralMarble-400', 'step': 399, 'sd_checkpoint': '4bdfc29c', 'sd_checkpoint_name': 'SD2.1-768'}130"""131learned_embeds_dict = {}132learned_embeds_dict[placeholder_token] = automatic_dict["string_to_param"]["*"]133load_multitoken_tokenizer(tokenizer, text_encoder, learned_embeds_dict)134135136def get_mask(tokenizer, accelerator):137# Get the mask of the weights that won't change138mask = torch.ones(len(tokenizer)).to(accelerator.device, dtype=torch.bool)139for placeholder_token in tokenizer.token_map:140placeholder_token_ids = tokenizer.encode(placeholder_token, add_special_tokens=False)141for i in range(len(placeholder_token_ids)):142mask = mask & (torch.arange(len(tokenizer)) != placeholder_token_ids[i]).to(accelerator.device)143return mask144145146def parse_args():147parser = argparse.ArgumentParser(description="Simple example of a training script.")148parser.add_argument(149"--progressive_tokens_max_steps",150type=int,151default=2000,152help="The number of steps until all tokens will be used.",153)154parser.add_argument(155"--progressive_tokens",156action="store_true",157help="Progressively train the tokens. For example, first train for 1 token, then 2 tokens and so on.",158)159parser.add_argument("--vector_shuffle", action="store_true", help="Shuffling tokens durint training")160parser.add_argument(161"--num_vec_per_token",162type=int,163default=1,164help=(165"The number of vectors used to represent the placeholder token. The higher the number, the better the"166" result at the cost of editability. This can be fixed by prompt editing."167),168)169parser.add_argument(170"--save_steps",171type=int,172default=500,173help="Save learned_embeds.bin every X updates steps.",174)175parser.add_argument(176"--only_save_embeds",177action="store_true",178default=False,179help="Save only the embeddings for the new concept.",180)181parser.add_argument(182"--pretrained_model_name_or_path",183type=str,184default=None,185required=True,186help="Path to pretrained model or model identifier from huggingface.co/models.",187)188parser.add_argument(189"--revision",190type=str,191default=None,192required=False,193help="Revision of pretrained model identifier from huggingface.co/models.",194)195parser.add_argument(196"--tokenizer_name",197type=str,198default=None,199help="Pretrained tokenizer name or path if not the same as model_name",200)201parser.add_argument(202"--train_data_dir", type=str, default=None, required=True, help="A folder containing the training data."203)204parser.add_argument(205"--placeholder_token",206type=str,207default=None,208required=True,209help="A token to use as a placeholder for the concept.",210)211parser.add_argument(212"--initializer_token", type=str, default=None, required=True, help="A token to use as initializer word."213)214parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'")215parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.")216parser.add_argument(217"--output_dir",218type=str,219default="text-inversion-model",220help="The output directory where the model predictions and checkpoints will be written.",221)222parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")223parser.add_argument(224"--resolution",225type=int,226default=512,227help=(228"The resolution for input images, all the images in the train/validation dataset will be resized to this"229" resolution"230),231)232parser.add_argument(233"--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution."234)235parser.add_argument(236"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."237)238parser.add_argument("--num_train_epochs", type=int, default=100)239parser.add_argument(240"--max_train_steps",241type=int,242default=5000,243help="Total number of training steps to perform. If provided, overrides num_train_epochs.",244)245parser.add_argument(246"--gradient_accumulation_steps",247type=int,248default=1,249help="Number of updates steps to accumulate before performing a backward/update pass.",250)251parser.add_argument(252"--gradient_checkpointing",253action="store_true",254help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",255)256parser.add_argument(257"--learning_rate",258type=float,259default=1e-4,260help="Initial learning rate (after the potential warmup period) to use.",261)262parser.add_argument(263"--scale_lr",264action="store_true",265default=False,266help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",267)268parser.add_argument(269"--lr_scheduler",270type=str,271default="constant",272help=(273'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'274' "constant", "constant_with_warmup"]'275),276)277parser.add_argument(278"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."279)280parser.add_argument(281"--dataloader_num_workers",282type=int,283default=0,284help=(285"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."286),287)288parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")289parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")290parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")291parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")292parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")293parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")294parser.add_argument(295"--hub_model_id",296type=str,297default=None,298help="The name of the repository to keep in sync with the local `output_dir`.",299)300parser.add_argument(301"--logging_dir",302type=str,303default="logs",304help=(305"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"306" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."307),308)309parser.add_argument(310"--mixed_precision",311type=str,312default="no",313choices=["no", "fp16", "bf16"],314help=(315"Whether to use mixed precision. Choose"316"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."317"and an Nvidia Ampere GPU."318),319)320parser.add_argument(321"--allow_tf32",322action="store_true",323help=(324"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"325" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"326),327)328parser.add_argument(329"--report_to",330type=str,331default="tensorboard",332help=(333'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'334' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'335),336)337parser.add_argument(338"--validation_prompt",339type=str,340default=None,341help="A prompt that is used during validation to verify that the model is learning.",342)343parser.add_argument(344"--num_validation_images",345type=int,346default=4,347help="Number of images that should be generated during validation with `validation_prompt`.",348)349parser.add_argument(350"--validation_epochs",351type=int,352default=50,353help=(354"Run validation every X epochs. Validation consists of running the prompt"355" `args.validation_prompt` multiple times: `args.num_validation_images`"356" and logging the images."357),358)359parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")360parser.add_argument(361"--checkpointing_steps",362type=int,363default=500,364help=(365"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"366" training using `--resume_from_checkpoint`."367),368)369parser.add_argument(370"--checkpoints_total_limit",371type=int,372default=None,373help=(374"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."375" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"376" for more docs"377),378)379parser.add_argument(380"--resume_from_checkpoint",381type=str,382default=None,383help=(384"Whether training should be resumed from a previous checkpoint. Use a path saved by"385' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'386),387)388parser.add_argument(389"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."390)391392args = parser.parse_args()393env_local_rank = int(os.environ.get("LOCAL_RANK", -1))394if env_local_rank != -1 and env_local_rank != args.local_rank:395args.local_rank = env_local_rank396397if args.train_data_dir is None:398raise ValueError("You must specify a train data directory.")399400return args401402403imagenet_templates_small = [404"a photo of a {}",405"a rendering of a {}",406"a cropped photo of the {}",407"the photo of a {}",408"a photo of a clean {}",409"a photo of a dirty {}",410"a dark photo of the {}",411"a photo of my {}",412"a photo of the cool {}",413"a close-up photo of a {}",414"a bright photo of the {}",415"a cropped photo of a {}",416"a photo of the {}",417"a good photo of the {}",418"a photo of one {}",419"a close-up photo of the {}",420"a rendition of the {}",421"a photo of the clean {}",422"a rendition of a {}",423"a photo of a nice {}",424"a good photo of a {}",425"a photo of the nice {}",426"a photo of the small {}",427"a photo of the weird {}",428"a photo of the large {}",429"a photo of a cool {}",430"a photo of a small {}",431]432433imagenet_style_templates_small = [434"a painting in the style of {}",435"a rendering in the style of {}",436"a cropped painting in the style of {}",437"the painting in the style of {}",438"a clean painting in the style of {}",439"a dirty painting in the style of {}",440"a dark painting in the style of {}",441"a picture in the style of {}",442"a cool painting in the style of {}",443"a close-up painting in the style of {}",444"a bright painting in the style of {}",445"a cropped painting in the style of {}",446"a good painting in the style of {}",447"a close-up painting in the style of {}",448"a rendition in the style of {}",449"a nice painting in the style of {}",450"a small painting in the style of {}",451"a weird painting in the style of {}",452"a large painting in the style of {}",453]454455456class TextualInversionDataset(Dataset):457def __init__(458self,459data_root,460tokenizer,461learnable_property="object", # [object, style]462size=512,463repeats=100,464interpolation="bicubic",465flip_p=0.5,466set="train",467placeholder_token="*",468center_crop=False,469vector_shuffle=False,470progressive_tokens=False,471):472self.data_root = data_root473self.tokenizer = tokenizer474self.learnable_property = learnable_property475self.size = size476self.placeholder_token = placeholder_token477self.center_crop = center_crop478self.flip_p = flip_p479self.vector_shuffle = vector_shuffle480self.progressive_tokens = progressive_tokens481self.prop_tokens_to_load = 0482483self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]484485self.num_images = len(self.image_paths)486self._length = self.num_images487488if set == "train":489self._length = self.num_images * repeats490491self.interpolation = {492"linear": PIL_INTERPOLATION["linear"],493"bilinear": PIL_INTERPOLATION["bilinear"],494"bicubic": PIL_INTERPOLATION["bicubic"],495"lanczos": PIL_INTERPOLATION["lanczos"],496}[interpolation]497498self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small499self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)500501def __len__(self):502return self._length503504def __getitem__(self, i):505example = {}506image = Image.open(self.image_paths[i % self.num_images])507508if not image.mode == "RGB":509image = image.convert("RGB")510511placeholder_string = self.placeholder_token512text = random.choice(self.templates).format(placeholder_string)513514example["input_ids"] = self.tokenizer.encode(515text,516padding="max_length",517truncation=True,518max_length=self.tokenizer.model_max_length,519return_tensors="pt",520vector_shuffle=self.vector_shuffle,521prop_tokens_to_load=self.prop_tokens_to_load if self.progressive_tokens else 1.0,522)[0]523524# default to score-sde preprocessing525img = np.array(image).astype(np.uint8)526527if self.center_crop:528crop = min(img.shape[0], img.shape[1])529(530h,531w,532) = (533img.shape[0],534img.shape[1],535)536img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]537538image = Image.fromarray(img)539image = image.resize((self.size, self.size), resample=self.interpolation)540541image = self.flip_transform(image)542image = np.array(image).astype(np.uint8)543image = (image / 127.5 - 1.0).astype(np.float32)544545example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)546return example547548549def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):550if token is None:551token = HfFolder.get_token()552if organization is None:553username = whoami(token)["name"]554return f"{username}/{model_id}"555else:556return f"{organization}/{model_id}"557558559def main():560args = parse_args()561logging_dir = os.path.join(args.output_dir, args.logging_dir)562563accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit)564565accelerator = Accelerator(566gradient_accumulation_steps=args.gradient_accumulation_steps,567mixed_precision=args.mixed_precision,568log_with=args.report_to,569logging_dir=logging_dir,570project_config=accelerator_project_config,571)572573if args.report_to == "wandb":574if not is_wandb_available():575raise ImportError("Make sure to install wandb if you want to use it for logging during training.")576import wandb577578# Make one log on every process with the configuration for debugging.579logging.basicConfig(580format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",581datefmt="%m/%d/%Y %H:%M:%S",582level=logging.INFO,583)584logger.info(accelerator.state, main_process_only=False)585if accelerator.is_local_main_process:586transformers.utils.logging.set_verbosity_warning()587diffusers.utils.logging.set_verbosity_info()588else:589transformers.utils.logging.set_verbosity_error()590diffusers.utils.logging.set_verbosity_error()591592# If passed along, set the training seed now.593if args.seed is not None:594set_seed(args.seed)595596# Handle the repository creation597if accelerator.is_main_process:598if args.push_to_hub:599if args.hub_model_id is None:600repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)601else:602repo_name = args.hub_model_id603create_repo(repo_name, exist_ok=True, token=args.hub_token)604repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)605606with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:607if "step_*" not in gitignore:608gitignore.write("step_*\n")609if "epoch_*" not in gitignore:610gitignore.write("epoch_*\n")611elif args.output_dir is not None:612os.makedirs(args.output_dir, exist_ok=True)613614# Load tokenizer615if args.tokenizer_name:616tokenizer = MultiTokenCLIPTokenizer.from_pretrained(args.tokenizer_name)617elif args.pretrained_model_name_or_path:618tokenizer = MultiTokenCLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")619620# Load scheduler and models621noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")622text_encoder = CLIPTextModel.from_pretrained(623args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision624)625vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)626unet = UNet2DConditionModel.from_pretrained(627args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision628)629if is_xformers_available():630try:631unet.enable_xformers_memory_efficient_attention()632except Exception as e:633logger.warning(634"Could not enable memory efficient attention. Make sure xformers is installed"635f" correctly and a GPU is available: {e}"636)637add_tokens(tokenizer, text_encoder, args.placeholder_token, args.num_vec_per_token, args.initializer_token)638639# Freeze vae and unet640vae.requires_grad_(False)641unet.requires_grad_(False)642# Freeze all parameters except for the token embeddings in text encoder643text_encoder.text_model.encoder.requires_grad_(False)644text_encoder.text_model.final_layer_norm.requires_grad_(False)645text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)646647if args.gradient_checkpointing:648# Keep unet in train mode if we are using gradient checkpointing to save memory.649# The dropout cannot be != 0 so it doesn't matter if we are in eval or train mode.650unet.train()651text_encoder.gradient_checkpointing_enable()652unet.enable_gradient_checkpointing()653654if args.enable_xformers_memory_efficient_attention:655if is_xformers_available():656import xformers657658xformers_version = version.parse(xformers.__version__)659if xformers_version == version.parse("0.0.16"):660logger.warn(661"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."662)663unet.enable_xformers_memory_efficient_attention()664else:665raise ValueError("xformers is not available. Make sure it is installed correctly")666667# Enable TF32 for faster training on Ampere GPUs,668# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices669if args.allow_tf32:670torch.backends.cuda.matmul.allow_tf32 = True671672if args.scale_lr:673args.learning_rate = (674args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes675)676677# Initialize the optimizer678optimizer = torch.optim.AdamW(679text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings680lr=args.learning_rate,681betas=(args.adam_beta1, args.adam_beta2),682weight_decay=args.adam_weight_decay,683eps=args.adam_epsilon,684)685686# Dataset and DataLoaders creation:687train_dataset = TextualInversionDataset(688data_root=args.train_data_dir,689tokenizer=tokenizer,690size=args.resolution,691placeholder_token=args.placeholder_token,692repeats=args.repeats,693learnable_property=args.learnable_property,694center_crop=args.center_crop,695set="train",696)697train_dataloader = torch.utils.data.DataLoader(698train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers699)700701# Scheduler and math around the number of training steps.702overrode_max_train_steps = False703num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)704if args.max_train_steps is None:705args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch706overrode_max_train_steps = True707708lr_scheduler = get_scheduler(709args.lr_scheduler,710optimizer=optimizer,711num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,712num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,713)714715# Prepare everything with our `accelerator`.716text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(717text_encoder, optimizer, train_dataloader, lr_scheduler718)719720# For mixed precision training we cast the unet and vae weights to half-precision721# as these models are only used for inference, keeping weights in full precision is not required.722weight_dtype = torch.float32723if accelerator.mixed_precision == "fp16":724weight_dtype = torch.float16725elif accelerator.mixed_precision == "bf16":726weight_dtype = torch.bfloat16727728# Move vae and unet to device and cast to weight_dtype729unet.to(accelerator.device, dtype=weight_dtype)730vae.to(accelerator.device, dtype=weight_dtype)731732# We need to recalculate our total training steps as the size of the training dataloader may have changed.733num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)734if overrode_max_train_steps:735args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch736# Afterwards we recalculate our number of training epochs737args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)738739# We need to initialize the trackers we use, and also store our configuration.740# The trackers initializes automatically on the main process.741if accelerator.is_main_process:742accelerator.init_trackers("textual_inversion", config=vars(args))743744# Train!745total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps746747logger.info("***** Running training *****")748logger.info(f" Num examples = {len(train_dataset)}")749logger.info(f" Num Epochs = {args.num_train_epochs}")750logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")751logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")752logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")753logger.info(f" Total optimization steps = {args.max_train_steps}")754global_step = 0755first_epoch = 0756757# Potentially load in the weights and states from a previous save758if args.resume_from_checkpoint:759if args.resume_from_checkpoint != "latest":760path = os.path.basename(args.resume_from_checkpoint)761else:762# Get the most recent checkpoint763dirs = os.listdir(args.output_dir)764dirs = [d for d in dirs if d.startswith("checkpoint")]765dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))766path = dirs[-1] if len(dirs) > 0 else None767768if path is None:769accelerator.print(770f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."771)772args.resume_from_checkpoint = None773else:774accelerator.print(f"Resuming from checkpoint {path}")775accelerator.load_state(os.path.join(args.output_dir, path))776global_step = int(path.split("-")[1])777778resume_global_step = global_step * args.gradient_accumulation_steps779first_epoch = global_step // num_update_steps_per_epoch780resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)781782# Only show the progress bar once on each machine.783progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)784progress_bar.set_description("Steps")785786# keep original embeddings as reference787orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.clone()788789for epoch in range(first_epoch, args.num_train_epochs):790text_encoder.train()791for step, batch in enumerate(train_dataloader):792# Skip steps until we reach the resumed step793if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:794if step % args.gradient_accumulation_steps == 0:795progress_bar.update(1)796continue797if args.progressive_tokens:798train_dataset.prop_tokens_to_load = float(global_step) / args.progressive_tokens_max_steps799800with accelerator.accumulate(text_encoder):801# Convert images to latent space802latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach()803latents = latents * vae.config.scaling_factor804805# Sample noise that we'll add to the latents806noise = torch.randn_like(latents)807bsz = latents.shape[0]808# Sample a random timestep for each image809timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)810timesteps = timesteps.long()811812# Add noise to the latents according to the noise magnitude at each timestep813# (this is the forward diffusion process)814noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)815816# Get the text embedding for conditioning817encoder_hidden_states = text_encoder(batch["input_ids"])[0].to(dtype=weight_dtype)818819# Predict the noise residual820model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample821822# Get the target for loss depending on the prediction type823if noise_scheduler.config.prediction_type == "epsilon":824target = noise825elif noise_scheduler.config.prediction_type == "v_prediction":826target = noise_scheduler.get_velocity(latents, noise, timesteps)827else:828raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")829830loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")831832accelerator.backward(loss)833834optimizer.step()835lr_scheduler.step()836optimizer.zero_grad()837838# Let's make sure we don't update any embedding weights besides the newly added token839index_no_updates = get_mask(tokenizer, accelerator)840with torch.no_grad():841accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[842index_no_updates843] = orig_embeds_params[index_no_updates]844845# Checks if the accelerator has performed an optimization step behind the scenes846if accelerator.sync_gradients:847progress_bar.update(1)848global_step += 1849if global_step % args.save_steps == 0:850save_path = os.path.join(args.output_dir, f"learned_embeds-steps-{global_step}.bin")851save_progress(tokenizer, text_encoder, accelerator, save_path)852853if global_step % args.checkpointing_steps == 0:854if accelerator.is_main_process:855save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")856accelerator.save_state(save_path)857logger.info(f"Saved state to {save_path}")858859logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}860progress_bar.set_postfix(**logs)861accelerator.log(logs, step=global_step)862863if global_step >= args.max_train_steps:864break865866if accelerator.is_main_process and args.validation_prompt is not None and epoch % args.validation_epochs == 0:867logger.info(868f"Running validation... \n Generating {args.num_validation_images} images with prompt:"869f" {args.validation_prompt}."870)871# create pipeline (note: unet and vae are loaded again in float32)872pipeline = DiffusionPipeline.from_pretrained(873args.pretrained_model_name_or_path,874text_encoder=accelerator.unwrap_model(text_encoder),875tokenizer=tokenizer,876unet=unet,877vae=vae,878revision=args.revision,879torch_dtype=weight_dtype,880)881pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)882pipeline = pipeline.to(accelerator.device)883pipeline.set_progress_bar_config(disable=True)884885# run inference886generator = (887None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)888)889images = []890for _ in range(args.num_validation_images):891with torch.autocast("cuda"):892image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]893images.append(image)894895for tracker in accelerator.trackers:896if tracker.name == "tensorboard":897np_images = np.stack([np.asarray(img) for img in images])898tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")899if tracker.name == "wandb":900tracker.log(901{902"validation": [903wandb.Image(image, caption=f"{i}: {args.validation_prompt}")904for i, image in enumerate(images)905]906}907)908909del pipeline910torch.cuda.empty_cache()911912# Create the pipeline using using the trained modules and save it.913accelerator.wait_for_everyone()914if accelerator.is_main_process:915if args.push_to_hub and args.only_save_embeds:916logger.warn("Enabling full model saving because --push_to_hub=True was specified.")917save_full_model = True918else:919save_full_model = not args.only_save_embeds920if save_full_model:921pipeline = StableDiffusionPipeline.from_pretrained(922args.pretrained_model_name_or_path,923text_encoder=accelerator.unwrap_model(text_encoder),924vae=vae,925unet=unet,926tokenizer=tokenizer,927)928pipeline.save_pretrained(args.output_dir)929# Save the newly trained embeddings930save_path = os.path.join(args.output_dir, "learned_embeds.bin")931save_progress(tokenizer, text_encoder, accelerator, save_path)932933if args.push_to_hub:934repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)935936accelerator.end_training()937938939if __name__ == "__main__":940main()941942943