Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/scripts/convert_ddpm_original_checkpoint_to_diffusers.py
1440 views
1
import argparse
2
import json
3
4
import torch
5
6
from diffusers import AutoencoderKL, DDPMPipeline, DDPMScheduler, UNet2DModel, VQModel
7
8
9
def shave_segments(path, n_shave_prefix_segments=1):
10
"""
11
Removes segments. Positive values shave the first segments, negative shave the last segments.
12
"""
13
if n_shave_prefix_segments >= 0:
14
return ".".join(path.split(".")[n_shave_prefix_segments:])
15
else:
16
return ".".join(path.split(".")[:n_shave_prefix_segments])
17
18
19
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
20
mapping = []
21
for old_item in old_list:
22
new_item = old_item
23
new_item = new_item.replace("block.", "resnets.")
24
new_item = new_item.replace("conv_shorcut", "conv1")
25
new_item = new_item.replace("in_shortcut", "conv_shortcut")
26
new_item = new_item.replace("temb_proj", "time_emb_proj")
27
28
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
29
30
mapping.append({"old": old_item, "new": new_item})
31
32
return mapping
33
34
35
def renew_attention_paths(old_list, n_shave_prefix_segments=0, in_mid=False):
36
mapping = []
37
for old_item in old_list:
38
new_item = old_item
39
40
# In `model.mid`, the layer is called `attn`.
41
if not in_mid:
42
new_item = new_item.replace("attn", "attentions")
43
new_item = new_item.replace(".k.", ".key.")
44
new_item = new_item.replace(".v.", ".value.")
45
new_item = new_item.replace(".q.", ".query.")
46
47
new_item = new_item.replace("proj_out", "proj_attn")
48
new_item = new_item.replace("norm", "group_norm")
49
50
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
51
mapping.append({"old": old_item, "new": new_item})
52
53
return mapping
54
55
56
def assign_to_checkpoint(
57
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
58
):
59
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
60
61
if attention_paths_to_split is not None:
62
if config is None:
63
raise ValueError("Please specify the config if setting 'attention_paths_to_split' to 'True'.")
64
65
for path, path_map in attention_paths_to_split.items():
66
old_tensor = old_checkpoint[path]
67
channels = old_tensor.shape[0] // 3
68
69
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
70
71
num_heads = old_tensor.shape[0] // config.get("num_head_channels", 1) // 3
72
73
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
74
query, key, value = old_tensor.split(channels // num_heads, dim=1)
75
76
checkpoint[path_map["query"]] = query.reshape(target_shape).squeeze()
77
checkpoint[path_map["key"]] = key.reshape(target_shape).squeeze()
78
checkpoint[path_map["value"]] = value.reshape(target_shape).squeeze()
79
80
for path in paths:
81
new_path = path["new"]
82
83
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
84
continue
85
86
new_path = new_path.replace("down.", "down_blocks.")
87
new_path = new_path.replace("up.", "up_blocks.")
88
89
if additional_replacements is not None:
90
for replacement in additional_replacements:
91
new_path = new_path.replace(replacement["old"], replacement["new"])
92
93
if "attentions" in new_path:
94
checkpoint[new_path] = old_checkpoint[path["old"]].squeeze()
95
else:
96
checkpoint[new_path] = old_checkpoint[path["old"]]
97
98
99
def convert_ddpm_checkpoint(checkpoint, config):
100
"""
101
Takes a state dict and a config, and returns a converted checkpoint.
102
"""
103
new_checkpoint = {}
104
105
new_checkpoint["time_embedding.linear_1.weight"] = checkpoint["temb.dense.0.weight"]
106
new_checkpoint["time_embedding.linear_1.bias"] = checkpoint["temb.dense.0.bias"]
107
new_checkpoint["time_embedding.linear_2.weight"] = checkpoint["temb.dense.1.weight"]
108
new_checkpoint["time_embedding.linear_2.bias"] = checkpoint["temb.dense.1.bias"]
109
110
new_checkpoint["conv_norm_out.weight"] = checkpoint["norm_out.weight"]
111
new_checkpoint["conv_norm_out.bias"] = checkpoint["norm_out.bias"]
112
113
new_checkpoint["conv_in.weight"] = checkpoint["conv_in.weight"]
114
new_checkpoint["conv_in.bias"] = checkpoint["conv_in.bias"]
115
new_checkpoint["conv_out.weight"] = checkpoint["conv_out.weight"]
116
new_checkpoint["conv_out.bias"] = checkpoint["conv_out.bias"]
117
118
num_down_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "down" in layer})
119
down_blocks = {
120
layer_id: [key for key in checkpoint if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
121
}
122
123
num_up_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "up" in layer})
124
up_blocks = {layer_id: [key for key in checkpoint if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)}
125
126
for i in range(num_down_blocks):
127
block_id = (i - 1) // (config["layers_per_block"] + 1)
128
129
if any("downsample" in layer for layer in down_blocks[i]):
130
new_checkpoint[f"down_blocks.{i}.downsamplers.0.conv.weight"] = checkpoint[
131
f"down.{i}.downsample.op.weight"
132
]
133
new_checkpoint[f"down_blocks.{i}.downsamplers.0.conv.bias"] = checkpoint[f"down.{i}.downsample.op.bias"]
134
# new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.weight'] = checkpoint[f'down.{i}.downsample.conv.weight']
135
# new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.bias'] = checkpoint[f'down.{i}.downsample.conv.bias']
136
137
if any("block" in layer for layer in down_blocks[i]):
138
num_blocks = len(
139
{".".join(shave_segments(layer, 2).split(".")[:2]) for layer in down_blocks[i] if "block" in layer}
140
)
141
blocks = {
142
layer_id: [key for key in down_blocks[i] if f"block.{layer_id}" in key]
143
for layer_id in range(num_blocks)
144
}
145
146
if num_blocks > 0:
147
for j in range(config["layers_per_block"]):
148
paths = renew_resnet_paths(blocks[j])
149
assign_to_checkpoint(paths, new_checkpoint, checkpoint)
150
151
if any("attn" in layer for layer in down_blocks[i]):
152
num_attn = len(
153
{".".join(shave_segments(layer, 2).split(".")[:2]) for layer in down_blocks[i] if "attn" in layer}
154
)
155
attns = {
156
layer_id: [key for key in down_blocks[i] if f"attn.{layer_id}" in key]
157
for layer_id in range(num_blocks)
158
}
159
160
if num_attn > 0:
161
for j in range(config["layers_per_block"]):
162
paths = renew_attention_paths(attns[j])
163
assign_to_checkpoint(paths, new_checkpoint, checkpoint, config=config)
164
165
mid_block_1_layers = [key for key in checkpoint if "mid.block_1" in key]
166
mid_block_2_layers = [key for key in checkpoint if "mid.block_2" in key]
167
mid_attn_1_layers = [key for key in checkpoint if "mid.attn_1" in key]
168
169
# Mid new 2
170
paths = renew_resnet_paths(mid_block_1_layers)
171
assign_to_checkpoint(
172
paths,
173
new_checkpoint,
174
checkpoint,
175
additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "block_1", "new": "resnets.0"}],
176
)
177
178
paths = renew_resnet_paths(mid_block_2_layers)
179
assign_to_checkpoint(
180
paths,
181
new_checkpoint,
182
checkpoint,
183
additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "block_2", "new": "resnets.1"}],
184
)
185
186
paths = renew_attention_paths(mid_attn_1_layers, in_mid=True)
187
assign_to_checkpoint(
188
paths,
189
new_checkpoint,
190
checkpoint,
191
additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "attn_1", "new": "attentions.0"}],
192
)
193
194
for i in range(num_up_blocks):
195
block_id = num_up_blocks - 1 - i
196
197
if any("upsample" in layer for layer in up_blocks[i]):
198
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = checkpoint[
199
f"up.{i}.upsample.conv.weight"
200
]
201
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = checkpoint[f"up.{i}.upsample.conv.bias"]
202
203
if any("block" in layer for layer in up_blocks[i]):
204
num_blocks = len(
205
{".".join(shave_segments(layer, 2).split(".")[:2]) for layer in up_blocks[i] if "block" in layer}
206
)
207
blocks = {
208
layer_id: [key for key in up_blocks[i] if f"block.{layer_id}" in key] for layer_id in range(num_blocks)
209
}
210
211
if num_blocks > 0:
212
for j in range(config["layers_per_block"] + 1):
213
replace_indices = {"old": f"up_blocks.{i}", "new": f"up_blocks.{block_id}"}
214
paths = renew_resnet_paths(blocks[j])
215
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
216
217
if any("attn" in layer for layer in up_blocks[i]):
218
num_attn = len(
219
{".".join(shave_segments(layer, 2).split(".")[:2]) for layer in up_blocks[i] if "attn" in layer}
220
)
221
attns = {
222
layer_id: [key for key in up_blocks[i] if f"attn.{layer_id}" in key] for layer_id in range(num_blocks)
223
}
224
225
if num_attn > 0:
226
for j in range(config["layers_per_block"] + 1):
227
replace_indices = {"old": f"up_blocks.{i}", "new": f"up_blocks.{block_id}"}
228
paths = renew_attention_paths(attns[j])
229
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
230
231
new_checkpoint = {k.replace("mid_new_2", "mid_block"): v for k, v in new_checkpoint.items()}
232
return new_checkpoint
233
234
235
def convert_vq_autoenc_checkpoint(checkpoint, config):
236
"""
237
Takes a state dict and a config, and returns a converted checkpoint.
238
"""
239
new_checkpoint = {}
240
241
new_checkpoint["encoder.conv_norm_out.weight"] = checkpoint["encoder.norm_out.weight"]
242
new_checkpoint["encoder.conv_norm_out.bias"] = checkpoint["encoder.norm_out.bias"]
243
244
new_checkpoint["encoder.conv_in.weight"] = checkpoint["encoder.conv_in.weight"]
245
new_checkpoint["encoder.conv_in.bias"] = checkpoint["encoder.conv_in.bias"]
246
new_checkpoint["encoder.conv_out.weight"] = checkpoint["encoder.conv_out.weight"]
247
new_checkpoint["encoder.conv_out.bias"] = checkpoint["encoder.conv_out.bias"]
248
249
new_checkpoint["decoder.conv_norm_out.weight"] = checkpoint["decoder.norm_out.weight"]
250
new_checkpoint["decoder.conv_norm_out.bias"] = checkpoint["decoder.norm_out.bias"]
251
252
new_checkpoint["decoder.conv_in.weight"] = checkpoint["decoder.conv_in.weight"]
253
new_checkpoint["decoder.conv_in.bias"] = checkpoint["decoder.conv_in.bias"]
254
new_checkpoint["decoder.conv_out.weight"] = checkpoint["decoder.conv_out.weight"]
255
new_checkpoint["decoder.conv_out.bias"] = checkpoint["decoder.conv_out.bias"]
256
257
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in checkpoint if "down" in layer})
258
down_blocks = {
259
layer_id: [key for key in checkpoint if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
260
}
261
262
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in checkpoint if "up" in layer})
263
up_blocks = {layer_id: [key for key in checkpoint if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)}
264
265
for i in range(num_down_blocks):
266
block_id = (i - 1) // (config["layers_per_block"] + 1)
267
268
if any("downsample" in layer for layer in down_blocks[i]):
269
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = checkpoint[
270
f"encoder.down.{i}.downsample.conv.weight"
271
]
272
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = checkpoint[
273
f"encoder.down.{i}.downsample.conv.bias"
274
]
275
276
if any("block" in layer for layer in down_blocks[i]):
277
num_blocks = len(
278
{".".join(shave_segments(layer, 3).split(".")[:3]) for layer in down_blocks[i] if "block" in layer}
279
)
280
blocks = {
281
layer_id: [key for key in down_blocks[i] if f"block.{layer_id}" in key]
282
for layer_id in range(num_blocks)
283
}
284
285
if num_blocks > 0:
286
for j in range(config["layers_per_block"]):
287
paths = renew_resnet_paths(blocks[j])
288
assign_to_checkpoint(paths, new_checkpoint, checkpoint)
289
290
if any("attn" in layer for layer in down_blocks[i]):
291
num_attn = len(
292
{".".join(shave_segments(layer, 3).split(".")[:3]) for layer in down_blocks[i] if "attn" in layer}
293
)
294
attns = {
295
layer_id: [key for key in down_blocks[i] if f"attn.{layer_id}" in key]
296
for layer_id in range(num_blocks)
297
}
298
299
if num_attn > 0:
300
for j in range(config["layers_per_block"]):
301
paths = renew_attention_paths(attns[j])
302
assign_to_checkpoint(paths, new_checkpoint, checkpoint, config=config)
303
304
mid_block_1_layers = [key for key in checkpoint if "mid.block_1" in key]
305
mid_block_2_layers = [key for key in checkpoint if "mid.block_2" in key]
306
mid_attn_1_layers = [key for key in checkpoint if "mid.attn_1" in key]
307
308
# Mid new 2
309
paths = renew_resnet_paths(mid_block_1_layers)
310
assign_to_checkpoint(
311
paths,
312
new_checkpoint,
313
checkpoint,
314
additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "block_1", "new": "resnets.0"}],
315
)
316
317
paths = renew_resnet_paths(mid_block_2_layers)
318
assign_to_checkpoint(
319
paths,
320
new_checkpoint,
321
checkpoint,
322
additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "block_2", "new": "resnets.1"}],
323
)
324
325
paths = renew_attention_paths(mid_attn_1_layers, in_mid=True)
326
assign_to_checkpoint(
327
paths,
328
new_checkpoint,
329
checkpoint,
330
additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "attn_1", "new": "attentions.0"}],
331
)
332
333
for i in range(num_up_blocks):
334
block_id = num_up_blocks - 1 - i
335
336
if any("upsample" in layer for layer in up_blocks[i]):
337
new_checkpoint[f"decoder.up_blocks.{block_id}.upsamplers.0.conv.weight"] = checkpoint[
338
f"decoder.up.{i}.upsample.conv.weight"
339
]
340
new_checkpoint[f"decoder.up_blocks.{block_id}.upsamplers.0.conv.bias"] = checkpoint[
341
f"decoder.up.{i}.upsample.conv.bias"
342
]
343
344
if any("block" in layer for layer in up_blocks[i]):
345
num_blocks = len(
346
{".".join(shave_segments(layer, 3).split(".")[:3]) for layer in up_blocks[i] if "block" in layer}
347
)
348
blocks = {
349
layer_id: [key for key in up_blocks[i] if f"block.{layer_id}" in key] for layer_id in range(num_blocks)
350
}
351
352
if num_blocks > 0:
353
for j in range(config["layers_per_block"] + 1):
354
replace_indices = {"old": f"up_blocks.{i}", "new": f"up_blocks.{block_id}"}
355
paths = renew_resnet_paths(blocks[j])
356
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
357
358
if any("attn" in layer for layer in up_blocks[i]):
359
num_attn = len(
360
{".".join(shave_segments(layer, 3).split(".")[:3]) for layer in up_blocks[i] if "attn" in layer}
361
)
362
attns = {
363
layer_id: [key for key in up_blocks[i] if f"attn.{layer_id}" in key] for layer_id in range(num_blocks)
364
}
365
366
if num_attn > 0:
367
for j in range(config["layers_per_block"] + 1):
368
replace_indices = {"old": f"up_blocks.{i}", "new": f"up_blocks.{block_id}"}
369
paths = renew_attention_paths(attns[j])
370
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
371
372
new_checkpoint = {k.replace("mid_new_2", "mid_block"): v for k, v in new_checkpoint.items()}
373
new_checkpoint["quant_conv.weight"] = checkpoint["quant_conv.weight"]
374
new_checkpoint["quant_conv.bias"] = checkpoint["quant_conv.bias"]
375
if "quantize.embedding.weight" in checkpoint:
376
new_checkpoint["quantize.embedding.weight"] = checkpoint["quantize.embedding.weight"]
377
new_checkpoint["post_quant_conv.weight"] = checkpoint["post_quant_conv.weight"]
378
new_checkpoint["post_quant_conv.bias"] = checkpoint["post_quant_conv.bias"]
379
380
return new_checkpoint
381
382
383
if __name__ == "__main__":
384
parser = argparse.ArgumentParser()
385
386
parser.add_argument(
387
"--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
388
)
389
390
parser.add_argument(
391
"--config_file",
392
default=None,
393
type=str,
394
required=True,
395
help="The config json file corresponding to the architecture.",
396
)
397
398
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
399
400
args = parser.parse_args()
401
checkpoint = torch.load(args.checkpoint_path)
402
403
with open(args.config_file) as f:
404
config = json.loads(f.read())
405
406
# unet case
407
key_prefix_set = set(key.split(".")[0] for key in checkpoint.keys())
408
if "encoder" in key_prefix_set and "decoder" in key_prefix_set:
409
converted_checkpoint = convert_vq_autoenc_checkpoint(checkpoint, config)
410
else:
411
converted_checkpoint = convert_ddpm_checkpoint(checkpoint, config)
412
413
if "ddpm" in config:
414
del config["ddpm"]
415
416
if config["_class_name"] == "VQModel":
417
model = VQModel(**config)
418
model.load_state_dict(converted_checkpoint)
419
model.save_pretrained(args.dump_path)
420
elif config["_class_name"] == "AutoencoderKL":
421
model = AutoencoderKL(**config)
422
model.load_state_dict(converted_checkpoint)
423
model.save_pretrained(args.dump_path)
424
else:
425
model = UNet2DModel(**config)
426
model.load_state_dict(converted_checkpoint)
427
428
scheduler = DDPMScheduler.from_config("/".join(args.checkpoint_path.split("/")[:-1]))
429
430
pipe = DDPMPipeline(unet=model, scheduler=scheduler)
431
pipe.save_pretrained(args.dump_path)
432
433