Path: blob/main/scripts/convert_ddpm_original_checkpoint_to_diffusers.py
1440 views
import argparse1import json23import torch45from diffusers import AutoencoderKL, DDPMPipeline, DDPMScheduler, UNet2DModel, VQModel678def shave_segments(path, n_shave_prefix_segments=1):9"""10Removes segments. Positive values shave the first segments, negative shave the last segments.11"""12if n_shave_prefix_segments >= 0:13return ".".join(path.split(".")[n_shave_prefix_segments:])14else:15return ".".join(path.split(".")[:n_shave_prefix_segments])161718def renew_resnet_paths(old_list, n_shave_prefix_segments=0):19mapping = []20for old_item in old_list:21new_item = old_item22new_item = new_item.replace("block.", "resnets.")23new_item = new_item.replace("conv_shorcut", "conv1")24new_item = new_item.replace("in_shortcut", "conv_shortcut")25new_item = new_item.replace("temb_proj", "time_emb_proj")2627new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)2829mapping.append({"old": old_item, "new": new_item})3031return mapping323334def renew_attention_paths(old_list, n_shave_prefix_segments=0, in_mid=False):35mapping = []36for old_item in old_list:37new_item = old_item3839# In `model.mid`, the layer is called `attn`.40if not in_mid:41new_item = new_item.replace("attn", "attentions")42new_item = new_item.replace(".k.", ".key.")43new_item = new_item.replace(".v.", ".value.")44new_item = new_item.replace(".q.", ".query.")4546new_item = new_item.replace("proj_out", "proj_attn")47new_item = new_item.replace("norm", "group_norm")4849new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)50mapping.append({"old": old_item, "new": new_item})5152return mapping535455def assign_to_checkpoint(56paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None57):58assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."5960if attention_paths_to_split is not None:61if config is None:62raise ValueError("Please specify the config if setting 'attention_paths_to_split' to 'True'.")6364for path, path_map in attention_paths_to_split.items():65old_tensor = old_checkpoint[path]66channels = old_tensor.shape[0] // 36768target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)6970num_heads = old_tensor.shape[0] // config.get("num_head_channels", 1) // 37172old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])73query, key, value = old_tensor.split(channels // num_heads, dim=1)7475checkpoint[path_map["query"]] = query.reshape(target_shape).squeeze()76checkpoint[path_map["key"]] = key.reshape(target_shape).squeeze()77checkpoint[path_map["value"]] = value.reshape(target_shape).squeeze()7879for path in paths:80new_path = path["new"]8182if attention_paths_to_split is not None and new_path in attention_paths_to_split:83continue8485new_path = new_path.replace("down.", "down_blocks.")86new_path = new_path.replace("up.", "up_blocks.")8788if additional_replacements is not None:89for replacement in additional_replacements:90new_path = new_path.replace(replacement["old"], replacement["new"])9192if "attentions" in new_path:93checkpoint[new_path] = old_checkpoint[path["old"]].squeeze()94else:95checkpoint[new_path] = old_checkpoint[path["old"]]969798def convert_ddpm_checkpoint(checkpoint, config):99"""100Takes a state dict and a config, and returns a converted checkpoint.101"""102new_checkpoint = {}103104new_checkpoint["time_embedding.linear_1.weight"] = checkpoint["temb.dense.0.weight"]105new_checkpoint["time_embedding.linear_1.bias"] = checkpoint["temb.dense.0.bias"]106new_checkpoint["time_embedding.linear_2.weight"] = checkpoint["temb.dense.1.weight"]107new_checkpoint["time_embedding.linear_2.bias"] = checkpoint["temb.dense.1.bias"]108109new_checkpoint["conv_norm_out.weight"] = checkpoint["norm_out.weight"]110new_checkpoint["conv_norm_out.bias"] = checkpoint["norm_out.bias"]111112new_checkpoint["conv_in.weight"] = checkpoint["conv_in.weight"]113new_checkpoint["conv_in.bias"] = checkpoint["conv_in.bias"]114new_checkpoint["conv_out.weight"] = checkpoint["conv_out.weight"]115new_checkpoint["conv_out.bias"] = checkpoint["conv_out.bias"]116117num_down_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "down" in layer})118down_blocks = {119layer_id: [key for key in checkpoint if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)120}121122num_up_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "up" in layer})123up_blocks = {layer_id: [key for key in checkpoint if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)}124125for i in range(num_down_blocks):126block_id = (i - 1) // (config["layers_per_block"] + 1)127128if any("downsample" in layer for layer in down_blocks[i]):129new_checkpoint[f"down_blocks.{i}.downsamplers.0.conv.weight"] = checkpoint[130f"down.{i}.downsample.op.weight"131]132new_checkpoint[f"down_blocks.{i}.downsamplers.0.conv.bias"] = checkpoint[f"down.{i}.downsample.op.bias"]133# new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.weight'] = checkpoint[f'down.{i}.downsample.conv.weight']134# new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.bias'] = checkpoint[f'down.{i}.downsample.conv.bias']135136if any("block" in layer for layer in down_blocks[i]):137num_blocks = len(138{".".join(shave_segments(layer, 2).split(".")[:2]) for layer in down_blocks[i] if "block" in layer}139)140blocks = {141layer_id: [key for key in down_blocks[i] if f"block.{layer_id}" in key]142for layer_id in range(num_blocks)143}144145if num_blocks > 0:146for j in range(config["layers_per_block"]):147paths = renew_resnet_paths(blocks[j])148assign_to_checkpoint(paths, new_checkpoint, checkpoint)149150if any("attn" in layer for layer in down_blocks[i]):151num_attn = len(152{".".join(shave_segments(layer, 2).split(".")[:2]) for layer in down_blocks[i] if "attn" in layer}153)154attns = {155layer_id: [key for key in down_blocks[i] if f"attn.{layer_id}" in key]156for layer_id in range(num_blocks)157}158159if num_attn > 0:160for j in range(config["layers_per_block"]):161paths = renew_attention_paths(attns[j])162assign_to_checkpoint(paths, new_checkpoint, checkpoint, config=config)163164mid_block_1_layers = [key for key in checkpoint if "mid.block_1" in key]165mid_block_2_layers = [key for key in checkpoint if "mid.block_2" in key]166mid_attn_1_layers = [key for key in checkpoint if "mid.attn_1" in key]167168# Mid new 2169paths = renew_resnet_paths(mid_block_1_layers)170assign_to_checkpoint(171paths,172new_checkpoint,173checkpoint,174additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "block_1", "new": "resnets.0"}],175)176177paths = renew_resnet_paths(mid_block_2_layers)178assign_to_checkpoint(179paths,180new_checkpoint,181checkpoint,182additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "block_2", "new": "resnets.1"}],183)184185paths = renew_attention_paths(mid_attn_1_layers, in_mid=True)186assign_to_checkpoint(187paths,188new_checkpoint,189checkpoint,190additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "attn_1", "new": "attentions.0"}],191)192193for i in range(num_up_blocks):194block_id = num_up_blocks - 1 - i195196if any("upsample" in layer for layer in up_blocks[i]):197new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = checkpoint[198f"up.{i}.upsample.conv.weight"199]200new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = checkpoint[f"up.{i}.upsample.conv.bias"]201202if any("block" in layer for layer in up_blocks[i]):203num_blocks = len(204{".".join(shave_segments(layer, 2).split(".")[:2]) for layer in up_blocks[i] if "block" in layer}205)206blocks = {207layer_id: [key for key in up_blocks[i] if f"block.{layer_id}" in key] for layer_id in range(num_blocks)208}209210if num_blocks > 0:211for j in range(config["layers_per_block"] + 1):212replace_indices = {"old": f"up_blocks.{i}", "new": f"up_blocks.{block_id}"}213paths = renew_resnet_paths(blocks[j])214assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])215216if any("attn" in layer for layer in up_blocks[i]):217num_attn = len(218{".".join(shave_segments(layer, 2).split(".")[:2]) for layer in up_blocks[i] if "attn" in layer}219)220attns = {221layer_id: [key for key in up_blocks[i] if f"attn.{layer_id}" in key] for layer_id in range(num_blocks)222}223224if num_attn > 0:225for j in range(config["layers_per_block"] + 1):226replace_indices = {"old": f"up_blocks.{i}", "new": f"up_blocks.{block_id}"}227paths = renew_attention_paths(attns[j])228assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])229230new_checkpoint = {k.replace("mid_new_2", "mid_block"): v for k, v in new_checkpoint.items()}231return new_checkpoint232233234def convert_vq_autoenc_checkpoint(checkpoint, config):235"""236Takes a state dict and a config, and returns a converted checkpoint.237"""238new_checkpoint = {}239240new_checkpoint["encoder.conv_norm_out.weight"] = checkpoint["encoder.norm_out.weight"]241new_checkpoint["encoder.conv_norm_out.bias"] = checkpoint["encoder.norm_out.bias"]242243new_checkpoint["encoder.conv_in.weight"] = checkpoint["encoder.conv_in.weight"]244new_checkpoint["encoder.conv_in.bias"] = checkpoint["encoder.conv_in.bias"]245new_checkpoint["encoder.conv_out.weight"] = checkpoint["encoder.conv_out.weight"]246new_checkpoint["encoder.conv_out.bias"] = checkpoint["encoder.conv_out.bias"]247248new_checkpoint["decoder.conv_norm_out.weight"] = checkpoint["decoder.norm_out.weight"]249new_checkpoint["decoder.conv_norm_out.bias"] = checkpoint["decoder.norm_out.bias"]250251new_checkpoint["decoder.conv_in.weight"] = checkpoint["decoder.conv_in.weight"]252new_checkpoint["decoder.conv_in.bias"] = checkpoint["decoder.conv_in.bias"]253new_checkpoint["decoder.conv_out.weight"] = checkpoint["decoder.conv_out.weight"]254new_checkpoint["decoder.conv_out.bias"] = checkpoint["decoder.conv_out.bias"]255256num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in checkpoint if "down" in layer})257down_blocks = {258layer_id: [key for key in checkpoint if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)259}260261num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in checkpoint if "up" in layer})262up_blocks = {layer_id: [key for key in checkpoint if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)}263264for i in range(num_down_blocks):265block_id = (i - 1) // (config["layers_per_block"] + 1)266267if any("downsample" in layer for layer in down_blocks[i]):268new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = checkpoint[269f"encoder.down.{i}.downsample.conv.weight"270]271new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = checkpoint[272f"encoder.down.{i}.downsample.conv.bias"273]274275if any("block" in layer for layer in down_blocks[i]):276num_blocks = len(277{".".join(shave_segments(layer, 3).split(".")[:3]) for layer in down_blocks[i] if "block" in layer}278)279blocks = {280layer_id: [key for key in down_blocks[i] if f"block.{layer_id}" in key]281for layer_id in range(num_blocks)282}283284if num_blocks > 0:285for j in range(config["layers_per_block"]):286paths = renew_resnet_paths(blocks[j])287assign_to_checkpoint(paths, new_checkpoint, checkpoint)288289if any("attn" in layer for layer in down_blocks[i]):290num_attn = len(291{".".join(shave_segments(layer, 3).split(".")[:3]) for layer in down_blocks[i] if "attn" in layer}292)293attns = {294layer_id: [key for key in down_blocks[i] if f"attn.{layer_id}" in key]295for layer_id in range(num_blocks)296}297298if num_attn > 0:299for j in range(config["layers_per_block"]):300paths = renew_attention_paths(attns[j])301assign_to_checkpoint(paths, new_checkpoint, checkpoint, config=config)302303mid_block_1_layers = [key for key in checkpoint if "mid.block_1" in key]304mid_block_2_layers = [key for key in checkpoint if "mid.block_2" in key]305mid_attn_1_layers = [key for key in checkpoint if "mid.attn_1" in key]306307# Mid new 2308paths = renew_resnet_paths(mid_block_1_layers)309assign_to_checkpoint(310paths,311new_checkpoint,312checkpoint,313additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "block_1", "new": "resnets.0"}],314)315316paths = renew_resnet_paths(mid_block_2_layers)317assign_to_checkpoint(318paths,319new_checkpoint,320checkpoint,321additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "block_2", "new": "resnets.1"}],322)323324paths = renew_attention_paths(mid_attn_1_layers, in_mid=True)325assign_to_checkpoint(326paths,327new_checkpoint,328checkpoint,329additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "attn_1", "new": "attentions.0"}],330)331332for i in range(num_up_blocks):333block_id = num_up_blocks - 1 - i334335if any("upsample" in layer for layer in up_blocks[i]):336new_checkpoint[f"decoder.up_blocks.{block_id}.upsamplers.0.conv.weight"] = checkpoint[337f"decoder.up.{i}.upsample.conv.weight"338]339new_checkpoint[f"decoder.up_blocks.{block_id}.upsamplers.0.conv.bias"] = checkpoint[340f"decoder.up.{i}.upsample.conv.bias"341]342343if any("block" in layer for layer in up_blocks[i]):344num_blocks = len(345{".".join(shave_segments(layer, 3).split(".")[:3]) for layer in up_blocks[i] if "block" in layer}346)347blocks = {348layer_id: [key for key in up_blocks[i] if f"block.{layer_id}" in key] for layer_id in range(num_blocks)349}350351if num_blocks > 0:352for j in range(config["layers_per_block"] + 1):353replace_indices = {"old": f"up_blocks.{i}", "new": f"up_blocks.{block_id}"}354paths = renew_resnet_paths(blocks[j])355assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])356357if any("attn" in layer for layer in up_blocks[i]):358num_attn = len(359{".".join(shave_segments(layer, 3).split(".")[:3]) for layer in up_blocks[i] if "attn" in layer}360)361attns = {362layer_id: [key for key in up_blocks[i] if f"attn.{layer_id}" in key] for layer_id in range(num_blocks)363}364365if num_attn > 0:366for j in range(config["layers_per_block"] + 1):367replace_indices = {"old": f"up_blocks.{i}", "new": f"up_blocks.{block_id}"}368paths = renew_attention_paths(attns[j])369assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])370371new_checkpoint = {k.replace("mid_new_2", "mid_block"): v for k, v in new_checkpoint.items()}372new_checkpoint["quant_conv.weight"] = checkpoint["quant_conv.weight"]373new_checkpoint["quant_conv.bias"] = checkpoint["quant_conv.bias"]374if "quantize.embedding.weight" in checkpoint:375new_checkpoint["quantize.embedding.weight"] = checkpoint["quantize.embedding.weight"]376new_checkpoint["post_quant_conv.weight"] = checkpoint["post_quant_conv.weight"]377new_checkpoint["post_quant_conv.bias"] = checkpoint["post_quant_conv.bias"]378379return new_checkpoint380381382if __name__ == "__main__":383parser = argparse.ArgumentParser()384385parser.add_argument(386"--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."387)388389parser.add_argument(390"--config_file",391default=None,392type=str,393required=True,394help="The config json file corresponding to the architecture.",395)396397parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")398399args = parser.parse_args()400checkpoint = torch.load(args.checkpoint_path)401402with open(args.config_file) as f:403config = json.loads(f.read())404405# unet case406key_prefix_set = set(key.split(".")[0] for key in checkpoint.keys())407if "encoder" in key_prefix_set and "decoder" in key_prefix_set:408converted_checkpoint = convert_vq_autoenc_checkpoint(checkpoint, config)409else:410converted_checkpoint = convert_ddpm_checkpoint(checkpoint, config)411412if "ddpm" in config:413del config["ddpm"]414415if config["_class_name"] == "VQModel":416model = VQModel(**config)417model.load_state_dict(converted_checkpoint)418model.save_pretrained(args.dump_path)419elif config["_class_name"] == "AutoencoderKL":420model = AutoencoderKL(**config)421model.load_state_dict(converted_checkpoint)422model.save_pretrained(args.dump_path)423else:424model = UNet2DModel(**config)425model.load_state_dict(converted_checkpoint)426427scheduler = DDPMScheduler.from_config("/".join(args.checkpoint_path.split("/")[:-1]))428429pipe = DDPMPipeline(unet=model, scheduler=scheduler)430pipe.save_pretrained(args.dump_path)431432433