Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/scripts/convert_original_audioldm_to_diffusers.py
1440 views
1
# coding=utf-8
2
# Copyright 2023 The HuggingFace Inc. team.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
""" Conversion script for the AudioLDM checkpoints."""
16
17
import argparse
18
import re
19
20
import torch
21
from transformers import (
22
AutoTokenizer,
23
ClapTextConfig,
24
ClapTextModelWithProjection,
25
SpeechT5HifiGan,
26
SpeechT5HifiGanConfig,
27
)
28
29
from diffusers import (
30
AudioLDMPipeline,
31
AutoencoderKL,
32
DDIMScheduler,
33
DPMSolverMultistepScheduler,
34
EulerAncestralDiscreteScheduler,
35
EulerDiscreteScheduler,
36
HeunDiscreteScheduler,
37
LMSDiscreteScheduler,
38
PNDMScheduler,
39
UNet2DConditionModel,
40
)
41
from diffusers.utils import is_omegaconf_available, is_safetensors_available
42
from diffusers.utils.import_utils import BACKENDS_MAPPING
43
44
45
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.shave_segments
46
def shave_segments(path, n_shave_prefix_segments=1):
47
"""
48
Removes segments. Positive values shave the first segments, negative shave the last segments.
49
"""
50
if n_shave_prefix_segments >= 0:
51
return ".".join(path.split(".")[n_shave_prefix_segments:])
52
else:
53
return ".".join(path.split(".")[:n_shave_prefix_segments])
54
55
56
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.renew_resnet_paths
57
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
58
"""
59
Updates paths inside resnets to the new naming scheme (local renaming)
60
"""
61
mapping = []
62
for old_item in old_list:
63
new_item = old_item.replace("in_layers.0", "norm1")
64
new_item = new_item.replace("in_layers.2", "conv1")
65
66
new_item = new_item.replace("out_layers.0", "norm2")
67
new_item = new_item.replace("out_layers.3", "conv2")
68
69
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
70
new_item = new_item.replace("skip_connection", "conv_shortcut")
71
72
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
73
74
mapping.append({"old": old_item, "new": new_item})
75
76
return mapping
77
78
79
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.renew_vae_resnet_paths
80
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
81
"""
82
Updates paths inside resnets to the new naming scheme (local renaming)
83
"""
84
mapping = []
85
for old_item in old_list:
86
new_item = old_item
87
88
new_item = new_item.replace("nin_shortcut", "conv_shortcut")
89
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
90
91
mapping.append({"old": old_item, "new": new_item})
92
93
return mapping
94
95
96
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.renew_attention_paths
97
def renew_attention_paths(old_list):
98
"""
99
Updates paths inside attentions to the new naming scheme (local renaming)
100
"""
101
mapping = []
102
for old_item in old_list:
103
new_item = old_item
104
105
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
106
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
107
108
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
109
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
110
111
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
112
113
mapping.append({"old": old_item, "new": new_item})
114
115
return mapping
116
117
118
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.renew_vae_attention_paths
119
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
120
"""
121
Updates paths inside attentions to the new naming scheme (local renaming)
122
"""
123
mapping = []
124
for old_item in old_list:
125
new_item = old_item
126
127
new_item = new_item.replace("norm.weight", "group_norm.weight")
128
new_item = new_item.replace("norm.bias", "group_norm.bias")
129
130
new_item = new_item.replace("q.weight", "query.weight")
131
new_item = new_item.replace("q.bias", "query.bias")
132
133
new_item = new_item.replace("k.weight", "key.weight")
134
new_item = new_item.replace("k.bias", "key.bias")
135
136
new_item = new_item.replace("v.weight", "value.weight")
137
new_item = new_item.replace("v.bias", "value.bias")
138
139
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
140
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
141
142
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
143
144
mapping.append({"old": old_item, "new": new_item})
145
146
return mapping
147
148
149
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.assign_to_checkpoint
150
def assign_to_checkpoint(
151
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
152
):
153
"""
154
This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
155
attention layers, and takes into account additional replacements that may arise.
156
157
Assigns the weights to the new checkpoint.
158
"""
159
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
160
161
# Splits the attention layers into three variables.
162
if attention_paths_to_split is not None:
163
for path, path_map in attention_paths_to_split.items():
164
old_tensor = old_checkpoint[path]
165
channels = old_tensor.shape[0] // 3
166
167
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
168
169
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
170
171
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
172
query, key, value = old_tensor.split(channels // num_heads, dim=1)
173
174
checkpoint[path_map["query"]] = query.reshape(target_shape)
175
checkpoint[path_map["key"]] = key.reshape(target_shape)
176
checkpoint[path_map["value"]] = value.reshape(target_shape)
177
178
for path in paths:
179
new_path = path["new"]
180
181
# These have already been assigned
182
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
183
continue
184
185
# Global renaming happens here
186
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
187
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
188
new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
189
190
if additional_replacements is not None:
191
for replacement in additional_replacements:
192
new_path = new_path.replace(replacement["old"], replacement["new"])
193
194
# proj_attn.weight has to be converted from conv 1D to linear
195
if "proj_attn.weight" in new_path:
196
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
197
else:
198
checkpoint[new_path] = old_checkpoint[path["old"]]
199
200
201
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.conv_attn_to_linear
202
def conv_attn_to_linear(checkpoint):
203
keys = list(checkpoint.keys())
204
attn_keys = ["query.weight", "key.weight", "value.weight"]
205
for key in keys:
206
if ".".join(key.split(".")[-2:]) in attn_keys:
207
if checkpoint[key].ndim > 2:
208
checkpoint[key] = checkpoint[key][:, :, 0, 0]
209
elif "proj_attn.weight" in key:
210
if checkpoint[key].ndim > 2:
211
checkpoint[key] = checkpoint[key][:, :, 0]
212
213
214
def create_unet_diffusers_config(original_config, image_size: int):
215
"""
216
Creates a UNet config for diffusers based on the config of the original AudioLDM model.
217
"""
218
unet_params = original_config.model.params.unet_config.params
219
vae_params = original_config.model.params.first_stage_config.params.ddconfig
220
221
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
222
223
down_block_types = []
224
resolution = 1
225
for i in range(len(block_out_channels)):
226
block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
227
down_block_types.append(block_type)
228
if i != len(block_out_channels) - 1:
229
resolution *= 2
230
231
up_block_types = []
232
for i in range(len(block_out_channels)):
233
block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
234
up_block_types.append(block_type)
235
resolution //= 2
236
237
vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
238
239
cross_attention_dim = (
240
unet_params.cross_attention_dim if "cross_attention_dim" in unet_params else block_out_channels
241
)
242
243
class_embed_type = "simple_projection" if "extra_film_condition_dim" in unet_params else None
244
projection_class_embeddings_input_dim = (
245
unet_params.extra_film_condition_dim if "extra_film_condition_dim" in unet_params else None
246
)
247
class_embeddings_concat = unet_params.extra_film_use_concat if "extra_film_use_concat" in unet_params else None
248
249
config = dict(
250
sample_size=image_size // vae_scale_factor,
251
in_channels=unet_params.in_channels,
252
out_channels=unet_params.out_channels,
253
down_block_types=tuple(down_block_types),
254
up_block_types=tuple(up_block_types),
255
block_out_channels=tuple(block_out_channels),
256
layers_per_block=unet_params.num_res_blocks,
257
cross_attention_dim=cross_attention_dim,
258
class_embed_type=class_embed_type,
259
projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
260
class_embeddings_concat=class_embeddings_concat,
261
)
262
263
return config
264
265
266
# Adapted from diffusers.pipelines.stable_diffusion.convert_from_ckpt.create_vae_diffusers_config
267
def create_vae_diffusers_config(original_config, checkpoint, image_size: int):
268
"""
269
Creates a VAE config for diffusers based on the config of the original AudioLDM model. Compared to the original
270
Stable Diffusion conversion, this function passes a *learnt* VAE scaling factor to the diffusers VAE.
271
"""
272
vae_params = original_config.model.params.first_stage_config.params.ddconfig
273
_ = original_config.model.params.first_stage_config.params.embed_dim
274
275
block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
276
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
277
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
278
279
scaling_factor = checkpoint["scale_factor"] if "scale_by_std" in original_config.model.params else 0.18215
280
281
config = dict(
282
sample_size=image_size,
283
in_channels=vae_params.in_channels,
284
out_channels=vae_params.out_ch,
285
down_block_types=tuple(down_block_types),
286
up_block_types=tuple(up_block_types),
287
block_out_channels=tuple(block_out_channels),
288
latent_channels=vae_params.z_channels,
289
layers_per_block=vae_params.num_res_blocks,
290
scaling_factor=float(scaling_factor),
291
)
292
return config
293
294
295
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.create_diffusers_schedular
296
def create_diffusers_schedular(original_config):
297
schedular = DDIMScheduler(
298
num_train_timesteps=original_config.model.params.timesteps,
299
beta_start=original_config.model.params.linear_start,
300
beta_end=original_config.model.params.linear_end,
301
beta_schedule="scaled_linear",
302
)
303
return schedular
304
305
306
# Adapted from diffusers.pipelines.stable_diffusion.convert_from_ckpt.convert_ldm_unet_checkpoint
307
def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False):
308
"""
309
Takes a state dict and a config, and returns a converted checkpoint. Compared to the original Stable Diffusion
310
conversion, this function additionally converts the learnt film embedding linear layer.
311
"""
312
313
# extract state_dict for UNet
314
unet_state_dict = {}
315
keys = list(checkpoint.keys())
316
317
unet_key = "model.diffusion_model."
318
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
319
if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
320
print(f"Checkpoint {path} has both EMA and non-EMA weights.")
321
print(
322
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
323
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
324
)
325
for key in keys:
326
if key.startswith("model.diffusion_model"):
327
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
328
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
329
else:
330
if sum(k.startswith("model_ema") for k in keys) > 100:
331
print(
332
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
333
" weights (usually better for inference), please make sure to add the `--extract_ema` flag."
334
)
335
336
for key in keys:
337
if key.startswith(unet_key):
338
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
339
340
new_checkpoint = {}
341
342
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
343
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
344
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
345
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
346
347
new_checkpoint["class_embedding.weight"] = unet_state_dict["film_emb.weight"]
348
new_checkpoint["class_embedding.bias"] = unet_state_dict["film_emb.bias"]
349
350
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
351
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
352
353
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
354
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
355
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
356
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
357
358
# Retrieves the keys for the input blocks only
359
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
360
input_blocks = {
361
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
362
for layer_id in range(num_input_blocks)
363
}
364
365
# Retrieves the keys for the middle blocks only
366
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
367
middle_blocks = {
368
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
369
for layer_id in range(num_middle_blocks)
370
}
371
372
# Retrieves the keys for the output blocks only
373
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
374
output_blocks = {
375
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
376
for layer_id in range(num_output_blocks)
377
}
378
379
for i in range(1, num_input_blocks):
380
block_id = (i - 1) // (config["layers_per_block"] + 1)
381
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
382
383
resnets = [
384
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
385
]
386
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
387
388
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
389
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
390
f"input_blocks.{i}.0.op.weight"
391
)
392
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
393
f"input_blocks.{i}.0.op.bias"
394
)
395
396
paths = renew_resnet_paths(resnets)
397
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
398
assign_to_checkpoint(
399
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
400
)
401
402
if len(attentions):
403
paths = renew_attention_paths(attentions)
404
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
405
assign_to_checkpoint(
406
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
407
)
408
409
resnet_0 = middle_blocks[0]
410
attentions = middle_blocks[1]
411
resnet_1 = middle_blocks[2]
412
413
resnet_0_paths = renew_resnet_paths(resnet_0)
414
assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
415
416
resnet_1_paths = renew_resnet_paths(resnet_1)
417
assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
418
419
attentions_paths = renew_attention_paths(attentions)
420
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
421
assign_to_checkpoint(
422
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
423
)
424
425
for i in range(num_output_blocks):
426
block_id = i // (config["layers_per_block"] + 1)
427
layer_in_block_id = i % (config["layers_per_block"] + 1)
428
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
429
output_block_list = {}
430
431
for layer in output_block_layers:
432
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
433
if layer_id in output_block_list:
434
output_block_list[layer_id].append(layer_name)
435
else:
436
output_block_list[layer_id] = [layer_name]
437
438
if len(output_block_list) > 1:
439
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
440
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
441
442
resnet_0_paths = renew_resnet_paths(resnets)
443
paths = renew_resnet_paths(resnets)
444
445
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
446
assign_to_checkpoint(
447
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
448
)
449
450
output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
451
if ["conv.bias", "conv.weight"] in output_block_list.values():
452
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
453
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
454
f"output_blocks.{i}.{index}.conv.weight"
455
]
456
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
457
f"output_blocks.{i}.{index}.conv.bias"
458
]
459
460
# Clear attentions as they have been attributed above.
461
if len(attentions) == 2:
462
attentions = []
463
464
if len(attentions):
465
paths = renew_attention_paths(attentions)
466
meta_path = {
467
"old": f"output_blocks.{i}.1",
468
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
469
}
470
assign_to_checkpoint(
471
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
472
)
473
else:
474
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
475
for path in resnet_0_paths:
476
old_path = ".".join(["output_blocks", str(i), path["old"]])
477
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
478
479
new_checkpoint[new_path] = unet_state_dict[old_path]
480
481
return new_checkpoint
482
483
484
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.convert_ldm_vae_checkpoint
485
def convert_ldm_vae_checkpoint(checkpoint, config):
486
# extract state dict for VAE
487
vae_state_dict = {}
488
vae_key = "first_stage_model."
489
keys = list(checkpoint.keys())
490
for key in keys:
491
if key.startswith(vae_key):
492
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
493
494
new_checkpoint = {}
495
496
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
497
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
498
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
499
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
500
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
501
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
502
503
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
504
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
505
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
506
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
507
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
508
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
509
510
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
511
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
512
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
513
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
514
515
# Retrieves the keys for the encoder down blocks only
516
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
517
down_blocks = {
518
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
519
}
520
521
# Retrieves the keys for the decoder up blocks only
522
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
523
up_blocks = {
524
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
525
}
526
527
for i in range(num_down_blocks):
528
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
529
530
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
531
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
532
f"encoder.down.{i}.downsample.conv.weight"
533
)
534
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
535
f"encoder.down.{i}.downsample.conv.bias"
536
)
537
538
paths = renew_vae_resnet_paths(resnets)
539
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
540
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
541
542
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
543
num_mid_res_blocks = 2
544
for i in range(1, num_mid_res_blocks + 1):
545
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
546
547
paths = renew_vae_resnet_paths(resnets)
548
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
549
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
550
551
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
552
paths = renew_vae_attention_paths(mid_attentions)
553
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
554
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
555
conv_attn_to_linear(new_checkpoint)
556
557
for i in range(num_up_blocks):
558
block_id = num_up_blocks - 1 - i
559
resnets = [
560
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
561
]
562
563
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
564
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
565
f"decoder.up.{block_id}.upsample.conv.weight"
566
]
567
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
568
f"decoder.up.{block_id}.upsample.conv.bias"
569
]
570
571
paths = renew_vae_resnet_paths(resnets)
572
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
573
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
574
575
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
576
num_mid_res_blocks = 2
577
for i in range(1, num_mid_res_blocks + 1):
578
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
579
580
paths = renew_vae_resnet_paths(resnets)
581
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
582
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
583
584
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
585
paths = renew_vae_attention_paths(mid_attentions)
586
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
587
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
588
conv_attn_to_linear(new_checkpoint)
589
return new_checkpoint
590
591
592
CLAP_KEYS_TO_MODIFY_MAPPING = {
593
"text_branch": "text_model",
594
"attn": "attention.self",
595
"self.proj": "output.dense",
596
"attention.self_mask": "attn_mask",
597
"mlp.fc1": "intermediate.dense",
598
"mlp.fc2": "output.dense",
599
"norm1": "layernorm_before",
600
"norm2": "layernorm_after",
601
"bn0": "batch_norm",
602
}
603
604
CLAP_KEYS_TO_IGNORE = ["text_transform"]
605
606
CLAP_EXPECTED_MISSING_KEYS = ["text_model.embeddings.token_type_ids"]
607
608
609
def convert_open_clap_checkpoint(checkpoint):
610
"""
611
Takes a state dict and returns a converted CLAP checkpoint.
612
"""
613
# extract state dict for CLAP text embedding model, discarding the audio component
614
model_state_dict = {}
615
model_key = "cond_stage_model.model.text_"
616
keys = list(checkpoint.keys())
617
for key in keys:
618
if key.startswith(model_key):
619
model_state_dict[key.replace(model_key, "text_")] = checkpoint.get(key)
620
621
new_checkpoint = {}
622
623
sequential_layers_pattern = r".*sequential.(\d+).*"
624
text_projection_pattern = r".*_projection.(\d+).*"
625
626
for key, value in model_state_dict.items():
627
# check if key should be ignored in mapping
628
if key.split(".")[0] in CLAP_KEYS_TO_IGNORE:
629
continue
630
631
# check if any key needs to be modified
632
for key_to_modify, new_key in CLAP_KEYS_TO_MODIFY_MAPPING.items():
633
if key_to_modify in key:
634
key = key.replace(key_to_modify, new_key)
635
636
if re.match(sequential_layers_pattern, key):
637
# replace sequential layers with list
638
sequential_layer = re.match(sequential_layers_pattern, key).group(1)
639
640
key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer)//3}.linear.")
641
elif re.match(text_projection_pattern, key):
642
projecton_layer = int(re.match(text_projection_pattern, key).group(1))
643
644
# Because in CLAP they use `nn.Sequential`...
645
transformers_projection_layer = 1 if projecton_layer == 0 else 2
646
647
key = key.replace(f"_projection.{projecton_layer}.", f"_projection.linear{transformers_projection_layer}.")
648
649
if "audio" and "qkv" in key:
650
# split qkv into query key and value
651
mixed_qkv = value
652
qkv_dim = mixed_qkv.size(0) // 3
653
654
query_layer = mixed_qkv[:qkv_dim]
655
key_layer = mixed_qkv[qkv_dim : qkv_dim * 2]
656
value_layer = mixed_qkv[qkv_dim * 2 :]
657
658
new_checkpoint[key.replace("qkv", "query")] = query_layer
659
new_checkpoint[key.replace("qkv", "key")] = key_layer
660
new_checkpoint[key.replace("qkv", "value")] = value_layer
661
else:
662
new_checkpoint[key] = value
663
664
return new_checkpoint
665
666
667
def create_transformers_vocoder_config(original_config):
668
"""
669
Creates a config for transformers SpeechT5HifiGan based on the config of the vocoder model.
670
"""
671
vocoder_params = original_config.model.params.vocoder_config.params
672
673
config = dict(
674
model_in_dim=vocoder_params.num_mels,
675
sampling_rate=vocoder_params.sampling_rate,
676
upsample_initial_channel=vocoder_params.upsample_initial_channel,
677
upsample_rates=list(vocoder_params.upsample_rates),
678
upsample_kernel_sizes=list(vocoder_params.upsample_kernel_sizes),
679
resblock_kernel_sizes=list(vocoder_params.resblock_kernel_sizes),
680
resblock_dilation_sizes=[
681
list(resblock_dilation) for resblock_dilation in vocoder_params.resblock_dilation_sizes
682
],
683
normalize_before=False,
684
)
685
686
return config
687
688
689
def convert_hifigan_checkpoint(checkpoint, config):
690
"""
691
Takes a state dict and config, and returns a converted HiFiGAN vocoder checkpoint.
692
"""
693
# extract state dict for vocoder
694
vocoder_state_dict = {}
695
vocoder_key = "first_stage_model.vocoder."
696
keys = list(checkpoint.keys())
697
for key in keys:
698
if key.startswith(vocoder_key):
699
vocoder_state_dict[key.replace(vocoder_key, "")] = checkpoint.get(key)
700
701
# fix upsampler keys, everything else is correct already
702
for i in range(len(config.upsample_rates)):
703
vocoder_state_dict[f"upsampler.{i}.weight"] = vocoder_state_dict.pop(f"ups.{i}.weight")
704
vocoder_state_dict[f"upsampler.{i}.bias"] = vocoder_state_dict.pop(f"ups.{i}.bias")
705
706
if not config.normalize_before:
707
# if we don't set normalize_before then these variables are unused, so we set them to their initialised values
708
vocoder_state_dict["mean"] = torch.zeros(config.model_in_dim)
709
vocoder_state_dict["scale"] = torch.ones(config.model_in_dim)
710
711
return vocoder_state_dict
712
713
714
# Adapted from https://huggingface.co/spaces/haoheliu/audioldm-text-to-audio-generation/blob/84a0384742a22bd80c44e903e241f0623e874f1d/audioldm/utils.py#L72-L73
715
DEFAULT_CONFIG = {
716
"model": {
717
"params": {
718
"linear_start": 0.0015,
719
"linear_end": 0.0195,
720
"timesteps": 1000,
721
"channels": 8,
722
"scale_by_std": True,
723
"unet_config": {
724
"target": "audioldm.latent_diffusion.openaimodel.UNetModel",
725
"params": {
726
"extra_film_condition_dim": 512,
727
"extra_film_use_concat": True,
728
"in_channels": 8,
729
"out_channels": 8,
730
"model_channels": 128,
731
"attention_resolutions": [8, 4, 2],
732
"num_res_blocks": 2,
733
"channel_mult": [1, 2, 3, 5],
734
"num_head_channels": 32,
735
},
736
},
737
"first_stage_config": {
738
"target": "audioldm.variational_autoencoder.autoencoder.AutoencoderKL",
739
"params": {
740
"embed_dim": 8,
741
"ddconfig": {
742
"z_channels": 8,
743
"resolution": 256,
744
"in_channels": 1,
745
"out_ch": 1,
746
"ch": 128,
747
"ch_mult": [1, 2, 4],
748
"num_res_blocks": 2,
749
},
750
},
751
},
752
"vocoder_config": {
753
"target": "audioldm.first_stage_model.vocoder",
754
"params": {
755
"upsample_rates": [5, 4, 2, 2, 2],
756
"upsample_kernel_sizes": [16, 16, 8, 4, 4],
757
"upsample_initial_channel": 1024,
758
"resblock_kernel_sizes": [3, 7, 11],
759
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
760
"num_mels": 64,
761
"sampling_rate": 16000,
762
},
763
},
764
},
765
},
766
}
767
768
769
def load_pipeline_from_original_audioldm_ckpt(
770
checkpoint_path: str,
771
original_config_file: str = None,
772
image_size: int = 512,
773
prediction_type: str = None,
774
extract_ema: bool = False,
775
scheduler_type: str = "ddim",
776
num_in_channels: int = None,
777
device: str = None,
778
from_safetensors: bool = False,
779
) -> AudioLDMPipeline:
780
"""
781
Load an AudioLDM pipeline object from a `.ckpt`/`.safetensors` file and (ideally) a `.yaml` config file.
782
783
Although many of the arguments can be automatically inferred, some of these rely on brittle checks against the
784
global step count, which will likely fail for models that have undergone further fine-tuning. Therefore, it is
785
recommended that you override the default values and/or supply an `original_config_file` wherever possible.
786
787
:param checkpoint_path: Path to `.ckpt` file. :param original_config_file: Path to `.yaml` config file
788
corresponding to the original architecture.
789
If `None`, will be automatically instantiated based on default values.
790
:param image_size: The image size that the model was trained on. Use 512 for original AudioLDM checkpoints. :param
791
prediction_type: The prediction type that the model was trained on. Use `'epsilon'` for original
792
AudioLDM checkpoints.
793
:param num_in_channels: The number of input channels. If `None` number of input channels will be automatically
794
inferred.
795
:param scheduler_type: Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler",
796
"euler-ancestral", "dpm", "ddim"]`.
797
:param extract_ema: Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract
798
the EMA weights or not. Defaults to `False`. Pass `True` to extract the EMA weights. EMA weights usually
799
yield higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning.
800
:param device: The device to use. Pass `None` to determine automatically. :param from_safetensors: If
801
`checkpoint_path` is in `safetensors` format, load checkpoint with safetensors
802
instead of PyTorch.
803
:return: An AudioLDMPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
804
"""
805
806
if not is_omegaconf_available():
807
raise ValueError(BACKENDS_MAPPING["omegaconf"][1])
808
809
from omegaconf import OmegaConf
810
811
if from_safetensors:
812
if not is_safetensors_available():
813
raise ValueError(BACKENDS_MAPPING["safetensors"][1])
814
815
from safetensors import safe_open
816
817
checkpoint = {}
818
with safe_open(checkpoint_path, framework="pt", device="cpu") as f:
819
for key in f.keys():
820
checkpoint[key] = f.get_tensor(key)
821
else:
822
if device is None:
823
device = "cuda" if torch.cuda.is_available() else "cpu"
824
checkpoint = torch.load(checkpoint_path, map_location=device)
825
else:
826
checkpoint = torch.load(checkpoint_path, map_location=device)
827
828
if "state_dict" in checkpoint:
829
checkpoint = checkpoint["state_dict"]
830
831
if original_config_file is None:
832
original_config = DEFAULT_CONFIG
833
original_config = OmegaConf.create(original_config)
834
else:
835
original_config = OmegaConf.load(original_config_file)
836
837
if num_in_channels is not None:
838
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
839
840
if (
841
"parameterization" in original_config["model"]["params"]
842
and original_config["model"]["params"]["parameterization"] == "v"
843
):
844
if prediction_type is None:
845
prediction_type = "v_prediction"
846
else:
847
if prediction_type is None:
848
prediction_type = "epsilon"
849
850
if image_size is None:
851
image_size = 512
852
853
num_train_timesteps = original_config.model.params.timesteps
854
beta_start = original_config.model.params.linear_start
855
beta_end = original_config.model.params.linear_end
856
857
scheduler = DDIMScheduler(
858
beta_end=beta_end,
859
beta_schedule="scaled_linear",
860
beta_start=beta_start,
861
num_train_timesteps=num_train_timesteps,
862
steps_offset=1,
863
clip_sample=False,
864
set_alpha_to_one=False,
865
prediction_type=prediction_type,
866
)
867
# make sure scheduler works correctly with DDIM
868
scheduler.register_to_config(clip_sample=False)
869
870
if scheduler_type == "pndm":
871
config = dict(scheduler.config)
872
config["skip_prk_steps"] = True
873
scheduler = PNDMScheduler.from_config(config)
874
elif scheduler_type == "lms":
875
scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
876
elif scheduler_type == "heun":
877
scheduler = HeunDiscreteScheduler.from_config(scheduler.config)
878
elif scheduler_type == "euler":
879
scheduler = EulerDiscreteScheduler.from_config(scheduler.config)
880
elif scheduler_type == "euler-ancestral":
881
scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config)
882
elif scheduler_type == "dpm":
883
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
884
elif scheduler_type == "ddim":
885
scheduler = scheduler
886
else:
887
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
888
889
# Convert the UNet2DModel
890
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
891
unet = UNet2DConditionModel(**unet_config)
892
893
converted_unet_checkpoint = convert_ldm_unet_checkpoint(
894
checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema
895
)
896
897
unet.load_state_dict(converted_unet_checkpoint)
898
899
# Convert the VAE model
900
vae_config = create_vae_diffusers_config(original_config, checkpoint=checkpoint, image_size=image_size)
901
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
902
903
vae = AutoencoderKL(**vae_config)
904
vae.load_state_dict(converted_vae_checkpoint)
905
906
# Convert the text model
907
# AudioLDM uses the same configuration and tokenizer as the original CLAP model
908
config = ClapTextConfig.from_pretrained("laion/clap-htsat-unfused")
909
tokenizer = AutoTokenizer.from_pretrained("laion/clap-htsat-unfused")
910
911
converted_text_model = convert_open_clap_checkpoint(checkpoint)
912
text_model = ClapTextModelWithProjection(config)
913
914
missing_keys, unexpected_keys = text_model.load_state_dict(converted_text_model, strict=False)
915
# we expect not to have token_type_ids in our original state dict so let's ignore them
916
missing_keys = list(set(missing_keys) - set(CLAP_EXPECTED_MISSING_KEYS))
917
918
if len(unexpected_keys) > 0:
919
raise ValueError(f"Unexpected keys when loading CLAP model: {unexpected_keys}")
920
921
if len(missing_keys) > 0:
922
raise ValueError(f"Missing keys when loading CLAP model: {missing_keys}")
923
924
# Convert the vocoder model
925
vocoder_config = create_transformers_vocoder_config(original_config)
926
vocoder_config = SpeechT5HifiGanConfig(**vocoder_config)
927
converted_vocoder_checkpoint = convert_hifigan_checkpoint(checkpoint, vocoder_config)
928
929
vocoder = SpeechT5HifiGan(vocoder_config)
930
vocoder.load_state_dict(converted_vocoder_checkpoint)
931
932
# Instantiate the diffusers pipeline
933
pipe = AudioLDMPipeline(
934
vae=vae,
935
text_encoder=text_model,
936
tokenizer=tokenizer,
937
unet=unet,
938
scheduler=scheduler,
939
vocoder=vocoder,
940
)
941
942
return pipe
943
944
945
if __name__ == "__main__":
946
parser = argparse.ArgumentParser()
947
948
parser.add_argument(
949
"--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
950
)
951
parser.add_argument(
952
"--original_config_file",
953
default=None,
954
type=str,
955
help="The YAML config file corresponding to the original architecture.",
956
)
957
parser.add_argument(
958
"--num_in_channels",
959
default=None,
960
type=int,
961
help="The number of input channels. If `None` number of input channels will be automatically inferred.",
962
)
963
parser.add_argument(
964
"--scheduler_type",
965
default="ddim",
966
type=str,
967
help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim', 'euler', 'euler-ancestral', 'dpm']",
968
)
969
parser.add_argument(
970
"--image_size",
971
default=None,
972
type=int,
973
help=("The image size that the model was trained on."),
974
)
975
parser.add_argument(
976
"--prediction_type",
977
default=None,
978
type=str,
979
help=("The prediction type that the model was trained on."),
980
)
981
parser.add_argument(
982
"--extract_ema",
983
action="store_true",
984
help=(
985
"Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights"
986
" or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield"
987
" higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning."
988
),
989
)
990
parser.add_argument(
991
"--from_safetensors",
992
action="store_true",
993
help="If `--checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.",
994
)
995
parser.add_argument(
996
"--to_safetensors",
997
action="store_true",
998
help="Whether to store pipeline in safetensors format or not.",
999
)
1000
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
1001
parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
1002
args = parser.parse_args()
1003
1004
pipe = load_pipeline_from_original_audioldm_ckpt(
1005
checkpoint_path=args.checkpoint_path,
1006
original_config_file=args.original_config_file,
1007
image_size=args.image_size,
1008
prediction_type=args.prediction_type,
1009
extract_ema=args.extract_ema,
1010
scheduler_type=args.scheduler_type,
1011
num_in_channels=args.num_in_channels,
1012
from_safetensors=args.from_safetensors,
1013
device=args.device,
1014
)
1015
pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
1016
1017