Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/scripts/convert_kakao_brain_unclip_to_diffusers.py
1440 views
1
import argparse
2
import tempfile
3
4
import torch
5
from accelerate import load_checkpoint_and_dispatch
6
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
7
8
from diffusers import UnCLIPPipeline, UNet2DConditionModel, UNet2DModel
9
from diffusers.models.prior_transformer import PriorTransformer
10
from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel
11
from diffusers.schedulers.scheduling_unclip import UnCLIPScheduler
12
13
14
"""
15
Example - From the diffusers root directory:
16
17
Download weights:
18
```sh
19
$ wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/efdf6206d8ed593961593dc029a8affa/decoder-ckpt-step%3D01000000-of-01000000.ckpt
20
$ wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/4226b831ae0279020d134281f3c31590/improved-sr-ckpt-step%3D1.2M.ckpt
21
$ wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/85626483eaca9f581e2a78d31ff905ca/prior-ckpt-step%3D01000000-of-01000000.ckpt
22
$ wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/0b62380a75e56f073e2844ab5199153d/ViT-L-14_stats.th
23
```
24
25
Convert the model:
26
```sh
27
$ python scripts/convert_kakao_brain_unclip_to_diffusers.py \
28
--decoder_checkpoint_path ./decoder-ckpt-step\=01000000-of-01000000.ckpt \
29
--super_res_unet_checkpoint_path ./improved-sr-ckpt-step\=1.2M.ckpt \
30
--prior_checkpoint_path ./prior-ckpt-step\=01000000-of-01000000.ckpt \
31
--clip_stat_path ./ViT-L-14_stats.th \
32
--dump_path <path where to save model>
33
```
34
"""
35
36
37
# prior
38
39
PRIOR_ORIGINAL_PREFIX = "model"
40
41
# Uses default arguments
42
PRIOR_CONFIG = {}
43
44
45
def prior_model_from_original_config():
46
model = PriorTransformer(**PRIOR_CONFIG)
47
48
return model
49
50
51
def prior_original_checkpoint_to_diffusers_checkpoint(model, checkpoint, clip_stats_checkpoint):
52
diffusers_checkpoint = {}
53
54
# <original>.time_embed.0 -> <diffusers>.time_embedding.linear_1
55
diffusers_checkpoint.update(
56
{
57
"time_embedding.linear_1.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.0.weight"],
58
"time_embedding.linear_1.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.0.bias"],
59
}
60
)
61
62
# <original>.clip_img_proj -> <diffusers>.proj_in
63
diffusers_checkpoint.update(
64
{
65
"proj_in.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.clip_img_proj.weight"],
66
"proj_in.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.clip_img_proj.bias"],
67
}
68
)
69
70
# <original>.text_emb_proj -> <diffusers>.embedding_proj
71
diffusers_checkpoint.update(
72
{
73
"embedding_proj.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.text_emb_proj.weight"],
74
"embedding_proj.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.text_emb_proj.bias"],
75
}
76
)
77
78
# <original>.text_enc_proj -> <diffusers>.encoder_hidden_states_proj
79
diffusers_checkpoint.update(
80
{
81
"encoder_hidden_states_proj.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.text_enc_proj.weight"],
82
"encoder_hidden_states_proj.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.text_enc_proj.bias"],
83
}
84
)
85
86
# <original>.positional_embedding -> <diffusers>.positional_embedding
87
diffusers_checkpoint.update({"positional_embedding": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.positional_embedding"]})
88
89
# <original>.prd_emb -> <diffusers>.prd_embedding
90
diffusers_checkpoint.update({"prd_embedding": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.prd_emb"]})
91
92
# <original>.time_embed.2 -> <diffusers>.time_embedding.linear_2
93
diffusers_checkpoint.update(
94
{
95
"time_embedding.linear_2.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.2.weight"],
96
"time_embedding.linear_2.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.2.bias"],
97
}
98
)
99
100
# <original>.resblocks.<x> -> <diffusers>.transformer_blocks.<x>
101
for idx in range(len(model.transformer_blocks)):
102
diffusers_transformer_prefix = f"transformer_blocks.{idx}"
103
original_transformer_prefix = f"{PRIOR_ORIGINAL_PREFIX}.transformer.resblocks.{idx}"
104
105
# <original>.attn -> <diffusers>.attn1
106
diffusers_attention_prefix = f"{diffusers_transformer_prefix}.attn1"
107
original_attention_prefix = f"{original_transformer_prefix}.attn"
108
diffusers_checkpoint.update(
109
prior_attention_to_diffusers(
110
checkpoint,
111
diffusers_attention_prefix=diffusers_attention_prefix,
112
original_attention_prefix=original_attention_prefix,
113
attention_head_dim=model.attention_head_dim,
114
)
115
)
116
117
# <original>.mlp -> <diffusers>.ff
118
diffusers_ff_prefix = f"{diffusers_transformer_prefix}.ff"
119
original_ff_prefix = f"{original_transformer_prefix}.mlp"
120
diffusers_checkpoint.update(
121
prior_ff_to_diffusers(
122
checkpoint, diffusers_ff_prefix=diffusers_ff_prefix, original_ff_prefix=original_ff_prefix
123
)
124
)
125
126
# <original>.ln_1 -> <diffusers>.norm1
127
diffusers_checkpoint.update(
128
{
129
f"{diffusers_transformer_prefix}.norm1.weight": checkpoint[
130
f"{original_transformer_prefix}.ln_1.weight"
131
],
132
f"{diffusers_transformer_prefix}.norm1.bias": checkpoint[f"{original_transformer_prefix}.ln_1.bias"],
133
}
134
)
135
136
# <original>.ln_2 -> <diffusers>.norm3
137
diffusers_checkpoint.update(
138
{
139
f"{diffusers_transformer_prefix}.norm3.weight": checkpoint[
140
f"{original_transformer_prefix}.ln_2.weight"
141
],
142
f"{diffusers_transformer_prefix}.norm3.bias": checkpoint[f"{original_transformer_prefix}.ln_2.bias"],
143
}
144
)
145
146
# <original>.final_ln -> <diffusers>.norm_out
147
diffusers_checkpoint.update(
148
{
149
"norm_out.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.final_ln.weight"],
150
"norm_out.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.final_ln.bias"],
151
}
152
)
153
154
# <original>.out_proj -> <diffusers>.proj_to_clip_embeddings
155
diffusers_checkpoint.update(
156
{
157
"proj_to_clip_embeddings.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.out_proj.weight"],
158
"proj_to_clip_embeddings.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.out_proj.bias"],
159
}
160
)
161
162
# clip stats
163
clip_mean, clip_std = clip_stats_checkpoint
164
clip_mean = clip_mean[None, :]
165
clip_std = clip_std[None, :]
166
167
diffusers_checkpoint.update({"clip_mean": clip_mean, "clip_std": clip_std})
168
169
return diffusers_checkpoint
170
171
172
def prior_attention_to_diffusers(
173
checkpoint, *, diffusers_attention_prefix, original_attention_prefix, attention_head_dim
174
):
175
diffusers_checkpoint = {}
176
177
# <original>.c_qkv -> <diffusers>.{to_q, to_k, to_v}
178
[q_weight, k_weight, v_weight], [q_bias, k_bias, v_bias] = split_attentions(
179
weight=checkpoint[f"{original_attention_prefix}.c_qkv.weight"],
180
bias=checkpoint[f"{original_attention_prefix}.c_qkv.bias"],
181
split=3,
182
chunk_size=attention_head_dim,
183
)
184
185
diffusers_checkpoint.update(
186
{
187
f"{diffusers_attention_prefix}.to_q.weight": q_weight,
188
f"{diffusers_attention_prefix}.to_q.bias": q_bias,
189
f"{diffusers_attention_prefix}.to_k.weight": k_weight,
190
f"{diffusers_attention_prefix}.to_k.bias": k_bias,
191
f"{diffusers_attention_prefix}.to_v.weight": v_weight,
192
f"{diffusers_attention_prefix}.to_v.bias": v_bias,
193
}
194
)
195
196
# <original>.c_proj -> <diffusers>.to_out.0
197
diffusers_checkpoint.update(
198
{
199
f"{diffusers_attention_prefix}.to_out.0.weight": checkpoint[f"{original_attention_prefix}.c_proj.weight"],
200
f"{diffusers_attention_prefix}.to_out.0.bias": checkpoint[f"{original_attention_prefix}.c_proj.bias"],
201
}
202
)
203
204
return diffusers_checkpoint
205
206
207
def prior_ff_to_diffusers(checkpoint, *, diffusers_ff_prefix, original_ff_prefix):
208
diffusers_checkpoint = {
209
# <original>.c_fc -> <diffusers>.net.0.proj
210
f"{diffusers_ff_prefix}.net.{0}.proj.weight": checkpoint[f"{original_ff_prefix}.c_fc.weight"],
211
f"{diffusers_ff_prefix}.net.{0}.proj.bias": checkpoint[f"{original_ff_prefix}.c_fc.bias"],
212
# <original>.c_proj -> <diffusers>.net.2
213
f"{diffusers_ff_prefix}.net.{2}.weight": checkpoint[f"{original_ff_prefix}.c_proj.weight"],
214
f"{diffusers_ff_prefix}.net.{2}.bias": checkpoint[f"{original_ff_prefix}.c_proj.bias"],
215
}
216
217
return diffusers_checkpoint
218
219
220
# done prior
221
222
223
# decoder
224
225
DECODER_ORIGINAL_PREFIX = "model"
226
227
# We are hardcoding the model configuration for now. If we need to generalize to more model configurations, we can
228
# update then.
229
DECODER_CONFIG = {
230
"sample_size": 64,
231
"layers_per_block": 3,
232
"down_block_types": (
233
"ResnetDownsampleBlock2D",
234
"SimpleCrossAttnDownBlock2D",
235
"SimpleCrossAttnDownBlock2D",
236
"SimpleCrossAttnDownBlock2D",
237
),
238
"up_block_types": (
239
"SimpleCrossAttnUpBlock2D",
240
"SimpleCrossAttnUpBlock2D",
241
"SimpleCrossAttnUpBlock2D",
242
"ResnetUpsampleBlock2D",
243
),
244
"mid_block_type": "UNetMidBlock2DSimpleCrossAttn",
245
"block_out_channels": (320, 640, 960, 1280),
246
"in_channels": 3,
247
"out_channels": 6,
248
"cross_attention_dim": 1536,
249
"class_embed_type": "identity",
250
"attention_head_dim": 64,
251
"resnet_time_scale_shift": "scale_shift",
252
}
253
254
255
def decoder_model_from_original_config():
256
model = UNet2DConditionModel(**DECODER_CONFIG)
257
258
return model
259
260
261
def decoder_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
262
diffusers_checkpoint = {}
263
264
original_unet_prefix = DECODER_ORIGINAL_PREFIX
265
num_head_channels = DECODER_CONFIG["attention_head_dim"]
266
267
diffusers_checkpoint.update(unet_time_embeddings(checkpoint, original_unet_prefix))
268
diffusers_checkpoint.update(unet_conv_in(checkpoint, original_unet_prefix))
269
270
# <original>.input_blocks -> <diffusers>.down_blocks
271
272
original_down_block_idx = 1
273
274
for diffusers_down_block_idx in range(len(model.down_blocks)):
275
checkpoint_update, num_original_down_blocks = unet_downblock_to_diffusers_checkpoint(
276
model,
277
checkpoint,
278
diffusers_down_block_idx=diffusers_down_block_idx,
279
original_down_block_idx=original_down_block_idx,
280
original_unet_prefix=original_unet_prefix,
281
num_head_channels=num_head_channels,
282
)
283
284
original_down_block_idx += num_original_down_blocks
285
286
diffusers_checkpoint.update(checkpoint_update)
287
288
# done <original>.input_blocks -> <diffusers>.down_blocks
289
290
diffusers_checkpoint.update(
291
unet_midblock_to_diffusers_checkpoint(
292
model,
293
checkpoint,
294
original_unet_prefix=original_unet_prefix,
295
num_head_channels=num_head_channels,
296
)
297
)
298
299
# <original>.output_blocks -> <diffusers>.up_blocks
300
301
original_up_block_idx = 0
302
303
for diffusers_up_block_idx in range(len(model.up_blocks)):
304
checkpoint_update, num_original_up_blocks = unet_upblock_to_diffusers_checkpoint(
305
model,
306
checkpoint,
307
diffusers_up_block_idx=diffusers_up_block_idx,
308
original_up_block_idx=original_up_block_idx,
309
original_unet_prefix=original_unet_prefix,
310
num_head_channels=num_head_channels,
311
)
312
313
original_up_block_idx += num_original_up_blocks
314
315
diffusers_checkpoint.update(checkpoint_update)
316
317
# done <original>.output_blocks -> <diffusers>.up_blocks
318
319
diffusers_checkpoint.update(unet_conv_norm_out(checkpoint, original_unet_prefix))
320
diffusers_checkpoint.update(unet_conv_out(checkpoint, original_unet_prefix))
321
322
return diffusers_checkpoint
323
324
325
# done decoder
326
327
# text proj
328
329
330
def text_proj_from_original_config():
331
# From the conditional unet constructor where the dimension of the projected time embeddings is
332
# constructed
333
time_embed_dim = DECODER_CONFIG["block_out_channels"][0] * 4
334
335
cross_attention_dim = DECODER_CONFIG["cross_attention_dim"]
336
337
model = UnCLIPTextProjModel(time_embed_dim=time_embed_dim, cross_attention_dim=cross_attention_dim)
338
339
return model
340
341
342
# Note that the input checkpoint is the original decoder checkpoint
343
def text_proj_original_checkpoint_to_diffusers_checkpoint(checkpoint):
344
diffusers_checkpoint = {
345
# <original>.text_seq_proj.0 -> <diffusers>.encoder_hidden_states_proj
346
"encoder_hidden_states_proj.weight": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.text_seq_proj.0.weight"],
347
"encoder_hidden_states_proj.bias": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.text_seq_proj.0.bias"],
348
# <original>.text_seq_proj.1 -> <diffusers>.text_encoder_hidden_states_norm
349
"text_encoder_hidden_states_norm.weight": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.text_seq_proj.1.weight"],
350
"text_encoder_hidden_states_norm.bias": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.text_seq_proj.1.bias"],
351
# <original>.clip_tok_proj -> <diffusers>.clip_extra_context_tokens_proj
352
"clip_extra_context_tokens_proj.weight": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.clip_tok_proj.weight"],
353
"clip_extra_context_tokens_proj.bias": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.clip_tok_proj.bias"],
354
# <original>.text_feat_proj -> <diffusers>.embedding_proj
355
"embedding_proj.weight": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.text_feat_proj.weight"],
356
"embedding_proj.bias": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.text_feat_proj.bias"],
357
# <original>.cf_param -> <diffusers>.learned_classifier_free_guidance_embeddings
358
"learned_classifier_free_guidance_embeddings": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.cf_param"],
359
# <original>.clip_emb -> <diffusers>.clip_image_embeddings_project_to_time_embeddings
360
"clip_image_embeddings_project_to_time_embeddings.weight": checkpoint[
361
f"{DECODER_ORIGINAL_PREFIX}.clip_emb.weight"
362
],
363
"clip_image_embeddings_project_to_time_embeddings.bias": checkpoint[
364
f"{DECODER_ORIGINAL_PREFIX}.clip_emb.bias"
365
],
366
}
367
368
return diffusers_checkpoint
369
370
371
# done text proj
372
373
# super res unet first steps
374
375
SUPER_RES_UNET_FIRST_STEPS_PREFIX = "model_first_steps"
376
377
SUPER_RES_UNET_FIRST_STEPS_CONFIG = {
378
"sample_size": 256,
379
"layers_per_block": 3,
380
"down_block_types": (
381
"ResnetDownsampleBlock2D",
382
"ResnetDownsampleBlock2D",
383
"ResnetDownsampleBlock2D",
384
"ResnetDownsampleBlock2D",
385
),
386
"up_block_types": (
387
"ResnetUpsampleBlock2D",
388
"ResnetUpsampleBlock2D",
389
"ResnetUpsampleBlock2D",
390
"ResnetUpsampleBlock2D",
391
),
392
"block_out_channels": (320, 640, 960, 1280),
393
"in_channels": 6,
394
"out_channels": 3,
395
"add_attention": False,
396
}
397
398
399
def super_res_unet_first_steps_model_from_original_config():
400
model = UNet2DModel(**SUPER_RES_UNET_FIRST_STEPS_CONFIG)
401
402
return model
403
404
405
def super_res_unet_first_steps_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
406
diffusers_checkpoint = {}
407
408
original_unet_prefix = SUPER_RES_UNET_FIRST_STEPS_PREFIX
409
410
diffusers_checkpoint.update(unet_time_embeddings(checkpoint, original_unet_prefix))
411
diffusers_checkpoint.update(unet_conv_in(checkpoint, original_unet_prefix))
412
413
# <original>.input_blocks -> <diffusers>.down_blocks
414
415
original_down_block_idx = 1
416
417
for diffusers_down_block_idx in range(len(model.down_blocks)):
418
checkpoint_update, num_original_down_blocks = unet_downblock_to_diffusers_checkpoint(
419
model,
420
checkpoint,
421
diffusers_down_block_idx=diffusers_down_block_idx,
422
original_down_block_idx=original_down_block_idx,
423
original_unet_prefix=original_unet_prefix,
424
num_head_channels=None,
425
)
426
427
original_down_block_idx += num_original_down_blocks
428
429
diffusers_checkpoint.update(checkpoint_update)
430
431
diffusers_checkpoint.update(
432
unet_midblock_to_diffusers_checkpoint(
433
model,
434
checkpoint,
435
original_unet_prefix=original_unet_prefix,
436
num_head_channels=None,
437
)
438
)
439
440
# <original>.output_blocks -> <diffusers>.up_blocks
441
442
original_up_block_idx = 0
443
444
for diffusers_up_block_idx in range(len(model.up_blocks)):
445
checkpoint_update, num_original_up_blocks = unet_upblock_to_diffusers_checkpoint(
446
model,
447
checkpoint,
448
diffusers_up_block_idx=diffusers_up_block_idx,
449
original_up_block_idx=original_up_block_idx,
450
original_unet_prefix=original_unet_prefix,
451
num_head_channels=None,
452
)
453
454
original_up_block_idx += num_original_up_blocks
455
456
diffusers_checkpoint.update(checkpoint_update)
457
458
# done <original>.output_blocks -> <diffusers>.up_blocks
459
460
diffusers_checkpoint.update(unet_conv_norm_out(checkpoint, original_unet_prefix))
461
diffusers_checkpoint.update(unet_conv_out(checkpoint, original_unet_prefix))
462
463
return diffusers_checkpoint
464
465
466
# done super res unet first steps
467
468
# super res unet last step
469
470
SUPER_RES_UNET_LAST_STEP_PREFIX = "model_last_step"
471
472
SUPER_RES_UNET_LAST_STEP_CONFIG = {
473
"sample_size": 256,
474
"layers_per_block": 3,
475
"down_block_types": (
476
"ResnetDownsampleBlock2D",
477
"ResnetDownsampleBlock2D",
478
"ResnetDownsampleBlock2D",
479
"ResnetDownsampleBlock2D",
480
),
481
"up_block_types": (
482
"ResnetUpsampleBlock2D",
483
"ResnetUpsampleBlock2D",
484
"ResnetUpsampleBlock2D",
485
"ResnetUpsampleBlock2D",
486
),
487
"block_out_channels": (320, 640, 960, 1280),
488
"in_channels": 6,
489
"out_channels": 3,
490
"add_attention": False,
491
}
492
493
494
def super_res_unet_last_step_model_from_original_config():
495
model = UNet2DModel(**SUPER_RES_UNET_LAST_STEP_CONFIG)
496
497
return model
498
499
500
def super_res_unet_last_step_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
501
diffusers_checkpoint = {}
502
503
original_unet_prefix = SUPER_RES_UNET_LAST_STEP_PREFIX
504
505
diffusers_checkpoint.update(unet_time_embeddings(checkpoint, original_unet_prefix))
506
diffusers_checkpoint.update(unet_conv_in(checkpoint, original_unet_prefix))
507
508
# <original>.input_blocks -> <diffusers>.down_blocks
509
510
original_down_block_idx = 1
511
512
for diffusers_down_block_idx in range(len(model.down_blocks)):
513
checkpoint_update, num_original_down_blocks = unet_downblock_to_diffusers_checkpoint(
514
model,
515
checkpoint,
516
diffusers_down_block_idx=diffusers_down_block_idx,
517
original_down_block_idx=original_down_block_idx,
518
original_unet_prefix=original_unet_prefix,
519
num_head_channels=None,
520
)
521
522
original_down_block_idx += num_original_down_blocks
523
524
diffusers_checkpoint.update(checkpoint_update)
525
526
diffusers_checkpoint.update(
527
unet_midblock_to_diffusers_checkpoint(
528
model,
529
checkpoint,
530
original_unet_prefix=original_unet_prefix,
531
num_head_channels=None,
532
)
533
)
534
535
# <original>.output_blocks -> <diffusers>.up_blocks
536
537
original_up_block_idx = 0
538
539
for diffusers_up_block_idx in range(len(model.up_blocks)):
540
checkpoint_update, num_original_up_blocks = unet_upblock_to_diffusers_checkpoint(
541
model,
542
checkpoint,
543
diffusers_up_block_idx=diffusers_up_block_idx,
544
original_up_block_idx=original_up_block_idx,
545
original_unet_prefix=original_unet_prefix,
546
num_head_channels=None,
547
)
548
549
original_up_block_idx += num_original_up_blocks
550
551
diffusers_checkpoint.update(checkpoint_update)
552
553
# done <original>.output_blocks -> <diffusers>.up_blocks
554
555
diffusers_checkpoint.update(unet_conv_norm_out(checkpoint, original_unet_prefix))
556
diffusers_checkpoint.update(unet_conv_out(checkpoint, original_unet_prefix))
557
558
return diffusers_checkpoint
559
560
561
# done super res unet last step
562
563
564
# unet utils
565
566
567
# <original>.time_embed -> <diffusers>.time_embedding
568
def unet_time_embeddings(checkpoint, original_unet_prefix):
569
diffusers_checkpoint = {}
570
571
diffusers_checkpoint.update(
572
{
573
"time_embedding.linear_1.weight": checkpoint[f"{original_unet_prefix}.time_embed.0.weight"],
574
"time_embedding.linear_1.bias": checkpoint[f"{original_unet_prefix}.time_embed.0.bias"],
575
"time_embedding.linear_2.weight": checkpoint[f"{original_unet_prefix}.time_embed.2.weight"],
576
"time_embedding.linear_2.bias": checkpoint[f"{original_unet_prefix}.time_embed.2.bias"],
577
}
578
)
579
580
return diffusers_checkpoint
581
582
583
# <original>.input_blocks.0 -> <diffusers>.conv_in
584
def unet_conv_in(checkpoint, original_unet_prefix):
585
diffusers_checkpoint = {}
586
587
diffusers_checkpoint.update(
588
{
589
"conv_in.weight": checkpoint[f"{original_unet_prefix}.input_blocks.0.0.weight"],
590
"conv_in.bias": checkpoint[f"{original_unet_prefix}.input_blocks.0.0.bias"],
591
}
592
)
593
594
return diffusers_checkpoint
595
596
597
# <original>.out.0 -> <diffusers>.conv_norm_out
598
def unet_conv_norm_out(checkpoint, original_unet_prefix):
599
diffusers_checkpoint = {}
600
601
diffusers_checkpoint.update(
602
{
603
"conv_norm_out.weight": checkpoint[f"{original_unet_prefix}.out.0.weight"],
604
"conv_norm_out.bias": checkpoint[f"{original_unet_prefix}.out.0.bias"],
605
}
606
)
607
608
return diffusers_checkpoint
609
610
611
# <original>.out.2 -> <diffusers>.conv_out
612
def unet_conv_out(checkpoint, original_unet_prefix):
613
diffusers_checkpoint = {}
614
615
diffusers_checkpoint.update(
616
{
617
"conv_out.weight": checkpoint[f"{original_unet_prefix}.out.2.weight"],
618
"conv_out.bias": checkpoint[f"{original_unet_prefix}.out.2.bias"],
619
}
620
)
621
622
return diffusers_checkpoint
623
624
625
# <original>.input_blocks -> <diffusers>.down_blocks
626
def unet_downblock_to_diffusers_checkpoint(
627
model, checkpoint, *, diffusers_down_block_idx, original_down_block_idx, original_unet_prefix, num_head_channels
628
):
629
diffusers_checkpoint = {}
630
631
diffusers_resnet_prefix = f"down_blocks.{diffusers_down_block_idx}.resnets"
632
original_down_block_prefix = f"{original_unet_prefix}.input_blocks"
633
634
down_block = model.down_blocks[diffusers_down_block_idx]
635
636
num_resnets = len(down_block.resnets)
637
638
if down_block.downsamplers is None:
639
downsampler = False
640
else:
641
assert len(down_block.downsamplers) == 1
642
downsampler = True
643
# The downsample block is also a resnet
644
num_resnets += 1
645
646
for resnet_idx_inc in range(num_resnets):
647
full_resnet_prefix = f"{original_down_block_prefix}.{original_down_block_idx + resnet_idx_inc}.0"
648
649
if downsampler and resnet_idx_inc == num_resnets - 1:
650
# this is a downsample block
651
full_diffusers_resnet_prefix = f"down_blocks.{diffusers_down_block_idx}.downsamplers.0"
652
else:
653
# this is a regular resnet block
654
full_diffusers_resnet_prefix = f"{diffusers_resnet_prefix}.{resnet_idx_inc}"
655
656
diffusers_checkpoint.update(
657
resnet_to_diffusers_checkpoint(
658
checkpoint, resnet_prefix=full_resnet_prefix, diffusers_resnet_prefix=full_diffusers_resnet_prefix
659
)
660
)
661
662
if hasattr(down_block, "attentions"):
663
num_attentions = len(down_block.attentions)
664
diffusers_attention_prefix = f"down_blocks.{diffusers_down_block_idx}.attentions"
665
666
for attention_idx_inc in range(num_attentions):
667
full_attention_prefix = f"{original_down_block_prefix}.{original_down_block_idx + attention_idx_inc}.1"
668
full_diffusers_attention_prefix = f"{diffusers_attention_prefix}.{attention_idx_inc}"
669
670
diffusers_checkpoint.update(
671
attention_to_diffusers_checkpoint(
672
checkpoint,
673
attention_prefix=full_attention_prefix,
674
diffusers_attention_prefix=full_diffusers_attention_prefix,
675
num_head_channels=num_head_channels,
676
)
677
)
678
679
num_original_down_blocks = num_resnets
680
681
return diffusers_checkpoint, num_original_down_blocks
682
683
684
# <original>.middle_block -> <diffusers>.mid_block
685
def unet_midblock_to_diffusers_checkpoint(model, checkpoint, *, original_unet_prefix, num_head_channels):
686
diffusers_checkpoint = {}
687
688
# block 0
689
690
original_block_idx = 0
691
692
diffusers_checkpoint.update(
693
resnet_to_diffusers_checkpoint(
694
checkpoint,
695
diffusers_resnet_prefix="mid_block.resnets.0",
696
resnet_prefix=f"{original_unet_prefix}.middle_block.{original_block_idx}",
697
)
698
)
699
700
original_block_idx += 1
701
702
# optional block 1
703
704
if hasattr(model.mid_block, "attentions") and model.mid_block.attentions[0] is not None:
705
diffusers_checkpoint.update(
706
attention_to_diffusers_checkpoint(
707
checkpoint,
708
diffusers_attention_prefix="mid_block.attentions.0",
709
attention_prefix=f"{original_unet_prefix}.middle_block.{original_block_idx}",
710
num_head_channels=num_head_channels,
711
)
712
)
713
original_block_idx += 1
714
715
# block 1 or block 2
716
717
diffusers_checkpoint.update(
718
resnet_to_diffusers_checkpoint(
719
checkpoint,
720
diffusers_resnet_prefix="mid_block.resnets.1",
721
resnet_prefix=f"{original_unet_prefix}.middle_block.{original_block_idx}",
722
)
723
)
724
725
return diffusers_checkpoint
726
727
728
# <original>.output_blocks -> <diffusers>.up_blocks
729
def unet_upblock_to_diffusers_checkpoint(
730
model, checkpoint, *, diffusers_up_block_idx, original_up_block_idx, original_unet_prefix, num_head_channels
731
):
732
diffusers_checkpoint = {}
733
734
diffusers_resnet_prefix = f"up_blocks.{diffusers_up_block_idx}.resnets"
735
original_up_block_prefix = f"{original_unet_prefix}.output_blocks"
736
737
up_block = model.up_blocks[diffusers_up_block_idx]
738
739
num_resnets = len(up_block.resnets)
740
741
if up_block.upsamplers is None:
742
upsampler = False
743
else:
744
assert len(up_block.upsamplers) == 1
745
upsampler = True
746
# The upsample block is also a resnet
747
num_resnets += 1
748
749
has_attentions = hasattr(up_block, "attentions")
750
751
for resnet_idx_inc in range(num_resnets):
752
if upsampler and resnet_idx_inc == num_resnets - 1:
753
# this is an upsample block
754
if has_attentions:
755
# There is a middle attention block that we skip
756
original_resnet_block_idx = 2
757
else:
758
original_resnet_block_idx = 1
759
760
# we add the `minus 1` because the last two resnets are stuck together in the same output block
761
full_resnet_prefix = (
762
f"{original_up_block_prefix}.{original_up_block_idx + resnet_idx_inc - 1}.{original_resnet_block_idx}"
763
)
764
765
full_diffusers_resnet_prefix = f"up_blocks.{diffusers_up_block_idx}.upsamplers.0"
766
else:
767
# this is a regular resnet block
768
full_resnet_prefix = f"{original_up_block_prefix}.{original_up_block_idx + resnet_idx_inc}.0"
769
full_diffusers_resnet_prefix = f"{diffusers_resnet_prefix}.{resnet_idx_inc}"
770
771
diffusers_checkpoint.update(
772
resnet_to_diffusers_checkpoint(
773
checkpoint, resnet_prefix=full_resnet_prefix, diffusers_resnet_prefix=full_diffusers_resnet_prefix
774
)
775
)
776
777
if has_attentions:
778
num_attentions = len(up_block.attentions)
779
diffusers_attention_prefix = f"up_blocks.{diffusers_up_block_idx}.attentions"
780
781
for attention_idx_inc in range(num_attentions):
782
full_attention_prefix = f"{original_up_block_prefix}.{original_up_block_idx + attention_idx_inc}.1"
783
full_diffusers_attention_prefix = f"{diffusers_attention_prefix}.{attention_idx_inc}"
784
785
diffusers_checkpoint.update(
786
attention_to_diffusers_checkpoint(
787
checkpoint,
788
attention_prefix=full_attention_prefix,
789
diffusers_attention_prefix=full_diffusers_attention_prefix,
790
num_head_channels=num_head_channels,
791
)
792
)
793
794
num_original_down_blocks = num_resnets - 1 if upsampler else num_resnets
795
796
return diffusers_checkpoint, num_original_down_blocks
797
798
799
def resnet_to_diffusers_checkpoint(checkpoint, *, diffusers_resnet_prefix, resnet_prefix):
800
diffusers_checkpoint = {
801
f"{diffusers_resnet_prefix}.norm1.weight": checkpoint[f"{resnet_prefix}.in_layers.0.weight"],
802
f"{diffusers_resnet_prefix}.norm1.bias": checkpoint[f"{resnet_prefix}.in_layers.0.bias"],
803
f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.in_layers.2.weight"],
804
f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.in_layers.2.bias"],
805
f"{diffusers_resnet_prefix}.time_emb_proj.weight": checkpoint[f"{resnet_prefix}.emb_layers.1.weight"],
806
f"{diffusers_resnet_prefix}.time_emb_proj.bias": checkpoint[f"{resnet_prefix}.emb_layers.1.bias"],
807
f"{diffusers_resnet_prefix}.norm2.weight": checkpoint[f"{resnet_prefix}.out_layers.0.weight"],
808
f"{diffusers_resnet_prefix}.norm2.bias": checkpoint[f"{resnet_prefix}.out_layers.0.bias"],
809
f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.out_layers.3.weight"],
810
f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.out_layers.3.bias"],
811
}
812
813
skip_connection_prefix = f"{resnet_prefix}.skip_connection"
814
815
if f"{skip_connection_prefix}.weight" in checkpoint:
816
diffusers_checkpoint.update(
817
{
818
f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{skip_connection_prefix}.weight"],
819
f"{diffusers_resnet_prefix}.conv_shortcut.bias": checkpoint[f"{skip_connection_prefix}.bias"],
820
}
821
)
822
823
return diffusers_checkpoint
824
825
826
def attention_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix, num_head_channels):
827
diffusers_checkpoint = {}
828
829
# <original>.norm -> <diffusers>.group_norm
830
diffusers_checkpoint.update(
831
{
832
f"{diffusers_attention_prefix}.group_norm.weight": checkpoint[f"{attention_prefix}.norm.weight"],
833
f"{diffusers_attention_prefix}.group_norm.bias": checkpoint[f"{attention_prefix}.norm.bias"],
834
}
835
)
836
837
# <original>.qkv -> <diffusers>.{query, key, value}
838
[q_weight, k_weight, v_weight], [q_bias, k_bias, v_bias] = split_attentions(
839
weight=checkpoint[f"{attention_prefix}.qkv.weight"][:, :, 0],
840
bias=checkpoint[f"{attention_prefix}.qkv.bias"],
841
split=3,
842
chunk_size=num_head_channels,
843
)
844
845
diffusers_checkpoint.update(
846
{
847
f"{diffusers_attention_prefix}.to_q.weight": q_weight,
848
f"{diffusers_attention_prefix}.to_q.bias": q_bias,
849
f"{diffusers_attention_prefix}.to_k.weight": k_weight,
850
f"{diffusers_attention_prefix}.to_k.bias": k_bias,
851
f"{diffusers_attention_prefix}.to_v.weight": v_weight,
852
f"{diffusers_attention_prefix}.to_v.bias": v_bias,
853
}
854
)
855
856
# <original>.encoder_kv -> <diffusers>.{context_key, context_value}
857
[encoder_k_weight, encoder_v_weight], [encoder_k_bias, encoder_v_bias] = split_attentions(
858
weight=checkpoint[f"{attention_prefix}.encoder_kv.weight"][:, :, 0],
859
bias=checkpoint[f"{attention_prefix}.encoder_kv.bias"],
860
split=2,
861
chunk_size=num_head_channels,
862
)
863
864
diffusers_checkpoint.update(
865
{
866
f"{diffusers_attention_prefix}.add_k_proj.weight": encoder_k_weight,
867
f"{diffusers_attention_prefix}.add_k_proj.bias": encoder_k_bias,
868
f"{diffusers_attention_prefix}.add_v_proj.weight": encoder_v_weight,
869
f"{diffusers_attention_prefix}.add_v_proj.bias": encoder_v_bias,
870
}
871
)
872
873
# <original>.proj_out (1d conv) -> <diffusers>.proj_attn (linear)
874
diffusers_checkpoint.update(
875
{
876
f"{diffusers_attention_prefix}.to_out.0.weight": checkpoint[f"{attention_prefix}.proj_out.weight"][
877
:, :, 0
878
],
879
f"{diffusers_attention_prefix}.to_out.0.bias": checkpoint[f"{attention_prefix}.proj_out.bias"],
880
}
881
)
882
883
return diffusers_checkpoint
884
885
886
# TODO maybe document and/or can do more efficiently (build indices in for loop and extract once for each split?)
887
def split_attentions(*, weight, bias, split, chunk_size):
888
weights = [None] * split
889
biases = [None] * split
890
891
weights_biases_idx = 0
892
893
for starting_row_index in range(0, weight.shape[0], chunk_size):
894
row_indices = torch.arange(starting_row_index, starting_row_index + chunk_size)
895
896
weight_rows = weight[row_indices, :]
897
bias_rows = bias[row_indices]
898
899
if weights[weights_biases_idx] is None:
900
assert weights[weights_biases_idx] is None
901
weights[weights_biases_idx] = weight_rows
902
biases[weights_biases_idx] = bias_rows
903
else:
904
assert weights[weights_biases_idx] is not None
905
weights[weights_biases_idx] = torch.concat([weights[weights_biases_idx], weight_rows])
906
biases[weights_biases_idx] = torch.concat([biases[weights_biases_idx], bias_rows])
907
908
weights_biases_idx = (weights_biases_idx + 1) % split
909
910
return weights, biases
911
912
913
# done unet utils
914
915
916
# Driver functions
917
918
919
def text_encoder():
920
print("loading CLIP text encoder")
921
922
clip_name = "openai/clip-vit-large-patch14"
923
924
# sets pad_value to 0
925
pad_token = "!"
926
927
tokenizer_model = CLIPTokenizer.from_pretrained(clip_name, pad_token=pad_token, device_map="auto")
928
929
assert tokenizer_model.convert_tokens_to_ids(pad_token) == 0
930
931
text_encoder_model = CLIPTextModelWithProjection.from_pretrained(
932
clip_name,
933
# `CLIPTextModel` does not support device_map="auto"
934
# device_map="auto"
935
)
936
937
print("done loading CLIP text encoder")
938
939
return text_encoder_model, tokenizer_model
940
941
942
def prior(*, args, checkpoint_map_location):
943
print("loading prior")
944
945
prior_checkpoint = torch.load(args.prior_checkpoint_path, map_location=checkpoint_map_location)
946
prior_checkpoint = prior_checkpoint["state_dict"]
947
948
clip_stats_checkpoint = torch.load(args.clip_stat_path, map_location=checkpoint_map_location)
949
950
prior_model = prior_model_from_original_config()
951
952
prior_diffusers_checkpoint = prior_original_checkpoint_to_diffusers_checkpoint(
953
prior_model, prior_checkpoint, clip_stats_checkpoint
954
)
955
956
del prior_checkpoint
957
del clip_stats_checkpoint
958
959
load_checkpoint_to_model(prior_diffusers_checkpoint, prior_model, strict=True)
960
961
print("done loading prior")
962
963
return prior_model
964
965
966
def decoder(*, args, checkpoint_map_location):
967
print("loading decoder")
968
969
decoder_checkpoint = torch.load(args.decoder_checkpoint_path, map_location=checkpoint_map_location)
970
decoder_checkpoint = decoder_checkpoint["state_dict"]
971
972
decoder_model = decoder_model_from_original_config()
973
974
decoder_diffusers_checkpoint = decoder_original_checkpoint_to_diffusers_checkpoint(
975
decoder_model, decoder_checkpoint
976
)
977
978
# text proj interlude
979
980
# The original decoder implementation includes a set of parameters that are used
981
# for creating the `encoder_hidden_states` which are what the U-net is conditioned
982
# on. The diffusers conditional unet directly takes the encoder_hidden_states. We pull
983
# the parameters into the UnCLIPTextProjModel class
984
text_proj_model = text_proj_from_original_config()
985
986
text_proj_checkpoint = text_proj_original_checkpoint_to_diffusers_checkpoint(decoder_checkpoint)
987
988
load_checkpoint_to_model(text_proj_checkpoint, text_proj_model, strict=True)
989
990
# done text proj interlude
991
992
del decoder_checkpoint
993
994
load_checkpoint_to_model(decoder_diffusers_checkpoint, decoder_model, strict=True)
995
996
print("done loading decoder")
997
998
return decoder_model, text_proj_model
999
1000
1001
def super_res_unet(*, args, checkpoint_map_location):
1002
print("loading super resolution unet")
1003
1004
super_res_checkpoint = torch.load(args.super_res_unet_checkpoint_path, map_location=checkpoint_map_location)
1005
super_res_checkpoint = super_res_checkpoint["state_dict"]
1006
1007
# model_first_steps
1008
1009
super_res_first_model = super_res_unet_first_steps_model_from_original_config()
1010
1011
super_res_first_steps_checkpoint = super_res_unet_first_steps_original_checkpoint_to_diffusers_checkpoint(
1012
super_res_first_model, super_res_checkpoint
1013
)
1014
1015
# model_last_step
1016
super_res_last_model = super_res_unet_last_step_model_from_original_config()
1017
1018
super_res_last_step_checkpoint = super_res_unet_last_step_original_checkpoint_to_diffusers_checkpoint(
1019
super_res_last_model, super_res_checkpoint
1020
)
1021
1022
del super_res_checkpoint
1023
1024
load_checkpoint_to_model(super_res_first_steps_checkpoint, super_res_first_model, strict=True)
1025
1026
load_checkpoint_to_model(super_res_last_step_checkpoint, super_res_last_model, strict=True)
1027
1028
print("done loading super resolution unet")
1029
1030
return super_res_first_model, super_res_last_model
1031
1032
1033
def load_checkpoint_to_model(checkpoint, model, strict=False):
1034
with tempfile.NamedTemporaryFile() as file:
1035
torch.save(checkpoint, file.name)
1036
del checkpoint
1037
if strict:
1038
model.load_state_dict(torch.load(file.name), strict=True)
1039
else:
1040
load_checkpoint_and_dispatch(model, file.name, device_map="auto")
1041
1042
1043
if __name__ == "__main__":
1044
parser = argparse.ArgumentParser()
1045
1046
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
1047
1048
parser.add_argument(
1049
"--prior_checkpoint_path",
1050
default=None,
1051
type=str,
1052
required=True,
1053
help="Path to the prior checkpoint to convert.",
1054
)
1055
1056
parser.add_argument(
1057
"--decoder_checkpoint_path",
1058
default=None,
1059
type=str,
1060
required=True,
1061
help="Path to the decoder checkpoint to convert.",
1062
)
1063
1064
parser.add_argument(
1065
"--super_res_unet_checkpoint_path",
1066
default=None,
1067
type=str,
1068
required=True,
1069
help="Path to the super resolution checkpoint to convert.",
1070
)
1071
1072
parser.add_argument(
1073
"--clip_stat_path", default=None, type=str, required=True, help="Path to the clip stats checkpoint to convert."
1074
)
1075
1076
parser.add_argument(
1077
"--checkpoint_load_device",
1078
default="cpu",
1079
type=str,
1080
required=False,
1081
help="The device passed to `map_location` when loading checkpoints.",
1082
)
1083
1084
parser.add_argument(
1085
"--debug",
1086
default=None,
1087
type=str,
1088
required=False,
1089
help="Only run a specific stage of the convert script. Used for debugging",
1090
)
1091
1092
args = parser.parse_args()
1093
1094
print(f"loading checkpoints to {args.checkpoint_load_device}")
1095
1096
checkpoint_map_location = torch.device(args.checkpoint_load_device)
1097
1098
if args.debug is not None:
1099
print(f"debug: only executing {args.debug}")
1100
1101
if args.debug is None:
1102
text_encoder_model, tokenizer_model = text_encoder()
1103
1104
prior_model = prior(args=args, checkpoint_map_location=checkpoint_map_location)
1105
1106
decoder_model, text_proj_model = decoder(args=args, checkpoint_map_location=checkpoint_map_location)
1107
1108
super_res_first_model, super_res_last_model = super_res_unet(
1109
args=args, checkpoint_map_location=checkpoint_map_location
1110
)
1111
1112
prior_scheduler = UnCLIPScheduler(
1113
variance_type="fixed_small_log",
1114
prediction_type="sample",
1115
num_train_timesteps=1000,
1116
clip_sample_range=5.0,
1117
)
1118
1119
decoder_scheduler = UnCLIPScheduler(
1120
variance_type="learned_range",
1121
prediction_type="epsilon",
1122
num_train_timesteps=1000,
1123
)
1124
1125
super_res_scheduler = UnCLIPScheduler(
1126
variance_type="fixed_small_log",
1127
prediction_type="epsilon",
1128
num_train_timesteps=1000,
1129
)
1130
1131
print(f"saving Kakao Brain unCLIP to {args.dump_path}")
1132
1133
pipe = UnCLIPPipeline(
1134
prior=prior_model,
1135
decoder=decoder_model,
1136
text_proj=text_proj_model,
1137
tokenizer=tokenizer_model,
1138
text_encoder=text_encoder_model,
1139
super_res_first=super_res_first_model,
1140
super_res_last=super_res_last_model,
1141
prior_scheduler=prior_scheduler,
1142
decoder_scheduler=decoder_scheduler,
1143
super_res_scheduler=super_res_scheduler,
1144
)
1145
pipe.save_pretrained(args.dump_path)
1146
1147
print("done writing Kakao Brain unCLIP")
1148
elif args.debug == "text_encoder":
1149
text_encoder_model, tokenizer_model = text_encoder()
1150
elif args.debug == "prior":
1151
prior_model = prior(args=args, checkpoint_map_location=checkpoint_map_location)
1152
elif args.debug == "decoder":
1153
decoder_model, text_proj_model = decoder(args=args, checkpoint_map_location=checkpoint_map_location)
1154
elif args.debug == "super_res_unet":
1155
super_res_first_model, super_res_last_model = super_res_unet(
1156
args=args, checkpoint_map_location=checkpoint_map_location
1157
)
1158
else:
1159
raise ValueError(f"unknown debug value : {args.debug}")
1160
1161