Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TheLastBen
GitHub Repository: TheLastBen/fast-stable-diffusion
Path: blob/main/Dreambooth/convertodiffv2.py
540 views
1
import argparse
2
import os
3
import torch
4
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig
5
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
6
7
8
# DiffUsers版StableDiffusionのモデルパラメータ
9
NUM_TRAIN_TIMESTEPS = 1000
10
BETA_START = 0.00085
11
BETA_END = 0.0120
12
13
UNET_PARAMS_MODEL_CHANNELS = 320
14
UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
15
UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
16
UNET_PARAMS_IMAGE_SIZE = 64
17
UNET_PARAMS_IN_CHANNELS = 4
18
UNET_PARAMS_OUT_CHANNELS = 4
19
UNET_PARAMS_NUM_RES_BLOCKS = 2
20
UNET_PARAMS_CONTEXT_DIM = 768
21
UNET_PARAMS_NUM_HEADS = 8
22
23
VAE_PARAMS_Z_CHANNELS = 4
24
VAE_PARAMS_RESOLUTION = 768
25
VAE_PARAMS_IN_CHANNELS = 3
26
VAE_PARAMS_OUT_CH = 3
27
VAE_PARAMS_CH = 128
28
VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
29
VAE_PARAMS_NUM_RES_BLOCKS = 2
30
31
# V2
32
V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
33
V2_UNET_PARAMS_CONTEXT_DIM = 1024
34
35
36
# region StableDiffusion->Diffusersの変換コード
37
# convert_original_stable_diffusion_to_diffusers をコピーしている(ASL 2.0)
38
39
40
def shave_segments(path, n_shave_prefix_segments=1):
41
"""
42
Removes segments. Positive values shave the first segments, negative shave the last segments.
43
"""
44
if n_shave_prefix_segments >= 0:
45
return ".".join(path.split(".")[n_shave_prefix_segments:])
46
else:
47
return ".".join(path.split(".")[:n_shave_prefix_segments])
48
49
50
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
51
"""
52
Updates paths inside resnets to the new naming scheme (local renaming)
53
"""
54
mapping = []
55
for old_item in old_list:
56
new_item = old_item.replace("in_layers.0", "norm1")
57
new_item = new_item.replace("in_layers.2", "conv1")
58
59
new_item = new_item.replace("out_layers.0", "norm2")
60
new_item = new_item.replace("out_layers.3", "conv2")
61
62
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
63
new_item = new_item.replace("skip_connection", "conv_shortcut")
64
65
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
66
67
mapping.append({"old": old_item, "new": new_item})
68
69
return mapping
70
71
72
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
73
"""
74
Updates paths inside resnets to the new naming scheme (local renaming)
75
"""
76
mapping = []
77
for old_item in old_list:
78
new_item = old_item
79
80
new_item = new_item.replace("nin_shortcut", "conv_shortcut")
81
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
82
83
mapping.append({"old": old_item, "new": new_item})
84
85
return mapping
86
87
88
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
89
"""
90
Updates paths inside attentions to the new naming scheme (local renaming)
91
"""
92
mapping = []
93
for old_item in old_list:
94
new_item = old_item
95
96
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
97
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
98
99
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
100
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
101
102
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
103
104
mapping.append({"old": old_item, "new": new_item})
105
106
return mapping
107
108
109
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
110
"""
111
Updates paths inside attentions to the new naming scheme (local renaming)
112
"""
113
mapping = []
114
for old_item in old_list:
115
new_item = old_item
116
117
new_item = new_item.replace("norm.weight", "group_norm.weight")
118
new_item = new_item.replace("norm.bias", "group_norm.bias")
119
120
new_item = new_item.replace("q.weight", "query.weight")
121
new_item = new_item.replace("q.bias", "query.bias")
122
123
new_item = new_item.replace("k.weight", "key.weight")
124
new_item = new_item.replace("k.bias", "key.bias")
125
126
new_item = new_item.replace("v.weight", "value.weight")
127
new_item = new_item.replace("v.bias", "value.bias")
128
129
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
130
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
131
132
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
133
134
mapping.append({"old": old_item, "new": new_item})
135
136
return mapping
137
138
139
def assign_to_checkpoint(
140
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
141
):
142
"""
143
This does the final conversion step: take locally converted weights and apply a global renaming
144
to them. It splits attention layers, and takes into account additional replacements
145
that may arise.
146
147
Assigns the weights to the new checkpoint.
148
"""
149
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
150
151
# Splits the attention layers into three variables.
152
if attention_paths_to_split is not None:
153
for path, path_map in attention_paths_to_split.items():
154
old_tensor = old_checkpoint[path]
155
channels = old_tensor.shape[0] // 3
156
157
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
158
159
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
160
161
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
162
query, key, value = old_tensor.split(channels // num_heads, dim=1)
163
164
checkpoint[path_map["query"]] = query.reshape(target_shape)
165
checkpoint[path_map["key"]] = key.reshape(target_shape)
166
checkpoint[path_map["value"]] = value.reshape(target_shape)
167
168
for path in paths:
169
new_path = path["new"]
170
171
# These have already been assigned
172
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
173
continue
174
175
# Global renaming happens here
176
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
177
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
178
new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
179
180
if additional_replacements is not None:
181
for replacement in additional_replacements:
182
new_path = new_path.replace(replacement["old"], replacement["new"])
183
184
# proj_attn.weight has to be converted from conv 1D to linear
185
if "proj_attn.weight" in new_path:
186
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
187
else:
188
checkpoint[new_path] = old_checkpoint[path["old"]]
189
190
191
def conv_attn_to_linear(checkpoint):
192
keys = list(checkpoint.keys())
193
attn_keys = ["query.weight", "key.weight", "value.weight"]
194
for key in keys:
195
if ".".join(key.split(".")[-2:]) in attn_keys:
196
if checkpoint[key].ndim > 2:
197
checkpoint[key] = checkpoint[key][:, :, 0, 0]
198
elif "proj_attn.weight" in key:
199
if checkpoint[key].ndim > 2:
200
checkpoint[key] = checkpoint[key][:, :, 0]
201
202
203
def linear_transformer_to_conv(checkpoint):
204
keys = list(checkpoint.keys())
205
tf_keys = ["proj_in.weight", "proj_out.weight"]
206
for key in keys:
207
if ".".join(key.split(".")[-2:]) in tf_keys:
208
if checkpoint[key].ndim == 2:
209
checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
210
211
212
def convert_ldm_unet_checkpoint(checkpoint, config):
213
"""
214
Takes a state dict and a config, and returns a converted checkpoint.
215
"""
216
217
# extract state_dict for UNet
218
unet_state_dict = {}
219
keys = list(checkpoint.keys())
220
221
unet_key = "model.diffusion_model."
222
223
for key in keys:
224
if key.startswith(unet_key):
225
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
226
227
new_checkpoint = {}
228
229
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
230
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
231
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
232
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
233
234
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
235
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
236
237
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
238
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
239
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
240
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
241
242
# Retrieves the keys for the input blocks only
243
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
244
input_blocks = {
245
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
246
for layer_id in range(num_input_blocks)
247
}
248
249
# Retrieves the keys for the middle blocks only
250
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
251
middle_blocks = {
252
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
253
for layer_id in range(num_middle_blocks)
254
}
255
256
# Retrieves the keys for the output blocks only
257
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
258
output_blocks = {
259
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
260
for layer_id in range(num_output_blocks)
261
}
262
263
for i in range(1, num_input_blocks):
264
block_id = (i - 1) // (config["layers_per_block"] + 1)
265
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
266
267
resnets = [
268
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
269
]
270
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
271
272
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
273
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
274
f"input_blocks.{i}.0.op.weight"
275
)
276
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
277
f"input_blocks.{i}.0.op.bias"
278
)
279
280
paths = renew_resnet_paths(resnets)
281
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
282
assign_to_checkpoint(
283
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
284
)
285
286
if len(attentions):
287
paths = renew_attention_paths(attentions)
288
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
289
assign_to_checkpoint(
290
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
291
)
292
293
resnet_0 = middle_blocks[0]
294
attentions = middle_blocks[1]
295
resnet_1 = middle_blocks[2]
296
297
resnet_0_paths = renew_resnet_paths(resnet_0)
298
assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
299
300
resnet_1_paths = renew_resnet_paths(resnet_1)
301
assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
302
303
attentions_paths = renew_attention_paths(attentions)
304
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
305
assign_to_checkpoint(
306
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
307
)
308
309
for i in range(num_output_blocks):
310
block_id = i // (config["layers_per_block"] + 1)
311
layer_in_block_id = i % (config["layers_per_block"] + 1)
312
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
313
output_block_list = {}
314
315
for layer in output_block_layers:
316
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
317
if layer_id in output_block_list:
318
output_block_list[layer_id].append(layer_name)
319
else:
320
output_block_list[layer_id] = [layer_name]
321
322
if len(output_block_list) > 1:
323
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
324
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
325
326
resnet_0_paths = renew_resnet_paths(resnets)
327
paths = renew_resnet_paths(resnets)
328
329
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
330
assign_to_checkpoint(
331
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
332
)
333
334
output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
335
if ["conv.bias", "conv.weight"] in output_block_list.values():
336
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
337
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
338
f"output_blocks.{i}.{index}.conv.weight"
339
]
340
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
341
f"output_blocks.{i}.{index}.conv.bias"
342
]
343
344
# Clear attentions as they have been attributed above.
345
if len(attentions) == 2:
346
attentions = []
347
348
if len(attentions):
349
paths = renew_attention_paths(attentions)
350
meta_path = {
351
"old": f"output_blocks.{i}.1",
352
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
353
}
354
assign_to_checkpoint(
355
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
356
)
357
else:
358
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
359
for path in resnet_0_paths:
360
old_path = ".".join(["output_blocks", str(i), path["old"]])
361
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
362
363
new_checkpoint[new_path] = unet_state_dict[old_path]
364
365
return new_checkpoint
366
367
368
def convert_ldm_vae_checkpoint(checkpoint, config):
369
# extract state dict for VAE
370
vae_state_dict = {}
371
vae_key = "first_stage_model."
372
keys = list(checkpoint.keys())
373
for key in keys:
374
if key.startswith(vae_key):
375
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
376
# if len(vae_state_dict) == 0:
377
# # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict
378
# vae_state_dict = checkpoint
379
380
new_checkpoint = {}
381
382
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
383
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
384
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
385
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
386
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
387
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
388
389
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
390
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
391
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
392
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
393
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
394
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
395
396
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
397
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
398
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
399
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
400
401
# Retrieves the keys for the encoder down blocks only
402
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
403
down_blocks = {
404
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
405
}
406
407
# Retrieves the keys for the decoder up blocks only
408
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
409
up_blocks = {
410
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
411
}
412
413
for i in range(num_down_blocks):
414
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
415
416
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
417
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
418
f"encoder.down.{i}.downsample.conv.weight"
419
)
420
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
421
f"encoder.down.{i}.downsample.conv.bias"
422
)
423
424
paths = renew_vae_resnet_paths(resnets)
425
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
426
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
427
428
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
429
num_mid_res_blocks = 2
430
for i in range(1, num_mid_res_blocks + 1):
431
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
432
433
paths = renew_vae_resnet_paths(resnets)
434
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
435
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
436
437
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
438
paths = renew_vae_attention_paths(mid_attentions)
439
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
440
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
441
conv_attn_to_linear(new_checkpoint)
442
443
for i in range(num_up_blocks):
444
block_id = num_up_blocks - 1 - i
445
resnets = [
446
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
447
]
448
449
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
450
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
451
f"decoder.up.{block_id}.upsample.conv.weight"
452
]
453
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
454
f"decoder.up.{block_id}.upsample.conv.bias"
455
]
456
457
paths = renew_vae_resnet_paths(resnets)
458
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
459
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
460
461
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
462
num_mid_res_blocks = 2
463
for i in range(1, num_mid_res_blocks + 1):
464
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
465
466
paths = renew_vae_resnet_paths(resnets)
467
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
468
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
469
470
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
471
paths = renew_vae_attention_paths(mid_attentions)
472
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
473
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
474
conv_attn_to_linear(new_checkpoint)
475
return new_checkpoint
476
477
478
def create_unet_diffusers_config():
479
"""
480
Creates a config for the diffusers based on the config of the LDM model.
481
"""
482
#unet_params = original_config.model.params.unet_config.params
483
484
#vae_params = original_config.model.params.first_stage_config.params.ddconfig
485
486
block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT]
487
488
down_block_types = []
489
resolution = 1
490
for i in range(len(block_out_channels)):
491
block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D"
492
down_block_types.append(block_type)
493
if i != len(block_out_channels) - 1:
494
resolution *= 2
495
496
up_block_types = []
497
for i in range(len(block_out_channels)):
498
block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D"
499
up_block_types.append(block_type)
500
resolution //= 2
501
502
class_embed_type = None
503
projection_class_embeddings_input_dim = None
504
505
config = dict(
506
sample_size=UNET_PARAMS_IMAGE_SIZE,
507
in_channels=UNET_PARAMS_IN_CHANNELS,
508
out_channels=UNET_PARAMS_OUT_CHANNELS,
509
down_block_types=tuple(down_block_types),
510
up_block_types=tuple(up_block_types),
511
block_out_channels=tuple(block_out_channels),
512
layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
513
use_linear_projection=True,
514
cross_attention_dim=V2_UNET_PARAMS_CONTEXT_DIM,
515
attention_head_dim=V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
516
)
517
518
return config
519
520
521
def create_vae_diffusers_config():
522
"""
523
Creates a config for the diffusers based on the config of the LDM model.
524
"""
525
# vae_params = original_config.model.params.first_stage_config.params.ddconfig
526
# _ = original_config.model.params.first_stage_config.params.embed_dim
527
block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT]
528
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
529
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
530
531
config = dict(
532
sample_size=VAE_PARAMS_RESOLUTION,
533
in_channels=VAE_PARAMS_IN_CHANNELS,
534
out_channels=VAE_PARAMS_OUT_CH,
535
down_block_types=tuple(down_block_types),
536
up_block_types=tuple(up_block_types),
537
block_out_channels=tuple(block_out_channels),
538
latent_channels=VAE_PARAMS_Z_CHANNELS,
539
layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS,
540
)
541
return config
542
543
544
def convert_ldm_clip_checkpoint_v1(checkpoint):
545
keys = list(checkpoint.keys())
546
text_model_dict = {}
547
for key in keys:
548
if key.startswith("cond_stage_model.transformer"):
549
text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key]
550
return text_model_dict
551
552
553
def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
554
# 嫌になるくらい違うぞ!
555
def convert_key(key):
556
if not key.startswith("cond_stage_model"):
557
return None
558
559
# common conversion
560
key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.")
561
key = key.replace("cond_stage_model.model.", "text_model.")
562
563
if "resblocks" in key:
564
# resblocks conversion
565
key = key.replace(".resblocks.", ".layers.")
566
if ".ln_" in key:
567
key = key.replace(".ln_", ".layer_norm")
568
elif ".mlp." in key:
569
key = key.replace(".c_fc.", ".fc1.")
570
key = key.replace(".c_proj.", ".fc2.")
571
elif '.attn.out_proj' in key:
572
key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
573
elif '.attn.in_proj' in key:
574
key = None # 特殊なので後で処理する
575
else:
576
raise ValueError(f"unexpected key in SD: {key}")
577
elif '.positional_embedding' in key:
578
key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
579
elif '.text_projection' in key:
580
key = None # 使われない???
581
elif '.logit_scale' in key:
582
key = None # 使われない???
583
elif '.token_embedding' in key:
584
key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
585
elif '.ln_final' in key:
586
key = key.replace(".ln_final", ".final_layer_norm")
587
return key
588
589
keys = list(checkpoint.keys())
590
new_sd = {}
591
for key in keys:
592
# remove resblocks 23
593
if '.resblocks.23.' in key:
594
continue
595
new_key = convert_key(key)
596
if new_key is None:
597
continue
598
new_sd[new_key] = checkpoint[key]
599
600
# attnの変換
601
for key in keys:
602
if '.resblocks.23.' in key:
603
continue
604
if '.resblocks' in key and '.attn.in_proj_' in key:
605
# 三つに分割
606
values = torch.chunk(checkpoint[key], 3)
607
608
key_suffix = ".weight" if "weight" in key else ".bias"
609
key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.")
610
key_pfx = key_pfx.replace("_weight", "")
611
key_pfx = key_pfx.replace("_bias", "")
612
key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
613
new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
614
new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
615
new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
616
617
# position_idsの追加
618
new_sd["text_model.embeddings.position_ids"] = torch.Tensor([list(range(max_length))]).to(torch.int64)
619
return new_sd
620
621
# endregion
622
623
624
# region Diffusers->StableDiffusion の変換コード
625
# convert_diffusers_to_original_stable_diffusion をコピーしている(ASL 2.0)
626
627
def conv_transformer_to_linear(checkpoint):
628
keys = list(checkpoint.keys())
629
tf_keys = ["proj_in.weight", "proj_out.weight"]
630
for key in keys:
631
if ".".join(key.split(".")[-2:]) in tf_keys:
632
if checkpoint[key].ndim > 2:
633
checkpoint[key] = checkpoint[key][:, :, 0, 0]
634
635
636
def convert_unet_state_dict_to_sd(v2, unet_state_dict):
637
unet_conversion_map = [
638
# (stable-diffusion, HF Diffusers)
639
("time_embed.0.weight", "time_embedding.linear_1.weight"),
640
("time_embed.0.bias", "time_embedding.linear_1.bias"),
641
("time_embed.2.weight", "time_embedding.linear_2.weight"),
642
("time_embed.2.bias", "time_embedding.linear_2.bias"),
643
("input_blocks.0.0.weight", "conv_in.weight"),
644
("input_blocks.0.0.bias", "conv_in.bias"),
645
("out.0.weight", "conv_norm_out.weight"),
646
("out.0.bias", "conv_norm_out.bias"),
647
("out.2.weight", "conv_out.weight"),
648
("out.2.bias", "conv_out.bias"),
649
]
650
651
unet_conversion_map_resnet = [
652
# (stable-diffusion, HF Diffusers)
653
("in_layers.0", "norm1"),
654
("in_layers.2", "conv1"),
655
("out_layers.0", "norm2"),
656
("out_layers.3", "conv2"),
657
("emb_layers.1", "time_emb_proj"),
658
("skip_connection", "conv_shortcut"),
659
]
660
661
unet_conversion_map_layer = []
662
for i in range(4):
663
# loop over downblocks/upblocks
664
665
for j in range(2):
666
# loop over resnets/attentions for downblocks
667
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
668
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
669
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
670
671
if i < 3:
672
# no attention layers in down_blocks.3
673
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
674
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
675
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
676
677
for j in range(3):
678
# loop over resnets/attentions for upblocks
679
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
680
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
681
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
682
683
if i > 0:
684
# no attention layers in up_blocks.0
685
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
686
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
687
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
688
689
if i < 3:
690
# no downsample in down_blocks.3
691
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
692
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
693
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
694
695
# no upsample in up_blocks.3
696
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
697
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
698
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
699
700
hf_mid_atn_prefix = "mid_block.attentions.0."
701
sd_mid_atn_prefix = "middle_block.1."
702
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
703
704
for j in range(2):
705
hf_mid_res_prefix = f"mid_block.resnets.{j}."
706
sd_mid_res_prefix = f"middle_block.{2*j}."
707
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
708
709
# buyer beware: this is a *brittle* function,
710
# and correct output requires that all of these pieces interact in
711
# the exact order in which I have arranged them.
712
mapping = {k: k for k in unet_state_dict.keys()}
713
for sd_name, hf_name in unet_conversion_map:
714
mapping[hf_name] = sd_name
715
for k, v in mapping.items():
716
if "resnets" in k:
717
for sd_part, hf_part in unet_conversion_map_resnet:
718
v = v.replace(hf_part, sd_part)
719
mapping[k] = v
720
for k, v in mapping.items():
721
for sd_part, hf_part in unet_conversion_map_layer:
722
v = v.replace(hf_part, sd_part)
723
mapping[k] = v
724
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
725
726
if v2:
727
conv_transformer_to_linear(new_state_dict)
728
729
return new_state_dict
730
731
732
# ================#
733
# VAE Conversion #
734
# ================#
735
736
def reshape_weight_for_sd(w):
737
# convert HF linear weights to SD conv2d weights
738
return w.reshape(*w.shape, 1, 1)
739
740
741
def convert_vae_state_dict(vae_state_dict):
742
vae_conversion_map = [
743
# (stable-diffusion, HF Diffusers)
744
("nin_shortcut", "conv_shortcut"),
745
("norm_out", "conv_norm_out"),
746
("mid.attn_1.", "mid_block.attentions.0."),
747
]
748
749
for i in range(4):
750
# down_blocks have two resnets
751
for j in range(2):
752
hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
753
sd_down_prefix = f"encoder.down.{i}.block.{j}."
754
vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
755
756
if i < 3:
757
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
758
sd_downsample_prefix = f"down.{i}.downsample."
759
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
760
761
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
762
sd_upsample_prefix = f"up.{3-i}.upsample."
763
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
764
765
# up_blocks have three resnets
766
# also, up blocks in hf are numbered in reverse from sd
767
for j in range(3):
768
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
769
sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
770
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
771
772
# this part accounts for mid blocks in both the encoder and the decoder
773
for i in range(2):
774
hf_mid_res_prefix = f"mid_block.resnets.{i}."
775
sd_mid_res_prefix = f"mid.block_{i+1}."
776
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
777
778
vae_conversion_map_attn = [
779
# (stable-diffusion, HF Diffusers)
780
("norm.", "group_norm."),
781
("q.", "query."),
782
("k.", "key."),
783
("v.", "value."),
784
("proj_out.", "proj_attn."),
785
]
786
787
mapping = {k: k for k in vae_state_dict.keys()}
788
for k, v in mapping.items():
789
for sd_part, hf_part in vae_conversion_map:
790
v = v.replace(hf_part, sd_part)
791
mapping[k] = v
792
for k, v in mapping.items():
793
if "attentions" in k:
794
for sd_part, hf_part in vae_conversion_map_attn:
795
v = v.replace(hf_part, sd_part)
796
mapping[k] = v
797
new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
798
weights_to_convert = ["q", "k", "v", "proj_out"]
799
800
for k, v in new_state_dict.items():
801
for weight_name in weights_to_convert:
802
if f"mid.attn_1.{weight_name}.weight" in k:
803
new_state_dict[k] = reshape_weight_for_sd(v)
804
805
return new_state_dict
806
807
808
# endregion
809
810
811
def load_checkpoint_with_text_encoder_conversion(ckpt_path):
812
# text encoderの格納形式が違うモデルに対応する ('text_model'がない)
813
TEXT_ENCODER_KEY_REPLACEMENTS = [
814
('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'),
815
('cond_stage_model.transformer.encoder.', 'cond_stage_model.transformer.text_model.encoder.'),
816
('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.')
817
]
818
819
if args.from_safetensors:
820
from safetensors import safe_open
821
822
checkpoint = {}
823
with safe_open(ckpt_path, framework="pt", device="cuda") as f:
824
for key in f.keys():
825
checkpoint[key] = f.get_tensor(key)
826
state_dict = checkpoint
827
else:
828
checkpoint = torch.load(ckpt_path, map_location="cuda")
829
state_dict = checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint
830
831
#while "state_dict" in checkpoint:
832
# checkpoint = checkpoint["state_dict"]
833
#else:
834
# state_dict = checkpoint
835
836
key_reps = []
837
for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
838
for key in state_dict.keys():
839
if key.startswith(rep_from):
840
new_key = rep_to + key[len(rep_from):]
841
key_reps.append((key, new_key))
842
843
for key, new_key in key_reps:
844
state_dict[new_key] = state_dict[key]
845
del state_dict[key]
846
847
return checkpoint
848
849
850
# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
851
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
852
853
checkpoint = load_checkpoint_with_text_encoder_conversion(ckpt_path)
854
855
while "state_dict" in checkpoint:
856
checkpoint = checkpoint["state_dict"]
857
else:
858
state_dict = checkpoint
859
860
if dtype is not None:
861
for k, v in state_dict.items():
862
if type(v) is torch.Tensor:
863
state_dict[k] = v.to(dtype)
864
865
# Convert the UNet2DConditionModel model.
866
unet_config = create_unet_diffusers_config()
867
converted_unet_checkpoint = convert_ldm_unet_checkpoint(state_dict, unet_config)
868
869
unet = UNet2DConditionModel(**unet_config)
870
info = unet.load_state_dict(converted_unet_checkpoint)
871
872
873
# Convert the VAE model.
874
vae_config = create_vae_diffusers_config()
875
converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
876
877
vae = AutoencoderKL(**vae_config)
878
info = vae.load_state_dict(converted_vae_checkpoint)
879
880
881
# convert text_model
882
if v2:
883
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77)
884
cfg = CLIPTextConfig(
885
vocab_size=49408,
886
hidden_size=1024,
887
intermediate_size=4096,
888
num_hidden_layers=23,
889
num_attention_heads=16,
890
max_position_embeddings=77,
891
hidden_act="gelu",
892
layer_norm_eps=1e-05,
893
dropout=0.0,
894
attention_dropout=0.0,
895
initializer_range=0.02,
896
initializer_factor=1.0,
897
pad_token_id=1,
898
bos_token_id=0,
899
eos_token_id=2,
900
model_type="clip_text_model",
901
projection_dim=512,
902
torch_dtype="float32",
903
transformers_version="4.25.0.dev0",
904
)
905
text_model = CLIPTextModel._from_config(cfg)
906
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
907
else:
908
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
909
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
910
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
911
912
913
return text_model, vae, unet
914
915
916
def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False):
917
def convert_key(key):
918
# position_idsの除去
919
if ".position_ids" in key:
920
return None
921
922
# common
923
key = key.replace("text_model.encoder.", "transformer.")
924
key = key.replace("text_model.", "")
925
if "layers" in key:
926
# resblocks conversion
927
key = key.replace(".layers.", ".resblocks.")
928
if ".layer_norm" in key:
929
key = key.replace(".layer_norm", ".ln_")
930
elif ".mlp." in key:
931
key = key.replace(".fc1.", ".c_fc.")
932
key = key.replace(".fc2.", ".c_proj.")
933
elif '.self_attn.out_proj' in key:
934
key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
935
elif '.self_attn.' in key:
936
key = None # 特殊なので後で処理する
937
else:
938
raise ValueError(f"unexpected key in DiffUsers model: {key}")
939
elif '.position_embedding' in key:
940
key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
941
elif '.token_embedding' in key:
942
key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
943
elif 'final_layer_norm' in key:
944
key = key.replace("final_layer_norm", "ln_final")
945
return key
946
947
keys = list(checkpoint.keys())
948
new_sd = {}
949
for key in keys:
950
new_key = convert_key(key)
951
if new_key is None:
952
continue
953
new_sd[new_key] = checkpoint[key]
954
955
# attnの変換
956
for key in keys:
957
if 'layers' in key and 'q_proj' in key:
958
# 三つを結合
959
key_q = key
960
key_k = key.replace("q_proj", "k_proj")
961
key_v = key.replace("q_proj", "v_proj")
962
963
value_q = checkpoint[key_q]
964
value_k = checkpoint[key_k]
965
value_v = checkpoint[key_v]
966
value = torch.cat([value_q, value_k, value_v])
967
968
new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
969
new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
970
new_sd[new_key] = value
971
972
# 最後の層などを捏造するか
973
if make_dummy_weights:
974
975
keys = list(new_sd.keys())
976
for key in keys:
977
if key.startswith("transformer.resblocks.22."):
978
new_sd[key.replace(".22.", ".23.")] = new_sd[key]
979
980
# Diffusersに含まれない重みを作っておく
981
new_sd['text_projection'] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device)
982
new_sd['logit_scale'] = torch.tensor(1)
983
984
return new_sd
985
986
987
def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None):
988
if ckpt_path is not None:
989
# epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む
990
checkpoint = load_checkpoint_with_text_encoder_conversion(ckpt_path)
991
state_dict = checkpoint["state_dict"]
992
strict = True
993
else:
994
# 新しく作る
995
checkpoint = {}
996
state_dict = {}
997
strict = False
998
999
def update_sd(prefix, sd):
1000
for k, v in sd.items():
1001
key = prefix + k
1002
assert not strict or key in state_dict, f"Illegal key in save SD: {key}"
1003
if save_dtype is not None:
1004
v = v.detach().clone().to("cpu").to(save_dtype)
1005
state_dict[key] = v
1006
1007
# Convert the UNet model
1008
unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict())
1009
update_sd("model.diffusion_model.", unet_state_dict)
1010
1011
# Convert the text encoder model
1012
if v2:
1013
make_dummy = ckpt_path is None # 参照元のcheckpointがない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる
1014
text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy)
1015
update_sd("cond_stage_model.model.", text_enc_dict)
1016
else:
1017
text_enc_dict = text_encoder.state_dict()
1018
update_sd("cond_stage_model.transformer.", text_enc_dict)
1019
1020
# Convert the VAE
1021
if vae is not None:
1022
vae_dict = convert_vae_state_dict(vae.state_dict())
1023
update_sd("first_stage_model.", vae_dict)
1024
1025
# Put together new checkpoint
1026
key_count = len(state_dict.keys())
1027
new_ckpt = {'state_dict': state_dict}
1028
1029
if 'epoch' in checkpoint:
1030
epochs += checkpoint['epoch']
1031
if 'global_step' in checkpoint:
1032
steps += checkpoint['global_step']
1033
1034
new_ckpt['epoch'] = epochs
1035
new_ckpt['global_step'] = steps
1036
1037
torch.save(new_ckpt, output_file)
1038
1039
return key_count
1040
1041
1042
def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None):
1043
if vae is None:
1044
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
1045
pipeline = StableDiffusionPipeline(
1046
unet=unet,
1047
text_encoder=text_encoder,
1048
vae=vae,
1049
scheduler=DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler"),
1050
tokenizer=CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer"),
1051
safety_checker=None,
1052
feature_extractor=None,
1053
)
1054
pipeline.save_pretrained(output_dir)
1055
1056
1057
1058
def convert(args):
1059
print("Converting to Diffusers ...")
1060
load_dtype = torch.float16 if args.fp16 else None
1061
1062
save_dtype = None
1063
if args.fp16:
1064
save_dtype = torch.float16
1065
elif args.bf16:
1066
save_dtype = torch.bfloat16
1067
elif args.float:
1068
save_dtype = torch.float
1069
1070
is_load_ckpt = os.path.isfile(args.model_to_load)
1071
is_save_ckpt = len(os.path.splitext(args.model_to_save)[1]) > 0
1072
1073
assert not is_load_ckpt or args.v1 != args.v2, f"v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です"
1074
assert is_save_ckpt or args.reference_model is not None, f"reference model is required to save as Diffusers / Diffusers形式での保存には参照モデルが必要です"
1075
1076
# モデルを読み込む
1077
msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else ""))
1078
1079
1080
if is_load_ckpt:
1081
v2_model = args.v2
1082
text_encoder, vae, unet = load_models_from_stable_diffusion_checkpoint(v2_model, args.model_to_load)
1083
else:
1084
pipe = StableDiffusionPipeline.from_pretrained(args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None)
1085
text_encoder = pipe.text_encoder
1086
vae = pipe.vae
1087
unet = pipe.unet
1088
1089
if args.v1 == args.v2:
1090
# 自動判定する
1091
v2_model = unet.config.cross_attention_dim == 1024
1092
#print("checking model version: model is " + ('v2' if v2_model else 'v1'))
1093
else:
1094
v2_model = args.v1
1095
1096
# 変換して保存する
1097
msg = ("checkpoint" + ("" if save_dtype is None else f" in {save_dtype}")) if is_save_ckpt else "Diffusers"
1098
1099
1100
if is_save_ckpt:
1101
original_model = args.model_to_load if is_load_ckpt else None
1102
key_count = save_stable_diffusion_checkpoint(v2_model, args.model_to_save, text_encoder, unet,
1103
original_model, args.epoch, args.global_step, save_dtype, vae)
1104
1105
else:
1106
save_diffusers_checkpoint(v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae)
1107
1108
1109
1110
if __name__ == '__main__':
1111
parser = argparse.ArgumentParser()
1112
parser.add_argument("--v1", action='store_true',
1113
help='load v1.x model (v1 or v2 is required to load checkpoint) / 1.xのモデルを読み込む')
1114
parser.add_argument("--v2", action='store_true',
1115
help='load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み込む')
1116
parser.add_argument("--fp16", action='store_true',
1117
help='load as fp16 (Diffusers only) and save as fp16 (checkpoint only) / fp16形式で読み込み(Diffusers形式のみ対応)、保存する(checkpointのみ対応)')
1118
parser.add_argument("--bf16", action='store_true', help='save as bf16 (checkpoint only) / bf16形式で保存する(checkpointのみ対応)')
1119
parser.add_argument("--float", action='store_true',
1120
help='save as float (checkpoint only) / float(float32)形式で保存する(checkpointのみ対応)')
1121
parser.add_argument("--epoch", type=int, default=0, help='epoch to write to checkpoint / checkpointに記録するepoch数の値')
1122
parser.add_argument("--global_step", type=int, default=0,
1123
help='global_step to write to checkpoint / checkpointに記録するglobal_stepの値')
1124
parser.add_argument("--reference_model", type=str, default=None,
1125
help="reference model for schduler/tokenizer, required in saving Diffusers, copy schduler/tokenizer from this / scheduler/tokenizerのコピー元のDiffusersモデル、Diffusers形式で保存するときに必要")
1126
1127
parser.add_argument("model_to_load", type=str, default=None,
1128
help="model to load: checkpoint file or Diffusers model's directory / 読み込むモデル、checkpointかDiffusers形式モデルのディレクトリ")
1129
parser.add_argument("model_to_save", type=str, default=None,
1130
help="model to save: checkpoint (with extension) or Diffusers model's directory (without extension) / 変換後のモデル、拡張子がある場合はcheckpoint、ない場合はDiffusesモデルとして保存")
1131
parser.add_argument(
1132
"--from_safetensors",
1133
action="store_true",
1134
help="If `--checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.",
1135
)
1136
args = parser.parse_args()
1137
convert(args)
1138
1139