Path: blob/master/extensions-builtin/Lora/network.py
2447 views
from __future__ import annotations1import os2from collections import namedtuple3import enum45import torch.nn as nn6import torch.nn.functional as F78from modules import sd_models, cache, errors, hashes, shared9import modules.models.sd3.mmdit1011NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module'])1213metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}141516class SdVersion(enum.Enum):17Unknown = 118SD1 = 219SD2 = 320SDXL = 4212223class NetworkOnDisk:24def __init__(self, name, filename):25self.name = name26self.filename = filename27self.metadata = {}28self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"2930def read_metadata():31metadata = sd_models.read_metadata_from_safetensors(filename)3233return metadata3435if self.is_safetensors:36try:37self.metadata = cache.cached_data_for_file('safetensors-metadata', "lora/" + self.name, filename, read_metadata)38except Exception as e:39errors.display(e, f"reading lora {filename}")4041if self.metadata:42m = {}43for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)):44m[k] = v4546self.metadata = m4748self.alias = self.metadata.get('ss_output_name', self.name)4950self.hash = None51self.shorthash = None52self.set_hash(53self.metadata.get('sshs_model_hash') or54hashes.sha256_from_cache(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or55''56)5758self.sd_version = self.detect_version()5960def detect_version(self):61if str(self.metadata.get('ss_base_model_version', "")).startswith("sdxl_"):62return SdVersion.SDXL63elif str(self.metadata.get('ss_v2', "")) == "True":64return SdVersion.SD265elif len(self.metadata):66return SdVersion.SD16768return SdVersion.Unknown6970def set_hash(self, v):71self.hash = v72self.shorthash = self.hash[0:12]7374if self.shorthash:75import networks76networks.available_network_hash_lookup[self.shorthash] = self7778def read_hash(self):79if not self.hash:80self.set_hash(hashes.sha256(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or '')8182def get_alias(self):83import networks84if shared.opts.lora_preferred_name == "Filename" or self.alias.lower() in networks.forbidden_network_aliases:85return self.name86else:87return self.alias888990class Network: # LoraModule91def __init__(self, name, network_on_disk: NetworkOnDisk):92self.name = name93self.network_on_disk = network_on_disk94self.te_multiplier = 1.095self.unet_multiplier = 1.096self.dyn_dim = None97self.modules = {}98self.bundle_embeddings = {}99self.mtime = None100101self.mentioned_name = None102"""the text that was used to add the network to prompt - can be either name or an alias"""103104105class ModuleType:106def create_module(self, net: Network, weights: NetworkWeights) -> Network | None:107return None108109110class NetworkModule:111def __init__(self, net: Network, weights: NetworkWeights):112self.network = net113self.network_key = weights.network_key114self.sd_key = weights.sd_key115self.sd_module = weights.sd_module116117if isinstance(self.sd_module, modules.models.sd3.mmdit.QkvLinear):118s = self.sd_module.weight.shape119self.shape = (s[0] // 3, s[1])120elif hasattr(self.sd_module, 'weight'):121self.shape = self.sd_module.weight.shape122elif isinstance(self.sd_module, nn.MultiheadAttention):123# For now, only self-attn use Pytorch's MHA124# So assume all qkvo proj have same shape125self.shape = self.sd_module.out_proj.weight.shape126else:127self.shape = None128129self.ops = None130self.extra_kwargs = {}131if isinstance(self.sd_module, nn.Conv2d):132self.ops = F.conv2d133self.extra_kwargs = {134'stride': self.sd_module.stride,135'padding': self.sd_module.padding136}137elif isinstance(self.sd_module, nn.Linear):138self.ops = F.linear139elif isinstance(self.sd_module, nn.LayerNorm):140self.ops = F.layer_norm141self.extra_kwargs = {142'normalized_shape': self.sd_module.normalized_shape,143'eps': self.sd_module.eps144}145elif isinstance(self.sd_module, nn.GroupNorm):146self.ops = F.group_norm147self.extra_kwargs = {148'num_groups': self.sd_module.num_groups,149'eps': self.sd_module.eps150}151152self.dim = None153self.bias = weights.w.get("bias")154self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None155self.scale = weights.w["scale"].item() if "scale" in weights.w else None156157self.dora_scale = weights.w.get("dora_scale", None)158self.dora_norm_dims = len(self.shape) - 1159160def multiplier(self):161if 'transformer' in self.sd_key[:20]:162return self.network.te_multiplier163else:164return self.network.unet_multiplier165166def calc_scale(self):167if self.scale is not None:168return self.scale169if self.dim is not None and self.alpha is not None:170return self.alpha / self.dim171172return 1.0173174def apply_weight_decompose(self, updown, orig_weight):175# Match the device/dtype176orig_weight = orig_weight.to(updown.dtype)177dora_scale = self.dora_scale.to(device=orig_weight.device, dtype=updown.dtype)178updown = updown.to(orig_weight.device)179180merged_scale1 = updown + orig_weight181merged_scale1_norm = (182merged_scale1.transpose(0, 1)183.reshape(merged_scale1.shape[1], -1)184.norm(dim=1, keepdim=True)185.reshape(merged_scale1.shape[1], *[1] * self.dora_norm_dims)186.transpose(0, 1)187)188189dora_merged = (190merged_scale1 * (dora_scale / merged_scale1_norm)191)192final_updown = dora_merged - orig_weight193return final_updown194195def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):196if self.bias is not None:197updown = updown.reshape(self.bias.shape)198updown += self.bias.to(orig_weight.device, dtype=updown.dtype)199updown = updown.reshape(output_shape)200201if len(output_shape) == 4:202updown = updown.reshape(output_shape)203204if orig_weight.size().numel() == updown.size().numel():205updown = updown.reshape(orig_weight.shape)206207if ex_bias is not None:208ex_bias = ex_bias * self.multiplier()209210updown = updown * self.calc_scale()211212if self.dora_scale is not None:213updown = self.apply_weight_decompose(updown, orig_weight)214215return updown * self.multiplier(), ex_bias216217def calc_updown(self, target):218raise NotImplementedError()219220def forward(self, x, y):221"""A general forward implementation for all modules"""222if self.ops is None:223raise NotImplementedError()224else:225updown, ex_bias = self.calc_updown(self.sd_module.weight)226return y + self.ops(x, weight=updown, bias=ex_bias, **self.extra_kwargs)227228229230