Path: blob/main/scripts/convert_kakao_brain_unclip_to_diffusers.py
1440 views
import argparse1import tempfile23import torch4from accelerate import load_checkpoint_and_dispatch5from transformers import CLIPTextModelWithProjection, CLIPTokenizer67from diffusers import UnCLIPPipeline, UNet2DConditionModel, UNet2DModel8from diffusers.models.prior_transformer import PriorTransformer9from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel10from diffusers.schedulers.scheduling_unclip import UnCLIPScheduler111213"""14Example - From the diffusers root directory:1516Download weights:17```sh18$ wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/efdf6206d8ed593961593dc029a8affa/decoder-ckpt-step%3D01000000-of-01000000.ckpt19$ wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/4226b831ae0279020d134281f3c31590/improved-sr-ckpt-step%3D1.2M.ckpt20$ wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/85626483eaca9f581e2a78d31ff905ca/prior-ckpt-step%3D01000000-of-01000000.ckpt21$ wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/0b62380a75e56f073e2844ab5199153d/ViT-L-14_stats.th22```2324Convert the model:25```sh26$ python scripts/convert_kakao_brain_unclip_to_diffusers.py \27--decoder_checkpoint_path ./decoder-ckpt-step\=01000000-of-01000000.ckpt \28--super_res_unet_checkpoint_path ./improved-sr-ckpt-step\=1.2M.ckpt \29--prior_checkpoint_path ./prior-ckpt-step\=01000000-of-01000000.ckpt \30--clip_stat_path ./ViT-L-14_stats.th \31--dump_path <path where to save model>32```33"""343536# prior3738PRIOR_ORIGINAL_PREFIX = "model"3940# Uses default arguments41PRIOR_CONFIG = {}424344def prior_model_from_original_config():45model = PriorTransformer(**PRIOR_CONFIG)4647return model484950def prior_original_checkpoint_to_diffusers_checkpoint(model, checkpoint, clip_stats_checkpoint):51diffusers_checkpoint = {}5253# <original>.time_embed.0 -> <diffusers>.time_embedding.linear_154diffusers_checkpoint.update(55{56"time_embedding.linear_1.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.0.weight"],57"time_embedding.linear_1.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.0.bias"],58}59)6061# <original>.clip_img_proj -> <diffusers>.proj_in62diffusers_checkpoint.update(63{64"proj_in.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.clip_img_proj.weight"],65"proj_in.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.clip_img_proj.bias"],66}67)6869# <original>.text_emb_proj -> <diffusers>.embedding_proj70diffusers_checkpoint.update(71{72"embedding_proj.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.text_emb_proj.weight"],73"embedding_proj.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.text_emb_proj.bias"],74}75)7677# <original>.text_enc_proj -> <diffusers>.encoder_hidden_states_proj78diffusers_checkpoint.update(79{80"encoder_hidden_states_proj.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.text_enc_proj.weight"],81"encoder_hidden_states_proj.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.text_enc_proj.bias"],82}83)8485# <original>.positional_embedding -> <diffusers>.positional_embedding86diffusers_checkpoint.update({"positional_embedding": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.positional_embedding"]})8788# <original>.prd_emb -> <diffusers>.prd_embedding89diffusers_checkpoint.update({"prd_embedding": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.prd_emb"]})9091# <original>.time_embed.2 -> <diffusers>.time_embedding.linear_292diffusers_checkpoint.update(93{94"time_embedding.linear_2.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.2.weight"],95"time_embedding.linear_2.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.2.bias"],96}97)9899# <original>.resblocks.<x> -> <diffusers>.transformer_blocks.<x>100for idx in range(len(model.transformer_blocks)):101diffusers_transformer_prefix = f"transformer_blocks.{idx}"102original_transformer_prefix = f"{PRIOR_ORIGINAL_PREFIX}.transformer.resblocks.{idx}"103104# <original>.attn -> <diffusers>.attn1105diffusers_attention_prefix = f"{diffusers_transformer_prefix}.attn1"106original_attention_prefix = f"{original_transformer_prefix}.attn"107diffusers_checkpoint.update(108prior_attention_to_diffusers(109checkpoint,110diffusers_attention_prefix=diffusers_attention_prefix,111original_attention_prefix=original_attention_prefix,112attention_head_dim=model.attention_head_dim,113)114)115116# <original>.mlp -> <diffusers>.ff117diffusers_ff_prefix = f"{diffusers_transformer_prefix}.ff"118original_ff_prefix = f"{original_transformer_prefix}.mlp"119diffusers_checkpoint.update(120prior_ff_to_diffusers(121checkpoint, diffusers_ff_prefix=diffusers_ff_prefix, original_ff_prefix=original_ff_prefix122)123)124125# <original>.ln_1 -> <diffusers>.norm1126diffusers_checkpoint.update(127{128f"{diffusers_transformer_prefix}.norm1.weight": checkpoint[129f"{original_transformer_prefix}.ln_1.weight"130],131f"{diffusers_transformer_prefix}.norm1.bias": checkpoint[f"{original_transformer_prefix}.ln_1.bias"],132}133)134135# <original>.ln_2 -> <diffusers>.norm3136diffusers_checkpoint.update(137{138f"{diffusers_transformer_prefix}.norm3.weight": checkpoint[139f"{original_transformer_prefix}.ln_2.weight"140],141f"{diffusers_transformer_prefix}.norm3.bias": checkpoint[f"{original_transformer_prefix}.ln_2.bias"],142}143)144145# <original>.final_ln -> <diffusers>.norm_out146diffusers_checkpoint.update(147{148"norm_out.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.final_ln.weight"],149"norm_out.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.final_ln.bias"],150}151)152153# <original>.out_proj -> <diffusers>.proj_to_clip_embeddings154diffusers_checkpoint.update(155{156"proj_to_clip_embeddings.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.out_proj.weight"],157"proj_to_clip_embeddings.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.out_proj.bias"],158}159)160161# clip stats162clip_mean, clip_std = clip_stats_checkpoint163clip_mean = clip_mean[None, :]164clip_std = clip_std[None, :]165166diffusers_checkpoint.update({"clip_mean": clip_mean, "clip_std": clip_std})167168return diffusers_checkpoint169170171def prior_attention_to_diffusers(172checkpoint, *, diffusers_attention_prefix, original_attention_prefix, attention_head_dim173):174diffusers_checkpoint = {}175176# <original>.c_qkv -> <diffusers>.{to_q, to_k, to_v}177[q_weight, k_weight, v_weight], [q_bias, k_bias, v_bias] = split_attentions(178weight=checkpoint[f"{original_attention_prefix}.c_qkv.weight"],179bias=checkpoint[f"{original_attention_prefix}.c_qkv.bias"],180split=3,181chunk_size=attention_head_dim,182)183184diffusers_checkpoint.update(185{186f"{diffusers_attention_prefix}.to_q.weight": q_weight,187f"{diffusers_attention_prefix}.to_q.bias": q_bias,188f"{diffusers_attention_prefix}.to_k.weight": k_weight,189f"{diffusers_attention_prefix}.to_k.bias": k_bias,190f"{diffusers_attention_prefix}.to_v.weight": v_weight,191f"{diffusers_attention_prefix}.to_v.bias": v_bias,192}193)194195# <original>.c_proj -> <diffusers>.to_out.0196diffusers_checkpoint.update(197{198f"{diffusers_attention_prefix}.to_out.0.weight": checkpoint[f"{original_attention_prefix}.c_proj.weight"],199f"{diffusers_attention_prefix}.to_out.0.bias": checkpoint[f"{original_attention_prefix}.c_proj.bias"],200}201)202203return diffusers_checkpoint204205206def prior_ff_to_diffusers(checkpoint, *, diffusers_ff_prefix, original_ff_prefix):207diffusers_checkpoint = {208# <original>.c_fc -> <diffusers>.net.0.proj209f"{diffusers_ff_prefix}.net.{0}.proj.weight": checkpoint[f"{original_ff_prefix}.c_fc.weight"],210f"{diffusers_ff_prefix}.net.{0}.proj.bias": checkpoint[f"{original_ff_prefix}.c_fc.bias"],211# <original>.c_proj -> <diffusers>.net.2212f"{diffusers_ff_prefix}.net.{2}.weight": checkpoint[f"{original_ff_prefix}.c_proj.weight"],213f"{diffusers_ff_prefix}.net.{2}.bias": checkpoint[f"{original_ff_prefix}.c_proj.bias"],214}215216return diffusers_checkpoint217218219# done prior220221222# decoder223224DECODER_ORIGINAL_PREFIX = "model"225226# We are hardcoding the model configuration for now. If we need to generalize to more model configurations, we can227# update then.228DECODER_CONFIG = {229"sample_size": 64,230"layers_per_block": 3,231"down_block_types": (232"ResnetDownsampleBlock2D",233"SimpleCrossAttnDownBlock2D",234"SimpleCrossAttnDownBlock2D",235"SimpleCrossAttnDownBlock2D",236),237"up_block_types": (238"SimpleCrossAttnUpBlock2D",239"SimpleCrossAttnUpBlock2D",240"SimpleCrossAttnUpBlock2D",241"ResnetUpsampleBlock2D",242),243"mid_block_type": "UNetMidBlock2DSimpleCrossAttn",244"block_out_channels": (320, 640, 960, 1280),245"in_channels": 3,246"out_channels": 6,247"cross_attention_dim": 1536,248"class_embed_type": "identity",249"attention_head_dim": 64,250"resnet_time_scale_shift": "scale_shift",251}252253254def decoder_model_from_original_config():255model = UNet2DConditionModel(**DECODER_CONFIG)256257return model258259260def decoder_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):261diffusers_checkpoint = {}262263original_unet_prefix = DECODER_ORIGINAL_PREFIX264num_head_channels = DECODER_CONFIG["attention_head_dim"]265266diffusers_checkpoint.update(unet_time_embeddings(checkpoint, original_unet_prefix))267diffusers_checkpoint.update(unet_conv_in(checkpoint, original_unet_prefix))268269# <original>.input_blocks -> <diffusers>.down_blocks270271original_down_block_idx = 1272273for diffusers_down_block_idx in range(len(model.down_blocks)):274checkpoint_update, num_original_down_blocks = unet_downblock_to_diffusers_checkpoint(275model,276checkpoint,277diffusers_down_block_idx=diffusers_down_block_idx,278original_down_block_idx=original_down_block_idx,279original_unet_prefix=original_unet_prefix,280num_head_channels=num_head_channels,281)282283original_down_block_idx += num_original_down_blocks284285diffusers_checkpoint.update(checkpoint_update)286287# done <original>.input_blocks -> <diffusers>.down_blocks288289diffusers_checkpoint.update(290unet_midblock_to_diffusers_checkpoint(291model,292checkpoint,293original_unet_prefix=original_unet_prefix,294num_head_channels=num_head_channels,295)296)297298# <original>.output_blocks -> <diffusers>.up_blocks299300original_up_block_idx = 0301302for diffusers_up_block_idx in range(len(model.up_blocks)):303checkpoint_update, num_original_up_blocks = unet_upblock_to_diffusers_checkpoint(304model,305checkpoint,306diffusers_up_block_idx=diffusers_up_block_idx,307original_up_block_idx=original_up_block_idx,308original_unet_prefix=original_unet_prefix,309num_head_channels=num_head_channels,310)311312original_up_block_idx += num_original_up_blocks313314diffusers_checkpoint.update(checkpoint_update)315316# done <original>.output_blocks -> <diffusers>.up_blocks317318diffusers_checkpoint.update(unet_conv_norm_out(checkpoint, original_unet_prefix))319diffusers_checkpoint.update(unet_conv_out(checkpoint, original_unet_prefix))320321return diffusers_checkpoint322323324# done decoder325326# text proj327328329def text_proj_from_original_config():330# From the conditional unet constructor where the dimension of the projected time embeddings is331# constructed332time_embed_dim = DECODER_CONFIG["block_out_channels"][0] * 4333334cross_attention_dim = DECODER_CONFIG["cross_attention_dim"]335336model = UnCLIPTextProjModel(time_embed_dim=time_embed_dim, cross_attention_dim=cross_attention_dim)337338return model339340341# Note that the input checkpoint is the original decoder checkpoint342def text_proj_original_checkpoint_to_diffusers_checkpoint(checkpoint):343diffusers_checkpoint = {344# <original>.text_seq_proj.0 -> <diffusers>.encoder_hidden_states_proj345"encoder_hidden_states_proj.weight": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.text_seq_proj.0.weight"],346"encoder_hidden_states_proj.bias": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.text_seq_proj.0.bias"],347# <original>.text_seq_proj.1 -> <diffusers>.text_encoder_hidden_states_norm348"text_encoder_hidden_states_norm.weight": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.text_seq_proj.1.weight"],349"text_encoder_hidden_states_norm.bias": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.text_seq_proj.1.bias"],350# <original>.clip_tok_proj -> <diffusers>.clip_extra_context_tokens_proj351"clip_extra_context_tokens_proj.weight": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.clip_tok_proj.weight"],352"clip_extra_context_tokens_proj.bias": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.clip_tok_proj.bias"],353# <original>.text_feat_proj -> <diffusers>.embedding_proj354"embedding_proj.weight": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.text_feat_proj.weight"],355"embedding_proj.bias": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.text_feat_proj.bias"],356# <original>.cf_param -> <diffusers>.learned_classifier_free_guidance_embeddings357"learned_classifier_free_guidance_embeddings": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.cf_param"],358# <original>.clip_emb -> <diffusers>.clip_image_embeddings_project_to_time_embeddings359"clip_image_embeddings_project_to_time_embeddings.weight": checkpoint[360f"{DECODER_ORIGINAL_PREFIX}.clip_emb.weight"361],362"clip_image_embeddings_project_to_time_embeddings.bias": checkpoint[363f"{DECODER_ORIGINAL_PREFIX}.clip_emb.bias"364],365}366367return diffusers_checkpoint368369370# done text proj371372# super res unet first steps373374SUPER_RES_UNET_FIRST_STEPS_PREFIX = "model_first_steps"375376SUPER_RES_UNET_FIRST_STEPS_CONFIG = {377"sample_size": 256,378"layers_per_block": 3,379"down_block_types": (380"ResnetDownsampleBlock2D",381"ResnetDownsampleBlock2D",382"ResnetDownsampleBlock2D",383"ResnetDownsampleBlock2D",384),385"up_block_types": (386"ResnetUpsampleBlock2D",387"ResnetUpsampleBlock2D",388"ResnetUpsampleBlock2D",389"ResnetUpsampleBlock2D",390),391"block_out_channels": (320, 640, 960, 1280),392"in_channels": 6,393"out_channels": 3,394"add_attention": False,395}396397398def super_res_unet_first_steps_model_from_original_config():399model = UNet2DModel(**SUPER_RES_UNET_FIRST_STEPS_CONFIG)400401return model402403404def super_res_unet_first_steps_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):405diffusers_checkpoint = {}406407original_unet_prefix = SUPER_RES_UNET_FIRST_STEPS_PREFIX408409diffusers_checkpoint.update(unet_time_embeddings(checkpoint, original_unet_prefix))410diffusers_checkpoint.update(unet_conv_in(checkpoint, original_unet_prefix))411412# <original>.input_blocks -> <diffusers>.down_blocks413414original_down_block_idx = 1415416for diffusers_down_block_idx in range(len(model.down_blocks)):417checkpoint_update, num_original_down_blocks = unet_downblock_to_diffusers_checkpoint(418model,419checkpoint,420diffusers_down_block_idx=diffusers_down_block_idx,421original_down_block_idx=original_down_block_idx,422original_unet_prefix=original_unet_prefix,423num_head_channels=None,424)425426original_down_block_idx += num_original_down_blocks427428diffusers_checkpoint.update(checkpoint_update)429430diffusers_checkpoint.update(431unet_midblock_to_diffusers_checkpoint(432model,433checkpoint,434original_unet_prefix=original_unet_prefix,435num_head_channels=None,436)437)438439# <original>.output_blocks -> <diffusers>.up_blocks440441original_up_block_idx = 0442443for diffusers_up_block_idx in range(len(model.up_blocks)):444checkpoint_update, num_original_up_blocks = unet_upblock_to_diffusers_checkpoint(445model,446checkpoint,447diffusers_up_block_idx=diffusers_up_block_idx,448original_up_block_idx=original_up_block_idx,449original_unet_prefix=original_unet_prefix,450num_head_channels=None,451)452453original_up_block_idx += num_original_up_blocks454455diffusers_checkpoint.update(checkpoint_update)456457# done <original>.output_blocks -> <diffusers>.up_blocks458459diffusers_checkpoint.update(unet_conv_norm_out(checkpoint, original_unet_prefix))460diffusers_checkpoint.update(unet_conv_out(checkpoint, original_unet_prefix))461462return diffusers_checkpoint463464465# done super res unet first steps466467# super res unet last step468469SUPER_RES_UNET_LAST_STEP_PREFIX = "model_last_step"470471SUPER_RES_UNET_LAST_STEP_CONFIG = {472"sample_size": 256,473"layers_per_block": 3,474"down_block_types": (475"ResnetDownsampleBlock2D",476"ResnetDownsampleBlock2D",477"ResnetDownsampleBlock2D",478"ResnetDownsampleBlock2D",479),480"up_block_types": (481"ResnetUpsampleBlock2D",482"ResnetUpsampleBlock2D",483"ResnetUpsampleBlock2D",484"ResnetUpsampleBlock2D",485),486"block_out_channels": (320, 640, 960, 1280),487"in_channels": 6,488"out_channels": 3,489"add_attention": False,490}491492493def super_res_unet_last_step_model_from_original_config():494model = UNet2DModel(**SUPER_RES_UNET_LAST_STEP_CONFIG)495496return model497498499def super_res_unet_last_step_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):500diffusers_checkpoint = {}501502original_unet_prefix = SUPER_RES_UNET_LAST_STEP_PREFIX503504diffusers_checkpoint.update(unet_time_embeddings(checkpoint, original_unet_prefix))505diffusers_checkpoint.update(unet_conv_in(checkpoint, original_unet_prefix))506507# <original>.input_blocks -> <diffusers>.down_blocks508509original_down_block_idx = 1510511for diffusers_down_block_idx in range(len(model.down_blocks)):512checkpoint_update, num_original_down_blocks = unet_downblock_to_diffusers_checkpoint(513model,514checkpoint,515diffusers_down_block_idx=diffusers_down_block_idx,516original_down_block_idx=original_down_block_idx,517original_unet_prefix=original_unet_prefix,518num_head_channels=None,519)520521original_down_block_idx += num_original_down_blocks522523diffusers_checkpoint.update(checkpoint_update)524525diffusers_checkpoint.update(526unet_midblock_to_diffusers_checkpoint(527model,528checkpoint,529original_unet_prefix=original_unet_prefix,530num_head_channels=None,531)532)533534# <original>.output_blocks -> <diffusers>.up_blocks535536original_up_block_idx = 0537538for diffusers_up_block_idx in range(len(model.up_blocks)):539checkpoint_update, num_original_up_blocks = unet_upblock_to_diffusers_checkpoint(540model,541checkpoint,542diffusers_up_block_idx=diffusers_up_block_idx,543original_up_block_idx=original_up_block_idx,544original_unet_prefix=original_unet_prefix,545num_head_channels=None,546)547548original_up_block_idx += num_original_up_blocks549550diffusers_checkpoint.update(checkpoint_update)551552# done <original>.output_blocks -> <diffusers>.up_blocks553554diffusers_checkpoint.update(unet_conv_norm_out(checkpoint, original_unet_prefix))555diffusers_checkpoint.update(unet_conv_out(checkpoint, original_unet_prefix))556557return diffusers_checkpoint558559560# done super res unet last step561562563# unet utils564565566# <original>.time_embed -> <diffusers>.time_embedding567def unet_time_embeddings(checkpoint, original_unet_prefix):568diffusers_checkpoint = {}569570diffusers_checkpoint.update(571{572"time_embedding.linear_1.weight": checkpoint[f"{original_unet_prefix}.time_embed.0.weight"],573"time_embedding.linear_1.bias": checkpoint[f"{original_unet_prefix}.time_embed.0.bias"],574"time_embedding.linear_2.weight": checkpoint[f"{original_unet_prefix}.time_embed.2.weight"],575"time_embedding.linear_2.bias": checkpoint[f"{original_unet_prefix}.time_embed.2.bias"],576}577)578579return diffusers_checkpoint580581582# <original>.input_blocks.0 -> <diffusers>.conv_in583def unet_conv_in(checkpoint, original_unet_prefix):584diffusers_checkpoint = {}585586diffusers_checkpoint.update(587{588"conv_in.weight": checkpoint[f"{original_unet_prefix}.input_blocks.0.0.weight"],589"conv_in.bias": checkpoint[f"{original_unet_prefix}.input_blocks.0.0.bias"],590}591)592593return diffusers_checkpoint594595596# <original>.out.0 -> <diffusers>.conv_norm_out597def unet_conv_norm_out(checkpoint, original_unet_prefix):598diffusers_checkpoint = {}599600diffusers_checkpoint.update(601{602"conv_norm_out.weight": checkpoint[f"{original_unet_prefix}.out.0.weight"],603"conv_norm_out.bias": checkpoint[f"{original_unet_prefix}.out.0.bias"],604}605)606607return diffusers_checkpoint608609610# <original>.out.2 -> <diffusers>.conv_out611def unet_conv_out(checkpoint, original_unet_prefix):612diffusers_checkpoint = {}613614diffusers_checkpoint.update(615{616"conv_out.weight": checkpoint[f"{original_unet_prefix}.out.2.weight"],617"conv_out.bias": checkpoint[f"{original_unet_prefix}.out.2.bias"],618}619)620621return diffusers_checkpoint622623624# <original>.input_blocks -> <diffusers>.down_blocks625def unet_downblock_to_diffusers_checkpoint(626model, checkpoint, *, diffusers_down_block_idx, original_down_block_idx, original_unet_prefix, num_head_channels627):628diffusers_checkpoint = {}629630diffusers_resnet_prefix = f"down_blocks.{diffusers_down_block_idx}.resnets"631original_down_block_prefix = f"{original_unet_prefix}.input_blocks"632633down_block = model.down_blocks[diffusers_down_block_idx]634635num_resnets = len(down_block.resnets)636637if down_block.downsamplers is None:638downsampler = False639else:640assert len(down_block.downsamplers) == 1641downsampler = True642# The downsample block is also a resnet643num_resnets += 1644645for resnet_idx_inc in range(num_resnets):646full_resnet_prefix = f"{original_down_block_prefix}.{original_down_block_idx + resnet_idx_inc}.0"647648if downsampler and resnet_idx_inc == num_resnets - 1:649# this is a downsample block650full_diffusers_resnet_prefix = f"down_blocks.{diffusers_down_block_idx}.downsamplers.0"651else:652# this is a regular resnet block653full_diffusers_resnet_prefix = f"{diffusers_resnet_prefix}.{resnet_idx_inc}"654655diffusers_checkpoint.update(656resnet_to_diffusers_checkpoint(657checkpoint, resnet_prefix=full_resnet_prefix, diffusers_resnet_prefix=full_diffusers_resnet_prefix658)659)660661if hasattr(down_block, "attentions"):662num_attentions = len(down_block.attentions)663diffusers_attention_prefix = f"down_blocks.{diffusers_down_block_idx}.attentions"664665for attention_idx_inc in range(num_attentions):666full_attention_prefix = f"{original_down_block_prefix}.{original_down_block_idx + attention_idx_inc}.1"667full_diffusers_attention_prefix = f"{diffusers_attention_prefix}.{attention_idx_inc}"668669diffusers_checkpoint.update(670attention_to_diffusers_checkpoint(671checkpoint,672attention_prefix=full_attention_prefix,673diffusers_attention_prefix=full_diffusers_attention_prefix,674num_head_channels=num_head_channels,675)676)677678num_original_down_blocks = num_resnets679680return diffusers_checkpoint, num_original_down_blocks681682683# <original>.middle_block -> <diffusers>.mid_block684def unet_midblock_to_diffusers_checkpoint(model, checkpoint, *, original_unet_prefix, num_head_channels):685diffusers_checkpoint = {}686687# block 0688689original_block_idx = 0690691diffusers_checkpoint.update(692resnet_to_diffusers_checkpoint(693checkpoint,694diffusers_resnet_prefix="mid_block.resnets.0",695resnet_prefix=f"{original_unet_prefix}.middle_block.{original_block_idx}",696)697)698699original_block_idx += 1700701# optional block 1702703if hasattr(model.mid_block, "attentions") and model.mid_block.attentions[0] is not None:704diffusers_checkpoint.update(705attention_to_diffusers_checkpoint(706checkpoint,707diffusers_attention_prefix="mid_block.attentions.0",708attention_prefix=f"{original_unet_prefix}.middle_block.{original_block_idx}",709num_head_channels=num_head_channels,710)711)712original_block_idx += 1713714# block 1 or block 2715716diffusers_checkpoint.update(717resnet_to_diffusers_checkpoint(718checkpoint,719diffusers_resnet_prefix="mid_block.resnets.1",720resnet_prefix=f"{original_unet_prefix}.middle_block.{original_block_idx}",721)722)723724return diffusers_checkpoint725726727# <original>.output_blocks -> <diffusers>.up_blocks728def unet_upblock_to_diffusers_checkpoint(729model, checkpoint, *, diffusers_up_block_idx, original_up_block_idx, original_unet_prefix, num_head_channels730):731diffusers_checkpoint = {}732733diffusers_resnet_prefix = f"up_blocks.{diffusers_up_block_idx}.resnets"734original_up_block_prefix = f"{original_unet_prefix}.output_blocks"735736up_block = model.up_blocks[diffusers_up_block_idx]737738num_resnets = len(up_block.resnets)739740if up_block.upsamplers is None:741upsampler = False742else:743assert len(up_block.upsamplers) == 1744upsampler = True745# The upsample block is also a resnet746num_resnets += 1747748has_attentions = hasattr(up_block, "attentions")749750for resnet_idx_inc in range(num_resnets):751if upsampler and resnet_idx_inc == num_resnets - 1:752# this is an upsample block753if has_attentions:754# There is a middle attention block that we skip755original_resnet_block_idx = 2756else:757original_resnet_block_idx = 1758759# we add the `minus 1` because the last two resnets are stuck together in the same output block760full_resnet_prefix = (761f"{original_up_block_prefix}.{original_up_block_idx + resnet_idx_inc - 1}.{original_resnet_block_idx}"762)763764full_diffusers_resnet_prefix = f"up_blocks.{diffusers_up_block_idx}.upsamplers.0"765else:766# this is a regular resnet block767full_resnet_prefix = f"{original_up_block_prefix}.{original_up_block_idx + resnet_idx_inc}.0"768full_diffusers_resnet_prefix = f"{diffusers_resnet_prefix}.{resnet_idx_inc}"769770diffusers_checkpoint.update(771resnet_to_diffusers_checkpoint(772checkpoint, resnet_prefix=full_resnet_prefix, diffusers_resnet_prefix=full_diffusers_resnet_prefix773)774)775776if has_attentions:777num_attentions = len(up_block.attentions)778diffusers_attention_prefix = f"up_blocks.{diffusers_up_block_idx}.attentions"779780for attention_idx_inc in range(num_attentions):781full_attention_prefix = f"{original_up_block_prefix}.{original_up_block_idx + attention_idx_inc}.1"782full_diffusers_attention_prefix = f"{diffusers_attention_prefix}.{attention_idx_inc}"783784diffusers_checkpoint.update(785attention_to_diffusers_checkpoint(786checkpoint,787attention_prefix=full_attention_prefix,788diffusers_attention_prefix=full_diffusers_attention_prefix,789num_head_channels=num_head_channels,790)791)792793num_original_down_blocks = num_resnets - 1 if upsampler else num_resnets794795return diffusers_checkpoint, num_original_down_blocks796797798def resnet_to_diffusers_checkpoint(checkpoint, *, diffusers_resnet_prefix, resnet_prefix):799diffusers_checkpoint = {800f"{diffusers_resnet_prefix}.norm1.weight": checkpoint[f"{resnet_prefix}.in_layers.0.weight"],801f"{diffusers_resnet_prefix}.norm1.bias": checkpoint[f"{resnet_prefix}.in_layers.0.bias"],802f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.in_layers.2.weight"],803f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.in_layers.2.bias"],804f"{diffusers_resnet_prefix}.time_emb_proj.weight": checkpoint[f"{resnet_prefix}.emb_layers.1.weight"],805f"{diffusers_resnet_prefix}.time_emb_proj.bias": checkpoint[f"{resnet_prefix}.emb_layers.1.bias"],806f"{diffusers_resnet_prefix}.norm2.weight": checkpoint[f"{resnet_prefix}.out_layers.0.weight"],807f"{diffusers_resnet_prefix}.norm2.bias": checkpoint[f"{resnet_prefix}.out_layers.0.bias"],808f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.out_layers.3.weight"],809f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.out_layers.3.bias"],810}811812skip_connection_prefix = f"{resnet_prefix}.skip_connection"813814if f"{skip_connection_prefix}.weight" in checkpoint:815diffusers_checkpoint.update(816{817f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{skip_connection_prefix}.weight"],818f"{diffusers_resnet_prefix}.conv_shortcut.bias": checkpoint[f"{skip_connection_prefix}.bias"],819}820)821822return diffusers_checkpoint823824825def attention_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix, num_head_channels):826diffusers_checkpoint = {}827828# <original>.norm -> <diffusers>.group_norm829diffusers_checkpoint.update(830{831f"{diffusers_attention_prefix}.group_norm.weight": checkpoint[f"{attention_prefix}.norm.weight"],832f"{diffusers_attention_prefix}.group_norm.bias": checkpoint[f"{attention_prefix}.norm.bias"],833}834)835836# <original>.qkv -> <diffusers>.{query, key, value}837[q_weight, k_weight, v_weight], [q_bias, k_bias, v_bias] = split_attentions(838weight=checkpoint[f"{attention_prefix}.qkv.weight"][:, :, 0],839bias=checkpoint[f"{attention_prefix}.qkv.bias"],840split=3,841chunk_size=num_head_channels,842)843844diffusers_checkpoint.update(845{846f"{diffusers_attention_prefix}.to_q.weight": q_weight,847f"{diffusers_attention_prefix}.to_q.bias": q_bias,848f"{diffusers_attention_prefix}.to_k.weight": k_weight,849f"{diffusers_attention_prefix}.to_k.bias": k_bias,850f"{diffusers_attention_prefix}.to_v.weight": v_weight,851f"{diffusers_attention_prefix}.to_v.bias": v_bias,852}853)854855# <original>.encoder_kv -> <diffusers>.{context_key, context_value}856[encoder_k_weight, encoder_v_weight], [encoder_k_bias, encoder_v_bias] = split_attentions(857weight=checkpoint[f"{attention_prefix}.encoder_kv.weight"][:, :, 0],858bias=checkpoint[f"{attention_prefix}.encoder_kv.bias"],859split=2,860chunk_size=num_head_channels,861)862863diffusers_checkpoint.update(864{865f"{diffusers_attention_prefix}.add_k_proj.weight": encoder_k_weight,866f"{diffusers_attention_prefix}.add_k_proj.bias": encoder_k_bias,867f"{diffusers_attention_prefix}.add_v_proj.weight": encoder_v_weight,868f"{diffusers_attention_prefix}.add_v_proj.bias": encoder_v_bias,869}870)871872# <original>.proj_out (1d conv) -> <diffusers>.proj_attn (linear)873diffusers_checkpoint.update(874{875f"{diffusers_attention_prefix}.to_out.0.weight": checkpoint[f"{attention_prefix}.proj_out.weight"][876:, :, 0877],878f"{diffusers_attention_prefix}.to_out.0.bias": checkpoint[f"{attention_prefix}.proj_out.bias"],879}880)881882return diffusers_checkpoint883884885# TODO maybe document and/or can do more efficiently (build indices in for loop and extract once for each split?)886def split_attentions(*, weight, bias, split, chunk_size):887weights = [None] * split888biases = [None] * split889890weights_biases_idx = 0891892for starting_row_index in range(0, weight.shape[0], chunk_size):893row_indices = torch.arange(starting_row_index, starting_row_index + chunk_size)894895weight_rows = weight[row_indices, :]896bias_rows = bias[row_indices]897898if weights[weights_biases_idx] is None:899assert weights[weights_biases_idx] is None900weights[weights_biases_idx] = weight_rows901biases[weights_biases_idx] = bias_rows902else:903assert weights[weights_biases_idx] is not None904weights[weights_biases_idx] = torch.concat([weights[weights_biases_idx], weight_rows])905biases[weights_biases_idx] = torch.concat([biases[weights_biases_idx], bias_rows])906907weights_biases_idx = (weights_biases_idx + 1) % split908909return weights, biases910911912# done unet utils913914915# Driver functions916917918def text_encoder():919print("loading CLIP text encoder")920921clip_name = "openai/clip-vit-large-patch14"922923# sets pad_value to 0924pad_token = "!"925926tokenizer_model = CLIPTokenizer.from_pretrained(clip_name, pad_token=pad_token, device_map="auto")927928assert tokenizer_model.convert_tokens_to_ids(pad_token) == 0929930text_encoder_model = CLIPTextModelWithProjection.from_pretrained(931clip_name,932# `CLIPTextModel` does not support device_map="auto"933# device_map="auto"934)935936print("done loading CLIP text encoder")937938return text_encoder_model, tokenizer_model939940941def prior(*, args, checkpoint_map_location):942print("loading prior")943944prior_checkpoint = torch.load(args.prior_checkpoint_path, map_location=checkpoint_map_location)945prior_checkpoint = prior_checkpoint["state_dict"]946947clip_stats_checkpoint = torch.load(args.clip_stat_path, map_location=checkpoint_map_location)948949prior_model = prior_model_from_original_config()950951prior_diffusers_checkpoint = prior_original_checkpoint_to_diffusers_checkpoint(952prior_model, prior_checkpoint, clip_stats_checkpoint953)954955del prior_checkpoint956del clip_stats_checkpoint957958load_checkpoint_to_model(prior_diffusers_checkpoint, prior_model, strict=True)959960print("done loading prior")961962return prior_model963964965def decoder(*, args, checkpoint_map_location):966print("loading decoder")967968decoder_checkpoint = torch.load(args.decoder_checkpoint_path, map_location=checkpoint_map_location)969decoder_checkpoint = decoder_checkpoint["state_dict"]970971decoder_model = decoder_model_from_original_config()972973decoder_diffusers_checkpoint = decoder_original_checkpoint_to_diffusers_checkpoint(974decoder_model, decoder_checkpoint975)976977# text proj interlude978979# The original decoder implementation includes a set of parameters that are used980# for creating the `encoder_hidden_states` which are what the U-net is conditioned981# on. The diffusers conditional unet directly takes the encoder_hidden_states. We pull982# the parameters into the UnCLIPTextProjModel class983text_proj_model = text_proj_from_original_config()984985text_proj_checkpoint = text_proj_original_checkpoint_to_diffusers_checkpoint(decoder_checkpoint)986987load_checkpoint_to_model(text_proj_checkpoint, text_proj_model, strict=True)988989# done text proj interlude990991del decoder_checkpoint992993load_checkpoint_to_model(decoder_diffusers_checkpoint, decoder_model, strict=True)994995print("done loading decoder")996997return decoder_model, text_proj_model9989991000def super_res_unet(*, args, checkpoint_map_location):1001print("loading super resolution unet")10021003super_res_checkpoint = torch.load(args.super_res_unet_checkpoint_path, map_location=checkpoint_map_location)1004super_res_checkpoint = super_res_checkpoint["state_dict"]10051006# model_first_steps10071008super_res_first_model = super_res_unet_first_steps_model_from_original_config()10091010super_res_first_steps_checkpoint = super_res_unet_first_steps_original_checkpoint_to_diffusers_checkpoint(1011super_res_first_model, super_res_checkpoint1012)10131014# model_last_step1015super_res_last_model = super_res_unet_last_step_model_from_original_config()10161017super_res_last_step_checkpoint = super_res_unet_last_step_original_checkpoint_to_diffusers_checkpoint(1018super_res_last_model, super_res_checkpoint1019)10201021del super_res_checkpoint10221023load_checkpoint_to_model(super_res_first_steps_checkpoint, super_res_first_model, strict=True)10241025load_checkpoint_to_model(super_res_last_step_checkpoint, super_res_last_model, strict=True)10261027print("done loading super resolution unet")10281029return super_res_first_model, super_res_last_model103010311032def load_checkpoint_to_model(checkpoint, model, strict=False):1033with tempfile.NamedTemporaryFile() as file:1034torch.save(checkpoint, file.name)1035del checkpoint1036if strict:1037model.load_state_dict(torch.load(file.name), strict=True)1038else:1039load_checkpoint_and_dispatch(model, file.name, device_map="auto")104010411042if __name__ == "__main__":1043parser = argparse.ArgumentParser()10441045parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")10461047parser.add_argument(1048"--prior_checkpoint_path",1049default=None,1050type=str,1051required=True,1052help="Path to the prior checkpoint to convert.",1053)10541055parser.add_argument(1056"--decoder_checkpoint_path",1057default=None,1058type=str,1059required=True,1060help="Path to the decoder checkpoint to convert.",1061)10621063parser.add_argument(1064"--super_res_unet_checkpoint_path",1065default=None,1066type=str,1067required=True,1068help="Path to the super resolution checkpoint to convert.",1069)10701071parser.add_argument(1072"--clip_stat_path", default=None, type=str, required=True, help="Path to the clip stats checkpoint to convert."1073)10741075parser.add_argument(1076"--checkpoint_load_device",1077default="cpu",1078type=str,1079required=False,1080help="The device passed to `map_location` when loading checkpoints.",1081)10821083parser.add_argument(1084"--debug",1085default=None,1086type=str,1087required=False,1088help="Only run a specific stage of the convert script. Used for debugging",1089)10901091args = parser.parse_args()10921093print(f"loading checkpoints to {args.checkpoint_load_device}")10941095checkpoint_map_location = torch.device(args.checkpoint_load_device)10961097if args.debug is not None:1098print(f"debug: only executing {args.debug}")10991100if args.debug is None:1101text_encoder_model, tokenizer_model = text_encoder()11021103prior_model = prior(args=args, checkpoint_map_location=checkpoint_map_location)11041105decoder_model, text_proj_model = decoder(args=args, checkpoint_map_location=checkpoint_map_location)11061107super_res_first_model, super_res_last_model = super_res_unet(1108args=args, checkpoint_map_location=checkpoint_map_location1109)11101111prior_scheduler = UnCLIPScheduler(1112variance_type="fixed_small_log",1113prediction_type="sample",1114num_train_timesteps=1000,1115clip_sample_range=5.0,1116)11171118decoder_scheduler = UnCLIPScheduler(1119variance_type="learned_range",1120prediction_type="epsilon",1121num_train_timesteps=1000,1122)11231124super_res_scheduler = UnCLIPScheduler(1125variance_type="fixed_small_log",1126prediction_type="epsilon",1127num_train_timesteps=1000,1128)11291130print(f"saving Kakao Brain unCLIP to {args.dump_path}")11311132pipe = UnCLIPPipeline(1133prior=prior_model,1134decoder=decoder_model,1135text_proj=text_proj_model,1136tokenizer=tokenizer_model,1137text_encoder=text_encoder_model,1138super_res_first=super_res_first_model,1139super_res_last=super_res_last_model,1140prior_scheduler=prior_scheduler,1141decoder_scheduler=decoder_scheduler,1142super_res_scheduler=super_res_scheduler,1143)1144pipe.save_pretrained(args.dump_path)11451146print("done writing Kakao Brain unCLIP")1147elif args.debug == "text_encoder":1148text_encoder_model, tokenizer_model = text_encoder()1149elif args.debug == "prior":1150prior_model = prior(args=args, checkpoint_map_location=checkpoint_map_location)1151elif args.debug == "decoder":1152decoder_model, text_proj_model = decoder(args=args, checkpoint_map_location=checkpoint_map_location)1153elif args.debug == "super_res_unet":1154super_res_first_model, super_res_last_model = super_res_unet(1155args=args, checkpoint_map_location=checkpoint_map_location1156)1157else:1158raise ValueError(f"unknown debug value : {args.debug}")115911601161