Path: blob/main/scripts/convert_models_diffuser_to_diffusers.py
1440 views
import json1import os23import torch45from diffusers import UNet1DModel678os.makedirs("hub/hopper-medium-v2/unet/hor32", exist_ok=True)9os.makedirs("hub/hopper-medium-v2/unet/hor128", exist_ok=True)1011os.makedirs("hub/hopper-medium-v2/value_function", exist_ok=True)121314def unet(hor):15if hor == 128:16down_block_types = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D")17block_out_channels = (32, 128, 256)18up_block_types = ("UpResnetBlock1D", "UpResnetBlock1D")1920elif hor == 32:21down_block_types = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D")22block_out_channels = (32, 64, 128, 256)23up_block_types = ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D")24model = torch.load(f"/Users/bglickenhaus/Documents/diffuser/temporal_unet-hopper-mediumv2-hor{hor}.torch")25state_dict = model.state_dict()26config = dict(27down_block_types=down_block_types,28block_out_channels=block_out_channels,29up_block_types=up_block_types,30layers_per_block=1,31use_timestep_embedding=True,32out_block_type="OutConv1DBlock",33norm_num_groups=8,34downsample_each_block=False,35in_channels=14,36out_channels=14,37extra_in_channels=0,38time_embedding_type="positional",39flip_sin_to_cos=False,40freq_shift=1,41sample_size=65536,42mid_block_type="MidResTemporalBlock1D",43act_fn="mish",44)45hf_value_function = UNet1DModel(**config)46print(f"length of state dict: {len(state_dict.keys())}")47print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}")48mapping = dict((k, hfk) for k, hfk in zip(model.state_dict().keys(), hf_value_function.state_dict().keys()))49for k, v in mapping.items():50state_dict[v] = state_dict.pop(k)51hf_value_function.load_state_dict(state_dict)5253torch.save(hf_value_function.state_dict(), f"hub/hopper-medium-v2/unet/hor{hor}/diffusion_pytorch_model.bin")54with open(f"hub/hopper-medium-v2/unet/hor{hor}/config.json", "w") as f:55json.dump(config, f)565758def value_function():59config = dict(60in_channels=14,61down_block_types=("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"),62up_block_types=(),63out_block_type="ValueFunction",64mid_block_type="ValueFunctionMidBlock1D",65block_out_channels=(32, 64, 128, 256),66layers_per_block=1,67downsample_each_block=True,68sample_size=65536,69out_channels=14,70extra_in_channels=0,71time_embedding_type="positional",72use_timestep_embedding=True,73flip_sin_to_cos=False,74freq_shift=1,75norm_num_groups=8,76act_fn="mish",77)7879model = torch.load("/Users/bglickenhaus/Documents/diffuser/value_function-hopper-mediumv2-hor32.torch")80state_dict = model81hf_value_function = UNet1DModel(**config)82print(f"length of state dict: {len(state_dict.keys())}")83print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}")8485mapping = dict((k, hfk) for k, hfk in zip(state_dict.keys(), hf_value_function.state_dict().keys()))86for k, v in mapping.items():87state_dict[v] = state_dict.pop(k)8889hf_value_function.load_state_dict(state_dict)9091torch.save(hf_value_function.state_dict(), "hub/hopper-medium-v2/value_function/diffusion_pytorch_model.bin")92with open("hub/hopper-medium-v2/value_function/config.json", "w") as f:93json.dump(config, f)949596if __name__ == "__main__":97unet(32)98# unet(128)99value_function()100101102