Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TheLastBen
GitHub Repository: TheLastBen/fast-stable-diffusion
Path: blob/main/AUTOMATIC1111_files/styles.py
540 views
1
from pathlib import Path
2
from modules import errors
3
import csv
4
import os
5
import typing
6
import shutil
7
8
9
class PromptStyle(typing.NamedTuple):
10
name: str
11
prompt: str
12
negative_prompt: str
13
path: str
14
15
16
def merge_prompts(style_prompt: str, prompt: str) -> str:
17
if "{prompt}" in style_prompt:
18
res = style_prompt.replace("{prompt}", prompt)
19
else:
20
parts = filter(None, (prompt.strip(), style_prompt.strip()))
21
res = ", ".join(parts)
22
23
return res
24
25
26
def apply_styles_to_prompt(prompt, styles):
27
for style in styles:
28
prompt = merge_prompts(style, prompt)
29
30
return prompt
31
32
33
def extract_style_text_from_prompt(style_text, prompt):
34
"""This function extracts the text from a given prompt based on a provided style text. It checks if the style text contains the placeholder {prompt} or if it appears at the end of the prompt. If a match is found, it returns True along with the extracted text. Otherwise, it returns False and the original prompt.
35
36
extract_style_text_from_prompt("masterpiece", "1girl, art by greg, masterpiece") outputs (True, "1girl, art by greg")
37
extract_style_text_from_prompt("masterpiece, {prompt}", "masterpiece, 1girl, art by greg") outputs (True, "1girl, art by greg")
38
extract_style_text_from_prompt("masterpiece, {prompt}", "exquisite, 1girl, art by greg") outputs (False, "exquisite, 1girl, art by greg")
39
"""
40
41
stripped_prompt = prompt.strip()
42
stripped_style_text = style_text.strip()
43
44
if "{prompt}" in stripped_style_text:
45
left, right = stripped_style_text.split("{prompt}", 2)
46
if stripped_prompt.startswith(left) and stripped_prompt.endswith(right):
47
prompt = stripped_prompt[len(left):len(stripped_prompt)-len(right)]
48
return True, prompt
49
else:
50
if stripped_prompt.endswith(stripped_style_text):
51
prompt = stripped_prompt[:len(stripped_prompt)-len(stripped_style_text)]
52
53
if prompt.endswith(', '):
54
prompt = prompt[:-2]
55
56
return True, prompt
57
58
return False, prompt
59
60
61
def extract_original_prompts(style: PromptStyle, prompt, negative_prompt):
62
"""
63
Takes a style and compares it to the prompt and negative prompt. If the style
64
matches, returns True plus the prompt and negative prompt with the style text
65
removed. Otherwise, returns False with the original prompt and negative prompt.
66
"""
67
if not style.prompt and not style.negative_prompt:
68
return False, prompt, negative_prompt
69
70
match_positive, extracted_positive = extract_style_text_from_prompt(style.prompt, prompt)
71
if not match_positive:
72
return False, prompt, negative_prompt
73
74
match_negative, extracted_negative = extract_style_text_from_prompt(style.negative_prompt, negative_prompt)
75
if not match_negative:
76
return False, prompt, negative_prompt
77
78
return True, extracted_positive, extracted_negative
79
80
81
class StyleDatabase:
82
def __init__(self, paths: list[str]):
83
self.no_style = PromptStyle("None", "", "", None)
84
self.styles = {}
85
self.paths = paths
86
self.all_styles_files: list[Path] = []
87
88
folder, file = os.path.split(self.paths[0])
89
if '*' in file or '?' in file:
90
# if the first path is a wildcard pattern, find the first match else use "folder/styles.csv" as the default path
91
self.default_path = next(Path(folder).glob(file), Path(os.path.join(folder, 'styles.csv')))
92
self.paths.insert(0, self.default_path)
93
else:
94
self.default_path = Path(self.paths[0])
95
96
self.prompt_fields = [field for field in PromptStyle._fields if field != "path"]
97
98
self.reload()
99
100
def reload(self):
101
"""
102
Clears the style database and reloads the styles from the CSV file(s)
103
matching the path used to initialize the database.
104
"""
105
self.styles.clear()
106
107
# scans for all styles files
108
all_styles_files = []
109
for pattern in self.paths:
110
folder, file = os.path.split(pattern)
111
if '*' in file or '?' in file:
112
found_files = Path(folder).glob(file)
113
[all_styles_files.append(file) for file in found_files]
114
else:
115
# if os.path.exists(pattern):
116
all_styles_files.append(Path(pattern))
117
118
# Remove any duplicate entries
119
seen = set()
120
self.all_styles_files = [s for s in all_styles_files if not (s in seen or seen.add(s))]
121
122
for styles_file in self.all_styles_files:
123
if len(all_styles_files) > 1:
124
# add divider when more than styles file
125
# '---------------- STYLES ----------------'
126
divider = f' {styles_file.stem.upper()} '.center(40, '-')
127
self.styles[divider] = PromptStyle(f"{divider}", None, None, "do_not_save")
128
if styles_file.is_file():
129
self.load_from_csv(styles_file)
130
131
def load_from_csv(self, path: str):
132
try:
133
with open(path, "r", encoding="utf-8-sig", newline="") as file:
134
reader = csv.DictReader(file, skipinitialspace=True)
135
for row in reader:
136
# Ignore empty rows or rows starting with a comment
137
if not row or row["name"].startswith("#"):
138
continue
139
# Support loading old CSV format with "name, text"-columns
140
prompt = row["prompt"] if "prompt" in row else row["text"]
141
negative_prompt = row.get("negative_prompt", "")
142
# Add style to database
143
self.styles[row["name"]] = PromptStyle(
144
row["name"], prompt, negative_prompt, str(path)
145
)
146
except Exception:
147
errors.report(f'Error loading styles from {path}: ', exc_info=True)
148
149
def get_style_paths(self) -> set:
150
"""Returns a set of all distinct paths of files that styles are loaded from."""
151
# Update any styles without a path to the default path
152
for style in list(self.styles.values()):
153
if not style.path:
154
self.styles[style.name] = style._replace(path=str(self.default_path))
155
156
# Create a list of all distinct paths, including the default path
157
style_paths = set()
158
style_paths.add(str(self.default_path))
159
for _, style in self.styles.items():
160
if style.path:
161
style_paths.add(style.path)
162
163
# Remove any paths for styles that are just list dividers
164
style_paths.discard("do_not_save")
165
166
return style_paths
167
168
def get_style_prompts(self, styles):
169
return [self.styles.get(x, self.no_style).prompt for x in styles]
170
171
def get_negative_style_prompts(self, styles):
172
return [self.styles.get(x, self.no_style).negative_prompt for x in styles]
173
174
def apply_styles_to_prompt(self, prompt, styles):
175
return apply_styles_to_prompt(
176
prompt, [self.styles.get(x, self.no_style).prompt for x in styles]
177
)
178
179
def apply_negative_styles_to_prompt(self, prompt, styles):
180
return apply_styles_to_prompt(
181
prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles]
182
)
183
184
def save_styles(self, path: str = None) -> None:
185
# The path argument is deprecated, but kept for backwards compatibility
186
187
style_paths = self.get_style_paths()
188
189
csv_names = [os.path.split(path)[1].lower() for path in style_paths]
190
191
for style_path in style_paths:
192
# Always keep a backup file around
193
if os.path.exists(style_path):
194
shutil.copy(style_path, f"{style_path}.bak")
195
196
# Write the styles to the CSV file
197
with open(style_path, "w", encoding="utf-8-sig", newline="") as file:
198
writer = csv.DictWriter(file, fieldnames=self.prompt_fields)
199
writer.writeheader()
200
for style in (s for s in self.styles.values() if s.path == style_path):
201
# Skip style list dividers, e.g. "STYLES.CSV"
202
if style.name.lower().strip("# ") in csv_names:
203
continue
204
# Write style fields, ignoring the path field
205
writer.writerow(
206
{k: v for k, v in style._asdict().items() if k != "path"}
207
)
208
209
def extract_styles_from_prompt(self, prompt, negative_prompt):
210
extracted = []
211
212
applicable_styles = list(self.styles.values())
213
214
while True:
215
found_style = None
216
217
for style in applicable_styles:
218
is_match, new_prompt, new_neg_prompt = extract_original_prompts(
219
style, prompt, negative_prompt
220
)
221
if is_match:
222
found_style = style
223
prompt = new_prompt
224
negative_prompt = new_neg_prompt
225
break
226
227
if not found_style:
228
break
229
230
applicable_styles.remove(found_style)
231
extracted.append(found_style.name)
232
233
return list(reversed(extracted)), prompt, negative_prompt
234
235