Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/scripts/convert_ncsnpp_original_checkpoint_to_diffusers.py
1440 views
1
# coding=utf-8
2
# Copyright 2023 The HuggingFace Inc. team.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
""" Conversion script for the NCSNPP checkpoints. """
16
17
import argparse
18
import json
19
20
import torch
21
22
from diffusers import ScoreSdeVePipeline, ScoreSdeVeScheduler, UNet2DModel
23
24
25
def convert_ncsnpp_checkpoint(checkpoint, config):
26
"""
27
Takes a state dict and the path to
28
"""
29
new_model_architecture = UNet2DModel(**config)
30
new_model_architecture.time_proj.W.data = checkpoint["all_modules.0.W"].data
31
new_model_architecture.time_proj.weight.data = checkpoint["all_modules.0.W"].data
32
new_model_architecture.time_embedding.linear_1.weight.data = checkpoint["all_modules.1.weight"].data
33
new_model_architecture.time_embedding.linear_1.bias.data = checkpoint["all_modules.1.bias"].data
34
35
new_model_architecture.time_embedding.linear_2.weight.data = checkpoint["all_modules.2.weight"].data
36
new_model_architecture.time_embedding.linear_2.bias.data = checkpoint["all_modules.2.bias"].data
37
38
new_model_architecture.conv_in.weight.data = checkpoint["all_modules.3.weight"].data
39
new_model_architecture.conv_in.bias.data = checkpoint["all_modules.3.bias"].data
40
41
new_model_architecture.conv_norm_out.weight.data = checkpoint[list(checkpoint.keys())[-4]].data
42
new_model_architecture.conv_norm_out.bias.data = checkpoint[list(checkpoint.keys())[-3]].data
43
new_model_architecture.conv_out.weight.data = checkpoint[list(checkpoint.keys())[-2]].data
44
new_model_architecture.conv_out.bias.data = checkpoint[list(checkpoint.keys())[-1]].data
45
46
module_index = 4
47
48
def set_attention_weights(new_layer, old_checkpoint, index):
49
new_layer.query.weight.data = old_checkpoint[f"all_modules.{index}.NIN_0.W"].data.T
50
new_layer.key.weight.data = old_checkpoint[f"all_modules.{index}.NIN_1.W"].data.T
51
new_layer.value.weight.data = old_checkpoint[f"all_modules.{index}.NIN_2.W"].data.T
52
53
new_layer.query.bias.data = old_checkpoint[f"all_modules.{index}.NIN_0.b"].data
54
new_layer.key.bias.data = old_checkpoint[f"all_modules.{index}.NIN_1.b"].data
55
new_layer.value.bias.data = old_checkpoint[f"all_modules.{index}.NIN_2.b"].data
56
57
new_layer.proj_attn.weight.data = old_checkpoint[f"all_modules.{index}.NIN_3.W"].data.T
58
new_layer.proj_attn.bias.data = old_checkpoint[f"all_modules.{index}.NIN_3.b"].data
59
60
new_layer.group_norm.weight.data = old_checkpoint[f"all_modules.{index}.GroupNorm_0.weight"].data
61
new_layer.group_norm.bias.data = old_checkpoint[f"all_modules.{index}.GroupNorm_0.bias"].data
62
63
def set_resnet_weights(new_layer, old_checkpoint, index):
64
new_layer.conv1.weight.data = old_checkpoint[f"all_modules.{index}.Conv_0.weight"].data
65
new_layer.conv1.bias.data = old_checkpoint[f"all_modules.{index}.Conv_0.bias"].data
66
new_layer.norm1.weight.data = old_checkpoint[f"all_modules.{index}.GroupNorm_0.weight"].data
67
new_layer.norm1.bias.data = old_checkpoint[f"all_modules.{index}.GroupNorm_0.bias"].data
68
69
new_layer.conv2.weight.data = old_checkpoint[f"all_modules.{index}.Conv_1.weight"].data
70
new_layer.conv2.bias.data = old_checkpoint[f"all_modules.{index}.Conv_1.bias"].data
71
new_layer.norm2.weight.data = old_checkpoint[f"all_modules.{index}.GroupNorm_1.weight"].data
72
new_layer.norm2.bias.data = old_checkpoint[f"all_modules.{index}.GroupNorm_1.bias"].data
73
74
new_layer.time_emb_proj.weight.data = old_checkpoint[f"all_modules.{index}.Dense_0.weight"].data
75
new_layer.time_emb_proj.bias.data = old_checkpoint[f"all_modules.{index}.Dense_0.bias"].data
76
77
if new_layer.in_channels != new_layer.out_channels or new_layer.up or new_layer.down:
78
new_layer.conv_shortcut.weight.data = old_checkpoint[f"all_modules.{index}.Conv_2.weight"].data
79
new_layer.conv_shortcut.bias.data = old_checkpoint[f"all_modules.{index}.Conv_2.bias"].data
80
81
for i, block in enumerate(new_model_architecture.downsample_blocks):
82
has_attentions = hasattr(block, "attentions")
83
for j in range(len(block.resnets)):
84
set_resnet_weights(block.resnets[j], checkpoint, module_index)
85
module_index += 1
86
if has_attentions:
87
set_attention_weights(block.attentions[j], checkpoint, module_index)
88
module_index += 1
89
90
if hasattr(block, "downsamplers") and block.downsamplers is not None:
91
set_resnet_weights(block.resnet_down, checkpoint, module_index)
92
module_index += 1
93
block.skip_conv.weight.data = checkpoint[f"all_modules.{module_index}.Conv_0.weight"].data
94
block.skip_conv.bias.data = checkpoint[f"all_modules.{module_index}.Conv_0.bias"].data
95
module_index += 1
96
97
set_resnet_weights(new_model_architecture.mid_block.resnets[0], checkpoint, module_index)
98
module_index += 1
99
set_attention_weights(new_model_architecture.mid_block.attentions[0], checkpoint, module_index)
100
module_index += 1
101
set_resnet_weights(new_model_architecture.mid_block.resnets[1], checkpoint, module_index)
102
module_index += 1
103
104
for i, block in enumerate(new_model_architecture.up_blocks):
105
has_attentions = hasattr(block, "attentions")
106
for j in range(len(block.resnets)):
107
set_resnet_weights(block.resnets[j], checkpoint, module_index)
108
module_index += 1
109
if has_attentions:
110
set_attention_weights(
111
block.attentions[0], checkpoint, module_index
112
) # why can there only be a single attention layer for up?
113
module_index += 1
114
115
if hasattr(block, "resnet_up") and block.resnet_up is not None:
116
block.skip_norm.weight.data = checkpoint[f"all_modules.{module_index}.weight"].data
117
block.skip_norm.bias.data = checkpoint[f"all_modules.{module_index}.bias"].data
118
module_index += 1
119
block.skip_conv.weight.data = checkpoint[f"all_modules.{module_index}.weight"].data
120
block.skip_conv.bias.data = checkpoint[f"all_modules.{module_index}.bias"].data
121
module_index += 1
122
set_resnet_weights(block.resnet_up, checkpoint, module_index)
123
module_index += 1
124
125
new_model_architecture.conv_norm_out.weight.data = checkpoint[f"all_modules.{module_index}.weight"].data
126
new_model_architecture.conv_norm_out.bias.data = checkpoint[f"all_modules.{module_index}.bias"].data
127
module_index += 1
128
new_model_architecture.conv_out.weight.data = checkpoint[f"all_modules.{module_index}.weight"].data
129
new_model_architecture.conv_out.bias.data = checkpoint[f"all_modules.{module_index}.bias"].data
130
131
return new_model_architecture.state_dict()
132
133
134
if __name__ == "__main__":
135
parser = argparse.ArgumentParser()
136
137
parser.add_argument(
138
"--checkpoint_path",
139
default="/Users/arthurzucker/Work/diffusers/ArthurZ/diffusion_pytorch_model.bin",
140
type=str,
141
required=False,
142
help="Path to the checkpoint to convert.",
143
)
144
145
parser.add_argument(
146
"--config_file",
147
default="/Users/arthurzucker/Work/diffusers/ArthurZ/config.json",
148
type=str,
149
required=False,
150
help="The config json file corresponding to the architecture.",
151
)
152
153
parser.add_argument(
154
"--dump_path",
155
default="/Users/arthurzucker/Work/diffusers/ArthurZ/diffusion_model_new.pt",
156
type=str,
157
required=False,
158
help="Path to the output model.",
159
)
160
161
args = parser.parse_args()
162
163
checkpoint = torch.load(args.checkpoint_path, map_location="cpu")
164
165
with open(args.config_file) as f:
166
config = json.loads(f.read())
167
168
converted_checkpoint = convert_ncsnpp_checkpoint(
169
checkpoint,
170
config,
171
)
172
173
if "sde" in config:
174
del config["sde"]
175
176
model = UNet2DModel(**config)
177
model.load_state_dict(converted_checkpoint)
178
179
try:
180
scheduler = ScoreSdeVeScheduler.from_config("/".join(args.checkpoint_path.split("/")[:-1]))
181
182
pipe = ScoreSdeVePipeline(unet=model, scheduler=scheduler)
183
pipe.save_pretrained(args.dump_path)
184
except: # noqa: E722
185
model.save_pretrained(args.dump_path)
186
187