Path: blob/main/scripts/change_naming_configs_and_checkpoints.py
1440 views
# coding=utf-81# Copyright 2023 The HuggingFace Inc. team.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14""" Conversion script for the LDM checkpoints. """1516import argparse17import json18import os1920import torch21from transformers.file_utils import has_file2223from diffusers import UNet2DConditionModel, UNet2DModel242526do_only_config = False27do_only_weights = True28do_only_renaming = False293031if __name__ == "__main__":32parser = argparse.ArgumentParser()3334parser.add_argument(35"--repo_path",36default=None,37type=str,38required=True,39help="The config json file corresponding to the architecture.",40)4142parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")4344args = parser.parse_args()4546config_parameters_to_change = {47"image_size": "sample_size",48"num_res_blocks": "layers_per_block",49"block_channels": "block_out_channels",50"down_blocks": "down_block_types",51"up_blocks": "up_block_types",52"downscale_freq_shift": "freq_shift",53"resnet_num_groups": "norm_num_groups",54"resnet_act_fn": "act_fn",55"resnet_eps": "norm_eps",56"num_head_channels": "attention_head_dim",57}5859key_parameters_to_change = {60"time_steps": "time_proj",61"mid": "mid_block",62"downsample_blocks": "down_blocks",63"upsample_blocks": "up_blocks",64}6566subfolder = "" if has_file(args.repo_path, "config.json") else "unet"6768with open(os.path.join(args.repo_path, subfolder, "config.json"), "r", encoding="utf-8") as reader:69text = reader.read()70config = json.loads(text)7172if do_only_config:73for key in config_parameters_to_change.keys():74config.pop(key, None)7576if has_file(args.repo_path, "config.json"):77model = UNet2DModel(**config)78else:79class_name = UNet2DConditionModel if "ldm-text2im-large-256" in args.repo_path else UNet2DModel80model = class_name(**config)8182if do_only_config:83model.save_config(os.path.join(args.repo_path, subfolder))8485config = dict(model.config)8687if do_only_renaming:88for key, value in config_parameters_to_change.items():89if key in config:90config[value] = config[key]91del config[key]9293config["down_block_types"] = [k.replace("UNetRes", "") for k in config["down_block_types"]]94config["up_block_types"] = [k.replace("UNetRes", "") for k in config["up_block_types"]]9596if do_only_weights:97state_dict = torch.load(os.path.join(args.repo_path, subfolder, "diffusion_pytorch_model.bin"))9899new_state_dict = {}100for param_key, param_value in state_dict.items():101if param_key.endswith(".op.bias") or param_key.endswith(".op.weight"):102continue103has_changed = False104for key, new_key in key_parameters_to_change.items():105if not has_changed and param_key.split(".")[0] == key:106new_state_dict[".".join([new_key] + param_key.split(".")[1:])] = param_value107has_changed = True108if not has_changed:109new_state_dict[param_key] = param_value110111model.load_state_dict(new_state_dict)112model.save_pretrained(os.path.join(args.repo_path, subfolder))113114115