Path: blob/main/examples/text_to_image/train_text_to_image_flax.py
1448 views
import argparse1import logging2import math3import os4import random5from pathlib import Path6from typing import Optional78import jax9import jax.numpy as jnp10import numpy as np11import optax12import torch13import torch.utils.checkpoint14import transformers15from datasets import load_dataset16from flax import jax_utils17from flax.training import train_state18from flax.training.common_utils import shard19from huggingface_hub import HfFolder, Repository, create_repo, whoami20from torchvision import transforms21from tqdm.auto import tqdm22from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel, set_seed2324from diffusers import (25FlaxAutoencoderKL,26FlaxDDPMScheduler,27FlaxPNDMScheduler,28FlaxStableDiffusionPipeline,29FlaxUNet2DConditionModel,30)31from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionSafetyChecker32from diffusers.utils import check_min_version333435# Will error if the minimal version of diffusers is not installed. Remove at your own risks.36check_min_version("0.15.0.dev0")3738logger = logging.getLogger(__name__)394041def parse_args():42parser = argparse.ArgumentParser(description="Simple example of a training script.")43parser.add_argument(44"--pretrained_model_name_or_path",45type=str,46default=None,47required=True,48help="Path to pretrained model or model identifier from huggingface.co/models.",49)50parser.add_argument(51"--revision",52type=str,53default=None,54required=False,55help="Revision of pretrained model identifier from huggingface.co/models.",56)57parser.add_argument(58"--dataset_name",59type=str,60default=None,61help=(62"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"63" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"64" or to a folder containing files that 🤗 Datasets can understand."65),66)67parser.add_argument(68"--dataset_config_name",69type=str,70default=None,71help="The config of the Dataset, leave as None if there's only one config.",72)73parser.add_argument(74"--train_data_dir",75type=str,76default=None,77help=(78"A folder containing the training data. Folder contents must follow the structure described in"79" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"80" must exist to provide the captions for the images. Ignored if `dataset_name` is specified."81),82)83parser.add_argument(84"--image_column", type=str, default="image", help="The column of the dataset containing an image."85)86parser.add_argument(87"--caption_column",88type=str,89default="text",90help="The column of the dataset containing a caption or a list of captions.",91)92parser.add_argument(93"--max_train_samples",94type=int,95default=None,96help=(97"For debugging purposes or quicker training, truncate the number of training examples to this "98"value if set."99),100)101parser.add_argument(102"--output_dir",103type=str,104default="sd-model-finetuned",105help="The output directory where the model predictions and checkpoints will be written.",106)107parser.add_argument(108"--cache_dir",109type=str,110default=None,111help="The directory where the downloaded models and datasets will be stored.",112)113parser.add_argument("--seed", type=int, default=0, help="A seed for reproducible training.")114parser.add_argument(115"--resolution",116type=int,117default=512,118help=(119"The resolution for input images, all the images in the train/validation dataset will be resized to this"120" resolution"121),122)123parser.add_argument(124"--center_crop",125default=False,126action="store_true",127help=(128"Whether to center crop the input images to the resolution. If not set, the images will be randomly"129" cropped. The images will be resized to the resolution first before cropping."130),131)132parser.add_argument(133"--random_flip",134action="store_true",135help="whether to randomly flip images horizontally",136)137parser.add_argument(138"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."139)140parser.add_argument("--num_train_epochs", type=int, default=100)141parser.add_argument(142"--max_train_steps",143type=int,144default=None,145help="Total number of training steps to perform. If provided, overrides num_train_epochs.",146)147parser.add_argument(148"--learning_rate",149type=float,150default=1e-4,151help="Initial learning rate (after the potential warmup period) to use.",152)153parser.add_argument(154"--scale_lr",155action="store_true",156default=False,157help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",158)159parser.add_argument(160"--lr_scheduler",161type=str,162default="constant",163help=(164'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'165' "constant", "constant_with_warmup"]'166),167)168parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")169parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")170parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")171parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")172parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")173parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")174parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")175parser.add_argument(176"--hub_model_id",177type=str,178default=None,179help="The name of the repository to keep in sync with the local `output_dir`.",180)181parser.add_argument(182"--logging_dir",183type=str,184default="logs",185help=(186"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"187" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."188),189)190parser.add_argument(191"--report_to",192type=str,193default="tensorboard",194help=(195'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'196' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'197),198)199parser.add_argument(200"--mixed_precision",201type=str,202default="no",203choices=["no", "fp16", "bf16"],204help=(205"Whether to use mixed precision. Choose"206"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."207"and an Nvidia Ampere GPU."208),209)210parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")211212args = parser.parse_args()213env_local_rank = int(os.environ.get("LOCAL_RANK", -1))214if env_local_rank != -1 and env_local_rank != args.local_rank:215args.local_rank = env_local_rank216217# Sanity checks218if args.dataset_name is None and args.train_data_dir is None:219raise ValueError("Need either a dataset name or a training folder.")220221return args222223224def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):225if token is None:226token = HfFolder.get_token()227if organization is None:228username = whoami(token)["name"]229return f"{username}/{model_id}"230else:231return f"{organization}/{model_id}"232233234dataset_name_mapping = {235"lambdalabs/pokemon-blip-captions": ("image", "text"),236}237238239def get_params_to_save(params):240return jax.device_get(jax.tree_util.tree_map(lambda x: x[0], params))241242243def main():244args = parse_args()245246logging.basicConfig(247format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",248datefmt="%m/%d/%Y %H:%M:%S",249level=logging.INFO,250)251# Setup logging, we only want one process per machine to log things on the screen.252logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)253if jax.process_index() == 0:254transformers.utils.logging.set_verbosity_info()255else:256transformers.utils.logging.set_verbosity_error()257258if args.seed is not None:259set_seed(args.seed)260261# Handle the repository creation262if jax.process_index() == 0:263if args.push_to_hub:264if args.hub_model_id is None:265repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)266else:267repo_name = args.hub_model_id268create_repo(repo_name, exist_ok=True, token=args.hub_token)269repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)270271with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:272if "step_*" not in gitignore:273gitignore.write("step_*\n")274if "epoch_*" not in gitignore:275gitignore.write("epoch_*\n")276elif args.output_dir is not None:277os.makedirs(args.output_dir, exist_ok=True)278279# Get the datasets: you can either provide your own training and evaluation files (see below)280# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).281282# In distributed training, the load_dataset function guarantees that only one local process can concurrently283# download the dataset.284if args.dataset_name is not None:285# Downloading and loading a dataset from the hub.286dataset = load_dataset(287args.dataset_name,288args.dataset_config_name,289cache_dir=args.cache_dir,290)291else:292data_files = {}293if args.train_data_dir is not None:294data_files["train"] = os.path.join(args.train_data_dir, "**")295dataset = load_dataset(296"imagefolder",297data_files=data_files,298cache_dir=args.cache_dir,299)300# See more about loading custom images at301# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder302303# Preprocessing the datasets.304# We need to tokenize inputs and targets.305column_names = dataset["train"].column_names306307# 6. Get the column names for input/target.308dataset_columns = dataset_name_mapping.get(args.dataset_name, None)309if args.image_column is None:310image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]311else:312image_column = args.image_column313if image_column not in column_names:314raise ValueError(315f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"316)317if args.caption_column is None:318caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]319else:320caption_column = args.caption_column321if caption_column not in column_names:322raise ValueError(323f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"324)325326# Preprocessing the datasets.327# We need to tokenize input captions and transform the images.328def tokenize_captions(examples, is_train=True):329captions = []330for caption in examples[caption_column]:331if isinstance(caption, str):332captions.append(caption)333elif isinstance(caption, (list, np.ndarray)):334# take a random caption if there are multiple335captions.append(random.choice(caption) if is_train else caption[0])336else:337raise ValueError(338f"Caption column `{caption_column}` should contain either strings or lists of strings."339)340inputs = tokenizer(captions, max_length=tokenizer.model_max_length, padding="do_not_pad", truncation=True)341input_ids = inputs.input_ids342return input_ids343344train_transforms = transforms.Compose(345[346transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),347transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),348transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),349transforms.ToTensor(),350transforms.Normalize([0.5], [0.5]),351]352)353354def preprocess_train(examples):355images = [image.convert("RGB") for image in examples[image_column]]356examples["pixel_values"] = [train_transforms(image) for image in images]357examples["input_ids"] = tokenize_captions(examples)358359return examples360361if jax.process_index() == 0:362if args.max_train_samples is not None:363dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))364# Set the training transforms365train_dataset = dataset["train"].with_transform(preprocess_train)366367def collate_fn(examples):368pixel_values = torch.stack([example["pixel_values"] for example in examples])369pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()370input_ids = [example["input_ids"] for example in examples]371372padded_tokens = tokenizer.pad(373{"input_ids": input_ids}, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"374)375batch = {376"pixel_values": pixel_values,377"input_ids": padded_tokens.input_ids,378}379batch = {k: v.numpy() for k, v in batch.items()}380381return batch382383total_train_batch_size = args.train_batch_size * jax.local_device_count()384train_dataloader = torch.utils.data.DataLoader(385train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=total_train_batch_size, drop_last=True386)387388weight_dtype = jnp.float32389if args.mixed_precision == "fp16":390weight_dtype = jnp.float16391elif args.mixed_precision == "bf16":392weight_dtype = jnp.bfloat16393394# Load models and create wrapper for stable diffusion395tokenizer = CLIPTokenizer.from_pretrained(396args.pretrained_model_name_or_path, revision=args.revision, subfolder="tokenizer"397)398text_encoder = FlaxCLIPTextModel.from_pretrained(399args.pretrained_model_name_or_path, revision=args.revision, subfolder="text_encoder", dtype=weight_dtype400)401vae, vae_params = FlaxAutoencoderKL.from_pretrained(402args.pretrained_model_name_or_path, revision=args.revision, subfolder="vae", dtype=weight_dtype403)404unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(405args.pretrained_model_name_or_path, revision=args.revision, subfolder="unet", dtype=weight_dtype406)407408# Optimization409if args.scale_lr:410args.learning_rate = args.learning_rate * total_train_batch_size411412constant_scheduler = optax.constant_schedule(args.learning_rate)413414adamw = optax.adamw(415learning_rate=constant_scheduler,416b1=args.adam_beta1,417b2=args.adam_beta2,418eps=args.adam_epsilon,419weight_decay=args.adam_weight_decay,420)421422optimizer = optax.chain(423optax.clip_by_global_norm(args.max_grad_norm),424adamw,425)426427state = train_state.TrainState.create(apply_fn=unet.__call__, params=unet_params, tx=optimizer)428429noise_scheduler = FlaxDDPMScheduler(430beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000431)432noise_scheduler_state = noise_scheduler.create_state()433434# Initialize our training435rng = jax.random.PRNGKey(args.seed)436train_rngs = jax.random.split(rng, jax.local_device_count())437438def train_step(state, text_encoder_params, vae_params, batch, train_rng):439dropout_rng, sample_rng, new_train_rng = jax.random.split(train_rng, 3)440441def compute_loss(params):442# Convert images to latent space443vae_outputs = vae.apply(444{"params": vae_params}, batch["pixel_values"], deterministic=True, method=vae.encode445)446latents = vae_outputs.latent_dist.sample(sample_rng)447# (NHWC) -> (NCHW)448latents = jnp.transpose(latents, (0, 3, 1, 2))449latents = latents * vae.config.scaling_factor450451# Sample noise that we'll add to the latents452noise_rng, timestep_rng = jax.random.split(sample_rng)453noise = jax.random.normal(noise_rng, latents.shape)454# Sample a random timestep for each image455bsz = latents.shape[0]456timesteps = jax.random.randint(457timestep_rng,458(bsz,),4590,460noise_scheduler.config.num_train_timesteps,461)462463# Add noise to the latents according to the noise magnitude at each timestep464# (this is the forward diffusion process)465noisy_latents = noise_scheduler.add_noise(noise_scheduler_state, latents, noise, timesteps)466467# Get the text embedding for conditioning468encoder_hidden_states = text_encoder(469batch["input_ids"],470params=text_encoder_params,471train=False,472)[0]473474# Predict the noise residual and compute loss475model_pred = unet.apply(476{"params": params}, noisy_latents, timesteps, encoder_hidden_states, train=True477).sample478479# Get the target for loss depending on the prediction type480if noise_scheduler.config.prediction_type == "epsilon":481target = noise482elif noise_scheduler.config.prediction_type == "v_prediction":483target = noise_scheduler.get_velocity(noise_scheduler_state, latents, noise, timesteps)484else:485raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")486487loss = (target - model_pred) ** 2488loss = loss.mean()489490return loss491492grad_fn = jax.value_and_grad(compute_loss)493loss, grad = grad_fn(state.params)494grad = jax.lax.pmean(grad, "batch")495496new_state = state.apply_gradients(grads=grad)497498metrics = {"loss": loss}499metrics = jax.lax.pmean(metrics, axis_name="batch")500501return new_state, metrics, new_train_rng502503# Create parallel version of the train step504p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))505506# Replicate the train state on each device507state = jax_utils.replicate(state)508text_encoder_params = jax_utils.replicate(text_encoder.params)509vae_params = jax_utils.replicate(vae_params)510511# Train!512num_update_steps_per_epoch = math.ceil(len(train_dataloader))513514# Scheduler and math around the number of training steps.515if args.max_train_steps is None:516args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch517518args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)519520logger.info("***** Running training *****")521logger.info(f" Num examples = {len(train_dataset)}")522logger.info(f" Num Epochs = {args.num_train_epochs}")523logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")524logger.info(f" Total train batch size (w. parallel & distributed) = {total_train_batch_size}")525logger.info(f" Total optimization steps = {args.max_train_steps}")526527global_step = 0528529epochs = tqdm(range(args.num_train_epochs), desc="Epoch ... ", position=0)530for epoch in epochs:531# ======================== Training ================================532533train_metrics = []534535steps_per_epoch = len(train_dataset) // total_train_batch_size536train_step_progress_bar = tqdm(total=steps_per_epoch, desc="Training...", position=1, leave=False)537# train538for batch in train_dataloader:539batch = shard(batch)540state, train_metric, train_rngs = p_train_step(state, text_encoder_params, vae_params, batch, train_rngs)541train_metrics.append(train_metric)542543train_step_progress_bar.update(1)544545global_step += 1546if global_step >= args.max_train_steps:547break548549train_metric = jax_utils.unreplicate(train_metric)550551train_step_progress_bar.close()552epochs.write(f"Epoch... ({epoch + 1}/{args.num_train_epochs} | Loss: {train_metric['loss']})")553554# Create the pipeline using using the trained modules and save it.555if jax.process_index() == 0:556scheduler = FlaxPNDMScheduler(557beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True558)559safety_checker = FlaxStableDiffusionSafetyChecker.from_pretrained(560"CompVis/stable-diffusion-safety-checker", from_pt=True561)562pipeline = FlaxStableDiffusionPipeline(563text_encoder=text_encoder,564vae=vae,565unet=unet,566tokenizer=tokenizer,567scheduler=scheduler,568safety_checker=safety_checker,569feature_extractor=CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32"),570)571572pipeline.save_pretrained(573args.output_dir,574params={575"text_encoder": get_params_to_save(text_encoder_params),576"vae": get_params_to_save(vae_params),577"unet": get_params_to_save(state.params),578"safety_checker": safety_checker.params,579},580)581582if args.push_to_hub:583repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)584585586if __name__ == "__main__":587main()588589590