Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/scripts/convert_music_spectrogram_to_diffusers.py
1440 views
1
#!/usr/bin/env python3
2
import argparse
3
import os
4
5
import jax as jnp
6
import numpy as onp
7
import torch
8
import torch.nn as nn
9
from music_spectrogram_diffusion import inference
10
from t5x import checkpoints
11
12
from diffusers import DDPMScheduler, OnnxRuntimeModel, SpectrogramDiffusionPipeline
13
from diffusers.pipelines.spectrogram_diffusion import SpectrogramContEncoder, SpectrogramNotesEncoder, T5FilmDecoder
14
15
16
MODEL = "base_with_context"
17
18
19
def load_notes_encoder(weights, model):
20
model.token_embedder.weight = nn.Parameter(torch.FloatTensor(weights["token_embedder"]["embedding"]))
21
model.position_encoding.weight = nn.Parameter(
22
torch.FloatTensor(weights["Embed_0"]["embedding"]), requires_grad=False
23
)
24
for lyr_num, lyr in enumerate(model.encoders):
25
ly_weight = weights[f"layers_{lyr_num}"]
26
lyr.layer[0].layer_norm.weight = nn.Parameter(
27
torch.FloatTensor(ly_weight["pre_attention_layer_norm"]["scale"])
28
)
29
30
attention_weights = ly_weight["attention"]
31
lyr.layer[0].SelfAttention.q.weight = nn.Parameter(torch.FloatTensor(attention_weights["query"]["kernel"].T))
32
lyr.layer[0].SelfAttention.k.weight = nn.Parameter(torch.FloatTensor(attention_weights["key"]["kernel"].T))
33
lyr.layer[0].SelfAttention.v.weight = nn.Parameter(torch.FloatTensor(attention_weights["value"]["kernel"].T))
34
lyr.layer[0].SelfAttention.o.weight = nn.Parameter(torch.FloatTensor(attention_weights["out"]["kernel"].T))
35
36
lyr.layer[1].layer_norm.weight = nn.Parameter(torch.FloatTensor(ly_weight["pre_mlp_layer_norm"]["scale"]))
37
38
lyr.layer[1].DenseReluDense.wi_0.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wi_0"]["kernel"].T))
39
lyr.layer[1].DenseReluDense.wi_1.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wi_1"]["kernel"].T))
40
lyr.layer[1].DenseReluDense.wo.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wo"]["kernel"].T))
41
42
model.layer_norm.weight = nn.Parameter(torch.FloatTensor(weights["encoder_norm"]["scale"]))
43
return model
44
45
46
def load_continuous_encoder(weights, model):
47
model.input_proj.weight = nn.Parameter(torch.FloatTensor(weights["input_proj"]["kernel"].T))
48
49
model.position_encoding.weight = nn.Parameter(
50
torch.FloatTensor(weights["Embed_0"]["embedding"]), requires_grad=False
51
)
52
53
for lyr_num, lyr in enumerate(model.encoders):
54
ly_weight = weights[f"layers_{lyr_num}"]
55
attention_weights = ly_weight["attention"]
56
57
lyr.layer[0].SelfAttention.q.weight = nn.Parameter(torch.FloatTensor(attention_weights["query"]["kernel"].T))
58
lyr.layer[0].SelfAttention.k.weight = nn.Parameter(torch.FloatTensor(attention_weights["key"]["kernel"].T))
59
lyr.layer[0].SelfAttention.v.weight = nn.Parameter(torch.FloatTensor(attention_weights["value"]["kernel"].T))
60
lyr.layer[0].SelfAttention.o.weight = nn.Parameter(torch.FloatTensor(attention_weights["out"]["kernel"].T))
61
lyr.layer[0].layer_norm.weight = nn.Parameter(
62
torch.FloatTensor(ly_weight["pre_attention_layer_norm"]["scale"])
63
)
64
65
lyr.layer[1].DenseReluDense.wi_0.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wi_0"]["kernel"].T))
66
lyr.layer[1].DenseReluDense.wi_1.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wi_1"]["kernel"].T))
67
lyr.layer[1].DenseReluDense.wo.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wo"]["kernel"].T))
68
lyr.layer[1].layer_norm.weight = nn.Parameter(torch.FloatTensor(ly_weight["pre_mlp_layer_norm"]["scale"]))
69
70
model.layer_norm.weight = nn.Parameter(torch.FloatTensor(weights["encoder_norm"]["scale"]))
71
72
return model
73
74
75
def load_decoder(weights, model):
76
model.conditioning_emb[0].weight = nn.Parameter(torch.FloatTensor(weights["time_emb_dense0"]["kernel"].T))
77
model.conditioning_emb[2].weight = nn.Parameter(torch.FloatTensor(weights["time_emb_dense1"]["kernel"].T))
78
79
model.position_encoding.weight = nn.Parameter(
80
torch.FloatTensor(weights["Embed_0"]["embedding"]), requires_grad=False
81
)
82
83
model.continuous_inputs_projection.weight = nn.Parameter(
84
torch.FloatTensor(weights["continuous_inputs_projection"]["kernel"].T)
85
)
86
87
for lyr_num, lyr in enumerate(model.decoders):
88
ly_weight = weights[f"layers_{lyr_num}"]
89
lyr.layer[0].layer_norm.weight = nn.Parameter(
90
torch.FloatTensor(ly_weight["pre_self_attention_layer_norm"]["scale"])
91
)
92
93
lyr.layer[0].FiLMLayer.scale_bias.weight = nn.Parameter(
94
torch.FloatTensor(ly_weight["FiLMLayer_0"]["DenseGeneral_0"]["kernel"].T)
95
)
96
97
attention_weights = ly_weight["self_attention"]
98
lyr.layer[0].attention.to_q.weight = nn.Parameter(torch.FloatTensor(attention_weights["query"]["kernel"].T))
99
lyr.layer[0].attention.to_k.weight = nn.Parameter(torch.FloatTensor(attention_weights["key"]["kernel"].T))
100
lyr.layer[0].attention.to_v.weight = nn.Parameter(torch.FloatTensor(attention_weights["value"]["kernel"].T))
101
lyr.layer[0].attention.to_out[0].weight = nn.Parameter(torch.FloatTensor(attention_weights["out"]["kernel"].T))
102
103
attention_weights = ly_weight["MultiHeadDotProductAttention_0"]
104
lyr.layer[1].attention.to_q.weight = nn.Parameter(torch.FloatTensor(attention_weights["query"]["kernel"].T))
105
lyr.layer[1].attention.to_k.weight = nn.Parameter(torch.FloatTensor(attention_weights["key"]["kernel"].T))
106
lyr.layer[1].attention.to_v.weight = nn.Parameter(torch.FloatTensor(attention_weights["value"]["kernel"].T))
107
lyr.layer[1].attention.to_out[0].weight = nn.Parameter(torch.FloatTensor(attention_weights["out"]["kernel"].T))
108
lyr.layer[1].layer_norm.weight = nn.Parameter(
109
torch.FloatTensor(ly_weight["pre_cross_attention_layer_norm"]["scale"])
110
)
111
112
lyr.layer[2].layer_norm.weight = nn.Parameter(torch.FloatTensor(ly_weight["pre_mlp_layer_norm"]["scale"]))
113
lyr.layer[2].film.scale_bias.weight = nn.Parameter(
114
torch.FloatTensor(ly_weight["FiLMLayer_1"]["DenseGeneral_0"]["kernel"].T)
115
)
116
lyr.layer[2].DenseReluDense.wi_0.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wi_0"]["kernel"].T))
117
lyr.layer[2].DenseReluDense.wi_1.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wi_1"]["kernel"].T))
118
lyr.layer[2].DenseReluDense.wo.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wo"]["kernel"].T))
119
120
model.decoder_norm.weight = nn.Parameter(torch.FloatTensor(weights["decoder_norm"]["scale"]))
121
122
model.spec_out.weight = nn.Parameter(torch.FloatTensor(weights["spec_out_dense"]["kernel"].T))
123
124
return model
125
126
127
def main(args):
128
t5_checkpoint = checkpoints.load_t5x_checkpoint(args.checkpoint_path)
129
t5_checkpoint = jnp.tree_util.tree_map(onp.array, t5_checkpoint)
130
131
gin_overrides = [
132
"from __gin__ import dynamic_registration",
133
"from music_spectrogram_diffusion.models.diffusion import diffusion_utils",
134
"diffusion_utils.ClassifierFreeGuidanceConfig.eval_condition_weight = 2.0",
135
"diffusion_utils.DiffusionConfig.classifier_free_guidance = @diffusion_utils.ClassifierFreeGuidanceConfig()",
136
]
137
138
gin_file = os.path.join(args.checkpoint_path, "..", "config.gin")
139
gin_config = inference.parse_training_gin_file(gin_file, gin_overrides)
140
synth_model = inference.InferenceModel(args.checkpoint_path, gin_config)
141
142
scheduler = DDPMScheduler(beta_schedule="squaredcos_cap_v2", variance_type="fixed_large")
143
144
notes_encoder = SpectrogramNotesEncoder(
145
max_length=synth_model.sequence_length["inputs"],
146
vocab_size=synth_model.model.module.config.vocab_size,
147
d_model=synth_model.model.module.config.emb_dim,
148
dropout_rate=synth_model.model.module.config.dropout_rate,
149
num_layers=synth_model.model.module.config.num_encoder_layers,
150
num_heads=synth_model.model.module.config.num_heads,
151
d_kv=synth_model.model.module.config.head_dim,
152
d_ff=synth_model.model.module.config.mlp_dim,
153
feed_forward_proj="gated-gelu",
154
)
155
156
continuous_encoder = SpectrogramContEncoder(
157
input_dims=synth_model.audio_codec.n_dims,
158
targets_context_length=synth_model.sequence_length["targets_context"],
159
d_model=synth_model.model.module.config.emb_dim,
160
dropout_rate=synth_model.model.module.config.dropout_rate,
161
num_layers=synth_model.model.module.config.num_encoder_layers,
162
num_heads=synth_model.model.module.config.num_heads,
163
d_kv=synth_model.model.module.config.head_dim,
164
d_ff=synth_model.model.module.config.mlp_dim,
165
feed_forward_proj="gated-gelu",
166
)
167
168
decoder = T5FilmDecoder(
169
input_dims=synth_model.audio_codec.n_dims,
170
targets_length=synth_model.sequence_length["targets_context"],
171
max_decoder_noise_time=synth_model.model.module.config.max_decoder_noise_time,
172
d_model=synth_model.model.module.config.emb_dim,
173
num_layers=synth_model.model.module.config.num_decoder_layers,
174
num_heads=synth_model.model.module.config.num_heads,
175
d_kv=synth_model.model.module.config.head_dim,
176
d_ff=synth_model.model.module.config.mlp_dim,
177
dropout_rate=synth_model.model.module.config.dropout_rate,
178
)
179
180
notes_encoder = load_notes_encoder(t5_checkpoint["target"]["token_encoder"], notes_encoder)
181
continuous_encoder = load_continuous_encoder(t5_checkpoint["target"]["continuous_encoder"], continuous_encoder)
182
decoder = load_decoder(t5_checkpoint["target"]["decoder"], decoder)
183
184
melgan = OnnxRuntimeModel.from_pretrained("kashif/soundstream_mel_decoder")
185
186
pipe = SpectrogramDiffusionPipeline(
187
notes_encoder=notes_encoder,
188
continuous_encoder=continuous_encoder,
189
decoder=decoder,
190
scheduler=scheduler,
191
melgan=melgan,
192
)
193
if args.save:
194
pipe.save_pretrained(args.output_path)
195
196
197
if __name__ == "__main__":
198
parser = argparse.ArgumentParser()
199
200
parser.add_argument("--output_path", default=None, type=str, required=True, help="Path to the converted model.")
201
parser.add_argument(
202
"--save", default=True, type=bool, required=False, help="Whether to save the converted model or not."
203
)
204
parser.add_argument(
205
"--checkpoint_path",
206
default=f"{MODEL}/checkpoint_500000",
207
type=str,
208
required=False,
209
help="Path to the original jax model checkpoint.",
210
)
211
args = parser.parse_args()
212
213
main(args)
214
215