Path: blob/main/scripts/convert_music_spectrogram_to_diffusers.py
1440 views
#!/usr/bin/env python31import argparse2import os34import jax as jnp5import numpy as onp6import torch7import torch.nn as nn8from music_spectrogram_diffusion import inference9from t5x import checkpoints1011from diffusers import DDPMScheduler, OnnxRuntimeModel, SpectrogramDiffusionPipeline12from diffusers.pipelines.spectrogram_diffusion import SpectrogramContEncoder, SpectrogramNotesEncoder, T5FilmDecoder131415MODEL = "base_with_context"161718def load_notes_encoder(weights, model):19model.token_embedder.weight = nn.Parameter(torch.FloatTensor(weights["token_embedder"]["embedding"]))20model.position_encoding.weight = nn.Parameter(21torch.FloatTensor(weights["Embed_0"]["embedding"]), requires_grad=False22)23for lyr_num, lyr in enumerate(model.encoders):24ly_weight = weights[f"layers_{lyr_num}"]25lyr.layer[0].layer_norm.weight = nn.Parameter(26torch.FloatTensor(ly_weight["pre_attention_layer_norm"]["scale"])27)2829attention_weights = ly_weight["attention"]30lyr.layer[0].SelfAttention.q.weight = nn.Parameter(torch.FloatTensor(attention_weights["query"]["kernel"].T))31lyr.layer[0].SelfAttention.k.weight = nn.Parameter(torch.FloatTensor(attention_weights["key"]["kernel"].T))32lyr.layer[0].SelfAttention.v.weight = nn.Parameter(torch.FloatTensor(attention_weights["value"]["kernel"].T))33lyr.layer[0].SelfAttention.o.weight = nn.Parameter(torch.FloatTensor(attention_weights["out"]["kernel"].T))3435lyr.layer[1].layer_norm.weight = nn.Parameter(torch.FloatTensor(ly_weight["pre_mlp_layer_norm"]["scale"]))3637lyr.layer[1].DenseReluDense.wi_0.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wi_0"]["kernel"].T))38lyr.layer[1].DenseReluDense.wi_1.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wi_1"]["kernel"].T))39lyr.layer[1].DenseReluDense.wo.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wo"]["kernel"].T))4041model.layer_norm.weight = nn.Parameter(torch.FloatTensor(weights["encoder_norm"]["scale"]))42return model434445def load_continuous_encoder(weights, model):46model.input_proj.weight = nn.Parameter(torch.FloatTensor(weights["input_proj"]["kernel"].T))4748model.position_encoding.weight = nn.Parameter(49torch.FloatTensor(weights["Embed_0"]["embedding"]), requires_grad=False50)5152for lyr_num, lyr in enumerate(model.encoders):53ly_weight = weights[f"layers_{lyr_num}"]54attention_weights = ly_weight["attention"]5556lyr.layer[0].SelfAttention.q.weight = nn.Parameter(torch.FloatTensor(attention_weights["query"]["kernel"].T))57lyr.layer[0].SelfAttention.k.weight = nn.Parameter(torch.FloatTensor(attention_weights["key"]["kernel"].T))58lyr.layer[0].SelfAttention.v.weight = nn.Parameter(torch.FloatTensor(attention_weights["value"]["kernel"].T))59lyr.layer[0].SelfAttention.o.weight = nn.Parameter(torch.FloatTensor(attention_weights["out"]["kernel"].T))60lyr.layer[0].layer_norm.weight = nn.Parameter(61torch.FloatTensor(ly_weight["pre_attention_layer_norm"]["scale"])62)6364lyr.layer[1].DenseReluDense.wi_0.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wi_0"]["kernel"].T))65lyr.layer[1].DenseReluDense.wi_1.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wi_1"]["kernel"].T))66lyr.layer[1].DenseReluDense.wo.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wo"]["kernel"].T))67lyr.layer[1].layer_norm.weight = nn.Parameter(torch.FloatTensor(ly_weight["pre_mlp_layer_norm"]["scale"]))6869model.layer_norm.weight = nn.Parameter(torch.FloatTensor(weights["encoder_norm"]["scale"]))7071return model727374def load_decoder(weights, model):75model.conditioning_emb[0].weight = nn.Parameter(torch.FloatTensor(weights["time_emb_dense0"]["kernel"].T))76model.conditioning_emb[2].weight = nn.Parameter(torch.FloatTensor(weights["time_emb_dense1"]["kernel"].T))7778model.position_encoding.weight = nn.Parameter(79torch.FloatTensor(weights["Embed_0"]["embedding"]), requires_grad=False80)8182model.continuous_inputs_projection.weight = nn.Parameter(83torch.FloatTensor(weights["continuous_inputs_projection"]["kernel"].T)84)8586for lyr_num, lyr in enumerate(model.decoders):87ly_weight = weights[f"layers_{lyr_num}"]88lyr.layer[0].layer_norm.weight = nn.Parameter(89torch.FloatTensor(ly_weight["pre_self_attention_layer_norm"]["scale"])90)9192lyr.layer[0].FiLMLayer.scale_bias.weight = nn.Parameter(93torch.FloatTensor(ly_weight["FiLMLayer_0"]["DenseGeneral_0"]["kernel"].T)94)9596attention_weights = ly_weight["self_attention"]97lyr.layer[0].attention.to_q.weight = nn.Parameter(torch.FloatTensor(attention_weights["query"]["kernel"].T))98lyr.layer[0].attention.to_k.weight = nn.Parameter(torch.FloatTensor(attention_weights["key"]["kernel"].T))99lyr.layer[0].attention.to_v.weight = nn.Parameter(torch.FloatTensor(attention_weights["value"]["kernel"].T))100lyr.layer[0].attention.to_out[0].weight = nn.Parameter(torch.FloatTensor(attention_weights["out"]["kernel"].T))101102attention_weights = ly_weight["MultiHeadDotProductAttention_0"]103lyr.layer[1].attention.to_q.weight = nn.Parameter(torch.FloatTensor(attention_weights["query"]["kernel"].T))104lyr.layer[1].attention.to_k.weight = nn.Parameter(torch.FloatTensor(attention_weights["key"]["kernel"].T))105lyr.layer[1].attention.to_v.weight = nn.Parameter(torch.FloatTensor(attention_weights["value"]["kernel"].T))106lyr.layer[1].attention.to_out[0].weight = nn.Parameter(torch.FloatTensor(attention_weights["out"]["kernel"].T))107lyr.layer[1].layer_norm.weight = nn.Parameter(108torch.FloatTensor(ly_weight["pre_cross_attention_layer_norm"]["scale"])109)110111lyr.layer[2].layer_norm.weight = nn.Parameter(torch.FloatTensor(ly_weight["pre_mlp_layer_norm"]["scale"]))112lyr.layer[2].film.scale_bias.weight = nn.Parameter(113torch.FloatTensor(ly_weight["FiLMLayer_1"]["DenseGeneral_0"]["kernel"].T)114)115lyr.layer[2].DenseReluDense.wi_0.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wi_0"]["kernel"].T))116lyr.layer[2].DenseReluDense.wi_1.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wi_1"]["kernel"].T))117lyr.layer[2].DenseReluDense.wo.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wo"]["kernel"].T))118119model.decoder_norm.weight = nn.Parameter(torch.FloatTensor(weights["decoder_norm"]["scale"]))120121model.spec_out.weight = nn.Parameter(torch.FloatTensor(weights["spec_out_dense"]["kernel"].T))122123return model124125126def main(args):127t5_checkpoint = checkpoints.load_t5x_checkpoint(args.checkpoint_path)128t5_checkpoint = jnp.tree_util.tree_map(onp.array, t5_checkpoint)129130gin_overrides = [131"from __gin__ import dynamic_registration",132"from music_spectrogram_diffusion.models.diffusion import diffusion_utils",133"diffusion_utils.ClassifierFreeGuidanceConfig.eval_condition_weight = 2.0",134"diffusion_utils.DiffusionConfig.classifier_free_guidance = @diffusion_utils.ClassifierFreeGuidanceConfig()",135]136137gin_file = os.path.join(args.checkpoint_path, "..", "config.gin")138gin_config = inference.parse_training_gin_file(gin_file, gin_overrides)139synth_model = inference.InferenceModel(args.checkpoint_path, gin_config)140141scheduler = DDPMScheduler(beta_schedule="squaredcos_cap_v2", variance_type="fixed_large")142143notes_encoder = SpectrogramNotesEncoder(144max_length=synth_model.sequence_length["inputs"],145vocab_size=synth_model.model.module.config.vocab_size,146d_model=synth_model.model.module.config.emb_dim,147dropout_rate=synth_model.model.module.config.dropout_rate,148num_layers=synth_model.model.module.config.num_encoder_layers,149num_heads=synth_model.model.module.config.num_heads,150d_kv=synth_model.model.module.config.head_dim,151d_ff=synth_model.model.module.config.mlp_dim,152feed_forward_proj="gated-gelu",153)154155continuous_encoder = SpectrogramContEncoder(156input_dims=synth_model.audio_codec.n_dims,157targets_context_length=synth_model.sequence_length["targets_context"],158d_model=synth_model.model.module.config.emb_dim,159dropout_rate=synth_model.model.module.config.dropout_rate,160num_layers=synth_model.model.module.config.num_encoder_layers,161num_heads=synth_model.model.module.config.num_heads,162d_kv=synth_model.model.module.config.head_dim,163d_ff=synth_model.model.module.config.mlp_dim,164feed_forward_proj="gated-gelu",165)166167decoder = T5FilmDecoder(168input_dims=synth_model.audio_codec.n_dims,169targets_length=synth_model.sequence_length["targets_context"],170max_decoder_noise_time=synth_model.model.module.config.max_decoder_noise_time,171d_model=synth_model.model.module.config.emb_dim,172num_layers=synth_model.model.module.config.num_decoder_layers,173num_heads=synth_model.model.module.config.num_heads,174d_kv=synth_model.model.module.config.head_dim,175d_ff=synth_model.model.module.config.mlp_dim,176dropout_rate=synth_model.model.module.config.dropout_rate,177)178179notes_encoder = load_notes_encoder(t5_checkpoint["target"]["token_encoder"], notes_encoder)180continuous_encoder = load_continuous_encoder(t5_checkpoint["target"]["continuous_encoder"], continuous_encoder)181decoder = load_decoder(t5_checkpoint["target"]["decoder"], decoder)182183melgan = OnnxRuntimeModel.from_pretrained("kashif/soundstream_mel_decoder")184185pipe = SpectrogramDiffusionPipeline(186notes_encoder=notes_encoder,187continuous_encoder=continuous_encoder,188decoder=decoder,189scheduler=scheduler,190melgan=melgan,191)192if args.save:193pipe.save_pretrained(args.output_path)194195196if __name__ == "__main__":197parser = argparse.ArgumentParser()198199parser.add_argument("--output_path", default=None, type=str, required=True, help="Path to the converted model.")200parser.add_argument(201"--save", default=True, type=bool, required=False, help="Whether to save the converted model or not."202)203parser.add_argument(204"--checkpoint_path",205default=f"{MODEL}/checkpoint_500000",206type=str,207required=False,208help="Path to the original jax model checkpoint.",209)210args = parser.parse_args()211212main(args)213214215