Path: blob/main/scripts/convert_k_upscaler_to_diffusers.py
1440 views
import argparse12import huggingface_hub3import k_diffusion as K4import torch56from diffusers import UNet2DConditionModel789UPSCALER_REPO = "pcuenq/k-upscaler"101112def resnet_to_diffusers_checkpoint(resnet, checkpoint, *, diffusers_resnet_prefix, resnet_prefix):13rv = {14# norm115f"{diffusers_resnet_prefix}.norm1.linear.weight": checkpoint[f"{resnet_prefix}.main.0.mapper.weight"],16f"{diffusers_resnet_prefix}.norm1.linear.bias": checkpoint[f"{resnet_prefix}.main.0.mapper.bias"],17# conv118f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.main.2.weight"],19f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.main.2.bias"],20# norm221f"{diffusers_resnet_prefix}.norm2.linear.weight": checkpoint[f"{resnet_prefix}.main.4.mapper.weight"],22f"{diffusers_resnet_prefix}.norm2.linear.bias": checkpoint[f"{resnet_prefix}.main.4.mapper.bias"],23# conv224f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.main.6.weight"],25f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.main.6.bias"],26}2728if resnet.conv_shortcut is not None:29rv.update(30{31f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{resnet_prefix}.skip.weight"],32}33)3435return rv363738def self_attn_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix):39weight_q, weight_k, weight_v = checkpoint[f"{attention_prefix}.qkv_proj.weight"].chunk(3, dim=0)40bias_q, bias_k, bias_v = checkpoint[f"{attention_prefix}.qkv_proj.bias"].chunk(3, dim=0)41rv = {42# norm43f"{diffusers_attention_prefix}.norm1.linear.weight": checkpoint[f"{attention_prefix}.norm_in.mapper.weight"],44f"{diffusers_attention_prefix}.norm1.linear.bias": checkpoint[f"{attention_prefix}.norm_in.mapper.bias"],45# to_q46f"{diffusers_attention_prefix}.attn1.to_q.weight": weight_q.squeeze(-1).squeeze(-1),47f"{diffusers_attention_prefix}.attn1.to_q.bias": bias_q,48# to_k49f"{diffusers_attention_prefix}.attn1.to_k.weight": weight_k.squeeze(-1).squeeze(-1),50f"{diffusers_attention_prefix}.attn1.to_k.bias": bias_k,51# to_v52f"{diffusers_attention_prefix}.attn1.to_v.weight": weight_v.squeeze(-1).squeeze(-1),53f"{diffusers_attention_prefix}.attn1.to_v.bias": bias_v,54# to_out55f"{diffusers_attention_prefix}.attn1.to_out.0.weight": checkpoint[f"{attention_prefix}.out_proj.weight"]56.squeeze(-1)57.squeeze(-1),58f"{diffusers_attention_prefix}.attn1.to_out.0.bias": checkpoint[f"{attention_prefix}.out_proj.bias"],59}6061return rv626364def cross_attn_to_diffusers_checkpoint(65checkpoint, *, diffusers_attention_prefix, diffusers_attention_index, attention_prefix66):67weight_k, weight_v = checkpoint[f"{attention_prefix}.kv_proj.weight"].chunk(2, dim=0)68bias_k, bias_v = checkpoint[f"{attention_prefix}.kv_proj.bias"].chunk(2, dim=0)6970rv = {71# norm2 (ada groupnorm)72f"{diffusers_attention_prefix}.norm{diffusers_attention_index}.linear.weight": checkpoint[73f"{attention_prefix}.norm_dec.mapper.weight"74],75f"{diffusers_attention_prefix}.norm{diffusers_attention_index}.linear.bias": checkpoint[76f"{attention_prefix}.norm_dec.mapper.bias"77],78# layernorm on encoder_hidden_state79f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.norm_cross.weight": checkpoint[80f"{attention_prefix}.norm_enc.weight"81],82f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.norm_cross.bias": checkpoint[83f"{attention_prefix}.norm_enc.bias"84],85# to_q86f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_q.weight": checkpoint[87f"{attention_prefix}.q_proj.weight"88]89.squeeze(-1)90.squeeze(-1),91f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_q.bias": checkpoint[92f"{attention_prefix}.q_proj.bias"93],94# to_k95f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_k.weight": weight_k.squeeze(-1).squeeze(-1),96f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_k.bias": bias_k,97# to_v98f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_v.weight": weight_v.squeeze(-1).squeeze(-1),99f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_v.bias": bias_v,100# to_out101f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_out.0.weight": checkpoint[102f"{attention_prefix}.out_proj.weight"103]104.squeeze(-1)105.squeeze(-1),106f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_out.0.bias": checkpoint[107f"{attention_prefix}.out_proj.bias"108],109}110111return rv112113114def block_to_diffusers_checkpoint(block, checkpoint, block_idx, block_type):115block_prefix = "inner_model.u_net.u_blocks" if block_type == "up" else "inner_model.u_net.d_blocks"116block_prefix = f"{block_prefix}.{block_idx}"117118diffusers_checkpoint = {}119120if not hasattr(block, "attentions"):121n = 1 # resnet only122elif not block.attentions[0].add_self_attention:123n = 2 # resnet -> cross-attention124else:125n = 3 # resnet -> self-attention -> cross-attention)126127for resnet_idx, resnet in enumerate(block.resnets):128# diffusers_resnet_prefix = f"{diffusers_up_block_prefix}.resnets.{resnet_idx}"129diffusers_resnet_prefix = f"{block_type}_blocks.{block_idx}.resnets.{resnet_idx}"130idx = n * resnet_idx if block_type == "up" else n * resnet_idx + 1131resnet_prefix = f"{block_prefix}.{idx}" if block_type == "up" else f"{block_prefix}.{idx}"132133diffusers_checkpoint.update(134resnet_to_diffusers_checkpoint(135resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix136)137)138139if hasattr(block, "attentions"):140for attention_idx, attention in enumerate(block.attentions):141diffusers_attention_prefix = f"{block_type}_blocks.{block_idx}.attentions.{attention_idx}"142idx = n * attention_idx + 1 if block_type == "up" else n * attention_idx + 2143self_attention_prefix = f"{block_prefix}.{idx}"144cross_attention_prefix = f"{block_prefix}.{idx }"145cross_attention_index = 1 if not attention.add_self_attention else 2146idx = (147n * attention_idx + cross_attention_index148if block_type == "up"149else n * attention_idx + cross_attention_index + 1150)151cross_attention_prefix = f"{block_prefix}.{idx }"152153diffusers_checkpoint.update(154cross_attn_to_diffusers_checkpoint(155checkpoint,156diffusers_attention_prefix=diffusers_attention_prefix,157diffusers_attention_index=2,158attention_prefix=cross_attention_prefix,159)160)161162if attention.add_self_attention is True:163diffusers_checkpoint.update(164self_attn_to_diffusers_checkpoint(165checkpoint,166diffusers_attention_prefix=diffusers_attention_prefix,167attention_prefix=self_attention_prefix,168)169)170171return diffusers_checkpoint172173174def unet_to_diffusers_checkpoint(model, checkpoint):175diffusers_checkpoint = {}176177# pre-processing178diffusers_checkpoint.update(179{180"conv_in.weight": checkpoint["inner_model.proj_in.weight"],181"conv_in.bias": checkpoint["inner_model.proj_in.bias"],182}183)184185# timestep and class embedding186diffusers_checkpoint.update(187{188"time_proj.weight": checkpoint["inner_model.timestep_embed.weight"].squeeze(-1),189"time_embedding.linear_1.weight": checkpoint["inner_model.mapping.0.weight"],190"time_embedding.linear_1.bias": checkpoint["inner_model.mapping.0.bias"],191"time_embedding.linear_2.weight": checkpoint["inner_model.mapping.2.weight"],192"time_embedding.linear_2.bias": checkpoint["inner_model.mapping.2.bias"],193"time_embedding.cond_proj.weight": checkpoint["inner_model.mapping_cond.weight"],194}195)196197# down_blocks198for down_block_idx, down_block in enumerate(model.down_blocks):199diffusers_checkpoint.update(block_to_diffusers_checkpoint(down_block, checkpoint, down_block_idx, "down"))200201# up_blocks202for up_block_idx, up_block in enumerate(model.up_blocks):203diffusers_checkpoint.update(block_to_diffusers_checkpoint(up_block, checkpoint, up_block_idx, "up"))204205# post-processing206diffusers_checkpoint.update(207{208"conv_out.weight": checkpoint["inner_model.proj_out.weight"],209"conv_out.bias": checkpoint["inner_model.proj_out.bias"],210}211)212213return diffusers_checkpoint214215216def unet_model_from_original_config(original_config):217in_channels = original_config["input_channels"] + original_config["unet_cond_dim"]218out_channels = original_config["input_channels"] + (1 if original_config["has_variance"] else 0)219220block_out_channels = original_config["channels"]221222assert (223len(set(original_config["depths"])) == 1224), "UNet2DConditionModel currently do not support blocks with different number of layers"225layers_per_block = original_config["depths"][0]226227class_labels_dim = original_config["mapping_cond_dim"]228cross_attention_dim = original_config["cross_cond_dim"]229230attn1_types = []231attn2_types = []232for s, c in zip(original_config["self_attn_depths"], original_config["cross_attn_depths"]):233if s:234a1 = "self"235a2 = "cross" if c else None236elif c:237a1 = "cross"238a2 = None239else:240a1 = None241a2 = None242attn1_types.append(a1)243attn2_types.append(a2)244245unet = UNet2DConditionModel(246in_channels=in_channels,247out_channels=out_channels,248down_block_types=("KDownBlock2D", "KCrossAttnDownBlock2D", "KCrossAttnDownBlock2D", "KCrossAttnDownBlock2D"),249mid_block_type=None,250up_block_types=("KCrossAttnUpBlock2D", "KCrossAttnUpBlock2D", "KCrossAttnUpBlock2D", "KUpBlock2D"),251block_out_channels=block_out_channels,252layers_per_block=layers_per_block,253act_fn="gelu",254norm_num_groups=None,255cross_attention_dim=cross_attention_dim,256attention_head_dim=64,257time_cond_proj_dim=class_labels_dim,258resnet_time_scale_shift="scale_shift",259time_embedding_type="fourier",260timestep_post_act="gelu",261conv_in_kernel=1,262conv_out_kernel=1,263)264265return unet266267268def main(args):269device = torch.device("cuda" if torch.cuda.is_available() else "cpu")270271orig_config_path = huggingface_hub.hf_hub_download(UPSCALER_REPO, "config_laion_text_cond_latent_upscaler_2.json")272orig_weights_path = huggingface_hub.hf_hub_download(273UPSCALER_REPO, "laion_text_cond_latent_upscaler_2_1_00470000_slim.pth"274)275print(f"loading original model configuration from {orig_config_path}")276print(f"loading original model checkpoint from {orig_weights_path}")277278print("converting to diffusers unet")279orig_config = K.config.load_config(open(orig_config_path))["model"]280model = unet_model_from_original_config(orig_config)281282orig_checkpoint = torch.load(orig_weights_path, map_location=device)["model_ema"]283converted_checkpoint = unet_to_diffusers_checkpoint(model, orig_checkpoint)284285model.load_state_dict(converted_checkpoint, strict=True)286model.save_pretrained(args.dump_path)287print(f"saving converted unet model in {args.dump_path}")288289290if __name__ == "__main__":291parser = argparse.ArgumentParser()292293parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")294args = parser.parse_args()295296main(args)297298299