Path: blob/main/Dreambooth/convertosd.py
540 views
# Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.1# *Only* converts the UNet, VAE, and Text Encoder.2# Does not convert optimizer state or any other thing.3# Written by jachiam45import argparse6import os.path as osp78import torch91011# =================#12# UNet Conversion #13# =================#1415unet_conversion_map = [16# (stable-diffusion, HF Diffusers)17("time_embed.0.weight", "time_embedding.linear_1.weight"),18("time_embed.0.bias", "time_embedding.linear_1.bias"),19("time_embed.2.weight", "time_embedding.linear_2.weight"),20("time_embed.2.bias", "time_embedding.linear_2.bias"),21("input_blocks.0.0.weight", "conv_in.weight"),22("input_blocks.0.0.bias", "conv_in.bias"),23("out.0.weight", "conv_norm_out.weight"),24("out.0.bias", "conv_norm_out.bias"),25("out.2.weight", "conv_out.weight"),26("out.2.bias", "conv_out.bias"),27]2829unet_conversion_map_resnet = [30# (stable-diffusion, HF Diffusers)31("in_layers.0", "norm1"),32("in_layers.2", "conv1"),33("out_layers.0", "norm2"),34("out_layers.3", "conv2"),35("emb_layers.1", "time_emb_proj"),36("skip_connection", "conv_shortcut"),37]3839unet_conversion_map_layer = []40# hardcoded number of downblocks and resnets/attentions...41# would need smarter logic for other networks.42for i in range(4):43# loop over downblocks/upblocks4445for j in range(2):46# loop over resnets/attentions for downblocks47hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."48sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."49unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))5051if i < 3:52# no attention layers in down_blocks.353hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."54sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."55unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))5657for j in range(3):58# loop over resnets/attentions for upblocks59hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."60sd_up_res_prefix = f"output_blocks.{3*i + j}.0."61unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))6263if i > 0:64# no attention layers in up_blocks.065hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."66sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."67unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))6869if i < 3:70# no downsample in down_blocks.371hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."72sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."73unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))7475# no upsample in up_blocks.376hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."77sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."78unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))7980hf_mid_atn_prefix = "mid_block.attentions.0."81sd_mid_atn_prefix = "middle_block.1."82unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))8384for j in range(2):85hf_mid_res_prefix = f"mid_block.resnets.{j}."86sd_mid_res_prefix = f"middle_block.{2*j}."87unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))888990def convert_unet_state_dict(unet_state_dict):91# buyer beware: this is a *brittle* function,92# and correct output requires that all of these pieces interact in93# the exact order in which I have arranged them.94mapping = {k: k for k in unet_state_dict.keys()}95for sd_name, hf_name in unet_conversion_map:96mapping[hf_name] = sd_name97for k, v in mapping.items():98if "resnets" in k:99for sd_part, hf_part in unet_conversion_map_resnet:100v = v.replace(hf_part, sd_part)101mapping[k] = v102for k, v in mapping.items():103for sd_part, hf_part in unet_conversion_map_layer:104v = v.replace(hf_part, sd_part)105mapping[k] = v106new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}107return new_state_dict108109110# ================#111# VAE Conversion #112# ================#113114vae_conversion_map = [115# (stable-diffusion, HF Diffusers)116("nin_shortcut", "conv_shortcut"),117("norm_out", "conv_norm_out"),118("mid.attn_1.", "mid_block.attentions.0."),119]120121for i in range(4):122# down_blocks have two resnets123for j in range(2):124hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."125sd_down_prefix = f"encoder.down.{i}.block.{j}."126vae_conversion_map.append((sd_down_prefix, hf_down_prefix))127128if i < 3:129hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."130sd_downsample_prefix = f"down.{i}.downsample."131vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))132133hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."134sd_upsample_prefix = f"up.{3-i}.upsample."135vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))136137# up_blocks have three resnets138# also, up blocks in hf are numbered in reverse from sd139for j in range(3):140hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."141sd_up_prefix = f"decoder.up.{3-i}.block.{j}."142vae_conversion_map.append((sd_up_prefix, hf_up_prefix))143144# this part accounts for mid blocks in both the encoder and the decoder145for i in range(2):146hf_mid_res_prefix = f"mid_block.resnets.{i}."147sd_mid_res_prefix = f"mid.block_{i+1}."148vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))149150151vae_conversion_map_attn = [152# (stable-diffusion, HF Diffusers)153("norm.", "group_norm."),154("q.", "query."),155("k.", "key."),156("v.", "value."),157("proj_out.", "proj_attn."),158]159160161def reshape_weight_for_sd(w):162# convert HF linear weights to SD conv2d weights163return w.reshape(*w.shape, 1, 1)164165166def convert_vae_state_dict(vae_state_dict):167mapping = {k: k for k in vae_state_dict.keys()}168for k, v in mapping.items():169for sd_part, hf_part in vae_conversion_map:170v = v.replace(hf_part, sd_part)171mapping[k] = v172for k, v in mapping.items():173if "attentions" in k:174for sd_part, hf_part in vae_conversion_map_attn:175v = v.replace(hf_part, sd_part)176mapping[k] = v177new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}178weights_to_convert = ["q", "k", "v", "proj_out"]179print("[1;32mConverting to CKPT ...")180for k, v in new_state_dict.items():181for weight_name in weights_to_convert:182if f"mid.attn_1.{weight_name}.weight" in k:183new_state_dict[k] = reshape_weight_for_sd(v)184return new_state_dict185186187# =========================#188# Text Encoder Conversion #189# =========================#190# pretty much a no-op191192193def convert_text_enc_state_dict(text_enc_dict):194return text_enc_dict195196197if __name__ == "__main__":198199200model_path = ""201checkpoint_path= ""202203unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin")204vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin")205text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin")206207# Convert the UNet model208unet_state_dict = torch.load(unet_path, map_location='cpu')209unet_state_dict = convert_unet_state_dict(unet_state_dict)210unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}211212# Convert the VAE model213vae_state_dict = torch.load(vae_path, map_location='cpu')214vae_state_dict = convert_vae_state_dict(vae_state_dict)215vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}216217# Convert the text encoder model218text_enc_dict = torch.load(text_enc_path, map_location='cpu')219text_enc_dict = convert_text_enc_state_dict(text_enc_dict)220text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}221222# Put together new checkpoint223state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}224225state_dict = {k:v.half() for k,v in state_dict.items()}226state_dict = {"state_dict": state_dict}227torch.save(state_dict, checkpoint_path)228229230