Path: blob/main/AUTOMATIC1111_files/styles.py
540 views
from pathlib import Path1from modules import errors2import csv3import os4import typing5import shutil678class PromptStyle(typing.NamedTuple):9name: str10prompt: str11negative_prompt: str12path: str131415def merge_prompts(style_prompt: str, prompt: str) -> str:16if "{prompt}" in style_prompt:17res = style_prompt.replace("{prompt}", prompt)18else:19parts = filter(None, (prompt.strip(), style_prompt.strip()))20res = ", ".join(parts)2122return res232425def apply_styles_to_prompt(prompt, styles):26for style in styles:27prompt = merge_prompts(style, prompt)2829return prompt303132def extract_style_text_from_prompt(style_text, prompt):33"""This function extracts the text from a given prompt based on a provided style text. It checks if the style text contains the placeholder {prompt} or if it appears at the end of the prompt. If a match is found, it returns True along with the extracted text. Otherwise, it returns False and the original prompt.3435extract_style_text_from_prompt("masterpiece", "1girl, art by greg, masterpiece") outputs (True, "1girl, art by greg")36extract_style_text_from_prompt("masterpiece, {prompt}", "masterpiece, 1girl, art by greg") outputs (True, "1girl, art by greg")37extract_style_text_from_prompt("masterpiece, {prompt}", "exquisite, 1girl, art by greg") outputs (False, "exquisite, 1girl, art by greg")38"""3940stripped_prompt = prompt.strip()41stripped_style_text = style_text.strip()4243if "{prompt}" in stripped_style_text:44left, right = stripped_style_text.split("{prompt}", 2)45if stripped_prompt.startswith(left) and stripped_prompt.endswith(right):46prompt = stripped_prompt[len(left):len(stripped_prompt)-len(right)]47return True, prompt48else:49if stripped_prompt.endswith(stripped_style_text):50prompt = stripped_prompt[:len(stripped_prompt)-len(stripped_style_text)]5152if prompt.endswith(', '):53prompt = prompt[:-2]5455return True, prompt5657return False, prompt585960def extract_original_prompts(style: PromptStyle, prompt, negative_prompt):61"""62Takes a style and compares it to the prompt and negative prompt. If the style63matches, returns True plus the prompt and negative prompt with the style text64removed. Otherwise, returns False with the original prompt and negative prompt.65"""66if not style.prompt and not style.negative_prompt:67return False, prompt, negative_prompt6869match_positive, extracted_positive = extract_style_text_from_prompt(style.prompt, prompt)70if not match_positive:71return False, prompt, negative_prompt7273match_negative, extracted_negative = extract_style_text_from_prompt(style.negative_prompt, negative_prompt)74if not match_negative:75return False, prompt, negative_prompt7677return True, extracted_positive, extracted_negative787980class StyleDatabase:81def __init__(self, paths: list[str]):82self.no_style = PromptStyle("None", "", "", None)83self.styles = {}84self.paths = paths85self.all_styles_files: list[Path] = []8687folder, file = os.path.split(self.paths[0])88if '*' in file or '?' in file:89# if the first path is a wildcard pattern, find the first match else use "folder/styles.csv" as the default path90self.default_path = next(Path(folder).glob(file), Path(os.path.join(folder, 'styles.csv')))91self.paths.insert(0, self.default_path)92else:93self.default_path = Path(self.paths[0])9495self.prompt_fields = [field for field in PromptStyle._fields if field != "path"]9697self.reload()9899def reload(self):100"""101Clears the style database and reloads the styles from the CSV file(s)102matching the path used to initialize the database.103"""104self.styles.clear()105106# scans for all styles files107all_styles_files = []108for pattern in self.paths:109folder, file = os.path.split(pattern)110if '*' in file or '?' in file:111found_files = Path(folder).glob(file)112[all_styles_files.append(file) for file in found_files]113else:114# if os.path.exists(pattern):115all_styles_files.append(Path(pattern))116117# Remove any duplicate entries118seen = set()119self.all_styles_files = [s for s in all_styles_files if not (s in seen or seen.add(s))]120121for styles_file in self.all_styles_files:122if len(all_styles_files) > 1:123# add divider when more than styles file124# '---------------- STYLES ----------------'125divider = f' {styles_file.stem.upper()} '.center(40, '-')126self.styles[divider] = PromptStyle(f"{divider}", None, None, "do_not_save")127if styles_file.is_file():128self.load_from_csv(styles_file)129130def load_from_csv(self, path: str):131try:132with open(path, "r", encoding="utf-8-sig", newline="") as file:133reader = csv.DictReader(file, skipinitialspace=True)134for row in reader:135# Ignore empty rows or rows starting with a comment136if not row or row["name"].startswith("#"):137continue138# Support loading old CSV format with "name, text"-columns139prompt = row["prompt"] if "prompt" in row else row["text"]140negative_prompt = row.get("negative_prompt", "")141# Add style to database142self.styles[row["name"]] = PromptStyle(143row["name"], prompt, negative_prompt, str(path)144)145except Exception:146errors.report(f'Error loading styles from {path}: ', exc_info=True)147148def get_style_paths(self) -> set:149"""Returns a set of all distinct paths of files that styles are loaded from."""150# Update any styles without a path to the default path151for style in list(self.styles.values()):152if not style.path:153self.styles[style.name] = style._replace(path=str(self.default_path))154155# Create a list of all distinct paths, including the default path156style_paths = set()157style_paths.add(str(self.default_path))158for _, style in self.styles.items():159if style.path:160style_paths.add(style.path)161162# Remove any paths for styles that are just list dividers163style_paths.discard("do_not_save")164165return style_paths166167def get_style_prompts(self, styles):168return [self.styles.get(x, self.no_style).prompt for x in styles]169170def get_negative_style_prompts(self, styles):171return [self.styles.get(x, self.no_style).negative_prompt for x in styles]172173def apply_styles_to_prompt(self, prompt, styles):174return apply_styles_to_prompt(175prompt, [self.styles.get(x, self.no_style).prompt for x in styles]176)177178def apply_negative_styles_to_prompt(self, prompt, styles):179return apply_styles_to_prompt(180prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles]181)182183def save_styles(self, path: str = None) -> None:184# The path argument is deprecated, but kept for backwards compatibility185186style_paths = self.get_style_paths()187188csv_names = [os.path.split(path)[1].lower() for path in style_paths]189190for style_path in style_paths:191# Always keep a backup file around192if os.path.exists(style_path):193shutil.copy(style_path, f"{style_path}.bak")194195# Write the styles to the CSV file196with open(style_path, "w", encoding="utf-8-sig", newline="") as file:197writer = csv.DictWriter(file, fieldnames=self.prompt_fields)198writer.writeheader()199for style in (s for s in self.styles.values() if s.path == style_path):200# Skip style list dividers, e.g. "STYLES.CSV"201if style.name.lower().strip("# ") in csv_names:202continue203# Write style fields, ignoring the path field204writer.writerow(205{k: v for k, v in style._asdict().items() if k != "path"}206)207208def extract_styles_from_prompt(self, prompt, negative_prompt):209extracted = []210211applicable_styles = list(self.styles.values())212213while True:214found_style = None215216for style in applicable_styles:217is_match, new_prompt, new_neg_prompt = extract_original_prompts(218style, prompt, negative_prompt219)220if is_match:221found_style = style222prompt = new_prompt223negative_prompt = new_neg_prompt224break225226if not found_style:227break228229applicable_styles.remove(found_style)230extracted.append(found_style.name)231232return list(reversed(extracted)), prompt, negative_prompt233234235