Path: blob/master/modules/extensions.py
3055 views
from __future__ import annotations12import configparser3import dataclasses4import os5import threading6import re78from modules import shared, errors, cache, scripts9from modules.gitpython_hack import Repo10from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F4011112extensions: list[Extension] = []13extension_paths: dict[str, Extension] = {}14loaded_extensions: dict[str, Exception] = {}151617os.makedirs(extensions_dir, exist_ok=True)181920def active():21if shared.cmd_opts.disable_all_extensions or shared.opts.disable_all_extensions == "all":22return []23elif shared.cmd_opts.disable_extra_extensions or shared.opts.disable_all_extensions == "extra":24return [x for x in extensions if x.enabled and x.is_builtin]25else:26return [x for x in extensions if x.enabled]272829@dataclasses.dataclass30class CallbackOrderInfo:31name: str32before: list33after: list343536class ExtensionMetadata:37filename = "metadata.ini"38config: configparser.ConfigParser39canonical_name: str40requires: list4142def __init__(self, path, canonical_name):43self.config = configparser.ConfigParser()4445filepath = os.path.join(path, self.filename)46# `self.config.read()` will quietly swallow OSErrors (which FileNotFoundError is),47# so no need to check whether the file exists beforehand.48try:49self.config.read(filepath)50except Exception:51errors.report(f"Error reading {self.filename} for extension {canonical_name}.", exc_info=True)5253self.canonical_name = self.config.get("Extension", "Name", fallback=canonical_name)54self.canonical_name = canonical_name.lower().strip()5556self.requires = None5758def get_script_requirements(self, field, section, extra_section=None):59"""reads a list of requirements from the config; field is the name of the field in the ini file,60like Requires or Before, and section is the name of the [section] in the ini file; additionally,61reads more requirements from [extra_section] if specified."""6263x = self.config.get(section, field, fallback='')6465if extra_section:66x = x + ', ' + self.config.get(extra_section, field, fallback='')6768listed_requirements = self.parse_list(x.lower())69res = []7071for requirement in listed_requirements:72loaded_requirements = (x for x in requirement.split("|") if x in loaded_extensions)73relevant_requirement = next(loaded_requirements, requirement)74res.append(relevant_requirement)7576return res7778def parse_list(self, text):79"""converts a line from config ("ext1 ext2, ext3 ") into a python list (["ext1", "ext2", "ext3"])"""8081if not text:82return []8384# both "," and " " are accepted as separator85return [x for x in re.split(r"[,\s]+", text.strip()) if x]8687def list_callback_order_instructions(self):88for section in self.config.sections():89if not section.startswith("callbacks/"):90continue9192callback_name = section[10:]9394if not callback_name.startswith(self.canonical_name):95errors.report(f"Callback order section for extension {self.canonical_name} is referencing the wrong extension: {section}")96continue9798before = self.parse_list(self.config.get(section, 'Before', fallback=''))99after = self.parse_list(self.config.get(section, 'After', fallback=''))100101yield CallbackOrderInfo(callback_name, before, after)102103104class Extension:105lock = threading.Lock()106cached_fields = ['remote', 'commit_date', 'branch', 'commit_hash', 'version']107metadata: ExtensionMetadata108109def __init__(self, name, path, enabled=True, is_builtin=False, metadata=None):110self.name = name111self.path = path112self.enabled = enabled113self.status = ''114self.can_update = False115self.is_builtin = is_builtin116self.commit_hash = ''117self.commit_date = None118self.version = ''119self.branch = None120self.remote = None121self.have_info_from_repo = False122self.metadata = metadata if metadata else ExtensionMetadata(self.path, name.lower())123self.canonical_name = metadata.canonical_name124125def to_dict(self):126return {x: getattr(self, x) for x in self.cached_fields}127128def from_dict(self, d):129for field in self.cached_fields:130setattr(self, field, d[field])131132def read_info_from_repo(self):133if self.is_builtin or self.have_info_from_repo:134return135136def read_from_repo():137with self.lock:138if self.have_info_from_repo:139return140141self.do_read_info_from_repo()142143return self.to_dict()144145try:146d = cache.cached_data_for_file('extensions-git', self.name, os.path.join(self.path, ".git"), read_from_repo)147self.from_dict(d)148except FileNotFoundError:149pass150self.status = 'unknown' if self.status == '' else self.status151152def do_read_info_from_repo(self):153repo = None154try:155if os.path.exists(os.path.join(self.path, ".git")):156repo = Repo(self.path)157except Exception:158errors.report(f"Error reading github repository info from {self.path}", exc_info=True)159160if repo is None or repo.bare:161self.remote = None162else:163try:164self.remote = next(repo.remote().urls, None)165commit = repo.head.commit166self.commit_date = commit.committed_date167if repo.active_branch:168self.branch = repo.active_branch.name169self.commit_hash = commit.hexsha170self.version = self.commit_hash[:8]171172except Exception:173errors.report(f"Failed reading extension data from Git repository ({self.name})", exc_info=True)174self.remote = None175176self.have_info_from_repo = True177178def list_files(self, subdir, extension):179dirpath = os.path.join(self.path, subdir)180if not os.path.isdir(dirpath):181return []182183res = []184for filename in sorted(os.listdir(dirpath)):185res.append(scripts.ScriptFile(self.path, filename, os.path.join(dirpath, filename)))186187res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]188189return res190191def check_updates(self):192repo = Repo(self.path)193branch_name = f'{repo.remote().name}/{self.branch}'194for fetch in repo.remote().fetch(dry_run=True):195if self.branch and fetch.name != branch_name:196continue197if fetch.flags != fetch.HEAD_UPTODATE:198self.can_update = True199self.status = "new commits"200return201202try:203origin = repo.rev_parse(branch_name)204if repo.head.commit != origin:205self.can_update = True206self.status = "behind HEAD"207return208except Exception:209self.can_update = False210self.status = "unknown (remote error)"211return212213self.can_update = False214self.status = "latest"215216def fetch_and_reset_hard(self, commit=None):217repo = Repo(self.path)218if commit is None:219commit = f'{repo.remote().name}/{self.branch}'220# Fix: `error: Your local changes to the following files would be overwritten by merge`,221# because WSL2 Docker set 755 file permissions instead of 644, this results to the error.222repo.git.fetch(all=True)223repo.git.reset(commit, hard=True)224self.have_info_from_repo = False225226227def list_extensions():228extensions.clear()229extension_paths.clear()230loaded_extensions.clear()231232if shared.cmd_opts.disable_all_extensions:233print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***")234elif shared.opts.disable_all_extensions == "all":235print("*** \"Disable all extensions\" option was set, will not load any extensions ***")236elif shared.cmd_opts.disable_extra_extensions:237print("*** \"--disable-extra-extensions\" arg was used, will only load built-in extensions ***")238elif shared.opts.disable_all_extensions == "extra":239print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***")240241242# scan through extensions directory and load metadata243for dirname in [extensions_builtin_dir, extensions_dir]:244if not os.path.isdir(dirname):245continue246247for extension_dirname in sorted(os.listdir(dirname)):248path = os.path.join(dirname, extension_dirname)249if not os.path.isdir(path):250continue251252canonical_name = extension_dirname253metadata = ExtensionMetadata(path, canonical_name)254255# check for duplicated canonical names256already_loaded_extension = loaded_extensions.get(metadata.canonical_name)257if already_loaded_extension is not None:258errors.report(f'Duplicate canonical name "{canonical_name}" found in extensions "{extension_dirname}" and "{already_loaded_extension.name}". Former will be discarded.', exc_info=False)259continue260261is_builtin = dirname == extensions_builtin_dir262extension = Extension(name=extension_dirname, path=path, enabled=extension_dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin, metadata=metadata)263extensions.append(extension)264extension_paths[extension.path] = extension265loaded_extensions[canonical_name] = extension266267for extension in extensions:268extension.metadata.requires = extension.metadata.get_script_requirements("Requires", "Extension")269270# check for requirements271for extension in extensions:272if not extension.enabled:273continue274275for req in extension.metadata.requires:276required_extension = loaded_extensions.get(req)277if required_extension is None:278errors.report(f'Extension "{extension.name}" requires "{req}" which is not installed.', exc_info=False)279continue280281if not required_extension.enabled:282errors.report(f'Extension "{extension.name}" requires "{required_extension.name}" which is disabled.', exc_info=False)283continue284285286def find_extension(filename):287parentdir = os.path.dirname(os.path.realpath(filename))288289while parentdir != filename:290extension = extension_paths.get(parentdir)291if extension is not None:292return extension293294filename = parentdir295parentdir = os.path.dirname(filename)296297return None298299300301