Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/scripts/convert_vq_diffusion_to_diffusers.py
1440 views
1
"""
2
This script ports models from VQ-diffusion (https://github.com/microsoft/VQ-Diffusion) to diffusers.
3
4
It currently only supports porting the ITHQ dataset.
5
6
ITHQ dataset:
7
```sh
8
# From the root directory of diffusers.
9
10
# Download the VQVAE checkpoint
11
$ wget https://facevcstandard.blob.core.windows.net/v-zhictang/Improved-VQ-Diffusion_model_release/ithq_vqvae.pth?sv=2020-10-02&st=2022-05-30T15%3A17%3A18Z&se=2030-05-31T15%3A17%3A00Z&sr=b&sp=r&sig=1jVavHFPpUjDs%2FTO1V3PTezaNbPp2Nx8MxiWI7y6fEY%3D -O ithq_vqvae.pth
12
13
# Download the VQVAE config
14
# NOTE that in VQ-diffusion the documented file is `configs/ithq.yaml` but the target class
15
# `image_synthesis.modeling.codecs.image_codec.ema_vqvae.PatchVQVAE`
16
# loads `OUTPUT/pretrained_model/taming_dvae/config.yaml`
17
$ wget https://raw.githubusercontent.com/microsoft/VQ-Diffusion/main/OUTPUT/pretrained_model/taming_dvae/config.yaml -O ithq_vqvae.yaml
18
19
# Download the main model checkpoint
20
$ wget https://facevcstandard.blob.core.windows.net/v-zhictang/Improved-VQ-Diffusion_model_release/ithq_learnable.pth?sv=2020-10-02&st=2022-05-30T10%3A22%3A06Z&se=2030-05-31T10%3A22%3A00Z&sr=b&sp=r&sig=GOE%2Bza02%2FPnGxYVOOPtwrTR4RA3%2F5NVgMxdW4kjaEZ8%3D -O ithq_learnable.pth
21
22
# Download the main model config
23
$ wget https://raw.githubusercontent.com/microsoft/VQ-Diffusion/main/configs/ithq.yaml -O ithq.yaml
24
25
# run the convert script
26
$ python ./scripts/convert_vq_diffusion_to_diffusers.py \
27
--checkpoint_path ./ithq_learnable.pth \
28
--original_config_file ./ithq.yaml \
29
--vqvae_checkpoint_path ./ithq_vqvae.pth \
30
--vqvae_original_config_file ./ithq_vqvae.yaml \
31
--dump_path <path to save pre-trained `VQDiffusionPipeline`>
32
```
33
"""
34
35
import argparse
36
import tempfile
37
38
import torch
39
import yaml
40
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
41
from transformers import CLIPTextModel, CLIPTokenizer
42
from yaml.loader import FullLoader
43
44
from diffusers import Transformer2DModel, VQDiffusionPipeline, VQDiffusionScheduler, VQModel
45
from diffusers.pipelines.vq_diffusion.pipeline_vq_diffusion import LearnedClassifierFreeSamplingEmbeddings
46
47
48
try:
49
from omegaconf import OmegaConf
50
except ImportError:
51
raise ImportError(
52
"OmegaConf is required to convert the VQ Diffusion checkpoints. Please install it with `pip install"
53
" OmegaConf`."
54
)
55
56
# vqvae model
57
58
PORTED_VQVAES = ["image_synthesis.modeling.codecs.image_codec.patch_vqgan.PatchVQGAN"]
59
60
61
def vqvae_model_from_original_config(original_config):
62
assert original_config.target in PORTED_VQVAES, f"{original_config.target} has not yet been ported to diffusers."
63
64
original_config = original_config.params
65
66
original_encoder_config = original_config.encoder_config.params
67
original_decoder_config = original_config.decoder_config.params
68
69
in_channels = original_encoder_config.in_channels
70
out_channels = original_decoder_config.out_ch
71
72
down_block_types = get_down_block_types(original_encoder_config)
73
up_block_types = get_up_block_types(original_decoder_config)
74
75
assert original_encoder_config.ch == original_decoder_config.ch
76
assert original_encoder_config.ch_mult == original_decoder_config.ch_mult
77
block_out_channels = tuple(
78
[original_encoder_config.ch * a_ch_mult for a_ch_mult in original_encoder_config.ch_mult]
79
)
80
81
assert original_encoder_config.num_res_blocks == original_decoder_config.num_res_blocks
82
layers_per_block = original_encoder_config.num_res_blocks
83
84
assert original_encoder_config.z_channels == original_decoder_config.z_channels
85
latent_channels = original_encoder_config.z_channels
86
87
num_vq_embeddings = original_config.n_embed
88
89
# Hard coded value for ResnetBlock.GoupNorm(num_groups) in VQ-diffusion
90
norm_num_groups = 32
91
92
e_dim = original_config.embed_dim
93
94
model = VQModel(
95
in_channels=in_channels,
96
out_channels=out_channels,
97
down_block_types=down_block_types,
98
up_block_types=up_block_types,
99
block_out_channels=block_out_channels,
100
layers_per_block=layers_per_block,
101
latent_channels=latent_channels,
102
num_vq_embeddings=num_vq_embeddings,
103
norm_num_groups=norm_num_groups,
104
vq_embed_dim=e_dim,
105
)
106
107
return model
108
109
110
def get_down_block_types(original_encoder_config):
111
attn_resolutions = coerce_attn_resolutions(original_encoder_config.attn_resolutions)
112
num_resolutions = len(original_encoder_config.ch_mult)
113
resolution = coerce_resolution(original_encoder_config.resolution)
114
115
curr_res = resolution
116
down_block_types = []
117
118
for _ in range(num_resolutions):
119
if curr_res in attn_resolutions:
120
down_block_type = "AttnDownEncoderBlock2D"
121
else:
122
down_block_type = "DownEncoderBlock2D"
123
124
down_block_types.append(down_block_type)
125
126
curr_res = [r // 2 for r in curr_res]
127
128
return down_block_types
129
130
131
def get_up_block_types(original_decoder_config):
132
attn_resolutions = coerce_attn_resolutions(original_decoder_config.attn_resolutions)
133
num_resolutions = len(original_decoder_config.ch_mult)
134
resolution = coerce_resolution(original_decoder_config.resolution)
135
136
curr_res = [r // 2 ** (num_resolutions - 1) for r in resolution]
137
up_block_types = []
138
139
for _ in reversed(range(num_resolutions)):
140
if curr_res in attn_resolutions:
141
up_block_type = "AttnUpDecoderBlock2D"
142
else:
143
up_block_type = "UpDecoderBlock2D"
144
145
up_block_types.append(up_block_type)
146
147
curr_res = [r * 2 for r in curr_res]
148
149
return up_block_types
150
151
152
def coerce_attn_resolutions(attn_resolutions):
153
attn_resolutions = OmegaConf.to_object(attn_resolutions)
154
attn_resolutions_ = []
155
for ar in attn_resolutions:
156
if isinstance(ar, (list, tuple)):
157
attn_resolutions_.append(list(ar))
158
else:
159
attn_resolutions_.append([ar, ar])
160
return attn_resolutions_
161
162
163
def coerce_resolution(resolution):
164
resolution = OmegaConf.to_object(resolution)
165
if isinstance(resolution, int):
166
resolution = [resolution, resolution] # H, W
167
elif isinstance(resolution, (tuple, list)):
168
resolution = list(resolution)
169
else:
170
raise ValueError("Unknown type of resolution:", resolution)
171
return resolution
172
173
174
# done vqvae model
175
176
# vqvae checkpoint
177
178
179
def vqvae_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
180
diffusers_checkpoint = {}
181
182
diffusers_checkpoint.update(vqvae_encoder_to_diffusers_checkpoint(model, checkpoint))
183
184
# quant_conv
185
186
diffusers_checkpoint.update(
187
{
188
"quant_conv.weight": checkpoint["quant_conv.weight"],
189
"quant_conv.bias": checkpoint["quant_conv.bias"],
190
}
191
)
192
193
# quantize
194
diffusers_checkpoint.update({"quantize.embedding.weight": checkpoint["quantize.embedding"]})
195
196
# post_quant_conv
197
diffusers_checkpoint.update(
198
{
199
"post_quant_conv.weight": checkpoint["post_quant_conv.weight"],
200
"post_quant_conv.bias": checkpoint["post_quant_conv.bias"],
201
}
202
)
203
204
# decoder
205
diffusers_checkpoint.update(vqvae_decoder_to_diffusers_checkpoint(model, checkpoint))
206
207
return diffusers_checkpoint
208
209
210
def vqvae_encoder_to_diffusers_checkpoint(model, checkpoint):
211
diffusers_checkpoint = {}
212
213
# conv_in
214
diffusers_checkpoint.update(
215
{
216
"encoder.conv_in.weight": checkpoint["encoder.conv_in.weight"],
217
"encoder.conv_in.bias": checkpoint["encoder.conv_in.bias"],
218
}
219
)
220
221
# down_blocks
222
for down_block_idx, down_block in enumerate(model.encoder.down_blocks):
223
diffusers_down_block_prefix = f"encoder.down_blocks.{down_block_idx}"
224
down_block_prefix = f"encoder.down.{down_block_idx}"
225
226
# resnets
227
for resnet_idx, resnet in enumerate(down_block.resnets):
228
diffusers_resnet_prefix = f"{diffusers_down_block_prefix}.resnets.{resnet_idx}"
229
resnet_prefix = f"{down_block_prefix}.block.{resnet_idx}"
230
231
diffusers_checkpoint.update(
232
vqvae_resnet_to_diffusers_checkpoint(
233
resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix
234
)
235
)
236
237
# downsample
238
239
# do not include the downsample when on the last down block
240
# There is no downsample on the last down block
241
if down_block_idx != len(model.encoder.down_blocks) - 1:
242
# There's a single downsample in the original checkpoint but a list of downsamples
243
# in the diffusers model.
244
diffusers_downsample_prefix = f"{diffusers_down_block_prefix}.downsamplers.0.conv"
245
downsample_prefix = f"{down_block_prefix}.downsample.conv"
246
diffusers_checkpoint.update(
247
{
248
f"{diffusers_downsample_prefix}.weight": checkpoint[f"{downsample_prefix}.weight"],
249
f"{diffusers_downsample_prefix}.bias": checkpoint[f"{downsample_prefix}.bias"],
250
}
251
)
252
253
# attentions
254
255
if hasattr(down_block, "attentions"):
256
for attention_idx, _ in enumerate(down_block.attentions):
257
diffusers_attention_prefix = f"{diffusers_down_block_prefix}.attentions.{attention_idx}"
258
attention_prefix = f"{down_block_prefix}.attn.{attention_idx}"
259
diffusers_checkpoint.update(
260
vqvae_attention_to_diffusers_checkpoint(
261
checkpoint,
262
diffusers_attention_prefix=diffusers_attention_prefix,
263
attention_prefix=attention_prefix,
264
)
265
)
266
267
# mid block
268
269
# mid block attentions
270
271
# There is a single hardcoded attention block in the middle of the VQ-diffusion encoder
272
diffusers_attention_prefix = "encoder.mid_block.attentions.0"
273
attention_prefix = "encoder.mid.attn_1"
274
diffusers_checkpoint.update(
275
vqvae_attention_to_diffusers_checkpoint(
276
checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix
277
)
278
)
279
280
# mid block resnets
281
282
for diffusers_resnet_idx, resnet in enumerate(model.encoder.mid_block.resnets):
283
diffusers_resnet_prefix = f"encoder.mid_block.resnets.{diffusers_resnet_idx}"
284
285
# the hardcoded prefixes to `block_` are 1 and 2
286
orig_resnet_idx = diffusers_resnet_idx + 1
287
# There are two hardcoded resnets in the middle of the VQ-diffusion encoder
288
resnet_prefix = f"encoder.mid.block_{orig_resnet_idx}"
289
290
diffusers_checkpoint.update(
291
vqvae_resnet_to_diffusers_checkpoint(
292
resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix
293
)
294
)
295
296
diffusers_checkpoint.update(
297
{
298
# conv_norm_out
299
"encoder.conv_norm_out.weight": checkpoint["encoder.norm_out.weight"],
300
"encoder.conv_norm_out.bias": checkpoint["encoder.norm_out.bias"],
301
# conv_out
302
"encoder.conv_out.weight": checkpoint["encoder.conv_out.weight"],
303
"encoder.conv_out.bias": checkpoint["encoder.conv_out.bias"],
304
}
305
)
306
307
return diffusers_checkpoint
308
309
310
def vqvae_decoder_to_diffusers_checkpoint(model, checkpoint):
311
diffusers_checkpoint = {}
312
313
# conv in
314
diffusers_checkpoint.update(
315
{
316
"decoder.conv_in.weight": checkpoint["decoder.conv_in.weight"],
317
"decoder.conv_in.bias": checkpoint["decoder.conv_in.bias"],
318
}
319
)
320
321
# up_blocks
322
323
for diffusers_up_block_idx, up_block in enumerate(model.decoder.up_blocks):
324
# up_blocks are stored in reverse order in the VQ-diffusion checkpoint
325
orig_up_block_idx = len(model.decoder.up_blocks) - 1 - diffusers_up_block_idx
326
327
diffusers_up_block_prefix = f"decoder.up_blocks.{diffusers_up_block_idx}"
328
up_block_prefix = f"decoder.up.{orig_up_block_idx}"
329
330
# resnets
331
for resnet_idx, resnet in enumerate(up_block.resnets):
332
diffusers_resnet_prefix = f"{diffusers_up_block_prefix}.resnets.{resnet_idx}"
333
resnet_prefix = f"{up_block_prefix}.block.{resnet_idx}"
334
335
diffusers_checkpoint.update(
336
vqvae_resnet_to_diffusers_checkpoint(
337
resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix
338
)
339
)
340
341
# upsample
342
343
# there is no up sample on the last up block
344
if diffusers_up_block_idx != len(model.decoder.up_blocks) - 1:
345
# There's a single upsample in the VQ-diffusion checkpoint but a list of downsamples
346
# in the diffusers model.
347
diffusers_downsample_prefix = f"{diffusers_up_block_prefix}.upsamplers.0.conv"
348
downsample_prefix = f"{up_block_prefix}.upsample.conv"
349
diffusers_checkpoint.update(
350
{
351
f"{diffusers_downsample_prefix}.weight": checkpoint[f"{downsample_prefix}.weight"],
352
f"{diffusers_downsample_prefix}.bias": checkpoint[f"{downsample_prefix}.bias"],
353
}
354
)
355
356
# attentions
357
358
if hasattr(up_block, "attentions"):
359
for attention_idx, _ in enumerate(up_block.attentions):
360
diffusers_attention_prefix = f"{diffusers_up_block_prefix}.attentions.{attention_idx}"
361
attention_prefix = f"{up_block_prefix}.attn.{attention_idx}"
362
diffusers_checkpoint.update(
363
vqvae_attention_to_diffusers_checkpoint(
364
checkpoint,
365
diffusers_attention_prefix=diffusers_attention_prefix,
366
attention_prefix=attention_prefix,
367
)
368
)
369
370
# mid block
371
372
# mid block attentions
373
374
# There is a single hardcoded attention block in the middle of the VQ-diffusion decoder
375
diffusers_attention_prefix = "decoder.mid_block.attentions.0"
376
attention_prefix = "decoder.mid.attn_1"
377
diffusers_checkpoint.update(
378
vqvae_attention_to_diffusers_checkpoint(
379
checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix
380
)
381
)
382
383
# mid block resnets
384
385
for diffusers_resnet_idx, resnet in enumerate(model.encoder.mid_block.resnets):
386
diffusers_resnet_prefix = f"decoder.mid_block.resnets.{diffusers_resnet_idx}"
387
388
# the hardcoded prefixes to `block_` are 1 and 2
389
orig_resnet_idx = diffusers_resnet_idx + 1
390
# There are two hardcoded resnets in the middle of the VQ-diffusion decoder
391
resnet_prefix = f"decoder.mid.block_{orig_resnet_idx}"
392
393
diffusers_checkpoint.update(
394
vqvae_resnet_to_diffusers_checkpoint(
395
resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix
396
)
397
)
398
399
diffusers_checkpoint.update(
400
{
401
# conv_norm_out
402
"decoder.conv_norm_out.weight": checkpoint["decoder.norm_out.weight"],
403
"decoder.conv_norm_out.bias": checkpoint["decoder.norm_out.bias"],
404
# conv_out
405
"decoder.conv_out.weight": checkpoint["decoder.conv_out.weight"],
406
"decoder.conv_out.bias": checkpoint["decoder.conv_out.bias"],
407
}
408
)
409
410
return diffusers_checkpoint
411
412
413
def vqvae_resnet_to_diffusers_checkpoint(resnet, checkpoint, *, diffusers_resnet_prefix, resnet_prefix):
414
rv = {
415
# norm1
416
f"{diffusers_resnet_prefix}.norm1.weight": checkpoint[f"{resnet_prefix}.norm1.weight"],
417
f"{diffusers_resnet_prefix}.norm1.bias": checkpoint[f"{resnet_prefix}.norm1.bias"],
418
# conv1
419
f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.conv1.weight"],
420
f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.conv1.bias"],
421
# norm2
422
f"{diffusers_resnet_prefix}.norm2.weight": checkpoint[f"{resnet_prefix}.norm2.weight"],
423
f"{diffusers_resnet_prefix}.norm2.bias": checkpoint[f"{resnet_prefix}.norm2.bias"],
424
# conv2
425
f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.conv2.weight"],
426
f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.conv2.bias"],
427
}
428
429
if resnet.conv_shortcut is not None:
430
rv.update(
431
{
432
f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{resnet_prefix}.nin_shortcut.weight"],
433
f"{diffusers_resnet_prefix}.conv_shortcut.bias": checkpoint[f"{resnet_prefix}.nin_shortcut.bias"],
434
}
435
)
436
437
return rv
438
439
440
def vqvae_attention_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix):
441
return {
442
# group_norm
443
f"{diffusers_attention_prefix}.group_norm.weight": checkpoint[f"{attention_prefix}.norm.weight"],
444
f"{diffusers_attention_prefix}.group_norm.bias": checkpoint[f"{attention_prefix}.norm.bias"],
445
# query
446
f"{diffusers_attention_prefix}.query.weight": checkpoint[f"{attention_prefix}.q.weight"][:, :, 0, 0],
447
f"{diffusers_attention_prefix}.query.bias": checkpoint[f"{attention_prefix}.q.bias"],
448
# key
449
f"{diffusers_attention_prefix}.key.weight": checkpoint[f"{attention_prefix}.k.weight"][:, :, 0, 0],
450
f"{diffusers_attention_prefix}.key.bias": checkpoint[f"{attention_prefix}.k.bias"],
451
# value
452
f"{diffusers_attention_prefix}.value.weight": checkpoint[f"{attention_prefix}.v.weight"][:, :, 0, 0],
453
f"{diffusers_attention_prefix}.value.bias": checkpoint[f"{attention_prefix}.v.bias"],
454
# proj_attn
455
f"{diffusers_attention_prefix}.proj_attn.weight": checkpoint[f"{attention_prefix}.proj_out.weight"][
456
:, :, 0, 0
457
],
458
f"{diffusers_attention_prefix}.proj_attn.bias": checkpoint[f"{attention_prefix}.proj_out.bias"],
459
}
460
461
462
# done vqvae checkpoint
463
464
# transformer model
465
466
PORTED_DIFFUSIONS = ["image_synthesis.modeling.transformers.diffusion_transformer.DiffusionTransformer"]
467
PORTED_TRANSFORMERS = ["image_synthesis.modeling.transformers.transformer_utils.Text2ImageTransformer"]
468
PORTED_CONTENT_EMBEDDINGS = ["image_synthesis.modeling.embeddings.dalle_mask_image_embedding.DalleMaskImageEmbedding"]
469
470
471
def transformer_model_from_original_config(
472
original_diffusion_config, original_transformer_config, original_content_embedding_config
473
):
474
assert (
475
original_diffusion_config.target in PORTED_DIFFUSIONS
476
), f"{original_diffusion_config.target} has not yet been ported to diffusers."
477
assert (
478
original_transformer_config.target in PORTED_TRANSFORMERS
479
), f"{original_transformer_config.target} has not yet been ported to diffusers."
480
assert (
481
original_content_embedding_config.target in PORTED_CONTENT_EMBEDDINGS
482
), f"{original_content_embedding_config.target} has not yet been ported to diffusers."
483
484
original_diffusion_config = original_diffusion_config.params
485
original_transformer_config = original_transformer_config.params
486
original_content_embedding_config = original_content_embedding_config.params
487
488
inner_dim = original_transformer_config["n_embd"]
489
490
n_heads = original_transformer_config["n_head"]
491
492
# VQ-Diffusion gives dimension of the multi-headed attention layers as the
493
# number of attention heads times the sequence length (the dimension) of a
494
# single head. We want to specify our attention blocks with those values
495
# specified separately
496
assert inner_dim % n_heads == 0
497
d_head = inner_dim // n_heads
498
499
depth = original_transformer_config["n_layer"]
500
context_dim = original_transformer_config["condition_dim"]
501
502
num_embed = original_content_embedding_config["num_embed"]
503
# the number of embeddings in the transformer includes the mask embedding.
504
# the content embedding (the vqvae) does not include the mask embedding.
505
num_embed = num_embed + 1
506
507
height = original_transformer_config["content_spatial_size"][0]
508
width = original_transformer_config["content_spatial_size"][1]
509
510
assert width == height, "width has to be equal to height"
511
dropout = original_transformer_config["resid_pdrop"]
512
num_embeds_ada_norm = original_diffusion_config["diffusion_step"]
513
514
model_kwargs = {
515
"attention_bias": True,
516
"cross_attention_dim": context_dim,
517
"attention_head_dim": d_head,
518
"num_layers": depth,
519
"dropout": dropout,
520
"num_attention_heads": n_heads,
521
"num_vector_embeds": num_embed,
522
"num_embeds_ada_norm": num_embeds_ada_norm,
523
"norm_num_groups": 32,
524
"sample_size": width,
525
"activation_fn": "geglu-approximate",
526
}
527
528
model = Transformer2DModel(**model_kwargs)
529
return model
530
531
532
# done transformer model
533
534
# transformer checkpoint
535
536
537
def transformer_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
538
diffusers_checkpoint = {}
539
540
transformer_prefix = "transformer.transformer"
541
542
diffusers_latent_image_embedding_prefix = "latent_image_embedding"
543
latent_image_embedding_prefix = f"{transformer_prefix}.content_emb"
544
545
# DalleMaskImageEmbedding
546
diffusers_checkpoint.update(
547
{
548
f"{diffusers_latent_image_embedding_prefix}.emb.weight": checkpoint[
549
f"{latent_image_embedding_prefix}.emb.weight"
550
],
551
f"{diffusers_latent_image_embedding_prefix}.height_emb.weight": checkpoint[
552
f"{latent_image_embedding_prefix}.height_emb.weight"
553
],
554
f"{diffusers_latent_image_embedding_prefix}.width_emb.weight": checkpoint[
555
f"{latent_image_embedding_prefix}.width_emb.weight"
556
],
557
}
558
)
559
560
# transformer blocks
561
for transformer_block_idx, transformer_block in enumerate(model.transformer_blocks):
562
diffusers_transformer_block_prefix = f"transformer_blocks.{transformer_block_idx}"
563
transformer_block_prefix = f"{transformer_prefix}.blocks.{transformer_block_idx}"
564
565
# ada norm block
566
diffusers_ada_norm_prefix = f"{diffusers_transformer_block_prefix}.norm1"
567
ada_norm_prefix = f"{transformer_block_prefix}.ln1"
568
569
diffusers_checkpoint.update(
570
transformer_ada_norm_to_diffusers_checkpoint(
571
checkpoint, diffusers_ada_norm_prefix=diffusers_ada_norm_prefix, ada_norm_prefix=ada_norm_prefix
572
)
573
)
574
575
# attention block
576
diffusers_attention_prefix = f"{diffusers_transformer_block_prefix}.attn1"
577
attention_prefix = f"{transformer_block_prefix}.attn1"
578
579
diffusers_checkpoint.update(
580
transformer_attention_to_diffusers_checkpoint(
581
checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix
582
)
583
)
584
585
# ada norm block
586
diffusers_ada_norm_prefix = f"{diffusers_transformer_block_prefix}.norm2"
587
ada_norm_prefix = f"{transformer_block_prefix}.ln1_1"
588
589
diffusers_checkpoint.update(
590
transformer_ada_norm_to_diffusers_checkpoint(
591
checkpoint, diffusers_ada_norm_prefix=diffusers_ada_norm_prefix, ada_norm_prefix=ada_norm_prefix
592
)
593
)
594
595
# attention block
596
diffusers_attention_prefix = f"{diffusers_transformer_block_prefix}.attn2"
597
attention_prefix = f"{transformer_block_prefix}.attn2"
598
599
diffusers_checkpoint.update(
600
transformer_attention_to_diffusers_checkpoint(
601
checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix
602
)
603
)
604
605
# norm block
606
diffusers_norm_block_prefix = f"{diffusers_transformer_block_prefix}.norm3"
607
norm_block_prefix = f"{transformer_block_prefix}.ln2"
608
609
diffusers_checkpoint.update(
610
{
611
f"{diffusers_norm_block_prefix}.weight": checkpoint[f"{norm_block_prefix}.weight"],
612
f"{diffusers_norm_block_prefix}.bias": checkpoint[f"{norm_block_prefix}.bias"],
613
}
614
)
615
616
# feedforward block
617
diffusers_feedforward_prefix = f"{diffusers_transformer_block_prefix}.ff"
618
feedforward_prefix = f"{transformer_block_prefix}.mlp"
619
620
diffusers_checkpoint.update(
621
transformer_feedforward_to_diffusers_checkpoint(
622
checkpoint,
623
diffusers_feedforward_prefix=diffusers_feedforward_prefix,
624
feedforward_prefix=feedforward_prefix,
625
)
626
)
627
628
# to logits
629
630
diffusers_norm_out_prefix = "norm_out"
631
norm_out_prefix = f"{transformer_prefix}.to_logits.0"
632
633
diffusers_checkpoint.update(
634
{
635
f"{diffusers_norm_out_prefix}.weight": checkpoint[f"{norm_out_prefix}.weight"],
636
f"{diffusers_norm_out_prefix}.bias": checkpoint[f"{norm_out_prefix}.bias"],
637
}
638
)
639
640
diffusers_out_prefix = "out"
641
out_prefix = f"{transformer_prefix}.to_logits.1"
642
643
diffusers_checkpoint.update(
644
{
645
f"{diffusers_out_prefix}.weight": checkpoint[f"{out_prefix}.weight"],
646
f"{diffusers_out_prefix}.bias": checkpoint[f"{out_prefix}.bias"],
647
}
648
)
649
650
return diffusers_checkpoint
651
652
653
def transformer_ada_norm_to_diffusers_checkpoint(checkpoint, *, diffusers_ada_norm_prefix, ada_norm_prefix):
654
return {
655
f"{diffusers_ada_norm_prefix}.emb.weight": checkpoint[f"{ada_norm_prefix}.emb.weight"],
656
f"{diffusers_ada_norm_prefix}.linear.weight": checkpoint[f"{ada_norm_prefix}.linear.weight"],
657
f"{diffusers_ada_norm_prefix}.linear.bias": checkpoint[f"{ada_norm_prefix}.linear.bias"],
658
}
659
660
661
def transformer_attention_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix):
662
return {
663
# key
664
f"{diffusers_attention_prefix}.to_k.weight": checkpoint[f"{attention_prefix}.key.weight"],
665
f"{diffusers_attention_prefix}.to_k.bias": checkpoint[f"{attention_prefix}.key.bias"],
666
# query
667
f"{diffusers_attention_prefix}.to_q.weight": checkpoint[f"{attention_prefix}.query.weight"],
668
f"{diffusers_attention_prefix}.to_q.bias": checkpoint[f"{attention_prefix}.query.bias"],
669
# value
670
f"{diffusers_attention_prefix}.to_v.weight": checkpoint[f"{attention_prefix}.value.weight"],
671
f"{diffusers_attention_prefix}.to_v.bias": checkpoint[f"{attention_prefix}.value.bias"],
672
# linear out
673
f"{diffusers_attention_prefix}.to_out.0.weight": checkpoint[f"{attention_prefix}.proj.weight"],
674
f"{diffusers_attention_prefix}.to_out.0.bias": checkpoint[f"{attention_prefix}.proj.bias"],
675
}
676
677
678
def transformer_feedforward_to_diffusers_checkpoint(checkpoint, *, diffusers_feedforward_prefix, feedforward_prefix):
679
return {
680
f"{diffusers_feedforward_prefix}.net.0.proj.weight": checkpoint[f"{feedforward_prefix}.0.weight"],
681
f"{diffusers_feedforward_prefix}.net.0.proj.bias": checkpoint[f"{feedforward_prefix}.0.bias"],
682
f"{diffusers_feedforward_prefix}.net.2.weight": checkpoint[f"{feedforward_prefix}.2.weight"],
683
f"{diffusers_feedforward_prefix}.net.2.bias": checkpoint[f"{feedforward_prefix}.2.bias"],
684
}
685
686
687
# done transformer checkpoint
688
689
690
def read_config_file(filename):
691
# The yaml file contains annotations that certain values should
692
# loaded as tuples. By default, OmegaConf will panic when reading
693
# these. Instead, we can manually read the yaml with the FullLoader and then
694
# construct the OmegaConf object.
695
with open(filename) as f:
696
original_config = yaml.load(f, FullLoader)
697
698
return OmegaConf.create(original_config)
699
700
701
# We take separate arguments for the vqvae because the ITHQ vqvae config file
702
# is separate from the config file for the rest of the model.
703
if __name__ == "__main__":
704
parser = argparse.ArgumentParser()
705
706
parser.add_argument(
707
"--vqvae_checkpoint_path",
708
default=None,
709
type=str,
710
required=True,
711
help="Path to the vqvae checkpoint to convert.",
712
)
713
714
parser.add_argument(
715
"--vqvae_original_config_file",
716
default=None,
717
type=str,
718
required=True,
719
help="The YAML config file corresponding to the original architecture for the vqvae.",
720
)
721
722
parser.add_argument(
723
"--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
724
)
725
726
parser.add_argument(
727
"--original_config_file",
728
default=None,
729
type=str,
730
required=True,
731
help="The YAML config file corresponding to the original architecture.",
732
)
733
734
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
735
736
parser.add_argument(
737
"--checkpoint_load_device",
738
default="cpu",
739
type=str,
740
required=False,
741
help="The device passed to `map_location` when loading checkpoints.",
742
)
743
744
# See link for how ema weights are always selected
745
# https://github.com/microsoft/VQ-Diffusion/blob/3c98e77f721db7c787b76304fa2c96a36c7b00af/inference_VQ_Diffusion.py#L65
746
parser.add_argument(
747
"--no_use_ema",
748
action="store_true",
749
required=False,
750
help=(
751
"Set to not use the ema weights from the original VQ-Diffusion checkpoint. You probably do not want to set"
752
" it as the original VQ-Diffusion always uses the ema weights when loading models."
753
),
754
)
755
756
args = parser.parse_args()
757
758
use_ema = not args.no_use_ema
759
760
print(f"loading checkpoints to {args.checkpoint_load_device}")
761
762
checkpoint_map_location = torch.device(args.checkpoint_load_device)
763
764
# vqvae_model
765
766
print(f"loading vqvae, config: {args.vqvae_original_config_file}, checkpoint: {args.vqvae_checkpoint_path}")
767
768
vqvae_original_config = read_config_file(args.vqvae_original_config_file).model
769
vqvae_checkpoint = torch.load(args.vqvae_checkpoint_path, map_location=checkpoint_map_location)["model"]
770
771
with init_empty_weights():
772
vqvae_model = vqvae_model_from_original_config(vqvae_original_config)
773
774
vqvae_diffusers_checkpoint = vqvae_original_checkpoint_to_diffusers_checkpoint(vqvae_model, vqvae_checkpoint)
775
776
with tempfile.NamedTemporaryFile() as vqvae_diffusers_checkpoint_file:
777
torch.save(vqvae_diffusers_checkpoint, vqvae_diffusers_checkpoint_file.name)
778
del vqvae_diffusers_checkpoint
779
del vqvae_checkpoint
780
load_checkpoint_and_dispatch(vqvae_model, vqvae_diffusers_checkpoint_file.name, device_map="auto")
781
782
print("done loading vqvae")
783
784
# done vqvae_model
785
786
# transformer_model
787
788
print(
789
f"loading transformer, config: {args.original_config_file}, checkpoint: {args.checkpoint_path}, use ema:"
790
f" {use_ema}"
791
)
792
793
original_config = read_config_file(args.original_config_file).model
794
795
diffusion_config = original_config.params.diffusion_config
796
transformer_config = original_config.params.diffusion_config.params.transformer_config
797
content_embedding_config = original_config.params.diffusion_config.params.content_emb_config
798
799
pre_checkpoint = torch.load(args.checkpoint_path, map_location=checkpoint_map_location)
800
801
if use_ema:
802
if "ema" in pre_checkpoint:
803
checkpoint = {}
804
for k, v in pre_checkpoint["model"].items():
805
checkpoint[k] = v
806
807
for k, v in pre_checkpoint["ema"].items():
808
# The ema weights are only used on the transformer. To mimic their key as if they came
809
# from the state_dict for the top level model, we prefix with an additional "transformer."
810
# See the source linked in the args.use_ema config for more information.
811
checkpoint[f"transformer.{k}"] = v
812
else:
813
print("attempted to load ema weights but no ema weights are specified in the loaded checkpoint.")
814
checkpoint = pre_checkpoint["model"]
815
else:
816
checkpoint = pre_checkpoint["model"]
817
818
del pre_checkpoint
819
820
with init_empty_weights():
821
transformer_model = transformer_model_from_original_config(
822
diffusion_config, transformer_config, content_embedding_config
823
)
824
825
diffusers_transformer_checkpoint = transformer_original_checkpoint_to_diffusers_checkpoint(
826
transformer_model, checkpoint
827
)
828
829
# classifier free sampling embeddings interlude
830
831
# The learned embeddings are stored on the transformer in the original VQ-diffusion. We store them on a separate
832
# model, so we pull them off the checkpoint before the checkpoint is deleted.
833
834
learnable_classifier_free_sampling_embeddings = diffusion_config.params.learnable_cf
835
836
if learnable_classifier_free_sampling_embeddings:
837
learned_classifier_free_sampling_embeddings_embeddings = checkpoint["transformer.empty_text_embed"]
838
else:
839
learned_classifier_free_sampling_embeddings_embeddings = None
840
841
# done classifier free sampling embeddings interlude
842
843
with tempfile.NamedTemporaryFile() as diffusers_transformer_checkpoint_file:
844
torch.save(diffusers_transformer_checkpoint, diffusers_transformer_checkpoint_file.name)
845
del diffusers_transformer_checkpoint
846
del checkpoint
847
load_checkpoint_and_dispatch(transformer_model, diffusers_transformer_checkpoint_file.name, device_map="auto")
848
849
print("done loading transformer")
850
851
# done transformer_model
852
853
# text encoder
854
855
print("loading CLIP text encoder")
856
857
clip_name = "openai/clip-vit-base-patch32"
858
859
# The original VQ-Diffusion specifies the pad value by the int used in the
860
# returned tokens. Each model uses `0` as the pad value. The transformers clip api
861
# specifies the pad value via the token before it has been tokenized. The `!` pad
862
# token is the same as padding with the `0` pad value.
863
pad_token = "!"
864
865
tokenizer_model = CLIPTokenizer.from_pretrained(clip_name, pad_token=pad_token, device_map="auto")
866
867
assert tokenizer_model.convert_tokens_to_ids(pad_token) == 0
868
869
text_encoder_model = CLIPTextModel.from_pretrained(
870
clip_name,
871
# `CLIPTextModel` does not support device_map="auto"
872
# device_map="auto"
873
)
874
875
print("done loading CLIP text encoder")
876
877
# done text encoder
878
879
# scheduler
880
881
scheduler_model = VQDiffusionScheduler(
882
# the scheduler has the same number of embeddings as the transformer
883
num_vec_classes=transformer_model.num_vector_embeds
884
)
885
886
# done scheduler
887
888
# learned classifier free sampling embeddings
889
890
with init_empty_weights():
891
learned_classifier_free_sampling_embeddings_model = LearnedClassifierFreeSamplingEmbeddings(
892
learnable_classifier_free_sampling_embeddings,
893
hidden_size=text_encoder_model.config.hidden_size,
894
length=tokenizer_model.model_max_length,
895
)
896
897
learned_classifier_free_sampling_checkpoint = {
898
"embeddings": learned_classifier_free_sampling_embeddings_embeddings.float()
899
}
900
901
with tempfile.NamedTemporaryFile() as learned_classifier_free_sampling_checkpoint_file:
902
torch.save(learned_classifier_free_sampling_checkpoint, learned_classifier_free_sampling_checkpoint_file.name)
903
del learned_classifier_free_sampling_checkpoint
904
del learned_classifier_free_sampling_embeddings_embeddings
905
load_checkpoint_and_dispatch(
906
learned_classifier_free_sampling_embeddings_model,
907
learned_classifier_free_sampling_checkpoint_file.name,
908
device_map="auto",
909
)
910
911
# done learned classifier free sampling embeddings
912
913
print(f"saving VQ diffusion model, path: {args.dump_path}")
914
915
pipe = VQDiffusionPipeline(
916
vqvae=vqvae_model,
917
transformer=transformer_model,
918
tokenizer=tokenizer_model,
919
text_encoder=text_encoder_model,
920
learned_classifier_free_sampling_embeddings=learned_classifier_free_sampling_embeddings_model,
921
scheduler=scheduler_model,
922
)
923
pipe.save_pretrained(args.dump_path)
924
925
print("done writing VQ diffusion model")
926
927