Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/scripts/convert_dit_to_diffusers.py
1440 views
1
import argparse
2
import os
3
4
import torch
5
from torchvision.datasets.utils import download_url
6
7
from diffusers import AutoencoderKL, DDIMScheduler, DiTPipeline, Transformer2DModel
8
9
10
pretrained_models = {512: "DiT-XL-2-512x512.pt", 256: "DiT-XL-2-256x256.pt"}
11
12
13
def download_model(model_name):
14
"""
15
Downloads a pre-trained DiT model from the web.
16
"""
17
local_path = f"pretrained_models/{model_name}"
18
if not os.path.isfile(local_path):
19
os.makedirs("pretrained_models", exist_ok=True)
20
web_path = f"https://dl.fbaipublicfiles.com/DiT/models/{model_name}"
21
download_url(web_path, "pretrained_models")
22
model = torch.load(local_path, map_location=lambda storage, loc: storage)
23
return model
24
25
26
def main(args):
27
state_dict = download_model(pretrained_models[args.image_size])
28
29
state_dict["pos_embed.proj.weight"] = state_dict["x_embedder.proj.weight"]
30
state_dict["pos_embed.proj.bias"] = state_dict["x_embedder.proj.bias"]
31
state_dict.pop("x_embedder.proj.weight")
32
state_dict.pop("x_embedder.proj.bias")
33
34
for depth in range(28):
35
state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_1.weight"] = state_dict[
36
"t_embedder.mlp.0.weight"
37
]
38
state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_1.bias"] = state_dict[
39
"t_embedder.mlp.0.bias"
40
]
41
state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_2.weight"] = state_dict[
42
"t_embedder.mlp.2.weight"
43
]
44
state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_2.bias"] = state_dict[
45
"t_embedder.mlp.2.bias"
46
]
47
state_dict[f"transformer_blocks.{depth}.norm1.emb.class_embedder.embedding_table.weight"] = state_dict[
48
"y_embedder.embedding_table.weight"
49
]
50
51
state_dict[f"transformer_blocks.{depth}.norm1.linear.weight"] = state_dict[
52
f"blocks.{depth}.adaLN_modulation.1.weight"
53
]
54
state_dict[f"transformer_blocks.{depth}.norm1.linear.bias"] = state_dict[
55
f"blocks.{depth}.adaLN_modulation.1.bias"
56
]
57
58
q, k, v = torch.chunk(state_dict[f"blocks.{depth}.attn.qkv.weight"], 3, dim=0)
59
q_bias, k_bias, v_bias = torch.chunk(state_dict[f"blocks.{depth}.attn.qkv.bias"], 3, dim=0)
60
61
state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q
62
state_dict[f"transformer_blocks.{depth}.attn1.to_q.bias"] = q_bias
63
state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k
64
state_dict[f"transformer_blocks.{depth}.attn1.to_k.bias"] = k_bias
65
state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v
66
state_dict[f"transformer_blocks.{depth}.attn1.to_v.bias"] = v_bias
67
68
state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict[
69
f"blocks.{depth}.attn.proj.weight"
70
]
71
state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.bias"] = state_dict[f"blocks.{depth}.attn.proj.bias"]
72
73
state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.weight"] = state_dict[f"blocks.{depth}.mlp.fc1.weight"]
74
state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.bias"] = state_dict[f"blocks.{depth}.mlp.fc1.bias"]
75
state_dict[f"transformer_blocks.{depth}.ff.net.2.weight"] = state_dict[f"blocks.{depth}.mlp.fc2.weight"]
76
state_dict[f"transformer_blocks.{depth}.ff.net.2.bias"] = state_dict[f"blocks.{depth}.mlp.fc2.bias"]
77
78
state_dict.pop(f"blocks.{depth}.attn.qkv.weight")
79
state_dict.pop(f"blocks.{depth}.attn.qkv.bias")
80
state_dict.pop(f"blocks.{depth}.attn.proj.weight")
81
state_dict.pop(f"blocks.{depth}.attn.proj.bias")
82
state_dict.pop(f"blocks.{depth}.mlp.fc1.weight")
83
state_dict.pop(f"blocks.{depth}.mlp.fc1.bias")
84
state_dict.pop(f"blocks.{depth}.mlp.fc2.weight")
85
state_dict.pop(f"blocks.{depth}.mlp.fc2.bias")
86
state_dict.pop(f"blocks.{depth}.adaLN_modulation.1.weight")
87
state_dict.pop(f"blocks.{depth}.adaLN_modulation.1.bias")
88
89
state_dict.pop("t_embedder.mlp.0.weight")
90
state_dict.pop("t_embedder.mlp.0.bias")
91
state_dict.pop("t_embedder.mlp.2.weight")
92
state_dict.pop("t_embedder.mlp.2.bias")
93
state_dict.pop("y_embedder.embedding_table.weight")
94
95
state_dict["proj_out_1.weight"] = state_dict["final_layer.adaLN_modulation.1.weight"]
96
state_dict["proj_out_1.bias"] = state_dict["final_layer.adaLN_modulation.1.bias"]
97
state_dict["proj_out_2.weight"] = state_dict["final_layer.linear.weight"]
98
state_dict["proj_out_2.bias"] = state_dict["final_layer.linear.bias"]
99
100
state_dict.pop("final_layer.linear.weight")
101
state_dict.pop("final_layer.linear.bias")
102
state_dict.pop("final_layer.adaLN_modulation.1.weight")
103
state_dict.pop("final_layer.adaLN_modulation.1.bias")
104
105
# DiT XL/2
106
transformer = Transformer2DModel(
107
sample_size=args.image_size // 8,
108
num_layers=28,
109
attention_head_dim=72,
110
in_channels=4,
111
out_channels=8,
112
patch_size=2,
113
attention_bias=True,
114
num_attention_heads=16,
115
activation_fn="gelu-approximate",
116
num_embeds_ada_norm=1000,
117
norm_type="ada_norm_zero",
118
norm_elementwise_affine=False,
119
)
120
transformer.load_state_dict(state_dict, strict=True)
121
122
scheduler = DDIMScheduler(
123
num_train_timesteps=1000,
124
beta_schedule="linear",
125
prediction_type="epsilon",
126
clip_sample=False,
127
)
128
129
vae = AutoencoderKL.from_pretrained(args.vae_model)
130
131
pipeline = DiTPipeline(transformer=transformer, vae=vae, scheduler=scheduler)
132
133
if args.save:
134
pipeline.save_pretrained(args.checkpoint_path)
135
136
137
if __name__ == "__main__":
138
parser = argparse.ArgumentParser()
139
140
parser.add_argument(
141
"--image_size",
142
default=256,
143
type=int,
144
required=False,
145
help="Image size of pretrained model, either 256 or 512.",
146
)
147
parser.add_argument(
148
"--vae_model",
149
default="stabilityai/sd-vae-ft-ema",
150
type=str,
151
required=False,
152
help="Path to pretrained VAE model, either stabilityai/sd-vae-ft-mse or stabilityai/sd-vae-ft-ema.",
153
)
154
parser.add_argument(
155
"--save", default=True, type=bool, required=False, help="Whether to save the converted pipeline or not."
156
)
157
parser.add_argument(
158
"--checkpoint_path", default=None, type=str, required=True, help="Path to the output pipeline."
159
)
160
161
args = parser.parse_args()
162
main(args)
163
164