Path: blob/main/scripts/convert_vae_pt_to_diffusers.py
1440 views
import argparse1import io23import requests4import torch5from omegaconf import OmegaConf67from diffusers import AutoencoderKL8from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (9assign_to_checkpoint,10conv_attn_to_linear,11create_vae_diffusers_config,12renew_vae_attention_paths,13renew_vae_resnet_paths,14)151617def custom_convert_ldm_vae_checkpoint(checkpoint, config):18vae_state_dict = checkpoint1920new_checkpoint = {}2122new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]23new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]24new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]25new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]26new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]27new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]2829new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]30new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]31new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]32new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]33new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]34new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]3536new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]37new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]38new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]39new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]4041# Retrieves the keys for the encoder down blocks only42num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})43down_blocks = {44layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)45}4647# Retrieves the keys for the decoder up blocks only48num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})49up_blocks = {50layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)51}5253for i in range(num_down_blocks):54resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]5556if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:57new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(58f"encoder.down.{i}.downsample.conv.weight"59)60new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(61f"encoder.down.{i}.downsample.conv.bias"62)6364paths = renew_vae_resnet_paths(resnets)65meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}66assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)6768mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]69num_mid_res_blocks = 270for i in range(1, num_mid_res_blocks + 1):71resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]7273paths = renew_vae_resnet_paths(resnets)74meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}75assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)7677mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]78paths = renew_vae_attention_paths(mid_attentions)79meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}80assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)81conv_attn_to_linear(new_checkpoint)8283for i in range(num_up_blocks):84block_id = num_up_blocks - 1 - i85resnets = [86key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key87]8889if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:90new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[91f"decoder.up.{block_id}.upsample.conv.weight"92]93new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[94f"decoder.up.{block_id}.upsample.conv.bias"95]9697paths = renew_vae_resnet_paths(resnets)98meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}99assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)100101mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]102num_mid_res_blocks = 2103for i in range(1, num_mid_res_blocks + 1):104resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]105106paths = renew_vae_resnet_paths(resnets)107meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}108assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)109110mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]111paths = renew_vae_attention_paths(mid_attentions)112meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}113assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)114conv_attn_to_linear(new_checkpoint)115return new_checkpoint116117118def vae_pt_to_vae_diffuser(119checkpoint_path: str,120output_path: str,121):122# Only support V1123r = requests.get(124" https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"125)126io_obj = io.BytesIO(r.content)127128original_config = OmegaConf.load(io_obj)129image_size = 512130device = "cuda" if torch.cuda.is_available() else "cpu"131checkpoint = torch.load(checkpoint_path, map_location=device)132133# Convert the VAE model.134vae_config = create_vae_diffusers_config(original_config, image_size=image_size)135converted_vae_checkpoint = custom_convert_ldm_vae_checkpoint(checkpoint["state_dict"], vae_config)136137vae = AutoencoderKL(**vae_config)138vae.load_state_dict(converted_vae_checkpoint)139vae.save_pretrained(output_path)140141142if __name__ == "__main__":143parser = argparse.ArgumentParser()144145parser.add_argument("--vae_pt_path", default=None, type=str, required=True, help="Path to the VAE.pt to convert.")146parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the VAE.pt to convert.")147148args = parser.parse_args()149150vae_pt_to_vae_diffuser(args.vae_pt_path, args.dump_path)151152153