Path: blob/main/Dreambooth/convertodiffv2-768.py
540 views
import argparse1import os2import torch3from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig4from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel567# DiffUsers版StableDiffusionのモデルパラメータ8NUM_TRAIN_TIMESTEPS = 10009BETA_START = 0.0008510BETA_END = 0.01201112UNET_PARAMS_MODEL_CHANNELS = 32013UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]14UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]15UNET_PARAMS_IMAGE_SIZE = 9616UNET_PARAMS_IN_CHANNELS = 417UNET_PARAMS_OUT_CHANNELS = 418UNET_PARAMS_NUM_RES_BLOCKS = 219UNET_PARAMS_CONTEXT_DIM = 76820UNET_PARAMS_NUM_HEADS = 82122VAE_PARAMS_Z_CHANNELS = 423VAE_PARAMS_RESOLUTION = 76824VAE_PARAMS_IN_CHANNELS = 325VAE_PARAMS_OUT_CH = 326VAE_PARAMS_CH = 12827VAE_PARAMS_CH_MULT = [1, 2, 4, 4]28VAE_PARAMS_NUM_RES_BLOCKS = 22930# V231V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]32V2_UNET_PARAMS_CONTEXT_DIM = 1024333435# region StableDiffusion->Diffusersの変換コード36# convert_original_stable_diffusion_to_diffusers をコピーしている(ASL 2.0)373839def shave_segments(path, n_shave_prefix_segments=1):40"""41Removes segments. Positive values shave the first segments, negative shave the last segments.42"""43if n_shave_prefix_segments >= 0:44return ".".join(path.split(".")[n_shave_prefix_segments:])45else:46return ".".join(path.split(".")[:n_shave_prefix_segments])474849def renew_resnet_paths(old_list, n_shave_prefix_segments=0):50"""51Updates paths inside resnets to the new naming scheme (local renaming)52"""53mapping = []54for old_item in old_list:55new_item = old_item.replace("in_layers.0", "norm1")56new_item = new_item.replace("in_layers.2", "conv1")5758new_item = new_item.replace("out_layers.0", "norm2")59new_item = new_item.replace("out_layers.3", "conv2")6061new_item = new_item.replace("emb_layers.1", "time_emb_proj")62new_item = new_item.replace("skip_connection", "conv_shortcut")6364new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)6566mapping.append({"old": old_item, "new": new_item})6768return mapping697071def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):72"""73Updates paths inside resnets to the new naming scheme (local renaming)74"""75mapping = []76for old_item in old_list:77new_item = old_item7879new_item = new_item.replace("nin_shortcut", "conv_shortcut")80new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)8182mapping.append({"old": old_item, "new": new_item})8384return mapping858687def renew_attention_paths(old_list, n_shave_prefix_segments=0):88"""89Updates paths inside attentions to the new naming scheme (local renaming)90"""91mapping = []92for old_item in old_list:93new_item = old_item9495# new_item = new_item.replace('norm.weight', 'group_norm.weight')96# new_item = new_item.replace('norm.bias', 'group_norm.bias')9798# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')99# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')100101# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)102103mapping.append({"old": old_item, "new": new_item})104105return mapping106107108def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):109"""110Updates paths inside attentions to the new naming scheme (local renaming)111"""112mapping = []113for old_item in old_list:114new_item = old_item115116new_item = new_item.replace("norm.weight", "group_norm.weight")117new_item = new_item.replace("norm.bias", "group_norm.bias")118119new_item = new_item.replace("q.weight", "query.weight")120new_item = new_item.replace("q.bias", "query.bias")121122new_item = new_item.replace("k.weight", "key.weight")123new_item = new_item.replace("k.bias", "key.bias")124125new_item = new_item.replace("v.weight", "value.weight")126new_item = new_item.replace("v.bias", "value.bias")127128new_item = new_item.replace("proj_out.weight", "proj_attn.weight")129new_item = new_item.replace("proj_out.bias", "proj_attn.bias")130131new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)132133mapping.append({"old": old_item, "new": new_item})134135return mapping136137138def assign_to_checkpoint(139paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None140):141"""142This does the final conversion step: take locally converted weights and apply a global renaming143to them. It splits attention layers, and takes into account additional replacements144that may arise.145146Assigns the weights to the new checkpoint.147"""148assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."149150# Splits the attention layers into three variables.151if attention_paths_to_split is not None:152for path, path_map in attention_paths_to_split.items():153old_tensor = old_checkpoint[path]154channels = old_tensor.shape[0] // 3155156target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)157158num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3159160old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])161query, key, value = old_tensor.split(channels // num_heads, dim=1)162163checkpoint[path_map["query"]] = query.reshape(target_shape)164checkpoint[path_map["key"]] = key.reshape(target_shape)165checkpoint[path_map["value"]] = value.reshape(target_shape)166167for path in paths:168new_path = path["new"]169170# These have already been assigned171if attention_paths_to_split is not None and new_path in attention_paths_to_split:172continue173174# Global renaming happens here175new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")176new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")177new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")178179if additional_replacements is not None:180for replacement in additional_replacements:181new_path = new_path.replace(replacement["old"], replacement["new"])182183# proj_attn.weight has to be converted from conv 1D to linear184if "proj_attn.weight" in new_path:185checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]186else:187checkpoint[new_path] = old_checkpoint[path["old"]]188189190def conv_attn_to_linear(checkpoint):191keys = list(checkpoint.keys())192attn_keys = ["query.weight", "key.weight", "value.weight"]193for key in keys:194if ".".join(key.split(".")[-2:]) in attn_keys:195if checkpoint[key].ndim > 2:196checkpoint[key] = checkpoint[key][:, :, 0, 0]197elif "proj_attn.weight" in key:198if checkpoint[key].ndim > 2:199checkpoint[key] = checkpoint[key][:, :, 0]200201202def linear_transformer_to_conv(checkpoint):203keys = list(checkpoint.keys())204tf_keys = ["proj_in.weight", "proj_out.weight"]205for key in keys:206if ".".join(key.split(".")[-2:]) in tf_keys:207if checkpoint[key].ndim == 2:208checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)209210211def convert_ldm_unet_checkpoint(checkpoint, config):212"""213Takes a state dict and a config, and returns a converted checkpoint.214"""215216# extract state_dict for UNet217unet_state_dict = {}218keys = list(checkpoint.keys())219220unet_key = "model.diffusion_model."221222for key in keys:223if key.startswith(unet_key):224unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)225226new_checkpoint = {}227228new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]229new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]230new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]231new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]232233new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]234new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]235236new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]237new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]238new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]239new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]240241# Retrieves the keys for the input blocks only242num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})243input_blocks = {244layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]245for layer_id in range(num_input_blocks)246}247248# Retrieves the keys for the middle blocks only249num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})250middle_blocks = {251layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]252for layer_id in range(num_middle_blocks)253}254255# Retrieves the keys for the output blocks only256num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})257output_blocks = {258layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]259for layer_id in range(num_output_blocks)260}261262for i in range(1, num_input_blocks):263block_id = (i - 1) // (config["layers_per_block"] + 1)264layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)265266resnets = [267key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key268]269attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]270271if f"input_blocks.{i}.0.op.weight" in unet_state_dict:272new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(273f"input_blocks.{i}.0.op.weight"274)275new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(276f"input_blocks.{i}.0.op.bias"277)278279paths = renew_resnet_paths(resnets)280meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}281assign_to_checkpoint(282paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config283)284285if len(attentions):286paths = renew_attention_paths(attentions)287meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}288assign_to_checkpoint(289paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config290)291292resnet_0 = middle_blocks[0]293attentions = middle_blocks[1]294resnet_1 = middle_blocks[2]295296resnet_0_paths = renew_resnet_paths(resnet_0)297assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)298299resnet_1_paths = renew_resnet_paths(resnet_1)300assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)301302attentions_paths = renew_attention_paths(attentions)303meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}304assign_to_checkpoint(305attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config306)307308for i in range(num_output_blocks):309block_id = i // (config["layers_per_block"] + 1)310layer_in_block_id = i % (config["layers_per_block"] + 1)311output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]312output_block_list = {}313314for layer in output_block_layers:315layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)316if layer_id in output_block_list:317output_block_list[layer_id].append(layer_name)318else:319output_block_list[layer_id] = [layer_name]320321if len(output_block_list) > 1:322resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]323attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]324325resnet_0_paths = renew_resnet_paths(resnets)326paths = renew_resnet_paths(resnets)327328meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}329assign_to_checkpoint(330paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config331)332333output_block_list = {k: sorted(v) for k, v in output_block_list.items()}334if ["conv.bias", "conv.weight"] in output_block_list.values():335index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])336new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[337f"output_blocks.{i}.{index}.conv.weight"338]339new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[340f"output_blocks.{i}.{index}.conv.bias"341]342343# Clear attentions as they have been attributed above.344if len(attentions) == 2:345attentions = []346347if len(attentions):348paths = renew_attention_paths(attentions)349meta_path = {350"old": f"output_blocks.{i}.1",351"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",352}353assign_to_checkpoint(354paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config355)356else:357resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)358for path in resnet_0_paths:359old_path = ".".join(["output_blocks", str(i), path["old"]])360new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])361362new_checkpoint[new_path] = unet_state_dict[old_path]363364return new_checkpoint365366367def convert_ldm_vae_checkpoint(checkpoint, config):368# extract state dict for VAE369vae_state_dict = {}370vae_key = "first_stage_model."371keys = list(checkpoint.keys())372for key in keys:373if key.startswith(vae_key):374vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)375# if len(vae_state_dict) == 0:376# # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict377# vae_state_dict = checkpoint378379new_checkpoint = {}380381new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]382new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]383new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]384new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]385new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]386new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]387388new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]389new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]390new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]391new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]392new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]393new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]394395new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]396new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]397new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]398new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]399400# Retrieves the keys for the encoder down blocks only401num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})402down_blocks = {403layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)404}405406# Retrieves the keys for the decoder up blocks only407num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})408up_blocks = {409layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)410}411412for i in range(num_down_blocks):413resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]414415if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:416new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(417f"encoder.down.{i}.downsample.conv.weight"418)419new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(420f"encoder.down.{i}.downsample.conv.bias"421)422423paths = renew_vae_resnet_paths(resnets)424meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}425assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)426427mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]428num_mid_res_blocks = 2429for i in range(1, num_mid_res_blocks + 1):430resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]431432paths = renew_vae_resnet_paths(resnets)433meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}434assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)435436mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]437paths = renew_vae_attention_paths(mid_attentions)438meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}439assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)440conv_attn_to_linear(new_checkpoint)441442for i in range(num_up_blocks):443block_id = num_up_blocks - 1 - i444resnets = [445key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key446]447448if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:449new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[450f"decoder.up.{block_id}.upsample.conv.weight"451]452new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[453f"decoder.up.{block_id}.upsample.conv.bias"454]455456paths = renew_vae_resnet_paths(resnets)457meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}458assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)459460mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]461num_mid_res_blocks = 2462for i in range(1, num_mid_res_blocks + 1):463resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]464465paths = renew_vae_resnet_paths(resnets)466meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}467assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)468469mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]470paths = renew_vae_attention_paths(mid_attentions)471meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}472assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)473conv_attn_to_linear(new_checkpoint)474return new_checkpoint475476477def create_unet_diffusers_config():478"""479Creates a config for the diffusers based on the config of the LDM model.480"""481#unet_params = original_config.model.params.unet_config.params482483#vae_params = original_config.model.params.first_stage_config.params.ddconfig484485block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT]486487down_block_types = []488resolution = 1489for i in range(len(block_out_channels)):490block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D"491down_block_types.append(block_type)492if i != len(block_out_channels) - 1:493resolution *= 2494495up_block_types = []496for i in range(len(block_out_channels)):497block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D"498up_block_types.append(block_type)499resolution //= 2500501class_embed_type = None502projection_class_embeddings_input_dim = None503504config = dict(505sample_size=UNET_PARAMS_IMAGE_SIZE,506in_channels=UNET_PARAMS_IN_CHANNELS,507out_channels=UNET_PARAMS_OUT_CHANNELS,508down_block_types=tuple(down_block_types),509up_block_types=tuple(up_block_types),510block_out_channels=tuple(block_out_channels),511layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,512use_linear_projection=True,513cross_attention_dim=V2_UNET_PARAMS_CONTEXT_DIM,514attention_head_dim=V2_UNET_PARAMS_ATTENTION_HEAD_DIM,515)516517return config518519520def create_vae_diffusers_config():521"""522Creates a config for the diffusers based on the config of the LDM model.523"""524# vae_params = original_config.model.params.first_stage_config.params.ddconfig525# _ = original_config.model.params.first_stage_config.params.embed_dim526block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT]527down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)528up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)529530config = dict(531sample_size=VAE_PARAMS_RESOLUTION,532in_channels=VAE_PARAMS_IN_CHANNELS,533out_channels=VAE_PARAMS_OUT_CH,534down_block_types=tuple(down_block_types),535up_block_types=tuple(up_block_types),536block_out_channels=tuple(block_out_channels),537latent_channels=VAE_PARAMS_Z_CHANNELS,538layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS,539)540return config541542543def convert_ldm_clip_checkpoint_v1(checkpoint):544keys = list(checkpoint.keys())545text_model_dict = {}546for key in keys:547if key.startswith("cond_stage_model.transformer"):548text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key]549return text_model_dict550551552def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):553# 嫌になるくらい違うぞ!554def convert_key(key):555if not key.startswith("cond_stage_model"):556return None557558# common conversion559key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.")560key = key.replace("cond_stage_model.model.", "text_model.")561562if "resblocks" in key:563# resblocks conversion564key = key.replace(".resblocks.", ".layers.")565if ".ln_" in key:566key = key.replace(".ln_", ".layer_norm")567elif ".mlp." in key:568key = key.replace(".c_fc.", ".fc1.")569key = key.replace(".c_proj.", ".fc2.")570elif '.attn.out_proj' in key:571key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")572elif '.attn.in_proj' in key:573key = None # 特殊なので後で処理する574else:575raise ValueError(f"unexpected key in SD: {key}")576elif '.positional_embedding' in key:577key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")578elif '.text_projection' in key:579key = None # 使われない???580elif '.logit_scale' in key:581key = None # 使われない???582elif '.token_embedding' in key:583key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")584elif '.ln_final' in key:585key = key.replace(".ln_final", ".final_layer_norm")586return key587588keys = list(checkpoint.keys())589new_sd = {}590for key in keys:591# remove resblocks 23592if '.resblocks.23.' in key:593continue594new_key = convert_key(key)595if new_key is None:596continue597new_sd[new_key] = checkpoint[key]598599# attnの変換600for key in keys:601if '.resblocks.23.' in key:602continue603if '.resblocks' in key and '.attn.in_proj_' in key:604# 三つに分割605values = torch.chunk(checkpoint[key], 3)606607key_suffix = ".weight" if "weight" in key else ".bias"608key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.")609key_pfx = key_pfx.replace("_weight", "")610key_pfx = key_pfx.replace("_bias", "")611key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")612new_sd[key_pfx + "q_proj" + key_suffix] = values[0]613new_sd[key_pfx + "k_proj" + key_suffix] = values[1]614new_sd[key_pfx + "v_proj" + key_suffix] = values[2]615616# position_idsの追加617new_sd["text_model.embeddings.position_ids"] = torch.Tensor([list(range(max_length))]).to(torch.int64)618return new_sd619620# endregion621622623# region Diffusers->StableDiffusion の変換コード624# convert_diffusers_to_original_stable_diffusion をコピーしている(ASL 2.0)625626def conv_transformer_to_linear(checkpoint):627keys = list(checkpoint.keys())628tf_keys = ["proj_in.weight", "proj_out.weight"]629for key in keys:630if ".".join(key.split(".")[-2:]) in tf_keys:631if checkpoint[key].ndim > 2:632checkpoint[key] = checkpoint[key][:, :, 0, 0]633634635def convert_unet_state_dict_to_sd(v2, unet_state_dict):636unet_conversion_map = [637# (stable-diffusion, HF Diffusers)638("time_embed.0.weight", "time_embedding.linear_1.weight"),639("time_embed.0.bias", "time_embedding.linear_1.bias"),640("time_embed.2.weight", "time_embedding.linear_2.weight"),641("time_embed.2.bias", "time_embedding.linear_2.bias"),642("input_blocks.0.0.weight", "conv_in.weight"),643("input_blocks.0.0.bias", "conv_in.bias"),644("out.0.weight", "conv_norm_out.weight"),645("out.0.bias", "conv_norm_out.bias"),646("out.2.weight", "conv_out.weight"),647("out.2.bias", "conv_out.bias"),648]649650unet_conversion_map_resnet = [651# (stable-diffusion, HF Diffusers)652("in_layers.0", "norm1"),653("in_layers.2", "conv1"),654("out_layers.0", "norm2"),655("out_layers.3", "conv2"),656("emb_layers.1", "time_emb_proj"),657("skip_connection", "conv_shortcut"),658]659660unet_conversion_map_layer = []661for i in range(4):662# loop over downblocks/upblocks663664for j in range(2):665# loop over resnets/attentions for downblocks666hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."667sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."668unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))669670if i < 3:671# no attention layers in down_blocks.3672hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."673sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."674unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))675676for j in range(3):677# loop over resnets/attentions for upblocks678hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."679sd_up_res_prefix = f"output_blocks.{3*i + j}.0."680unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))681682if i > 0:683# no attention layers in up_blocks.0684hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."685sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."686unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))687688if i < 3:689# no downsample in down_blocks.3690hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."691sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."692unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))693694# no upsample in up_blocks.3695hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."696sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."697unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))698699hf_mid_atn_prefix = "mid_block.attentions.0."700sd_mid_atn_prefix = "middle_block.1."701unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))702703for j in range(2):704hf_mid_res_prefix = f"mid_block.resnets.{j}."705sd_mid_res_prefix = f"middle_block.{2*j}."706unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))707708# buyer beware: this is a *brittle* function,709# and correct output requires that all of these pieces interact in710# the exact order in which I have arranged them.711mapping = {k: k for k in unet_state_dict.keys()}712for sd_name, hf_name in unet_conversion_map:713mapping[hf_name] = sd_name714for k, v in mapping.items():715if "resnets" in k:716for sd_part, hf_part in unet_conversion_map_resnet:717v = v.replace(hf_part, sd_part)718mapping[k] = v719for k, v in mapping.items():720for sd_part, hf_part in unet_conversion_map_layer:721v = v.replace(hf_part, sd_part)722mapping[k] = v723new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}724725if v2:726conv_transformer_to_linear(new_state_dict)727728return new_state_dict729730731# ================#732# VAE Conversion #733# ================#734735def reshape_weight_for_sd(w):736# convert HF linear weights to SD conv2d weights737return w.reshape(*w.shape, 1, 1)738739740def convert_vae_state_dict(vae_state_dict):741vae_conversion_map = [742# (stable-diffusion, HF Diffusers)743("nin_shortcut", "conv_shortcut"),744("norm_out", "conv_norm_out"),745("mid.attn_1.", "mid_block.attentions.0."),746]747748for i in range(4):749# down_blocks have two resnets750for j in range(2):751hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."752sd_down_prefix = f"encoder.down.{i}.block.{j}."753vae_conversion_map.append((sd_down_prefix, hf_down_prefix))754755if i < 3:756hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."757sd_downsample_prefix = f"down.{i}.downsample."758vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))759760hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."761sd_upsample_prefix = f"up.{3-i}.upsample."762vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))763764# up_blocks have three resnets765# also, up blocks in hf are numbered in reverse from sd766for j in range(3):767hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."768sd_up_prefix = f"decoder.up.{3-i}.block.{j}."769vae_conversion_map.append((sd_up_prefix, hf_up_prefix))770771# this part accounts for mid blocks in both the encoder and the decoder772for i in range(2):773hf_mid_res_prefix = f"mid_block.resnets.{i}."774sd_mid_res_prefix = f"mid.block_{i+1}."775vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))776777vae_conversion_map_attn = [778# (stable-diffusion, HF Diffusers)779("norm.", "group_norm."),780("q.", "query."),781("k.", "key."),782("v.", "value."),783("proj_out.", "proj_attn."),784]785786mapping = {k: k for k in vae_state_dict.keys()}787for k, v in mapping.items():788for sd_part, hf_part in vae_conversion_map:789v = v.replace(hf_part, sd_part)790mapping[k] = v791for k, v in mapping.items():792if "attentions" in k:793for sd_part, hf_part in vae_conversion_map_attn:794v = v.replace(hf_part, sd_part)795mapping[k] = v796new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}797weights_to_convert = ["q", "k", "v", "proj_out"]798799for k, v in new_state_dict.items():800for weight_name in weights_to_convert:801if f"mid.attn_1.{weight_name}.weight" in k:802new_state_dict[k] = reshape_weight_for_sd(v)803804return new_state_dict805806807# endregion808809810def load_checkpoint_with_text_encoder_conversion(ckpt_path):811# text encoderの格納形式が違うモデルに対応する ('text_model'がない)812TEXT_ENCODER_KEY_REPLACEMENTS = [813('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'),814('cond_stage_model.transformer.encoder.', 'cond_stage_model.transformer.text_model.encoder.'),815('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.')816]817818if args.from_safetensors:819from safetensors import safe_open820821checkpoint = {}822with safe_open(ckpt_path, framework="pt", device="cuda") as f:823for key in f.keys():824checkpoint[key] = f.get_tensor(key)825state_dict = checkpoint826else:827checkpoint = torch.load(ckpt_path, map_location="cuda")828state_dict = checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint829830#while "state_dict" in checkpoint:831# checkpoint = checkpoint["state_dict"]832#else:833# state_dict = checkpoint834835key_reps = []836for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:837for key in state_dict.keys():838if key.startswith(rep_from):839new_key = rep_to + key[len(rep_from):]840key_reps.append((key, new_key))841842for key, new_key in key_reps:843state_dict[new_key] = state_dict[key]844del state_dict[key]845846return checkpoint847848849# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認850def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):851852checkpoint = load_checkpoint_with_text_encoder_conversion(ckpt_path)853854while "state_dict" in checkpoint:855checkpoint = checkpoint["state_dict"]856else:857state_dict = checkpoint858859if dtype is not None:860for k, v in state_dict.items():861if type(v) is torch.Tensor:862state_dict[k] = v.to(dtype)863864# Convert the UNet2DConditionModel model.865unet_config = create_unet_diffusers_config()866unet_config["upcast_attention"] = True867converted_unet_checkpoint = convert_ldm_unet_checkpoint(state_dict, unet_config)868869unet = UNet2DConditionModel(**unet_config)870info = unet.load_state_dict(converted_unet_checkpoint)871872873# Convert the VAE model.874vae_config = create_vae_diffusers_config()875converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)876877vae = AutoencoderKL(**vae_config)878info = vae.load_state_dict(converted_vae_checkpoint)879880881# convert text_model882if v2:883converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77)884cfg = CLIPTextConfig(885vocab_size=49408,886hidden_size=1024,887intermediate_size=4096,888num_hidden_layers=23,889num_attention_heads=16,890max_position_embeddings=77,891hidden_act="gelu",892layer_norm_eps=1e-05,893dropout=0.0,894attention_dropout=0.0,895initializer_range=0.02,896initializer_factor=1.0,897pad_token_id=1,898bos_token_id=0,899eos_token_id=2,900model_type="clip_text_model",901projection_dim=512,902torch_dtype="float32",903transformers_version="4.25.0.dev0",904)905text_model = CLIPTextModel._from_config(cfg)906info = text_model.load_state_dict(converted_text_encoder_checkpoint)907else:908converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)909text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")910info = text_model.load_state_dict(converted_text_encoder_checkpoint)911912913return text_model, vae, unet914915916def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False):917def convert_key(key):918# position_idsの除去919if ".position_ids" in key:920return None921922# common923key = key.replace("text_model.encoder.", "transformer.")924key = key.replace("text_model.", "")925if "layers" in key:926# resblocks conversion927key = key.replace(".layers.", ".resblocks.")928if ".layer_norm" in key:929key = key.replace(".layer_norm", ".ln_")930elif ".mlp." in key:931key = key.replace(".fc1.", ".c_fc.")932key = key.replace(".fc2.", ".c_proj.")933elif '.self_attn.out_proj' in key:934key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")935elif '.self_attn.' in key:936key = None # 特殊なので後で処理する937else:938raise ValueError(f"unexpected key in DiffUsers model: {key}")939elif '.position_embedding' in key:940key = key.replace("embeddings.position_embedding.weight", "positional_embedding")941elif '.token_embedding' in key:942key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")943elif 'final_layer_norm' in key:944key = key.replace("final_layer_norm", "ln_final")945return key946947keys = list(checkpoint.keys())948new_sd = {}949for key in keys:950new_key = convert_key(key)951if new_key is None:952continue953new_sd[new_key] = checkpoint[key]954955# attnの変換956for key in keys:957if 'layers' in key and 'q_proj' in key:958# 三つを結合959key_q = key960key_k = key.replace("q_proj", "k_proj")961key_v = key.replace("q_proj", "v_proj")962963value_q = checkpoint[key_q]964value_k = checkpoint[key_k]965value_v = checkpoint[key_v]966value = torch.cat([value_q, value_k, value_v])967968new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")969new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")970new_sd[new_key] = value971972# 最後の層などを捏造するか973if make_dummy_weights:974975keys = list(new_sd.keys())976for key in keys:977if key.startswith("transformer.resblocks.22."):978new_sd[key.replace(".22.", ".23.")] = new_sd[key]979980# Diffusersに含まれない重みを作っておく981new_sd['text_projection'] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device)982new_sd['logit_scale'] = torch.tensor(1)983984return new_sd985986987def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None):988if ckpt_path is not None:989# epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む990checkpoint = load_checkpoint_with_text_encoder_conversion(ckpt_path)991state_dict = checkpoint["state_dict"]992strict = True993else:994# 新しく作る995checkpoint = {}996state_dict = {}997strict = False998999def update_sd(prefix, sd):1000for k, v in sd.items():1001key = prefix + k1002assert not strict or key in state_dict, f"Illegal key in save SD: {key}"1003if save_dtype is not None:1004v = v.detach().clone().to("cpu").to(save_dtype)1005state_dict[key] = v10061007# Convert the UNet model1008unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict())1009update_sd("model.diffusion_model.", unet_state_dict)10101011# Convert the text encoder model1012if v2:1013make_dummy = ckpt_path is None # 参照元のcheckpointがない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる1014text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy)1015update_sd("cond_stage_model.model.", text_enc_dict)1016else:1017text_enc_dict = text_encoder.state_dict()1018update_sd("cond_stage_model.transformer.", text_enc_dict)10191020# Convert the VAE1021if vae is not None:1022vae_dict = convert_vae_state_dict(vae.state_dict())1023update_sd("first_stage_model.", vae_dict)10241025# Put together new checkpoint1026key_count = len(state_dict.keys())1027new_ckpt = {'state_dict': state_dict}10281029if 'epoch' in checkpoint:1030epochs += checkpoint['epoch']1031if 'global_step' in checkpoint:1032steps += checkpoint['global_step']10331034new_ckpt['epoch'] = epochs1035new_ckpt['global_step'] = steps10361037torch.save(new_ckpt, output_file)10381039return key_count104010411042def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None):1043if vae is None:1044vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")1045pipeline = StableDiffusionPipeline(1046unet=unet,1047text_encoder=text_encoder,1048vae=vae,1049scheduler=DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler"),1050tokenizer=CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer"),1051safety_checker=None,1052feature_extractor=None,1053)1054pipeline.save_pretrained(output_dir)1055105610571058def convert(args):1059print("[1;32mConverting to Diffusers ...")1060load_dtype = torch.float16 if args.fp16 else None10611062save_dtype = None1063if args.fp16:1064save_dtype = torch.float161065elif args.bf16:1066save_dtype = torch.bfloat161067elif args.float:1068save_dtype = torch.float10691070is_load_ckpt = os.path.isfile(args.model_to_load)1071is_save_ckpt = len(os.path.splitext(args.model_to_save)[1]) > 010721073assert not is_load_ckpt or args.v1 != args.v2, f"v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です"1074assert is_save_ckpt or args.reference_model is not None, f"reference model is required to save as Diffusers / Diffusers形式での保存には参照モデルが必要です"10751076# モデルを読み込む1077msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else ""))107810791080if is_load_ckpt:1081v2_model = args.v21082text_encoder, vae, unet = load_models_from_stable_diffusion_checkpoint(v2_model, args.model_to_load)1083else:1084pipe = StableDiffusionPipeline.from_pretrained(args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None)1085text_encoder = pipe.text_encoder1086vae = pipe.vae1087unet = pipe.unet10881089if args.v1 == args.v2:1090# 自動判定する1091v2_model = unet.config.cross_attention_dim == 10241092#print("checking model version: model is " + ('v2' if v2_model else 'v1'))1093else:1094v2_model = args.v110951096# 変換して保存する1097msg = ("checkpoint" + ("" if save_dtype is None else f" in {save_dtype}")) if is_save_ckpt else "Diffusers"109810991100if is_save_ckpt:1101original_model = args.model_to_load if is_load_ckpt else None1102key_count = save_stable_diffusion_checkpoint(v2_model, args.model_to_save, text_encoder, unet,1103original_model, args.epoch, args.global_step, save_dtype, vae)11041105else:1106save_diffusers_checkpoint(v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae)1107110811091110if __name__ == '__main__':1111parser = argparse.ArgumentParser()1112parser.add_argument("--v1", action='store_true',1113help='load v1.x model (v1 or v2 is required to load checkpoint) / 1.xのモデルを読み込む')1114parser.add_argument("--v2", action='store_true',1115help='load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み込む')1116parser.add_argument("--fp16", action='store_true',1117help='load as fp16 (Diffusers only) and save as fp16 (checkpoint only) / fp16形式で読み込み(Diffusers形式のみ対応)、保存する(checkpointのみ対応)')1118parser.add_argument("--bf16", action='store_true', help='save as bf16 (checkpoint only) / bf16形式で保存する(checkpointのみ対応)')1119parser.add_argument("--float", action='store_true',1120help='save as float (checkpoint only) / float(float32)形式で保存する(checkpointのみ対応)')1121parser.add_argument("--epoch", type=int, default=0, help='epoch to write to checkpoint / checkpointに記録するepoch数の値')1122parser.add_argument("--global_step", type=int, default=0,1123help='global_step to write to checkpoint / checkpointに記録するglobal_stepの値')1124parser.add_argument("--reference_model", type=str, default=None,1125help="reference model for schduler/tokenizer, required in saving Diffusers, copy schduler/tokenizer from this / scheduler/tokenizerのコピー元のDiffusersモデル、Diffusers形式で保存するときに必要")11261127parser.add_argument("model_to_load", type=str, default=None,1128help="model to load: checkpoint file or Diffusers model's directory / 読み込むモデル、checkpointかDiffusers形式モデルのディレクトリ")1129parser.add_argument("model_to_save", type=str, default=None,1130help="model to save: checkpoint (with extension) or Diffusers model's directory (without extension) / 変換後のモデル、拡張子がある場合はcheckpoint、ない場合はDiffusesモデルとして保存")1131parser.add_argument(1132"--from_safetensors",1133action="store_true",1134help="If `--checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.",1135)1136args = parser.parse_args()1137convert(args)113811391140