Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
automatic1111
GitHub Repository: automatic1111/stable-diffusion-webui
Path: blob/master/modules/initialize_util.py
3055 views
1
import json
2
import os
3
import signal
4
import sys
5
import re
6
7
from modules.timer import startup_timer
8
9
10
def gradio_server_name():
11
from modules.shared_cmd_options import cmd_opts
12
13
if cmd_opts.server_name:
14
return cmd_opts.server_name
15
else:
16
return "0.0.0.0" if cmd_opts.listen else None
17
18
19
def fix_torch_version():
20
import torch
21
22
# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
23
if ".dev" in torch.__version__ or "+git" in torch.__version__:
24
torch.__long_version__ = torch.__version__
25
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
26
27
def fix_pytorch_lightning():
28
# Checks if pytorch_lightning.utilities.distributed already exists in the sys.modules cache
29
if 'pytorch_lightning.utilities.distributed' not in sys.modules:
30
import pytorch_lightning
31
# Lets the user know that the library was not found and then will set it to pytorch_lightning.utilities.rank_zero
32
print("Pytorch_lightning.distributed not found, attempting pytorch_lightning.rank_zero")
33
sys.modules["pytorch_lightning.utilities.distributed"] = pytorch_lightning.utilities.rank_zero
34
35
def fix_asyncio_event_loop_policy():
36
"""
37
The default `asyncio` event loop policy only automatically creates
38
event loops in the main threads. Other threads must create event
39
loops explicitly or `asyncio.get_event_loop` (and therefore
40
`.IOLoop.current`) will fail. Installing this policy allows event
41
loops to be created automatically on any thread, matching the
42
behavior of Tornado versions prior to 5.0 (or 5.0 on Python 2).
43
"""
44
45
import asyncio
46
47
if sys.platform == "win32" and hasattr(asyncio, "WindowsSelectorEventLoopPolicy"):
48
# "Any thread" and "selector" should be orthogonal, but there's not a clean
49
# interface for composing policies so pick the right base.
50
_BasePolicy = asyncio.WindowsSelectorEventLoopPolicy # type: ignore
51
else:
52
_BasePolicy = asyncio.DefaultEventLoopPolicy
53
54
class AnyThreadEventLoopPolicy(_BasePolicy): # type: ignore
55
"""Event loop policy that allows loop creation on any thread.
56
Usage::
57
58
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
59
"""
60
61
def get_event_loop(self) -> asyncio.AbstractEventLoop:
62
try:
63
return super().get_event_loop()
64
except (RuntimeError, AssertionError):
65
# This was an AssertionError in python 3.4.2 (which ships with debian jessie)
66
# and changed to a RuntimeError in 3.4.3.
67
# "There is no current event loop in thread %r"
68
loop = self.new_event_loop()
69
self.set_event_loop(loop)
70
return loop
71
72
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
73
74
75
def restore_config_state_file():
76
from modules import shared, config_states
77
78
config_state_file = shared.opts.restore_config_state_file
79
if config_state_file == "":
80
return
81
82
shared.opts.restore_config_state_file = ""
83
shared.opts.save(shared.config_filename)
84
85
if os.path.isfile(config_state_file):
86
print(f"*** About to restore extension state from file: {config_state_file}")
87
with open(config_state_file, "r", encoding="utf-8") as f:
88
config_state = json.load(f)
89
config_states.restore_extension_config(config_state)
90
startup_timer.record("restore extension config")
91
elif config_state_file:
92
print(f"!!! Config state backup not found: {config_state_file}")
93
94
95
def validate_tls_options():
96
from modules.shared_cmd_options import cmd_opts
97
98
if not (cmd_opts.tls_keyfile and cmd_opts.tls_certfile):
99
return
100
101
try:
102
if not os.path.exists(cmd_opts.tls_keyfile):
103
print("Invalid path to TLS keyfile given")
104
if not os.path.exists(cmd_opts.tls_certfile):
105
print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'")
106
except TypeError:
107
cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None
108
print("TLS setup invalid, running webui without TLS")
109
else:
110
print("Running with TLS")
111
startup_timer.record("TLS")
112
113
114
def get_gradio_auth_creds():
115
"""
116
Convert the gradio_auth and gradio_auth_path commandline arguments into
117
an iterable of (username, password) tuples.
118
"""
119
from modules.shared_cmd_options import cmd_opts
120
121
def process_credential_line(s):
122
s = s.strip()
123
if not s:
124
return None
125
return tuple(s.split(':', 1))
126
127
if cmd_opts.gradio_auth:
128
for cred in cmd_opts.gradio_auth.split(','):
129
cred = process_credential_line(cred)
130
if cred:
131
yield cred
132
133
if cmd_opts.gradio_auth_path:
134
with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file:
135
for line in file.readlines():
136
for cred in line.strip().split(','):
137
cred = process_credential_line(cred)
138
if cred:
139
yield cred
140
141
142
def dumpstacks():
143
import threading
144
import traceback
145
146
id2name = {th.ident: th.name for th in threading.enumerate()}
147
code = []
148
for threadId, stack in sys._current_frames().items():
149
code.append(f"\n# Thread: {id2name.get(threadId, '')}({threadId})")
150
for filename, lineno, name, line in traceback.extract_stack(stack):
151
code.append(f"""File: "{filename}", line {lineno}, in {name}""")
152
if line:
153
code.append(" " + line.strip())
154
155
print("\n".join(code))
156
157
158
def configure_sigint_handler():
159
# make the program just exit at ctrl+c without waiting for anything
160
161
from modules import shared
162
163
def sigint_handler(sig, frame):
164
print(f'Interrupted with signal {sig} in {frame}')
165
166
if shared.opts.dump_stacks_on_signal:
167
dumpstacks()
168
169
os._exit(0)
170
171
if not os.environ.get("COVERAGE_RUN"):
172
# Don't install the immediate-quit handler when running under coverage,
173
# as then the coverage report won't be generated.
174
signal.signal(signal.SIGINT, sigint_handler)
175
176
177
def configure_opts_onchange():
178
from modules import shared, sd_models, sd_vae, ui_tempdir, sd_hijack
179
from modules.call_queue import wrap_queued_call
180
181
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False)
182
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: sd_vae.reload_vae_weights()), call=False)
183
shared.opts.onchange("sd_vae_overrides_per_model_preferences", wrap_queued_call(lambda: sd_vae.reload_vae_weights()), call=False)
184
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
185
shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
186
shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
187
shared.opts.onchange("fp8_storage", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False)
188
shared.opts.onchange("cache_fp16_weight", wrap_queued_call(lambda: sd_models.reload_model_weights(forced_reload=True)), call=False)
189
startup_timer.record("opts onchange")
190
191
192
def setup_middleware(app):
193
from starlette.middleware.gzip import GZipMiddleware
194
195
app.middleware_stack = None # reset current middleware to allow modifying user provided list
196
app.add_middleware(GZipMiddleware, minimum_size=1000)
197
configure_cors_middleware(app)
198
app.build_middleware_stack() # rebuild middleware stack on-the-fly
199
200
201
def configure_cors_middleware(app):
202
from starlette.middleware.cors import CORSMiddleware
203
from modules.shared_cmd_options import cmd_opts
204
205
cors_options = {
206
"allow_methods": ["*"],
207
"allow_headers": ["*"],
208
"allow_credentials": True,
209
}
210
if cmd_opts.cors_allow_origins:
211
cors_options["allow_origins"] = cmd_opts.cors_allow_origins.split(',')
212
if cmd_opts.cors_allow_origins_regex:
213
cors_options["allow_origin_regex"] = cmd_opts.cors_allow_origins_regex
214
app.add_middleware(CORSMiddleware, **cors_options)
215
216
217