Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
automatic1111
GitHub Repository: automatic1111/stable-diffusion-webui
Path: blob/master/modules/extra_networks.py
3055 views
1
import json
2
import os
3
import re
4
import logging
5
from collections import defaultdict
6
7
from modules import errors
8
9
extra_network_registry = {}
10
extra_network_aliases = {}
11
12
13
def initialize():
14
extra_network_registry.clear()
15
extra_network_aliases.clear()
16
17
18
def register_extra_network(extra_network):
19
extra_network_registry[extra_network.name] = extra_network
20
21
22
def register_extra_network_alias(extra_network, alias):
23
extra_network_aliases[alias] = extra_network
24
25
26
def register_default_extra_networks():
27
from modules.extra_networks_hypernet import ExtraNetworkHypernet
28
register_extra_network(ExtraNetworkHypernet())
29
30
31
class ExtraNetworkParams:
32
def __init__(self, items=None):
33
self.items = items or []
34
self.positional = []
35
self.named = {}
36
37
for item in self.items:
38
parts = item.split('=', 2) if isinstance(item, str) else [item]
39
if len(parts) == 2:
40
self.named[parts[0]] = parts[1]
41
else:
42
self.positional.append(item)
43
44
def __eq__(self, other):
45
return self.items == other.items
46
47
48
class ExtraNetwork:
49
def __init__(self, name):
50
self.name = name
51
52
def activate(self, p, params_list):
53
"""
54
Called by processing on every run. Whatever the extra network is meant to do should be activated here.
55
Passes arguments related to this extra network in params_list.
56
User passes arguments by specifying this in his prompt:
57
58
<name:arg1:arg2:arg3>
59
60
Where name matches the name of this ExtraNetwork object, and arg1:arg2:arg3 are any natural number of text arguments
61
separated by colon.
62
63
Even if the user does not mention this ExtraNetwork in his prompt, the call will still be made, with empty params_list -
64
in this case, all effects of this extra networks should be disabled.
65
66
Can be called multiple times before deactivate() - each new call should override the previous call completely.
67
68
For example, if this ExtraNetwork's name is 'hypernet' and user's prompt is:
69
70
> "1girl, <hypernet:agm:1.1> <extrasupernet:master:12:13:14> <hypernet:ray>"
71
72
params_list will be:
73
74
[
75
ExtraNetworkParams(items=["agm", "1.1"]),
76
ExtraNetworkParams(items=["ray"])
77
]
78
79
"""
80
raise NotImplementedError
81
82
def deactivate(self, p):
83
"""
84
Called at the end of processing for housekeeping. No need to do anything here.
85
"""
86
87
raise NotImplementedError
88
89
90
def lookup_extra_networks(extra_network_data):
91
"""returns a dict mapping ExtraNetwork objects to lists of arguments for those extra networks.
92
93
Example input:
94
{
95
'lora': [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D58310>],
96
'lyco': [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D58F70>],
97
'hypernet': [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D5A800>]
98
}
99
100
Example output:
101
102
{
103
<extra_networks_lora.ExtraNetworkLora object at 0x0000020581BEECE0>: [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D58310>, <modules.extra_networks.ExtraNetworkParams object at 0x0000020690D58F70>],
104
<modules.extra_networks_hypernet.ExtraNetworkHypernet object at 0x0000020581BEEE60>: [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D5A800>]
105
}
106
"""
107
108
res = {}
109
110
for extra_network_name, extra_network_args in list(extra_network_data.items()):
111
extra_network = extra_network_registry.get(extra_network_name, None)
112
alias = extra_network_aliases.get(extra_network_name, None)
113
114
if alias is not None and extra_network is None:
115
extra_network = alias
116
117
if extra_network is None:
118
logging.info(f"Skipping unknown extra network: {extra_network_name}")
119
continue
120
121
res.setdefault(extra_network, []).extend(extra_network_args)
122
123
return res
124
125
126
def activate(p, extra_network_data):
127
"""call activate for extra networks in extra_network_data in specified order, then call
128
activate for all remaining registered networks with an empty argument list"""
129
130
activated = []
131
132
for extra_network, extra_network_args in lookup_extra_networks(extra_network_data).items():
133
134
try:
135
extra_network.activate(p, extra_network_args)
136
activated.append(extra_network)
137
except Exception as e:
138
errors.display(e, f"activating extra network {extra_network.name} with arguments {extra_network_args}")
139
140
for extra_network_name, extra_network in extra_network_registry.items():
141
if extra_network in activated:
142
continue
143
144
try:
145
extra_network.activate(p, [])
146
except Exception as e:
147
errors.display(e, f"activating extra network {extra_network_name}")
148
149
if p.scripts is not None:
150
p.scripts.after_extra_networks_activate(p, batch_number=p.iteration, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds, extra_network_data=extra_network_data)
151
152
153
def deactivate(p, extra_network_data):
154
"""call deactivate for extra networks in extra_network_data in specified order, then call
155
deactivate for all remaining registered networks"""
156
157
data = lookup_extra_networks(extra_network_data)
158
159
for extra_network in data:
160
try:
161
extra_network.deactivate(p)
162
except Exception as e:
163
errors.display(e, f"deactivating extra network {extra_network.name}")
164
165
for extra_network_name, extra_network in extra_network_registry.items():
166
if extra_network in data:
167
continue
168
169
try:
170
extra_network.deactivate(p)
171
except Exception as e:
172
errors.display(e, f"deactivating unmentioned extra network {extra_network_name}")
173
174
175
re_extra_net = re.compile(r"<(\w+):([^>]+)>")
176
177
178
def parse_prompt(prompt):
179
res = defaultdict(list)
180
181
def found(m):
182
name = m.group(1)
183
args = m.group(2)
184
185
res[name].append(ExtraNetworkParams(items=args.split(":")))
186
187
return ""
188
189
prompt = re.sub(re_extra_net, found, prompt)
190
191
return prompt, res
192
193
194
def parse_prompts(prompts):
195
res = []
196
extra_data = None
197
198
for prompt in prompts:
199
updated_prompt, parsed_extra_data = parse_prompt(prompt)
200
201
if extra_data is None:
202
extra_data = parsed_extra_data
203
204
res.append(updated_prompt)
205
206
return res, extra_data
207
208
209
def get_user_metadata(filename, lister=None):
210
if filename is None:
211
return {}
212
213
basename, ext = os.path.splitext(filename)
214
metadata_filename = basename + '.json'
215
216
metadata = {}
217
try:
218
exists = lister.exists(metadata_filename) if lister else os.path.exists(metadata_filename)
219
if exists:
220
with open(metadata_filename, "r", encoding="utf8") as file:
221
metadata = json.load(file)
222
except Exception as e:
223
errors.display(e, f"reading extra network user metadata from {metadata_filename}")
224
225
return metadata
226
227