Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/scripts/change_naming_configs_and_checkpoints.py
1440 views
1
# coding=utf-8
2
# Copyright 2023 The HuggingFace Inc. team.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
""" Conversion script for the LDM checkpoints. """
16
17
import argparse
18
import json
19
import os
20
21
import torch
22
from transformers.file_utils import has_file
23
24
from diffusers import UNet2DConditionModel, UNet2DModel
25
26
27
do_only_config = False
28
do_only_weights = True
29
do_only_renaming = False
30
31
32
if __name__ == "__main__":
33
parser = argparse.ArgumentParser()
34
35
parser.add_argument(
36
"--repo_path",
37
default=None,
38
type=str,
39
required=True,
40
help="The config json file corresponding to the architecture.",
41
)
42
43
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
44
45
args = parser.parse_args()
46
47
config_parameters_to_change = {
48
"image_size": "sample_size",
49
"num_res_blocks": "layers_per_block",
50
"block_channels": "block_out_channels",
51
"down_blocks": "down_block_types",
52
"up_blocks": "up_block_types",
53
"downscale_freq_shift": "freq_shift",
54
"resnet_num_groups": "norm_num_groups",
55
"resnet_act_fn": "act_fn",
56
"resnet_eps": "norm_eps",
57
"num_head_channels": "attention_head_dim",
58
}
59
60
key_parameters_to_change = {
61
"time_steps": "time_proj",
62
"mid": "mid_block",
63
"downsample_blocks": "down_blocks",
64
"upsample_blocks": "up_blocks",
65
}
66
67
subfolder = "" if has_file(args.repo_path, "config.json") else "unet"
68
69
with open(os.path.join(args.repo_path, subfolder, "config.json"), "r", encoding="utf-8") as reader:
70
text = reader.read()
71
config = json.loads(text)
72
73
if do_only_config:
74
for key in config_parameters_to_change.keys():
75
config.pop(key, None)
76
77
if has_file(args.repo_path, "config.json"):
78
model = UNet2DModel(**config)
79
else:
80
class_name = UNet2DConditionModel if "ldm-text2im-large-256" in args.repo_path else UNet2DModel
81
model = class_name(**config)
82
83
if do_only_config:
84
model.save_config(os.path.join(args.repo_path, subfolder))
85
86
config = dict(model.config)
87
88
if do_only_renaming:
89
for key, value in config_parameters_to_change.items():
90
if key in config:
91
config[value] = config[key]
92
del config[key]
93
94
config["down_block_types"] = [k.replace("UNetRes", "") for k in config["down_block_types"]]
95
config["up_block_types"] = [k.replace("UNetRes", "") for k in config["up_block_types"]]
96
97
if do_only_weights:
98
state_dict = torch.load(os.path.join(args.repo_path, subfolder, "diffusion_pytorch_model.bin"))
99
100
new_state_dict = {}
101
for param_key, param_value in state_dict.items():
102
if param_key.endswith(".op.bias") or param_key.endswith(".op.weight"):
103
continue
104
has_changed = False
105
for key, new_key in key_parameters_to_change.items():
106
if not has_changed and param_key.split(".")[0] == key:
107
new_state_dict[".".join([new_key] + param_key.split(".")[1:])] = param_value
108
has_changed = True
109
if not has_changed:
110
new_state_dict[param_key] = param_value
111
112
model.load_state_dict(new_state_dict)
113
model.save_pretrained(os.path.join(args.repo_path, subfolder))
114
115