Path: blob/main/scripts/convert_vq_diffusion_to_diffusers.py
1440 views
"""1This script ports models from VQ-diffusion (https://github.com/microsoft/VQ-Diffusion) to diffusers.23It currently only supports porting the ITHQ dataset.45ITHQ dataset:6```sh7# From the root directory of diffusers.89# Download the VQVAE checkpoint10$ wget https://facevcstandard.blob.core.windows.net/v-zhictang/Improved-VQ-Diffusion_model_release/ithq_vqvae.pth?sv=2020-10-02&st=2022-05-30T15%3A17%3A18Z&se=2030-05-31T15%3A17%3A00Z&sr=b&sp=r&sig=1jVavHFPpUjDs%2FTO1V3PTezaNbPp2Nx8MxiWI7y6fEY%3D -O ithq_vqvae.pth1112# Download the VQVAE config13# NOTE that in VQ-diffusion the documented file is `configs/ithq.yaml` but the target class14# `image_synthesis.modeling.codecs.image_codec.ema_vqvae.PatchVQVAE`15# loads `OUTPUT/pretrained_model/taming_dvae/config.yaml`16$ wget https://raw.githubusercontent.com/microsoft/VQ-Diffusion/main/OUTPUT/pretrained_model/taming_dvae/config.yaml -O ithq_vqvae.yaml1718# Download the main model checkpoint19$ wget https://facevcstandard.blob.core.windows.net/v-zhictang/Improved-VQ-Diffusion_model_release/ithq_learnable.pth?sv=2020-10-02&st=2022-05-30T10%3A22%3A06Z&se=2030-05-31T10%3A22%3A00Z&sr=b&sp=r&sig=GOE%2Bza02%2FPnGxYVOOPtwrTR4RA3%2F5NVgMxdW4kjaEZ8%3D -O ithq_learnable.pth2021# Download the main model config22$ wget https://raw.githubusercontent.com/microsoft/VQ-Diffusion/main/configs/ithq.yaml -O ithq.yaml2324# run the convert script25$ python ./scripts/convert_vq_diffusion_to_diffusers.py \26--checkpoint_path ./ithq_learnable.pth \27--original_config_file ./ithq.yaml \28--vqvae_checkpoint_path ./ithq_vqvae.pth \29--vqvae_original_config_file ./ithq_vqvae.yaml \30--dump_path <path to save pre-trained `VQDiffusionPipeline`>31```32"""3334import argparse35import tempfile3637import torch38import yaml39from accelerate import init_empty_weights, load_checkpoint_and_dispatch40from transformers import CLIPTextModel, CLIPTokenizer41from yaml.loader import FullLoader4243from diffusers import Transformer2DModel, VQDiffusionPipeline, VQDiffusionScheduler, VQModel44from diffusers.pipelines.vq_diffusion.pipeline_vq_diffusion import LearnedClassifierFreeSamplingEmbeddings454647try:48from omegaconf import OmegaConf49except ImportError:50raise ImportError(51"OmegaConf is required to convert the VQ Diffusion checkpoints. Please install it with `pip install"52" OmegaConf`."53)5455# vqvae model5657PORTED_VQVAES = ["image_synthesis.modeling.codecs.image_codec.patch_vqgan.PatchVQGAN"]585960def vqvae_model_from_original_config(original_config):61assert original_config.target in PORTED_VQVAES, f"{original_config.target} has not yet been ported to diffusers."6263original_config = original_config.params6465original_encoder_config = original_config.encoder_config.params66original_decoder_config = original_config.decoder_config.params6768in_channels = original_encoder_config.in_channels69out_channels = original_decoder_config.out_ch7071down_block_types = get_down_block_types(original_encoder_config)72up_block_types = get_up_block_types(original_decoder_config)7374assert original_encoder_config.ch == original_decoder_config.ch75assert original_encoder_config.ch_mult == original_decoder_config.ch_mult76block_out_channels = tuple(77[original_encoder_config.ch * a_ch_mult for a_ch_mult in original_encoder_config.ch_mult]78)7980assert original_encoder_config.num_res_blocks == original_decoder_config.num_res_blocks81layers_per_block = original_encoder_config.num_res_blocks8283assert original_encoder_config.z_channels == original_decoder_config.z_channels84latent_channels = original_encoder_config.z_channels8586num_vq_embeddings = original_config.n_embed8788# Hard coded value for ResnetBlock.GoupNorm(num_groups) in VQ-diffusion89norm_num_groups = 329091e_dim = original_config.embed_dim9293model = VQModel(94in_channels=in_channels,95out_channels=out_channels,96down_block_types=down_block_types,97up_block_types=up_block_types,98block_out_channels=block_out_channels,99layers_per_block=layers_per_block,100latent_channels=latent_channels,101num_vq_embeddings=num_vq_embeddings,102norm_num_groups=norm_num_groups,103vq_embed_dim=e_dim,104)105106return model107108109def get_down_block_types(original_encoder_config):110attn_resolutions = coerce_attn_resolutions(original_encoder_config.attn_resolutions)111num_resolutions = len(original_encoder_config.ch_mult)112resolution = coerce_resolution(original_encoder_config.resolution)113114curr_res = resolution115down_block_types = []116117for _ in range(num_resolutions):118if curr_res in attn_resolutions:119down_block_type = "AttnDownEncoderBlock2D"120else:121down_block_type = "DownEncoderBlock2D"122123down_block_types.append(down_block_type)124125curr_res = [r // 2 for r in curr_res]126127return down_block_types128129130def get_up_block_types(original_decoder_config):131attn_resolutions = coerce_attn_resolutions(original_decoder_config.attn_resolutions)132num_resolutions = len(original_decoder_config.ch_mult)133resolution = coerce_resolution(original_decoder_config.resolution)134135curr_res = [r // 2 ** (num_resolutions - 1) for r in resolution]136up_block_types = []137138for _ in reversed(range(num_resolutions)):139if curr_res in attn_resolutions:140up_block_type = "AttnUpDecoderBlock2D"141else:142up_block_type = "UpDecoderBlock2D"143144up_block_types.append(up_block_type)145146curr_res = [r * 2 for r in curr_res]147148return up_block_types149150151def coerce_attn_resolutions(attn_resolutions):152attn_resolutions = OmegaConf.to_object(attn_resolutions)153attn_resolutions_ = []154for ar in attn_resolutions:155if isinstance(ar, (list, tuple)):156attn_resolutions_.append(list(ar))157else:158attn_resolutions_.append([ar, ar])159return attn_resolutions_160161162def coerce_resolution(resolution):163resolution = OmegaConf.to_object(resolution)164if isinstance(resolution, int):165resolution = [resolution, resolution] # H, W166elif isinstance(resolution, (tuple, list)):167resolution = list(resolution)168else:169raise ValueError("Unknown type of resolution:", resolution)170return resolution171172173# done vqvae model174175# vqvae checkpoint176177178def vqvae_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):179diffusers_checkpoint = {}180181diffusers_checkpoint.update(vqvae_encoder_to_diffusers_checkpoint(model, checkpoint))182183# quant_conv184185diffusers_checkpoint.update(186{187"quant_conv.weight": checkpoint["quant_conv.weight"],188"quant_conv.bias": checkpoint["quant_conv.bias"],189}190)191192# quantize193diffusers_checkpoint.update({"quantize.embedding.weight": checkpoint["quantize.embedding"]})194195# post_quant_conv196diffusers_checkpoint.update(197{198"post_quant_conv.weight": checkpoint["post_quant_conv.weight"],199"post_quant_conv.bias": checkpoint["post_quant_conv.bias"],200}201)202203# decoder204diffusers_checkpoint.update(vqvae_decoder_to_diffusers_checkpoint(model, checkpoint))205206return diffusers_checkpoint207208209def vqvae_encoder_to_diffusers_checkpoint(model, checkpoint):210diffusers_checkpoint = {}211212# conv_in213diffusers_checkpoint.update(214{215"encoder.conv_in.weight": checkpoint["encoder.conv_in.weight"],216"encoder.conv_in.bias": checkpoint["encoder.conv_in.bias"],217}218)219220# down_blocks221for down_block_idx, down_block in enumerate(model.encoder.down_blocks):222diffusers_down_block_prefix = f"encoder.down_blocks.{down_block_idx}"223down_block_prefix = f"encoder.down.{down_block_idx}"224225# resnets226for resnet_idx, resnet in enumerate(down_block.resnets):227diffusers_resnet_prefix = f"{diffusers_down_block_prefix}.resnets.{resnet_idx}"228resnet_prefix = f"{down_block_prefix}.block.{resnet_idx}"229230diffusers_checkpoint.update(231vqvae_resnet_to_diffusers_checkpoint(232resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix233)234)235236# downsample237238# do not include the downsample when on the last down block239# There is no downsample on the last down block240if down_block_idx != len(model.encoder.down_blocks) - 1:241# There's a single downsample in the original checkpoint but a list of downsamples242# in the diffusers model.243diffusers_downsample_prefix = f"{diffusers_down_block_prefix}.downsamplers.0.conv"244downsample_prefix = f"{down_block_prefix}.downsample.conv"245diffusers_checkpoint.update(246{247f"{diffusers_downsample_prefix}.weight": checkpoint[f"{downsample_prefix}.weight"],248f"{diffusers_downsample_prefix}.bias": checkpoint[f"{downsample_prefix}.bias"],249}250)251252# attentions253254if hasattr(down_block, "attentions"):255for attention_idx, _ in enumerate(down_block.attentions):256diffusers_attention_prefix = f"{diffusers_down_block_prefix}.attentions.{attention_idx}"257attention_prefix = f"{down_block_prefix}.attn.{attention_idx}"258diffusers_checkpoint.update(259vqvae_attention_to_diffusers_checkpoint(260checkpoint,261diffusers_attention_prefix=diffusers_attention_prefix,262attention_prefix=attention_prefix,263)264)265266# mid block267268# mid block attentions269270# There is a single hardcoded attention block in the middle of the VQ-diffusion encoder271diffusers_attention_prefix = "encoder.mid_block.attentions.0"272attention_prefix = "encoder.mid.attn_1"273diffusers_checkpoint.update(274vqvae_attention_to_diffusers_checkpoint(275checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix276)277)278279# mid block resnets280281for diffusers_resnet_idx, resnet in enumerate(model.encoder.mid_block.resnets):282diffusers_resnet_prefix = f"encoder.mid_block.resnets.{diffusers_resnet_idx}"283284# the hardcoded prefixes to `block_` are 1 and 2285orig_resnet_idx = diffusers_resnet_idx + 1286# There are two hardcoded resnets in the middle of the VQ-diffusion encoder287resnet_prefix = f"encoder.mid.block_{orig_resnet_idx}"288289diffusers_checkpoint.update(290vqvae_resnet_to_diffusers_checkpoint(291resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix292)293)294295diffusers_checkpoint.update(296{297# conv_norm_out298"encoder.conv_norm_out.weight": checkpoint["encoder.norm_out.weight"],299"encoder.conv_norm_out.bias": checkpoint["encoder.norm_out.bias"],300# conv_out301"encoder.conv_out.weight": checkpoint["encoder.conv_out.weight"],302"encoder.conv_out.bias": checkpoint["encoder.conv_out.bias"],303}304)305306return diffusers_checkpoint307308309def vqvae_decoder_to_diffusers_checkpoint(model, checkpoint):310diffusers_checkpoint = {}311312# conv in313diffusers_checkpoint.update(314{315"decoder.conv_in.weight": checkpoint["decoder.conv_in.weight"],316"decoder.conv_in.bias": checkpoint["decoder.conv_in.bias"],317}318)319320# up_blocks321322for diffusers_up_block_idx, up_block in enumerate(model.decoder.up_blocks):323# up_blocks are stored in reverse order in the VQ-diffusion checkpoint324orig_up_block_idx = len(model.decoder.up_blocks) - 1 - diffusers_up_block_idx325326diffusers_up_block_prefix = f"decoder.up_blocks.{diffusers_up_block_idx}"327up_block_prefix = f"decoder.up.{orig_up_block_idx}"328329# resnets330for resnet_idx, resnet in enumerate(up_block.resnets):331diffusers_resnet_prefix = f"{diffusers_up_block_prefix}.resnets.{resnet_idx}"332resnet_prefix = f"{up_block_prefix}.block.{resnet_idx}"333334diffusers_checkpoint.update(335vqvae_resnet_to_diffusers_checkpoint(336resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix337)338)339340# upsample341342# there is no up sample on the last up block343if diffusers_up_block_idx != len(model.decoder.up_blocks) - 1:344# There's a single upsample in the VQ-diffusion checkpoint but a list of downsamples345# in the diffusers model.346diffusers_downsample_prefix = f"{diffusers_up_block_prefix}.upsamplers.0.conv"347downsample_prefix = f"{up_block_prefix}.upsample.conv"348diffusers_checkpoint.update(349{350f"{diffusers_downsample_prefix}.weight": checkpoint[f"{downsample_prefix}.weight"],351f"{diffusers_downsample_prefix}.bias": checkpoint[f"{downsample_prefix}.bias"],352}353)354355# attentions356357if hasattr(up_block, "attentions"):358for attention_idx, _ in enumerate(up_block.attentions):359diffusers_attention_prefix = f"{diffusers_up_block_prefix}.attentions.{attention_idx}"360attention_prefix = f"{up_block_prefix}.attn.{attention_idx}"361diffusers_checkpoint.update(362vqvae_attention_to_diffusers_checkpoint(363checkpoint,364diffusers_attention_prefix=diffusers_attention_prefix,365attention_prefix=attention_prefix,366)367)368369# mid block370371# mid block attentions372373# There is a single hardcoded attention block in the middle of the VQ-diffusion decoder374diffusers_attention_prefix = "decoder.mid_block.attentions.0"375attention_prefix = "decoder.mid.attn_1"376diffusers_checkpoint.update(377vqvae_attention_to_diffusers_checkpoint(378checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix379)380)381382# mid block resnets383384for diffusers_resnet_idx, resnet in enumerate(model.encoder.mid_block.resnets):385diffusers_resnet_prefix = f"decoder.mid_block.resnets.{diffusers_resnet_idx}"386387# the hardcoded prefixes to `block_` are 1 and 2388orig_resnet_idx = diffusers_resnet_idx + 1389# There are two hardcoded resnets in the middle of the VQ-diffusion decoder390resnet_prefix = f"decoder.mid.block_{orig_resnet_idx}"391392diffusers_checkpoint.update(393vqvae_resnet_to_diffusers_checkpoint(394resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix395)396)397398diffusers_checkpoint.update(399{400# conv_norm_out401"decoder.conv_norm_out.weight": checkpoint["decoder.norm_out.weight"],402"decoder.conv_norm_out.bias": checkpoint["decoder.norm_out.bias"],403# conv_out404"decoder.conv_out.weight": checkpoint["decoder.conv_out.weight"],405"decoder.conv_out.bias": checkpoint["decoder.conv_out.bias"],406}407)408409return diffusers_checkpoint410411412def vqvae_resnet_to_diffusers_checkpoint(resnet, checkpoint, *, diffusers_resnet_prefix, resnet_prefix):413rv = {414# norm1415f"{diffusers_resnet_prefix}.norm1.weight": checkpoint[f"{resnet_prefix}.norm1.weight"],416f"{diffusers_resnet_prefix}.norm1.bias": checkpoint[f"{resnet_prefix}.norm1.bias"],417# conv1418f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.conv1.weight"],419f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.conv1.bias"],420# norm2421f"{diffusers_resnet_prefix}.norm2.weight": checkpoint[f"{resnet_prefix}.norm2.weight"],422f"{diffusers_resnet_prefix}.norm2.bias": checkpoint[f"{resnet_prefix}.norm2.bias"],423# conv2424f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.conv2.weight"],425f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.conv2.bias"],426}427428if resnet.conv_shortcut is not None:429rv.update(430{431f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{resnet_prefix}.nin_shortcut.weight"],432f"{diffusers_resnet_prefix}.conv_shortcut.bias": checkpoint[f"{resnet_prefix}.nin_shortcut.bias"],433}434)435436return rv437438439def vqvae_attention_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix):440return {441# group_norm442f"{diffusers_attention_prefix}.group_norm.weight": checkpoint[f"{attention_prefix}.norm.weight"],443f"{diffusers_attention_prefix}.group_norm.bias": checkpoint[f"{attention_prefix}.norm.bias"],444# query445f"{diffusers_attention_prefix}.query.weight": checkpoint[f"{attention_prefix}.q.weight"][:, :, 0, 0],446f"{diffusers_attention_prefix}.query.bias": checkpoint[f"{attention_prefix}.q.bias"],447# key448f"{diffusers_attention_prefix}.key.weight": checkpoint[f"{attention_prefix}.k.weight"][:, :, 0, 0],449f"{diffusers_attention_prefix}.key.bias": checkpoint[f"{attention_prefix}.k.bias"],450# value451f"{diffusers_attention_prefix}.value.weight": checkpoint[f"{attention_prefix}.v.weight"][:, :, 0, 0],452f"{diffusers_attention_prefix}.value.bias": checkpoint[f"{attention_prefix}.v.bias"],453# proj_attn454f"{diffusers_attention_prefix}.proj_attn.weight": checkpoint[f"{attention_prefix}.proj_out.weight"][455:, :, 0, 0456],457f"{diffusers_attention_prefix}.proj_attn.bias": checkpoint[f"{attention_prefix}.proj_out.bias"],458}459460461# done vqvae checkpoint462463# transformer model464465PORTED_DIFFUSIONS = ["image_synthesis.modeling.transformers.diffusion_transformer.DiffusionTransformer"]466PORTED_TRANSFORMERS = ["image_synthesis.modeling.transformers.transformer_utils.Text2ImageTransformer"]467PORTED_CONTENT_EMBEDDINGS = ["image_synthesis.modeling.embeddings.dalle_mask_image_embedding.DalleMaskImageEmbedding"]468469470def transformer_model_from_original_config(471original_diffusion_config, original_transformer_config, original_content_embedding_config472):473assert (474original_diffusion_config.target in PORTED_DIFFUSIONS475), f"{original_diffusion_config.target} has not yet been ported to diffusers."476assert (477original_transformer_config.target in PORTED_TRANSFORMERS478), f"{original_transformer_config.target} has not yet been ported to diffusers."479assert (480original_content_embedding_config.target in PORTED_CONTENT_EMBEDDINGS481), f"{original_content_embedding_config.target} has not yet been ported to diffusers."482483original_diffusion_config = original_diffusion_config.params484original_transformer_config = original_transformer_config.params485original_content_embedding_config = original_content_embedding_config.params486487inner_dim = original_transformer_config["n_embd"]488489n_heads = original_transformer_config["n_head"]490491# VQ-Diffusion gives dimension of the multi-headed attention layers as the492# number of attention heads times the sequence length (the dimension) of a493# single head. We want to specify our attention blocks with those values494# specified separately495assert inner_dim % n_heads == 0496d_head = inner_dim // n_heads497498depth = original_transformer_config["n_layer"]499context_dim = original_transformer_config["condition_dim"]500501num_embed = original_content_embedding_config["num_embed"]502# the number of embeddings in the transformer includes the mask embedding.503# the content embedding (the vqvae) does not include the mask embedding.504num_embed = num_embed + 1505506height = original_transformer_config["content_spatial_size"][0]507width = original_transformer_config["content_spatial_size"][1]508509assert width == height, "width has to be equal to height"510dropout = original_transformer_config["resid_pdrop"]511num_embeds_ada_norm = original_diffusion_config["diffusion_step"]512513model_kwargs = {514"attention_bias": True,515"cross_attention_dim": context_dim,516"attention_head_dim": d_head,517"num_layers": depth,518"dropout": dropout,519"num_attention_heads": n_heads,520"num_vector_embeds": num_embed,521"num_embeds_ada_norm": num_embeds_ada_norm,522"norm_num_groups": 32,523"sample_size": width,524"activation_fn": "geglu-approximate",525}526527model = Transformer2DModel(**model_kwargs)528return model529530531# done transformer model532533# transformer checkpoint534535536def transformer_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):537diffusers_checkpoint = {}538539transformer_prefix = "transformer.transformer"540541diffusers_latent_image_embedding_prefix = "latent_image_embedding"542latent_image_embedding_prefix = f"{transformer_prefix}.content_emb"543544# DalleMaskImageEmbedding545diffusers_checkpoint.update(546{547f"{diffusers_latent_image_embedding_prefix}.emb.weight": checkpoint[548f"{latent_image_embedding_prefix}.emb.weight"549],550f"{diffusers_latent_image_embedding_prefix}.height_emb.weight": checkpoint[551f"{latent_image_embedding_prefix}.height_emb.weight"552],553f"{diffusers_latent_image_embedding_prefix}.width_emb.weight": checkpoint[554f"{latent_image_embedding_prefix}.width_emb.weight"555],556}557)558559# transformer blocks560for transformer_block_idx, transformer_block in enumerate(model.transformer_blocks):561diffusers_transformer_block_prefix = f"transformer_blocks.{transformer_block_idx}"562transformer_block_prefix = f"{transformer_prefix}.blocks.{transformer_block_idx}"563564# ada norm block565diffusers_ada_norm_prefix = f"{diffusers_transformer_block_prefix}.norm1"566ada_norm_prefix = f"{transformer_block_prefix}.ln1"567568diffusers_checkpoint.update(569transformer_ada_norm_to_diffusers_checkpoint(570checkpoint, diffusers_ada_norm_prefix=diffusers_ada_norm_prefix, ada_norm_prefix=ada_norm_prefix571)572)573574# attention block575diffusers_attention_prefix = f"{diffusers_transformer_block_prefix}.attn1"576attention_prefix = f"{transformer_block_prefix}.attn1"577578diffusers_checkpoint.update(579transformer_attention_to_diffusers_checkpoint(580checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix581)582)583584# ada norm block585diffusers_ada_norm_prefix = f"{diffusers_transformer_block_prefix}.norm2"586ada_norm_prefix = f"{transformer_block_prefix}.ln1_1"587588diffusers_checkpoint.update(589transformer_ada_norm_to_diffusers_checkpoint(590checkpoint, diffusers_ada_norm_prefix=diffusers_ada_norm_prefix, ada_norm_prefix=ada_norm_prefix591)592)593594# attention block595diffusers_attention_prefix = f"{diffusers_transformer_block_prefix}.attn2"596attention_prefix = f"{transformer_block_prefix}.attn2"597598diffusers_checkpoint.update(599transformer_attention_to_diffusers_checkpoint(600checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix601)602)603604# norm block605diffusers_norm_block_prefix = f"{diffusers_transformer_block_prefix}.norm3"606norm_block_prefix = f"{transformer_block_prefix}.ln2"607608diffusers_checkpoint.update(609{610f"{diffusers_norm_block_prefix}.weight": checkpoint[f"{norm_block_prefix}.weight"],611f"{diffusers_norm_block_prefix}.bias": checkpoint[f"{norm_block_prefix}.bias"],612}613)614615# feedforward block616diffusers_feedforward_prefix = f"{diffusers_transformer_block_prefix}.ff"617feedforward_prefix = f"{transformer_block_prefix}.mlp"618619diffusers_checkpoint.update(620transformer_feedforward_to_diffusers_checkpoint(621checkpoint,622diffusers_feedforward_prefix=diffusers_feedforward_prefix,623feedforward_prefix=feedforward_prefix,624)625)626627# to logits628629diffusers_norm_out_prefix = "norm_out"630norm_out_prefix = f"{transformer_prefix}.to_logits.0"631632diffusers_checkpoint.update(633{634f"{diffusers_norm_out_prefix}.weight": checkpoint[f"{norm_out_prefix}.weight"],635f"{diffusers_norm_out_prefix}.bias": checkpoint[f"{norm_out_prefix}.bias"],636}637)638639diffusers_out_prefix = "out"640out_prefix = f"{transformer_prefix}.to_logits.1"641642diffusers_checkpoint.update(643{644f"{diffusers_out_prefix}.weight": checkpoint[f"{out_prefix}.weight"],645f"{diffusers_out_prefix}.bias": checkpoint[f"{out_prefix}.bias"],646}647)648649return diffusers_checkpoint650651652def transformer_ada_norm_to_diffusers_checkpoint(checkpoint, *, diffusers_ada_norm_prefix, ada_norm_prefix):653return {654f"{diffusers_ada_norm_prefix}.emb.weight": checkpoint[f"{ada_norm_prefix}.emb.weight"],655f"{diffusers_ada_norm_prefix}.linear.weight": checkpoint[f"{ada_norm_prefix}.linear.weight"],656f"{diffusers_ada_norm_prefix}.linear.bias": checkpoint[f"{ada_norm_prefix}.linear.bias"],657}658659660def transformer_attention_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix):661return {662# key663f"{diffusers_attention_prefix}.to_k.weight": checkpoint[f"{attention_prefix}.key.weight"],664f"{diffusers_attention_prefix}.to_k.bias": checkpoint[f"{attention_prefix}.key.bias"],665# query666f"{diffusers_attention_prefix}.to_q.weight": checkpoint[f"{attention_prefix}.query.weight"],667f"{diffusers_attention_prefix}.to_q.bias": checkpoint[f"{attention_prefix}.query.bias"],668# value669f"{diffusers_attention_prefix}.to_v.weight": checkpoint[f"{attention_prefix}.value.weight"],670f"{diffusers_attention_prefix}.to_v.bias": checkpoint[f"{attention_prefix}.value.bias"],671# linear out672f"{diffusers_attention_prefix}.to_out.0.weight": checkpoint[f"{attention_prefix}.proj.weight"],673f"{diffusers_attention_prefix}.to_out.0.bias": checkpoint[f"{attention_prefix}.proj.bias"],674}675676677def transformer_feedforward_to_diffusers_checkpoint(checkpoint, *, diffusers_feedforward_prefix, feedforward_prefix):678return {679f"{diffusers_feedforward_prefix}.net.0.proj.weight": checkpoint[f"{feedforward_prefix}.0.weight"],680f"{diffusers_feedforward_prefix}.net.0.proj.bias": checkpoint[f"{feedforward_prefix}.0.bias"],681f"{diffusers_feedforward_prefix}.net.2.weight": checkpoint[f"{feedforward_prefix}.2.weight"],682f"{diffusers_feedforward_prefix}.net.2.bias": checkpoint[f"{feedforward_prefix}.2.bias"],683}684685686# done transformer checkpoint687688689def read_config_file(filename):690# The yaml file contains annotations that certain values should691# loaded as tuples. By default, OmegaConf will panic when reading692# these. Instead, we can manually read the yaml with the FullLoader and then693# construct the OmegaConf object.694with open(filename) as f:695original_config = yaml.load(f, FullLoader)696697return OmegaConf.create(original_config)698699700# We take separate arguments for the vqvae because the ITHQ vqvae config file701# is separate from the config file for the rest of the model.702if __name__ == "__main__":703parser = argparse.ArgumentParser()704705parser.add_argument(706"--vqvae_checkpoint_path",707default=None,708type=str,709required=True,710help="Path to the vqvae checkpoint to convert.",711)712713parser.add_argument(714"--vqvae_original_config_file",715default=None,716type=str,717required=True,718help="The YAML config file corresponding to the original architecture for the vqvae.",719)720721parser.add_argument(722"--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."723)724725parser.add_argument(726"--original_config_file",727default=None,728type=str,729required=True,730help="The YAML config file corresponding to the original architecture.",731)732733parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")734735parser.add_argument(736"--checkpoint_load_device",737default="cpu",738type=str,739required=False,740help="The device passed to `map_location` when loading checkpoints.",741)742743# See link for how ema weights are always selected744# https://github.com/microsoft/VQ-Diffusion/blob/3c98e77f721db7c787b76304fa2c96a36c7b00af/inference_VQ_Diffusion.py#L65745parser.add_argument(746"--no_use_ema",747action="store_true",748required=False,749help=(750"Set to not use the ema weights from the original VQ-Diffusion checkpoint. You probably do not want to set"751" it as the original VQ-Diffusion always uses the ema weights when loading models."752),753)754755args = parser.parse_args()756757use_ema = not args.no_use_ema758759print(f"loading checkpoints to {args.checkpoint_load_device}")760761checkpoint_map_location = torch.device(args.checkpoint_load_device)762763# vqvae_model764765print(f"loading vqvae, config: {args.vqvae_original_config_file}, checkpoint: {args.vqvae_checkpoint_path}")766767vqvae_original_config = read_config_file(args.vqvae_original_config_file).model768vqvae_checkpoint = torch.load(args.vqvae_checkpoint_path, map_location=checkpoint_map_location)["model"]769770with init_empty_weights():771vqvae_model = vqvae_model_from_original_config(vqvae_original_config)772773vqvae_diffusers_checkpoint = vqvae_original_checkpoint_to_diffusers_checkpoint(vqvae_model, vqvae_checkpoint)774775with tempfile.NamedTemporaryFile() as vqvae_diffusers_checkpoint_file:776torch.save(vqvae_diffusers_checkpoint, vqvae_diffusers_checkpoint_file.name)777del vqvae_diffusers_checkpoint778del vqvae_checkpoint779load_checkpoint_and_dispatch(vqvae_model, vqvae_diffusers_checkpoint_file.name, device_map="auto")780781print("done loading vqvae")782783# done vqvae_model784785# transformer_model786787print(788f"loading transformer, config: {args.original_config_file}, checkpoint: {args.checkpoint_path}, use ema:"789f" {use_ema}"790)791792original_config = read_config_file(args.original_config_file).model793794diffusion_config = original_config.params.diffusion_config795transformer_config = original_config.params.diffusion_config.params.transformer_config796content_embedding_config = original_config.params.diffusion_config.params.content_emb_config797798pre_checkpoint = torch.load(args.checkpoint_path, map_location=checkpoint_map_location)799800if use_ema:801if "ema" in pre_checkpoint:802checkpoint = {}803for k, v in pre_checkpoint["model"].items():804checkpoint[k] = v805806for k, v in pre_checkpoint["ema"].items():807# The ema weights are only used on the transformer. To mimic their key as if they came808# from the state_dict for the top level model, we prefix with an additional "transformer."809# See the source linked in the args.use_ema config for more information.810checkpoint[f"transformer.{k}"] = v811else:812print("attempted to load ema weights but no ema weights are specified in the loaded checkpoint.")813checkpoint = pre_checkpoint["model"]814else:815checkpoint = pre_checkpoint["model"]816817del pre_checkpoint818819with init_empty_weights():820transformer_model = transformer_model_from_original_config(821diffusion_config, transformer_config, content_embedding_config822)823824diffusers_transformer_checkpoint = transformer_original_checkpoint_to_diffusers_checkpoint(825transformer_model, checkpoint826)827828# classifier free sampling embeddings interlude829830# The learned embeddings are stored on the transformer in the original VQ-diffusion. We store them on a separate831# model, so we pull them off the checkpoint before the checkpoint is deleted.832833learnable_classifier_free_sampling_embeddings = diffusion_config.params.learnable_cf834835if learnable_classifier_free_sampling_embeddings:836learned_classifier_free_sampling_embeddings_embeddings = checkpoint["transformer.empty_text_embed"]837else:838learned_classifier_free_sampling_embeddings_embeddings = None839840# done classifier free sampling embeddings interlude841842with tempfile.NamedTemporaryFile() as diffusers_transformer_checkpoint_file:843torch.save(diffusers_transformer_checkpoint, diffusers_transformer_checkpoint_file.name)844del diffusers_transformer_checkpoint845del checkpoint846load_checkpoint_and_dispatch(transformer_model, diffusers_transformer_checkpoint_file.name, device_map="auto")847848print("done loading transformer")849850# done transformer_model851852# text encoder853854print("loading CLIP text encoder")855856clip_name = "openai/clip-vit-base-patch32"857858# The original VQ-Diffusion specifies the pad value by the int used in the859# returned tokens. Each model uses `0` as the pad value. The transformers clip api860# specifies the pad value via the token before it has been tokenized. The `!` pad861# token is the same as padding with the `0` pad value.862pad_token = "!"863864tokenizer_model = CLIPTokenizer.from_pretrained(clip_name, pad_token=pad_token, device_map="auto")865866assert tokenizer_model.convert_tokens_to_ids(pad_token) == 0867868text_encoder_model = CLIPTextModel.from_pretrained(869clip_name,870# `CLIPTextModel` does not support device_map="auto"871# device_map="auto"872)873874print("done loading CLIP text encoder")875876# done text encoder877878# scheduler879880scheduler_model = VQDiffusionScheduler(881# the scheduler has the same number of embeddings as the transformer882num_vec_classes=transformer_model.num_vector_embeds883)884885# done scheduler886887# learned classifier free sampling embeddings888889with init_empty_weights():890learned_classifier_free_sampling_embeddings_model = LearnedClassifierFreeSamplingEmbeddings(891learnable_classifier_free_sampling_embeddings,892hidden_size=text_encoder_model.config.hidden_size,893length=tokenizer_model.model_max_length,894)895896learned_classifier_free_sampling_checkpoint = {897"embeddings": learned_classifier_free_sampling_embeddings_embeddings.float()898}899900with tempfile.NamedTemporaryFile() as learned_classifier_free_sampling_checkpoint_file:901torch.save(learned_classifier_free_sampling_checkpoint, learned_classifier_free_sampling_checkpoint_file.name)902del learned_classifier_free_sampling_checkpoint903del learned_classifier_free_sampling_embeddings_embeddings904load_checkpoint_and_dispatch(905learned_classifier_free_sampling_embeddings_model,906learned_classifier_free_sampling_checkpoint_file.name,907device_map="auto",908)909910# done learned classifier free sampling embeddings911912print(f"saving VQ diffusion model, path: {args.dump_path}")913914pipe = VQDiffusionPipeline(915vqvae=vqvae_model,916transformer=transformer_model,917tokenizer=tokenizer_model,918text_encoder=text_encoder_model,919learned_classifier_free_sampling_embeddings=learned_classifier_free_sampling_embeddings_model,920scheduler=scheduler_model,921)922pipe.save_pretrained(args.dump_path)923924print("done writing VQ diffusion model")925926927