Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/scripts/convert_k_upscaler_to_diffusers.py
1440 views
1
import argparse
2
3
import huggingface_hub
4
import k_diffusion as K
5
import torch
6
7
from diffusers import UNet2DConditionModel
8
9
10
UPSCALER_REPO = "pcuenq/k-upscaler"
11
12
13
def resnet_to_diffusers_checkpoint(resnet, checkpoint, *, diffusers_resnet_prefix, resnet_prefix):
14
rv = {
15
# norm1
16
f"{diffusers_resnet_prefix}.norm1.linear.weight": checkpoint[f"{resnet_prefix}.main.0.mapper.weight"],
17
f"{diffusers_resnet_prefix}.norm1.linear.bias": checkpoint[f"{resnet_prefix}.main.0.mapper.bias"],
18
# conv1
19
f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.main.2.weight"],
20
f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.main.2.bias"],
21
# norm2
22
f"{diffusers_resnet_prefix}.norm2.linear.weight": checkpoint[f"{resnet_prefix}.main.4.mapper.weight"],
23
f"{diffusers_resnet_prefix}.norm2.linear.bias": checkpoint[f"{resnet_prefix}.main.4.mapper.bias"],
24
# conv2
25
f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.main.6.weight"],
26
f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.main.6.bias"],
27
}
28
29
if resnet.conv_shortcut is not None:
30
rv.update(
31
{
32
f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{resnet_prefix}.skip.weight"],
33
}
34
)
35
36
return rv
37
38
39
def self_attn_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix):
40
weight_q, weight_k, weight_v = checkpoint[f"{attention_prefix}.qkv_proj.weight"].chunk(3, dim=0)
41
bias_q, bias_k, bias_v = checkpoint[f"{attention_prefix}.qkv_proj.bias"].chunk(3, dim=0)
42
rv = {
43
# norm
44
f"{diffusers_attention_prefix}.norm1.linear.weight": checkpoint[f"{attention_prefix}.norm_in.mapper.weight"],
45
f"{diffusers_attention_prefix}.norm1.linear.bias": checkpoint[f"{attention_prefix}.norm_in.mapper.bias"],
46
# to_q
47
f"{diffusers_attention_prefix}.attn1.to_q.weight": weight_q.squeeze(-1).squeeze(-1),
48
f"{diffusers_attention_prefix}.attn1.to_q.bias": bias_q,
49
# to_k
50
f"{diffusers_attention_prefix}.attn1.to_k.weight": weight_k.squeeze(-1).squeeze(-1),
51
f"{diffusers_attention_prefix}.attn1.to_k.bias": bias_k,
52
# to_v
53
f"{diffusers_attention_prefix}.attn1.to_v.weight": weight_v.squeeze(-1).squeeze(-1),
54
f"{diffusers_attention_prefix}.attn1.to_v.bias": bias_v,
55
# to_out
56
f"{diffusers_attention_prefix}.attn1.to_out.0.weight": checkpoint[f"{attention_prefix}.out_proj.weight"]
57
.squeeze(-1)
58
.squeeze(-1),
59
f"{diffusers_attention_prefix}.attn1.to_out.0.bias": checkpoint[f"{attention_prefix}.out_proj.bias"],
60
}
61
62
return rv
63
64
65
def cross_attn_to_diffusers_checkpoint(
66
checkpoint, *, diffusers_attention_prefix, diffusers_attention_index, attention_prefix
67
):
68
weight_k, weight_v = checkpoint[f"{attention_prefix}.kv_proj.weight"].chunk(2, dim=0)
69
bias_k, bias_v = checkpoint[f"{attention_prefix}.kv_proj.bias"].chunk(2, dim=0)
70
71
rv = {
72
# norm2 (ada groupnorm)
73
f"{diffusers_attention_prefix}.norm{diffusers_attention_index}.linear.weight": checkpoint[
74
f"{attention_prefix}.norm_dec.mapper.weight"
75
],
76
f"{diffusers_attention_prefix}.norm{diffusers_attention_index}.linear.bias": checkpoint[
77
f"{attention_prefix}.norm_dec.mapper.bias"
78
],
79
# layernorm on encoder_hidden_state
80
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.norm_cross.weight": checkpoint[
81
f"{attention_prefix}.norm_enc.weight"
82
],
83
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.norm_cross.bias": checkpoint[
84
f"{attention_prefix}.norm_enc.bias"
85
],
86
# to_q
87
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_q.weight": checkpoint[
88
f"{attention_prefix}.q_proj.weight"
89
]
90
.squeeze(-1)
91
.squeeze(-1),
92
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_q.bias": checkpoint[
93
f"{attention_prefix}.q_proj.bias"
94
],
95
# to_k
96
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_k.weight": weight_k.squeeze(-1).squeeze(-1),
97
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_k.bias": bias_k,
98
# to_v
99
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_v.weight": weight_v.squeeze(-1).squeeze(-1),
100
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_v.bias": bias_v,
101
# to_out
102
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_out.0.weight": checkpoint[
103
f"{attention_prefix}.out_proj.weight"
104
]
105
.squeeze(-1)
106
.squeeze(-1),
107
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_out.0.bias": checkpoint[
108
f"{attention_prefix}.out_proj.bias"
109
],
110
}
111
112
return rv
113
114
115
def block_to_diffusers_checkpoint(block, checkpoint, block_idx, block_type):
116
block_prefix = "inner_model.u_net.u_blocks" if block_type == "up" else "inner_model.u_net.d_blocks"
117
block_prefix = f"{block_prefix}.{block_idx}"
118
119
diffusers_checkpoint = {}
120
121
if not hasattr(block, "attentions"):
122
n = 1 # resnet only
123
elif not block.attentions[0].add_self_attention:
124
n = 2 # resnet -> cross-attention
125
else:
126
n = 3 # resnet -> self-attention -> cross-attention)
127
128
for resnet_idx, resnet in enumerate(block.resnets):
129
# diffusers_resnet_prefix = f"{diffusers_up_block_prefix}.resnets.{resnet_idx}"
130
diffusers_resnet_prefix = f"{block_type}_blocks.{block_idx}.resnets.{resnet_idx}"
131
idx = n * resnet_idx if block_type == "up" else n * resnet_idx + 1
132
resnet_prefix = f"{block_prefix}.{idx}" if block_type == "up" else f"{block_prefix}.{idx}"
133
134
diffusers_checkpoint.update(
135
resnet_to_diffusers_checkpoint(
136
resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix
137
)
138
)
139
140
if hasattr(block, "attentions"):
141
for attention_idx, attention in enumerate(block.attentions):
142
diffusers_attention_prefix = f"{block_type}_blocks.{block_idx}.attentions.{attention_idx}"
143
idx = n * attention_idx + 1 if block_type == "up" else n * attention_idx + 2
144
self_attention_prefix = f"{block_prefix}.{idx}"
145
cross_attention_prefix = f"{block_prefix}.{idx }"
146
cross_attention_index = 1 if not attention.add_self_attention else 2
147
idx = (
148
n * attention_idx + cross_attention_index
149
if block_type == "up"
150
else n * attention_idx + cross_attention_index + 1
151
)
152
cross_attention_prefix = f"{block_prefix}.{idx }"
153
154
diffusers_checkpoint.update(
155
cross_attn_to_diffusers_checkpoint(
156
checkpoint,
157
diffusers_attention_prefix=diffusers_attention_prefix,
158
diffusers_attention_index=2,
159
attention_prefix=cross_attention_prefix,
160
)
161
)
162
163
if attention.add_self_attention is True:
164
diffusers_checkpoint.update(
165
self_attn_to_diffusers_checkpoint(
166
checkpoint,
167
diffusers_attention_prefix=diffusers_attention_prefix,
168
attention_prefix=self_attention_prefix,
169
)
170
)
171
172
return diffusers_checkpoint
173
174
175
def unet_to_diffusers_checkpoint(model, checkpoint):
176
diffusers_checkpoint = {}
177
178
# pre-processing
179
diffusers_checkpoint.update(
180
{
181
"conv_in.weight": checkpoint["inner_model.proj_in.weight"],
182
"conv_in.bias": checkpoint["inner_model.proj_in.bias"],
183
}
184
)
185
186
# timestep and class embedding
187
diffusers_checkpoint.update(
188
{
189
"time_proj.weight": checkpoint["inner_model.timestep_embed.weight"].squeeze(-1),
190
"time_embedding.linear_1.weight": checkpoint["inner_model.mapping.0.weight"],
191
"time_embedding.linear_1.bias": checkpoint["inner_model.mapping.0.bias"],
192
"time_embedding.linear_2.weight": checkpoint["inner_model.mapping.2.weight"],
193
"time_embedding.linear_2.bias": checkpoint["inner_model.mapping.2.bias"],
194
"time_embedding.cond_proj.weight": checkpoint["inner_model.mapping_cond.weight"],
195
}
196
)
197
198
# down_blocks
199
for down_block_idx, down_block in enumerate(model.down_blocks):
200
diffusers_checkpoint.update(block_to_diffusers_checkpoint(down_block, checkpoint, down_block_idx, "down"))
201
202
# up_blocks
203
for up_block_idx, up_block in enumerate(model.up_blocks):
204
diffusers_checkpoint.update(block_to_diffusers_checkpoint(up_block, checkpoint, up_block_idx, "up"))
205
206
# post-processing
207
diffusers_checkpoint.update(
208
{
209
"conv_out.weight": checkpoint["inner_model.proj_out.weight"],
210
"conv_out.bias": checkpoint["inner_model.proj_out.bias"],
211
}
212
)
213
214
return diffusers_checkpoint
215
216
217
def unet_model_from_original_config(original_config):
218
in_channels = original_config["input_channels"] + original_config["unet_cond_dim"]
219
out_channels = original_config["input_channels"] + (1 if original_config["has_variance"] else 0)
220
221
block_out_channels = original_config["channels"]
222
223
assert (
224
len(set(original_config["depths"])) == 1
225
), "UNet2DConditionModel currently do not support blocks with different number of layers"
226
layers_per_block = original_config["depths"][0]
227
228
class_labels_dim = original_config["mapping_cond_dim"]
229
cross_attention_dim = original_config["cross_cond_dim"]
230
231
attn1_types = []
232
attn2_types = []
233
for s, c in zip(original_config["self_attn_depths"], original_config["cross_attn_depths"]):
234
if s:
235
a1 = "self"
236
a2 = "cross" if c else None
237
elif c:
238
a1 = "cross"
239
a2 = None
240
else:
241
a1 = None
242
a2 = None
243
attn1_types.append(a1)
244
attn2_types.append(a2)
245
246
unet = UNet2DConditionModel(
247
in_channels=in_channels,
248
out_channels=out_channels,
249
down_block_types=("KDownBlock2D", "KCrossAttnDownBlock2D", "KCrossAttnDownBlock2D", "KCrossAttnDownBlock2D"),
250
mid_block_type=None,
251
up_block_types=("KCrossAttnUpBlock2D", "KCrossAttnUpBlock2D", "KCrossAttnUpBlock2D", "KUpBlock2D"),
252
block_out_channels=block_out_channels,
253
layers_per_block=layers_per_block,
254
act_fn="gelu",
255
norm_num_groups=None,
256
cross_attention_dim=cross_attention_dim,
257
attention_head_dim=64,
258
time_cond_proj_dim=class_labels_dim,
259
resnet_time_scale_shift="scale_shift",
260
time_embedding_type="fourier",
261
timestep_post_act="gelu",
262
conv_in_kernel=1,
263
conv_out_kernel=1,
264
)
265
266
return unet
267
268
269
def main(args):
270
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
271
272
orig_config_path = huggingface_hub.hf_hub_download(UPSCALER_REPO, "config_laion_text_cond_latent_upscaler_2.json")
273
orig_weights_path = huggingface_hub.hf_hub_download(
274
UPSCALER_REPO, "laion_text_cond_latent_upscaler_2_1_00470000_slim.pth"
275
)
276
print(f"loading original model configuration from {orig_config_path}")
277
print(f"loading original model checkpoint from {orig_weights_path}")
278
279
print("converting to diffusers unet")
280
orig_config = K.config.load_config(open(orig_config_path))["model"]
281
model = unet_model_from_original_config(orig_config)
282
283
orig_checkpoint = torch.load(orig_weights_path, map_location=device)["model_ema"]
284
converted_checkpoint = unet_to_diffusers_checkpoint(model, orig_checkpoint)
285
286
model.load_state_dict(converted_checkpoint, strict=True)
287
model.save_pretrained(args.dump_path)
288
print(f"saving converted unet model in {args.dump_path}")
289
290
291
if __name__ == "__main__":
292
parser = argparse.ArgumentParser()
293
294
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
295
args = parser.parse_args()
296
297
main(args)
298
299