Path: blob/master/modules/extra_networks.py
3055 views
import json1import os2import re3import logging4from collections import defaultdict56from modules import errors78extra_network_registry = {}9extra_network_aliases = {}101112def initialize():13extra_network_registry.clear()14extra_network_aliases.clear()151617def register_extra_network(extra_network):18extra_network_registry[extra_network.name] = extra_network192021def register_extra_network_alias(extra_network, alias):22extra_network_aliases[alias] = extra_network232425def register_default_extra_networks():26from modules.extra_networks_hypernet import ExtraNetworkHypernet27register_extra_network(ExtraNetworkHypernet())282930class ExtraNetworkParams:31def __init__(self, items=None):32self.items = items or []33self.positional = []34self.named = {}3536for item in self.items:37parts = item.split('=', 2) if isinstance(item, str) else [item]38if len(parts) == 2:39self.named[parts[0]] = parts[1]40else:41self.positional.append(item)4243def __eq__(self, other):44return self.items == other.items454647class ExtraNetwork:48def __init__(self, name):49self.name = name5051def activate(self, p, params_list):52"""53Called by processing on every run. Whatever the extra network is meant to do should be activated here.54Passes arguments related to this extra network in params_list.55User passes arguments by specifying this in his prompt:5657<name:arg1:arg2:arg3>5859Where name matches the name of this ExtraNetwork object, and arg1:arg2:arg3 are any natural number of text arguments60separated by colon.6162Even if the user does not mention this ExtraNetwork in his prompt, the call will still be made, with empty params_list -63in this case, all effects of this extra networks should be disabled.6465Can be called multiple times before deactivate() - each new call should override the previous call completely.6667For example, if this ExtraNetwork's name is 'hypernet' and user's prompt is:6869> "1girl, <hypernet:agm:1.1> <extrasupernet:master:12:13:14> <hypernet:ray>"7071params_list will be:7273[74ExtraNetworkParams(items=["agm", "1.1"]),75ExtraNetworkParams(items=["ray"])76]7778"""79raise NotImplementedError8081def deactivate(self, p):82"""83Called at the end of processing for housekeeping. No need to do anything here.84"""8586raise NotImplementedError878889def lookup_extra_networks(extra_network_data):90"""returns a dict mapping ExtraNetwork objects to lists of arguments for those extra networks.9192Example input:93{94'lora': [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D58310>],95'lyco': [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D58F70>],96'hypernet': [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D5A800>]97}9899Example output:100101{102<extra_networks_lora.ExtraNetworkLora object at 0x0000020581BEECE0>: [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D58310>, <modules.extra_networks.ExtraNetworkParams object at 0x0000020690D58F70>],103<modules.extra_networks_hypernet.ExtraNetworkHypernet object at 0x0000020581BEEE60>: [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D5A800>]104}105"""106107res = {}108109for extra_network_name, extra_network_args in list(extra_network_data.items()):110extra_network = extra_network_registry.get(extra_network_name, None)111alias = extra_network_aliases.get(extra_network_name, None)112113if alias is not None and extra_network is None:114extra_network = alias115116if extra_network is None:117logging.info(f"Skipping unknown extra network: {extra_network_name}")118continue119120res.setdefault(extra_network, []).extend(extra_network_args)121122return res123124125def activate(p, extra_network_data):126"""call activate for extra networks in extra_network_data in specified order, then call127activate for all remaining registered networks with an empty argument list"""128129activated = []130131for extra_network, extra_network_args in lookup_extra_networks(extra_network_data).items():132133try:134extra_network.activate(p, extra_network_args)135activated.append(extra_network)136except Exception as e:137errors.display(e, f"activating extra network {extra_network.name} with arguments {extra_network_args}")138139for extra_network_name, extra_network in extra_network_registry.items():140if extra_network in activated:141continue142143try:144extra_network.activate(p, [])145except Exception as e:146errors.display(e, f"activating extra network {extra_network_name}")147148if p.scripts is not None:149p.scripts.after_extra_networks_activate(p, batch_number=p.iteration, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds, extra_network_data=extra_network_data)150151152def deactivate(p, extra_network_data):153"""call deactivate for extra networks in extra_network_data in specified order, then call154deactivate for all remaining registered networks"""155156data = lookup_extra_networks(extra_network_data)157158for extra_network in data:159try:160extra_network.deactivate(p)161except Exception as e:162errors.display(e, f"deactivating extra network {extra_network.name}")163164for extra_network_name, extra_network in extra_network_registry.items():165if extra_network in data:166continue167168try:169extra_network.deactivate(p)170except Exception as e:171errors.display(e, f"deactivating unmentioned extra network {extra_network_name}")172173174re_extra_net = re.compile(r"<(\w+):([^>]+)>")175176177def parse_prompt(prompt):178res = defaultdict(list)179180def found(m):181name = m.group(1)182args = m.group(2)183184res[name].append(ExtraNetworkParams(items=args.split(":")))185186return ""187188prompt = re.sub(re_extra_net, found, prompt)189190return prompt, res191192193def parse_prompts(prompts):194res = []195extra_data = None196197for prompt in prompts:198updated_prompt, parsed_extra_data = parse_prompt(prompt)199200if extra_data is None:201extra_data = parsed_extra_data202203res.append(updated_prompt)204205return res, extra_data206207208def get_user_metadata(filename, lister=None):209if filename is None:210return {}211212basename, ext = os.path.splitext(filename)213metadata_filename = basename + '.json'214215metadata = {}216try:217exists = lister.exists(metadata_filename) if lister else os.path.exists(metadata_filename)218if exists:219with open(metadata_filename, "r", encoding="utf8") as file:220metadata = json.load(file)221except Exception as e:222errors.display(e, f"reading extra network user metadata from {metadata_filename}")223224return metadata225226227