Path: blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py
1440 views
# Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.1# *Only* converts the UNet, VAE, and Text Encoder.2# Does not convert optimizer state or any other thing.34import argparse5import os.path as osp6import re78import torch9from safetensors.torch import load_file, save_file101112# =================#13# UNet Conversion #14# =================#1516unet_conversion_map = [17# (stable-diffusion, HF Diffusers)18("time_embed.0.weight", "time_embedding.linear_1.weight"),19("time_embed.0.bias", "time_embedding.linear_1.bias"),20("time_embed.2.weight", "time_embedding.linear_2.weight"),21("time_embed.2.bias", "time_embedding.linear_2.bias"),22("input_blocks.0.0.weight", "conv_in.weight"),23("input_blocks.0.0.bias", "conv_in.bias"),24("out.0.weight", "conv_norm_out.weight"),25("out.0.bias", "conv_norm_out.bias"),26("out.2.weight", "conv_out.weight"),27("out.2.bias", "conv_out.bias"),28]2930unet_conversion_map_resnet = [31# (stable-diffusion, HF Diffusers)32("in_layers.0", "norm1"),33("in_layers.2", "conv1"),34("out_layers.0", "norm2"),35("out_layers.3", "conv2"),36("emb_layers.1", "time_emb_proj"),37("skip_connection", "conv_shortcut"),38]3940unet_conversion_map_layer = []41# hardcoded number of downblocks and resnets/attentions...42# would need smarter logic for other networks.43for i in range(4):44# loop over downblocks/upblocks4546for j in range(2):47# loop over resnets/attentions for downblocks48hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."49sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."50unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))5152if i < 3:53# no attention layers in down_blocks.354hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."55sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."56unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))5758for j in range(3):59# loop over resnets/attentions for upblocks60hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."61sd_up_res_prefix = f"output_blocks.{3*i + j}.0."62unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))6364if i > 0:65# no attention layers in up_blocks.066hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."67sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."68unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))6970if i < 3:71# no downsample in down_blocks.372hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."73sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."74unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))7576# no upsample in up_blocks.377hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."78sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."79unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))8081hf_mid_atn_prefix = "mid_block.attentions.0."82sd_mid_atn_prefix = "middle_block.1."83unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))8485for j in range(2):86hf_mid_res_prefix = f"mid_block.resnets.{j}."87sd_mid_res_prefix = f"middle_block.{2*j}."88unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))899091def convert_unet_state_dict(unet_state_dict):92# buyer beware: this is a *brittle* function,93# and correct output requires that all of these pieces interact in94# the exact order in which I have arranged them.95mapping = {k: k for k in unet_state_dict.keys()}96for sd_name, hf_name in unet_conversion_map:97mapping[hf_name] = sd_name98for k, v in mapping.items():99if "resnets" in k:100for sd_part, hf_part in unet_conversion_map_resnet:101v = v.replace(hf_part, sd_part)102mapping[k] = v103for k, v in mapping.items():104for sd_part, hf_part in unet_conversion_map_layer:105v = v.replace(hf_part, sd_part)106mapping[k] = v107new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}108return new_state_dict109110111# ================#112# VAE Conversion #113# ================#114115vae_conversion_map = [116# (stable-diffusion, HF Diffusers)117("nin_shortcut", "conv_shortcut"),118("norm_out", "conv_norm_out"),119("mid.attn_1.", "mid_block.attentions.0."),120]121122for i in range(4):123# down_blocks have two resnets124for j in range(2):125hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."126sd_down_prefix = f"encoder.down.{i}.block.{j}."127vae_conversion_map.append((sd_down_prefix, hf_down_prefix))128129if i < 3:130hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."131sd_downsample_prefix = f"down.{i}.downsample."132vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))133134hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."135sd_upsample_prefix = f"up.{3-i}.upsample."136vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))137138# up_blocks have three resnets139# also, up blocks in hf are numbered in reverse from sd140for j in range(3):141hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."142sd_up_prefix = f"decoder.up.{3-i}.block.{j}."143vae_conversion_map.append((sd_up_prefix, hf_up_prefix))144145# this part accounts for mid blocks in both the encoder and the decoder146for i in range(2):147hf_mid_res_prefix = f"mid_block.resnets.{i}."148sd_mid_res_prefix = f"mid.block_{i+1}."149vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))150151152vae_conversion_map_attn = [153# (stable-diffusion, HF Diffusers)154("norm.", "group_norm."),155("q.", "query."),156("k.", "key."),157("v.", "value."),158("proj_out.", "proj_attn."),159]160161162def reshape_weight_for_sd(w):163# convert HF linear weights to SD conv2d weights164return w.reshape(*w.shape, 1, 1)165166167def convert_vae_state_dict(vae_state_dict):168mapping = {k: k for k in vae_state_dict.keys()}169for k, v in mapping.items():170for sd_part, hf_part in vae_conversion_map:171v = v.replace(hf_part, sd_part)172mapping[k] = v173for k, v in mapping.items():174if "attentions" in k:175for sd_part, hf_part in vae_conversion_map_attn:176v = v.replace(hf_part, sd_part)177mapping[k] = v178new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}179weights_to_convert = ["q", "k", "v", "proj_out"]180for k, v in new_state_dict.items():181for weight_name in weights_to_convert:182if f"mid.attn_1.{weight_name}.weight" in k:183print(f"Reshaping {k} for SD format")184new_state_dict[k] = reshape_weight_for_sd(v)185return new_state_dict186187188# =========================#189# Text Encoder Conversion #190# =========================#191192193textenc_conversion_lst = [194# (stable-diffusion, HF Diffusers)195("resblocks.", "text_model.encoder.layers."),196("ln_1", "layer_norm1"),197("ln_2", "layer_norm2"),198(".c_fc.", ".fc1."),199(".c_proj.", ".fc2."),200(".attn", ".self_attn"),201("ln_final.", "transformer.text_model.final_layer_norm."),202("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),203("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),204]205protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst}206textenc_pattern = re.compile("|".join(protected.keys()))207208# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp209code2idx = {"q": 0, "k": 1, "v": 2}210211212def convert_text_enc_state_dict_v20(text_enc_dict):213new_state_dict = {}214capture_qkv_weight = {}215capture_qkv_bias = {}216for k, v in text_enc_dict.items():217if (218k.endswith(".self_attn.q_proj.weight")219or k.endswith(".self_attn.k_proj.weight")220or k.endswith(".self_attn.v_proj.weight")221):222k_pre = k[: -len(".q_proj.weight")]223k_code = k[-len("q_proj.weight")]224if k_pre not in capture_qkv_weight:225capture_qkv_weight[k_pre] = [None, None, None]226capture_qkv_weight[k_pre][code2idx[k_code]] = v227continue228229if (230k.endswith(".self_attn.q_proj.bias")231or k.endswith(".self_attn.k_proj.bias")232or k.endswith(".self_attn.v_proj.bias")233):234k_pre = k[: -len(".q_proj.bias")]235k_code = k[-len("q_proj.bias")]236if k_pre not in capture_qkv_bias:237capture_qkv_bias[k_pre] = [None, None, None]238capture_qkv_bias[k_pre][code2idx[k_code]] = v239continue240241relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)242new_state_dict[relabelled_key] = v243244for k_pre, tensors in capture_qkv_weight.items():245if None in tensors:246raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")247relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)248new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors)249250for k_pre, tensors in capture_qkv_bias.items():251if None in tensors:252raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")253relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)254new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors)255256return new_state_dict257258259def convert_text_enc_state_dict(text_enc_dict):260return text_enc_dict261262263if __name__ == "__main__":264parser = argparse.ArgumentParser()265266parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.")267parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")268parser.add_argument("--half", action="store_true", help="Save weights in half precision.")269parser.add_argument(270"--use_safetensors", action="store_true", help="Save weights use safetensors, default is ckpt."271)272273args = parser.parse_args()274275assert args.model_path is not None, "Must provide a model path!"276277assert args.checkpoint_path is not None, "Must provide a checkpoint path!"278279# Path for safetensors280unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.safetensors")281vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.safetensors")282text_enc_path = osp.join(args.model_path, "text_encoder", "model.safetensors")283284# Load models from safetensors if it exists, if it doesn't pytorch285if osp.exists(unet_path):286unet_state_dict = load_file(unet_path, device="cpu")287else:288unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin")289unet_state_dict = torch.load(unet_path, map_location="cpu")290291if osp.exists(vae_path):292vae_state_dict = load_file(vae_path, device="cpu")293else:294vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin")295vae_state_dict = torch.load(vae_path, map_location="cpu")296297if osp.exists(text_enc_path):298text_enc_dict = load_file(text_enc_path, device="cpu")299else:300text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin")301text_enc_dict = torch.load(text_enc_path, map_location="cpu")302303# Convert the UNet model304unet_state_dict = convert_unet_state_dict(unet_state_dict)305unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}306307# Convert the VAE model308vae_state_dict = convert_vae_state_dict(vae_state_dict)309vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}310311# Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper312is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict313314if is_v20_model:315# Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm316text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()}317text_enc_dict = convert_text_enc_state_dict_v20(text_enc_dict)318text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()}319else:320text_enc_dict = convert_text_enc_state_dict(text_enc_dict)321text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}322323# Put together new checkpoint324state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}325if args.half:326state_dict = {k: v.half() for k, v in state_dict.items()}327328if args.use_safetensors:329save_file(state_dict, args.checkpoint_path)330else:331state_dict = {"state_dict": state_dict}332torch.save(state_dict, args.checkpoint_path)333334335