Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/scripts/convert_versatile_diffusion_to_diffusers.py
1440 views
1
# coding=utf-8
2
# Copyright 2023 The HuggingFace Inc. team.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
""" Conversion script for the Versatile Stable Diffusion checkpoints. """
16
17
import argparse
18
from argparse import Namespace
19
20
import torch
21
from transformers import (
22
CLIPImageProcessor,
23
CLIPTextModelWithProjection,
24
CLIPTokenizer,
25
CLIPVisionModelWithProjection,
26
)
27
28
from diffusers import (
29
AutoencoderKL,
30
DDIMScheduler,
31
DPMSolverMultistepScheduler,
32
EulerAncestralDiscreteScheduler,
33
EulerDiscreteScheduler,
34
LMSDiscreteScheduler,
35
PNDMScheduler,
36
UNet2DConditionModel,
37
VersatileDiffusionPipeline,
38
)
39
from diffusers.pipelines.versatile_diffusion.modeling_text_unet import UNetFlatConditionModel
40
41
42
SCHEDULER_CONFIG = Namespace(
43
**{
44
"beta_linear_start": 0.00085,
45
"beta_linear_end": 0.012,
46
"timesteps": 1000,
47
"scale_factor": 0.18215,
48
}
49
)
50
51
IMAGE_UNET_CONFIG = Namespace(
52
**{
53
"input_channels": 4,
54
"model_channels": 320,
55
"output_channels": 4,
56
"num_noattn_blocks": [2, 2, 2, 2],
57
"channel_mult": [1, 2, 4, 4],
58
"with_attn": [True, True, True, False],
59
"num_heads": 8,
60
"context_dim": 768,
61
"use_checkpoint": True,
62
}
63
)
64
65
TEXT_UNET_CONFIG = Namespace(
66
**{
67
"input_channels": 768,
68
"model_channels": 320,
69
"output_channels": 768,
70
"num_noattn_blocks": [2, 2, 2, 2],
71
"channel_mult": [1, 2, 4, 4],
72
"second_dim": [4, 4, 4, 4],
73
"with_attn": [True, True, True, False],
74
"num_heads": 8,
75
"context_dim": 768,
76
"use_checkpoint": True,
77
}
78
)
79
80
AUTOENCODER_CONFIG = Namespace(
81
**{
82
"double_z": True,
83
"z_channels": 4,
84
"resolution": 256,
85
"in_channels": 3,
86
"out_ch": 3,
87
"ch": 128,
88
"ch_mult": [1, 2, 4, 4],
89
"num_res_blocks": 2,
90
"attn_resolutions": [],
91
"dropout": 0.0,
92
}
93
)
94
95
96
def shave_segments(path, n_shave_prefix_segments=1):
97
"""
98
Removes segments. Positive values shave the first segments, negative shave the last segments.
99
"""
100
if n_shave_prefix_segments >= 0:
101
return ".".join(path.split(".")[n_shave_prefix_segments:])
102
else:
103
return ".".join(path.split(".")[:n_shave_prefix_segments])
104
105
106
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
107
"""
108
Updates paths inside resnets to the new naming scheme (local renaming)
109
"""
110
mapping = []
111
for old_item in old_list:
112
new_item = old_item.replace("in_layers.0", "norm1")
113
new_item = new_item.replace("in_layers.2", "conv1")
114
115
new_item = new_item.replace("out_layers.0", "norm2")
116
new_item = new_item.replace("out_layers.3", "conv2")
117
118
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
119
new_item = new_item.replace("skip_connection", "conv_shortcut")
120
121
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
122
123
mapping.append({"old": old_item, "new": new_item})
124
125
return mapping
126
127
128
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
129
"""
130
Updates paths inside resnets to the new naming scheme (local renaming)
131
"""
132
mapping = []
133
for old_item in old_list:
134
new_item = old_item
135
136
new_item = new_item.replace("nin_shortcut", "conv_shortcut")
137
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
138
139
mapping.append({"old": old_item, "new": new_item})
140
141
return mapping
142
143
144
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
145
"""
146
Updates paths inside attentions to the new naming scheme (local renaming)
147
"""
148
mapping = []
149
for old_item in old_list:
150
new_item = old_item
151
152
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
153
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
154
155
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
156
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
157
158
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
159
160
mapping.append({"old": old_item, "new": new_item})
161
162
return mapping
163
164
165
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
166
"""
167
Updates paths inside attentions to the new naming scheme (local renaming)
168
"""
169
mapping = []
170
for old_item in old_list:
171
new_item = old_item
172
173
new_item = new_item.replace("norm.weight", "group_norm.weight")
174
new_item = new_item.replace("norm.bias", "group_norm.bias")
175
176
new_item = new_item.replace("q.weight", "query.weight")
177
new_item = new_item.replace("q.bias", "query.bias")
178
179
new_item = new_item.replace("k.weight", "key.weight")
180
new_item = new_item.replace("k.bias", "key.bias")
181
182
new_item = new_item.replace("v.weight", "value.weight")
183
new_item = new_item.replace("v.bias", "value.bias")
184
185
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
186
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
187
188
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
189
190
mapping.append({"old": old_item, "new": new_item})
191
192
return mapping
193
194
195
def assign_to_checkpoint(
196
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
197
):
198
"""
199
This does the final conversion step: take locally converted weights and apply a global renaming
200
to them. It splits attention layers, and takes into account additional replacements
201
that may arise.
202
203
Assigns the weights to the new checkpoint.
204
"""
205
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
206
207
# Splits the attention layers into three variables.
208
if attention_paths_to_split is not None:
209
for path, path_map in attention_paths_to_split.items():
210
old_tensor = old_checkpoint[path]
211
channels = old_tensor.shape[0] // 3
212
213
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
214
215
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
216
217
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
218
query, key, value = old_tensor.split(channels // num_heads, dim=1)
219
220
checkpoint[path_map["query"]] = query.reshape(target_shape)
221
checkpoint[path_map["key"]] = key.reshape(target_shape)
222
checkpoint[path_map["value"]] = value.reshape(target_shape)
223
224
for path in paths:
225
new_path = path["new"]
226
227
# These have already been assigned
228
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
229
continue
230
231
# Global renaming happens here
232
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
233
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
234
new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
235
236
if additional_replacements is not None:
237
for replacement in additional_replacements:
238
new_path = new_path.replace(replacement["old"], replacement["new"])
239
240
# proj_attn.weight has to be converted from conv 1D to linear
241
if "proj_attn.weight" in new_path:
242
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
243
elif path["old"] in old_checkpoint:
244
checkpoint[new_path] = old_checkpoint[path["old"]]
245
246
247
def conv_attn_to_linear(checkpoint):
248
keys = list(checkpoint.keys())
249
attn_keys = ["query.weight", "key.weight", "value.weight"]
250
for key in keys:
251
if ".".join(key.split(".")[-2:]) in attn_keys:
252
if checkpoint[key].ndim > 2:
253
checkpoint[key] = checkpoint[key][:, :, 0, 0]
254
elif "proj_attn.weight" in key:
255
if checkpoint[key].ndim > 2:
256
checkpoint[key] = checkpoint[key][:, :, 0]
257
258
259
def create_image_unet_diffusers_config(unet_params):
260
"""
261
Creates a config for the diffusers based on the config of the VD model.
262
"""
263
264
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
265
266
down_block_types = []
267
resolution = 1
268
for i in range(len(block_out_channels)):
269
block_type = "CrossAttnDownBlock2D" if unet_params.with_attn[i] else "DownBlock2D"
270
down_block_types.append(block_type)
271
if i != len(block_out_channels) - 1:
272
resolution *= 2
273
274
up_block_types = []
275
for i in range(len(block_out_channels)):
276
block_type = "CrossAttnUpBlock2D" if unet_params.with_attn[-i - 1] else "UpBlock2D"
277
up_block_types.append(block_type)
278
resolution //= 2
279
280
if not all(n == unet_params.num_noattn_blocks[0] for n in unet_params.num_noattn_blocks):
281
raise ValueError("Not all num_res_blocks are equal, which is not supported in this script.")
282
283
config = dict(
284
sample_size=None,
285
in_channels=unet_params.input_channels,
286
out_channels=unet_params.output_channels,
287
down_block_types=tuple(down_block_types),
288
up_block_types=tuple(up_block_types),
289
block_out_channels=tuple(block_out_channels),
290
layers_per_block=unet_params.num_noattn_blocks[0],
291
cross_attention_dim=unet_params.context_dim,
292
attention_head_dim=unet_params.num_heads,
293
)
294
295
return config
296
297
298
def create_text_unet_diffusers_config(unet_params):
299
"""
300
Creates a config for the diffusers based on the config of the VD model.
301
"""
302
303
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
304
305
down_block_types = []
306
resolution = 1
307
for i in range(len(block_out_channels)):
308
block_type = "CrossAttnDownBlockFlat" if unet_params.with_attn[i] else "DownBlockFlat"
309
down_block_types.append(block_type)
310
if i != len(block_out_channels) - 1:
311
resolution *= 2
312
313
up_block_types = []
314
for i in range(len(block_out_channels)):
315
block_type = "CrossAttnUpBlockFlat" if unet_params.with_attn[-i - 1] else "UpBlockFlat"
316
up_block_types.append(block_type)
317
resolution //= 2
318
319
if not all(n == unet_params.num_noattn_blocks[0] for n in unet_params.num_noattn_blocks):
320
raise ValueError("Not all num_res_blocks are equal, which is not supported in this script.")
321
322
config = dict(
323
sample_size=None,
324
in_channels=(unet_params.input_channels, 1, 1),
325
out_channels=(unet_params.output_channels, 1, 1),
326
down_block_types=tuple(down_block_types),
327
up_block_types=tuple(up_block_types),
328
block_out_channels=tuple(block_out_channels),
329
layers_per_block=unet_params.num_noattn_blocks[0],
330
cross_attention_dim=unet_params.context_dim,
331
attention_head_dim=unet_params.num_heads,
332
)
333
334
return config
335
336
337
def create_vae_diffusers_config(vae_params):
338
"""
339
Creates a config for the diffusers based on the config of the VD model.
340
"""
341
342
block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
343
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
344
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
345
346
config = dict(
347
sample_size=vae_params.resolution,
348
in_channels=vae_params.in_channels,
349
out_channels=vae_params.out_ch,
350
down_block_types=tuple(down_block_types),
351
up_block_types=tuple(up_block_types),
352
block_out_channels=tuple(block_out_channels),
353
latent_channels=vae_params.z_channels,
354
layers_per_block=vae_params.num_res_blocks,
355
)
356
return config
357
358
359
def create_diffusers_scheduler(original_config):
360
schedular = DDIMScheduler(
361
num_train_timesteps=original_config.model.params.timesteps,
362
beta_start=original_config.model.params.linear_start,
363
beta_end=original_config.model.params.linear_end,
364
beta_schedule="scaled_linear",
365
)
366
return schedular
367
368
369
def convert_vd_unet_checkpoint(checkpoint, config, unet_key, extract_ema=False):
370
"""
371
Takes a state dict and a config, and returns a converted checkpoint.
372
"""
373
374
# extract state_dict for UNet
375
unet_state_dict = {}
376
keys = list(checkpoint.keys())
377
378
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
379
if sum(k.startswith("model_ema") for k in keys) > 100:
380
print("Checkpoint has both EMA and non-EMA weights.")
381
if extract_ema:
382
print(
383
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
384
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
385
)
386
for key in keys:
387
if key.startswith("model.diffusion_model"):
388
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
389
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
390
else:
391
print(
392
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
393
" weights (usually better for inference), please make sure to add the `--extract_ema` flag."
394
)
395
396
for key in keys:
397
if key.startswith(unet_key):
398
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
399
400
new_checkpoint = {}
401
402
new_checkpoint["time_embedding.linear_1.weight"] = checkpoint["model.diffusion_model.time_embed.0.weight"]
403
new_checkpoint["time_embedding.linear_1.bias"] = checkpoint["model.diffusion_model.time_embed.0.bias"]
404
new_checkpoint["time_embedding.linear_2.weight"] = checkpoint["model.diffusion_model.time_embed.2.weight"]
405
new_checkpoint["time_embedding.linear_2.bias"] = checkpoint["model.diffusion_model.time_embed.2.bias"]
406
407
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
408
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
409
410
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
411
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
412
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
413
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
414
415
# Retrieves the keys for the input blocks only
416
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
417
input_blocks = {
418
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
419
for layer_id in range(num_input_blocks)
420
}
421
422
# Retrieves the keys for the middle blocks only
423
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
424
middle_blocks = {
425
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
426
for layer_id in range(num_middle_blocks)
427
}
428
429
# Retrieves the keys for the output blocks only
430
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
431
output_blocks = {
432
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
433
for layer_id in range(num_output_blocks)
434
}
435
436
for i in range(1, num_input_blocks):
437
block_id = (i - 1) // (config["layers_per_block"] + 1)
438
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
439
440
resnets = [
441
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
442
]
443
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
444
445
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
446
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
447
f"input_blocks.{i}.0.op.weight"
448
)
449
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
450
f"input_blocks.{i}.0.op.bias"
451
)
452
elif f"input_blocks.{i}.0.weight" in unet_state_dict:
453
# text_unet uses linear layers in place of downsamplers
454
shape = unet_state_dict[f"input_blocks.{i}.0.weight"].shape
455
if shape[0] != shape[1]:
456
continue
457
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.weight"] = unet_state_dict.pop(
458
f"input_blocks.{i}.0.weight"
459
)
460
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.bias"] = unet_state_dict.pop(
461
f"input_blocks.{i}.0.bias"
462
)
463
464
paths = renew_resnet_paths(resnets)
465
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
466
assign_to_checkpoint(
467
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
468
)
469
470
if len(attentions):
471
paths = renew_attention_paths(attentions)
472
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
473
assign_to_checkpoint(
474
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
475
)
476
477
resnet_0 = middle_blocks[0]
478
attentions = middle_blocks[1]
479
resnet_1 = middle_blocks[2]
480
481
resnet_0_paths = renew_resnet_paths(resnet_0)
482
assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
483
484
resnet_1_paths = renew_resnet_paths(resnet_1)
485
assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
486
487
attentions_paths = renew_attention_paths(attentions)
488
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
489
assign_to_checkpoint(
490
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
491
)
492
493
for i in range(num_output_blocks):
494
block_id = i // (config["layers_per_block"] + 1)
495
layer_in_block_id = i % (config["layers_per_block"] + 1)
496
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
497
output_block_list = {}
498
499
for layer in output_block_layers:
500
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
501
if layer_id in output_block_list:
502
output_block_list[layer_id].append(layer_name)
503
else:
504
output_block_list[layer_id] = [layer_name]
505
506
if len(output_block_list) > 1:
507
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
508
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
509
510
paths = renew_resnet_paths(resnets)
511
512
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
513
assign_to_checkpoint(
514
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
515
)
516
517
if ["conv.weight", "conv.bias"] in output_block_list.values():
518
index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
519
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
520
f"output_blocks.{i}.{index}.conv.weight"
521
]
522
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
523
f"output_blocks.{i}.{index}.conv.bias"
524
]
525
# Clear attentions as they have been attributed above.
526
if len(attentions) == 2:
527
attentions = []
528
elif f"output_blocks.{i}.1.weight" in unet_state_dict:
529
# text_unet uses linear layers in place of upsamplers
530
shape = unet_state_dict[f"output_blocks.{i}.1.weight"].shape
531
if shape[0] != shape[1]:
532
continue
533
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.weight"] = unet_state_dict.pop(
534
f"output_blocks.{i}.1.weight"
535
)
536
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.bias"] = unet_state_dict.pop(
537
f"output_blocks.{i}.1.bias"
538
)
539
# Clear attentions as they have been attributed above.
540
if len(attentions) == 2:
541
attentions = []
542
elif f"output_blocks.{i}.2.weight" in unet_state_dict:
543
# text_unet uses linear layers in place of upsamplers
544
shape = unet_state_dict[f"output_blocks.{i}.2.weight"].shape
545
if shape[0] != shape[1]:
546
continue
547
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.weight"] = unet_state_dict.pop(
548
f"output_blocks.{i}.2.weight"
549
)
550
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.bias"] = unet_state_dict.pop(
551
f"output_blocks.{i}.2.bias"
552
)
553
554
if len(attentions):
555
paths = renew_attention_paths(attentions)
556
meta_path = {
557
"old": f"output_blocks.{i}.1",
558
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
559
}
560
assign_to_checkpoint(
561
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
562
)
563
else:
564
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
565
for path in resnet_0_paths:
566
old_path = ".".join(["output_blocks", str(i), path["old"]])
567
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
568
569
new_checkpoint[new_path] = unet_state_dict[old_path]
570
571
return new_checkpoint
572
573
574
def convert_vd_vae_checkpoint(checkpoint, config):
575
# extract state dict for VAE
576
vae_state_dict = {}
577
keys = list(checkpoint.keys())
578
for key in keys:
579
vae_state_dict[key] = checkpoint.get(key)
580
581
new_checkpoint = {}
582
583
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
584
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
585
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
586
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
587
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
588
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
589
590
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
591
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
592
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
593
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
594
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
595
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
596
597
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
598
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
599
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
600
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
601
602
# Retrieves the keys for the encoder down blocks only
603
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
604
down_blocks = {
605
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
606
}
607
608
# Retrieves the keys for the decoder up blocks only
609
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
610
up_blocks = {
611
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
612
}
613
614
for i in range(num_down_blocks):
615
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
616
617
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
618
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
619
f"encoder.down.{i}.downsample.conv.weight"
620
)
621
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
622
f"encoder.down.{i}.downsample.conv.bias"
623
)
624
625
paths = renew_vae_resnet_paths(resnets)
626
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
627
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
628
629
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
630
num_mid_res_blocks = 2
631
for i in range(1, num_mid_res_blocks + 1):
632
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
633
634
paths = renew_vae_resnet_paths(resnets)
635
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
636
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
637
638
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
639
paths = renew_vae_attention_paths(mid_attentions)
640
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
641
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
642
conv_attn_to_linear(new_checkpoint)
643
644
for i in range(num_up_blocks):
645
block_id = num_up_blocks - 1 - i
646
resnets = [
647
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
648
]
649
650
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
651
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
652
f"decoder.up.{block_id}.upsample.conv.weight"
653
]
654
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
655
f"decoder.up.{block_id}.upsample.conv.bias"
656
]
657
658
paths = renew_vae_resnet_paths(resnets)
659
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
660
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
661
662
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
663
num_mid_res_blocks = 2
664
for i in range(1, num_mid_res_blocks + 1):
665
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
666
667
paths = renew_vae_resnet_paths(resnets)
668
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
669
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
670
671
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
672
paths = renew_vae_attention_paths(mid_attentions)
673
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
674
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
675
conv_attn_to_linear(new_checkpoint)
676
return new_checkpoint
677
678
679
if __name__ == "__main__":
680
parser = argparse.ArgumentParser()
681
682
parser.add_argument(
683
"--unet_checkpoint_path", default=None, type=str, required=False, help="Path to the checkpoint to convert."
684
)
685
parser.add_argument(
686
"--vae_checkpoint_path", default=None, type=str, required=False, help="Path to the checkpoint to convert."
687
)
688
parser.add_argument(
689
"--optimus_checkpoint_path", default=None, type=str, required=False, help="Path to the checkpoint to convert."
690
)
691
parser.add_argument(
692
"--scheduler_type",
693
default="pndm",
694
type=str,
695
help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim', 'euler', 'euler-ancestral', 'dpm']",
696
)
697
parser.add_argument(
698
"--extract_ema",
699
action="store_true",
700
help=(
701
"Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights"
702
" or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield"
703
" higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning."
704
),
705
)
706
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
707
708
args = parser.parse_args()
709
710
scheduler_config = SCHEDULER_CONFIG
711
712
num_train_timesteps = scheduler_config.timesteps
713
beta_start = scheduler_config.beta_linear_start
714
beta_end = scheduler_config.beta_linear_end
715
if args.scheduler_type == "pndm":
716
scheduler = PNDMScheduler(
717
beta_end=beta_end,
718
beta_schedule="scaled_linear",
719
beta_start=beta_start,
720
num_train_timesteps=num_train_timesteps,
721
skip_prk_steps=True,
722
steps_offset=1,
723
)
724
elif args.scheduler_type == "lms":
725
scheduler = LMSDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear")
726
elif args.scheduler_type == "euler":
727
scheduler = EulerDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear")
728
elif args.scheduler_type == "euler-ancestral":
729
scheduler = EulerAncestralDiscreteScheduler(
730
beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear"
731
)
732
elif args.scheduler_type == "dpm":
733
scheduler = DPMSolverMultistepScheduler(
734
beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear"
735
)
736
elif args.scheduler_type == "ddim":
737
scheduler = DDIMScheduler(
738
beta_start=beta_start,
739
beta_end=beta_end,
740
beta_schedule="scaled_linear",
741
clip_sample=False,
742
set_alpha_to_one=False,
743
steps_offset=1,
744
)
745
else:
746
raise ValueError(f"Scheduler of type {args.scheduler_type} doesn't exist!")
747
748
# Convert the UNet2DConditionModel models.
749
if args.unet_checkpoint_path is not None:
750
# image UNet
751
image_unet_config = create_image_unet_diffusers_config(IMAGE_UNET_CONFIG)
752
checkpoint = torch.load(args.unet_checkpoint_path)
753
converted_image_unet_checkpoint = convert_vd_unet_checkpoint(
754
checkpoint, image_unet_config, unet_key="model.diffusion_model.unet_image.", extract_ema=args.extract_ema
755
)
756
image_unet = UNet2DConditionModel(**image_unet_config)
757
image_unet.load_state_dict(converted_image_unet_checkpoint)
758
759
# text UNet
760
text_unet_config = create_text_unet_diffusers_config(TEXT_UNET_CONFIG)
761
converted_text_unet_checkpoint = convert_vd_unet_checkpoint(
762
checkpoint, text_unet_config, unet_key="model.diffusion_model.unet_text.", extract_ema=args.extract_ema
763
)
764
text_unet = UNetFlatConditionModel(**text_unet_config)
765
text_unet.load_state_dict(converted_text_unet_checkpoint)
766
767
# Convert the VAE model.
768
if args.vae_checkpoint_path is not None:
769
vae_config = create_vae_diffusers_config(AUTOENCODER_CONFIG)
770
checkpoint = torch.load(args.vae_checkpoint_path)
771
converted_vae_checkpoint = convert_vd_vae_checkpoint(checkpoint, vae_config)
772
773
vae = AutoencoderKL(**vae_config)
774
vae.load_state_dict(converted_vae_checkpoint)
775
776
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
777
image_feature_extractor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14")
778
text_encoder = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
779
image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
780
781
pipe = VersatileDiffusionPipeline(
782
scheduler=scheduler,
783
tokenizer=tokenizer,
784
image_feature_extractor=image_feature_extractor,
785
text_encoder=text_encoder,
786
image_encoder=image_encoder,
787
image_unet=image_unet,
788
text_unet=text_unet,
789
vae=vae,
790
)
791
pipe.save_pretrained(args.dump_path)
792
793