Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TheLastBen
GitHub Repository: TheLastBen/fast-stable-diffusion
Path: blob/main/Dreambooth/convertodiffv1.py
540 views
1
import argparse
2
import os
3
import torch
4
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig
5
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
6
7
8
9
# DiffUsers版StableDiffusionのモデルパラメータ
10
NUM_TRAIN_TIMESTEPS = 1000
11
BETA_START = 0.00085
12
BETA_END = 0.0120
13
14
UNET_PARAMS_MODEL_CHANNELS = 320
15
UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
16
UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
17
UNET_PARAMS_IMAGE_SIZE = 64
18
UNET_PARAMS_IN_CHANNELS = 4
19
UNET_PARAMS_OUT_CHANNELS = 4
20
UNET_PARAMS_NUM_RES_BLOCKS = 2
21
UNET_PARAMS_CONTEXT_DIM = 768
22
UNET_PARAMS_NUM_HEADS = 8
23
24
VAE_PARAMS_Z_CHANNELS = 4
25
VAE_PARAMS_RESOLUTION = 512
26
VAE_PARAMS_IN_CHANNELS = 3
27
VAE_PARAMS_OUT_CH = 3
28
VAE_PARAMS_CH = 128
29
VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
30
VAE_PARAMS_NUM_RES_BLOCKS = 2
31
32
# V2
33
V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
34
V2_UNET_PARAMS_CONTEXT_DIM = 1024
35
36
37
# region StableDiffusion->Diffusersの変換コード
38
# convert_original_stable_diffusion_to_diffusers をコピーしている(ASL 2.0)
39
40
41
def shave_segments(path, n_shave_prefix_segments=1):
42
"""
43
Removes segments. Positive values shave the first segments, negative shave the last segments.
44
"""
45
if n_shave_prefix_segments >= 0:
46
return ".".join(path.split(".")[n_shave_prefix_segments:])
47
else:
48
return ".".join(path.split(".")[:n_shave_prefix_segments])
49
50
51
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
52
"""
53
Updates paths inside resnets to the new naming scheme (local renaming)
54
"""
55
mapping = []
56
for old_item in old_list:
57
new_item = old_item.replace("in_layers.0", "norm1")
58
new_item = new_item.replace("in_layers.2", "conv1")
59
60
new_item = new_item.replace("out_layers.0", "norm2")
61
new_item = new_item.replace("out_layers.3", "conv2")
62
63
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
64
new_item = new_item.replace("skip_connection", "conv_shortcut")
65
66
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
67
68
mapping.append({"old": old_item, "new": new_item})
69
70
return mapping
71
72
73
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
74
"""
75
Updates paths inside resnets to the new naming scheme (local renaming)
76
"""
77
mapping = []
78
for old_item in old_list:
79
new_item = old_item
80
81
new_item = new_item.replace("nin_shortcut", "conv_shortcut")
82
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
83
84
mapping.append({"old": old_item, "new": new_item})
85
86
return mapping
87
88
89
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
90
"""
91
Updates paths inside attentions to the new naming scheme (local renaming)
92
"""
93
mapping = []
94
for old_item in old_list:
95
new_item = old_item
96
97
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
98
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
99
100
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
101
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
102
103
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
104
105
mapping.append({"old": old_item, "new": new_item})
106
107
return mapping
108
109
110
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
111
"""
112
Updates paths inside attentions to the new naming scheme (local renaming)
113
"""
114
mapping = []
115
for old_item in old_list:
116
new_item = old_item
117
118
new_item = new_item.replace("norm.weight", "group_norm.weight")
119
new_item = new_item.replace("norm.bias", "group_norm.bias")
120
121
new_item = new_item.replace("q.weight", "query.weight")
122
new_item = new_item.replace("q.bias", "query.bias")
123
124
new_item = new_item.replace("k.weight", "key.weight")
125
new_item = new_item.replace("k.bias", "key.bias")
126
127
new_item = new_item.replace("v.weight", "value.weight")
128
new_item = new_item.replace("v.bias", "value.bias")
129
130
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
131
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
132
133
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
134
135
mapping.append({"old": old_item, "new": new_item})
136
137
return mapping
138
139
140
def assign_to_checkpoint(
141
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
142
):
143
"""
144
This does the final conversion step: take locally converted weights and apply a global renaming
145
to them. It splits attention layers, and takes into account additional replacements
146
that may arise.
147
148
Assigns the weights to the new checkpoint.
149
"""
150
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
151
152
# Splits the attention layers into three variables.
153
if attention_paths_to_split is not None:
154
for path, path_map in attention_paths_to_split.items():
155
old_tensor = old_checkpoint[path]
156
channels = old_tensor.shape[0] // 3
157
158
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
159
160
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
161
162
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
163
query, key, value = old_tensor.split(channels // num_heads, dim=1)
164
165
checkpoint[path_map["query"]] = query.reshape(target_shape)
166
checkpoint[path_map["key"]] = key.reshape(target_shape)
167
checkpoint[path_map["value"]] = value.reshape(target_shape)
168
169
for path in paths:
170
new_path = path["new"]
171
172
# These have already been assigned
173
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
174
continue
175
176
# Global renaming happens here
177
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
178
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
179
new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
180
181
if additional_replacements is not None:
182
for replacement in additional_replacements:
183
new_path = new_path.replace(replacement["old"], replacement["new"])
184
185
# proj_attn.weight has to be converted from conv 1D to linear
186
if "proj_attn.weight" in new_path:
187
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
188
else:
189
checkpoint[new_path] = old_checkpoint[path["old"]]
190
191
192
def conv_attn_to_linear(checkpoint):
193
keys = list(checkpoint.keys())
194
attn_keys = ["query.weight", "key.weight", "value.weight"]
195
for key in keys:
196
if ".".join(key.split(".")[-2:]) in attn_keys:
197
if checkpoint[key].ndim > 2:
198
checkpoint[key] = checkpoint[key][:, :, 0, 0]
199
elif "proj_attn.weight" in key:
200
if checkpoint[key].ndim > 2:
201
checkpoint[key] = checkpoint[key][:, :, 0]
202
203
204
def linear_transformer_to_conv(checkpoint):
205
keys = list(checkpoint.keys())
206
tf_keys = ["proj_in.weight", "proj_out.weight"]
207
for key in keys:
208
if ".".join(key.split(".")[-2:]) in tf_keys:
209
if checkpoint[key].ndim == 2:
210
checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
211
212
213
def convert_ldm_unet_checkpoint(v2, checkpoint, config):
214
"""
215
Takes a state dict and a config, and returns a converted checkpoint.
216
"""
217
218
# extract state_dict for UNet
219
unet_state_dict = {}
220
unet_key = "model.diffusion_model."
221
keys = list(checkpoint.keys())
222
for key in keys:
223
if key.startswith(unet_key):
224
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
225
226
new_checkpoint = {}
227
228
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
229
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
230
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
231
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
232
233
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
234
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
235
236
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
237
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
238
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
239
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
240
241
# Retrieves the keys for the input blocks only
242
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
243
input_blocks = {
244
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
245
for layer_id in range(num_input_blocks)
246
}
247
248
# Retrieves the keys for the middle blocks only
249
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
250
middle_blocks = {
251
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
252
for layer_id in range(num_middle_blocks)
253
}
254
255
# Retrieves the keys for the output blocks only
256
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
257
output_blocks = {
258
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
259
for layer_id in range(num_output_blocks)
260
}
261
262
for i in range(1, num_input_blocks):
263
block_id = (i - 1) // (config["layers_per_block"] + 1)
264
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
265
266
resnets = [
267
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
268
]
269
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
270
271
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
272
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
273
f"input_blocks.{i}.0.op.weight"
274
)
275
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
276
f"input_blocks.{i}.0.op.bias"
277
)
278
279
paths = renew_resnet_paths(resnets)
280
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
281
assign_to_checkpoint(
282
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
283
)
284
285
if len(attentions):
286
paths = renew_attention_paths(attentions)
287
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
288
assign_to_checkpoint(
289
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
290
)
291
292
resnet_0 = middle_blocks[0]
293
attentions = middle_blocks[1]
294
resnet_1 = middle_blocks[2]
295
296
resnet_0_paths = renew_resnet_paths(resnet_0)
297
assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
298
299
resnet_1_paths = renew_resnet_paths(resnet_1)
300
assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
301
302
attentions_paths = renew_attention_paths(attentions)
303
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
304
assign_to_checkpoint(
305
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
306
)
307
308
for i in range(num_output_blocks):
309
block_id = i // (config["layers_per_block"] + 1)
310
layer_in_block_id = i % (config["layers_per_block"] + 1)
311
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
312
output_block_list = {}
313
314
for layer in output_block_layers:
315
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
316
if layer_id in output_block_list:
317
output_block_list[layer_id].append(layer_name)
318
else:
319
output_block_list[layer_id] = [layer_name]
320
321
if len(output_block_list) > 1:
322
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
323
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
324
325
resnet_0_paths = renew_resnet_paths(resnets)
326
paths = renew_resnet_paths(resnets)
327
328
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
329
assign_to_checkpoint(
330
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
331
)
332
333
if ["conv.weight", "conv.bias"] in output_block_list.values():
334
index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
335
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
336
f"output_blocks.{i}.{index}.conv.weight"
337
]
338
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
339
f"output_blocks.{i}.{index}.conv.bias"
340
]
341
342
# Clear attentions as they have been attributed above.
343
if len(attentions) == 2:
344
attentions = []
345
346
if len(attentions):
347
paths = renew_attention_paths(attentions)
348
meta_path = {
349
"old": f"output_blocks.{i}.1",
350
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
351
}
352
assign_to_checkpoint(
353
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
354
)
355
else:
356
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
357
for path in resnet_0_paths:
358
old_path = ".".join(["output_blocks", str(i), path["old"]])
359
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
360
361
new_checkpoint[new_path] = unet_state_dict[old_path]
362
363
# SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する
364
if v2:
365
linear_transformer_to_conv(new_checkpoint)
366
367
return new_checkpoint
368
369
370
def convert_ldm_vae_checkpoint(checkpoint, config):
371
# extract state dict for VAE
372
vae_state_dict = {}
373
vae_key = "first_stage_model."
374
keys = list(checkpoint.keys())
375
for key in keys:
376
if key.startswith(vae_key):
377
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
378
# if len(vae_state_dict) == 0:
379
# # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict
380
# vae_state_dict = checkpoint
381
382
new_checkpoint = {}
383
384
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
385
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
386
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
387
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
388
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
389
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
390
391
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
392
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
393
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
394
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
395
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
396
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
397
398
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
399
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
400
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
401
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
402
403
# Retrieves the keys for the encoder down blocks only
404
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
405
down_blocks = {
406
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
407
}
408
409
# Retrieves the keys for the decoder up blocks only
410
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
411
up_blocks = {
412
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
413
}
414
415
for i in range(num_down_blocks):
416
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
417
418
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
419
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
420
f"encoder.down.{i}.downsample.conv.weight"
421
)
422
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
423
f"encoder.down.{i}.downsample.conv.bias"
424
)
425
426
paths = renew_vae_resnet_paths(resnets)
427
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
428
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
429
430
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
431
num_mid_res_blocks = 2
432
for i in range(1, num_mid_res_blocks + 1):
433
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
434
435
paths = renew_vae_resnet_paths(resnets)
436
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
437
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
438
439
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
440
paths = renew_vae_attention_paths(mid_attentions)
441
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
442
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
443
conv_attn_to_linear(new_checkpoint)
444
445
for i in range(num_up_blocks):
446
block_id = num_up_blocks - 1 - i
447
resnets = [
448
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
449
]
450
451
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
452
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
453
f"decoder.up.{block_id}.upsample.conv.weight"
454
]
455
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
456
f"decoder.up.{block_id}.upsample.conv.bias"
457
]
458
459
paths = renew_vae_resnet_paths(resnets)
460
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
461
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
462
463
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
464
num_mid_res_blocks = 2
465
for i in range(1, num_mid_res_blocks + 1):
466
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
467
468
paths = renew_vae_resnet_paths(resnets)
469
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
470
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
471
472
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
473
paths = renew_vae_attention_paths(mid_attentions)
474
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
475
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
476
conv_attn_to_linear(new_checkpoint)
477
return new_checkpoint
478
479
480
def create_unet_diffusers_config(v2):
481
"""
482
Creates a config for the diffusers based on the config of the LDM model.
483
"""
484
# unet_params = original_config.model.params.unet_config.params
485
486
block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT]
487
488
down_block_types = []
489
resolution = 1
490
for i in range(len(block_out_channels)):
491
block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D"
492
down_block_types.append(block_type)
493
if i != len(block_out_channels) - 1:
494
resolution *= 2
495
496
up_block_types = []
497
for i in range(len(block_out_channels)):
498
block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D"
499
up_block_types.append(block_type)
500
resolution //= 2
501
502
config = dict(
503
sample_size=UNET_PARAMS_IMAGE_SIZE,
504
in_channels=UNET_PARAMS_IN_CHANNELS,
505
out_channels=UNET_PARAMS_OUT_CHANNELS,
506
down_block_types=tuple(down_block_types),
507
up_block_types=tuple(up_block_types),
508
block_out_channels=tuple(block_out_channels),
509
layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
510
cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM,
511
attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
512
)
513
514
return config
515
516
517
def create_vae_diffusers_config():
518
"""
519
Creates a config for the diffusers based on the config of the LDM model.
520
"""
521
# vae_params = original_config.model.params.first_stage_config.params.ddconfig
522
# _ = original_config.model.params.first_stage_config.params.embed_dim
523
block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT]
524
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
525
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
526
527
config = dict(
528
sample_size=VAE_PARAMS_RESOLUTION,
529
in_channels=VAE_PARAMS_IN_CHANNELS,
530
out_channels=VAE_PARAMS_OUT_CH,
531
down_block_types=tuple(down_block_types),
532
up_block_types=tuple(up_block_types),
533
block_out_channels=tuple(block_out_channels),
534
latent_channels=VAE_PARAMS_Z_CHANNELS,
535
layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS,
536
)
537
return config
538
539
540
def convert_ldm_clip_checkpoint_v1(checkpoint):
541
keys = list(checkpoint.keys())
542
text_model_dict = {}
543
for key in keys:
544
if key.startswith("cond_stage_model.transformer"):
545
text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key]
546
return text_model_dict
547
548
549
def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
550
# 嫌になるくらい違うぞ!
551
def convert_key(key):
552
if not key.startswith("cond_stage_model"):
553
return None
554
555
# common conversion
556
key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.")
557
key = key.replace("cond_stage_model.model.", "text_model.")
558
559
if "resblocks" in key:
560
# resblocks conversion
561
key = key.replace(".resblocks.", ".layers.")
562
if ".ln_" in key:
563
key = key.replace(".ln_", ".layer_norm")
564
elif ".mlp." in key:
565
key = key.replace(".c_fc.", ".fc1.")
566
key = key.replace(".c_proj.", ".fc2.")
567
elif '.attn.out_proj' in key:
568
key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
569
elif '.attn.in_proj' in key:
570
key = None # 特殊なので後で処理する
571
else:
572
raise ValueError(f"unexpected key in SD: {key}")
573
elif '.positional_embedding' in key:
574
key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
575
elif '.text_projection' in key:
576
key = None # 使われない???
577
elif '.logit_scale' in key:
578
key = None # 使われない???
579
elif '.token_embedding' in key:
580
key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
581
elif '.ln_final' in key:
582
key = key.replace(".ln_final", ".final_layer_norm")
583
return key
584
585
keys = list(checkpoint.keys())
586
new_sd = {}
587
for key in keys:
588
# remove resblocks 23
589
if '.resblocks.23.' in key:
590
continue
591
new_key = convert_key(key)
592
if new_key is None:
593
continue
594
new_sd[new_key] = checkpoint[key]
595
596
# attnの変換
597
for key in keys:
598
if '.resblocks.23.' in key:
599
continue
600
if '.resblocks' in key and '.attn.in_proj_' in key:
601
# 三つに分割
602
values = torch.chunk(checkpoint[key], 3)
603
604
key_suffix = ".weight" if "weight" in key else ".bias"
605
key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.")
606
key_pfx = key_pfx.replace("_weight", "")
607
key_pfx = key_pfx.replace("_bias", "")
608
key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
609
new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
610
new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
611
new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
612
613
# position_idsの追加
614
new_sd["text_model.embeddings.position_ids"] = torch.Tensor([list(range(max_length))]).to(torch.int64)
615
return new_sd
616
617
# endregion
618
619
620
# region Diffusers->StableDiffusion の変換コード
621
# convert_diffusers_to_original_stable_diffusion をコピーしている(ASL 2.0)
622
623
def conv_transformer_to_linear(checkpoint):
624
keys = list(checkpoint.keys())
625
tf_keys = ["proj_in.weight", "proj_out.weight"]
626
for key in keys:
627
if ".".join(key.split(".")[-2:]) in tf_keys:
628
if checkpoint[key].ndim > 2:
629
checkpoint[key] = checkpoint[key][:, :, 0, 0]
630
631
632
def convert_unet_state_dict_to_sd(v2, unet_state_dict):
633
unet_conversion_map = [
634
# (stable-diffusion, HF Diffusers)
635
("time_embed.0.weight", "time_embedding.linear_1.weight"),
636
("time_embed.0.bias", "time_embedding.linear_1.bias"),
637
("time_embed.2.weight", "time_embedding.linear_2.weight"),
638
("time_embed.2.bias", "time_embedding.linear_2.bias"),
639
("input_blocks.0.0.weight", "conv_in.weight"),
640
("input_blocks.0.0.bias", "conv_in.bias"),
641
("out.0.weight", "conv_norm_out.weight"),
642
("out.0.bias", "conv_norm_out.bias"),
643
("out.2.weight", "conv_out.weight"),
644
("out.2.bias", "conv_out.bias"),
645
]
646
647
unet_conversion_map_resnet = [
648
# (stable-diffusion, HF Diffusers)
649
("in_layers.0", "norm1"),
650
("in_layers.2", "conv1"),
651
("out_layers.0", "norm2"),
652
("out_layers.3", "conv2"),
653
("emb_layers.1", "time_emb_proj"),
654
("skip_connection", "conv_shortcut"),
655
]
656
657
unet_conversion_map_layer = []
658
for i in range(4):
659
# loop over downblocks/upblocks
660
661
for j in range(2):
662
# loop over resnets/attentions for downblocks
663
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
664
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
665
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
666
667
if i < 3:
668
# no attention layers in down_blocks.3
669
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
670
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
671
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
672
673
for j in range(3):
674
# loop over resnets/attentions for upblocks
675
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
676
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
677
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
678
679
if i > 0:
680
# no attention layers in up_blocks.0
681
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
682
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
683
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
684
685
if i < 3:
686
# no downsample in down_blocks.3
687
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
688
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
689
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
690
691
# no upsample in up_blocks.3
692
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
693
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
694
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
695
696
hf_mid_atn_prefix = "mid_block.attentions.0."
697
sd_mid_atn_prefix = "middle_block.1."
698
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
699
700
for j in range(2):
701
hf_mid_res_prefix = f"mid_block.resnets.{j}."
702
sd_mid_res_prefix = f"middle_block.{2*j}."
703
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
704
705
# buyer beware: this is a *brittle* function,
706
# and correct output requires that all of these pieces interact in
707
# the exact order in which I have arranged them.
708
mapping = {k: k for k in unet_state_dict.keys()}
709
for sd_name, hf_name in unet_conversion_map:
710
mapping[hf_name] = sd_name
711
for k, v in mapping.items():
712
if "resnets" in k:
713
for sd_part, hf_part in unet_conversion_map_resnet:
714
v = v.replace(hf_part, sd_part)
715
mapping[k] = v
716
for k, v in mapping.items():
717
for sd_part, hf_part in unet_conversion_map_layer:
718
v = v.replace(hf_part, sd_part)
719
mapping[k] = v
720
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
721
722
if v2:
723
conv_transformer_to_linear(new_state_dict)
724
725
return new_state_dict
726
727
728
# ================#
729
# VAE Conversion #
730
# ================#
731
732
def reshape_weight_for_sd(w):
733
# convert HF linear weights to SD conv2d weights
734
return w.reshape(*w.shape, 1, 1)
735
736
737
def convert_vae_state_dict(vae_state_dict):
738
vae_conversion_map = [
739
# (stable-diffusion, HF Diffusers)
740
("nin_shortcut", "conv_shortcut"),
741
("norm_out", "conv_norm_out"),
742
("mid.attn_1.", "mid_block.attentions.0."),
743
]
744
745
for i in range(4):
746
# down_blocks have two resnets
747
for j in range(2):
748
hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
749
sd_down_prefix = f"encoder.down.{i}.block.{j}."
750
vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
751
752
if i < 3:
753
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
754
sd_downsample_prefix = f"down.{i}.downsample."
755
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
756
757
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
758
sd_upsample_prefix = f"up.{3-i}.upsample."
759
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
760
761
# up_blocks have three resnets
762
# also, up blocks in hf are numbered in reverse from sd
763
for j in range(3):
764
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
765
sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
766
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
767
768
# this part accounts for mid blocks in both the encoder and the decoder
769
for i in range(2):
770
hf_mid_res_prefix = f"mid_block.resnets.{i}."
771
sd_mid_res_prefix = f"mid.block_{i+1}."
772
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
773
774
vae_conversion_map_attn = [
775
# (stable-diffusion, HF Diffusers)
776
("norm.", "group_norm."),
777
("q.", "query."),
778
("k.", "key."),
779
("v.", "value."),
780
("proj_out.", "proj_attn."),
781
]
782
783
mapping = {k: k for k in vae_state_dict.keys()}
784
for k, v in mapping.items():
785
for sd_part, hf_part in vae_conversion_map:
786
v = v.replace(hf_part, sd_part)
787
mapping[k] = v
788
for k, v in mapping.items():
789
if "attentions" in k:
790
for sd_part, hf_part in vae_conversion_map_attn:
791
v = v.replace(hf_part, sd_part)
792
mapping[k] = v
793
new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
794
weights_to_convert = ["q", "k", "v", "proj_out"]
795
796
for k, v in new_state_dict.items():
797
for weight_name in weights_to_convert:
798
if f"mid.attn_1.{weight_name}.weight" in k:
799
new_state_dict[k] = reshape_weight_for_sd(v)
800
801
return new_state_dict
802
803
804
# endregion
805
806
807
def load_checkpoint_with_text_encoder_conversion(ckpt_path):
808
# text encoderの格納形式が違うモデルに対応する ('text_model'がない)
809
TEXT_ENCODER_KEY_REPLACEMENTS = [
810
('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'),
811
('cond_stage_model.transformer.encoder.', 'cond_stage_model.transformer.text_model.encoder.'),
812
('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.')
813
]
814
815
checkpoint = torch.load(ckpt_path, map_location="cuda")
816
state_dict = checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint
817
key_reps = []
818
for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
819
for key in state_dict.keys():
820
if key.startswith(rep_from):
821
new_key = rep_to + key[len(rep_from):]
822
key_reps.append((key, new_key))
823
824
for key, new_key in key_reps:
825
state_dict[new_key] = state_dict[key]
826
del state_dict[key]
827
828
return checkpoint
829
830
831
# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
832
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
833
834
checkpoint = load_checkpoint_with_text_encoder_conversion(ckpt_path)
835
state_dict = checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint
836
if dtype is not None:
837
for k, v in state_dict.items():
838
if type(v) is torch.Tensor:
839
state_dict[k] = v.to(dtype)
840
841
# Convert the UNet2DConditionModel model.
842
unet_config = create_unet_diffusers_config(v2)
843
converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config)
844
845
unet = UNet2DConditionModel(**unet_config)
846
info = unet.load_state_dict(converted_unet_checkpoint)
847
848
849
# Convert the VAE model.
850
vae_config = create_vae_diffusers_config()
851
converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
852
853
vae = AutoencoderKL(**vae_config)
854
info = vae.load_state_dict(converted_vae_checkpoint)
855
856
857
# convert text_model
858
if v2:
859
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77)
860
cfg = CLIPTextConfig(
861
vocab_size=49408,
862
hidden_size=1024,
863
intermediate_size=4096,
864
num_hidden_layers=23,
865
num_attention_heads=16,
866
max_position_embeddings=77,
867
hidden_act="gelu",
868
layer_norm_eps=1e-05,
869
dropout=0.0,
870
attention_dropout=0.0,
871
initializer_range=0.02,
872
initializer_factor=1.0,
873
pad_token_id=1,
874
bos_token_id=0,
875
eos_token_id=2,
876
model_type="clip_text_model",
877
projection_dim=512,
878
torch_dtype="float32",
879
transformers_version="4.25.0.dev0",
880
)
881
text_model = CLIPTextModel._from_config(cfg)
882
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
883
else:
884
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
885
cfg = CLIPTextConfig(
886
vocab_size=49408,
887
hidden_size=768,
888
intermediate_size=3072,
889
num_hidden_layers=12,
890
num_attention_heads=12,
891
max_position_embeddings=77,
892
hidden_act="quick_gelu",
893
layer_norm_eps=1e-05,
894
dropout=0.0,
895
attention_dropout=0.0,
896
initializer_range=0.02,
897
initializer_factor=1.0,
898
pad_token_id=1,
899
bos_token_id=0,
900
eos_token_id=2,
901
model_type="clip_text_model",
902
projection_dim=768,
903
torch_dtype="float32",
904
transformers_version="4.16.0.dev0",
905
)
906
907
908
text_model = CLIPTextModel._from_config(cfg)
909
#text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
910
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
911
912
913
return text_model, vae, unet
914
915
916
def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False):
917
def convert_key(key):
918
# position_idsの除去
919
if ".position_ids" in key:
920
return None
921
922
# common
923
key = key.replace("text_model.encoder.", "transformer.")
924
key = key.replace("text_model.", "")
925
if "layers" in key:
926
# resblocks conversion
927
key = key.replace(".layers.", ".resblocks.")
928
if ".layer_norm" in key:
929
key = key.replace(".layer_norm", ".ln_")
930
elif ".mlp." in key:
931
key = key.replace(".fc1.", ".c_fc.")
932
key = key.replace(".fc2.", ".c_proj.")
933
elif '.self_attn.out_proj' in key:
934
key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
935
elif '.self_attn.' in key:
936
key = None # 特殊なので後で処理する
937
else:
938
raise ValueError(f"unexpected key in DiffUsers model: {key}")
939
elif '.position_embedding' in key:
940
key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
941
elif '.token_embedding' in key:
942
key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
943
elif 'final_layer_norm' in key:
944
key = key.replace("final_layer_norm", "ln_final")
945
return key
946
947
keys = list(checkpoint.keys())
948
new_sd = {}
949
for key in keys:
950
new_key = convert_key(key)
951
if new_key is None:
952
continue
953
new_sd[new_key] = checkpoint[key]
954
955
# attnの変換
956
for key in keys:
957
if 'layers' in key and 'q_proj' in key:
958
# 三つを結合
959
key_q = key
960
key_k = key.replace("q_proj", "k_proj")
961
key_v = key.replace("q_proj", "v_proj")
962
963
value_q = checkpoint[key_q]
964
value_k = checkpoint[key_k]
965
value_v = checkpoint[key_v]
966
value = torch.cat([value_q, value_k, value_v])
967
968
new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
969
new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
970
new_sd[new_key] = value
971
972
# 最後の層などを捏造するか
973
if make_dummy_weights:
974
975
keys = list(new_sd.keys())
976
for key in keys:
977
if key.startswith("transformer.resblocks.22."):
978
new_sd[key.replace(".22.", ".23.")] = new_sd[key]
979
980
# Diffusersに含まれない重みを作っておく
981
new_sd['text_projection'] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device)
982
new_sd['logit_scale'] = torch.tensor(1)
983
984
return new_sd
985
986
987
def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None):
988
if ckpt_path is not None:
989
# epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む
990
checkpoint = load_checkpoint_with_text_encoder_conversion(ckpt_path)
991
state_dict = checkpoint["state_dict"]
992
strict = True
993
else:
994
# 新しく作る
995
checkpoint = {}
996
state_dict = {}
997
strict = False
998
999
def update_sd(prefix, sd):
1000
for k, v in sd.items():
1001
key = prefix + k
1002
assert not strict or key in state_dict, f"Illegal key in save SD: {key}"
1003
if save_dtype is not None:
1004
v = v.detach().clone().to("cpu").to(save_dtype)
1005
state_dict[key] = v
1006
1007
# Convert the UNet model
1008
unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict())
1009
update_sd("model.diffusion_model.", unet_state_dict)
1010
1011
# Convert the text encoder model
1012
if v2:
1013
make_dummy = ckpt_path is None # 参照元のcheckpointがない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる
1014
text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy)
1015
update_sd("cond_stage_model.model.", text_enc_dict)
1016
else:
1017
text_enc_dict = text_encoder.state_dict()
1018
update_sd("cond_stage_model.transformer.", text_enc_dict)
1019
1020
# Convert the VAE
1021
if vae is not None:
1022
vae_dict = convert_vae_state_dict(vae.state_dict())
1023
update_sd("first_stage_model.", vae_dict)
1024
1025
# Put together new checkpoint
1026
key_count = len(state_dict.keys())
1027
new_ckpt = {'state_dict': state_dict}
1028
1029
if 'epoch' in checkpoint:
1030
epochs += checkpoint['epoch']
1031
if 'global_step' in checkpoint:
1032
steps += checkpoint['global_step']
1033
1034
new_ckpt['epoch'] = epochs
1035
new_ckpt['global_step'] = steps
1036
1037
torch.save(new_ckpt, output_file)
1038
1039
return key_count
1040
1041
1042
def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, vae=None):
1043
if vae is None:
1044
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
1045
1046
pipeline = StableDiffusionPipeline(
1047
unet=unet,
1048
text_encoder=text_encoder,
1049
vae=vae,
1050
scheduler = DDIMScheduler.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="scheduler"),
1051
tokenizer=CLIPTokenizer.from_pretrained("refmdl", subfolder="tokenizer"),
1052
)
1053
pipeline.save_pretrained(output_dir)
1054
1055
1056
1057
def convert(args):
1058
print("Converting to Diffusers ...")
1059
load_dtype = torch.float16 if args.fp16 else None
1060
1061
save_dtype = None
1062
if args.fp16:
1063
save_dtype = torch.float16
1064
elif args.bf16:
1065
save_dtype = torch.bfloat16
1066
elif args.float:
1067
save_dtype = torch.float
1068
1069
is_load_ckpt = os.path.isfile(args.model_to_load)
1070
is_save_ckpt = len(os.path.splitext(args.model_to_save)[1]) > 0
1071
1072
assert not is_load_ckpt or args.v1 != args.v2, f"v1 or v2 is required to load checkpoint"
1073
assert is_save_ckpt is not None, f"reference model is required to save as Diffusers"
1074
1075
# モデルを読み込む
1076
msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else ""))
1077
1078
1079
if is_load_ckpt:
1080
v2_model = args.v2
1081
text_encoder, vae, unet = load_models_from_stable_diffusion_checkpoint(v2_model, args.model_to_load)
1082
else:
1083
pipe = StableDiffusionPipeline.from_pretrained(args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None)
1084
text_encoder = pipe.text_encoder
1085
vae = pipe.vae
1086
unet = pipe.unet
1087
1088
if args.v1 == args.v2:
1089
# 自動判定する
1090
v2_model = unet.config.cross_attention_dim == 1024
1091
#print("checking model version: model is " + ('v2' if v2_model else 'v1'))
1092
else:
1093
v2_model = args.v1
1094
1095
# 変換して保存する
1096
msg = ("checkpoint" + ("" if save_dtype is None else f" in {save_dtype}")) if is_save_ckpt else "Diffusers"
1097
1098
1099
if is_save_ckpt:
1100
original_model = args.model_to_load if is_load_ckpt else None
1101
key_count = save_stable_diffusion_checkpoint(v2_model, args.model_to_save, text_encoder, unet,
1102
original_model, args.epoch, args.global_step, save_dtype, vae)
1103
1104
else:
1105
save_diffusers_checkpoint(v2_model, args.model_to_save, text_encoder, unet, vae)
1106
1107
1108
1109
if __name__ == '__main__':
1110
parser = argparse.ArgumentParser()
1111
parser.add_argument("--v1", action='store_true',
1112
help='load v1.x model (v1 or v2 is required to load checkpoint) / 1.xのモデルを読み込む')
1113
parser.add_argument("--v2", action='store_true',
1114
help='load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み込む')
1115
parser.add_argument("--fp16", action='store_true',
1116
help='load as fp16 (Diffusers only) and save as fp16 (checkpoint only) / fp16形式で読み込み(Diffusers形式のみ対応)、保存する(checkpointのみ対応)')
1117
parser.add_argument("--bf16", action='store_true', help='save as bf16 (checkpoint only) / bf16形式で保存する(checkpointのみ対応)')
1118
parser.add_argument("--float", action='store_true',
1119
help='save as float (checkpoint only) / float(float32)形式で保存する(checkpointのみ対応)')
1120
parser.add_argument("--epoch", type=int, default=0, help='epoch to write to checkpoint / checkpointに記録するepoch数の値')
1121
parser.add_argument("--global_step", type=int, default=0,
1122
help='global_step to write to checkpoint / checkpointに記録するglobal_stepの値')
1123
1124
parser.add_argument("model_to_load", type=str, default=None,
1125
help="model to load: checkpoint file or Diffusers model's directory / 読み込むモデル、checkpointかDiffusers形式モデルのディレクトリ")
1126
parser.add_argument("model_to_save", type=str, default=None,
1127
help="model to save: checkpoint (with extension) or Diffusers model's directory (without extension) / 変換後のモデル、拡張子がある場合はcheckpoint、ない場合はDiffusesモデルとして保存")
1128
1129
args = parser.parse_args()
1130
convert(args)
1131
1132