Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/scripts/convert_vae_pt_to_diffusers.py
1440 views
1
import argparse
2
import io
3
4
import requests
5
import torch
6
from omegaconf import OmegaConf
7
8
from diffusers import AutoencoderKL
9
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
10
assign_to_checkpoint,
11
conv_attn_to_linear,
12
create_vae_diffusers_config,
13
renew_vae_attention_paths,
14
renew_vae_resnet_paths,
15
)
16
17
18
def custom_convert_ldm_vae_checkpoint(checkpoint, config):
19
vae_state_dict = checkpoint
20
21
new_checkpoint = {}
22
23
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
24
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
25
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
26
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
27
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
28
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
29
30
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
31
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
32
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
33
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
34
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
35
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
36
37
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
38
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
39
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
40
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
41
42
# Retrieves the keys for the encoder down blocks only
43
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
44
down_blocks = {
45
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
46
}
47
48
# Retrieves the keys for the decoder up blocks only
49
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
50
up_blocks = {
51
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
52
}
53
54
for i in range(num_down_blocks):
55
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
56
57
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
58
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
59
f"encoder.down.{i}.downsample.conv.weight"
60
)
61
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
62
f"encoder.down.{i}.downsample.conv.bias"
63
)
64
65
paths = renew_vae_resnet_paths(resnets)
66
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
67
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
68
69
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
70
num_mid_res_blocks = 2
71
for i in range(1, num_mid_res_blocks + 1):
72
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
73
74
paths = renew_vae_resnet_paths(resnets)
75
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
76
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
77
78
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
79
paths = renew_vae_attention_paths(mid_attentions)
80
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
81
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
82
conv_attn_to_linear(new_checkpoint)
83
84
for i in range(num_up_blocks):
85
block_id = num_up_blocks - 1 - i
86
resnets = [
87
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
88
]
89
90
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
91
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
92
f"decoder.up.{block_id}.upsample.conv.weight"
93
]
94
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
95
f"decoder.up.{block_id}.upsample.conv.bias"
96
]
97
98
paths = renew_vae_resnet_paths(resnets)
99
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
100
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
101
102
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
103
num_mid_res_blocks = 2
104
for i in range(1, num_mid_res_blocks + 1):
105
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
106
107
paths = renew_vae_resnet_paths(resnets)
108
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
109
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
110
111
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
112
paths = renew_vae_attention_paths(mid_attentions)
113
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
114
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
115
conv_attn_to_linear(new_checkpoint)
116
return new_checkpoint
117
118
119
def vae_pt_to_vae_diffuser(
120
checkpoint_path: str,
121
output_path: str,
122
):
123
# Only support V1
124
r = requests.get(
125
" https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
126
)
127
io_obj = io.BytesIO(r.content)
128
129
original_config = OmegaConf.load(io_obj)
130
image_size = 512
131
device = "cuda" if torch.cuda.is_available() else "cpu"
132
checkpoint = torch.load(checkpoint_path, map_location=device)
133
134
# Convert the VAE model.
135
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
136
converted_vae_checkpoint = custom_convert_ldm_vae_checkpoint(checkpoint["state_dict"], vae_config)
137
138
vae = AutoencoderKL(**vae_config)
139
vae.load_state_dict(converted_vae_checkpoint)
140
vae.save_pretrained(output_path)
141
142
143
if __name__ == "__main__":
144
parser = argparse.ArgumentParser()
145
146
parser.add_argument("--vae_pt_path", default=None, type=str, required=True, help="Path to the VAE.pt to convert.")
147
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the VAE.pt to convert.")
148
149
args = parser.parse_args()
150
151
vae_pt_to_vae_diffuser(args.vae_pt_path, args.dump_path)
152
153