Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TheLastBen
GitHub Repository: TheLastBen/fast-stable-diffusion
Path: blob/main/Dreambooth/convertodiffv2-768.py
540 views
1
import argparse
2
import os
3
import torch
4
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig
5
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
6
7
8
# DiffUsers版StableDiffusionのモデルパラメータ
9
NUM_TRAIN_TIMESTEPS = 1000
10
BETA_START = 0.00085
11
BETA_END = 0.0120
12
13
UNET_PARAMS_MODEL_CHANNELS = 320
14
UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
15
UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
16
UNET_PARAMS_IMAGE_SIZE = 96
17
UNET_PARAMS_IN_CHANNELS = 4
18
UNET_PARAMS_OUT_CHANNELS = 4
19
UNET_PARAMS_NUM_RES_BLOCKS = 2
20
UNET_PARAMS_CONTEXT_DIM = 768
21
UNET_PARAMS_NUM_HEADS = 8
22
23
VAE_PARAMS_Z_CHANNELS = 4
24
VAE_PARAMS_RESOLUTION = 768
25
VAE_PARAMS_IN_CHANNELS = 3
26
VAE_PARAMS_OUT_CH = 3
27
VAE_PARAMS_CH = 128
28
VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
29
VAE_PARAMS_NUM_RES_BLOCKS = 2
30
31
# V2
32
V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
33
V2_UNET_PARAMS_CONTEXT_DIM = 1024
34
35
36
# region StableDiffusion->Diffusersの変換コード
37
# convert_original_stable_diffusion_to_diffusers をコピーしている(ASL 2.0)
38
39
40
def shave_segments(path, n_shave_prefix_segments=1):
41
"""
42
Removes segments. Positive values shave the first segments, negative shave the last segments.
43
"""
44
if n_shave_prefix_segments >= 0:
45
return ".".join(path.split(".")[n_shave_prefix_segments:])
46
else:
47
return ".".join(path.split(".")[:n_shave_prefix_segments])
48
49
50
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
51
"""
52
Updates paths inside resnets to the new naming scheme (local renaming)
53
"""
54
mapping = []
55
for old_item in old_list:
56
new_item = old_item.replace("in_layers.0", "norm1")
57
new_item = new_item.replace("in_layers.2", "conv1")
58
59
new_item = new_item.replace("out_layers.0", "norm2")
60
new_item = new_item.replace("out_layers.3", "conv2")
61
62
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
63
new_item = new_item.replace("skip_connection", "conv_shortcut")
64
65
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
66
67
mapping.append({"old": old_item, "new": new_item})
68
69
return mapping
70
71
72
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
73
"""
74
Updates paths inside resnets to the new naming scheme (local renaming)
75
"""
76
mapping = []
77
for old_item in old_list:
78
new_item = old_item
79
80
new_item = new_item.replace("nin_shortcut", "conv_shortcut")
81
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
82
83
mapping.append({"old": old_item, "new": new_item})
84
85
return mapping
86
87
88
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
89
"""
90
Updates paths inside attentions to the new naming scheme (local renaming)
91
"""
92
mapping = []
93
for old_item in old_list:
94
new_item = old_item
95
96
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
97
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
98
99
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
100
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
101
102
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
103
104
mapping.append({"old": old_item, "new": new_item})
105
106
return mapping
107
108
109
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
110
"""
111
Updates paths inside attentions to the new naming scheme (local renaming)
112
"""
113
mapping = []
114
for old_item in old_list:
115
new_item = old_item
116
117
new_item = new_item.replace("norm.weight", "group_norm.weight")
118
new_item = new_item.replace("norm.bias", "group_norm.bias")
119
120
new_item = new_item.replace("q.weight", "query.weight")
121
new_item = new_item.replace("q.bias", "query.bias")
122
123
new_item = new_item.replace("k.weight", "key.weight")
124
new_item = new_item.replace("k.bias", "key.bias")
125
126
new_item = new_item.replace("v.weight", "value.weight")
127
new_item = new_item.replace("v.bias", "value.bias")
128
129
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
130
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
131
132
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
133
134
mapping.append({"old": old_item, "new": new_item})
135
136
return mapping
137
138
139
def assign_to_checkpoint(
140
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
141
):
142
"""
143
This does the final conversion step: take locally converted weights and apply a global renaming
144
to them. It splits attention layers, and takes into account additional replacements
145
that may arise.
146
147
Assigns the weights to the new checkpoint.
148
"""
149
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
150
151
# Splits the attention layers into three variables.
152
if attention_paths_to_split is not None:
153
for path, path_map in attention_paths_to_split.items():
154
old_tensor = old_checkpoint[path]
155
channels = old_tensor.shape[0] // 3
156
157
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
158
159
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
160
161
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
162
query, key, value = old_tensor.split(channels // num_heads, dim=1)
163
164
checkpoint[path_map["query"]] = query.reshape(target_shape)
165
checkpoint[path_map["key"]] = key.reshape(target_shape)
166
checkpoint[path_map["value"]] = value.reshape(target_shape)
167
168
for path in paths:
169
new_path = path["new"]
170
171
# These have already been assigned
172
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
173
continue
174
175
# Global renaming happens here
176
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
177
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
178
new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
179
180
if additional_replacements is not None:
181
for replacement in additional_replacements:
182
new_path = new_path.replace(replacement["old"], replacement["new"])
183
184
# proj_attn.weight has to be converted from conv 1D to linear
185
if "proj_attn.weight" in new_path:
186
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
187
else:
188
checkpoint[new_path] = old_checkpoint[path["old"]]
189
190
191
def conv_attn_to_linear(checkpoint):
192
keys = list(checkpoint.keys())
193
attn_keys = ["query.weight", "key.weight", "value.weight"]
194
for key in keys:
195
if ".".join(key.split(".")[-2:]) in attn_keys:
196
if checkpoint[key].ndim > 2:
197
checkpoint[key] = checkpoint[key][:, :, 0, 0]
198
elif "proj_attn.weight" in key:
199
if checkpoint[key].ndim > 2:
200
checkpoint[key] = checkpoint[key][:, :, 0]
201
202
203
def linear_transformer_to_conv(checkpoint):
204
keys = list(checkpoint.keys())
205
tf_keys = ["proj_in.weight", "proj_out.weight"]
206
for key in keys:
207
if ".".join(key.split(".")[-2:]) in tf_keys:
208
if checkpoint[key].ndim == 2:
209
checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
210
211
212
def convert_ldm_unet_checkpoint(checkpoint, config):
213
"""
214
Takes a state dict and a config, and returns a converted checkpoint.
215
"""
216
217
# extract state_dict for UNet
218
unet_state_dict = {}
219
keys = list(checkpoint.keys())
220
221
unet_key = "model.diffusion_model."
222
223
for key in keys:
224
if key.startswith(unet_key):
225
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
226
227
new_checkpoint = {}
228
229
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
230
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
231
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
232
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
233
234
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
235
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
236
237
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
238
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
239
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
240
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
241
242
# Retrieves the keys for the input blocks only
243
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
244
input_blocks = {
245
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
246
for layer_id in range(num_input_blocks)
247
}
248
249
# Retrieves the keys for the middle blocks only
250
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
251
middle_blocks = {
252
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
253
for layer_id in range(num_middle_blocks)
254
}
255
256
# Retrieves the keys for the output blocks only
257
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
258
output_blocks = {
259
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
260
for layer_id in range(num_output_blocks)
261
}
262
263
for i in range(1, num_input_blocks):
264
block_id = (i - 1) // (config["layers_per_block"] + 1)
265
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
266
267
resnets = [
268
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
269
]
270
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
271
272
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
273
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
274
f"input_blocks.{i}.0.op.weight"
275
)
276
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
277
f"input_blocks.{i}.0.op.bias"
278
)
279
280
paths = renew_resnet_paths(resnets)
281
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
282
assign_to_checkpoint(
283
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
284
)
285
286
if len(attentions):
287
paths = renew_attention_paths(attentions)
288
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
289
assign_to_checkpoint(
290
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
291
)
292
293
resnet_0 = middle_blocks[0]
294
attentions = middle_blocks[1]
295
resnet_1 = middle_blocks[2]
296
297
resnet_0_paths = renew_resnet_paths(resnet_0)
298
assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
299
300
resnet_1_paths = renew_resnet_paths(resnet_1)
301
assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
302
303
attentions_paths = renew_attention_paths(attentions)
304
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
305
assign_to_checkpoint(
306
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
307
)
308
309
for i in range(num_output_blocks):
310
block_id = i // (config["layers_per_block"] + 1)
311
layer_in_block_id = i % (config["layers_per_block"] + 1)
312
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
313
output_block_list = {}
314
315
for layer in output_block_layers:
316
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
317
if layer_id in output_block_list:
318
output_block_list[layer_id].append(layer_name)
319
else:
320
output_block_list[layer_id] = [layer_name]
321
322
if len(output_block_list) > 1:
323
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
324
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
325
326
resnet_0_paths = renew_resnet_paths(resnets)
327
paths = renew_resnet_paths(resnets)
328
329
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
330
assign_to_checkpoint(
331
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
332
)
333
334
output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
335
if ["conv.bias", "conv.weight"] in output_block_list.values():
336
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
337
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
338
f"output_blocks.{i}.{index}.conv.weight"
339
]
340
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
341
f"output_blocks.{i}.{index}.conv.bias"
342
]
343
344
# Clear attentions as they have been attributed above.
345
if len(attentions) == 2:
346
attentions = []
347
348
if len(attentions):
349
paths = renew_attention_paths(attentions)
350
meta_path = {
351
"old": f"output_blocks.{i}.1",
352
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
353
}
354
assign_to_checkpoint(
355
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
356
)
357
else:
358
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
359
for path in resnet_0_paths:
360
old_path = ".".join(["output_blocks", str(i), path["old"]])
361
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
362
363
new_checkpoint[new_path] = unet_state_dict[old_path]
364
365
return new_checkpoint
366
367
368
def convert_ldm_vae_checkpoint(checkpoint, config):
369
# extract state dict for VAE
370
vae_state_dict = {}
371
vae_key = "first_stage_model."
372
keys = list(checkpoint.keys())
373
for key in keys:
374
if key.startswith(vae_key):
375
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
376
# if len(vae_state_dict) == 0:
377
# # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict
378
# vae_state_dict = checkpoint
379
380
new_checkpoint = {}
381
382
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
383
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
384
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
385
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
386
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
387
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
388
389
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
390
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
391
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
392
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
393
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
394
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
395
396
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
397
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
398
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
399
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
400
401
# Retrieves the keys for the encoder down blocks only
402
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
403
down_blocks = {
404
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
405
}
406
407
# Retrieves the keys for the decoder up blocks only
408
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
409
up_blocks = {
410
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
411
}
412
413
for i in range(num_down_blocks):
414
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
415
416
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
417
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
418
f"encoder.down.{i}.downsample.conv.weight"
419
)
420
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
421
f"encoder.down.{i}.downsample.conv.bias"
422
)
423
424
paths = renew_vae_resnet_paths(resnets)
425
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
426
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
427
428
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
429
num_mid_res_blocks = 2
430
for i in range(1, num_mid_res_blocks + 1):
431
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
432
433
paths = renew_vae_resnet_paths(resnets)
434
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
435
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
436
437
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
438
paths = renew_vae_attention_paths(mid_attentions)
439
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
440
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
441
conv_attn_to_linear(new_checkpoint)
442
443
for i in range(num_up_blocks):
444
block_id = num_up_blocks - 1 - i
445
resnets = [
446
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
447
]
448
449
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
450
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
451
f"decoder.up.{block_id}.upsample.conv.weight"
452
]
453
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
454
f"decoder.up.{block_id}.upsample.conv.bias"
455
]
456
457
paths = renew_vae_resnet_paths(resnets)
458
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
459
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
460
461
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
462
num_mid_res_blocks = 2
463
for i in range(1, num_mid_res_blocks + 1):
464
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
465
466
paths = renew_vae_resnet_paths(resnets)
467
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
468
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
469
470
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
471
paths = renew_vae_attention_paths(mid_attentions)
472
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
473
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
474
conv_attn_to_linear(new_checkpoint)
475
return new_checkpoint
476
477
478
def create_unet_diffusers_config():
479
"""
480
Creates a config for the diffusers based on the config of the LDM model.
481
"""
482
#unet_params = original_config.model.params.unet_config.params
483
484
#vae_params = original_config.model.params.first_stage_config.params.ddconfig
485
486
block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT]
487
488
down_block_types = []
489
resolution = 1
490
for i in range(len(block_out_channels)):
491
block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D"
492
down_block_types.append(block_type)
493
if i != len(block_out_channels) - 1:
494
resolution *= 2
495
496
up_block_types = []
497
for i in range(len(block_out_channels)):
498
block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D"
499
up_block_types.append(block_type)
500
resolution //= 2
501
502
class_embed_type = None
503
projection_class_embeddings_input_dim = None
504
505
config = dict(
506
sample_size=UNET_PARAMS_IMAGE_SIZE,
507
in_channels=UNET_PARAMS_IN_CHANNELS,
508
out_channels=UNET_PARAMS_OUT_CHANNELS,
509
down_block_types=tuple(down_block_types),
510
up_block_types=tuple(up_block_types),
511
block_out_channels=tuple(block_out_channels),
512
layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
513
use_linear_projection=True,
514
cross_attention_dim=V2_UNET_PARAMS_CONTEXT_DIM,
515
attention_head_dim=V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
516
)
517
518
return config
519
520
521
def create_vae_diffusers_config():
522
"""
523
Creates a config for the diffusers based on the config of the LDM model.
524
"""
525
# vae_params = original_config.model.params.first_stage_config.params.ddconfig
526
# _ = original_config.model.params.first_stage_config.params.embed_dim
527
block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT]
528
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
529
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
530
531
config = dict(
532
sample_size=VAE_PARAMS_RESOLUTION,
533
in_channels=VAE_PARAMS_IN_CHANNELS,
534
out_channels=VAE_PARAMS_OUT_CH,
535
down_block_types=tuple(down_block_types),
536
up_block_types=tuple(up_block_types),
537
block_out_channels=tuple(block_out_channels),
538
latent_channels=VAE_PARAMS_Z_CHANNELS,
539
layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS,
540
)
541
return config
542
543
544
def convert_ldm_clip_checkpoint_v1(checkpoint):
545
keys = list(checkpoint.keys())
546
text_model_dict = {}
547
for key in keys:
548
if key.startswith("cond_stage_model.transformer"):
549
text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key]
550
return text_model_dict
551
552
553
def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
554
# 嫌になるくらい違うぞ!
555
def convert_key(key):
556
if not key.startswith("cond_stage_model"):
557
return None
558
559
# common conversion
560
key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.")
561
key = key.replace("cond_stage_model.model.", "text_model.")
562
563
if "resblocks" in key:
564
# resblocks conversion
565
key = key.replace(".resblocks.", ".layers.")
566
if ".ln_" in key:
567
key = key.replace(".ln_", ".layer_norm")
568
elif ".mlp." in key:
569
key = key.replace(".c_fc.", ".fc1.")
570
key = key.replace(".c_proj.", ".fc2.")
571
elif '.attn.out_proj' in key:
572
key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
573
elif '.attn.in_proj' in key:
574
key = None # 特殊なので後で処理する
575
else:
576
raise ValueError(f"unexpected key in SD: {key}")
577
elif '.positional_embedding' in key:
578
key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
579
elif '.text_projection' in key:
580
key = None # 使われない???
581
elif '.logit_scale' in key:
582
key = None # 使われない???
583
elif '.token_embedding' in key:
584
key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
585
elif '.ln_final' in key:
586
key = key.replace(".ln_final", ".final_layer_norm")
587
return key
588
589
keys = list(checkpoint.keys())
590
new_sd = {}
591
for key in keys:
592
# remove resblocks 23
593
if '.resblocks.23.' in key:
594
continue
595
new_key = convert_key(key)
596
if new_key is None:
597
continue
598
new_sd[new_key] = checkpoint[key]
599
600
# attnの変換
601
for key in keys:
602
if '.resblocks.23.' in key:
603
continue
604
if '.resblocks' in key and '.attn.in_proj_' in key:
605
# 三つに分割
606
values = torch.chunk(checkpoint[key], 3)
607
608
key_suffix = ".weight" if "weight" in key else ".bias"
609
key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.")
610
key_pfx = key_pfx.replace("_weight", "")
611
key_pfx = key_pfx.replace("_bias", "")
612
key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
613
new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
614
new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
615
new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
616
617
# position_idsの追加
618
new_sd["text_model.embeddings.position_ids"] = torch.Tensor([list(range(max_length))]).to(torch.int64)
619
return new_sd
620
621
# endregion
622
623
624
# region Diffusers->StableDiffusion の変換コード
625
# convert_diffusers_to_original_stable_diffusion をコピーしている(ASL 2.0)
626
627
def conv_transformer_to_linear(checkpoint):
628
keys = list(checkpoint.keys())
629
tf_keys = ["proj_in.weight", "proj_out.weight"]
630
for key in keys:
631
if ".".join(key.split(".")[-2:]) in tf_keys:
632
if checkpoint[key].ndim > 2:
633
checkpoint[key] = checkpoint[key][:, :, 0, 0]
634
635
636
def convert_unet_state_dict_to_sd(v2, unet_state_dict):
637
unet_conversion_map = [
638
# (stable-diffusion, HF Diffusers)
639
("time_embed.0.weight", "time_embedding.linear_1.weight"),
640
("time_embed.0.bias", "time_embedding.linear_1.bias"),
641
("time_embed.2.weight", "time_embedding.linear_2.weight"),
642
("time_embed.2.bias", "time_embedding.linear_2.bias"),
643
("input_blocks.0.0.weight", "conv_in.weight"),
644
("input_blocks.0.0.bias", "conv_in.bias"),
645
("out.0.weight", "conv_norm_out.weight"),
646
("out.0.bias", "conv_norm_out.bias"),
647
("out.2.weight", "conv_out.weight"),
648
("out.2.bias", "conv_out.bias"),
649
]
650
651
unet_conversion_map_resnet = [
652
# (stable-diffusion, HF Diffusers)
653
("in_layers.0", "norm1"),
654
("in_layers.2", "conv1"),
655
("out_layers.0", "norm2"),
656
("out_layers.3", "conv2"),
657
("emb_layers.1", "time_emb_proj"),
658
("skip_connection", "conv_shortcut"),
659
]
660
661
unet_conversion_map_layer = []
662
for i in range(4):
663
# loop over downblocks/upblocks
664
665
for j in range(2):
666
# loop over resnets/attentions for downblocks
667
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
668
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
669
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
670
671
if i < 3:
672
# no attention layers in down_blocks.3
673
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
674
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
675
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
676
677
for j in range(3):
678
# loop over resnets/attentions for upblocks
679
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
680
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
681
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
682
683
if i > 0:
684
# no attention layers in up_blocks.0
685
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
686
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
687
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
688
689
if i < 3:
690
# no downsample in down_blocks.3
691
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
692
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
693
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
694
695
# no upsample in up_blocks.3
696
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
697
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
698
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
699
700
hf_mid_atn_prefix = "mid_block.attentions.0."
701
sd_mid_atn_prefix = "middle_block.1."
702
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
703
704
for j in range(2):
705
hf_mid_res_prefix = f"mid_block.resnets.{j}."
706
sd_mid_res_prefix = f"middle_block.{2*j}."
707
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
708
709
# buyer beware: this is a *brittle* function,
710
# and correct output requires that all of these pieces interact in
711
# the exact order in which I have arranged them.
712
mapping = {k: k for k in unet_state_dict.keys()}
713
for sd_name, hf_name in unet_conversion_map:
714
mapping[hf_name] = sd_name
715
for k, v in mapping.items():
716
if "resnets" in k:
717
for sd_part, hf_part in unet_conversion_map_resnet:
718
v = v.replace(hf_part, sd_part)
719
mapping[k] = v
720
for k, v in mapping.items():
721
for sd_part, hf_part in unet_conversion_map_layer:
722
v = v.replace(hf_part, sd_part)
723
mapping[k] = v
724
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
725
726
if v2:
727
conv_transformer_to_linear(new_state_dict)
728
729
return new_state_dict
730
731
732
# ================#
733
# VAE Conversion #
734
# ================#
735
736
def reshape_weight_for_sd(w):
737
# convert HF linear weights to SD conv2d weights
738
return w.reshape(*w.shape, 1, 1)
739
740
741
def convert_vae_state_dict(vae_state_dict):
742
vae_conversion_map = [
743
# (stable-diffusion, HF Diffusers)
744
("nin_shortcut", "conv_shortcut"),
745
("norm_out", "conv_norm_out"),
746
("mid.attn_1.", "mid_block.attentions.0."),
747
]
748
749
for i in range(4):
750
# down_blocks have two resnets
751
for j in range(2):
752
hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
753
sd_down_prefix = f"encoder.down.{i}.block.{j}."
754
vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
755
756
if i < 3:
757
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
758
sd_downsample_prefix = f"down.{i}.downsample."
759
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
760
761
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
762
sd_upsample_prefix = f"up.{3-i}.upsample."
763
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
764
765
# up_blocks have three resnets
766
# also, up blocks in hf are numbered in reverse from sd
767
for j in range(3):
768
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
769
sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
770
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
771
772
# this part accounts for mid blocks in both the encoder and the decoder
773
for i in range(2):
774
hf_mid_res_prefix = f"mid_block.resnets.{i}."
775
sd_mid_res_prefix = f"mid.block_{i+1}."
776
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
777
778
vae_conversion_map_attn = [
779
# (stable-diffusion, HF Diffusers)
780
("norm.", "group_norm."),
781
("q.", "query."),
782
("k.", "key."),
783
("v.", "value."),
784
("proj_out.", "proj_attn."),
785
]
786
787
mapping = {k: k for k in vae_state_dict.keys()}
788
for k, v in mapping.items():
789
for sd_part, hf_part in vae_conversion_map:
790
v = v.replace(hf_part, sd_part)
791
mapping[k] = v
792
for k, v in mapping.items():
793
if "attentions" in k:
794
for sd_part, hf_part in vae_conversion_map_attn:
795
v = v.replace(hf_part, sd_part)
796
mapping[k] = v
797
new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
798
weights_to_convert = ["q", "k", "v", "proj_out"]
799
800
for k, v in new_state_dict.items():
801
for weight_name in weights_to_convert:
802
if f"mid.attn_1.{weight_name}.weight" in k:
803
new_state_dict[k] = reshape_weight_for_sd(v)
804
805
return new_state_dict
806
807
808
# endregion
809
810
811
def load_checkpoint_with_text_encoder_conversion(ckpt_path):
812
# text encoderの格納形式が違うモデルに対応する ('text_model'がない)
813
TEXT_ENCODER_KEY_REPLACEMENTS = [
814
('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'),
815
('cond_stage_model.transformer.encoder.', 'cond_stage_model.transformer.text_model.encoder.'),
816
('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.')
817
]
818
819
if args.from_safetensors:
820
from safetensors import safe_open
821
822
checkpoint = {}
823
with safe_open(ckpt_path, framework="pt", device="cuda") as f:
824
for key in f.keys():
825
checkpoint[key] = f.get_tensor(key)
826
state_dict = checkpoint
827
else:
828
checkpoint = torch.load(ckpt_path, map_location="cuda")
829
state_dict = checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint
830
831
#while "state_dict" in checkpoint:
832
# checkpoint = checkpoint["state_dict"]
833
#else:
834
# state_dict = checkpoint
835
836
key_reps = []
837
for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
838
for key in state_dict.keys():
839
if key.startswith(rep_from):
840
new_key = rep_to + key[len(rep_from):]
841
key_reps.append((key, new_key))
842
843
for key, new_key in key_reps:
844
state_dict[new_key] = state_dict[key]
845
del state_dict[key]
846
847
return checkpoint
848
849
850
# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
851
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
852
853
checkpoint = load_checkpoint_with_text_encoder_conversion(ckpt_path)
854
855
while "state_dict" in checkpoint:
856
checkpoint = checkpoint["state_dict"]
857
else:
858
state_dict = checkpoint
859
860
if dtype is not None:
861
for k, v in state_dict.items():
862
if type(v) is torch.Tensor:
863
state_dict[k] = v.to(dtype)
864
865
# Convert the UNet2DConditionModel model.
866
unet_config = create_unet_diffusers_config()
867
unet_config["upcast_attention"] = True
868
converted_unet_checkpoint = convert_ldm_unet_checkpoint(state_dict, unet_config)
869
870
unet = UNet2DConditionModel(**unet_config)
871
info = unet.load_state_dict(converted_unet_checkpoint)
872
873
874
# Convert the VAE model.
875
vae_config = create_vae_diffusers_config()
876
converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
877
878
vae = AutoencoderKL(**vae_config)
879
info = vae.load_state_dict(converted_vae_checkpoint)
880
881
882
# convert text_model
883
if v2:
884
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77)
885
cfg = CLIPTextConfig(
886
vocab_size=49408,
887
hidden_size=1024,
888
intermediate_size=4096,
889
num_hidden_layers=23,
890
num_attention_heads=16,
891
max_position_embeddings=77,
892
hidden_act="gelu",
893
layer_norm_eps=1e-05,
894
dropout=0.0,
895
attention_dropout=0.0,
896
initializer_range=0.02,
897
initializer_factor=1.0,
898
pad_token_id=1,
899
bos_token_id=0,
900
eos_token_id=2,
901
model_type="clip_text_model",
902
projection_dim=512,
903
torch_dtype="float32",
904
transformers_version="4.25.0.dev0",
905
)
906
text_model = CLIPTextModel._from_config(cfg)
907
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
908
else:
909
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
910
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
911
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
912
913
914
return text_model, vae, unet
915
916
917
def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False):
918
def convert_key(key):
919
# position_idsの除去
920
if ".position_ids" in key:
921
return None
922
923
# common
924
key = key.replace("text_model.encoder.", "transformer.")
925
key = key.replace("text_model.", "")
926
if "layers" in key:
927
# resblocks conversion
928
key = key.replace(".layers.", ".resblocks.")
929
if ".layer_norm" in key:
930
key = key.replace(".layer_norm", ".ln_")
931
elif ".mlp." in key:
932
key = key.replace(".fc1.", ".c_fc.")
933
key = key.replace(".fc2.", ".c_proj.")
934
elif '.self_attn.out_proj' in key:
935
key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
936
elif '.self_attn.' in key:
937
key = None # 特殊なので後で処理する
938
else:
939
raise ValueError(f"unexpected key in DiffUsers model: {key}")
940
elif '.position_embedding' in key:
941
key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
942
elif '.token_embedding' in key:
943
key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
944
elif 'final_layer_norm' in key:
945
key = key.replace("final_layer_norm", "ln_final")
946
return key
947
948
keys = list(checkpoint.keys())
949
new_sd = {}
950
for key in keys:
951
new_key = convert_key(key)
952
if new_key is None:
953
continue
954
new_sd[new_key] = checkpoint[key]
955
956
# attnの変換
957
for key in keys:
958
if 'layers' in key and 'q_proj' in key:
959
# 三つを結合
960
key_q = key
961
key_k = key.replace("q_proj", "k_proj")
962
key_v = key.replace("q_proj", "v_proj")
963
964
value_q = checkpoint[key_q]
965
value_k = checkpoint[key_k]
966
value_v = checkpoint[key_v]
967
value = torch.cat([value_q, value_k, value_v])
968
969
new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
970
new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
971
new_sd[new_key] = value
972
973
# 最後の層などを捏造するか
974
if make_dummy_weights:
975
976
keys = list(new_sd.keys())
977
for key in keys:
978
if key.startswith("transformer.resblocks.22."):
979
new_sd[key.replace(".22.", ".23.")] = new_sd[key]
980
981
# Diffusersに含まれない重みを作っておく
982
new_sd['text_projection'] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device)
983
new_sd['logit_scale'] = torch.tensor(1)
984
985
return new_sd
986
987
988
def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None):
989
if ckpt_path is not None:
990
# epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む
991
checkpoint = load_checkpoint_with_text_encoder_conversion(ckpt_path)
992
state_dict = checkpoint["state_dict"]
993
strict = True
994
else:
995
# 新しく作る
996
checkpoint = {}
997
state_dict = {}
998
strict = False
999
1000
def update_sd(prefix, sd):
1001
for k, v in sd.items():
1002
key = prefix + k
1003
assert not strict or key in state_dict, f"Illegal key in save SD: {key}"
1004
if save_dtype is not None:
1005
v = v.detach().clone().to("cpu").to(save_dtype)
1006
state_dict[key] = v
1007
1008
# Convert the UNet model
1009
unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict())
1010
update_sd("model.diffusion_model.", unet_state_dict)
1011
1012
# Convert the text encoder model
1013
if v2:
1014
make_dummy = ckpt_path is None # 参照元のcheckpointがない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる
1015
text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy)
1016
update_sd("cond_stage_model.model.", text_enc_dict)
1017
else:
1018
text_enc_dict = text_encoder.state_dict()
1019
update_sd("cond_stage_model.transformer.", text_enc_dict)
1020
1021
# Convert the VAE
1022
if vae is not None:
1023
vae_dict = convert_vae_state_dict(vae.state_dict())
1024
update_sd("first_stage_model.", vae_dict)
1025
1026
# Put together new checkpoint
1027
key_count = len(state_dict.keys())
1028
new_ckpt = {'state_dict': state_dict}
1029
1030
if 'epoch' in checkpoint:
1031
epochs += checkpoint['epoch']
1032
if 'global_step' in checkpoint:
1033
steps += checkpoint['global_step']
1034
1035
new_ckpt['epoch'] = epochs
1036
new_ckpt['global_step'] = steps
1037
1038
torch.save(new_ckpt, output_file)
1039
1040
return key_count
1041
1042
1043
def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None):
1044
if vae is None:
1045
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
1046
pipeline = StableDiffusionPipeline(
1047
unet=unet,
1048
text_encoder=text_encoder,
1049
vae=vae,
1050
scheduler=DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler"),
1051
tokenizer=CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer"),
1052
safety_checker=None,
1053
feature_extractor=None,
1054
)
1055
pipeline.save_pretrained(output_dir)
1056
1057
1058
1059
def convert(args):
1060
print("Converting to Diffusers ...")
1061
load_dtype = torch.float16 if args.fp16 else None
1062
1063
save_dtype = None
1064
if args.fp16:
1065
save_dtype = torch.float16
1066
elif args.bf16:
1067
save_dtype = torch.bfloat16
1068
elif args.float:
1069
save_dtype = torch.float
1070
1071
is_load_ckpt = os.path.isfile(args.model_to_load)
1072
is_save_ckpt = len(os.path.splitext(args.model_to_save)[1]) > 0
1073
1074
assert not is_load_ckpt or args.v1 != args.v2, f"v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です"
1075
assert is_save_ckpt or args.reference_model is not None, f"reference model is required to save as Diffusers / Diffusers形式での保存には参照モデルが必要です"
1076
1077
# モデルを読み込む
1078
msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else ""))
1079
1080
1081
if is_load_ckpt:
1082
v2_model = args.v2
1083
text_encoder, vae, unet = load_models_from_stable_diffusion_checkpoint(v2_model, args.model_to_load)
1084
else:
1085
pipe = StableDiffusionPipeline.from_pretrained(args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None)
1086
text_encoder = pipe.text_encoder
1087
vae = pipe.vae
1088
unet = pipe.unet
1089
1090
if args.v1 == args.v2:
1091
# 自動判定する
1092
v2_model = unet.config.cross_attention_dim == 1024
1093
#print("checking model version: model is " + ('v2' if v2_model else 'v1'))
1094
else:
1095
v2_model = args.v1
1096
1097
# 変換して保存する
1098
msg = ("checkpoint" + ("" if save_dtype is None else f" in {save_dtype}")) if is_save_ckpt else "Diffusers"
1099
1100
1101
if is_save_ckpt:
1102
original_model = args.model_to_load if is_load_ckpt else None
1103
key_count = save_stable_diffusion_checkpoint(v2_model, args.model_to_save, text_encoder, unet,
1104
original_model, args.epoch, args.global_step, save_dtype, vae)
1105
1106
else:
1107
save_diffusers_checkpoint(v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae)
1108
1109
1110
1111
if __name__ == '__main__':
1112
parser = argparse.ArgumentParser()
1113
parser.add_argument("--v1", action='store_true',
1114
help='load v1.x model (v1 or v2 is required to load checkpoint) / 1.xのモデルを読み込む')
1115
parser.add_argument("--v2", action='store_true',
1116
help='load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み込む')
1117
parser.add_argument("--fp16", action='store_true',
1118
help='load as fp16 (Diffusers only) and save as fp16 (checkpoint only) / fp16形式で読み込み(Diffusers形式のみ対応)、保存する(checkpointのみ対応)')
1119
parser.add_argument("--bf16", action='store_true', help='save as bf16 (checkpoint only) / bf16形式で保存する(checkpointのみ対応)')
1120
parser.add_argument("--float", action='store_true',
1121
help='save as float (checkpoint only) / float(float32)形式で保存する(checkpointのみ対応)')
1122
parser.add_argument("--epoch", type=int, default=0, help='epoch to write to checkpoint / checkpointに記録するepoch数の値')
1123
parser.add_argument("--global_step", type=int, default=0,
1124
help='global_step to write to checkpoint / checkpointに記録するglobal_stepの値')
1125
parser.add_argument("--reference_model", type=str, default=None,
1126
help="reference model for schduler/tokenizer, required in saving Diffusers, copy schduler/tokenizer from this / scheduler/tokenizerのコピー元のDiffusersモデル、Diffusers形式で保存するときに必要")
1127
1128
parser.add_argument("model_to_load", type=str, default=None,
1129
help="model to load: checkpoint file or Diffusers model's directory / 読み込むモデル、checkpointかDiffusers形式モデルのディレクトリ")
1130
parser.add_argument("model_to_save", type=str, default=None,
1131
help="model to save: checkpoint (with extension) or Diffusers model's directory (without extension) / 変換後のモデル、拡張子がある場合はcheckpoint、ない場合はDiffusesモデルとして保存")
1132
parser.add_argument(
1133
"--from_safetensors",
1134
action="store_true",
1135
help="If `--checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.",
1136
)
1137
args = parser.parse_args()
1138
convert(args)
1139
1140