Path: blob/main/examples/textual_inversion/textual_inversion_flax.py
1448 views
import argparse1import logging2import math3import os4import random5from pathlib import Path6from typing import Optional78import jax9import jax.numpy as jnp10import numpy as np11import optax12import PIL13import torch14import torch.utils.checkpoint15import transformers16from flax import jax_utils17from flax.training import train_state18from flax.training.common_utils import shard19from huggingface_hub import HfFolder, Repository, create_repo, whoami2021# TODO: remove and import from diffusers.utils when the new version of diffusers is released22from packaging import version23from PIL import Image24from torch.utils.data import Dataset25from torchvision import transforms26from tqdm.auto import tqdm27from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel, set_seed2829from diffusers import (30FlaxAutoencoderKL,31FlaxDDPMScheduler,32FlaxPNDMScheduler,33FlaxStableDiffusionPipeline,34FlaxUNet2DConditionModel,35)36from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionSafetyChecker37from diffusers.utils import check_min_version383940if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):41PIL_INTERPOLATION = {42"linear": PIL.Image.Resampling.BILINEAR,43"bilinear": PIL.Image.Resampling.BILINEAR,44"bicubic": PIL.Image.Resampling.BICUBIC,45"lanczos": PIL.Image.Resampling.LANCZOS,46"nearest": PIL.Image.Resampling.NEAREST,47}48else:49PIL_INTERPOLATION = {50"linear": PIL.Image.LINEAR,51"bilinear": PIL.Image.BILINEAR,52"bicubic": PIL.Image.BICUBIC,53"lanczos": PIL.Image.LANCZOS,54"nearest": PIL.Image.NEAREST,55}56# ------------------------------------------------------------------------------5758# Will error if the minimal version of diffusers is not installed. Remove at your own risks.59check_min_version("0.15.0.dev0")6061logger = logging.getLogger(__name__)626364def parse_args():65parser = argparse.ArgumentParser(description="Simple example of a training script.")66parser.add_argument(67"--pretrained_model_name_or_path",68type=str,69default=None,70required=True,71help="Path to pretrained model or model identifier from huggingface.co/models.",72)73parser.add_argument(74"--tokenizer_name",75type=str,76default=None,77help="Pretrained tokenizer name or path if not the same as model_name",78)79parser.add_argument(80"--train_data_dir", type=str, default=None, required=True, help="A folder containing the training data."81)82parser.add_argument(83"--placeholder_token",84type=str,85default=None,86required=True,87help="A token to use as a placeholder for the concept.",88)89parser.add_argument(90"--initializer_token", type=str, default=None, required=True, help="A token to use as initializer word."91)92parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'")93parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.")94parser.add_argument(95"--output_dir",96type=str,97default="text-inversion-model",98help="The output directory where the model predictions and checkpoints will be written.",99)100parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")101parser.add_argument(102"--resolution",103type=int,104default=512,105help=(106"The resolution for input images, all the images in the train/validation dataset will be resized to this"107" resolution"108),109)110parser.add_argument(111"--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution."112)113parser.add_argument(114"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."115)116parser.add_argument("--num_train_epochs", type=int, default=100)117parser.add_argument(118"--max_train_steps",119type=int,120default=5000,121help="Total number of training steps to perform. If provided, overrides num_train_epochs.",122)123parser.add_argument(124"--save_steps",125type=int,126default=500,127help="Save learned_embeds.bin every X updates steps.",128)129parser.add_argument(130"--learning_rate",131type=float,132default=1e-4,133help="Initial learning rate (after the potential warmup period) to use.",134)135parser.add_argument(136"--scale_lr",137action="store_true",138default=True,139help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",140)141parser.add_argument(142"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."143)144parser.add_argument(145"--revision",146type=str,147default=None,148required=False,149help="Revision of pretrained model identifier from huggingface.co/models.",150)151parser.add_argument(152"--lr_scheduler",153type=str,154default="constant",155help=(156'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'157' "constant", "constant_with_warmup"]'158),159)160parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")161parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")162parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")163parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")164parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")165parser.add_argument(166"--use_auth_token",167action="store_true",168help=(169"Will use the token generated when running `huggingface-cli login` (necessary to use this script with"170" private models)."171),172)173parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")174parser.add_argument(175"--hub_model_id",176type=str,177default=None,178help="The name of the repository to keep in sync with the local `output_dir`.",179)180parser.add_argument(181"--logging_dir",182type=str,183default="logs",184help=(185"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"186" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."187),188)189parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")190191args = parser.parse_args()192env_local_rank = int(os.environ.get("LOCAL_RANK", -1))193if env_local_rank != -1 and env_local_rank != args.local_rank:194args.local_rank = env_local_rank195196if args.train_data_dir is None:197raise ValueError("You must specify a train data directory.")198199return args200201202imagenet_templates_small = [203"a photo of a {}",204"a rendering of a {}",205"a cropped photo of the {}",206"the photo of a {}",207"a photo of a clean {}",208"a photo of a dirty {}",209"a dark photo of the {}",210"a photo of my {}",211"a photo of the cool {}",212"a close-up photo of a {}",213"a bright photo of the {}",214"a cropped photo of a {}",215"a photo of the {}",216"a good photo of the {}",217"a photo of one {}",218"a close-up photo of the {}",219"a rendition of the {}",220"a photo of the clean {}",221"a rendition of a {}",222"a photo of a nice {}",223"a good photo of a {}",224"a photo of the nice {}",225"a photo of the small {}",226"a photo of the weird {}",227"a photo of the large {}",228"a photo of a cool {}",229"a photo of a small {}",230]231232imagenet_style_templates_small = [233"a painting in the style of {}",234"a rendering in the style of {}",235"a cropped painting in the style of {}",236"the painting in the style of {}",237"a clean painting in the style of {}",238"a dirty painting in the style of {}",239"a dark painting in the style of {}",240"a picture in the style of {}",241"a cool painting in the style of {}",242"a close-up painting in the style of {}",243"a bright painting in the style of {}",244"a cropped painting in the style of {}",245"a good painting in the style of {}",246"a close-up painting in the style of {}",247"a rendition in the style of {}",248"a nice painting in the style of {}",249"a small painting in the style of {}",250"a weird painting in the style of {}",251"a large painting in the style of {}",252]253254255class TextualInversionDataset(Dataset):256def __init__(257self,258data_root,259tokenizer,260learnable_property="object", # [object, style]261size=512,262repeats=100,263interpolation="bicubic",264flip_p=0.5,265set="train",266placeholder_token="*",267center_crop=False,268):269self.data_root = data_root270self.tokenizer = tokenizer271self.learnable_property = learnable_property272self.size = size273self.placeholder_token = placeholder_token274self.center_crop = center_crop275self.flip_p = flip_p276277self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]278279self.num_images = len(self.image_paths)280self._length = self.num_images281282if set == "train":283self._length = self.num_images * repeats284285self.interpolation = {286"linear": PIL_INTERPOLATION["linear"],287"bilinear": PIL_INTERPOLATION["bilinear"],288"bicubic": PIL_INTERPOLATION["bicubic"],289"lanczos": PIL_INTERPOLATION["lanczos"],290}[interpolation]291292self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small293self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)294295def __len__(self):296return self._length297298def __getitem__(self, i):299example = {}300image = Image.open(self.image_paths[i % self.num_images])301302if not image.mode == "RGB":303image = image.convert("RGB")304305placeholder_string = self.placeholder_token306text = random.choice(self.templates).format(placeholder_string)307308example["input_ids"] = self.tokenizer(309text,310padding="max_length",311truncation=True,312max_length=self.tokenizer.model_max_length,313return_tensors="pt",314).input_ids[0]315316# default to score-sde preprocessing317img = np.array(image).astype(np.uint8)318319if self.center_crop:320crop = min(img.shape[0], img.shape[1])321(322h,323w,324) = (325img.shape[0],326img.shape[1],327)328img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]329330image = Image.fromarray(img)331image = image.resize((self.size, self.size), resample=self.interpolation)332333image = self.flip_transform(image)334image = np.array(image).astype(np.uint8)335image = (image / 127.5 - 1.0).astype(np.float32)336337example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)338return example339340341def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):342if token is None:343token = HfFolder.get_token()344if organization is None:345username = whoami(token)["name"]346return f"{username}/{model_id}"347else:348return f"{organization}/{model_id}"349350351def resize_token_embeddings(model, new_num_tokens, initializer_token_id, placeholder_token_id, rng):352if model.config.vocab_size == new_num_tokens or new_num_tokens is None:353return354model.config.vocab_size = new_num_tokens355356params = model.params357old_embeddings = params["text_model"]["embeddings"]["token_embedding"]["embedding"]358old_num_tokens, emb_dim = old_embeddings.shape359360initializer = jax.nn.initializers.normal()361362new_embeddings = initializer(rng, (new_num_tokens, emb_dim))363new_embeddings = new_embeddings.at[:old_num_tokens].set(old_embeddings)364new_embeddings = new_embeddings.at[placeholder_token_id].set(new_embeddings[initializer_token_id])365params["text_model"]["embeddings"]["token_embedding"]["embedding"] = new_embeddings366367model.params = params368return model369370371def get_params_to_save(params):372return jax.device_get(jax.tree_util.tree_map(lambda x: x[0], params))373374375def main():376args = parse_args()377378if args.seed is not None:379set_seed(args.seed)380381if jax.process_index() == 0:382if args.push_to_hub:383if args.hub_model_id is None:384repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)385else:386repo_name = args.hub_model_id387create_repo(repo_name, exist_ok=True, token=args.hub_token)388repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)389390with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:391if "step_*" not in gitignore:392gitignore.write("step_*\n")393if "epoch_*" not in gitignore:394gitignore.write("epoch_*\n")395elif args.output_dir is not None:396os.makedirs(args.output_dir, exist_ok=True)397398# Make one log on every process with the configuration for debugging.399logging.basicConfig(400format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",401datefmt="%m/%d/%Y %H:%M:%S",402level=logging.INFO,403)404# Setup logging, we only want one process per machine to log things on the screen.405logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)406if jax.process_index() == 0:407transformers.utils.logging.set_verbosity_info()408else:409transformers.utils.logging.set_verbosity_error()410411# Load the tokenizer and add the placeholder token as a additional special token412if args.tokenizer_name:413tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)414elif args.pretrained_model_name_or_path:415tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")416417# Add the placeholder token in tokenizer418num_added_tokens = tokenizer.add_tokens(args.placeholder_token)419if num_added_tokens == 0:420raise ValueError(421f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"422" `placeholder_token` that is not already in the tokenizer."423)424425# Convert the initializer_token, placeholder_token to ids426token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)427# Check if initializer_token is a single token or a sequence of tokens428if len(token_ids) > 1:429raise ValueError("The initializer token must be a single token.")430431initializer_token_id = token_ids[0]432placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)433434# Load models and create wrapper for stable diffusion435text_encoder = FlaxCLIPTextModel.from_pretrained(436args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision437)438vae, vae_params = FlaxAutoencoderKL.from_pretrained(439args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision440)441unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(442args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision443)444445# Create sampling rng446rng = jax.random.PRNGKey(args.seed)447rng, _ = jax.random.split(rng)448# Resize the token embeddings as we are adding new special tokens to the tokenizer449text_encoder = resize_token_embeddings(450text_encoder, len(tokenizer), initializer_token_id, placeholder_token_id, rng451)452original_token_embeds = text_encoder.params["text_model"]["embeddings"]["token_embedding"]["embedding"]453454train_dataset = TextualInversionDataset(455data_root=args.train_data_dir,456tokenizer=tokenizer,457size=args.resolution,458placeholder_token=args.placeholder_token,459repeats=args.repeats,460learnable_property=args.learnable_property,461center_crop=args.center_crop,462set="train",463)464465def collate_fn(examples):466pixel_values = torch.stack([example["pixel_values"] for example in examples])467input_ids = torch.stack([example["input_ids"] for example in examples])468469batch = {"pixel_values": pixel_values, "input_ids": input_ids}470batch = {k: v.numpy() for k, v in batch.items()}471472return batch473474total_train_batch_size = args.train_batch_size * jax.local_device_count()475train_dataloader = torch.utils.data.DataLoader(476train_dataset, batch_size=total_train_batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn477)478479# Optimization480if args.scale_lr:481args.learning_rate = args.learning_rate * total_train_batch_size482483constant_scheduler = optax.constant_schedule(args.learning_rate)484485optimizer = optax.adamw(486learning_rate=constant_scheduler,487b1=args.adam_beta1,488b2=args.adam_beta2,489eps=args.adam_epsilon,490weight_decay=args.adam_weight_decay,491)492493def create_mask(params, label_fn):494def _map(params, mask, label_fn):495for k in params:496if label_fn(k):497mask[k] = "token_embedding"498else:499if isinstance(params[k], dict):500mask[k] = {}501_map(params[k], mask[k], label_fn)502else:503mask[k] = "zero"504505mask = {}506_map(params, mask, label_fn)507return mask508509def zero_grads():510# from https://github.com/deepmind/optax/issues/159#issuecomment-896459491511def init_fn(_):512return ()513514def update_fn(updates, state, params=None):515return jax.tree_util.tree_map(jnp.zeros_like, updates), ()516517return optax.GradientTransformation(init_fn, update_fn)518519# Zero out gradients of layers other than the token embedding layer520tx = optax.multi_transform(521{"token_embedding": optimizer, "zero": zero_grads()},522create_mask(text_encoder.params, lambda s: s == "token_embedding"),523)524525state = train_state.TrainState.create(apply_fn=text_encoder.__call__, params=text_encoder.params, tx=tx)526527noise_scheduler = FlaxDDPMScheduler(528beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000529)530noise_scheduler_state = noise_scheduler.create_state()531532# Initialize our training533train_rngs = jax.random.split(rng, jax.local_device_count())534535# Define gradient train step fn536def train_step(state, vae_params, unet_params, batch, train_rng):537dropout_rng, sample_rng, new_train_rng = jax.random.split(train_rng, 3)538539def compute_loss(params):540vae_outputs = vae.apply(541{"params": vae_params}, batch["pixel_values"], deterministic=True, method=vae.encode542)543latents = vae_outputs.latent_dist.sample(sample_rng)544# (NHWC) -> (NCHW)545latents = jnp.transpose(latents, (0, 3, 1, 2))546latents = latents * vae.config.scaling_factor547548noise_rng, timestep_rng = jax.random.split(sample_rng)549noise = jax.random.normal(noise_rng, latents.shape)550bsz = latents.shape[0]551timesteps = jax.random.randint(552timestep_rng,553(bsz,),5540,555noise_scheduler.config.num_train_timesteps,556)557noisy_latents = noise_scheduler.add_noise(noise_scheduler_state, latents, noise, timesteps)558encoder_hidden_states = state.apply_fn(559batch["input_ids"], params=params, dropout_rng=dropout_rng, train=True560)[0]561# Predict the noise residual and compute loss562model_pred = unet.apply(563{"params": unet_params}, noisy_latents, timesteps, encoder_hidden_states, train=False564).sample565566# Get the target for loss depending on the prediction type567if noise_scheduler.config.prediction_type == "epsilon":568target = noise569elif noise_scheduler.config.prediction_type == "v_prediction":570target = noise_scheduler.get_velocity(noise_scheduler_state, latents, noise, timesteps)571else:572raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")573574loss = (target - model_pred) ** 2575loss = loss.mean()576577return loss578579grad_fn = jax.value_and_grad(compute_loss)580loss, grad = grad_fn(state.params)581grad = jax.lax.pmean(grad, "batch")582new_state = state.apply_gradients(grads=grad)583584# Keep the token embeddings fixed except the newly added embeddings for the concept,585# as we only want to optimize the concept embeddings586token_embeds = original_token_embeds.at[placeholder_token_id].set(587new_state.params["text_model"]["embeddings"]["token_embedding"]["embedding"][placeholder_token_id]588)589new_state.params["text_model"]["embeddings"]["token_embedding"]["embedding"] = token_embeds590591metrics = {"loss": loss}592metrics = jax.lax.pmean(metrics, axis_name="batch")593return new_state, metrics, new_train_rng594595# Create parallel version of the train and eval step596p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))597598# Replicate the train state on each device599state = jax_utils.replicate(state)600vae_params = jax_utils.replicate(vae_params)601unet_params = jax_utils.replicate(unet_params)602603# Train!604num_update_steps_per_epoch = math.ceil(len(train_dataloader))605606# Scheduler and math around the number of training steps.607if args.max_train_steps is None:608args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch609610args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)611612logger.info("***** Running training *****")613logger.info(f" Num examples = {len(train_dataset)}")614logger.info(f" Num Epochs = {args.num_train_epochs}")615logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")616logger.info(f" Total train batch size (w. parallel & distributed) = {total_train_batch_size}")617logger.info(f" Total optimization steps = {args.max_train_steps}")618619global_step = 0620621epochs = tqdm(range(args.num_train_epochs), desc=f"Epoch ... (1/{args.num_train_epochs})", position=0)622for epoch in epochs:623# ======================== Training ================================624625train_metrics = []626627steps_per_epoch = len(train_dataset) // total_train_batch_size628train_step_progress_bar = tqdm(total=steps_per_epoch, desc="Training...", position=1, leave=False)629# train630for batch in train_dataloader:631batch = shard(batch)632state, train_metric, train_rngs = p_train_step(state, vae_params, unet_params, batch, train_rngs)633train_metrics.append(train_metric)634635train_step_progress_bar.update(1)636global_step += 1637638if global_step >= args.max_train_steps:639break640if global_step % args.save_steps == 0:641learned_embeds = get_params_to_save(state.params)["text_model"]["embeddings"]["token_embedding"][642"embedding"643][placeholder_token_id]644learned_embeds_dict = {args.placeholder_token: learned_embeds}645jnp.save(646os.path.join(args.output_dir, "learned_embeds-" + str(global_step) + ".npy"), learned_embeds_dict647)648649train_metric = jax_utils.unreplicate(train_metric)650651train_step_progress_bar.close()652epochs.write(f"Epoch... ({epoch + 1}/{args.num_train_epochs} | Loss: {train_metric['loss']})")653654# Create the pipeline using using the trained modules and save it.655if jax.process_index() == 0:656scheduler = FlaxPNDMScheduler(657beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True658)659safety_checker = FlaxStableDiffusionSafetyChecker.from_pretrained(660"CompVis/stable-diffusion-safety-checker", from_pt=True661)662pipeline = FlaxStableDiffusionPipeline(663text_encoder=text_encoder,664vae=vae,665unet=unet,666tokenizer=tokenizer,667scheduler=scheduler,668safety_checker=safety_checker,669feature_extractor=CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32"),670)671672pipeline.save_pretrained(673args.output_dir,674params={675"text_encoder": get_params_to_save(state.params),676"vae": get_params_to_save(vae_params),677"unet": get_params_to_save(unet_params),678"safety_checker": safety_checker.params,679},680)681682# Also save the newly trained embeddings683learned_embeds = get_params_to_save(state.params)["text_model"]["embeddings"]["token_embedding"]["embedding"][684placeholder_token_id685]686learned_embeds_dict = {args.placeholder_token: learned_embeds}687jnp.save(os.path.join(args.output_dir, "learned_embeds.npy"), learned_embeds_dict)688689if args.push_to_hub:690repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)691692693if __name__ == "__main__":694main()695696697