Path: blob/master/modules/initialize_util.py
3055 views
import json1import os2import signal3import sys4import re56from modules.timer import startup_timer789def gradio_server_name():10from modules.shared_cmd_options import cmd_opts1112if cmd_opts.server_name:13return cmd_opts.server_name14else:15return "0.0.0.0" if cmd_opts.listen else None161718def fix_torch_version():19import torch2021# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors22if ".dev" in torch.__version__ or "+git" in torch.__version__:23torch.__long_version__ = torch.__version__24torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)2526def fix_pytorch_lightning():27# Checks if pytorch_lightning.utilities.distributed already exists in the sys.modules cache28if 'pytorch_lightning.utilities.distributed' not in sys.modules:29import pytorch_lightning30# Lets the user know that the library was not found and then will set it to pytorch_lightning.utilities.rank_zero31print("Pytorch_lightning.distributed not found, attempting pytorch_lightning.rank_zero")32sys.modules["pytorch_lightning.utilities.distributed"] = pytorch_lightning.utilities.rank_zero3334def fix_asyncio_event_loop_policy():35"""36The default `asyncio` event loop policy only automatically creates37event loops in the main threads. Other threads must create event38loops explicitly or `asyncio.get_event_loop` (and therefore39`.IOLoop.current`) will fail. Installing this policy allows event40loops to be created automatically on any thread, matching the41behavior of Tornado versions prior to 5.0 (or 5.0 on Python 2).42"""4344import asyncio4546if sys.platform == "win32" and hasattr(asyncio, "WindowsSelectorEventLoopPolicy"):47# "Any thread" and "selector" should be orthogonal, but there's not a clean48# interface for composing policies so pick the right base.49_BasePolicy = asyncio.WindowsSelectorEventLoopPolicy # type: ignore50else:51_BasePolicy = asyncio.DefaultEventLoopPolicy5253class AnyThreadEventLoopPolicy(_BasePolicy): # type: ignore54"""Event loop policy that allows loop creation on any thread.55Usage::5657asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())58"""5960def get_event_loop(self) -> asyncio.AbstractEventLoop:61try:62return super().get_event_loop()63except (RuntimeError, AssertionError):64# This was an AssertionError in python 3.4.2 (which ships with debian jessie)65# and changed to a RuntimeError in 3.4.3.66# "There is no current event loop in thread %r"67loop = self.new_event_loop()68self.set_event_loop(loop)69return loop7071asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())727374def restore_config_state_file():75from modules import shared, config_states7677config_state_file = shared.opts.restore_config_state_file78if config_state_file == "":79return8081shared.opts.restore_config_state_file = ""82shared.opts.save(shared.config_filename)8384if os.path.isfile(config_state_file):85print(f"*** About to restore extension state from file: {config_state_file}")86with open(config_state_file, "r", encoding="utf-8") as f:87config_state = json.load(f)88config_states.restore_extension_config(config_state)89startup_timer.record("restore extension config")90elif config_state_file:91print(f"!!! Config state backup not found: {config_state_file}")929394def validate_tls_options():95from modules.shared_cmd_options import cmd_opts9697if not (cmd_opts.tls_keyfile and cmd_opts.tls_certfile):98return99100try:101if not os.path.exists(cmd_opts.tls_keyfile):102print("Invalid path to TLS keyfile given")103if not os.path.exists(cmd_opts.tls_certfile):104print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'")105except TypeError:106cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None107print("TLS setup invalid, running webui without TLS")108else:109print("Running with TLS")110startup_timer.record("TLS")111112113def get_gradio_auth_creds():114"""115Convert the gradio_auth and gradio_auth_path commandline arguments into116an iterable of (username, password) tuples.117"""118from modules.shared_cmd_options import cmd_opts119120def process_credential_line(s):121s = s.strip()122if not s:123return None124return tuple(s.split(':', 1))125126if cmd_opts.gradio_auth:127for cred in cmd_opts.gradio_auth.split(','):128cred = process_credential_line(cred)129if cred:130yield cred131132if cmd_opts.gradio_auth_path:133with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file:134for line in file.readlines():135for cred in line.strip().split(','):136cred = process_credential_line(cred)137if cred:138yield cred139140141def dumpstacks():142import threading143import traceback144145id2name = {th.ident: th.name for th in threading.enumerate()}146code = []147for threadId, stack in sys._current_frames().items():148code.append(f"\n# Thread: {id2name.get(threadId, '')}({threadId})")149for filename, lineno, name, line in traceback.extract_stack(stack):150code.append(f"""File: "{filename}", line {lineno}, in {name}""")151if line:152code.append(" " + line.strip())153154print("\n".join(code))155156157def configure_sigint_handler():158# make the program just exit at ctrl+c without waiting for anything159160from modules import shared161162def sigint_handler(sig, frame):163print(f'Interrupted with signal {sig} in {frame}')164165if shared.opts.dump_stacks_on_signal:166dumpstacks()167168os._exit(0)169170if not os.environ.get("COVERAGE_RUN"):171# Don't install the immediate-quit handler when running under coverage,172# as then the coverage report won't be generated.173signal.signal(signal.SIGINT, sigint_handler)174175176def configure_opts_onchange():177from modules import shared, sd_models, sd_vae, ui_tempdir, sd_hijack178from modules.call_queue import wrap_queued_call179180shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False)181shared.opts.onchange("sd_vae", wrap_queued_call(lambda: sd_vae.reload_vae_weights()), call=False)182shared.opts.onchange("sd_vae_overrides_per_model_preferences", wrap_queued_call(lambda: sd_vae.reload_vae_weights()), call=False)183shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)184shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)185shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)186shared.opts.onchange("fp8_storage", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False)187shared.opts.onchange("cache_fp16_weight", wrap_queued_call(lambda: sd_models.reload_model_weights(forced_reload=True)), call=False)188startup_timer.record("opts onchange")189190191def setup_middleware(app):192from starlette.middleware.gzip import GZipMiddleware193194app.middleware_stack = None # reset current middleware to allow modifying user provided list195app.add_middleware(GZipMiddleware, minimum_size=1000)196configure_cors_middleware(app)197app.build_middleware_stack() # rebuild middleware stack on-the-fly198199200def configure_cors_middleware(app):201from starlette.middleware.cors import CORSMiddleware202from modules.shared_cmd_options import cmd_opts203204cors_options = {205"allow_methods": ["*"],206"allow_headers": ["*"],207"allow_credentials": True,208}209if cmd_opts.cors_allow_origins:210cors_options["allow_origins"] = cmd_opts.cors_allow_origins.split(',')211if cmd_opts.cors_allow_origins_regex:212cors_options["allow_origin_regex"] = cmd_opts.cors_allow_origins_regex213app.add_middleware(CORSMiddleware, **cors_options)214215216217