Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
AUTOMATIC1111
GitHub Repository: AUTOMATIC1111/stable-diffusion-webui
Path: blob/master/extensions-builtin/Lora/networks.py
2447 views
1
from __future__ import annotations
2
import gradio as gr
3
import logging
4
import os
5
import re
6
7
import lora_patches
8
import network
9
import network_lora
10
import network_glora
11
import network_hada
12
import network_ia3
13
import network_lokr
14
import network_full
15
import network_norm
16
import network_oft
17
18
import torch
19
from typing import Union
20
21
from modules import shared, devices, sd_models, errors, scripts, sd_hijack
22
import modules.textual_inversion.textual_inversion as textual_inversion
23
import modules.models.sd3.mmdit
24
25
from lora_logger import logger
26
27
module_types = [
28
network_lora.ModuleTypeLora(),
29
network_hada.ModuleTypeHada(),
30
network_ia3.ModuleTypeIa3(),
31
network_lokr.ModuleTypeLokr(),
32
network_full.ModuleTypeFull(),
33
network_norm.ModuleTypeNorm(),
34
network_glora.ModuleTypeGLora(),
35
network_oft.ModuleTypeOFT(),
36
]
37
38
39
re_digits = re.compile(r"\d+")
40
re_x_proj = re.compile(r"(.*)_([qkv]_proj)$")
41
re_compiled = {}
42
43
suffix_conversion = {
44
"attentions": {},
45
"resnets": {
46
"conv1": "in_layers_2",
47
"conv2": "out_layers_3",
48
"norm1": "in_layers_0",
49
"norm2": "out_layers_0",
50
"time_emb_proj": "emb_layers_1",
51
"conv_shortcut": "skip_connection",
52
}
53
}
54
55
56
def convert_diffusers_name_to_compvis(key, is_sd2):
57
def match(match_list, regex_text):
58
regex = re_compiled.get(regex_text)
59
if regex is None:
60
regex = re.compile(regex_text)
61
re_compiled[regex_text] = regex
62
63
r = re.match(regex, key)
64
if not r:
65
return False
66
67
match_list.clear()
68
match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()])
69
return True
70
71
m = []
72
73
if match(m, r"lora_unet_conv_in(.*)"):
74
return f'diffusion_model_input_blocks_0_0{m[0]}'
75
76
if match(m, r"lora_unet_conv_out(.*)"):
77
return f'diffusion_model_out_2{m[0]}'
78
79
if match(m, r"lora_unet_time_embedding_linear_(\d+)(.*)"):
80
return f"diffusion_model_time_embed_{m[0] * 2 - 2}{m[1]}"
81
82
if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
83
suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
84
return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
85
86
if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"):
87
suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2])
88
return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}"
89
90
if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
91
suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
92
return f"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
93
94
if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"):
95
return f"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op"
96
97
if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"):
98
return f"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv"
99
100
if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"):
101
if is_sd2:
102
if 'mlp_fc1' in m[1]:
103
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
104
elif 'mlp_fc2' in m[1]:
105
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"
106
else:
107
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"
108
109
return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
110
111
if match(m, r"lora_te2_text_model_encoder_layers_(\d+)_(.+)"):
112
if 'mlp_fc1' in m[1]:
113
return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
114
elif 'mlp_fc2' in m[1]:
115
return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"
116
else:
117
return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"
118
119
return key
120
121
122
def assign_network_names_to_compvis_modules(sd_model):
123
network_layer_mapping = {}
124
125
if shared.sd_model.is_sdxl:
126
for i, embedder in enumerate(shared.sd_model.conditioner.embedders):
127
if not hasattr(embedder, 'wrapped'):
128
continue
129
130
for name, module in embedder.wrapped.named_modules():
131
network_name = f'{i}_{name.replace(".", "_")}'
132
network_layer_mapping[network_name] = module
133
module.network_layer_name = network_name
134
else:
135
cond_stage_model = getattr(shared.sd_model.cond_stage_model, 'wrapped', shared.sd_model.cond_stage_model)
136
137
for name, module in cond_stage_model.named_modules():
138
network_name = name.replace(".", "_")
139
network_layer_mapping[network_name] = module
140
module.network_layer_name = network_name
141
142
for name, module in shared.sd_model.model.named_modules():
143
network_name = name.replace(".", "_")
144
network_layer_mapping[network_name] = module
145
module.network_layer_name = network_name
146
147
sd_model.network_layer_mapping = network_layer_mapping
148
149
150
class BundledTIHash(str):
151
def __init__(self, hash_str):
152
self.hash = hash_str
153
154
def __str__(self):
155
return self.hash if shared.opts.lora_bundled_ti_to_infotext else ''
156
157
158
def load_network(name, network_on_disk):
159
net = network.Network(name, network_on_disk)
160
net.mtime = os.path.getmtime(network_on_disk.filename)
161
162
sd = sd_models.read_state_dict(network_on_disk.filename)
163
164
# this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0
165
if not hasattr(shared.sd_model, 'network_layer_mapping'):
166
assign_network_names_to_compvis_modules(shared.sd_model)
167
168
keys_failed_to_match = {}
169
is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping
170
if hasattr(shared.sd_model, 'diffusers_weight_map'):
171
diffusers_weight_map = shared.sd_model.diffusers_weight_map
172
elif hasattr(shared.sd_model, 'diffusers_weight_mapping'):
173
diffusers_weight_map = {}
174
for k, v in shared.sd_model.diffusers_weight_mapping():
175
diffusers_weight_map[k] = v
176
shared.sd_model.diffusers_weight_map = diffusers_weight_map
177
else:
178
diffusers_weight_map = None
179
180
matched_networks = {}
181
bundle_embeddings = {}
182
183
for key_network, weight in sd.items():
184
185
if diffusers_weight_map:
186
key_network_without_network_parts, network_name, network_weight = key_network.rsplit(".", 2)
187
network_part = network_name + '.' + network_weight
188
else:
189
key_network_without_network_parts, _, network_part = key_network.partition(".")
190
191
if key_network_without_network_parts == "bundle_emb":
192
emb_name, vec_name = network_part.split(".", 1)
193
emb_dict = bundle_embeddings.get(emb_name, {})
194
if vec_name.split('.')[0] == 'string_to_param':
195
_, k2 = vec_name.split('.', 1)
196
emb_dict['string_to_param'] = {k2: weight}
197
else:
198
emb_dict[vec_name] = weight
199
bundle_embeddings[emb_name] = emb_dict
200
201
if diffusers_weight_map:
202
key = diffusers_weight_map.get(key_network_without_network_parts, key_network_without_network_parts)
203
else:
204
key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2)
205
206
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
207
208
if sd_module is None:
209
m = re_x_proj.match(key)
210
if m:
211
sd_module = shared.sd_model.network_layer_mapping.get(m.group(1), None)
212
213
# SDXL loras seem to already have correct compvis keys, so only need to replace "lora_unet" with "diffusion_model"
214
if sd_module is None and "lora_unet" in key_network_without_network_parts:
215
key = key_network_without_network_parts.replace("lora_unet", "diffusion_model")
216
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
217
elif sd_module is None and "lora_te1_text_model" in key_network_without_network_parts:
218
key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model")
219
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
220
221
# some SD1 Loras also have correct compvis keys
222
if sd_module is None:
223
key = key_network_without_network_parts.replace("lora_te1_text_model", "transformer_text_model")
224
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
225
226
# kohya_ss OFT module
227
elif sd_module is None and "oft_unet" in key_network_without_network_parts:
228
key = key_network_without_network_parts.replace("oft_unet", "diffusion_model")
229
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
230
231
# KohakuBlueLeaf OFT module
232
if sd_module is None and "oft_diag" in key:
233
key = key_network_without_network_parts.replace("lora_unet", "diffusion_model")
234
key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model")
235
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
236
237
if sd_module is None:
238
keys_failed_to_match[key_network] = key
239
continue
240
241
if key not in matched_networks:
242
matched_networks[key] = network.NetworkWeights(network_key=key_network, sd_key=key, w={}, sd_module=sd_module)
243
244
matched_networks[key].w[network_part] = weight
245
246
for key, weights in matched_networks.items():
247
net_module = None
248
for nettype in module_types:
249
net_module = nettype.create_module(net, weights)
250
if net_module is not None:
251
break
252
253
if net_module is None:
254
raise AssertionError(f"Could not find a module type (out of {', '.join([x.__class__.__name__ for x in module_types])}) that would accept those keys: {', '.join(weights.w)}")
255
256
net.modules[key] = net_module
257
258
embeddings = {}
259
for emb_name, data in bundle_embeddings.items():
260
embedding = textual_inversion.create_embedding_from_data(data, emb_name, filename=network_on_disk.filename + "/" + emb_name)
261
embedding.loaded = None
262
embedding.shorthash = BundledTIHash(name)
263
embeddings[emb_name] = embedding
264
265
net.bundle_embeddings = embeddings
266
267
if keys_failed_to_match:
268
logging.debug(f"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}")
269
270
return net
271
272
273
def purge_networks_from_memory():
274
while len(networks_in_memory) > shared.opts.lora_in_memory_limit and len(networks_in_memory) > 0:
275
name = next(iter(networks_in_memory))
276
networks_in_memory.pop(name, None)
277
278
devices.torch_gc()
279
280
281
def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
282
emb_db = sd_hijack.model_hijack.embedding_db
283
already_loaded = {}
284
285
for net in loaded_networks:
286
if net.name in names:
287
already_loaded[net.name] = net
288
for emb_name, embedding in net.bundle_embeddings.items():
289
if embedding.loaded:
290
emb_db.register_embedding_by_name(None, shared.sd_model, emb_name)
291
292
loaded_networks.clear()
293
294
unavailable_networks = []
295
for name in names:
296
if name.lower() in forbidden_network_aliases and available_networks.get(name) is None:
297
unavailable_networks.append(name)
298
elif available_network_aliases.get(name) is None:
299
unavailable_networks.append(name)
300
301
if unavailable_networks:
302
update_available_networks_by_names(unavailable_networks)
303
304
networks_on_disk = [available_networks.get(name, None) if name.lower() in forbidden_network_aliases else available_network_aliases.get(name, None) for name in names]
305
if any(x is None for x in networks_on_disk):
306
list_available_networks()
307
308
networks_on_disk = [available_networks.get(name, None) if name.lower() in forbidden_network_aliases else available_network_aliases.get(name, None) for name in names]
309
310
failed_to_load_networks = []
311
312
for i, (network_on_disk, name) in enumerate(zip(networks_on_disk, names)):
313
net = already_loaded.get(name, None)
314
315
if network_on_disk is not None:
316
if net is None:
317
net = networks_in_memory.get(name)
318
319
if net is None or os.path.getmtime(network_on_disk.filename) > net.mtime:
320
try:
321
net = load_network(name, network_on_disk)
322
323
networks_in_memory.pop(name, None)
324
networks_in_memory[name] = net
325
except Exception as e:
326
errors.display(e, f"loading network {network_on_disk.filename}")
327
continue
328
329
net.mentioned_name = name
330
331
network_on_disk.read_hash()
332
333
if net is None:
334
failed_to_load_networks.append(name)
335
logging.info(f"Couldn't find network with name {name}")
336
continue
337
338
net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0
339
net.unet_multiplier = unet_multipliers[i] if unet_multipliers else 1.0
340
net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0
341
loaded_networks.append(net)
342
343
for emb_name, embedding in net.bundle_embeddings.items():
344
if embedding.loaded is None and emb_name in emb_db.word_embeddings:
345
logger.warning(
346
f'Skip bundle embedding: "{emb_name}"'
347
' as it was already loaded from embeddings folder'
348
)
349
continue
350
351
embedding.loaded = False
352
if emb_db.expected_shape == -1 or emb_db.expected_shape == embedding.shape:
353
embedding.loaded = True
354
emb_db.register_embedding(embedding, shared.sd_model)
355
else:
356
emb_db.skipped_embeddings[name] = embedding
357
358
if failed_to_load_networks:
359
lora_not_found_message = f'Lora not found: {", ".join(failed_to_load_networks)}'
360
sd_hijack.model_hijack.comments.append(lora_not_found_message)
361
if shared.opts.lora_not_found_warning_console:
362
print(f'\n{lora_not_found_message}\n')
363
if shared.opts.lora_not_found_gradio_warning:
364
gr.Warning(lora_not_found_message)
365
366
purge_networks_from_memory()
367
368
369
def allowed_layer_without_weight(layer):
370
if isinstance(layer, torch.nn.LayerNorm) and not layer.elementwise_affine:
371
return True
372
373
return False
374
375
376
def store_weights_backup(weight):
377
if weight is None:
378
return None
379
380
return weight.to(devices.cpu, copy=True)
381
382
383
def restore_weights_backup(obj, field, weight):
384
if weight is None:
385
setattr(obj, field, None)
386
return
387
388
getattr(obj, field).copy_(weight)
389
390
391
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
392
weights_backup = getattr(self, "network_weights_backup", None)
393
bias_backup = getattr(self, "network_bias_backup", None)
394
395
if weights_backup is None and bias_backup is None:
396
return
397
398
if weights_backup is not None:
399
if isinstance(self, torch.nn.MultiheadAttention):
400
restore_weights_backup(self, 'in_proj_weight', weights_backup[0])
401
restore_weights_backup(self.out_proj, 'weight', weights_backup[1])
402
else:
403
restore_weights_backup(self, 'weight', weights_backup)
404
405
if isinstance(self, torch.nn.MultiheadAttention):
406
restore_weights_backup(self.out_proj, 'bias', bias_backup)
407
else:
408
restore_weights_backup(self, 'bias', bias_backup)
409
410
411
def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
412
"""
413
Applies the currently selected set of networks to the weights of torch layer self.
414
If weights already have this particular set of networks applied, does nothing.
415
If not, restores original weights from backup and alters weights according to networks.
416
"""
417
418
network_layer_name = getattr(self, 'network_layer_name', None)
419
if network_layer_name is None:
420
return
421
422
current_names = getattr(self, "network_current_names", ())
423
wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks)
424
425
weights_backup = getattr(self, "network_weights_backup", None)
426
if weights_backup is None and wanted_names != ():
427
if current_names != () and not allowed_layer_without_weight(self):
428
raise RuntimeError(f"{network_layer_name} - no backup weights found and current weights are not unchanged")
429
430
if isinstance(self, torch.nn.MultiheadAttention):
431
weights_backup = (store_weights_backup(self.in_proj_weight), store_weights_backup(self.out_proj.weight))
432
else:
433
weights_backup = store_weights_backup(self.weight)
434
435
self.network_weights_backup = weights_backup
436
437
bias_backup = getattr(self, "network_bias_backup", None)
438
if bias_backup is None and wanted_names != ():
439
if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None:
440
bias_backup = store_weights_backup(self.out_proj.bias)
441
elif getattr(self, 'bias', None) is not None:
442
bias_backup = store_weights_backup(self.bias)
443
else:
444
bias_backup = None
445
446
# Unlike weight which always has value, some modules don't have bias.
447
# Only report if bias is not None and current bias are not unchanged.
448
if bias_backup is not None and current_names != ():
449
raise RuntimeError("no backup bias found and current bias are not unchanged")
450
451
self.network_bias_backup = bias_backup
452
453
if current_names != wanted_names:
454
network_restore_weights_from_backup(self)
455
456
for net in loaded_networks:
457
module = net.modules.get(network_layer_name, None)
458
if module is not None and hasattr(self, 'weight') and not isinstance(module, modules.models.sd3.mmdit.QkvLinear):
459
try:
460
with torch.no_grad():
461
if getattr(self, 'fp16_weight', None) is None:
462
weight = self.weight
463
bias = self.bias
464
else:
465
weight = self.fp16_weight.clone().to(self.weight.device)
466
bias = getattr(self, 'fp16_bias', None)
467
if bias is not None:
468
bias = bias.clone().to(self.bias.device)
469
updown, ex_bias = module.calc_updown(weight)
470
471
if len(weight.shape) == 4 and weight.shape[1] == 9:
472
# inpainting model. zero pad updown to make channel[1] 4 to 9
473
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
474
475
self.weight.copy_((weight.to(dtype=updown.dtype) + updown).to(dtype=self.weight.dtype))
476
if ex_bias is not None and hasattr(self, 'bias'):
477
if self.bias is None:
478
self.bias = torch.nn.Parameter(ex_bias).to(self.weight.dtype)
479
else:
480
self.bias.copy_((bias + ex_bias).to(dtype=self.bias.dtype))
481
except RuntimeError as e:
482
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
483
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
484
485
continue
486
487
module_q = net.modules.get(network_layer_name + "_q_proj", None)
488
module_k = net.modules.get(network_layer_name + "_k_proj", None)
489
module_v = net.modules.get(network_layer_name + "_v_proj", None)
490
module_out = net.modules.get(network_layer_name + "_out_proj", None)
491
492
if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
493
try:
494
with torch.no_grad():
495
# Send "real" orig_weight into MHA's lora module
496
qw, kw, vw = self.in_proj_weight.chunk(3, 0)
497
updown_q, _ = module_q.calc_updown(qw)
498
updown_k, _ = module_k.calc_updown(kw)
499
updown_v, _ = module_v.calc_updown(vw)
500
del qw, kw, vw
501
updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
502
updown_out, ex_bias = module_out.calc_updown(self.out_proj.weight)
503
504
self.in_proj_weight += updown_qkv
505
self.out_proj.weight += updown_out
506
if ex_bias is not None:
507
if self.out_proj.bias is None:
508
self.out_proj.bias = torch.nn.Parameter(ex_bias)
509
else:
510
self.out_proj.bias += ex_bias
511
512
except RuntimeError as e:
513
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
514
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
515
516
continue
517
518
if isinstance(self, modules.models.sd3.mmdit.QkvLinear) and module_q and module_k and module_v:
519
try:
520
with torch.no_grad():
521
# Send "real" orig_weight into MHA's lora module
522
qw, kw, vw = self.weight.chunk(3, 0)
523
updown_q, _ = module_q.calc_updown(qw)
524
updown_k, _ = module_k.calc_updown(kw)
525
updown_v, _ = module_v.calc_updown(vw)
526
del qw, kw, vw
527
updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
528
self.weight += updown_qkv
529
530
except RuntimeError as e:
531
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
532
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
533
534
continue
535
536
if module is None:
537
continue
538
539
logging.debug(f"Network {net.name} layer {network_layer_name}: couldn't find supported operation")
540
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
541
542
self.network_current_names = wanted_names
543
544
545
def network_forward(org_module, input, original_forward):
546
"""
547
Old way of applying Lora by executing operations during layer's forward.
548
Stacking many loras this way results in big performance degradation.
549
"""
550
551
if len(loaded_networks) == 0:
552
return original_forward(org_module, input)
553
554
input = devices.cond_cast_unet(input)
555
556
network_restore_weights_from_backup(org_module)
557
network_reset_cached_weight(org_module)
558
559
y = original_forward(org_module, input)
560
561
network_layer_name = getattr(org_module, 'network_layer_name', None)
562
for lora in loaded_networks:
563
module = lora.modules.get(network_layer_name, None)
564
if module is None:
565
continue
566
567
y = module.forward(input, y)
568
569
return y
570
571
572
def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
573
self.network_current_names = ()
574
self.network_weights_backup = None
575
self.network_bias_backup = None
576
577
578
def network_Linear_forward(self, input):
579
if shared.opts.lora_functional:
580
return network_forward(self, input, originals.Linear_forward)
581
582
network_apply_weights(self)
583
584
return originals.Linear_forward(self, input)
585
586
587
def network_Linear_load_state_dict(self, *args, **kwargs):
588
network_reset_cached_weight(self)
589
590
return originals.Linear_load_state_dict(self, *args, **kwargs)
591
592
593
def network_Conv2d_forward(self, input):
594
if shared.opts.lora_functional:
595
return network_forward(self, input, originals.Conv2d_forward)
596
597
network_apply_weights(self)
598
599
return originals.Conv2d_forward(self, input)
600
601
602
def network_Conv2d_load_state_dict(self, *args, **kwargs):
603
network_reset_cached_weight(self)
604
605
return originals.Conv2d_load_state_dict(self, *args, **kwargs)
606
607
608
def network_GroupNorm_forward(self, input):
609
if shared.opts.lora_functional:
610
return network_forward(self, input, originals.GroupNorm_forward)
611
612
network_apply_weights(self)
613
614
return originals.GroupNorm_forward(self, input)
615
616
617
def network_GroupNorm_load_state_dict(self, *args, **kwargs):
618
network_reset_cached_weight(self)
619
620
return originals.GroupNorm_load_state_dict(self, *args, **kwargs)
621
622
623
def network_LayerNorm_forward(self, input):
624
if shared.opts.lora_functional:
625
return network_forward(self, input, originals.LayerNorm_forward)
626
627
network_apply_weights(self)
628
629
return originals.LayerNorm_forward(self, input)
630
631
632
def network_LayerNorm_load_state_dict(self, *args, **kwargs):
633
network_reset_cached_weight(self)
634
635
return originals.LayerNorm_load_state_dict(self, *args, **kwargs)
636
637
638
def network_MultiheadAttention_forward(self, *args, **kwargs):
639
network_apply_weights(self)
640
641
return originals.MultiheadAttention_forward(self, *args, **kwargs)
642
643
644
def network_MultiheadAttention_load_state_dict(self, *args, **kwargs):
645
network_reset_cached_weight(self)
646
647
return originals.MultiheadAttention_load_state_dict(self, *args, **kwargs)
648
649
650
def process_network_files(names: list[str] | None = None):
651
candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
652
candidates += list(shared.walk_files(shared.cmd_opts.lyco_dir_backcompat, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
653
for filename in candidates:
654
if os.path.isdir(filename):
655
continue
656
name = os.path.splitext(os.path.basename(filename))[0]
657
# if names is provided, only load networks with names in the list
658
if names and name not in names:
659
continue
660
try:
661
entry = network.NetworkOnDisk(name, filename)
662
except OSError: # should catch FileNotFoundError and PermissionError etc.
663
errors.report(f"Failed to load network {name} from {filename}", exc_info=True)
664
continue
665
666
available_networks[name] = entry
667
668
if entry.alias in available_network_aliases:
669
forbidden_network_aliases[entry.alias.lower()] = 1
670
671
available_network_aliases[name] = entry
672
available_network_aliases[entry.alias] = entry
673
674
675
def update_available_networks_by_names(names: list[str]):
676
process_network_files(names)
677
678
679
def list_available_networks():
680
available_networks.clear()
681
available_network_aliases.clear()
682
forbidden_network_aliases.clear()
683
available_network_hash_lookup.clear()
684
forbidden_network_aliases.update({"none": 1, "Addams": 1})
685
686
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
687
688
process_network_files()
689
690
691
re_network_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)")
692
693
694
def infotext_pasted(infotext, params):
695
if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]:
696
return # if the other extension is active, it will handle those fields, no need to do anything
697
698
added = []
699
700
for k in params:
701
if not k.startswith("AddNet Model "):
702
continue
703
704
num = k[13:]
705
706
if params.get("AddNet Module " + num) != "LoRA":
707
continue
708
709
name = params.get("AddNet Model " + num)
710
if name is None:
711
continue
712
713
m = re_network_name.match(name)
714
if m:
715
name = m.group(1)
716
717
multiplier = params.get("AddNet Weight A " + num, "1.0")
718
719
added.append(f"<lora:{name}:{multiplier}>")
720
721
if added:
722
params["Prompt"] += "\n" + "".join(added)
723
724
725
originals: lora_patches.LoraPatches = None
726
727
extra_network_lora = None
728
729
available_networks = {}
730
available_network_aliases = {}
731
loaded_networks = []
732
loaded_bundle_embeddings = {}
733
networks_in_memory = {}
734
available_network_hash_lookup = {}
735
forbidden_network_aliases = {}
736
737
list_available_networks()
738
739