Path: blob/main/scripts/convert_ldm_original_checkpoint_to_diffusers.py
1440 views
# coding=utf-81# Copyright 2023 The HuggingFace Inc. team.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14""" Conversion script for the LDM checkpoints. """1516import argparse17import json1819import torch2021from diffusers import DDPMScheduler, LDMPipeline, UNet2DModel, VQModel222324def shave_segments(path, n_shave_prefix_segments=1):25"""26Removes segments. Positive values shave the first segments, negative shave the last segments.27"""28if n_shave_prefix_segments >= 0:29return ".".join(path.split(".")[n_shave_prefix_segments:])30else:31return ".".join(path.split(".")[:n_shave_prefix_segments])323334def renew_resnet_paths(old_list, n_shave_prefix_segments=0):35"""36Updates paths inside resnets to the new naming scheme (local renaming)37"""38mapping = []39for old_item in old_list:40new_item = old_item.replace("in_layers.0", "norm1")41new_item = new_item.replace("in_layers.2", "conv1")4243new_item = new_item.replace("out_layers.0", "norm2")44new_item = new_item.replace("out_layers.3", "conv2")4546new_item = new_item.replace("emb_layers.1", "time_emb_proj")47new_item = new_item.replace("skip_connection", "conv_shortcut")4849new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)5051mapping.append({"old": old_item, "new": new_item})5253return mapping545556def renew_attention_paths(old_list, n_shave_prefix_segments=0):57"""58Updates paths inside attentions to the new naming scheme (local renaming)59"""60mapping = []61for old_item in old_list:62new_item = old_item6364new_item = new_item.replace("norm.weight", "group_norm.weight")65new_item = new_item.replace("norm.bias", "group_norm.bias")6667new_item = new_item.replace("proj_out.weight", "proj_attn.weight")68new_item = new_item.replace("proj_out.bias", "proj_attn.bias")6970new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)7172mapping.append({"old": old_item, "new": new_item})7374return mapping757677def assign_to_checkpoint(78paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None79):80"""81This does the final conversion step: take locally converted weights and apply a global renaming82to them. It splits attention layers, and takes into account additional replacements83that may arise.8485Assigns the weights to the new checkpoint.86"""87assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."8889# Splits the attention layers into three variables.90if attention_paths_to_split is not None:91for path, path_map in attention_paths_to_split.items():92old_tensor = old_checkpoint[path]93channels = old_tensor.shape[0] // 39495target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)9697num_heads = old_tensor.shape[0] // config["num_head_channels"] // 39899old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])100query, key, value = old_tensor.split(channels // num_heads, dim=1)101102checkpoint[path_map["query"]] = query.reshape(target_shape)103checkpoint[path_map["key"]] = key.reshape(target_shape)104checkpoint[path_map["value"]] = value.reshape(target_shape)105106for path in paths:107new_path = path["new"]108109# These have already been assigned110if attention_paths_to_split is not None and new_path in attention_paths_to_split:111continue112113# Global renaming happens here114new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")115new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")116new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")117118if additional_replacements is not None:119for replacement in additional_replacements:120new_path = new_path.replace(replacement["old"], replacement["new"])121122# proj_attn.weight has to be converted from conv 1D to linear123if "proj_attn.weight" in new_path:124checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]125else:126checkpoint[new_path] = old_checkpoint[path["old"]]127128129def convert_ldm_checkpoint(checkpoint, config):130"""131Takes a state dict and a config, and returns a converted checkpoint.132"""133new_checkpoint = {}134135new_checkpoint["time_embedding.linear_1.weight"] = checkpoint["time_embed.0.weight"]136new_checkpoint["time_embedding.linear_1.bias"] = checkpoint["time_embed.0.bias"]137new_checkpoint["time_embedding.linear_2.weight"] = checkpoint["time_embed.2.weight"]138new_checkpoint["time_embedding.linear_2.bias"] = checkpoint["time_embed.2.bias"]139140new_checkpoint["conv_in.weight"] = checkpoint["input_blocks.0.0.weight"]141new_checkpoint["conv_in.bias"] = checkpoint["input_blocks.0.0.bias"]142143new_checkpoint["conv_norm_out.weight"] = checkpoint["out.0.weight"]144new_checkpoint["conv_norm_out.bias"] = checkpoint["out.0.bias"]145new_checkpoint["conv_out.weight"] = checkpoint["out.2.weight"]146new_checkpoint["conv_out.bias"] = checkpoint["out.2.bias"]147148# Retrieves the keys for the input blocks only149num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "input_blocks" in layer})150input_blocks = {151layer_id: [key for key in checkpoint if f"input_blocks.{layer_id}" in key]152for layer_id in range(num_input_blocks)153}154155# Retrieves the keys for the middle blocks only156num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "middle_block" in layer})157middle_blocks = {158layer_id: [key for key in checkpoint if f"middle_block.{layer_id}" in key]159for layer_id in range(num_middle_blocks)160}161162# Retrieves the keys for the output blocks only163num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "output_blocks" in layer})164output_blocks = {165layer_id: [key for key in checkpoint if f"output_blocks.{layer_id}" in key]166for layer_id in range(num_output_blocks)167}168169for i in range(1, num_input_blocks):170block_id = (i - 1) // (config["num_res_blocks"] + 1)171layer_in_block_id = (i - 1) % (config["num_res_blocks"] + 1)172173resnets = [key for key in input_blocks[i] if f"input_blocks.{i}.0" in key]174attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]175176if f"input_blocks.{i}.0.op.weight" in checkpoint:177new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = checkpoint[178f"input_blocks.{i}.0.op.weight"179]180new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = checkpoint[181f"input_blocks.{i}.0.op.bias"182]183continue184185paths = renew_resnet_paths(resnets)186meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}187resnet_op = {"old": "resnets.2.op", "new": "downsamplers.0.op"}188assign_to_checkpoint(189paths, new_checkpoint, checkpoint, additional_replacements=[meta_path, resnet_op], config=config190)191192if len(attentions):193paths = renew_attention_paths(attentions)194meta_path = {195"old": f"input_blocks.{i}.1",196"new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}",197}198to_split = {199f"input_blocks.{i}.1.qkv.bias": {200"key": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias",201"query": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias",202"value": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias",203},204f"input_blocks.{i}.1.qkv.weight": {205"key": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight",206"query": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight",207"value": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight",208},209}210assign_to_checkpoint(211paths,212new_checkpoint,213checkpoint,214additional_replacements=[meta_path],215attention_paths_to_split=to_split,216config=config,217)218219resnet_0 = middle_blocks[0]220attentions = middle_blocks[1]221resnet_1 = middle_blocks[2]222223resnet_0_paths = renew_resnet_paths(resnet_0)224assign_to_checkpoint(resnet_0_paths, new_checkpoint, checkpoint, config=config)225226resnet_1_paths = renew_resnet_paths(resnet_1)227assign_to_checkpoint(resnet_1_paths, new_checkpoint, checkpoint, config=config)228229attentions_paths = renew_attention_paths(attentions)230to_split = {231"middle_block.1.qkv.bias": {232"key": "mid_block.attentions.0.key.bias",233"query": "mid_block.attentions.0.query.bias",234"value": "mid_block.attentions.0.value.bias",235},236"middle_block.1.qkv.weight": {237"key": "mid_block.attentions.0.key.weight",238"query": "mid_block.attentions.0.query.weight",239"value": "mid_block.attentions.0.value.weight",240},241}242assign_to_checkpoint(243attentions_paths, new_checkpoint, checkpoint, attention_paths_to_split=to_split, config=config244)245246for i in range(num_output_blocks):247block_id = i // (config["num_res_blocks"] + 1)248layer_in_block_id = i % (config["num_res_blocks"] + 1)249output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]250output_block_list = {}251252for layer in output_block_layers:253layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)254if layer_id in output_block_list:255output_block_list[layer_id].append(layer_name)256else:257output_block_list[layer_id] = [layer_name]258259if len(output_block_list) > 1:260resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]261attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]262263resnet_0_paths = renew_resnet_paths(resnets)264paths = renew_resnet_paths(resnets)265266meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}267assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[meta_path], config=config)268269if ["conv.weight", "conv.bias"] in output_block_list.values():270index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])271new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = checkpoint[272f"output_blocks.{i}.{index}.conv.weight"273]274new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = checkpoint[275f"output_blocks.{i}.{index}.conv.bias"276]277278# Clear attentions as they have been attributed above.279if len(attentions) == 2:280attentions = []281282if len(attentions):283paths = renew_attention_paths(attentions)284meta_path = {285"old": f"output_blocks.{i}.1",286"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",287}288to_split = {289f"output_blocks.{i}.1.qkv.bias": {290"key": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias",291"query": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias",292"value": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias",293},294f"output_blocks.{i}.1.qkv.weight": {295"key": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight",296"query": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight",297"value": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight",298},299}300assign_to_checkpoint(301paths,302new_checkpoint,303checkpoint,304additional_replacements=[meta_path],305attention_paths_to_split=to_split if any("qkv" in key for key in attentions) else None,306config=config,307)308else:309resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)310for path in resnet_0_paths:311old_path = ".".join(["output_blocks", str(i), path["old"]])312new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])313314new_checkpoint[new_path] = checkpoint[old_path]315316return new_checkpoint317318319if __name__ == "__main__":320parser = argparse.ArgumentParser()321322parser.add_argument(323"--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."324)325326parser.add_argument(327"--config_file",328default=None,329type=str,330required=True,331help="The config json file corresponding to the architecture.",332)333334parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")335336args = parser.parse_args()337338checkpoint = torch.load(args.checkpoint_path)339340with open(args.config_file) as f:341config = json.loads(f.read())342343converted_checkpoint = convert_ldm_checkpoint(checkpoint, config)344345if "ldm" in config:346del config["ldm"]347348model = UNet2DModel(**config)349model.load_state_dict(converted_checkpoint)350351try:352scheduler = DDPMScheduler.from_config("/".join(args.checkpoint_path.split("/")[:-1]))353vqvae = VQModel.from_pretrained("/".join(args.checkpoint_path.split("/")[:-1]))354355pipe = LDMPipeline(unet=model, scheduler=scheduler, vae=vqvae)356pipe.save_pretrained(args.dump_path)357except: # noqa: E722358model.save_pretrained(args.dump_path)359360361