Path: blob/main/scripts/conversion_ldm_uncond.py
1440 views
import argparse12import OmegaConf3import torch45from diffusers import DDIMScheduler, LDMPipeline, UNetLDMModel, VQModel678def convert_ldm_original(checkpoint_path, config_path, output_path):9config = OmegaConf.load(config_path)10state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]11keys = list(state_dict.keys())1213# extract state_dict for VQVAE14first_stage_dict = {}15first_stage_key = "first_stage_model."16for key in keys:17if key.startswith(first_stage_key):18first_stage_dict[key.replace(first_stage_key, "")] = state_dict[key]1920# extract state_dict for UNetLDM21unet_state_dict = {}22unet_key = "model.diffusion_model."23for key in keys:24if key.startswith(unet_key):25unet_state_dict[key.replace(unet_key, "")] = state_dict[key]2627vqvae_init_args = config.model.params.first_stage_config.params28unet_init_args = config.model.params.unet_config.params2930vqvae = VQModel(**vqvae_init_args).eval()31vqvae.load_state_dict(first_stage_dict)3233unet = UNetLDMModel(**unet_init_args).eval()34unet.load_state_dict(unet_state_dict)3536noise_scheduler = DDIMScheduler(37timesteps=config.model.params.timesteps,38beta_schedule="scaled_linear",39beta_start=config.model.params.linear_start,40beta_end=config.model.params.linear_end,41clip_sample=False,42)4344pipeline = LDMPipeline(vqvae, unet, noise_scheduler)45pipeline.save_pretrained(output_path)464748if __name__ == "__main__":49parser = argparse.ArgumentParser()50parser.add_argument("--checkpoint_path", type=str, required=True)51parser.add_argument("--config_path", type=str, required=True)52parser.add_argument("--output_path", type=str, required=True)53args = parser.parse_args()5455convert_ldm_original(args.checkpoint_path, args.config_path, args.output_path)565758