Path: blob/main/scripts/convert_ncsnpp_original_checkpoint_to_diffusers.py
1440 views
# coding=utf-81# Copyright 2023 The HuggingFace Inc. team.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14""" Conversion script for the NCSNPP checkpoints. """1516import argparse17import json1819import torch2021from diffusers import ScoreSdeVePipeline, ScoreSdeVeScheduler, UNet2DModel222324def convert_ncsnpp_checkpoint(checkpoint, config):25"""26Takes a state dict and the path to27"""28new_model_architecture = UNet2DModel(**config)29new_model_architecture.time_proj.W.data = checkpoint["all_modules.0.W"].data30new_model_architecture.time_proj.weight.data = checkpoint["all_modules.0.W"].data31new_model_architecture.time_embedding.linear_1.weight.data = checkpoint["all_modules.1.weight"].data32new_model_architecture.time_embedding.linear_1.bias.data = checkpoint["all_modules.1.bias"].data3334new_model_architecture.time_embedding.linear_2.weight.data = checkpoint["all_modules.2.weight"].data35new_model_architecture.time_embedding.linear_2.bias.data = checkpoint["all_modules.2.bias"].data3637new_model_architecture.conv_in.weight.data = checkpoint["all_modules.3.weight"].data38new_model_architecture.conv_in.bias.data = checkpoint["all_modules.3.bias"].data3940new_model_architecture.conv_norm_out.weight.data = checkpoint[list(checkpoint.keys())[-4]].data41new_model_architecture.conv_norm_out.bias.data = checkpoint[list(checkpoint.keys())[-3]].data42new_model_architecture.conv_out.weight.data = checkpoint[list(checkpoint.keys())[-2]].data43new_model_architecture.conv_out.bias.data = checkpoint[list(checkpoint.keys())[-1]].data4445module_index = 44647def set_attention_weights(new_layer, old_checkpoint, index):48new_layer.query.weight.data = old_checkpoint[f"all_modules.{index}.NIN_0.W"].data.T49new_layer.key.weight.data = old_checkpoint[f"all_modules.{index}.NIN_1.W"].data.T50new_layer.value.weight.data = old_checkpoint[f"all_modules.{index}.NIN_2.W"].data.T5152new_layer.query.bias.data = old_checkpoint[f"all_modules.{index}.NIN_0.b"].data53new_layer.key.bias.data = old_checkpoint[f"all_modules.{index}.NIN_1.b"].data54new_layer.value.bias.data = old_checkpoint[f"all_modules.{index}.NIN_2.b"].data5556new_layer.proj_attn.weight.data = old_checkpoint[f"all_modules.{index}.NIN_3.W"].data.T57new_layer.proj_attn.bias.data = old_checkpoint[f"all_modules.{index}.NIN_3.b"].data5859new_layer.group_norm.weight.data = old_checkpoint[f"all_modules.{index}.GroupNorm_0.weight"].data60new_layer.group_norm.bias.data = old_checkpoint[f"all_modules.{index}.GroupNorm_0.bias"].data6162def set_resnet_weights(new_layer, old_checkpoint, index):63new_layer.conv1.weight.data = old_checkpoint[f"all_modules.{index}.Conv_0.weight"].data64new_layer.conv1.bias.data = old_checkpoint[f"all_modules.{index}.Conv_0.bias"].data65new_layer.norm1.weight.data = old_checkpoint[f"all_modules.{index}.GroupNorm_0.weight"].data66new_layer.norm1.bias.data = old_checkpoint[f"all_modules.{index}.GroupNorm_0.bias"].data6768new_layer.conv2.weight.data = old_checkpoint[f"all_modules.{index}.Conv_1.weight"].data69new_layer.conv2.bias.data = old_checkpoint[f"all_modules.{index}.Conv_1.bias"].data70new_layer.norm2.weight.data = old_checkpoint[f"all_modules.{index}.GroupNorm_1.weight"].data71new_layer.norm2.bias.data = old_checkpoint[f"all_modules.{index}.GroupNorm_1.bias"].data7273new_layer.time_emb_proj.weight.data = old_checkpoint[f"all_modules.{index}.Dense_0.weight"].data74new_layer.time_emb_proj.bias.data = old_checkpoint[f"all_modules.{index}.Dense_0.bias"].data7576if new_layer.in_channels != new_layer.out_channels or new_layer.up or new_layer.down:77new_layer.conv_shortcut.weight.data = old_checkpoint[f"all_modules.{index}.Conv_2.weight"].data78new_layer.conv_shortcut.bias.data = old_checkpoint[f"all_modules.{index}.Conv_2.bias"].data7980for i, block in enumerate(new_model_architecture.downsample_blocks):81has_attentions = hasattr(block, "attentions")82for j in range(len(block.resnets)):83set_resnet_weights(block.resnets[j], checkpoint, module_index)84module_index += 185if has_attentions:86set_attention_weights(block.attentions[j], checkpoint, module_index)87module_index += 18889if hasattr(block, "downsamplers") and block.downsamplers is not None:90set_resnet_weights(block.resnet_down, checkpoint, module_index)91module_index += 192block.skip_conv.weight.data = checkpoint[f"all_modules.{module_index}.Conv_0.weight"].data93block.skip_conv.bias.data = checkpoint[f"all_modules.{module_index}.Conv_0.bias"].data94module_index += 19596set_resnet_weights(new_model_architecture.mid_block.resnets[0], checkpoint, module_index)97module_index += 198set_attention_weights(new_model_architecture.mid_block.attentions[0], checkpoint, module_index)99module_index += 1100set_resnet_weights(new_model_architecture.mid_block.resnets[1], checkpoint, module_index)101module_index += 1102103for i, block in enumerate(new_model_architecture.up_blocks):104has_attentions = hasattr(block, "attentions")105for j in range(len(block.resnets)):106set_resnet_weights(block.resnets[j], checkpoint, module_index)107module_index += 1108if has_attentions:109set_attention_weights(110block.attentions[0], checkpoint, module_index111) # why can there only be a single attention layer for up?112module_index += 1113114if hasattr(block, "resnet_up") and block.resnet_up is not None:115block.skip_norm.weight.data = checkpoint[f"all_modules.{module_index}.weight"].data116block.skip_norm.bias.data = checkpoint[f"all_modules.{module_index}.bias"].data117module_index += 1118block.skip_conv.weight.data = checkpoint[f"all_modules.{module_index}.weight"].data119block.skip_conv.bias.data = checkpoint[f"all_modules.{module_index}.bias"].data120module_index += 1121set_resnet_weights(block.resnet_up, checkpoint, module_index)122module_index += 1123124new_model_architecture.conv_norm_out.weight.data = checkpoint[f"all_modules.{module_index}.weight"].data125new_model_architecture.conv_norm_out.bias.data = checkpoint[f"all_modules.{module_index}.bias"].data126module_index += 1127new_model_architecture.conv_out.weight.data = checkpoint[f"all_modules.{module_index}.weight"].data128new_model_architecture.conv_out.bias.data = checkpoint[f"all_modules.{module_index}.bias"].data129130return new_model_architecture.state_dict()131132133if __name__ == "__main__":134parser = argparse.ArgumentParser()135136parser.add_argument(137"--checkpoint_path",138default="/Users/arthurzucker/Work/diffusers/ArthurZ/diffusion_pytorch_model.bin",139type=str,140required=False,141help="Path to the checkpoint to convert.",142)143144parser.add_argument(145"--config_file",146default="/Users/arthurzucker/Work/diffusers/ArthurZ/config.json",147type=str,148required=False,149help="The config json file corresponding to the architecture.",150)151152parser.add_argument(153"--dump_path",154default="/Users/arthurzucker/Work/diffusers/ArthurZ/diffusion_model_new.pt",155type=str,156required=False,157help="Path to the output model.",158)159160args = parser.parse_args()161162checkpoint = torch.load(args.checkpoint_path, map_location="cpu")163164with open(args.config_file) as f:165config = json.loads(f.read())166167converted_checkpoint = convert_ncsnpp_checkpoint(168checkpoint,169config,170)171172if "sde" in config:173del config["sde"]174175model = UNet2DModel(**config)176model.load_state_dict(converted_checkpoint)177178try:179scheduler = ScoreSdeVeScheduler.from_config("/".join(args.checkpoint_path.split("/")[:-1]))180181pipe = ScoreSdeVePipeline(unet=model, scheduler=scheduler)182pipe.save_pretrained(args.dump_path)183except: # noqa: E722184model.save_pretrained(args.dump_path)185186187