Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/scripts/convert_dance_diffusion_to_diffusers.py
1440 views
1
#!/usr/bin/env python3
2
import argparse
3
import math
4
import os
5
from copy import deepcopy
6
7
import torch
8
from audio_diffusion.models import DiffusionAttnUnet1D
9
from diffusion import sampling
10
from torch import nn
11
12
from diffusers import DanceDiffusionPipeline, IPNDMScheduler, UNet1DModel
13
14
15
MODELS_MAP = {
16
"gwf-440k": {
17
"url": "https://model-server.zqevans2.workers.dev/gwf-440k.ckpt",
18
"sample_rate": 48000,
19
"sample_size": 65536,
20
},
21
"jmann-small-190k": {
22
"url": "https://model-server.zqevans2.workers.dev/jmann-small-190k.ckpt",
23
"sample_rate": 48000,
24
"sample_size": 65536,
25
},
26
"jmann-large-580k": {
27
"url": "https://model-server.zqevans2.workers.dev/jmann-large-580k.ckpt",
28
"sample_rate": 48000,
29
"sample_size": 131072,
30
},
31
"maestro-uncond-150k": {
32
"url": "https://model-server.zqevans2.workers.dev/maestro-uncond-150k.ckpt",
33
"sample_rate": 16000,
34
"sample_size": 65536,
35
},
36
"unlocked-uncond-250k": {
37
"url": "https://model-server.zqevans2.workers.dev/unlocked-uncond-250k.ckpt",
38
"sample_rate": 16000,
39
"sample_size": 65536,
40
},
41
"honk-140k": {
42
"url": "https://model-server.zqevans2.workers.dev/honk-140k.ckpt",
43
"sample_rate": 16000,
44
"sample_size": 65536,
45
},
46
}
47
48
49
def alpha_sigma_to_t(alpha, sigma):
50
"""Returns a timestep, given the scaling factors for the clean image and for
51
the noise."""
52
return torch.atan2(sigma, alpha) / math.pi * 2
53
54
55
def get_crash_schedule(t):
56
sigma = torch.sin(t * math.pi / 2) ** 2
57
alpha = (1 - sigma**2) ** 0.5
58
return alpha_sigma_to_t(alpha, sigma)
59
60
61
class Object(object):
62
pass
63
64
65
class DiffusionUncond(nn.Module):
66
def __init__(self, global_args):
67
super().__init__()
68
69
self.diffusion = DiffusionAttnUnet1D(global_args, n_attn_layers=4)
70
self.diffusion_ema = deepcopy(self.diffusion)
71
self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
72
73
74
def download(model_name):
75
url = MODELS_MAP[model_name]["url"]
76
os.system(f"wget {url} ./")
77
78
return f"./{model_name}.ckpt"
79
80
81
DOWN_NUM_TO_LAYER = {
82
"1": "resnets.0",
83
"2": "attentions.0",
84
"3": "resnets.1",
85
"4": "attentions.1",
86
"5": "resnets.2",
87
"6": "attentions.2",
88
}
89
UP_NUM_TO_LAYER = {
90
"8": "resnets.0",
91
"9": "attentions.0",
92
"10": "resnets.1",
93
"11": "attentions.1",
94
"12": "resnets.2",
95
"13": "attentions.2",
96
}
97
MID_NUM_TO_LAYER = {
98
"1": "resnets.0",
99
"2": "attentions.0",
100
"3": "resnets.1",
101
"4": "attentions.1",
102
"5": "resnets.2",
103
"6": "attentions.2",
104
"8": "resnets.3",
105
"9": "attentions.3",
106
"10": "resnets.4",
107
"11": "attentions.4",
108
"12": "resnets.5",
109
"13": "attentions.5",
110
}
111
DEPTH_0_TO_LAYER = {
112
"0": "resnets.0",
113
"1": "resnets.1",
114
"2": "resnets.2",
115
"4": "resnets.0",
116
"5": "resnets.1",
117
"6": "resnets.2",
118
}
119
120
RES_CONV_MAP = {
121
"skip": "conv_skip",
122
"main.0": "conv_1",
123
"main.1": "group_norm_1",
124
"main.3": "conv_2",
125
"main.4": "group_norm_2",
126
}
127
128
ATTN_MAP = {
129
"norm": "group_norm",
130
"qkv_proj": ["query", "key", "value"],
131
"out_proj": ["proj_attn"],
132
}
133
134
135
def convert_resconv_naming(name):
136
if name.startswith("skip"):
137
return name.replace("skip", RES_CONV_MAP["skip"])
138
139
# name has to be of format main.{digit}
140
if not name.startswith("main."):
141
raise ValueError(f"ResConvBlock error with {name}")
142
143
return name.replace(name[:6], RES_CONV_MAP[name[:6]])
144
145
146
def convert_attn_naming(name):
147
for key, value in ATTN_MAP.items():
148
if name.startswith(key) and not isinstance(value, list):
149
return name.replace(key, value)
150
elif name.startswith(key):
151
return [name.replace(key, v) for v in value]
152
raise ValueError(f"Attn error with {name}")
153
154
155
def rename(input_string, max_depth=13):
156
string = input_string
157
158
if string.split(".")[0] == "timestep_embed":
159
return string.replace("timestep_embed", "time_proj")
160
161
depth = 0
162
if string.startswith("net.3."):
163
depth += 1
164
string = string[6:]
165
elif string.startswith("net."):
166
string = string[4:]
167
168
while string.startswith("main.7."):
169
depth += 1
170
string = string[7:]
171
172
if string.startswith("main."):
173
string = string[5:]
174
175
# mid block
176
if string[:2].isdigit():
177
layer_num = string[:2]
178
string_left = string[2:]
179
else:
180
layer_num = string[0]
181
string_left = string[1:]
182
183
if depth == max_depth:
184
new_layer = MID_NUM_TO_LAYER[layer_num]
185
prefix = "mid_block"
186
elif depth > 0 and int(layer_num) < 7:
187
new_layer = DOWN_NUM_TO_LAYER[layer_num]
188
prefix = f"down_blocks.{depth}"
189
elif depth > 0 and int(layer_num) > 7:
190
new_layer = UP_NUM_TO_LAYER[layer_num]
191
prefix = f"up_blocks.{max_depth - depth - 1}"
192
elif depth == 0:
193
new_layer = DEPTH_0_TO_LAYER[layer_num]
194
prefix = f"up_blocks.{max_depth - 1}" if int(layer_num) > 3 else "down_blocks.0"
195
196
if not string_left.startswith("."):
197
raise ValueError(f"Naming error with {input_string} and string_left: {string_left}.")
198
199
string_left = string_left[1:]
200
201
if "resnets" in new_layer:
202
string_left = convert_resconv_naming(string_left)
203
elif "attentions" in new_layer:
204
new_string_left = convert_attn_naming(string_left)
205
string_left = new_string_left
206
207
if not isinstance(string_left, list):
208
new_string = prefix + "." + new_layer + "." + string_left
209
else:
210
new_string = [prefix + "." + new_layer + "." + s for s in string_left]
211
return new_string
212
213
214
def rename_orig_weights(state_dict):
215
new_state_dict = {}
216
for k, v in state_dict.items():
217
if k.endswith("kernel"):
218
# up- and downsample layers, don't have trainable weights
219
continue
220
221
new_k = rename(k)
222
223
# check if we need to transform from Conv => Linear for attention
224
if isinstance(new_k, list):
225
new_state_dict = transform_conv_attns(new_state_dict, new_k, v)
226
else:
227
new_state_dict[new_k] = v
228
229
return new_state_dict
230
231
232
def transform_conv_attns(new_state_dict, new_k, v):
233
if len(new_k) == 1:
234
if len(v.shape) == 3:
235
# weight
236
new_state_dict[new_k[0]] = v[:, :, 0]
237
else:
238
# bias
239
new_state_dict[new_k[0]] = v
240
else:
241
# qkv matrices
242
trippled_shape = v.shape[0]
243
single_shape = trippled_shape // 3
244
for i in range(3):
245
if len(v.shape) == 3:
246
new_state_dict[new_k[i]] = v[i * single_shape : (i + 1) * single_shape, :, 0]
247
else:
248
new_state_dict[new_k[i]] = v[i * single_shape : (i + 1) * single_shape]
249
return new_state_dict
250
251
252
def main(args):
253
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
254
255
model_name = args.model_path.split("/")[-1].split(".")[0]
256
if not os.path.isfile(args.model_path):
257
assert (
258
model_name == args.model_path
259
), f"Make sure to provide one of the official model names {MODELS_MAP.keys()}"
260
args.model_path = download(model_name)
261
262
sample_rate = MODELS_MAP[model_name]["sample_rate"]
263
sample_size = MODELS_MAP[model_name]["sample_size"]
264
265
config = Object()
266
config.sample_size = sample_size
267
config.sample_rate = sample_rate
268
config.latent_dim = 0
269
270
diffusers_model = UNet1DModel(sample_size=sample_size, sample_rate=sample_rate)
271
diffusers_state_dict = diffusers_model.state_dict()
272
273
orig_model = DiffusionUncond(config)
274
orig_model.load_state_dict(torch.load(args.model_path, map_location=device)["state_dict"])
275
orig_model = orig_model.diffusion_ema.eval()
276
orig_model_state_dict = orig_model.state_dict()
277
renamed_state_dict = rename_orig_weights(orig_model_state_dict)
278
279
renamed_minus_diffusers = set(renamed_state_dict.keys()) - set(diffusers_state_dict.keys())
280
diffusers_minus_renamed = set(diffusers_state_dict.keys()) - set(renamed_state_dict.keys())
281
282
assert len(renamed_minus_diffusers) == 0, f"Problem with {renamed_minus_diffusers}"
283
assert all(k.endswith("kernel") for k in list(diffusers_minus_renamed)), f"Problem with {diffusers_minus_renamed}"
284
285
for key, value in renamed_state_dict.items():
286
assert (
287
diffusers_state_dict[key].squeeze().shape == value.squeeze().shape
288
), f"Shape for {key} doesn't match. Diffusers: {diffusers_state_dict[key].shape} vs. {value.shape}"
289
if key == "time_proj.weight":
290
value = value.squeeze()
291
292
diffusers_state_dict[key] = value
293
294
diffusers_model.load_state_dict(diffusers_state_dict)
295
296
steps = 100
297
seed = 33
298
299
diffusers_scheduler = IPNDMScheduler(num_train_timesteps=steps)
300
301
generator = torch.manual_seed(seed)
302
noise = torch.randn([1, 2, config.sample_size], generator=generator).to(device)
303
304
t = torch.linspace(1, 0, steps + 1, device=device)[:-1]
305
step_list = get_crash_schedule(t)
306
307
pipe = DanceDiffusionPipeline(unet=diffusers_model, scheduler=diffusers_scheduler)
308
309
generator = torch.manual_seed(33)
310
audio = pipe(num_inference_steps=steps, generator=generator).audios
311
312
generated = sampling.iplms_sample(orig_model, noise, step_list, {})
313
generated = generated.clamp(-1, 1)
314
315
diff_sum = (generated - audio).abs().sum()
316
diff_max = (generated - audio).abs().max()
317
318
if args.save:
319
pipe.save_pretrained(args.checkpoint_path)
320
321
print("Diff sum", diff_sum)
322
print("Diff max", diff_max)
323
324
assert diff_max < 1e-3, f"Diff max: {diff_max} is too much :-/"
325
326
print(f"Conversion for {model_name} successful!")
327
328
329
if __name__ == "__main__":
330
parser = argparse.ArgumentParser()
331
332
parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.")
333
parser.add_argument(
334
"--save", default=True, type=bool, required=False, help="Whether to save the converted model or not."
335
)
336
parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
337
args = parser.parse_args()
338
339
main(args)
340
341