Path: blob/main/examples/dreambooth/train_dreambooth_flax.py
1441 views
import argparse1import hashlib2import logging3import math4import os5from pathlib import Path6from typing import Optional78import jax9import jax.numpy as jnp10import numpy as np11import optax12import torch13import torch.utils.checkpoint14import transformers15from flax import jax_utils16from flax.training import train_state17from flax.training.common_utils import shard18from huggingface_hub import HfFolder, Repository, create_repo, whoami19from jax.experimental.compilation_cache import compilation_cache as cc20from PIL import Image21from torch.utils.data import Dataset22from torchvision import transforms23from tqdm.auto import tqdm24from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel, set_seed2526from diffusers import (27FlaxAutoencoderKL,28FlaxDDPMScheduler,29FlaxPNDMScheduler,30FlaxStableDiffusionPipeline,31FlaxUNet2DConditionModel,32)33from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionSafetyChecker34from diffusers.utils import check_min_version353637# Will error if the minimal version of diffusers is not installed. Remove at your own risks.38check_min_version("0.15.0.dev0")3940# Cache compiled models across invocations of this script.41cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))4243logger = logging.getLogger(__name__)444546def parse_args():47parser = argparse.ArgumentParser(description="Simple example of a training script.")48parser.add_argument(49"--pretrained_model_name_or_path",50type=str,51default=None,52required=True,53help="Path to pretrained model or model identifier from huggingface.co/models.",54)55parser.add_argument(56"--pretrained_vae_name_or_path",57type=str,58default=None,59help="Path to pretrained vae or vae identifier from huggingface.co/models.",60)61parser.add_argument(62"--revision",63type=str,64default=None,65required=False,66help="Revision of pretrained model identifier from huggingface.co/models.",67)68parser.add_argument(69"--tokenizer_name",70type=str,71default=None,72help="Pretrained tokenizer name or path if not the same as model_name",73)74parser.add_argument(75"--instance_data_dir",76type=str,77default=None,78required=True,79help="A folder containing the training data of instance images.",80)81parser.add_argument(82"--class_data_dir",83type=str,84default=None,85required=False,86help="A folder containing the training data of class images.",87)88parser.add_argument(89"--instance_prompt",90type=str,91default=None,92help="The prompt with identifier specifying the instance",93)94parser.add_argument(95"--class_prompt",96type=str,97default=None,98help="The prompt to specify images in the same class as provided instance images.",99)100parser.add_argument(101"--with_prior_preservation",102default=False,103action="store_true",104help="Flag to add prior preservation loss.",105)106parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")107parser.add_argument(108"--num_class_images",109type=int,110default=100,111help=(112"Minimal class images for prior preservation loss. If there are not enough images already present in"113" class_data_dir, additional images will be sampled with class_prompt."114),115)116parser.add_argument(117"--output_dir",118type=str,119default="text-inversion-model",120help="The output directory where the model predictions and checkpoints will be written.",121)122parser.add_argument("--save_steps", type=int, default=None, help="Save a checkpoint every X steps.")123parser.add_argument("--seed", type=int, default=0, help="A seed for reproducible training.")124parser.add_argument(125"--resolution",126type=int,127default=512,128help=(129"The resolution for input images, all the images in the train/validation dataset will be resized to this"130" resolution"131),132)133parser.add_argument(134"--center_crop",135default=False,136action="store_true",137help=(138"Whether to center crop the input images to the resolution. If not set, the images will be randomly"139" cropped. The images will be resized to the resolution first before cropping."140),141)142parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder")143parser.add_argument(144"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."145)146parser.add_argument(147"--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."148)149parser.add_argument("--num_train_epochs", type=int, default=1)150parser.add_argument(151"--max_train_steps",152type=int,153default=None,154help="Total number of training steps to perform. If provided, overrides num_train_epochs.",155)156parser.add_argument(157"--learning_rate",158type=float,159default=5e-6,160help="Initial learning rate (after the potential warmup period) to use.",161)162parser.add_argument(163"--scale_lr",164action="store_true",165default=False,166help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",167)168parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")169parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")170parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")171parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")172parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")173parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")174parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")175parser.add_argument(176"--hub_model_id",177type=str,178default=None,179help="The name of the repository to keep in sync with the local `output_dir`.",180)181parser.add_argument(182"--logging_dir",183type=str,184default="logs",185help=(186"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"187" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."188),189)190parser.add_argument(191"--mixed_precision",192type=str,193default="no",194choices=["no", "fp16", "bf16"],195help=(196"Whether to use mixed precision. Choose"197"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."198"and an Nvidia Ampere GPU."199),200)201parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")202203args = parser.parse_args()204env_local_rank = int(os.environ.get("LOCAL_RANK", -1))205if env_local_rank != -1 and env_local_rank != args.local_rank:206args.local_rank = env_local_rank207208if args.instance_data_dir is None:209raise ValueError("You must specify a train data directory.")210211if args.with_prior_preservation:212if args.class_data_dir is None:213raise ValueError("You must specify a data directory for class images.")214if args.class_prompt is None:215raise ValueError("You must specify prompt for class images.")216217return args218219220class DreamBoothDataset(Dataset):221"""222A dataset to prepare the instance and class images with the prompts for fine-tuning the model.223It pre-processes the images and the tokenizes prompts.224"""225226def __init__(227self,228instance_data_root,229instance_prompt,230tokenizer,231class_data_root=None,232class_prompt=None,233class_num=None,234size=512,235center_crop=False,236):237self.size = size238self.center_crop = center_crop239self.tokenizer = tokenizer240241self.instance_data_root = Path(instance_data_root)242if not self.instance_data_root.exists():243raise ValueError("Instance images root doesn't exists.")244245self.instance_images_path = list(Path(instance_data_root).iterdir())246self.num_instance_images = len(self.instance_images_path)247self.instance_prompt = instance_prompt248self._length = self.num_instance_images249250if class_data_root is not None:251self.class_data_root = Path(class_data_root)252self.class_data_root.mkdir(parents=True, exist_ok=True)253self.class_images_path = list(self.class_data_root.iterdir())254if class_num is not None:255self.num_class_images = min(len(self.class_images_path), class_num)256else:257self.num_class_images = len(self.class_images_path)258self._length = max(self.num_class_images, self.num_instance_images)259self.class_prompt = class_prompt260else:261self.class_data_root = None262263self.image_transforms = transforms.Compose(264[265transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),266transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),267transforms.ToTensor(),268transforms.Normalize([0.5], [0.5]),269]270)271272def __len__(self):273return self._length274275def __getitem__(self, index):276example = {}277instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])278if not instance_image.mode == "RGB":279instance_image = instance_image.convert("RGB")280example["instance_images"] = self.image_transforms(instance_image)281example["instance_prompt_ids"] = self.tokenizer(282self.instance_prompt,283padding="do_not_pad",284truncation=True,285max_length=self.tokenizer.model_max_length,286).input_ids287288if self.class_data_root:289class_image = Image.open(self.class_images_path[index % self.num_class_images])290if not class_image.mode == "RGB":291class_image = class_image.convert("RGB")292example["class_images"] = self.image_transforms(class_image)293example["class_prompt_ids"] = self.tokenizer(294self.class_prompt,295padding="do_not_pad",296truncation=True,297max_length=self.tokenizer.model_max_length,298).input_ids299300return example301302303class PromptDataset(Dataset):304"A simple dataset to prepare the prompts to generate class images on multiple GPUs."305306def __init__(self, prompt, num_samples):307self.prompt = prompt308self.num_samples = num_samples309310def __len__(self):311return self.num_samples312313def __getitem__(self, index):314example = {}315example["prompt"] = self.prompt316example["index"] = index317return example318319320def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):321if token is None:322token = HfFolder.get_token()323if organization is None:324username = whoami(token)["name"]325return f"{username}/{model_id}"326else:327return f"{organization}/{model_id}"328329330def get_params_to_save(params):331return jax.device_get(jax.tree_util.tree_map(lambda x: x[0], params))332333334def main():335args = parse_args()336337logging.basicConfig(338format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",339datefmt="%m/%d/%Y %H:%M:%S",340level=logging.INFO,341)342# Setup logging, we only want one process per machine to log things on the screen.343logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)344if jax.process_index() == 0:345transformers.utils.logging.set_verbosity_info()346else:347transformers.utils.logging.set_verbosity_error()348349if args.seed is not None:350set_seed(args.seed)351352rng = jax.random.PRNGKey(args.seed)353354if args.with_prior_preservation:355class_images_dir = Path(args.class_data_dir)356if not class_images_dir.exists():357class_images_dir.mkdir(parents=True)358cur_class_images = len(list(class_images_dir.iterdir()))359360if cur_class_images < args.num_class_images:361pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(362args.pretrained_model_name_or_path, safety_checker=None, revision=args.revision363)364pipeline.set_progress_bar_config(disable=True)365366num_new_images = args.num_class_images - cur_class_images367logger.info(f"Number of class images to sample: {num_new_images}.")368369sample_dataset = PromptDataset(args.class_prompt, num_new_images)370total_sample_batch_size = args.sample_batch_size * jax.local_device_count()371sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=total_sample_batch_size)372373for example in tqdm(374sample_dataloader, desc="Generating class images", disable=not jax.process_index() == 0375):376prompt_ids = pipeline.prepare_inputs(example["prompt"])377prompt_ids = shard(prompt_ids)378p_params = jax_utils.replicate(params)379rng = jax.random.split(rng)[0]380sample_rng = jax.random.split(rng, jax.device_count())381images = pipeline(prompt_ids, p_params, sample_rng, jit=True).images382images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])383images = pipeline.numpy_to_pil(np.array(images))384385for i, image in enumerate(images):386hash_image = hashlib.sha1(image.tobytes()).hexdigest()387image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"388image.save(image_filename)389390del pipeline391392# Handle the repository creation393if jax.process_index() == 0:394if args.push_to_hub:395if args.hub_model_id is None:396repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)397else:398repo_name = args.hub_model_id399create_repo(repo_name, exist_ok=True, token=args.hub_token)400repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)401402with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:403if "step_*" not in gitignore:404gitignore.write("step_*\n")405if "epoch_*" not in gitignore:406gitignore.write("epoch_*\n")407elif args.output_dir is not None:408os.makedirs(args.output_dir, exist_ok=True)409410# Load the tokenizer and add the placeholder token as a additional special token411if args.tokenizer_name:412tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)413elif args.pretrained_model_name_or_path:414tokenizer = CLIPTokenizer.from_pretrained(415args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision416)417else:418raise NotImplementedError("No tokenizer specified!")419420train_dataset = DreamBoothDataset(421instance_data_root=args.instance_data_dir,422instance_prompt=args.instance_prompt,423class_data_root=args.class_data_dir if args.with_prior_preservation else None,424class_prompt=args.class_prompt,425class_num=args.num_class_images,426tokenizer=tokenizer,427size=args.resolution,428center_crop=args.center_crop,429)430431def collate_fn(examples):432input_ids = [example["instance_prompt_ids"] for example in examples]433pixel_values = [example["instance_images"] for example in examples]434435# Concat class and instance examples for prior preservation.436# We do this to avoid doing two forward passes.437if args.with_prior_preservation:438input_ids += [example["class_prompt_ids"] for example in examples]439pixel_values += [example["class_images"] for example in examples]440441pixel_values = torch.stack(pixel_values)442pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()443444input_ids = tokenizer.pad(445{"input_ids": input_ids}, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"446).input_ids447448batch = {449"input_ids": input_ids,450"pixel_values": pixel_values,451}452batch = {k: v.numpy() for k, v in batch.items()}453return batch454455total_train_batch_size = args.train_batch_size * jax.local_device_count()456if len(train_dataset) < total_train_batch_size:457raise ValueError(458f"Training batch size is {total_train_batch_size}, but your dataset only contains"459f" {len(train_dataset)} images. Please, use a larger dataset or reduce the effective batch size. Note that"460f" there are {jax.local_device_count()} parallel devices, so your batch size can't be smaller than that."461)462463train_dataloader = torch.utils.data.DataLoader(464train_dataset, batch_size=total_train_batch_size, shuffle=True, collate_fn=collate_fn, drop_last=True465)466467weight_dtype = jnp.float32468if args.mixed_precision == "fp16":469weight_dtype = jnp.float16470elif args.mixed_precision == "bf16":471weight_dtype = jnp.bfloat16472473if args.pretrained_vae_name_or_path:474# TODO(patil-suraj): Upload flax weights for the VAE475vae_arg, vae_kwargs = (args.pretrained_vae_name_or_path, {"from_pt": True})476else:477vae_arg, vae_kwargs = (args.pretrained_model_name_or_path, {"subfolder": "vae", "revision": args.revision})478479# Load models and create wrapper for stable diffusion480text_encoder = FlaxCLIPTextModel.from_pretrained(481args.pretrained_model_name_or_path, subfolder="text_encoder", dtype=weight_dtype, revision=args.revision482)483vae, vae_params = FlaxAutoencoderKL.from_pretrained(484vae_arg,485dtype=weight_dtype,486**vae_kwargs,487)488unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(489args.pretrained_model_name_or_path, subfolder="unet", dtype=weight_dtype, revision=args.revision490)491492# Optimization493if args.scale_lr:494args.learning_rate = args.learning_rate * total_train_batch_size495496constant_scheduler = optax.constant_schedule(args.learning_rate)497498adamw = optax.adamw(499learning_rate=constant_scheduler,500b1=args.adam_beta1,501b2=args.adam_beta2,502eps=args.adam_epsilon,503weight_decay=args.adam_weight_decay,504)505506optimizer = optax.chain(507optax.clip_by_global_norm(args.max_grad_norm),508adamw,509)510511unet_state = train_state.TrainState.create(apply_fn=unet.__call__, params=unet_params, tx=optimizer)512text_encoder_state = train_state.TrainState.create(513apply_fn=text_encoder.__call__, params=text_encoder.params, tx=optimizer514)515516noise_scheduler = FlaxDDPMScheduler(517beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000518)519noise_scheduler_state = noise_scheduler.create_state()520521# Initialize our training522train_rngs = jax.random.split(rng, jax.local_device_count())523524def train_step(unet_state, text_encoder_state, vae_params, batch, train_rng):525dropout_rng, sample_rng, new_train_rng = jax.random.split(train_rng, 3)526527if args.train_text_encoder:528params = {"text_encoder": text_encoder_state.params, "unet": unet_state.params}529else:530params = {"unet": unet_state.params}531532def compute_loss(params):533# Convert images to latent space534vae_outputs = vae.apply(535{"params": vae_params}, batch["pixel_values"], deterministic=True, method=vae.encode536)537latents = vae_outputs.latent_dist.sample(sample_rng)538# (NHWC) -> (NCHW)539latents = jnp.transpose(latents, (0, 3, 1, 2))540latents = latents * vae.config.scaling_factor541542# Sample noise that we'll add to the latents543noise_rng, timestep_rng = jax.random.split(sample_rng)544noise = jax.random.normal(noise_rng, latents.shape)545# Sample a random timestep for each image546bsz = latents.shape[0]547timesteps = jax.random.randint(548timestep_rng,549(bsz,),5500,551noise_scheduler.config.num_train_timesteps,552)553554# Add noise to the latents according to the noise magnitude at each timestep555# (this is the forward diffusion process)556noisy_latents = noise_scheduler.add_noise(noise_scheduler_state, latents, noise, timesteps)557558# Get the text embedding for conditioning559if args.train_text_encoder:560encoder_hidden_states = text_encoder_state.apply_fn(561batch["input_ids"], params=params["text_encoder"], dropout_rng=dropout_rng, train=True562)[0]563else:564encoder_hidden_states = text_encoder(565batch["input_ids"], params=text_encoder_state.params, train=False566)[0]567568# Predict the noise residual569model_pred = unet.apply(570{"params": params["unet"]}, noisy_latents, timesteps, encoder_hidden_states, train=True571).sample572573# Get the target for loss depending on the prediction type574if noise_scheduler.config.prediction_type == "epsilon":575target = noise576elif noise_scheduler.config.prediction_type == "v_prediction":577target = noise_scheduler.get_velocity(noise_scheduler_state, latents, noise, timesteps)578else:579raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")580581if args.with_prior_preservation:582# Chunk the noise and noise_pred into two parts and compute the loss on each part separately.583model_pred, model_pred_prior = jnp.split(model_pred, 2, axis=0)584target, target_prior = jnp.split(target, 2, axis=0)585586# Compute instance loss587loss = (target - model_pred) ** 2588loss = loss.mean()589590# Compute prior loss591prior_loss = (target_prior - model_pred_prior) ** 2592prior_loss = prior_loss.mean()593594# Add the prior loss to the instance loss.595loss = loss + args.prior_loss_weight * prior_loss596else:597loss = (target - model_pred) ** 2598loss = loss.mean()599600return loss601602grad_fn = jax.value_and_grad(compute_loss)603loss, grad = grad_fn(params)604grad = jax.lax.pmean(grad, "batch")605606new_unet_state = unet_state.apply_gradients(grads=grad["unet"])607if args.train_text_encoder:608new_text_encoder_state = text_encoder_state.apply_gradients(grads=grad["text_encoder"])609else:610new_text_encoder_state = text_encoder_state611612metrics = {"loss": loss}613metrics = jax.lax.pmean(metrics, axis_name="batch")614615return new_unet_state, new_text_encoder_state, metrics, new_train_rng616617# Create parallel version of the train step618p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0, 1))619620# Replicate the train state on each device621unet_state = jax_utils.replicate(unet_state)622text_encoder_state = jax_utils.replicate(text_encoder_state)623vae_params = jax_utils.replicate(vae_params)624625# Train!626num_update_steps_per_epoch = math.ceil(len(train_dataloader))627628# Scheduler and math around the number of training steps.629if args.max_train_steps is None:630args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch631632args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)633634logger.info("***** Running training *****")635logger.info(f" Num examples = {len(train_dataset)}")636logger.info(f" Num Epochs = {args.num_train_epochs}")637logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")638logger.info(f" Total train batch size (w. parallel & distributed) = {total_train_batch_size}")639logger.info(f" Total optimization steps = {args.max_train_steps}")640641def checkpoint(step=None):642# Create the pipeline using the trained modules and save it.643scheduler, _ = FlaxPNDMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")644safety_checker = FlaxStableDiffusionSafetyChecker.from_pretrained(645"CompVis/stable-diffusion-safety-checker", from_pt=True646)647pipeline = FlaxStableDiffusionPipeline(648text_encoder=text_encoder,649vae=vae,650unet=unet,651tokenizer=tokenizer,652scheduler=scheduler,653safety_checker=safety_checker,654feature_extractor=CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32"),655)656657outdir = os.path.join(args.output_dir, str(step)) if step else args.output_dir658pipeline.save_pretrained(659outdir,660params={661"text_encoder": get_params_to_save(text_encoder_state.params),662"vae": get_params_to_save(vae_params),663"unet": get_params_to_save(unet_state.params),664"safety_checker": safety_checker.params,665},666)667668if args.push_to_hub:669message = f"checkpoint-{step}" if step is not None else "End of training"670repo.push_to_hub(commit_message=message, blocking=False, auto_lfs_prune=True)671672global_step = 0673674epochs = tqdm(range(args.num_train_epochs), desc="Epoch ... ", position=0)675for epoch in epochs:676# ======================== Training ================================677678train_metrics = []679680steps_per_epoch = len(train_dataset) // total_train_batch_size681train_step_progress_bar = tqdm(total=steps_per_epoch, desc="Training...", position=1, leave=False)682# train683for batch in train_dataloader:684batch = shard(batch)685unet_state, text_encoder_state, train_metric, train_rngs = p_train_step(686unet_state, text_encoder_state, vae_params, batch, train_rngs687)688train_metrics.append(train_metric)689690train_step_progress_bar.update(jax.local_device_count())691692global_step += 1693if jax.process_index() == 0 and args.save_steps and global_step % args.save_steps == 0:694checkpoint(global_step)695if global_step >= args.max_train_steps:696break697698train_metric = jax_utils.unreplicate(train_metric)699700train_step_progress_bar.close()701epochs.write(f"Epoch... ({epoch + 1}/{args.num_train_epochs} | Loss: {train_metric['loss']})")702703if jax.process_index() == 0:704checkpoint()705706707if __name__ == "__main__":708main()709710711