Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/scripts/convert_ms_text_to_video_to_diffusers.py
1440 views
1
# coding=utf-8
2
# Copyright 2023 The HuggingFace Inc. team.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
""" Conversion script for the LDM checkpoints. """
16
17
import argparse
18
19
import torch
20
21
from diffusers import UNet3DConditionModel
22
23
24
def assign_to_checkpoint(
25
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
26
):
27
"""
28
This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
29
attention layers, and takes into account additional replacements that may arise.
30
31
Assigns the weights to the new checkpoint.
32
"""
33
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
34
35
# Splits the attention layers into three variables.
36
if attention_paths_to_split is not None:
37
for path, path_map in attention_paths_to_split.items():
38
old_tensor = old_checkpoint[path]
39
channels = old_tensor.shape[0] // 3
40
41
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
42
43
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
44
45
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
46
query, key, value = old_tensor.split(channels // num_heads, dim=1)
47
48
checkpoint[path_map["query"]] = query.reshape(target_shape)
49
checkpoint[path_map["key"]] = key.reshape(target_shape)
50
checkpoint[path_map["value"]] = value.reshape(target_shape)
51
52
for path in paths:
53
new_path = path["new"]
54
55
# These have already been assigned
56
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
57
continue
58
59
if additional_replacements is not None:
60
for replacement in additional_replacements:
61
new_path = new_path.replace(replacement["old"], replacement["new"])
62
63
# proj_attn.weight has to be converted from conv 1D to linear
64
weight = old_checkpoint[path["old"]]
65
names = ["proj_attn.weight"]
66
names_2 = ["proj_out.weight", "proj_in.weight"]
67
if any(k in new_path for k in names):
68
checkpoint[new_path] = weight[:, :, 0]
69
elif any(k in new_path for k in names_2) and len(weight.shape) > 2 and ".attentions." not in new_path:
70
checkpoint[new_path] = weight[:, :, 0]
71
else:
72
checkpoint[new_path] = weight
73
74
75
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
76
"""
77
Updates paths inside attentions to the new naming scheme (local renaming)
78
"""
79
mapping = []
80
for old_item in old_list:
81
new_item = old_item
82
83
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
84
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
85
86
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
87
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
88
89
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
90
91
mapping.append({"old": old_item, "new": new_item})
92
93
return mapping
94
95
96
def shave_segments(path, n_shave_prefix_segments=1):
97
"""
98
Removes segments. Positive values shave the first segments, negative shave the last segments.
99
"""
100
if n_shave_prefix_segments >= 0:
101
return ".".join(path.split(".")[n_shave_prefix_segments:])
102
else:
103
return ".".join(path.split(".")[:n_shave_prefix_segments])
104
105
106
def renew_temp_conv_paths(old_list, n_shave_prefix_segments=0):
107
"""
108
Updates paths inside resnets to the new naming scheme (local renaming)
109
"""
110
mapping = []
111
for old_item in old_list:
112
mapping.append({"old": old_item, "new": old_item})
113
114
return mapping
115
116
117
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
118
"""
119
Updates paths inside resnets to the new naming scheme (local renaming)
120
"""
121
mapping = []
122
for old_item in old_list:
123
new_item = old_item.replace("in_layers.0", "norm1")
124
new_item = new_item.replace("in_layers.2", "conv1")
125
126
new_item = new_item.replace("out_layers.0", "norm2")
127
new_item = new_item.replace("out_layers.3", "conv2")
128
129
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
130
new_item = new_item.replace("skip_connection", "conv_shortcut")
131
132
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
133
134
if "temopral_conv" not in old_item:
135
mapping.append({"old": old_item, "new": new_item})
136
137
return mapping
138
139
140
def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False):
141
"""
142
Takes a state dict and a config, and returns a converted checkpoint.
143
"""
144
145
# extract state_dict for UNet
146
unet_state_dict = {}
147
keys = list(checkpoint.keys())
148
149
unet_key = "model.diffusion_model."
150
151
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
152
if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
153
print(f"Checkpoint {path} has both EMA and non-EMA weights.")
154
print(
155
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
156
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
157
)
158
for key in keys:
159
if key.startswith("model.diffusion_model"):
160
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
161
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
162
else:
163
if sum(k.startswith("model_ema") for k in keys) > 100:
164
print(
165
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
166
" weights (usually better for inference), please make sure to add the `--extract_ema` flag."
167
)
168
169
for key in keys:
170
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
171
172
new_checkpoint = {}
173
174
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
175
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
176
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
177
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
178
179
if config["class_embed_type"] is None:
180
# No parameters to port
181
...
182
elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
183
new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
184
new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
185
new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
186
new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
187
else:
188
raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
189
190
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
191
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
192
193
first_temp_attention = [v for v in unet_state_dict if v.startswith("input_blocks.0.1")]
194
paths = renew_attention_paths(first_temp_attention)
195
meta_path = {"old": "input_blocks.0.1", "new": "transformer_in"}
196
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
197
198
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
199
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
200
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
201
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
202
203
# Retrieves the keys for the input blocks only
204
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
205
input_blocks = {
206
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
207
for layer_id in range(num_input_blocks)
208
}
209
210
# Retrieves the keys for the middle blocks only
211
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
212
middle_blocks = {
213
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
214
for layer_id in range(num_middle_blocks)
215
}
216
217
# Retrieves the keys for the output blocks only
218
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
219
output_blocks = {
220
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
221
for layer_id in range(num_output_blocks)
222
}
223
224
for i in range(1, num_input_blocks):
225
block_id = (i - 1) // (config["layers_per_block"] + 1)
226
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
227
228
resnets = [
229
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
230
]
231
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
232
temp_attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.2" in key]
233
234
if f"input_blocks.{i}.op.weight" in unet_state_dict:
235
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
236
f"input_blocks.{i}.op.weight"
237
)
238
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
239
f"input_blocks.{i}.op.bias"
240
)
241
242
paths = renew_resnet_paths(resnets)
243
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
244
assign_to_checkpoint(
245
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
246
)
247
248
temporal_convs = [key for key in resnets if "temopral_conv" in key]
249
paths = renew_temp_conv_paths(temporal_convs)
250
meta_path = {
251
"old": f"input_blocks.{i}.0.temopral_conv",
252
"new": f"down_blocks.{block_id}.temp_convs.{layer_in_block_id}",
253
}
254
assign_to_checkpoint(
255
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
256
)
257
258
if len(attentions):
259
paths = renew_attention_paths(attentions)
260
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
261
assign_to_checkpoint(
262
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
263
)
264
265
if len(temp_attentions):
266
paths = renew_attention_paths(temp_attentions)
267
meta_path = {
268
"old": f"input_blocks.{i}.2",
269
"new": f"down_blocks.{block_id}.temp_attentions.{layer_in_block_id}",
270
}
271
assign_to_checkpoint(
272
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
273
)
274
275
resnet_0 = middle_blocks[0]
276
temporal_convs_0 = [key for key in resnet_0 if "temopral_conv" in key]
277
attentions = middle_blocks[1]
278
temp_attentions = middle_blocks[2]
279
resnet_1 = middle_blocks[3]
280
temporal_convs_1 = [key for key in resnet_1 if "temopral_conv" in key]
281
282
resnet_0_paths = renew_resnet_paths(resnet_0)
283
meta_path = {"old": "middle_block.0", "new": "mid_block.resnets.0"}
284
assign_to_checkpoint(
285
resnet_0_paths, new_checkpoint, unet_state_dict, config=config, additional_replacements=[meta_path]
286
)
287
288
temp_conv_0_paths = renew_temp_conv_paths(temporal_convs_0)
289
meta_path = {"old": "middle_block.0.temopral_conv", "new": "mid_block.temp_convs.0"}
290
assign_to_checkpoint(
291
temp_conv_0_paths, new_checkpoint, unet_state_dict, config=config, additional_replacements=[meta_path]
292
)
293
294
resnet_1_paths = renew_resnet_paths(resnet_1)
295
meta_path = {"old": "middle_block.3", "new": "mid_block.resnets.1"}
296
assign_to_checkpoint(
297
resnet_1_paths, new_checkpoint, unet_state_dict, config=config, additional_replacements=[meta_path]
298
)
299
300
temp_conv_1_paths = renew_temp_conv_paths(temporal_convs_1)
301
meta_path = {"old": "middle_block.3.temopral_conv", "new": "mid_block.temp_convs.1"}
302
assign_to_checkpoint(
303
temp_conv_1_paths, new_checkpoint, unet_state_dict, config=config, additional_replacements=[meta_path]
304
)
305
306
attentions_paths = renew_attention_paths(attentions)
307
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
308
assign_to_checkpoint(
309
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
310
)
311
312
temp_attentions_paths = renew_attention_paths(temp_attentions)
313
meta_path = {"old": "middle_block.2", "new": "mid_block.temp_attentions.0"}
314
assign_to_checkpoint(
315
temp_attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
316
)
317
318
for i in range(num_output_blocks):
319
block_id = i // (config["layers_per_block"] + 1)
320
layer_in_block_id = i % (config["layers_per_block"] + 1)
321
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
322
output_block_list = {}
323
324
for layer in output_block_layers:
325
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
326
if layer_id in output_block_list:
327
output_block_list[layer_id].append(layer_name)
328
else:
329
output_block_list[layer_id] = [layer_name]
330
331
if len(output_block_list) > 1:
332
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
333
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
334
temp_attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.2" in key]
335
336
resnet_0_paths = renew_resnet_paths(resnets)
337
paths = renew_resnet_paths(resnets)
338
339
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
340
assign_to_checkpoint(
341
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
342
)
343
344
temporal_convs = [key for key in resnets if "temopral_conv" in key]
345
paths = renew_temp_conv_paths(temporal_convs)
346
meta_path = {
347
"old": f"output_blocks.{i}.0.temopral_conv",
348
"new": f"up_blocks.{block_id}.temp_convs.{layer_in_block_id}",
349
}
350
assign_to_checkpoint(
351
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
352
)
353
354
output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
355
if ["conv.bias", "conv.weight"] in output_block_list.values():
356
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
357
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
358
f"output_blocks.{i}.{index}.conv.weight"
359
]
360
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
361
f"output_blocks.{i}.{index}.conv.bias"
362
]
363
364
# Clear attentions as they have been attributed above.
365
if len(attentions) == 2:
366
attentions = []
367
368
if len(attentions):
369
paths = renew_attention_paths(attentions)
370
meta_path = {
371
"old": f"output_blocks.{i}.1",
372
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
373
}
374
assign_to_checkpoint(
375
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
376
)
377
378
if len(temp_attentions):
379
paths = renew_attention_paths(temp_attentions)
380
meta_path = {
381
"old": f"output_blocks.{i}.2",
382
"new": f"up_blocks.{block_id}.temp_attentions.{layer_in_block_id}",
383
}
384
assign_to_checkpoint(
385
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
386
)
387
else:
388
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
389
for path in resnet_0_paths:
390
old_path = ".".join(["output_blocks", str(i), path["old"]])
391
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
392
new_checkpoint[new_path] = unet_state_dict[old_path]
393
394
temopral_conv_paths = [l for l in output_block_layers if "temopral_conv" in l]
395
for path in temopral_conv_paths:
396
pruned_path = path.split("temopral_conv.")[-1]
397
old_path = ".".join(["output_blocks", str(i), str(block_id), "temopral_conv", pruned_path])
398
new_path = ".".join(["up_blocks", str(block_id), "temp_convs", str(layer_in_block_id), pruned_path])
399
new_checkpoint[new_path] = unet_state_dict[old_path]
400
401
return new_checkpoint
402
403
404
if __name__ == "__main__":
405
parser = argparse.ArgumentParser()
406
407
parser.add_argument(
408
"--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
409
)
410
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
411
args = parser.parse_args()
412
413
unet_checkpoint = torch.load(args.checkpoint_path, map_location="cpu")
414
unet = UNet3DConditionModel()
415
416
converted_ckpt = convert_ldm_unet_checkpoint(unet_checkpoint, unet.config)
417
418
diff_0 = set(unet.state_dict().keys()) - set(converted_ckpt.keys())
419
diff_1 = set(converted_ckpt.keys()) - set(unet.state_dict().keys())
420
421
assert len(diff_0) == len(diff_1) == 0, "Converted weights don't match"
422
423
# load state_dict
424
unet.load_state_dict(converted_ckpt)
425
426
unet.save_pretrained(args.dump_path)
427
428
# -- finish converting the unet --
429
430