Path: blob/main/Dreambooth/convertodiffv1.py
540 views
import argparse1import os2import torch3from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig4from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel5678# DiffUsers版StableDiffusionのモデルパラメータ9NUM_TRAIN_TIMESTEPS = 100010BETA_START = 0.0008511BETA_END = 0.01201213UNET_PARAMS_MODEL_CHANNELS = 32014UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]15UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]16UNET_PARAMS_IMAGE_SIZE = 6417UNET_PARAMS_IN_CHANNELS = 418UNET_PARAMS_OUT_CHANNELS = 419UNET_PARAMS_NUM_RES_BLOCKS = 220UNET_PARAMS_CONTEXT_DIM = 76821UNET_PARAMS_NUM_HEADS = 82223VAE_PARAMS_Z_CHANNELS = 424VAE_PARAMS_RESOLUTION = 51225VAE_PARAMS_IN_CHANNELS = 326VAE_PARAMS_OUT_CH = 327VAE_PARAMS_CH = 12828VAE_PARAMS_CH_MULT = [1, 2, 4, 4]29VAE_PARAMS_NUM_RES_BLOCKS = 23031# V232V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]33V2_UNET_PARAMS_CONTEXT_DIM = 1024343536# region StableDiffusion->Diffusersの変換コード37# convert_original_stable_diffusion_to_diffusers をコピーしている(ASL 2.0)383940def shave_segments(path, n_shave_prefix_segments=1):41"""42Removes segments. Positive values shave the first segments, negative shave the last segments.43"""44if n_shave_prefix_segments >= 0:45return ".".join(path.split(".")[n_shave_prefix_segments:])46else:47return ".".join(path.split(".")[:n_shave_prefix_segments])484950def renew_resnet_paths(old_list, n_shave_prefix_segments=0):51"""52Updates paths inside resnets to the new naming scheme (local renaming)53"""54mapping = []55for old_item in old_list:56new_item = old_item.replace("in_layers.0", "norm1")57new_item = new_item.replace("in_layers.2", "conv1")5859new_item = new_item.replace("out_layers.0", "norm2")60new_item = new_item.replace("out_layers.3", "conv2")6162new_item = new_item.replace("emb_layers.1", "time_emb_proj")63new_item = new_item.replace("skip_connection", "conv_shortcut")6465new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)6667mapping.append({"old": old_item, "new": new_item})6869return mapping707172def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):73"""74Updates paths inside resnets to the new naming scheme (local renaming)75"""76mapping = []77for old_item in old_list:78new_item = old_item7980new_item = new_item.replace("nin_shortcut", "conv_shortcut")81new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)8283mapping.append({"old": old_item, "new": new_item})8485return mapping868788def renew_attention_paths(old_list, n_shave_prefix_segments=0):89"""90Updates paths inside attentions to the new naming scheme (local renaming)91"""92mapping = []93for old_item in old_list:94new_item = old_item9596# new_item = new_item.replace('norm.weight', 'group_norm.weight')97# new_item = new_item.replace('norm.bias', 'group_norm.bias')9899# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')100# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')101102# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)103104mapping.append({"old": old_item, "new": new_item})105106return mapping107108109def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):110"""111Updates paths inside attentions to the new naming scheme (local renaming)112"""113mapping = []114for old_item in old_list:115new_item = old_item116117new_item = new_item.replace("norm.weight", "group_norm.weight")118new_item = new_item.replace("norm.bias", "group_norm.bias")119120new_item = new_item.replace("q.weight", "query.weight")121new_item = new_item.replace("q.bias", "query.bias")122123new_item = new_item.replace("k.weight", "key.weight")124new_item = new_item.replace("k.bias", "key.bias")125126new_item = new_item.replace("v.weight", "value.weight")127new_item = new_item.replace("v.bias", "value.bias")128129new_item = new_item.replace("proj_out.weight", "proj_attn.weight")130new_item = new_item.replace("proj_out.bias", "proj_attn.bias")131132new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)133134mapping.append({"old": old_item, "new": new_item})135136return mapping137138139def assign_to_checkpoint(140paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None141):142"""143This does the final conversion step: take locally converted weights and apply a global renaming144to them. It splits attention layers, and takes into account additional replacements145that may arise.146147Assigns the weights to the new checkpoint.148"""149assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."150151# Splits the attention layers into three variables.152if attention_paths_to_split is not None:153for path, path_map in attention_paths_to_split.items():154old_tensor = old_checkpoint[path]155channels = old_tensor.shape[0] // 3156157target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)158159num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3160161old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])162query, key, value = old_tensor.split(channels // num_heads, dim=1)163164checkpoint[path_map["query"]] = query.reshape(target_shape)165checkpoint[path_map["key"]] = key.reshape(target_shape)166checkpoint[path_map["value"]] = value.reshape(target_shape)167168for path in paths:169new_path = path["new"]170171# These have already been assigned172if attention_paths_to_split is not None and new_path in attention_paths_to_split:173continue174175# Global renaming happens here176new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")177new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")178new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")179180if additional_replacements is not None:181for replacement in additional_replacements:182new_path = new_path.replace(replacement["old"], replacement["new"])183184# proj_attn.weight has to be converted from conv 1D to linear185if "proj_attn.weight" in new_path:186checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]187else:188checkpoint[new_path] = old_checkpoint[path["old"]]189190191def conv_attn_to_linear(checkpoint):192keys = list(checkpoint.keys())193attn_keys = ["query.weight", "key.weight", "value.weight"]194for key in keys:195if ".".join(key.split(".")[-2:]) in attn_keys:196if checkpoint[key].ndim > 2:197checkpoint[key] = checkpoint[key][:, :, 0, 0]198elif "proj_attn.weight" in key:199if checkpoint[key].ndim > 2:200checkpoint[key] = checkpoint[key][:, :, 0]201202203def linear_transformer_to_conv(checkpoint):204keys = list(checkpoint.keys())205tf_keys = ["proj_in.weight", "proj_out.weight"]206for key in keys:207if ".".join(key.split(".")[-2:]) in tf_keys:208if checkpoint[key].ndim == 2:209checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)210211212def convert_ldm_unet_checkpoint(v2, checkpoint, config):213"""214Takes a state dict and a config, and returns a converted checkpoint.215"""216217# extract state_dict for UNet218unet_state_dict = {}219unet_key = "model.diffusion_model."220keys = list(checkpoint.keys())221for key in keys:222if key.startswith(unet_key):223unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)224225new_checkpoint = {}226227new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]228new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]229new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]230new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]231232new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]233new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]234235new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]236new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]237new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]238new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]239240# Retrieves the keys for the input blocks only241num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})242input_blocks = {243layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]244for layer_id in range(num_input_blocks)245}246247# Retrieves the keys for the middle blocks only248num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})249middle_blocks = {250layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]251for layer_id in range(num_middle_blocks)252}253254# Retrieves the keys for the output blocks only255num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})256output_blocks = {257layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]258for layer_id in range(num_output_blocks)259}260261for i in range(1, num_input_blocks):262block_id = (i - 1) // (config["layers_per_block"] + 1)263layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)264265resnets = [266key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key267]268attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]269270if f"input_blocks.{i}.0.op.weight" in unet_state_dict:271new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(272f"input_blocks.{i}.0.op.weight"273)274new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(275f"input_blocks.{i}.0.op.bias"276)277278paths = renew_resnet_paths(resnets)279meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}280assign_to_checkpoint(281paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config282)283284if len(attentions):285paths = renew_attention_paths(attentions)286meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}287assign_to_checkpoint(288paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config289)290291resnet_0 = middle_blocks[0]292attentions = middle_blocks[1]293resnet_1 = middle_blocks[2]294295resnet_0_paths = renew_resnet_paths(resnet_0)296assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)297298resnet_1_paths = renew_resnet_paths(resnet_1)299assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)300301attentions_paths = renew_attention_paths(attentions)302meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}303assign_to_checkpoint(304attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config305)306307for i in range(num_output_blocks):308block_id = i // (config["layers_per_block"] + 1)309layer_in_block_id = i % (config["layers_per_block"] + 1)310output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]311output_block_list = {}312313for layer in output_block_layers:314layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)315if layer_id in output_block_list:316output_block_list[layer_id].append(layer_name)317else:318output_block_list[layer_id] = [layer_name]319320if len(output_block_list) > 1:321resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]322attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]323324resnet_0_paths = renew_resnet_paths(resnets)325paths = renew_resnet_paths(resnets)326327meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}328assign_to_checkpoint(329paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config330)331332if ["conv.weight", "conv.bias"] in output_block_list.values():333index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])334new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[335f"output_blocks.{i}.{index}.conv.weight"336]337new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[338f"output_blocks.{i}.{index}.conv.bias"339]340341# Clear attentions as they have been attributed above.342if len(attentions) == 2:343attentions = []344345if len(attentions):346paths = renew_attention_paths(attentions)347meta_path = {348"old": f"output_blocks.{i}.1",349"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",350}351assign_to_checkpoint(352paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config353)354else:355resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)356for path in resnet_0_paths:357old_path = ".".join(["output_blocks", str(i), path["old"]])358new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])359360new_checkpoint[new_path] = unet_state_dict[old_path]361362# SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する363if v2:364linear_transformer_to_conv(new_checkpoint)365366return new_checkpoint367368369def convert_ldm_vae_checkpoint(checkpoint, config):370# extract state dict for VAE371vae_state_dict = {}372vae_key = "first_stage_model."373keys = list(checkpoint.keys())374for key in keys:375if key.startswith(vae_key):376vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)377# if len(vae_state_dict) == 0:378# # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict379# vae_state_dict = checkpoint380381new_checkpoint = {}382383new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]384new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]385new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]386new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]387new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]388new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]389390new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]391new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]392new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]393new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]394new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]395new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]396397new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]398new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]399new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]400new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]401402# Retrieves the keys for the encoder down blocks only403num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})404down_blocks = {405layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)406}407408# Retrieves the keys for the decoder up blocks only409num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})410up_blocks = {411layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)412}413414for i in range(num_down_blocks):415resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]416417if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:418new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(419f"encoder.down.{i}.downsample.conv.weight"420)421new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(422f"encoder.down.{i}.downsample.conv.bias"423)424425paths = renew_vae_resnet_paths(resnets)426meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}427assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)428429mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]430num_mid_res_blocks = 2431for i in range(1, num_mid_res_blocks + 1):432resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]433434paths = renew_vae_resnet_paths(resnets)435meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}436assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)437438mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]439paths = renew_vae_attention_paths(mid_attentions)440meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}441assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)442conv_attn_to_linear(new_checkpoint)443444for i in range(num_up_blocks):445block_id = num_up_blocks - 1 - i446resnets = [447key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key448]449450if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:451new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[452f"decoder.up.{block_id}.upsample.conv.weight"453]454new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[455f"decoder.up.{block_id}.upsample.conv.bias"456]457458paths = renew_vae_resnet_paths(resnets)459meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}460assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)461462mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]463num_mid_res_blocks = 2464for i in range(1, num_mid_res_blocks + 1):465resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]466467paths = renew_vae_resnet_paths(resnets)468meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}469assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)470471mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]472paths = renew_vae_attention_paths(mid_attentions)473meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}474assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)475conv_attn_to_linear(new_checkpoint)476return new_checkpoint477478479def create_unet_diffusers_config(v2):480"""481Creates a config for the diffusers based on the config of the LDM model.482"""483# unet_params = original_config.model.params.unet_config.params484485block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT]486487down_block_types = []488resolution = 1489for i in range(len(block_out_channels)):490block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D"491down_block_types.append(block_type)492if i != len(block_out_channels) - 1:493resolution *= 2494495up_block_types = []496for i in range(len(block_out_channels)):497block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D"498up_block_types.append(block_type)499resolution //= 2500501config = dict(502sample_size=UNET_PARAMS_IMAGE_SIZE,503in_channels=UNET_PARAMS_IN_CHANNELS,504out_channels=UNET_PARAMS_OUT_CHANNELS,505down_block_types=tuple(down_block_types),506up_block_types=tuple(up_block_types),507block_out_channels=tuple(block_out_channels),508layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,509cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM,510attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,511)512513return config514515516def create_vae_diffusers_config():517"""518Creates a config for the diffusers based on the config of the LDM model.519"""520# vae_params = original_config.model.params.first_stage_config.params.ddconfig521# _ = original_config.model.params.first_stage_config.params.embed_dim522block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT]523down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)524up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)525526config = dict(527sample_size=VAE_PARAMS_RESOLUTION,528in_channels=VAE_PARAMS_IN_CHANNELS,529out_channels=VAE_PARAMS_OUT_CH,530down_block_types=tuple(down_block_types),531up_block_types=tuple(up_block_types),532block_out_channels=tuple(block_out_channels),533latent_channels=VAE_PARAMS_Z_CHANNELS,534layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS,535)536return config537538539def convert_ldm_clip_checkpoint_v1(checkpoint):540keys = list(checkpoint.keys())541text_model_dict = {}542for key in keys:543if key.startswith("cond_stage_model.transformer"):544text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key]545return text_model_dict546547548def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):549# 嫌になるくらい違うぞ!550def convert_key(key):551if not key.startswith("cond_stage_model"):552return None553554# common conversion555key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.")556key = key.replace("cond_stage_model.model.", "text_model.")557558if "resblocks" in key:559# resblocks conversion560key = key.replace(".resblocks.", ".layers.")561if ".ln_" in key:562key = key.replace(".ln_", ".layer_norm")563elif ".mlp." in key:564key = key.replace(".c_fc.", ".fc1.")565key = key.replace(".c_proj.", ".fc2.")566elif '.attn.out_proj' in key:567key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")568elif '.attn.in_proj' in key:569key = None # 特殊なので後で処理する570else:571raise ValueError(f"unexpected key in SD: {key}")572elif '.positional_embedding' in key:573key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")574elif '.text_projection' in key:575key = None # 使われない???576elif '.logit_scale' in key:577key = None # 使われない???578elif '.token_embedding' in key:579key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")580elif '.ln_final' in key:581key = key.replace(".ln_final", ".final_layer_norm")582return key583584keys = list(checkpoint.keys())585new_sd = {}586for key in keys:587# remove resblocks 23588if '.resblocks.23.' in key:589continue590new_key = convert_key(key)591if new_key is None:592continue593new_sd[new_key] = checkpoint[key]594595# attnの変換596for key in keys:597if '.resblocks.23.' in key:598continue599if '.resblocks' in key and '.attn.in_proj_' in key:600# 三つに分割601values = torch.chunk(checkpoint[key], 3)602603key_suffix = ".weight" if "weight" in key else ".bias"604key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.")605key_pfx = key_pfx.replace("_weight", "")606key_pfx = key_pfx.replace("_bias", "")607key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")608new_sd[key_pfx + "q_proj" + key_suffix] = values[0]609new_sd[key_pfx + "k_proj" + key_suffix] = values[1]610new_sd[key_pfx + "v_proj" + key_suffix] = values[2]611612# position_idsの追加613new_sd["text_model.embeddings.position_ids"] = torch.Tensor([list(range(max_length))]).to(torch.int64)614return new_sd615616# endregion617618619# region Diffusers->StableDiffusion の変換コード620# convert_diffusers_to_original_stable_diffusion をコピーしている(ASL 2.0)621622def conv_transformer_to_linear(checkpoint):623keys = list(checkpoint.keys())624tf_keys = ["proj_in.weight", "proj_out.weight"]625for key in keys:626if ".".join(key.split(".")[-2:]) in tf_keys:627if checkpoint[key].ndim > 2:628checkpoint[key] = checkpoint[key][:, :, 0, 0]629630631def convert_unet_state_dict_to_sd(v2, unet_state_dict):632unet_conversion_map = [633# (stable-diffusion, HF Diffusers)634("time_embed.0.weight", "time_embedding.linear_1.weight"),635("time_embed.0.bias", "time_embedding.linear_1.bias"),636("time_embed.2.weight", "time_embedding.linear_2.weight"),637("time_embed.2.bias", "time_embedding.linear_2.bias"),638("input_blocks.0.0.weight", "conv_in.weight"),639("input_blocks.0.0.bias", "conv_in.bias"),640("out.0.weight", "conv_norm_out.weight"),641("out.0.bias", "conv_norm_out.bias"),642("out.2.weight", "conv_out.weight"),643("out.2.bias", "conv_out.bias"),644]645646unet_conversion_map_resnet = [647# (stable-diffusion, HF Diffusers)648("in_layers.0", "norm1"),649("in_layers.2", "conv1"),650("out_layers.0", "norm2"),651("out_layers.3", "conv2"),652("emb_layers.1", "time_emb_proj"),653("skip_connection", "conv_shortcut"),654]655656unet_conversion_map_layer = []657for i in range(4):658# loop over downblocks/upblocks659660for j in range(2):661# loop over resnets/attentions for downblocks662hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."663sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."664unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))665666if i < 3:667# no attention layers in down_blocks.3668hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."669sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."670unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))671672for j in range(3):673# loop over resnets/attentions for upblocks674hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."675sd_up_res_prefix = f"output_blocks.{3*i + j}.0."676unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))677678if i > 0:679# no attention layers in up_blocks.0680hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."681sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."682unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))683684if i < 3:685# no downsample in down_blocks.3686hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."687sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."688unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))689690# no upsample in up_blocks.3691hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."692sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."693unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))694695hf_mid_atn_prefix = "mid_block.attentions.0."696sd_mid_atn_prefix = "middle_block.1."697unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))698699for j in range(2):700hf_mid_res_prefix = f"mid_block.resnets.{j}."701sd_mid_res_prefix = f"middle_block.{2*j}."702unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))703704# buyer beware: this is a *brittle* function,705# and correct output requires that all of these pieces interact in706# the exact order in which I have arranged them.707mapping = {k: k for k in unet_state_dict.keys()}708for sd_name, hf_name in unet_conversion_map:709mapping[hf_name] = sd_name710for k, v in mapping.items():711if "resnets" in k:712for sd_part, hf_part in unet_conversion_map_resnet:713v = v.replace(hf_part, sd_part)714mapping[k] = v715for k, v in mapping.items():716for sd_part, hf_part in unet_conversion_map_layer:717v = v.replace(hf_part, sd_part)718mapping[k] = v719new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}720721if v2:722conv_transformer_to_linear(new_state_dict)723724return new_state_dict725726727# ================#728# VAE Conversion #729# ================#730731def reshape_weight_for_sd(w):732# convert HF linear weights to SD conv2d weights733return w.reshape(*w.shape, 1, 1)734735736def convert_vae_state_dict(vae_state_dict):737vae_conversion_map = [738# (stable-diffusion, HF Diffusers)739("nin_shortcut", "conv_shortcut"),740("norm_out", "conv_norm_out"),741("mid.attn_1.", "mid_block.attentions.0."),742]743744for i in range(4):745# down_blocks have two resnets746for j in range(2):747hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."748sd_down_prefix = f"encoder.down.{i}.block.{j}."749vae_conversion_map.append((sd_down_prefix, hf_down_prefix))750751if i < 3:752hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."753sd_downsample_prefix = f"down.{i}.downsample."754vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))755756hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."757sd_upsample_prefix = f"up.{3-i}.upsample."758vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))759760# up_blocks have three resnets761# also, up blocks in hf are numbered in reverse from sd762for j in range(3):763hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."764sd_up_prefix = f"decoder.up.{3-i}.block.{j}."765vae_conversion_map.append((sd_up_prefix, hf_up_prefix))766767# this part accounts for mid blocks in both the encoder and the decoder768for i in range(2):769hf_mid_res_prefix = f"mid_block.resnets.{i}."770sd_mid_res_prefix = f"mid.block_{i+1}."771vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))772773vae_conversion_map_attn = [774# (stable-diffusion, HF Diffusers)775("norm.", "group_norm."),776("q.", "query."),777("k.", "key."),778("v.", "value."),779("proj_out.", "proj_attn."),780]781782mapping = {k: k for k in vae_state_dict.keys()}783for k, v in mapping.items():784for sd_part, hf_part in vae_conversion_map:785v = v.replace(hf_part, sd_part)786mapping[k] = v787for k, v in mapping.items():788if "attentions" in k:789for sd_part, hf_part in vae_conversion_map_attn:790v = v.replace(hf_part, sd_part)791mapping[k] = v792new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}793weights_to_convert = ["q", "k", "v", "proj_out"]794795for k, v in new_state_dict.items():796for weight_name in weights_to_convert:797if f"mid.attn_1.{weight_name}.weight" in k:798new_state_dict[k] = reshape_weight_for_sd(v)799800return new_state_dict801802803# endregion804805806def load_checkpoint_with_text_encoder_conversion(ckpt_path):807# text encoderの格納形式が違うモデルに対応する ('text_model'がない)808TEXT_ENCODER_KEY_REPLACEMENTS = [809('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'),810('cond_stage_model.transformer.encoder.', 'cond_stage_model.transformer.text_model.encoder.'),811('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.')812]813814checkpoint = torch.load(ckpt_path, map_location="cuda")815state_dict = checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint816key_reps = []817for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:818for key in state_dict.keys():819if key.startswith(rep_from):820new_key = rep_to + key[len(rep_from):]821key_reps.append((key, new_key))822823for key, new_key in key_reps:824state_dict[new_key] = state_dict[key]825del state_dict[key]826827return checkpoint828829830# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認831def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):832833checkpoint = load_checkpoint_with_text_encoder_conversion(ckpt_path)834state_dict = checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint835if dtype is not None:836for k, v in state_dict.items():837if type(v) is torch.Tensor:838state_dict[k] = v.to(dtype)839840# Convert the UNet2DConditionModel model.841unet_config = create_unet_diffusers_config(v2)842converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config)843844unet = UNet2DConditionModel(**unet_config)845info = unet.load_state_dict(converted_unet_checkpoint)846847848# Convert the VAE model.849vae_config = create_vae_diffusers_config()850converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)851852vae = AutoencoderKL(**vae_config)853info = vae.load_state_dict(converted_vae_checkpoint)854855856# convert text_model857if v2:858converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77)859cfg = CLIPTextConfig(860vocab_size=49408,861hidden_size=1024,862intermediate_size=4096,863num_hidden_layers=23,864num_attention_heads=16,865max_position_embeddings=77,866hidden_act="gelu",867layer_norm_eps=1e-05,868dropout=0.0,869attention_dropout=0.0,870initializer_range=0.02,871initializer_factor=1.0,872pad_token_id=1,873bos_token_id=0,874eos_token_id=2,875model_type="clip_text_model",876projection_dim=512,877torch_dtype="float32",878transformers_version="4.25.0.dev0",879)880text_model = CLIPTextModel._from_config(cfg)881info = text_model.load_state_dict(converted_text_encoder_checkpoint)882else:883converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)884cfg = CLIPTextConfig(885vocab_size=49408,886hidden_size=768,887intermediate_size=3072,888num_hidden_layers=12,889num_attention_heads=12,890max_position_embeddings=77,891hidden_act="quick_gelu",892layer_norm_eps=1e-05,893dropout=0.0,894attention_dropout=0.0,895initializer_range=0.02,896initializer_factor=1.0,897pad_token_id=1,898bos_token_id=0,899eos_token_id=2,900model_type="clip_text_model",901projection_dim=768,902torch_dtype="float32",903transformers_version="4.16.0.dev0",904)905906907text_model = CLIPTextModel._from_config(cfg)908#text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")909info = text_model.load_state_dict(converted_text_encoder_checkpoint)910911912return text_model, vae, unet913914915def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False):916def convert_key(key):917# position_idsの除去918if ".position_ids" in key:919return None920921# common922key = key.replace("text_model.encoder.", "transformer.")923key = key.replace("text_model.", "")924if "layers" in key:925# resblocks conversion926key = key.replace(".layers.", ".resblocks.")927if ".layer_norm" in key:928key = key.replace(".layer_norm", ".ln_")929elif ".mlp." in key:930key = key.replace(".fc1.", ".c_fc.")931key = key.replace(".fc2.", ".c_proj.")932elif '.self_attn.out_proj' in key:933key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")934elif '.self_attn.' in key:935key = None # 特殊なので後で処理する936else:937raise ValueError(f"unexpected key in DiffUsers model: {key}")938elif '.position_embedding' in key:939key = key.replace("embeddings.position_embedding.weight", "positional_embedding")940elif '.token_embedding' in key:941key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")942elif 'final_layer_norm' in key:943key = key.replace("final_layer_norm", "ln_final")944return key945946keys = list(checkpoint.keys())947new_sd = {}948for key in keys:949new_key = convert_key(key)950if new_key is None:951continue952new_sd[new_key] = checkpoint[key]953954# attnの変換955for key in keys:956if 'layers' in key and 'q_proj' in key:957# 三つを結合958key_q = key959key_k = key.replace("q_proj", "k_proj")960key_v = key.replace("q_proj", "v_proj")961962value_q = checkpoint[key_q]963value_k = checkpoint[key_k]964value_v = checkpoint[key_v]965value = torch.cat([value_q, value_k, value_v])966967new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")968new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")969new_sd[new_key] = value970971# 最後の層などを捏造するか972if make_dummy_weights:973974keys = list(new_sd.keys())975for key in keys:976if key.startswith("transformer.resblocks.22."):977new_sd[key.replace(".22.", ".23.")] = new_sd[key]978979# Diffusersに含まれない重みを作っておく980new_sd['text_projection'] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device)981new_sd['logit_scale'] = torch.tensor(1)982983return new_sd984985986def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None):987if ckpt_path is not None:988# epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む989checkpoint = load_checkpoint_with_text_encoder_conversion(ckpt_path)990state_dict = checkpoint["state_dict"]991strict = True992else:993# 新しく作る994checkpoint = {}995state_dict = {}996strict = False997998def update_sd(prefix, sd):999for k, v in sd.items():1000key = prefix + k1001assert not strict or key in state_dict, f"Illegal key in save SD: {key}"1002if save_dtype is not None:1003v = v.detach().clone().to("cpu").to(save_dtype)1004state_dict[key] = v10051006# Convert the UNet model1007unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict())1008update_sd("model.diffusion_model.", unet_state_dict)10091010# Convert the text encoder model1011if v2:1012make_dummy = ckpt_path is None # 参照元のcheckpointがない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる1013text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy)1014update_sd("cond_stage_model.model.", text_enc_dict)1015else:1016text_enc_dict = text_encoder.state_dict()1017update_sd("cond_stage_model.transformer.", text_enc_dict)10181019# Convert the VAE1020if vae is not None:1021vae_dict = convert_vae_state_dict(vae.state_dict())1022update_sd("first_stage_model.", vae_dict)10231024# Put together new checkpoint1025key_count = len(state_dict.keys())1026new_ckpt = {'state_dict': state_dict}10271028if 'epoch' in checkpoint:1029epochs += checkpoint['epoch']1030if 'global_step' in checkpoint:1031steps += checkpoint['global_step']10321033new_ckpt['epoch'] = epochs1034new_ckpt['global_step'] = steps10351036torch.save(new_ckpt, output_file)10371038return key_count103910401041def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, vae=None):1042if vae is None:1043vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")10441045pipeline = StableDiffusionPipeline(1046unet=unet,1047text_encoder=text_encoder,1048vae=vae,1049scheduler = DDIMScheduler.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="scheduler"),1050tokenizer=CLIPTokenizer.from_pretrained("refmdl", subfolder="tokenizer"),1051)1052pipeline.save_pretrained(output_dir)1053105410551056def convert(args):1057print("[1;32mConverting to Diffusers ...")1058load_dtype = torch.float16 if args.fp16 else None10591060save_dtype = None1061if args.fp16:1062save_dtype = torch.float161063elif args.bf16:1064save_dtype = torch.bfloat161065elif args.float:1066save_dtype = torch.float10671068is_load_ckpt = os.path.isfile(args.model_to_load)1069is_save_ckpt = len(os.path.splitext(args.model_to_save)[1]) > 010701071assert not is_load_ckpt or args.v1 != args.v2, f"v1 or v2 is required to load checkpoint"1072assert is_save_ckpt is not None, f"reference model is required to save as Diffusers"10731074# モデルを読み込む1075msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else ""))107610771078if is_load_ckpt:1079v2_model = args.v21080text_encoder, vae, unet = load_models_from_stable_diffusion_checkpoint(v2_model, args.model_to_load)1081else:1082pipe = StableDiffusionPipeline.from_pretrained(args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None)1083text_encoder = pipe.text_encoder1084vae = pipe.vae1085unet = pipe.unet10861087if args.v1 == args.v2:1088# 自動判定する1089v2_model = unet.config.cross_attention_dim == 10241090#print("checking model version: model is " + ('v2' if v2_model else 'v1'))1091else:1092v2_model = args.v110931094# 変換して保存する1095msg = ("checkpoint" + ("" if save_dtype is None else f" in {save_dtype}")) if is_save_ckpt else "Diffusers"109610971098if is_save_ckpt:1099original_model = args.model_to_load if is_load_ckpt else None1100key_count = save_stable_diffusion_checkpoint(v2_model, args.model_to_save, text_encoder, unet,1101original_model, args.epoch, args.global_step, save_dtype, vae)11021103else:1104save_diffusers_checkpoint(v2_model, args.model_to_save, text_encoder, unet, vae)1105110611071108if __name__ == '__main__':1109parser = argparse.ArgumentParser()1110parser.add_argument("--v1", action='store_true',1111help='load v1.x model (v1 or v2 is required to load checkpoint) / 1.xのモデルを読み込む')1112parser.add_argument("--v2", action='store_true',1113help='load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み込む')1114parser.add_argument("--fp16", action='store_true',1115help='load as fp16 (Diffusers only) and save as fp16 (checkpoint only) / fp16形式で読み込み(Diffusers形式のみ対応)、保存する(checkpointのみ対応)')1116parser.add_argument("--bf16", action='store_true', help='save as bf16 (checkpoint only) / bf16形式で保存する(checkpointのみ対応)')1117parser.add_argument("--float", action='store_true',1118help='save as float (checkpoint only) / float(float32)形式で保存する(checkpointのみ対応)')1119parser.add_argument("--epoch", type=int, default=0, help='epoch to write to checkpoint / checkpointに記録するepoch数の値')1120parser.add_argument("--global_step", type=int, default=0,1121help='global_step to write to checkpoint / checkpointに記録するglobal_stepの値')11221123parser.add_argument("model_to_load", type=str, default=None,1124help="model to load: checkpoint file or Diffusers model's directory / 読み込むモデル、checkpointかDiffusers形式モデルのディレクトリ")1125parser.add_argument("model_to_save", type=str, default=None,1126help="model to save: checkpoint (with extension) or Diffusers model's directory (without extension) / 変換後のモデル、拡張子がある場合はcheckpoint、ない場合はDiffusesモデルとして保存")11271128args = parser.parse_args()1129convert(args)113011311132