Path: blob/main/scripts/convert_dance_diffusion_to_diffusers.py
1440 views
#!/usr/bin/env python31import argparse2import math3import os4from copy import deepcopy56import torch7from audio_diffusion.models import DiffusionAttnUnet1D8from diffusion import sampling9from torch import nn1011from diffusers import DanceDiffusionPipeline, IPNDMScheduler, UNet1DModel121314MODELS_MAP = {15"gwf-440k": {16"url": "https://model-server.zqevans2.workers.dev/gwf-440k.ckpt",17"sample_rate": 48000,18"sample_size": 65536,19},20"jmann-small-190k": {21"url": "https://model-server.zqevans2.workers.dev/jmann-small-190k.ckpt",22"sample_rate": 48000,23"sample_size": 65536,24},25"jmann-large-580k": {26"url": "https://model-server.zqevans2.workers.dev/jmann-large-580k.ckpt",27"sample_rate": 48000,28"sample_size": 131072,29},30"maestro-uncond-150k": {31"url": "https://model-server.zqevans2.workers.dev/maestro-uncond-150k.ckpt",32"sample_rate": 16000,33"sample_size": 65536,34},35"unlocked-uncond-250k": {36"url": "https://model-server.zqevans2.workers.dev/unlocked-uncond-250k.ckpt",37"sample_rate": 16000,38"sample_size": 65536,39},40"honk-140k": {41"url": "https://model-server.zqevans2.workers.dev/honk-140k.ckpt",42"sample_rate": 16000,43"sample_size": 65536,44},45}464748def alpha_sigma_to_t(alpha, sigma):49"""Returns a timestep, given the scaling factors for the clean image and for50the noise."""51return torch.atan2(sigma, alpha) / math.pi * 2525354def get_crash_schedule(t):55sigma = torch.sin(t * math.pi / 2) ** 256alpha = (1 - sigma**2) ** 0.557return alpha_sigma_to_t(alpha, sigma)585960class Object(object):61pass626364class DiffusionUncond(nn.Module):65def __init__(self, global_args):66super().__init__()6768self.diffusion = DiffusionAttnUnet1D(global_args, n_attn_layers=4)69self.diffusion_ema = deepcopy(self.diffusion)70self.rng = torch.quasirandom.SobolEngine(1, scramble=True)717273def download(model_name):74url = MODELS_MAP[model_name]["url"]75os.system(f"wget {url} ./")7677return f"./{model_name}.ckpt"787980DOWN_NUM_TO_LAYER = {81"1": "resnets.0",82"2": "attentions.0",83"3": "resnets.1",84"4": "attentions.1",85"5": "resnets.2",86"6": "attentions.2",87}88UP_NUM_TO_LAYER = {89"8": "resnets.0",90"9": "attentions.0",91"10": "resnets.1",92"11": "attentions.1",93"12": "resnets.2",94"13": "attentions.2",95}96MID_NUM_TO_LAYER = {97"1": "resnets.0",98"2": "attentions.0",99"3": "resnets.1",100"4": "attentions.1",101"5": "resnets.2",102"6": "attentions.2",103"8": "resnets.3",104"9": "attentions.3",105"10": "resnets.4",106"11": "attentions.4",107"12": "resnets.5",108"13": "attentions.5",109}110DEPTH_0_TO_LAYER = {111"0": "resnets.0",112"1": "resnets.1",113"2": "resnets.2",114"4": "resnets.0",115"5": "resnets.1",116"6": "resnets.2",117}118119RES_CONV_MAP = {120"skip": "conv_skip",121"main.0": "conv_1",122"main.1": "group_norm_1",123"main.3": "conv_2",124"main.4": "group_norm_2",125}126127ATTN_MAP = {128"norm": "group_norm",129"qkv_proj": ["query", "key", "value"],130"out_proj": ["proj_attn"],131}132133134def convert_resconv_naming(name):135if name.startswith("skip"):136return name.replace("skip", RES_CONV_MAP["skip"])137138# name has to be of format main.{digit}139if not name.startswith("main."):140raise ValueError(f"ResConvBlock error with {name}")141142return name.replace(name[:6], RES_CONV_MAP[name[:6]])143144145def convert_attn_naming(name):146for key, value in ATTN_MAP.items():147if name.startswith(key) and not isinstance(value, list):148return name.replace(key, value)149elif name.startswith(key):150return [name.replace(key, v) for v in value]151raise ValueError(f"Attn error with {name}")152153154def rename(input_string, max_depth=13):155string = input_string156157if string.split(".")[0] == "timestep_embed":158return string.replace("timestep_embed", "time_proj")159160depth = 0161if string.startswith("net.3."):162depth += 1163string = string[6:]164elif string.startswith("net."):165string = string[4:]166167while string.startswith("main.7."):168depth += 1169string = string[7:]170171if string.startswith("main."):172string = string[5:]173174# mid block175if string[:2].isdigit():176layer_num = string[:2]177string_left = string[2:]178else:179layer_num = string[0]180string_left = string[1:]181182if depth == max_depth:183new_layer = MID_NUM_TO_LAYER[layer_num]184prefix = "mid_block"185elif depth > 0 and int(layer_num) < 7:186new_layer = DOWN_NUM_TO_LAYER[layer_num]187prefix = f"down_blocks.{depth}"188elif depth > 0 and int(layer_num) > 7:189new_layer = UP_NUM_TO_LAYER[layer_num]190prefix = f"up_blocks.{max_depth - depth - 1}"191elif depth == 0:192new_layer = DEPTH_0_TO_LAYER[layer_num]193prefix = f"up_blocks.{max_depth - 1}" if int(layer_num) > 3 else "down_blocks.0"194195if not string_left.startswith("."):196raise ValueError(f"Naming error with {input_string} and string_left: {string_left}.")197198string_left = string_left[1:]199200if "resnets" in new_layer:201string_left = convert_resconv_naming(string_left)202elif "attentions" in new_layer:203new_string_left = convert_attn_naming(string_left)204string_left = new_string_left205206if not isinstance(string_left, list):207new_string = prefix + "." + new_layer + "." + string_left208else:209new_string = [prefix + "." + new_layer + "." + s for s in string_left]210return new_string211212213def rename_orig_weights(state_dict):214new_state_dict = {}215for k, v in state_dict.items():216if k.endswith("kernel"):217# up- and downsample layers, don't have trainable weights218continue219220new_k = rename(k)221222# check if we need to transform from Conv => Linear for attention223if isinstance(new_k, list):224new_state_dict = transform_conv_attns(new_state_dict, new_k, v)225else:226new_state_dict[new_k] = v227228return new_state_dict229230231def transform_conv_attns(new_state_dict, new_k, v):232if len(new_k) == 1:233if len(v.shape) == 3:234# weight235new_state_dict[new_k[0]] = v[:, :, 0]236else:237# bias238new_state_dict[new_k[0]] = v239else:240# qkv matrices241trippled_shape = v.shape[0]242single_shape = trippled_shape // 3243for i in range(3):244if len(v.shape) == 3:245new_state_dict[new_k[i]] = v[i * single_shape : (i + 1) * single_shape, :, 0]246else:247new_state_dict[new_k[i]] = v[i * single_shape : (i + 1) * single_shape]248return new_state_dict249250251def main(args):252device = torch.device("cuda" if torch.cuda.is_available() else "cpu")253254model_name = args.model_path.split("/")[-1].split(".")[0]255if not os.path.isfile(args.model_path):256assert (257model_name == args.model_path258), f"Make sure to provide one of the official model names {MODELS_MAP.keys()}"259args.model_path = download(model_name)260261sample_rate = MODELS_MAP[model_name]["sample_rate"]262sample_size = MODELS_MAP[model_name]["sample_size"]263264config = Object()265config.sample_size = sample_size266config.sample_rate = sample_rate267config.latent_dim = 0268269diffusers_model = UNet1DModel(sample_size=sample_size, sample_rate=sample_rate)270diffusers_state_dict = diffusers_model.state_dict()271272orig_model = DiffusionUncond(config)273orig_model.load_state_dict(torch.load(args.model_path, map_location=device)["state_dict"])274orig_model = orig_model.diffusion_ema.eval()275orig_model_state_dict = orig_model.state_dict()276renamed_state_dict = rename_orig_weights(orig_model_state_dict)277278renamed_minus_diffusers = set(renamed_state_dict.keys()) - set(diffusers_state_dict.keys())279diffusers_minus_renamed = set(diffusers_state_dict.keys()) - set(renamed_state_dict.keys())280281assert len(renamed_minus_diffusers) == 0, f"Problem with {renamed_minus_diffusers}"282assert all(k.endswith("kernel") for k in list(diffusers_minus_renamed)), f"Problem with {diffusers_minus_renamed}"283284for key, value in renamed_state_dict.items():285assert (286diffusers_state_dict[key].squeeze().shape == value.squeeze().shape287), f"Shape for {key} doesn't match. Diffusers: {diffusers_state_dict[key].shape} vs. {value.shape}"288if key == "time_proj.weight":289value = value.squeeze()290291diffusers_state_dict[key] = value292293diffusers_model.load_state_dict(diffusers_state_dict)294295steps = 100296seed = 33297298diffusers_scheduler = IPNDMScheduler(num_train_timesteps=steps)299300generator = torch.manual_seed(seed)301noise = torch.randn([1, 2, config.sample_size], generator=generator).to(device)302303t = torch.linspace(1, 0, steps + 1, device=device)[:-1]304step_list = get_crash_schedule(t)305306pipe = DanceDiffusionPipeline(unet=diffusers_model, scheduler=diffusers_scheduler)307308generator = torch.manual_seed(33)309audio = pipe(num_inference_steps=steps, generator=generator).audios310311generated = sampling.iplms_sample(orig_model, noise, step_list, {})312generated = generated.clamp(-1, 1)313314diff_sum = (generated - audio).abs().sum()315diff_max = (generated - audio).abs().max()316317if args.save:318pipe.save_pretrained(args.checkpoint_path)319320print("Diff sum", diff_sum)321print("Diff max", diff_max)322323assert diff_max < 1e-3, f"Diff max: {diff_max} is too much :-/"324325print(f"Conversion for {model_name} successful!")326327328if __name__ == "__main__":329parser = argparse.ArgumentParser()330331parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.")332parser.add_argument(333"--save", default=True, type=bool, required=False, help="Whether to save the converted model or not."334)335parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")336args = parser.parse_args()337338main(args)339340341