Path: blob/main/scripts/convert_dit_to_diffusers.py
1440 views
import argparse1import os23import torch4from torchvision.datasets.utils import download_url56from diffusers import AutoencoderKL, DDIMScheduler, DiTPipeline, Transformer2DModel789pretrained_models = {512: "DiT-XL-2-512x512.pt", 256: "DiT-XL-2-256x256.pt"}101112def download_model(model_name):13"""14Downloads a pre-trained DiT model from the web.15"""16local_path = f"pretrained_models/{model_name}"17if not os.path.isfile(local_path):18os.makedirs("pretrained_models", exist_ok=True)19web_path = f"https://dl.fbaipublicfiles.com/DiT/models/{model_name}"20download_url(web_path, "pretrained_models")21model = torch.load(local_path, map_location=lambda storage, loc: storage)22return model232425def main(args):26state_dict = download_model(pretrained_models[args.image_size])2728state_dict["pos_embed.proj.weight"] = state_dict["x_embedder.proj.weight"]29state_dict["pos_embed.proj.bias"] = state_dict["x_embedder.proj.bias"]30state_dict.pop("x_embedder.proj.weight")31state_dict.pop("x_embedder.proj.bias")3233for depth in range(28):34state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_1.weight"] = state_dict[35"t_embedder.mlp.0.weight"36]37state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_1.bias"] = state_dict[38"t_embedder.mlp.0.bias"39]40state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_2.weight"] = state_dict[41"t_embedder.mlp.2.weight"42]43state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_2.bias"] = state_dict[44"t_embedder.mlp.2.bias"45]46state_dict[f"transformer_blocks.{depth}.norm1.emb.class_embedder.embedding_table.weight"] = state_dict[47"y_embedder.embedding_table.weight"48]4950state_dict[f"transformer_blocks.{depth}.norm1.linear.weight"] = state_dict[51f"blocks.{depth}.adaLN_modulation.1.weight"52]53state_dict[f"transformer_blocks.{depth}.norm1.linear.bias"] = state_dict[54f"blocks.{depth}.adaLN_modulation.1.bias"55]5657q, k, v = torch.chunk(state_dict[f"blocks.{depth}.attn.qkv.weight"], 3, dim=0)58q_bias, k_bias, v_bias = torch.chunk(state_dict[f"blocks.{depth}.attn.qkv.bias"], 3, dim=0)5960state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q61state_dict[f"transformer_blocks.{depth}.attn1.to_q.bias"] = q_bias62state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k63state_dict[f"transformer_blocks.{depth}.attn1.to_k.bias"] = k_bias64state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v65state_dict[f"transformer_blocks.{depth}.attn1.to_v.bias"] = v_bias6667state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict[68f"blocks.{depth}.attn.proj.weight"69]70state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.bias"] = state_dict[f"blocks.{depth}.attn.proj.bias"]7172state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.weight"] = state_dict[f"blocks.{depth}.mlp.fc1.weight"]73state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.bias"] = state_dict[f"blocks.{depth}.mlp.fc1.bias"]74state_dict[f"transformer_blocks.{depth}.ff.net.2.weight"] = state_dict[f"blocks.{depth}.mlp.fc2.weight"]75state_dict[f"transformer_blocks.{depth}.ff.net.2.bias"] = state_dict[f"blocks.{depth}.mlp.fc2.bias"]7677state_dict.pop(f"blocks.{depth}.attn.qkv.weight")78state_dict.pop(f"blocks.{depth}.attn.qkv.bias")79state_dict.pop(f"blocks.{depth}.attn.proj.weight")80state_dict.pop(f"blocks.{depth}.attn.proj.bias")81state_dict.pop(f"blocks.{depth}.mlp.fc1.weight")82state_dict.pop(f"blocks.{depth}.mlp.fc1.bias")83state_dict.pop(f"blocks.{depth}.mlp.fc2.weight")84state_dict.pop(f"blocks.{depth}.mlp.fc2.bias")85state_dict.pop(f"blocks.{depth}.adaLN_modulation.1.weight")86state_dict.pop(f"blocks.{depth}.adaLN_modulation.1.bias")8788state_dict.pop("t_embedder.mlp.0.weight")89state_dict.pop("t_embedder.mlp.0.bias")90state_dict.pop("t_embedder.mlp.2.weight")91state_dict.pop("t_embedder.mlp.2.bias")92state_dict.pop("y_embedder.embedding_table.weight")9394state_dict["proj_out_1.weight"] = state_dict["final_layer.adaLN_modulation.1.weight"]95state_dict["proj_out_1.bias"] = state_dict["final_layer.adaLN_modulation.1.bias"]96state_dict["proj_out_2.weight"] = state_dict["final_layer.linear.weight"]97state_dict["proj_out_2.bias"] = state_dict["final_layer.linear.bias"]9899state_dict.pop("final_layer.linear.weight")100state_dict.pop("final_layer.linear.bias")101state_dict.pop("final_layer.adaLN_modulation.1.weight")102state_dict.pop("final_layer.adaLN_modulation.1.bias")103104# DiT XL/2105transformer = Transformer2DModel(106sample_size=args.image_size // 8,107num_layers=28,108attention_head_dim=72,109in_channels=4,110out_channels=8,111patch_size=2,112attention_bias=True,113num_attention_heads=16,114activation_fn="gelu-approximate",115num_embeds_ada_norm=1000,116norm_type="ada_norm_zero",117norm_elementwise_affine=False,118)119transformer.load_state_dict(state_dict, strict=True)120121scheduler = DDIMScheduler(122num_train_timesteps=1000,123beta_schedule="linear",124prediction_type="epsilon",125clip_sample=False,126)127128vae = AutoencoderKL.from_pretrained(args.vae_model)129130pipeline = DiTPipeline(transformer=transformer, vae=vae, scheduler=scheduler)131132if args.save:133pipeline.save_pretrained(args.checkpoint_path)134135136if __name__ == "__main__":137parser = argparse.ArgumentParser()138139parser.add_argument(140"--image_size",141default=256,142type=int,143required=False,144help="Image size of pretrained model, either 256 or 512.",145)146parser.add_argument(147"--vae_model",148default="stabilityai/sd-vae-ft-ema",149type=str,150required=False,151help="Path to pretrained VAE model, either stabilityai/sd-vae-ft-mse or stabilityai/sd-vae-ft-ema.",152)153parser.add_argument(154"--save", default=True, type=bool, required=False, help="Whether to save the converted pipeline or not."155)156parser.add_argument(157"--checkpoint_path", default=None, type=str, required=True, help="Path to the output pipeline."158)159160args = parser.parse_args()161main(args)162163164