Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/scripts/convert_ldm_original_checkpoint_to_diffusers.py
1440 views
1
# coding=utf-8
2
# Copyright 2023 The HuggingFace Inc. team.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
""" Conversion script for the LDM checkpoints. """
16
17
import argparse
18
import json
19
20
import torch
21
22
from diffusers import DDPMScheduler, LDMPipeline, UNet2DModel, VQModel
23
24
25
def shave_segments(path, n_shave_prefix_segments=1):
26
"""
27
Removes segments. Positive values shave the first segments, negative shave the last segments.
28
"""
29
if n_shave_prefix_segments >= 0:
30
return ".".join(path.split(".")[n_shave_prefix_segments:])
31
else:
32
return ".".join(path.split(".")[:n_shave_prefix_segments])
33
34
35
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
36
"""
37
Updates paths inside resnets to the new naming scheme (local renaming)
38
"""
39
mapping = []
40
for old_item in old_list:
41
new_item = old_item.replace("in_layers.0", "norm1")
42
new_item = new_item.replace("in_layers.2", "conv1")
43
44
new_item = new_item.replace("out_layers.0", "norm2")
45
new_item = new_item.replace("out_layers.3", "conv2")
46
47
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
48
new_item = new_item.replace("skip_connection", "conv_shortcut")
49
50
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
51
52
mapping.append({"old": old_item, "new": new_item})
53
54
return mapping
55
56
57
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
58
"""
59
Updates paths inside attentions to the new naming scheme (local renaming)
60
"""
61
mapping = []
62
for old_item in old_list:
63
new_item = old_item
64
65
new_item = new_item.replace("norm.weight", "group_norm.weight")
66
new_item = new_item.replace("norm.bias", "group_norm.bias")
67
68
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
69
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
70
71
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
72
73
mapping.append({"old": old_item, "new": new_item})
74
75
return mapping
76
77
78
def assign_to_checkpoint(
79
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
80
):
81
"""
82
This does the final conversion step: take locally converted weights and apply a global renaming
83
to them. It splits attention layers, and takes into account additional replacements
84
that may arise.
85
86
Assigns the weights to the new checkpoint.
87
"""
88
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
89
90
# Splits the attention layers into three variables.
91
if attention_paths_to_split is not None:
92
for path, path_map in attention_paths_to_split.items():
93
old_tensor = old_checkpoint[path]
94
channels = old_tensor.shape[0] // 3
95
96
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
97
98
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
99
100
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
101
query, key, value = old_tensor.split(channels // num_heads, dim=1)
102
103
checkpoint[path_map["query"]] = query.reshape(target_shape)
104
checkpoint[path_map["key"]] = key.reshape(target_shape)
105
checkpoint[path_map["value"]] = value.reshape(target_shape)
106
107
for path in paths:
108
new_path = path["new"]
109
110
# These have already been assigned
111
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
112
continue
113
114
# Global renaming happens here
115
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
116
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
117
new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
118
119
if additional_replacements is not None:
120
for replacement in additional_replacements:
121
new_path = new_path.replace(replacement["old"], replacement["new"])
122
123
# proj_attn.weight has to be converted from conv 1D to linear
124
if "proj_attn.weight" in new_path:
125
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
126
else:
127
checkpoint[new_path] = old_checkpoint[path["old"]]
128
129
130
def convert_ldm_checkpoint(checkpoint, config):
131
"""
132
Takes a state dict and a config, and returns a converted checkpoint.
133
"""
134
new_checkpoint = {}
135
136
new_checkpoint["time_embedding.linear_1.weight"] = checkpoint["time_embed.0.weight"]
137
new_checkpoint["time_embedding.linear_1.bias"] = checkpoint["time_embed.0.bias"]
138
new_checkpoint["time_embedding.linear_2.weight"] = checkpoint["time_embed.2.weight"]
139
new_checkpoint["time_embedding.linear_2.bias"] = checkpoint["time_embed.2.bias"]
140
141
new_checkpoint["conv_in.weight"] = checkpoint["input_blocks.0.0.weight"]
142
new_checkpoint["conv_in.bias"] = checkpoint["input_blocks.0.0.bias"]
143
144
new_checkpoint["conv_norm_out.weight"] = checkpoint["out.0.weight"]
145
new_checkpoint["conv_norm_out.bias"] = checkpoint["out.0.bias"]
146
new_checkpoint["conv_out.weight"] = checkpoint["out.2.weight"]
147
new_checkpoint["conv_out.bias"] = checkpoint["out.2.bias"]
148
149
# Retrieves the keys for the input blocks only
150
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "input_blocks" in layer})
151
input_blocks = {
152
layer_id: [key for key in checkpoint if f"input_blocks.{layer_id}" in key]
153
for layer_id in range(num_input_blocks)
154
}
155
156
# Retrieves the keys for the middle blocks only
157
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "middle_block" in layer})
158
middle_blocks = {
159
layer_id: [key for key in checkpoint if f"middle_block.{layer_id}" in key]
160
for layer_id in range(num_middle_blocks)
161
}
162
163
# Retrieves the keys for the output blocks only
164
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "output_blocks" in layer})
165
output_blocks = {
166
layer_id: [key for key in checkpoint if f"output_blocks.{layer_id}" in key]
167
for layer_id in range(num_output_blocks)
168
}
169
170
for i in range(1, num_input_blocks):
171
block_id = (i - 1) // (config["num_res_blocks"] + 1)
172
layer_in_block_id = (i - 1) % (config["num_res_blocks"] + 1)
173
174
resnets = [key for key in input_blocks[i] if f"input_blocks.{i}.0" in key]
175
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
176
177
if f"input_blocks.{i}.0.op.weight" in checkpoint:
178
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = checkpoint[
179
f"input_blocks.{i}.0.op.weight"
180
]
181
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = checkpoint[
182
f"input_blocks.{i}.0.op.bias"
183
]
184
continue
185
186
paths = renew_resnet_paths(resnets)
187
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
188
resnet_op = {"old": "resnets.2.op", "new": "downsamplers.0.op"}
189
assign_to_checkpoint(
190
paths, new_checkpoint, checkpoint, additional_replacements=[meta_path, resnet_op], config=config
191
)
192
193
if len(attentions):
194
paths = renew_attention_paths(attentions)
195
meta_path = {
196
"old": f"input_blocks.{i}.1",
197
"new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}",
198
}
199
to_split = {
200
f"input_blocks.{i}.1.qkv.bias": {
201
"key": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias",
202
"query": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias",
203
"value": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias",
204
},
205
f"input_blocks.{i}.1.qkv.weight": {
206
"key": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight",
207
"query": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight",
208
"value": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight",
209
},
210
}
211
assign_to_checkpoint(
212
paths,
213
new_checkpoint,
214
checkpoint,
215
additional_replacements=[meta_path],
216
attention_paths_to_split=to_split,
217
config=config,
218
)
219
220
resnet_0 = middle_blocks[0]
221
attentions = middle_blocks[1]
222
resnet_1 = middle_blocks[2]
223
224
resnet_0_paths = renew_resnet_paths(resnet_0)
225
assign_to_checkpoint(resnet_0_paths, new_checkpoint, checkpoint, config=config)
226
227
resnet_1_paths = renew_resnet_paths(resnet_1)
228
assign_to_checkpoint(resnet_1_paths, new_checkpoint, checkpoint, config=config)
229
230
attentions_paths = renew_attention_paths(attentions)
231
to_split = {
232
"middle_block.1.qkv.bias": {
233
"key": "mid_block.attentions.0.key.bias",
234
"query": "mid_block.attentions.0.query.bias",
235
"value": "mid_block.attentions.0.value.bias",
236
},
237
"middle_block.1.qkv.weight": {
238
"key": "mid_block.attentions.0.key.weight",
239
"query": "mid_block.attentions.0.query.weight",
240
"value": "mid_block.attentions.0.value.weight",
241
},
242
}
243
assign_to_checkpoint(
244
attentions_paths, new_checkpoint, checkpoint, attention_paths_to_split=to_split, config=config
245
)
246
247
for i in range(num_output_blocks):
248
block_id = i // (config["num_res_blocks"] + 1)
249
layer_in_block_id = i % (config["num_res_blocks"] + 1)
250
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
251
output_block_list = {}
252
253
for layer in output_block_layers:
254
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
255
if layer_id in output_block_list:
256
output_block_list[layer_id].append(layer_name)
257
else:
258
output_block_list[layer_id] = [layer_name]
259
260
if len(output_block_list) > 1:
261
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
262
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
263
264
resnet_0_paths = renew_resnet_paths(resnets)
265
paths = renew_resnet_paths(resnets)
266
267
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
268
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[meta_path], config=config)
269
270
if ["conv.weight", "conv.bias"] in output_block_list.values():
271
index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
272
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = checkpoint[
273
f"output_blocks.{i}.{index}.conv.weight"
274
]
275
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = checkpoint[
276
f"output_blocks.{i}.{index}.conv.bias"
277
]
278
279
# Clear attentions as they have been attributed above.
280
if len(attentions) == 2:
281
attentions = []
282
283
if len(attentions):
284
paths = renew_attention_paths(attentions)
285
meta_path = {
286
"old": f"output_blocks.{i}.1",
287
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
288
}
289
to_split = {
290
f"output_blocks.{i}.1.qkv.bias": {
291
"key": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias",
292
"query": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias",
293
"value": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias",
294
},
295
f"output_blocks.{i}.1.qkv.weight": {
296
"key": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight",
297
"query": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight",
298
"value": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight",
299
},
300
}
301
assign_to_checkpoint(
302
paths,
303
new_checkpoint,
304
checkpoint,
305
additional_replacements=[meta_path],
306
attention_paths_to_split=to_split if any("qkv" in key for key in attentions) else None,
307
config=config,
308
)
309
else:
310
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
311
for path in resnet_0_paths:
312
old_path = ".".join(["output_blocks", str(i), path["old"]])
313
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
314
315
new_checkpoint[new_path] = checkpoint[old_path]
316
317
return new_checkpoint
318
319
320
if __name__ == "__main__":
321
parser = argparse.ArgumentParser()
322
323
parser.add_argument(
324
"--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
325
)
326
327
parser.add_argument(
328
"--config_file",
329
default=None,
330
type=str,
331
required=True,
332
help="The config json file corresponding to the architecture.",
333
)
334
335
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
336
337
args = parser.parse_args()
338
339
checkpoint = torch.load(args.checkpoint_path)
340
341
with open(args.config_file) as f:
342
config = json.loads(f.read())
343
344
converted_checkpoint = convert_ldm_checkpoint(checkpoint, config)
345
346
if "ldm" in config:
347
del config["ldm"]
348
349
model = UNet2DModel(**config)
350
model.load_state_dict(converted_checkpoint)
351
352
try:
353
scheduler = DDPMScheduler.from_config("/".join(args.checkpoint_path.split("/")[:-1]))
354
vqvae = VQModel.from_pretrained("/".join(args.checkpoint_path.split("/")[:-1]))
355
356
pipe = LDMPipeline(unet=model, scheduler=scheduler, vae=vqvae)
357
pipe.save_pretrained(args.dump_path)
358
except: # noqa: E722
359
model.save_pretrained(args.dump_path)
360
361