Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
KoboldAI
GitHub Repository: KoboldAI/KoboldAI-Client
Path: blob/main/aiserver.py
471 views
1
#!/usr/bin/python3
2
#==================================================================#
3
# KoboldAI
4
# Version: 1.19.2
5
# By: The KoboldAI Community
6
#==================================================================#
7
8
# External packages
9
import eventlet
10
eventlet.monkey_patch(all=True, thread=False, os=False)
11
import os
12
os.system("")
13
__file__ = os.path.dirname(os.path.realpath(__file__))
14
os.chdir(__file__)
15
os.environ['EVENTLET_THREADPOOL_SIZE'] = '1'
16
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
17
from eventlet import tpool
18
19
import logging
20
from logger import logger, set_logger_verbosity, quiesce_logger
21
22
logging.getLogger("urllib3").setLevel(logging.ERROR)
23
24
from os import path, getcwd
25
import time
26
import re
27
import json
28
import collections
29
import zipfile
30
import packaging
31
import packaging.version
32
import contextlib
33
import traceback
34
import threading
35
import markdown
36
import bleach
37
import itertools
38
import bisect
39
import functools
40
import traceback
41
import inspect
42
import warnings
43
from collections.abc import Iterable
44
from typing import Any, Callable, TypeVar, Tuple, Union, Dict, Set, List, Optional, Type
45
46
import requests
47
import html
48
import argparse
49
import sys
50
import gc
51
52
import lupa
53
import importlib
54
55
# KoboldAI
56
import fileops
57
import gensettings
58
from utils import debounce
59
import utils
60
import structures
61
import torch
62
from transformers import StoppingCriteria, GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, modeling_utils
63
from transformers import __version__ as transformers_version
64
import transformers
65
try:
66
from transformers.models.opt.modeling_opt import OPTDecoder
67
except:
68
pass
69
import transformers.generation_utils
70
71
global tpu_mtj_backend
72
73
74
if lupa.LUA_VERSION[:2] != (5, 4):
75
logger.error(f"Please install lupa==1.10. You have lupa {lupa.__version__}.")
76
77
patch_causallm_patched = False
78
79
# Make sure tqdm progress bars display properly in Colab
80
from tqdm.auto import tqdm
81
old_init = tqdm.__init__
82
def new_init(self, *args, **kwargs):
83
old_init(self, *args, **kwargs)
84
if(self.ncols == 0 and kwargs.get("ncols") != 0):
85
self.ncols = 99
86
tqdm.__init__ = new_init
87
88
# Fix some issues with the OPT tokenizer
89
from transformers import PreTrainedTokenizerBase
90
old_pretrainedtokenizerbase_from_pretrained = PreTrainedTokenizerBase.from_pretrained.__func__
91
@classmethod
92
def new_pretrainedtokenizerbase_from_pretrained(cls, *args, **kwargs):
93
tokenizer = old_pretrainedtokenizerbase_from_pretrained(cls, *args, **kwargs)
94
tokenizer._koboldai_header = tokenizer.encode("")
95
tokenizer.add_bos_token = False
96
tokenizer.add_prefix_space = False
97
return tokenizer
98
PreTrainedTokenizerBase.from_pretrained = new_pretrainedtokenizerbase_from_pretrained
99
100
#==================================================================#
101
# Variables & Storage
102
#==================================================================#
103
104
# Terminal tags for colored text
105
class colors:
106
PURPLE = '\033[95m'
107
BLUE = '\033[94m'
108
CYAN = '\033[96m'
109
GREEN = '\033[92m'
110
YELLOW = '\033[93m'
111
RED = '\033[91m'
112
END = '\033[0m'
113
UNDERLINE = '\033[4m'
114
115
# AI models Menu
116
# This is a dict of lists where they key is the menu name, and the list is the menu items.
117
# Each item takes the 4 elements, 1: Text to display, 2: Model Name (var.model) or menu name (Key name for another menu),
118
# 3: the memory requirement for the model, 4: if the item is a menu or not (True/False)
119
model_menu = {
120
'mainmenu': [
121
["Load a model from its directory", "NeoCustom", "", False],
122
["Load an old GPT-2 model (eg CloverEdition)", "GPT2Custom", "", False],
123
["Adventure Models", "adventurelist", "", True],
124
["Novel Models", "novellist", "", True],
125
["NSFW Models", "nsfwlist", "", True],
126
["Untuned OPT", "optlist", "", True],
127
["Untuned GPT-Neo/J", "gptneolist", "", True],
128
["Untuned Pythia", "pythialist", "", True],
129
["Untuned Fairseq Dense", "fsdlist", "", True],
130
["Untuned Bloom", "bloomlist", "", True],
131
["Untuned XGLM", "xglmlist", "", True],
132
["Untuned GPT2", "gpt2list", "", True],
133
["Online Services", "apilist", "", True],
134
["Read Only (No AI)", "ReadOnly", "", False]
135
],
136
'adventurelist': [
137
["Skein 20B", "KoboldAI/GPT-NeoX-20B-Skein", "64GB", False],
138
["Nerys OPT 13B V2 (Hybrid)", "KoboldAI/OPT-13B-Nerys-v2", "32GB", False],
139
["Nerys FSD 13B V2 (Hybrid)", "KoboldAI/fairseq-dense-13B-Nerys-v2", "32GB", False],
140
["Nerys FSD 13B (Hybrid)", "KoboldAI/fairseq-dense-13B-Nerys", "32GB", False],
141
["Skein 6B", "KoboldAI/GPT-J-6B-Skein", "16GB", False],
142
["OPT Nerys 6B V2 (Hybrid)", "KoboldAI/OPT-6B-nerys-v2", "16GB", False],
143
["Adventure 6B", "KoboldAI/GPT-J-6B-Adventure", "16GB", False],
144
["Nerys FSD 2.7B (Hybrid)", "KoboldAI/fairseq-dense-2.7B-Nerys", "8GB", False],
145
["Adventure 2.7B", "KoboldAI/GPT-Neo-2.7B-AID", "8GB", False],
146
["Adventure 1.3B", "KoboldAI/GPT-Neo-1.3B-Adventure", "6GB", False],
147
["Adventure 125M (Mia)", "Merry/AID-Neo-125M", "2GB", False],
148
["Return to Main Menu", "mainmenu", "", True],
149
],
150
'novellist': [
151
["Nerys OPT 13B V2 (Hybrid)", "KoboldAI/OPT-13B-Nerys-v2", "32GB", False],
152
["Nerys FSD 13B V2 (Hybrid)", "KoboldAI/fairseq-dense-13B-Nerys-v2", "32GB", False],
153
["Janeway FSD 13B", "KoboldAI/fairseq-dense-13B-Janeway", "32GB", False],
154
["Nerys FSD 13B (Hybrid)", "KoboldAI/fairseq-dense-13B-Nerys", "32GB", False],
155
["OPT Nerys 6B V2 (Hybrid)", "KoboldAI/OPT-6B-nerys-v2", "16GB", False],
156
["Janeway FSD 6.7B", "KoboldAI/fairseq-dense-6.7B-Janeway", "16GB", False],
157
["Janeway Neo 6B", "KoboldAI/GPT-J-6B-Janeway", "16GB", False],
158
["Qilin Lit 6B (SFW)", "rexwang8/qilin-lit-6b", "16GB", False],
159
["Janeway Neo 2.7B", "KoboldAI/GPT-Neo-2.7B-Janeway", "8GB", False],
160
["Janeway FSD 2.7B", "KoboldAI/fairseq-dense-2.7B-Janeway", "8GB", False],
161
["Nerys FSD 2.7B (Hybrid)", "KoboldAI/fairseq-dense-2.7B-Nerys", "8GB", False],
162
["Horni-LN 2.7B", "KoboldAI/GPT-Neo-2.7B-Horni-LN", "8GB", False],
163
["Picard 2.7B (Older Janeway)", "KoboldAI/GPT-Neo-2.7B-Picard", "8GB", False],
164
["Return to Main Menu", "mainmenu", "", True],
165
],
166
'nsfwlist': [
167
["Erebus 20B (NSFW)", "KoboldAI/GPT-NeoX-20B-Erebus", "64GB", False],
168
["Nerybus 13B (NSFW)", "KoboldAI/OPT-13B-Nerybus-Mix", "32GB", False],
169
["Erebus 13B (NSFW)", "KoboldAI/OPT-13B-Erebus", "32GB", False],
170
["Shinen FSD 13B (NSFW)", "KoboldAI/fairseq-dense-13B-Shinen", "32GB", False],
171
["Nerybus 6.7B (NSFW)", "KoboldAI/OPT-6.7B-Nerybus-Mix", "16GB", False],
172
["Erebus 6.7B (NSFW)", "KoboldAI/OPT-6.7B-Erebus", "16GB", False],
173
["Shinen FSD 6.7B (NSFW)", "KoboldAI/fairseq-dense-6.7B-Shinen", "16GB", False],
174
["Lit V2 6B (NSFW)", "hakurei/litv2-6B-rev3", "16GB", False],
175
["Lit 6B (NSFW)", "hakurei/lit-6B", "16GB", False],
176
["Shinen 6B (NSFW)", "KoboldAI/GPT-J-6B-Shinen", "16GB", False],
177
["Nerybus 2.7B (NSFW)", "KoboldAI/OPT-2.7B-Nerybus-Mix", "8GB", False],
178
["Erebus 2.7B (NSFW)", "KoboldAI/OPT-2.7B-Erebus", "8GB", False],
179
["Horni 2.7B (NSFW)", "KoboldAI/GPT-Neo-2.7B-Horni", "8GB", False],
180
["Shinen 2.7B (NSFW)", "KoboldAI/GPT-Neo-2.7B-Shinen", "8GB", False],
181
["Return to Main Menu", "mainmenu", "", True],
182
],
183
'chatlist': [
184
["Convo 6B (Chatbot)", "hitomi-team/convo-6B", "16GB", False],
185
["C1 6B (Chatbot)", "hakurei/c1-6B", "16GB", False],
186
["C1 1.3B (Chatbot)", "iokru/c1-1.3B", "6GB", False],
187
["Return to Main Menu", "mainmenu", "", True],
188
],
189
'gptneolist': [
190
["GPT-NeoX 20B", "EleutherAI/gpt-neox-20b", "64GB", False],
191
["Pythia 13B (NeoX, Same dataset)", "EleutherAI/pythia-13b", "32GB", False],
192
["GPT-J 6B", "EleutherAI/gpt-j-6B", "16GB", False],
193
["GPT-Neo 2.7B", "EleutherAI/gpt-neo-2.7B", "8GB", False],
194
["GPT-Neo 1.3B", "EleutherAI/gpt-neo-1.3B", "6GB", False],
195
["Pythia 800M (NeoX, Same dataset)", "EleutherAI/pythia-800m", "4GB", False],
196
["Pythia 350M (NeoX, Same dataset)", "EleutherAI/pythia-350m", "2GB", False],
197
["GPT-Neo 125M", "EleutherAI/gpt-neo-125M", "2GB", False],
198
["Return to Main Menu", "mainmenu", "", True],
199
],
200
'pythialist': [
201
["Pythia 13B Deduped", "EleutherAI/pythia-13b-deduped", "32GB", False],
202
["Pythia 13B", "EleutherAI/pythia-13b", "32GB", False],
203
["Pythia 6.7B Deduped", "EleutherAI/pythia-6.7b-deduped", "16GB", False],
204
["Pythia 6.7B", "EleutherAI/pythia-6.7b", "16GB", False],
205
["Pythia 1.3B Deduped", "EleutherAI/pythia-1.3b-deduped", "6GB", False],
206
["Pythia 1.3B", "EleutherAI/pythia-1.3b", "6GB", False],
207
["Pythia 800M", "EleutherAI/pythia-800m", "4GB", False],
208
["Pythia 350M Deduped", "EleutherAI/pythia-350m-deduped", "2GB", False],
209
["Pythia 350M", "EleutherAI/pythia-350m", "2GB", False],
210
["Pythia 125M Deduped", "EleutherAI/pythia-125m-deduped", "2GB", False],
211
["Pythia 125M", "EleutherAI/pythia-125m", "2GB", False],
212
["Pythia 19M Deduped", "EleutherAI/pythia-19m-deduped", "1GB", False],
213
["Pythia 19M", "EleutherAI/pythia-19m", "1GB", False],
214
["Return to Main Menu", "mainmenu", "", True],
215
],
216
'gpt2list': [
217
["GPT-2 XL", "gpt2-xl", "6GB", False],
218
["GPT-2 Large", "gpt2-large", "4GB", False],
219
["GPT-2 Med", "gpt2-medium", "2GB", False],
220
["GPT-2", "gpt2", "2GB", False],
221
["Return to Main Menu", "mainmenu", "", True],
222
],
223
'bloomlist': [
224
["Bloom 176B", "bigscience/bloom", "", False],
225
["Bloom 7.1B", "bigscience/bloom-7b1", "", False],
226
["Bloom 3B", "bigscience/bloom-3b", "", False],
227
["Bloom 1.7B", "bigscience/bloom-1b7", "", False],
228
["Bloom 560M", "bigscience/bloom-560m", "", False],
229
["Return to Main Menu", "mainmenu", "", True],
230
],
231
'optlist': [
232
["OPT 66B", "facebook/opt-66b", "128GB", False],
233
["OPT 30B", "facebook/opt-30b", "64GB", False],
234
["OPT 13B", "facebook/opt-13b", "32GB", False],
235
["OPT 6.7B", "facebook/opt-6.7b", "16GB", False],
236
["OPT 2.7B", "facebook/opt-2.7b", "8GB", False],
237
["OPT 1.3B", "facebook/opt-1.3b", "4GB", False],
238
["OPT 350M", "facebook/opt-350m", "2GB", False],
239
["OPT 125M", "facebook/opt-125m", "1GB", False],
240
["Return to Main Menu", "mainmenu", "", True],
241
],
242
'fsdlist': [
243
["Fairseq Dense 13B", "KoboldAI/fairseq-dense-13B", "32GB", False],
244
["Fairseq Dense 6.7B", "KoboldAI/fairseq-dense-6.7B", "16GB", False],
245
["Fairseq Dense 2.7B", "KoboldAI/fairseq-dense-2.7B", "8GB", False],
246
["Fairseq Dense 1.3B", "KoboldAI/fairseq-dense-1.3B", "4GB", False],
247
["Fairseq Dense 355M", "KoboldAI/fairseq-dense-355M", "2GB", False],
248
["Fairseq Dense 125M", "KoboldAI/fairseq-dense-125M", "1GB", False],
249
["Return to Main Menu", "mainmenu", "", True],
250
],
251
'xglmlist': [
252
["XGLM 4.5B (Larger Dataset)", "facebook/xglm-4.5B", "12GB", False],
253
["XGLM 7.5B", "facebook/xglm-7.5B", "18GB", False],
254
["XGLM 2.9B", "facebook/xglm-2.9B", "10GB", False],
255
["XGLM 1.7B", "facebook/xglm-1.7B", "6GB", False],
256
["XGLM 564M", "facebook/xglm-564M", "4GB", False],
257
["Return to Main Menu", "mainmenu", "", True],
258
],
259
'apilist': [
260
["GooseAI API (requires API key)", "GooseAI", "", False],
261
["OpenAI API (requires API key)", "OAI", "", False],
262
["InferKit API (requires API key)", "InferKit", "", False],
263
# ["KoboldAI Server API (Old Google Colab)", "Colab", "", False],
264
["KoboldAI API", "API", "", False],
265
["KoboldAI Horde", "CLUSTER", "", False],
266
["Return to Main Menu", "mainmenu", "", True],
267
]
268
}
269
270
class TokenStreamQueue:
271
def __init__(self):
272
self.probability_buffer = None
273
self.queue = []
274
275
def add_text(self, text):
276
self.queue.append({
277
"decoded": text,
278
"probabilities": self.probability_buffer
279
})
280
self.probability_buffer = None
281
282
# Variables
283
class vars:
284
lastact = "" # The last action received from the user
285
submission = "" # Same as above, but after applying input formatting
286
lastctx = "" # The last context submitted to the generator
287
model = "ReadOnly" # Model ID string chosen at startup
288
online_model = "" # Used when Model ID is an online service, and there is a secondary option for the actual model name
289
model_selected = "" #selected model in UI
290
model_type = "" # Model Type (Automatically taken from the model config)
291
noai = False # Runs the script without starting up the transformers pipeline
292
aibusy = False # Stops submissions while the AI is working
293
max_length = 1024 # Maximum number of tokens to submit per action
294
ikmax = 3000 # Maximum number of characters to submit to InferKit
295
genamt = 80 # Amount of text for each action to generate
296
ikgen = 200 # Number of characters for InferKit to generate
297
rep_pen = 1.1 # Default generator repetition_penalty
298
rep_pen_slope = 0.7 # Default generator repetition penalty slope
299
rep_pen_range = 1024 # Default generator repetition penalty range
300
temp = 0.5 # Default generator temperature
301
top_p = 0.9 # Default generator top_p
302
top_k = 0 # Default generator top_k
303
top_a = 0.0 # Default generator top-a
304
tfs = 1.0 # Default generator tfs (tail-free sampling)
305
typical = 1.0 # Default generator typical sampling threshold
306
numseqs = 1 # Number of sequences to ask the generator to create
307
full_determinism = False # Whether or not full determinism is enabled
308
seed_specified = False # Whether or not the current RNG seed was specified by the user (in their settings file)
309
seed = None # The current RNG seed (as an int), or None if unknown
310
gamestarted = False # Whether the game has started (disables UI elements)
311
gamesaved = True # Whether or not current game is saved
312
serverstarted = False # Whether or not the Flask server has started
313
prompt = "" # Prompt
314
memory = "" # Text submitted to memory field
315
authornote = "" # Text submitted to Author's Note field
316
authornotetemplate = "[Author's note: <|>]" # Author's note template
317
setauthornotetemplate = authornotetemplate # Saved author's note template in settings
318
andepth = 3 # How far back in history to append author's note
319
actions = structures.KoboldStoryRegister() # Actions submitted by user and AI
320
actions_metadata = {} # List of dictonaries, one dictonary for every action that contains information about the action like alternative options.
321
# Contains at least the same number of items as actions. Back action will remove an item from actions, but not actions_metadata
322
# Dictonary keys are:
323
# Selected Text: (text the user had selected. None when this is a newly generated action)
324
# Alternative Generated Text: {Text, Pinned, Previous Selection, Edited}
325
#
326
worldinfo = [] # List of World Info key/value objects
327
worldinfo_i = [] # List of World Info key/value objects sans uninitialized entries
328
worldinfo_u = {} # Dictionary of World Info UID - key/value pairs
329
wifolders_d = {} # Dictionary of World Info folder UID-info pairs
330
wifolders_l = [] # List of World Info folder UIDs
331
wifolders_u = {} # Dictionary of pairs of folder UID - list of WI UID
332
modelconfig = {} # Raw contents of the model's config.json, or empty dictionary if none found
333
lua_state = None # Lua state of the Lua scripting system
334
lua_koboldbridge = None # `koboldbridge` from bridge.lua
335
lua_kobold = None # `kobold` from` bridge.lua
336
lua_koboldcore = None # `koboldcore` from bridge.lua
337
lua_logname = ... # Name of previous userscript that logged to terminal
338
lua_running = False # Whether or not Lua is running (i.e. wasn't stopped due to an error)
339
lua_edited = set() # Set of chunk numbers that were edited from a Lua generation modifier
340
lua_deleted = set() # Set of chunk numbers that were deleted from a Lua generation modifier
341
generated_tkns = 0 # If using a backend that supports Lua generation modifiers, how many tokens have already been generated, otherwise 0
342
abort = False # Whether or not generation was aborted by clicking on the submit button during generation
343
compiling = False # If using a TPU Colab, this will be set to True when the TPU backend starts compiling and then set to False again
344
checking = False # Whether or not we are actively checking to see if TPU backend is compiling or not
345
sp_changed = False # This gets set to True whenever a userscript changes the soft prompt so that check_for_sp_change() can alert the browser that the soft prompt has changed
346
spfilename = "" # Filename of soft prompt to load, or an empty string if not using a soft prompt
347
userscripts = [] # List of userscripts to load
348
last_userscripts = [] # List of previous userscript filenames from the previous time userscripts were send via usstatitems
349
corescript = "default.lua" # Filename of corescript to load
350
# badwords = [] # Array of str/chr values that should be removed from output
351
badwordsids = []
352
badwordsids_default = [[13460], [6880], [50256], [42496], [4613], [17414], [22039], [16410], [27], [29], [38430], [37922], [15913], [24618], [28725], [58], [47175], [36937], [26700], [12878], [16471], [37981], [5218], [29795], [13412], [45160], [3693], [49778], [4211], [20598], [36475], [33409], [44167], [32406], [29847], [29342], [42669], [685], [25787], [7359], [3784], [5320], [33994], [33490], [34516], [43734], [17635], [24293], [9959], [23785], [21737], [28401], [18161], [26358], [32509], [1279], [38155], [18189], [26894], [6927], [14610], [23834], [11037], [14631], [26933], [46904], [22330], [25915], [47934], [38214], [1875], [14692], [41832], [13163], [25970], [29565], [44926], [19841], [37250], [49029], [9609], [44438], [16791], [17816], [30109], [41888], [47527], [42924], [23984], [49074], [33717], [31161], [49082], [30138], [31175], [12240], [14804], [7131], [26076], [33250], [3556], [38381], [36338], [32756], [46581], [17912], [49146]] # Tokenized array of badwords used to prevent AI artifacting
353
badwordsids_neox = [[0], [1], [44162], [9502], [12520], [31841], [36320], [49824], [34417], [6038], [34494], [24815], [26635], [24345], [3455], [28905], [44270], [17278], [32666], [46880], [7086], [43189], [37322], [17778], [20879], [49821], [3138], [14490], [4681], [21391], [26786], [43134], [9336], [683], [48074], [41256], [19181], [29650], [28532], [36487], [45114], [46275], [16445], [15104], [11337], [1168], [5647], [29], [27482], [44965], [43782], [31011], [42944], [47389], [6334], [17548], [38329], [32044], [35487], [2239], [34761], [7444], [1084], [12399], [18990], [17636], [39083], [1184], [35830], [28365], [16731], [43467], [47744], [1138], [16079], [40116], [45564], [18297], [42368], [5456], [18022], [42696], [34476], [23505], [23741], [39334], [37944], [45382], [38709], [33440], [26077], [43600], [34418], [36033], [6660], [48167], [48471], [15775], [19884], [41533], [1008], [31053], [36692], [46576], [20095], [20629], [31759], [46410], [41000], [13488], [30952], [39258], [16160], [27655], [22367], [42767], [43736], [49694], [13811], [12004], [46768], [6257], [37471], [5264], [44153], [33805], [20977], [21083], [25416], [14277], [31096], [42041], [18331], [33376], [22372], [46294], [28379], [38475], [1656], [5204], [27075], [50001], [16616], [11396], [7748], [48744], [35402], [28120], [41512], [4207], [43144], [14767], [15640], [16595], [41305], [44479], [38958], [18474], [22734], [30522], [46267], [60], [13976], [31830], [48701], [39822], [9014], [21966], [31422], [28052], [34607], [2479], [3851], [32214], [44082], [45507], [3001], [34368], [34758], [13380], [38363], [4299], [46802], [30996], [12630], [49236], [7082], [8795], [5218], [44740], [9686], [9983], [45301], [27114], [40125], [1570], [26997], [544], [5290], [49193], [23781], [14193], [40000], [2947], [43781], [9102], [48064], [42274], [18772], [49384], [9884], [45635], [43521], [31258], [32056], [47686], [21760], [13143], [10148], [26119], [44308], [31379], [36399], [23983], [46694], [36134], [8562], [12977], [35117], [28591], [49021], [47093], [28653], [29013], [46468], [8605], [7254], [25896], [5032], [8168], [36893], [38270], [20499], [27501], [34419], [29547], [28571], [36586], [20871], [30537], [26842], [21375], [31148], [27618], [33094], [3291], [31789], [28391], [870], [9793], [41361], [47916], [27468], [43856], [8850], [35237], [15707], [47552], [2730], [41449], [45488], [3073], [49806], [21938], [24430], [22747], [20924], [46145], [20481], [20197], [8239], [28231], [17987], [42804], [47269], [29972], [49884], [21382], [46295], [36676], [34616], [3921], [26991], [27720], [46265], [654], [9855], [40354], [5291], [34904], [44342], [2470], [14598], [880], [19282], [2498], [24237], [21431], [16369], [8994], [44524], [45662], [13663], [37077], [1447], [37786], [30863], [42854], [1019], [20322], [4398], [12159], [44072], [48664], [31547], [18736], [9259], [31], [16354], [21810], [4357], [37982], [5064], [2033], [32871], [47446], [62], [22158], [37387], [8743], [47007], [17981], [11049], [4622], [37916], [36786], [35138], [29925], [14157], [18095], [27829], [1181], [22226], [5709], [4725], [30189], [37014], [1254], [11380], [42989], [696], [24576], [39487], [30119], [1092], [8088], [2194], [9899], [14412], [21828], [3725], [13544], [5180], [44679], [34398], [3891], [28739], [14219], [37594], [49550], [11326], [6904], [17266], [5749], [10174], [23405], [9955], [38271], [41018], [13011], [48392], [36784], [24254], [21687], [23734], [5413], [41447], [45472], [10122], [17555], [15830], [47384], [12084], [31350], [47940], [11661], [27988], [45443], [905], [49651], [16614], [34993], [6781], [30803], [35869], [8001], [41604], [28118], [46462], [46762], [16262], [17281], [5774], [10943], [5013], [18257], [6750], [4713], [3951], [11899], [38791], [16943], [37596], [9318], [18413], [40473], [13208], [16375]]
354
badwordsids_opt = [[44717], [46613], [48513], [49923], [50185], [48755], [8488], [43303], [49659], [48601], [49817], [45405], [48742], [49925], [47720], [11227], [48937], [48784], [50017], [42248], [49310], [48082], [49895], [50025], [49092], [49007], [8061], [44226], [0], [742], [28578], [15698], [49784], [46679], [39365], [49281], [49609], [48081], [48906], [46161], [48554], [49670], [48677], [49721], [49632], [48610], [48462], [47457], [10975], [46077], [28696], [48709], [43839], [49798], [49154], [48203], [49625], [48395], [50155], [47161], [49095], [48833], [49420], [49666], [48443], [22176], [49242], [48651], [49138], [49750], [40389], [48021], [21838], [49070], [45333], [40862], [1], [49915], [33525], [49858], [50254], [44403], [48992], [48872], [46117], [49853], [47567], [50206], [41552], [50068], [48999], [49703], [49940], [49329], [47620], [49868], [49962], [2], [44082], [50236], [31274], [50260], [47052], [42645], [49177], [17523], [48691], [49900], [49069], [49358], [48794], [47529], [46479], [48457], [646], [49910], [48077], [48935], [46386], [48902], [49151], [48759], [49803], [45587], [48392], [47789], [48654], [49836], [49230], [48188], [50264], [46844], [44690], [48505], [50161], [27779], [49995], [41833], [50154], [49097], [48520], [50018], [8174], [50084], [49366], [49526], [50193], [7479], [49982], [3]]
355
fp32_model = False # Whether or not the most recently loaded HF model was in fp32 format
356
deletewi = None # Temporary storage for UID to delete
357
wirmvwhtsp = True # Whether to remove leading whitespace from WI entries
358
widepth = 3 # How many historical actions to scan for WI hits
359
mode = "play" # Whether the interface is in play, memory, or edit mode
360
editln = 0 # Which line was last selected in Edit Mode
361
gpu_device = 0 # Which PyTorch device to use when using pure GPU generation
362
url = "https://api.inferkit.com/v1/models/standard/generate" # InferKit API URL
363
oaiurl = "" # OpenAI API URL
364
oaiengines = "https://api.openai.com/v1/engines"
365
colaburl = "" # Ngrok url for Google Colab mode
366
apikey = "" # API key to use for InferKit API calls
367
oaiapikey = "" # API key to use for OpenAI API calls
368
cluster_requested_models = [] # The models which we allow to generate during cluster mode
369
savedir = getcwd()+"\\stories"
370
hascuda = False # Whether torch has detected CUDA on the system
371
usegpu = False # Whether to launch pipeline with GPU support
372
custmodpth = "" # Filesystem location of custom model to run
373
formatoptns = {'frmttriminc': True, 'frmtrmblln': False, 'frmtrmspch': False, 'frmtadsnsp': True, 'singleline': False} # Container for state of formatting options
374
importnum = -1 # Selection on import popup list
375
importjs = {} # Temporary storage for import data
376
loadselect = "" # Temporary storage for story filename to load
377
spselect = "" # Temporary storage for soft prompt filename to load
378
spmeta = None # Metadata of current soft prompt, or None if not using a soft prompt
379
sp = None # Current soft prompt tensor (as a NumPy array)
380
sp_length = 0 # Length of current soft prompt in tokens, or 0 if not using a soft prompt
381
has_genmod = False # Whether or not at least one loaded Lua userscript has a generation modifier
382
svowname = "" # Filename that was flagged for overwrite confirm
383
saveow = False # Whether or not overwrite confirm has been displayed
384
autosave = False # Whether or not to automatically save after each action
385
genseqs = [] # Temporary storage for generated sequences
386
recentback = False # Whether Back button was recently used without Submitting or Retrying after
387
recentrng = None # If a new random game was recently generated without Submitting after, this is the topic used (as a string), otherwise this is None
388
recentrngm = None # If a new random game was recently generated without Submitting after, this is the memory used (as a string), otherwise this is None
389
useprompt = False # Whether to send the full prompt with every submit action
390
breakmodel = False # For GPU users, whether to use both system RAM and VRAM to conserve VRAM while offering speedup compared to CPU-only
391
bmsupported = False # Whether the breakmodel option is supported (GPT-Neo/GPT-J/XGLM/OPT only, currently)
392
nobreakmodel = False # Something specifically requested Breakmodel to be disabled (For example a models config)
393
smandelete = False # Whether stories can be deleted from inside the browser
394
smanrename = False # Whether stories can be renamed from inside the browser
395
allowsp = False # Whether we are allowed to use soft prompts (by default enabled if we're using GPT-2, GPT-Neo or GPT-J)
396
modeldim = -1 # Embedding dimension of your model (e.g. it's 4096 for GPT-J-6B and 2560 for GPT-Neo-2.7B)
397
laststory = None # Filename (without extension) of most recent story JSON file we loaded
398
regex_sl = re.compile(r'\n*(?<=.) *\n(.|\n)*') # Pattern for limiting the output to a single line
399
acregex_ai = re.compile(r'\n* *>(.|\n)*') # Pattern for matching adventure actions from the AI so we can remove them
400
acregex_ui = re.compile(r'^ *(&gt;.*)$', re.MULTILINE) # Pattern for matching actions in the HTML-escaped story so we can apply colouring, etc (make sure to encase part to format in parentheses)
401
comregex_ai = re.compile(r'(?:\n<\|(?:.|\n)*?\|>(?=\n|$))|(?:<\|(?:.|\n)*?\|>\n?)') # Pattern for matching comments to remove them before sending them to the AI
402
comregex_ui = re.compile(r'(&lt;\|(?:.|\n)*?\|&gt;)') # Pattern for matching comments in the editor
403
sampler_order = utils.default_sampler_order.copy()
404
rng_states = {} # Used by the POST /generate endpoint to store sampler RNG states
405
chatmode = False
406
chatname = "You"
407
adventure = False
408
actionmode = 1
409
dynamicscan = False
410
host = False
411
nopromptgen = False
412
rngpersist = False
413
nogenmod = False
414
welcome = False # Custom Welcome Text (False is default)
415
newlinemode = "ns"
416
quiet = False # If set will suppress any story text from being printed to the console (will only be seen on the client web page)
417
debug = False # If set to true, will send debug information to the client for display
418
lazy_load = True # Whether or not to use torch_lazy_loader.py for transformers models in order to reduce CPU memory usage
419
use_colab_tpu = os.environ.get("COLAB_TPU_ADDR", "") != "" or os.environ.get("TPU_NAME", "") != "" # Whether or not we're in a Colab TPU instance or Kaggle TPU instance and are going to use the TPU rather than the CPU
420
revision = None
421
standalone = False
422
api_tokenizer_id = None
423
disable_set_aibusy = False
424
disable_input_formatting = False
425
disable_output_formatting = False
426
output_streaming = True
427
token_stream_queue = TokenStreamQueue() # Queue for the token streaming
428
show_probs = False # Whether or not to show token probabilities
429
show_budget = False # Whether or not to show token probabilities
430
configname = None
431
432
utils.vars = vars
433
434
class Send_to_socketio(object):
435
def write(self, bar):
436
print(bar, end="")
437
time.sleep(0.01)
438
try:
439
gui_msg = bar.replace(f"{colors.PURPLE}INIT{colors.END} | ","").replace(" ", "&nbsp;")
440
emit('from_server', {'cmd': 'model_load_status', 'data': gui_msg}, broadcast=True)
441
except:
442
pass
443
444
# Set logging level to reduce chatter from Flask
445
import logging
446
log = logging.getLogger('werkzeug')
447
log.setLevel(logging.ERROR)
448
449
from flask import Flask, render_template, Response, request, copy_current_request_context, send_from_directory, session, jsonify, abort, redirect
450
from flask_socketio import SocketIO
451
from flask_socketio import emit as _emit
452
from flask_session import Session
453
from werkzeug.exceptions import HTTPException, NotFound, InternalServerError
454
import secrets
455
app = Flask(__name__, root_path=os.getcwd())
456
app.secret_key = secrets.token_hex()
457
app.config['SESSION_TYPE'] = 'filesystem'
458
app.config['TEMPLATES_AUTO_RELOAD'] = True
459
socketio = SocketIO(app, async_method="eventlet")
460
461
old_socketio_on = socketio.on
462
def new_socketio_on(*a, **k):
463
decorator = old_socketio_on(*a, **k)
464
def new_decorator(f):
465
@functools.wraps(f)
466
def g(*a, **k):
467
if args.no_ui:
468
return
469
return f(*a, **k)
470
return decorator(g)
471
return new_decorator
472
socketio.on = new_socketio_on
473
474
def emit(*args, **kwargs):
475
try:
476
return _emit(*args, **kwargs)
477
except AttributeError:
478
return socketio.emit(*args, **kwargs)
479
utils.emit = emit
480
481
# marshmallow/apispec setup
482
from apispec import APISpec
483
from apispec.ext.marshmallow import MarshmallowPlugin
484
from apispec.ext.marshmallow.field_converter import make_min_max_attributes
485
from apispec_webframeworks.flask import FlaskPlugin
486
from marshmallow import Schema, fields, validate, EXCLUDE
487
from marshmallow.exceptions import ValidationError
488
489
class KoboldSchema(Schema):
490
pass
491
492
def new_make_min_max_attributes(validators, min_attr, max_attr) -> dict:
493
# Patched apispec function that creates "exclusiveMinimum"/"exclusiveMaximum" OpenAPI attributes insteaed of "minimum"/"maximum" when using validators.Range or validators.Length with min_inclusive=False or max_inclusive=False
494
attributes = {}
495
min_list = [validator.min for validator in validators if validator.min is not None]
496
max_list = [validator.max for validator in validators if validator.max is not None]
497
min_inclusive_list = [getattr(validator, "min_inclusive", True) for validator in validators if validator.min is not None]
498
max_inclusive_list = [getattr(validator, "max_inclusive", True) for validator in validators if validator.max is not None]
499
if min_list:
500
if min_attr == "minimum" and not min_inclusive_list[max(range(len(min_list)), key=min_list.__getitem__)]:
501
min_attr = "exclusiveMinimum"
502
attributes[min_attr] = max(min_list)
503
if max_list:
504
if min_attr == "maximum" and not max_inclusive_list[min(range(len(max_list)), key=max_list.__getitem__)]:
505
min_attr = "exclusiveMaximum"
506
attributes[max_attr] = min(max_list)
507
return attributes
508
make_min_max_attributes.__code__ = new_make_min_max_attributes.__code__
509
510
def api_format_docstring(f):
511
f.__doc__ = eval('f"""{}"""'.format(f.__doc__.replace("\\", "\\\\")))
512
return f
513
514
def api_catch_out_of_memory_errors(f):
515
@functools.wraps(f)
516
def decorated(*args, **kwargs):
517
try:
518
return f(*args, **kwargs)
519
except Exception as e:
520
if any (s in traceback.format_exc().lower() for s in ("out of memory", "not enough memory")):
521
for line in reversed(traceback.format_exc().split("\n")):
522
if any(s in line.lower() for s in ("out of memory", "not enough memory")) and line.count(":"):
523
line = line.split(":", 1)[1]
524
line = re.sub(r"\[.+?\] +data\.", "", line).strip()
525
raise KoboldOutOfMemoryError("KoboldAI ran out of memory: " + line, type="out_of_memory.gpu.cuda" if "cuda out of memory" in line.lower() else "out_of_memory.gpu.hip" if "hip out of memory" in line.lower() else "out_of_memory.tpu.hbm" if "memory space hbm" in line.lower() else "out_of_memory.cpu.default_memory_allocator" if "defaultmemoryallocator" in line.lower() else "out_of_memory.unknown.unknown")
526
raise KoboldOutOfMemoryError(type="out_of_memory.unknown.unknown")
527
raise e
528
return decorated
529
530
def api_schema_wrap(f):
531
try:
532
input_schema: Type[Schema] = next(iter(inspect.signature(f).parameters.values())).annotation
533
except:
534
HAS_SCHEMA = False
535
else:
536
HAS_SCHEMA = inspect.isclass(input_schema) and issubclass(input_schema, Schema)
537
f = api_format_docstring(f)
538
f = api_catch_out_of_memory_errors(f)
539
@functools.wraps(f)
540
def decorated(*args, **kwargs):
541
if HAS_SCHEMA:
542
body = request.get_json()
543
schema = input_schema.from_dict(input_schema().load(body))
544
response = f(schema, *args, **kwargs)
545
else:
546
response = f(*args, **kwargs)
547
if not isinstance(response, Response):
548
response = jsonify(response)
549
return response
550
return decorated
551
552
@app.errorhandler(HTTPException)
553
def handler(e):
554
if request.path != "/api" and not request.path.startswith("/api/"):
555
return e
556
resp = jsonify(detail={"msg": str(e), "type": "generic.error_" + str(e.code)})
557
if e.code == 405 and e.valid_methods is not None:
558
resp.headers["Allow"] = ", ".join(e.valid_methods)
559
return resp, e.code
560
561
class KoboldOutOfMemoryError(HTTPException):
562
code = 507
563
description = "KoboldAI ran out of memory."
564
type = "out_of_memory.unknown.unknown"
565
def __init__(self, *args, type=None, **kwargs):
566
super().__init__(*args, **kwargs)
567
if type is not None:
568
self.type = type
569
@app.errorhandler(KoboldOutOfMemoryError)
570
def handler(e):
571
if request.path != "/api" and not request.path.startswith("/api/"):
572
return InternalServerError()
573
return jsonify(detail={"type": e.type, "msg": e.description}), e.code
574
575
@app.errorhandler(ValidationError)
576
def handler(e):
577
if request.path != "/api" and not request.path.startswith("/api/"):
578
return InternalServerError()
579
return jsonify(detail=e.messages), 422
580
581
@app.errorhandler(NotImplementedError)
582
def handler(e):
583
if request.path != "/api" and not request.path.startswith("/api/"):
584
return InternalServerError()
585
return jsonify(detail={"type": "not_implemented", "msg": str(e).strip()}), 501
586
587
api_versions: List[str] = []
588
589
class KoboldAPISpec(APISpec):
590
class KoboldFlaskPlugin(FlaskPlugin):
591
def __init__(self, api: "KoboldAPISpec", *args, **kwargs):
592
self._kobold_api_spec = api
593
super().__init__(*args, **kwargs)
594
595
def path_helper(self, *args, **kwargs):
596
return super().path_helper(*args, **kwargs)[len(self._kobold_api_spec._prefixes[0]):]
597
598
def __init__(self, *args, title: str = "KoboldAI API", openapi_version: str = "3.0.3", version: str = "1.0.0", prefixes: List[str] = None, **kwargs):
599
plugins = [KoboldAPISpec.KoboldFlaskPlugin(self), MarshmallowPlugin()]
600
self._prefixes = prefixes if prefixes is not None else [""]
601
self._kobold_api_spec_version = version
602
api_versions.append(version)
603
api_versions.sort(key=lambda x: [int(e) for e in x.split(".")])
604
super().__init__(*args, title=title, openapi_version=openapi_version, version=version, plugins=plugins, servers=[{"url": self._prefixes[0]}], **kwargs)
605
for prefix in self._prefixes:
606
app.route(prefix, endpoint="~KoboldAPISpec~" + prefix)(lambda: redirect(request.path + "/docs/"))
607
app.route(prefix + "/", endpoint="~KoboldAPISpec~" + prefix + "/")(lambda: redirect("docs/"))
608
app.route(prefix + "/docs", endpoint="~KoboldAPISpec~" + prefix + "/docs")(lambda: redirect("docs/"))
609
app.route(prefix + "/docs/", endpoint="~KoboldAPISpec~" + prefix + "/docs/")(lambda: render_template("swagger-ui.html", url=self._prefixes[0] + "/openapi.json"))
610
app.route(prefix + "/openapi.json", endpoint="~KoboldAPISpec~" + prefix + "/openapi.json")(lambda: jsonify(self.to_dict()))
611
612
def route(self, rule: str, methods=["GET"], **kwargs):
613
__F = TypeVar("__F", bound=Callable[..., Any])
614
if "strict_slashes" not in kwargs:
615
kwargs["strict_slashes"] = False
616
def new_decorator(f: __F) -> __F:
617
@functools.wraps(f)
618
def g(*args, **kwargs):
619
global api_version
620
api_version = self._kobold_api_spec_version
621
try:
622
return f(*args, **kwargs)
623
finally:
624
api_version = None
625
for prefix in self._prefixes:
626
g = app.route(prefix + rule, methods=methods, **kwargs)(g)
627
with app.test_request_context():
628
self.path(view=g, **kwargs)
629
return g
630
return new_decorator
631
632
def get(self, rule: str, **kwargs):
633
return self.route(rule, methods=["GET"], **kwargs)
634
635
def post(self, rule: str, **kwargs):
636
return self.route(rule, methods=["POST"], **kwargs)
637
638
def put(self, rule: str, **kwargs):
639
return self.route(rule, methods=["PUT"], **kwargs)
640
641
def patch(self, rule: str, **kwargs):
642
return self.route(rule, methods=["PATCH"], **kwargs)
643
644
def delete(self, rule: str, **kwargs):
645
return self.route(rule, methods=["DELETE"], **kwargs)
646
647
tags = [
648
{"name": "info", "description": "Metadata about this API"},
649
{"name": "generate", "description": "Text generation endpoints"},
650
{"name": "model", "description": "Information about the current text generation model"},
651
{"name": "story", "description": "Endpoints for managing the story in the KoboldAI GUI"},
652
{"name": "world_info", "description": "Endpoints for managing the world info in the KoboldAI GUI"},
653
{"name": "config", "description": "Allows you to get/set various setting values"},
654
]
655
656
api_version = None # This gets set automatically so don't change this value
657
658
api_v1 = KoboldAPISpec(
659
version="1.2.1",
660
prefixes=["/api/v1", "/api/latest"],
661
tags=tags,
662
)
663
664
# Returns the expected config filename for the current setup.
665
# If the model_name is specified, it returns what the settings file would be for that model
666
def get_config_filename(model_name = None):
667
if model_name:
668
return(f"settings/{model_name.replace('/', '_')}.settings")
669
elif args.configname:
670
return(f"settings/{args.configname.replace('/', '_')}.settings")
671
elif vars.configname != '':
672
return(f"settings/{vars.configname.replace('/', '_')}.settings")
673
else:
674
logger.warning(f"Empty configfile name sent back. Defaulting to ReadOnly")
675
return(f"settings/ReadOnly.settings")
676
#==================================================================#
677
# Function to get model selection at startup
678
#==================================================================#
679
def sendModelSelection(menu="mainmenu", folder="./models"):
680
#If we send one of the manual load options, send back the list of model directories, otherwise send the menu
681
if menu in ('NeoCustom', 'GPT2Custom'):
682
(paths, breadcrumbs) = get_folder_path_info(folder)
683
if vars.host:
684
breadcrumbs = []
685
menu_list = [[folder, menu, "", False] for folder in paths]
686
menu_list.append(["Return to Main Menu", "mainmenu", "", True])
687
if os.path.abspath("{}/models".format(os.getcwd())) == os.path.abspath(folder):
688
showdelete=True
689
else:
690
showdelete=False
691
emit('from_server', {'cmd': 'show_model_menu', 'data': menu_list, 'menu': menu, 'breadcrumbs': breadcrumbs, "showdelete": showdelete}, broadcast=True)
692
else:
693
emit('from_server', {'cmd': 'show_model_menu', 'data': model_menu[menu], 'menu': menu, 'breadcrumbs': [], "showdelete": False}, broadcast=True)
694
695
def get_folder_path_info(base):
696
if base == 'This PC':
697
breadcrumbs = [['This PC', 'This PC']]
698
paths = [["{}:\\".format(chr(i)), "{}:\\".format(chr(i))] for i in range(65, 91) if os.path.exists("{}:".format(chr(i)))]
699
else:
700
path = os.path.abspath(base)
701
if path[-1] == "\\":
702
path = path[:-1]
703
breadcrumbs = []
704
for i in range(len(path.replace("/", "\\").split("\\"))):
705
breadcrumbs.append(["\\".join(path.replace("/", "\\").split("\\")[:i+1]),
706
path.replace("/", "\\").split("\\")[i]])
707
if len(breadcrumbs) == 1:
708
breadcrumbs = [["{}:\\".format(chr(i)), "{}:\\".format(chr(i))] for i in range(65, 91) if os.path.exists("{}:".format(chr(i)))]
709
else:
710
if len([["{}:\\".format(chr(i)), "{}:\\".format(chr(i))] for i in range(65, 91) if os.path.exists("{}:".format(chr(i)))]) > 0:
711
breadcrumbs.insert(0, ['This PC', 'This PC'])
712
paths = []
713
base_path = os.path.abspath(base)
714
for item in os.listdir(base_path):
715
if os.path.isdir(os.path.join(base_path, item)):
716
paths.append([os.path.join(base_path, item), item])
717
# Paths/breadcrumbs is a list of lists, where the first element in the sublist is the full path and the second is the folder name
718
return (paths, breadcrumbs)
719
720
721
def getModelSelection(modellist):
722
print(" # Model\t\t\t\t\t\tVRAM\n ========================================================")
723
i = 1
724
for m in modellist:
725
print(" {0} - {1}\t\t\t{2}".format("{:<2}".format(i), m[0].ljust(25), m[2]))
726
i += 1
727
print(" ");
728
modelsel = 0
729
vars.model = ''
730
while(vars.model == ''):
731
modelsel = input("Model #> ")
732
if(modelsel.isnumeric() and int(modelsel) > 0 and int(modelsel) <= len(modellist)):
733
vars.model = modellist[int(modelsel)-1][1]
734
else:
735
print("{0}Please enter a valid selection.{1}".format(colors.RED, colors.END))
736
737
# Model Lists
738
try:
739
getModelSelection(eval(vars.model))
740
except Exception as e:
741
if(vars.model == "Return"):
742
getModelSelection(mainmenu)
743
744
# If custom model was selected, get the filesystem location and store it
745
if(vars.model == "NeoCustom" or vars.model == "GPT2Custom"):
746
print("{0}Please choose the folder where pytorch_model.bin is located:{1}\n".format(colors.CYAN, colors.END))
747
modpath = fileops.getdirpath(getcwd() + "/models", "Select Model Folder")
748
749
if(modpath):
750
# Save directory to vars
751
vars.custmodpth = modpath
752
else:
753
# Print error and retry model selection
754
print("{0}Model select cancelled!{1}".format(colors.RED, colors.END))
755
print("{0}Select an AI model to continue:{1}\n".format(colors.CYAN, colors.END))
756
getModelSelection(mainmenu)
757
758
def check_if_dir_is_model(path):
759
return os.path.exists(os.path.join(path, 'config.json'))
760
761
#==================================================================#
762
# Return all keys in tokenizer dictionary containing char
763
#==================================================================#
764
#def gettokenids(char):
765
# keys = []
766
# for key in vocab_keys:
767
# if(key.find(char) != -1):
768
# keys.append(key)
769
# return keys
770
771
#==================================================================#
772
# Return Model Name
773
#==================================================================#
774
def getmodelname():
775
if(vars.online_model != ''):
776
return(f"{vars.model}/{vars.online_model}")
777
if(vars.model in ("NeoCustom", "GPT2Custom", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
778
modelname = os.path.basename(os.path.normpath(vars.custmodpth))
779
return modelname
780
else:
781
modelname = vars.model
782
return modelname
783
784
#==================================================================#
785
# Get hidden size from model
786
#==================================================================#
787
def get_hidden_size_from_model(model):
788
return model.get_input_embeddings().embedding_dim
789
790
#==================================================================#
791
# Breakmodel configuration functions
792
#==================================================================#
793
def device_list(n_layers, primary=None, selected=None):
794
device_count = torch.cuda.device_count()
795
if(device_count < 2):
796
primary = None
797
gpu_blocks = breakmodel.gpu_blocks + (device_count - len(breakmodel.gpu_blocks))*[0]
798
print(f"{colors.YELLOW} DEVICE ID | LAYERS | DEVICE NAME{colors.END}")
799
for i in range(device_count):
800
name = torch.cuda.get_device_name(i)
801
if(len(name) > 47):
802
name = "..." + name[-44:]
803
row_color = colors.END
804
sep_color = colors.YELLOW
805
print(f"{row_color}{colors.YELLOW + '->' + row_color if i == selected else ' '} {'(primary)' if i == primary else ' '*9} {i:3} {sep_color}|{row_color} {gpu_blocks[i]:3} {sep_color}|{row_color} {name}{colors.END}")
806
row_color = colors.END
807
sep_color = colors.YELLOW
808
if(utils.HAS_ACCELERATE):
809
print(f"{row_color}{colors.YELLOW + '->' + row_color if -1 == selected else ' '} {' '*9} N/A {sep_color}|{row_color} {breakmodel.disk_blocks:3} {sep_color}|{row_color} (Disk cache){colors.END}")
810
print(f"{row_color} {' '*9} N/A {sep_color}|{row_color} {n_layers:3} {sep_color}|{row_color} (CPU){colors.END}")
811
812
def device_config(config):
813
global breakmodel, generator
814
import breakmodel
815
n_layers = utils.num_layers(config)
816
if args.cpu:
817
breakmodel.gpu_blocks = [0]*n_layers
818
return
819
elif(args.breakmodel_gpulayers is not None or (utils.HAS_ACCELERATE and args.breakmodel_disklayers is not None)):
820
try:
821
if(not args.breakmodel_gpulayers):
822
breakmodel.gpu_blocks = []
823
else:
824
breakmodel.gpu_blocks = list(map(int, args.breakmodel_gpulayers.split(',')))
825
assert len(breakmodel.gpu_blocks) <= torch.cuda.device_count()
826
s = n_layers
827
for i in range(len(breakmodel.gpu_blocks)):
828
if(breakmodel.gpu_blocks[i] <= -1):
829
breakmodel.gpu_blocks[i] = s
830
break
831
else:
832
s -= breakmodel.gpu_blocks[i]
833
assert sum(breakmodel.gpu_blocks) <= n_layers
834
n_layers -= sum(breakmodel.gpu_blocks)
835
if(args.breakmodel_disklayers is not None):
836
assert args.breakmodel_disklayers <= n_layers
837
breakmodel.disk_blocks = args.breakmodel_disklayers
838
n_layers -= args.breakmodel_disklayers
839
except:
840
logger.warning("--breakmodel_gpulayers is malformatted. Please use the --help option to see correct usage of --breakmodel_gpulayers. Defaulting to all layers on device 0.")
841
breakmodel.gpu_blocks = [n_layers]
842
n_layers = 0
843
elif(args.breakmodel_layers is not None):
844
breakmodel.gpu_blocks = [n_layers - max(0, min(n_layers, args.breakmodel_layers))]
845
n_layers -= sum(breakmodel.gpu_blocks)
846
elif(args.model is not None):
847
logger.info("Breakmodel not specified, assuming GPU 0")
848
breakmodel.gpu_blocks = [n_layers]
849
n_layers = 0
850
else:
851
device_count = torch.cuda.device_count()
852
if(device_count > 1):
853
print(colors.CYAN + "\nPlease select one of your GPUs to be your primary GPU.")
854
print("VRAM usage in your primary GPU will be higher than for your other ones.")
855
print("It is recommended you make your fastest GPU your primary GPU.")
856
device_list(n_layers)
857
while(True):
858
primaryselect = input("device ID> ")
859
if(primaryselect.isnumeric() and 0 <= int(primaryselect) < device_count):
860
breakmodel.primary_device = int(primaryselect)
861
break
862
else:
863
print(f"{colors.RED}Please enter an integer between 0 and {device_count-1}.{colors.END}")
864
else:
865
breakmodel.primary_device = 0
866
867
print(colors.PURPLE + "\nIf you don't have enough VRAM to run the model on a single GPU")
868
print("you can split the model between your CPU and your GPU(s), or between")
869
print("multiple GPUs if you have more than one.")
870
print("By putting more 'layers' on a GPU or CPU, more computations will be")
871
print("done on that device and more VRAM or RAM will be required on that device")
872
print("(roughly proportional to number of layers).")
873
print("It should be noted that GPUs are orders of magnitude faster than the CPU.")
874
print(f"This model has{colors.YELLOW} {n_layers} {colors.PURPLE}layers.{colors.END}\n")
875
876
for i in range(device_count):
877
device_list(n_layers, primary=breakmodel.primary_device, selected=i)
878
print(f"{colors.CYAN}\nHow many of the remaining{colors.YELLOW} {n_layers} {colors.CYAN}layers would you like to put into device {i}?\nYou can also enter -1 to allocate all remaining layers to this device.{colors.END}\n")
879
while(True):
880
layerselect = input("# of layers> ")
881
if((layerselect.isnumeric() or layerselect.strip() == '-1') and -1 <= int(layerselect) <= n_layers):
882
layerselect = int(layerselect)
883
layerselect = n_layers if layerselect == -1 else layerselect
884
breakmodel.gpu_blocks.append(layerselect)
885
n_layers -= layerselect
886
break
887
else:
888
print(f"{colors.RED}Please enter an integer between -1 and {n_layers}.{colors.END}")
889
if(n_layers == 0):
890
break
891
892
if(utils.HAS_ACCELERATE and n_layers > 0):
893
device_list(n_layers, primary=breakmodel.primary_device, selected=-1)
894
print(f"{colors.CYAN}\nHow many of the remaining{colors.YELLOW} {n_layers} {colors.CYAN}layers would you like to put into the disk cache?\nYou can also enter -1 to allocate all remaining layers to this device.{colors.END}\n")
895
while(True):
896
layerselect = input("# of layers> ")
897
if((layerselect.isnumeric() or layerselect.strip() == '-1') and -1 <= int(layerselect) <= n_layers):
898
layerselect = int(layerselect)
899
layerselect = n_layers if layerselect == -1 else layerselect
900
breakmodel.disk_blocks = layerselect
901
n_layers -= layerselect
902
break
903
else:
904
print(f"{colors.RED}Please enter an integer between -1 and {n_layers}.{colors.END}")
905
906
logger.init_ok("Final device configuration:", status="Info")
907
device_list(n_layers, primary=breakmodel.primary_device)
908
909
# If all layers are on the same device, use the old GPU generation mode
910
while(len(breakmodel.gpu_blocks) and breakmodel.gpu_blocks[-1] == 0):
911
breakmodel.gpu_blocks.pop()
912
if(len(breakmodel.gpu_blocks) and breakmodel.gpu_blocks[-1] in (-1, utils.num_layers(config))):
913
vars.breakmodel = False
914
vars.usegpu = True
915
vars.gpu_device = len(breakmodel.gpu_blocks)-1
916
return
917
918
if(not breakmodel.gpu_blocks):
919
logger.warning("Nothing assigned to a GPU, reverting to CPU only mode")
920
import breakmodel
921
breakmodel.primary_device = "cpu"
922
vars.breakmodel = False
923
vars.usegpu = False
924
return
925
926
def move_model_to_devices(model):
927
global generator
928
929
if(not utils.HAS_ACCELERATE and not vars.breakmodel):
930
if(vars.usegpu):
931
model = model.half().to(vars.gpu_device)
932
else:
933
model = model.to('cpu').float()
934
generator = model.generate
935
return
936
937
import breakmodel
938
939
if(utils.HAS_ACCELERATE):
940
import accelerate.utils
941
for key, value in model.state_dict().items():
942
target_dtype = torch.float32 if breakmodel.primary_device == "cpu" else torch.float16
943
if(value.dtype is not target_dtype):
944
accelerate.utils.set_module_tensor_to_device(model, key, target_dtype)
945
disk_blocks = breakmodel.disk_blocks
946
gpu_blocks = breakmodel.gpu_blocks
947
ram_blocks = len(utils.layers_module_names) - sum(gpu_blocks)
948
cumulative_gpu_blocks = tuple(itertools.accumulate(gpu_blocks))
949
device_map = {}
950
for name in utils.layers_module_names:
951
layer = int(name.rsplit(".", 1)[1])
952
device = ("disk" if layer < disk_blocks else "cpu") if layer < ram_blocks else bisect.bisect_right(cumulative_gpu_blocks, layer - ram_blocks)
953
device_map[name] = device
954
for name in utils.get_missing_module_names(model, list(device_map.keys())):
955
device_map[name] = breakmodel.primary_device
956
breakmodel.dispatch_model_ex(model, device_map, main_device=breakmodel.primary_device, offload_buffers=True, offload_dir="accelerate-disk-cache")
957
gc.collect()
958
generator = model.generate
959
return
960
961
model.half()
962
gc.collect()
963
964
if(hasattr(model, "transformer")):
965
model.transformer.wte.to(breakmodel.primary_device)
966
model.transformer.ln_f.to(breakmodel.primary_device)
967
if(hasattr(model, 'lm_head')):
968
model.lm_head.to(breakmodel.primary_device)
969
if(hasattr(model.transformer, 'wpe')):
970
model.transformer.wpe.to(breakmodel.primary_device)
971
elif(not hasattr(model.model, "decoder")):
972
model.model.embed_tokens.to(breakmodel.primary_device)
973
model.model.layer_norm.to(breakmodel.primary_device)
974
model.lm_head.to(breakmodel.primary_device)
975
model.model.embed_positions.to(breakmodel.primary_device)
976
else:
977
model.model.decoder.embed_tokens.to(breakmodel.primary_device)
978
if(model.model.decoder.project_in is not None):
979
model.model.decoder.project_in.to(breakmodel.primary_device)
980
if(model.model.decoder.project_out is not None):
981
model.model.decoder.project_out.to(breakmodel.primary_device)
982
model.model.decoder.embed_positions.to(breakmodel.primary_device)
983
gc.collect()
984
GPTNeoModel.forward = breakmodel.new_forward_neo
985
if("GPTJModel" in globals()):
986
GPTJModel.forward = breakmodel.new_forward_neo # type: ignore
987
if("XGLMModel" in globals()):
988
XGLMModel.forward = breakmodel.new_forward_xglm # type: ignore
989
if("OPTDecoder" in globals()):
990
OPTDecoder.forward = breakmodel.new_forward_opt # type: ignore
991
generator = model.generate
992
if(hasattr(model, "transformer")):
993
breakmodel.move_hidden_layers(model.transformer)
994
elif(not hasattr(model.model, "decoder")):
995
breakmodel.move_hidden_layers(model.model, model.model.layers)
996
else:
997
breakmodel.move_hidden_layers(model.model.decoder, model.model.decoder.layers)
998
999
#==================================================================#
1000
# Allow the models to override some settings
1001
#==================================================================#
1002
def loadmodelsettings():
1003
try:
1004
js = json.loads(str(model_config).partition(' ')[2])
1005
except Exception as e:
1006
try:
1007
try:
1008
js = json.load(open(vars.custmodpth + "/config.json", "r"))
1009
except Exception as e:
1010
js = json.load(open(vars.custmodpth.replace('/', '_') + "/config.json", "r"))
1011
except Exception as e:
1012
js = {}
1013
if vars.model_type == "xglm" or js.get("compat", "j") == "fairseq_lm":
1014
vars.newlinemode = "s" # Default to </s> newline mode if using XGLM
1015
if vars.model_type == "opt" or vars.model_type == "bloom":
1016
vars.newlinemode = "ns" # Handle </s> but don't convert newlines if using Fairseq models that have newlines trained in them
1017
vars.modelconfig = js
1018
if("badwordsids" in js):
1019
vars.badwordsids = js["badwordsids"]
1020
if("nobreakmodel" in js):
1021
vars.nobreakmodel = js["nobreakmodel"]
1022
if("sampler_order" in js):
1023
sampler_order = js["sampler_order"]
1024
if(len(sampler_order) < 7):
1025
sampler_order = [6] + sampler_order
1026
vars.sampler_order = sampler_order
1027
if("temp" in js):
1028
vars.temp = js["temp"]
1029
if("top_p" in js):
1030
vars.top_p = js["top_p"]
1031
if("top_k" in js):
1032
vars.top_k = js["top_k"]
1033
if("tfs" in js):
1034
vars.tfs = js["tfs"]
1035
if("typical" in js):
1036
vars.typical = js["typical"]
1037
if("top_a" in js):
1038
vars.top_a = js["top_a"]
1039
if("rep_pen" in js):
1040
vars.rep_pen = js["rep_pen"]
1041
if("rep_pen_slope" in js):
1042
vars.rep_pen_slope = js["rep_pen_slope"]
1043
if("rep_pen_range" in js):
1044
vars.rep_pen_range = js["rep_pen_range"]
1045
if("adventure" in js):
1046
vars.adventure = js["adventure"]
1047
if("chatmode" in js):
1048
vars.chatmode = js["chatmode"]
1049
if("dynamicscan" in js):
1050
vars.dynamicscan = js["dynamicscan"]
1051
if("formatoptns" in js):
1052
vars.formatoptns = js["formatoptns"]
1053
if("welcome" in js):
1054
vars.welcome = js["welcome"]
1055
if("newlinemode" in js):
1056
vars.newlinemode = js["newlinemode"]
1057
if("antemplate" in js):
1058
vars.setauthornotetemplate = js["antemplate"]
1059
if(not vars.gamestarted):
1060
vars.authornotetemplate = vars.setauthornotetemplate
1061
1062
#==================================================================#
1063
# Take settings from vars and write them to client settings file
1064
#==================================================================#
1065
def savesettings():
1066
# Build json to write
1067
js = {}
1068
js["apikey"] = vars.apikey
1069
js["andepth"] = vars.andepth
1070
js["sampler_order"] = vars.sampler_order
1071
js["temp"] = vars.temp
1072
js["top_p"] = vars.top_p
1073
js["top_k"] = vars.top_k
1074
js["tfs"] = vars.tfs
1075
js["typical"] = vars.typical
1076
js["top_a"] = vars.top_a
1077
js["rep_pen"] = vars.rep_pen
1078
js["rep_pen_slope"] = vars.rep_pen_slope
1079
js["rep_pen_range"] = vars.rep_pen_range
1080
js["genamt"] = vars.genamt
1081
js["max_length"] = vars.max_length
1082
js["ikgen"] = vars.ikgen
1083
js["formatoptns"] = vars.formatoptns
1084
js["numseqs"] = vars.numseqs
1085
js["widepth"] = vars.widepth
1086
js["useprompt"] = vars.useprompt
1087
js["adventure"] = vars.adventure
1088
js["chatmode"] = vars.chatmode
1089
js["chatname"] = vars.chatname
1090
js["dynamicscan"] = vars.dynamicscan
1091
js["nopromptgen"] = vars.nopromptgen
1092
js["rngpersist"] = vars.rngpersist
1093
js["nogenmod"] = vars.nogenmod
1094
js["fulldeterminism"] = vars.full_determinism
1095
js["autosave"] = vars.autosave
1096
js["welcome"] = vars.welcome
1097
js["output_streaming"] = vars.output_streaming
1098
js["show_probs"] = vars.show_probs
1099
js["show_budget"] = vars.show_budget
1100
1101
if(vars.seed_specified):
1102
js["seed"] = vars.seed
1103
else:
1104
js["seed"] = None
1105
1106
js["newlinemode"] = vars.newlinemode
1107
1108
js["antemplate"] = vars.setauthornotetemplate
1109
1110
js["userscripts"] = vars.userscripts
1111
js["corescript"] = vars.corescript
1112
js["softprompt"] = vars.spfilename
1113
1114
# Write it
1115
if not os.path.exists('settings'):
1116
os.mkdir('settings')
1117
file = open(get_config_filename(), "w")
1118
try:
1119
file.write(json.dumps(js, indent=3))
1120
finally:
1121
file.close()
1122
1123
#==================================================================#
1124
# Don't save settings unless 2 seconds have passed without modification
1125
#==================================================================#
1126
@debounce(2)
1127
def settingschanged():
1128
logger.info("Saving settings.")
1129
savesettings()
1130
1131
#==================================================================#
1132
# Read settings from client file JSON and send to vars
1133
#==================================================================#
1134
1135
def loadsettings():
1136
if(path.exists("defaults/" + getmodelname().replace('/', '_') + ".settings")):
1137
# Read file contents into JSON object
1138
file = open("defaults/" + getmodelname().replace('/', '_') + ".settings", "r")
1139
js = json.load(file)
1140
1141
processsettings(js)
1142
file.close()
1143
if(path.exists(get_config_filename())):
1144
# Read file contents into JSON object
1145
file = open(get_config_filename(), "r")
1146
js = json.load(file)
1147
1148
processsettings(js)
1149
file.close()
1150
1151
def processsettings(js):
1152
# Copy file contents to vars
1153
if("apikey" in js):
1154
# If the model is the HORDE, then previously saved API key in settings
1155
# Will always override a new key set.
1156
if vars.model != "CLUSTER" or vars.apikey == '':
1157
vars.apikey = js["apikey"]
1158
if("andepth" in js):
1159
vars.andepth = js["andepth"]
1160
if("sampler_order" in js):
1161
sampler_order = js["sampler_order"]
1162
if(len(sampler_order) < 7):
1163
sampler_order = [6] + sampler_order
1164
vars.sampler_order = sampler_order
1165
if("temp" in js):
1166
vars.temp = js["temp"]
1167
if("top_p" in js):
1168
vars.top_p = js["top_p"]
1169
if("top_k" in js):
1170
vars.top_k = js["top_k"]
1171
if("tfs" in js):
1172
vars.tfs = js["tfs"]
1173
if("typical" in js):
1174
vars.typical = js["typical"]
1175
if("top_a" in js):
1176
vars.top_a = js["top_a"]
1177
if("rep_pen" in js):
1178
vars.rep_pen = js["rep_pen"]
1179
if("rep_pen_slope" in js):
1180
vars.rep_pen_slope = js["rep_pen_slope"]
1181
if("rep_pen_range" in js):
1182
vars.rep_pen_range = js["rep_pen_range"]
1183
if("genamt" in js):
1184
vars.genamt = js["genamt"]
1185
if("max_length" in js):
1186
vars.max_length = js["max_length"]
1187
if("ikgen" in js):
1188
vars.ikgen = js["ikgen"]
1189
if("formatoptns" in js):
1190
vars.formatoptns = js["formatoptns"]
1191
if("numseqs" in js):
1192
vars.numseqs = js["numseqs"]
1193
if("widepth" in js):
1194
vars.widepth = js["widepth"]
1195
if("useprompt" in js):
1196
vars.useprompt = js["useprompt"]
1197
if("adventure" in js):
1198
vars.adventure = js["adventure"]
1199
if("chatmode" in js):
1200
vars.chatmode = js["chatmode"]
1201
if("chatname" in js):
1202
vars.chatname = js["chatname"]
1203
if("dynamicscan" in js):
1204
vars.dynamicscan = js["dynamicscan"]
1205
if("nopromptgen" in js):
1206
vars.nopromptgen = js["nopromptgen"]
1207
if("rngpersist" in js):
1208
vars.rngpersist = js["rngpersist"]
1209
if("nogenmod" in js):
1210
vars.nogenmod = js["nogenmod"]
1211
if("fulldeterminism" in js):
1212
vars.full_determinism = js["fulldeterminism"]
1213
if("autosave" in js):
1214
vars.autosave = js["autosave"]
1215
if("newlinemode" in js):
1216
vars.newlinemode = js["newlinemode"]
1217
if("welcome" in js):
1218
vars.welcome = js["welcome"]
1219
if("output_streaming" in js):
1220
vars.output_streaming = js["output_streaming"]
1221
if("show_probs" in js):
1222
vars.show_probs = js["show_probs"]
1223
if("show_budget" in js):
1224
vars.show_budget = js["show_budget"]
1225
1226
if("seed" in js):
1227
vars.seed = js["seed"]
1228
if(vars.seed is not None):
1229
vars.seed_specified = True
1230
else:
1231
vars.seed_specified = False
1232
else:
1233
vars.seed_specified = False
1234
1235
if("antemplate" in js):
1236
vars.setauthornotetemplate = js["antemplate"]
1237
if(not vars.gamestarted):
1238
vars.authornotetemplate = vars.setauthornotetemplate
1239
1240
if("userscripts" in js):
1241
vars.userscripts = []
1242
for userscript in js["userscripts"]:
1243
if type(userscript) is not str:
1244
continue
1245
userscript = userscript.strip()
1246
if len(userscript) != 0 and all(q not in userscript for q in ("..", ":")) and all(userscript[0] not in q for q in ("/", "\\")) and os.path.exists(fileops.uspath(userscript)):
1247
vars.userscripts.append(userscript)
1248
1249
if("corescript" in js and type(js["corescript"]) is str and all(q not in js["corescript"] for q in ("..", ":")) and all(js["corescript"][0] not in q for q in ("/", "\\"))):
1250
vars.corescript = js["corescript"]
1251
else:
1252
vars.corescript = "default.lua"
1253
1254
#==================================================================#
1255
# Load a soft prompt from a file
1256
#==================================================================#
1257
1258
def check_for_sp_change():
1259
while(True):
1260
time.sleep(0.05)
1261
1262
if(vars.sp_changed):
1263
with app.app_context():
1264
emit('from_server', {'cmd': 'spstatitems', 'data': {vars.spfilename: vars.spmeta} if vars.allowsp and len(vars.spfilename) else {}}, namespace=None, broadcast=True)
1265
vars.sp_changed = False
1266
1267
if(vars.token_stream_queue.queue):
1268
# If emit blocks, waiting for it to complete before clearing could
1269
# introduce a race condition that drops tokens.
1270
queued_tokens = list(vars.token_stream_queue.queue)
1271
vars.token_stream_queue.queue.clear()
1272
socketio.emit("from_server", {"cmd": "streamtoken", "data": queued_tokens}, namespace=None, broadcast=True)
1273
1274
socketio.start_background_task(check_for_sp_change)
1275
1276
def spRequest(filename):
1277
if(not vars.allowsp):
1278
raise RuntimeError("Soft prompts are not supported by your current model/backend")
1279
1280
old_filename = vars.spfilename
1281
1282
vars.spfilename = ""
1283
settingschanged()
1284
1285
if(len(filename) == 0):
1286
vars.sp = None
1287
vars.sp_length = 0
1288
if(old_filename != filename):
1289
vars.sp_changed = True
1290
return
1291
1292
global np
1293
if 'np' not in globals():
1294
import numpy as np
1295
1296
z, version, shape, fortran_order, dtype = fileops.checksp(filename, vars.modeldim)
1297
if not isinstance(z, zipfile.ZipFile):
1298
raise RuntimeError(f"{repr(filename)} is not a valid soft prompt file")
1299
with z.open('meta.json') as f:
1300
vars.spmeta = json.load(f)
1301
z.close()
1302
1303
with np.load(fileops.sppath(filename), allow_pickle=False) as f:
1304
tensor = f['tensor.npy']
1305
1306
# If the tensor is in bfloat16 format, convert it to float32
1307
if(tensor.dtype == 'V2'):
1308
tensor.dtype = np.uint16
1309
tensor = np.uint32(tensor) << 16
1310
tensor.dtype = np.float32
1311
1312
if(tensor.dtype != np.float16):
1313
tensor = np.float32(tensor)
1314
assert not np.isinf(tensor).any() and not np.isnan(tensor).any()
1315
1316
vars.sp_length = tensor.shape[-2]
1317
vars.spmeta["n_tokens"] = vars.sp_length
1318
1319
if(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
1320
rows = tensor.shape[0]
1321
padding_amount = tpu_mtj_backend.params["seq"] - (tpu_mtj_backend.params["seq"] % -tpu_mtj_backend.params["cores_per_replica"]) - rows
1322
tensor = np.pad(tensor, ((0, padding_amount), (0, 0)))
1323
tensor = tensor.reshape(
1324
tpu_mtj_backend.params["cores_per_replica"],
1325
-1,
1326
tpu_mtj_backend.params.get("d_embed", tpu_mtj_backend.params["d_model"]),
1327
)
1328
vars.sp = tpu_mtj_backend.shard_xmap(np.float32(tensor))
1329
else:
1330
vars.sp = torch.from_numpy(tensor)
1331
1332
vars.spfilename = filename
1333
settingschanged()
1334
if(old_filename != filename):
1335
vars.sp_changed = True
1336
1337
#==================================================================#
1338
# Startup
1339
#==================================================================#
1340
def general_startup(override_args=None):
1341
global args
1342
# Parsing Parameters
1343
parser = argparse.ArgumentParser(description="KoboldAI Server")
1344
parser.add_argument("--remote", action='store_true', help="Optimizes KoboldAI for Remote Play")
1345
parser.add_argument("--noaimenu", action='store_true', help="Disables the ability to select the AI")
1346
parser.add_argument("--ngrok", action='store_true', help="Optimizes KoboldAI for Remote Play using Ngrok")
1347
parser.add_argument("--localtunnel", action='store_true', help="Optimizes KoboldAI for Remote Play using Localtunnel")
1348
parser.add_argument("--host", action='store_true', help="Optimizes KoboldAI for Remote Play without using a proxy service")
1349
parser.add_argument("--port", type=int, help="Specify the port on which the application will be joinable")
1350
parser.add_argument("--aria2_port", type=int, help="Specify the port on which aria2's RPC interface will be open if aria2 is installed (defaults to 6799)")
1351
parser.add_argument("--model", help="Specify the Model Type to skip the Menu")
1352
parser.add_argument("--path", help="Specify the Path for local models (For model NeoCustom or GPT2Custom)")
1353
parser.add_argument("--apikey", help="Specify the API key to use for online services")
1354
parser.add_argument("--req_model", type=str, action='append', required=False, help="Which models which we allow to generate for us during cluster mode. Can be specified multiple times.")
1355
parser.add_argument("--revision", help="Specify the model revision for huggingface models (can be a git branch/tag name or a git commit hash)")
1356
parser.add_argument("--cpu", action='store_true', help="By default unattended launches are on the GPU use this option to force CPU usage.")
1357
parser.add_argument("--breakmodel", action='store_true', help=argparse.SUPPRESS)
1358
parser.add_argument("--breakmodel_layers", type=int, help=argparse.SUPPRESS)
1359
parser.add_argument("--breakmodel_gpulayers", type=str, help="If using a model that supports hybrid generation, this is a comma-separated list that specifies how many layers to put on each GPU device. For example to put 8 layers on device 0, 9 layers on device 1 and 11 layers on device 2, use --breakmodel_gpulayers 8,9,11")
1360
parser.add_argument("--breakmodel_disklayers", type=int, help="If using a model that supports hybrid generation, this is the number of layers to put in disk cache.")
1361
parser.add_argument("--override_delete", action='store_true', help="Deleting stories from inside the browser is disabled if you are using --remote and enabled otherwise. Using this option will instead allow deleting stories if using --remote and prevent deleting stories otherwise.")
1362
parser.add_argument("--override_rename", action='store_true', help="Renaming stories from inside the browser is disabled if you are using --remote and enabled otherwise. Using this option will instead allow renaming stories if using --remote and prevent renaming stories otherwise.")
1363
parser.add_argument("--configname", help="Force a fixed configuration name to aid with config management.")
1364
parser.add_argument("--colab", action='store_true', help="Optimize for Google Colab.")
1365
parser.add_argument("--nobreakmodel", action='store_true', help="Disables Breakmodel support completely.")
1366
parser.add_argument("--unblock", action='store_true', default=False, help="Unblocks the KoboldAI port to be accessible from other machines without optimizing for remote play (It is recommended to use --host instead)")
1367
parser.add_argument("--quiet", action='store_true', default=False, help="If present will suppress any story related text from showing on the console")
1368
parser.add_argument("--no_aria2", action='store_true', default=False, help="Prevents KoboldAI from using aria2 to download huggingface models more efficiently, in case aria2 is causing you issues")
1369
parser.add_argument("--lowmem", action='store_true', help="Extra Low Memory loading for the GPU, slower but memory does not peak to twice the usage")
1370
parser.add_argument("--savemodel", action='store_true', help="Saves the model to the models folder even if --colab is used (Allows you to save models to Google Drive)")
1371
parser.add_argument("--customsettings", help="Preloads arguements from json file. You only need to provide the location of the json file. Use customsettings.json template file. It can be renamed if you wish so that you can store multiple configurations. Leave any settings you want as default as null. Any values you wish to set need to be in double quotation marks")
1372
parser.add_argument("--no_ui", action='store_true', default=False, help="Disables the GUI and Socket.IO server while leaving the API server running.")
1373
parser.add_argument('-v', '--verbosity', action='count', default=0, help="The default logging level is ERROR or higher. This value increases the amount of logging seen in your screen")
1374
parser.add_argument('-q', '--quiesce', action='count', default=0, help="The default logging level is ERROR or higher. This value decreases the amount of logging seen in your screen")
1375
1376
#args: argparse.Namespace = None
1377
if "pytest" in sys.modules and override_args is None:
1378
args = parser.parse_args([])
1379
return
1380
if override_args is not None:
1381
import shlex
1382
args = parser.parse_args(shlex.split(override_args))
1383
elif(os.environ.get("KOBOLDAI_ARGS") is not None):
1384
import shlex
1385
args = parser.parse_args(shlex.split(os.environ["KOBOLDAI_ARGS"]))
1386
else:
1387
args = parser.parse_args()
1388
1389
utils.args = args
1390
1391
set_logger_verbosity(args.verbosity)
1392
quiesce_logger(args.quiesce)
1393
if args.customsettings:
1394
f = open (args.customsettings)
1395
importedsettings = json.load(f)
1396
for items in importedsettings:
1397
if importedsettings[items] is not None:
1398
setattr(args, items, importedsettings[items])
1399
f.close()
1400
1401
if args.no_ui:
1402
def new_emit(*args, **kwargs):
1403
return
1404
old_emit = socketio.emit
1405
socketio.emit = new_emit
1406
1407
vars.model = args.model;
1408
vars.revision = args.revision
1409
1410
if args.apikey:
1411
vars.apikey = args.apikey
1412
if args.req_model:
1413
vars.cluster_requested_models = args.req_model
1414
1415
if args.colab:
1416
args.remote = True;
1417
args.override_rename = True;
1418
args.override_delete = True;
1419
args.nobreakmodel = True;
1420
args.quiet = True;
1421
args.lowmem = True;
1422
args.noaimenu = True;
1423
1424
if args.quiet:
1425
vars.quiet = True
1426
1427
if args.nobreakmodel:
1428
vars.nobreakmodel = True;
1429
1430
if args.remote:
1431
vars.host = True;
1432
1433
if args.ngrok:
1434
vars.host = True;
1435
1436
if args.localtunnel:
1437
vars.host = True;
1438
1439
if args.host:
1440
vars.host = True;
1441
1442
if args.cpu:
1443
vars.use_colab_tpu = False
1444
1445
vars.smandelete = vars.host == args.override_delete
1446
vars.smanrename = vars.host == args.override_rename
1447
1448
vars.aria2_port = args.aria2_port or 6799
1449
1450
#Now let's look to see if we are going to force a load of a model from a user selected folder
1451
if(vars.model == "selectfolder"):
1452
print("{0}Please choose the folder where pytorch_model.bin is located:{1}\n".format(colors.CYAN, colors.END))
1453
modpath = fileops.getdirpath(getcwd() + "/models", "Select Model Folder")
1454
1455
if(modpath):
1456
# Save directory to vars
1457
vars.model = "NeoCustom"
1458
vars.custmodpth = modpath
1459
elif args.model:
1460
logger.message(f"Welcome to KoboldAI!")
1461
logger.message(f"You have selected the following Model: {vars.model}")
1462
if args.path:
1463
logger.message(f"You have selected the following path for your Model: {args.path}")
1464
vars.custmodpth = args.path;
1465
vars.colaburl = args.path + "/request"; # Lets just use the same parameter to keep it simple
1466
#==================================================================#
1467
# Load Model
1468
#==================================================================#
1469
1470
def tpumtjgetsofttokens():
1471
soft_tokens = None
1472
if(vars.sp is None):
1473
global np
1474
if 'np' not in globals():
1475
import numpy as np
1476
tensor = np.zeros((1, tpu_mtj_backend.params.get("d_embed", tpu_mtj_backend.params["d_model"])), dtype=np.float32)
1477
rows = tensor.shape[0]
1478
padding_amount = tpu_mtj_backend.params["seq"] - (tpu_mtj_backend.params["seq"] % -tpu_mtj_backend.params["cores_per_replica"]) - rows
1479
tensor = np.pad(tensor, ((0, padding_amount), (0, 0)))
1480
tensor = tensor.reshape(
1481
tpu_mtj_backend.params["cores_per_replica"],
1482
-1,
1483
tpu_mtj_backend.params.get("d_embed", tpu_mtj_backend.params["d_model"]),
1484
)
1485
vars.sp = tpu_mtj_backend.shard_xmap(tensor)
1486
soft_tokens = np.arange(
1487
tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"],
1488
tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"] + vars.sp_length,
1489
dtype=np.uint32
1490
)
1491
return soft_tokens
1492
1493
def get_model_info(model, directory=""):
1494
# if the model is in the api list
1495
disk_blocks = 0
1496
key = False
1497
breakmodel = False
1498
gpu = False
1499
layer_count = None
1500
key_value = ""
1501
break_values = []
1502
url = False
1503
default_url = None
1504
models_on_url = False
1505
multi_online_models = False
1506
gpu_count = torch.cuda.device_count()
1507
gpu_names = []
1508
send_horde_models = False
1509
for i in range(gpu_count):
1510
gpu_names.append(torch.cuda.get_device_name(i))
1511
if model in ['Colab', 'API']:
1512
url = True
1513
elif model == 'CLUSTER':
1514
models_on_url = True
1515
url = True
1516
key = True
1517
default_url = 'https://horde.koboldai.net'
1518
multi_online_models = True
1519
if path.exists(get_config_filename(model)):
1520
with open(get_config_filename(model), "r") as file:
1521
# Check if API key exists
1522
js = json.load(file)
1523
if("apikey" in js and js["apikey"] != ""):
1524
# API key exists, grab it and close the file
1525
key_value = js["apikey"]
1526
elif 'oaiapikey' in js and js['oaiapikey'] != "":
1527
key_value = js["oaiapikey"]
1528
if 'url' in js and js['url'] != "":
1529
url = js['url']
1530
if key_value != "":
1531
send_horde_models = True
1532
elif model in [x[1] for x in model_menu['apilist']]:
1533
if path.exists(get_config_filename(model)):
1534
with open(get_config_filename(model), "r") as file:
1535
# Check if API key exists
1536
js = json.load(file)
1537
if("apikey" in js and js["apikey"] != ""):
1538
# API key exists, grab it and close the file
1539
key_value = js["apikey"]
1540
elif 'oaiapikey' in js and js['oaiapikey'] != "":
1541
key_value = js["oaiapikey"]
1542
key = True
1543
elif model == 'ReadOnly':
1544
pass
1545
elif not utils.HAS_ACCELERATE and not torch.cuda.is_available():
1546
pass
1547
elif args.cpu:
1548
pass
1549
else:
1550
layer_count = get_layer_count(model, directory=directory)
1551
if layer_count is None:
1552
breakmodel = False
1553
gpu = True
1554
else:
1555
breakmodel = True
1556
if model in ["NeoCustom", "GPT2Custom"]:
1557
filename = "settings/{}.breakmodel".format(os.path.basename(os.path.normpath(directory)))
1558
else:
1559
filename = "settings/{}.breakmodel".format(model.replace("/", "_"))
1560
if path.exists(filename):
1561
with open(filename, "r") as file:
1562
data = file.read().split("\n")[:2]
1563
if len(data) < 2:
1564
data.append("0")
1565
break_values, disk_blocks = data
1566
break_values = break_values.split(",")
1567
else:
1568
break_values = [layer_count]
1569
break_values += [0] * (gpu_count - len(break_values))
1570
#print("Model_info: {}".format({'cmd': 'selected_model_info', 'key_value': key_value, 'key':key,
1571
# 'gpu':gpu, 'layer_count':layer_count, 'breakmodel':breakmodel,
1572
# 'break_values': break_values, 'gpu_count': gpu_count,
1573
# 'url': url, 'gpu_names': gpu_names}))
1574
emit('from_server', {'cmd': 'selected_model_info', 'key_value': key_value, 'key':key,
1575
'gpu':gpu, 'layer_count':layer_count, 'breakmodel':breakmodel,
1576
'disk_break_value': disk_blocks, 'accelerate': utils.HAS_ACCELERATE,
1577
'break_values': break_values, 'gpu_count': gpu_count, 'multi_online_models': multi_online_models,
1578
'url': url, 'default_url': default_url, 'gpu_names': gpu_names, 'models_on_url': models_on_url}, broadcast=True)
1579
if send_horde_models:
1580
get_cluster_models({'key': key_value, 'url': default_url})
1581
elif key_value != "" and model in [x[1] for x in model_menu['apilist']] and model != 'CLUSTER':
1582
get_oai_models(key_value)
1583
1584
1585
def get_layer_count(model, directory=""):
1586
if(model not in ["InferKit", "Colab", "API", "CLUSTER", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ"]):
1587
if(model == "GPT2Custom"):
1588
with open(os.path.join(directory, "config.json"), "r") as f:
1589
model_config = json.load(f)
1590
# Get the model_type from the config or assume a model type if it isn't present
1591
else:
1592
if(directory):
1593
model = directory
1594
from transformers import AutoConfig
1595
if(os.path.isdir(model.replace('/', '_'))):
1596
model_config = AutoConfig.from_pretrained(model.replace('/', '_'), revision=args.revision, cache_dir="cache")
1597
elif(os.path.isdir("models/{}".format(model.replace('/', '_')))):
1598
model_config = AutoConfig.from_pretrained("models/{}".format(model.replace('/', '_')), revision=args.revision, cache_dir="cache")
1599
elif(os.path.isdir(directory)):
1600
model_config = AutoConfig.from_pretrained(directory, revision=args.revision, cache_dir="cache")
1601
else:
1602
model_config = AutoConfig.from_pretrained(model, revision=args.revision, cache_dir="cache")
1603
try:
1604
if ((utils.HAS_ACCELERATE and model_config.model_type != 'gpt2') or model_config.model_type in ("gpt_neo", "gptj", "xglm", "opt")) and not vars.nobreakmodel:
1605
return utils.num_layers(model_config)
1606
else:
1607
return None
1608
except:
1609
return None
1610
else:
1611
return None
1612
1613
def get_oai_models(key):
1614
vars.oaiapikey = key
1615
if vars.model_selected == 'OAI':
1616
url = "https://api.openai.com/v1/engines"
1617
elif vars.model_selected == 'GooseAI':
1618
url = "https://api.goose.ai/v1/engines"
1619
else:
1620
return
1621
1622
# Get list of models from OAI
1623
logger.init("OAI Engines", status="Retrieving")
1624
req = requests.get(
1625
url,
1626
headers = {
1627
'Authorization': 'Bearer '+key
1628
}
1629
)
1630
if(req.status_code == 200):
1631
engines = req.json()["data"]
1632
try:
1633
engines = [[en["id"], "{} ({})".format(en['id'], "Ready" if en["ready"] == True else "Not Ready")] for en in engines]
1634
except:
1635
logger.error(engines)
1636
raise
1637
1638
online_model = ""
1639
changed=False
1640
1641
#Save the key
1642
if not path.exists("settings"):
1643
# If the client settings file doesn't exist, create it
1644
# Write API key to file
1645
os.makedirs('settings', exist_ok=True)
1646
if path.exists(get_config_filename(vars.model_selected)):
1647
with open(get_config_filename(vars.model_selected), "r") as file:
1648
js = json.load(file)
1649
if 'online_model' in js:
1650
online_model = js['online_model']
1651
if "apikey" in js:
1652
if js['apikey'] != key:
1653
changed=True
1654
else:
1655
changed=True
1656
if changed:
1657
js={}
1658
with open(get_config_filename(vars.model_selected), "w") as file:
1659
js["apikey"] = key
1660
file.write(json.dumps(js, indent=3))
1661
1662
logger.init_ok("OAI Engines", status="OK")
1663
emit('from_server', {'cmd': 'oai_engines', 'data': engines, 'online_model': online_model}, broadcast=True)
1664
else:
1665
# Something went wrong, print the message and quit since we can't initialize an engine
1666
logger.init_err("OAI Engines", status="Failed")
1667
logger.error(req.json())
1668
emit('from_server', {'cmd': 'errmsg', 'data': req.json()})
1669
1670
def get_cluster_models(msg):
1671
vars.oaiapikey = msg['key']
1672
vars.apikey = vars.oaiapikey
1673
url = msg['url']
1674
# Get list of models from public cluster
1675
logger.init("KAI Horde Models", status="Retrieving")
1676
try:
1677
req = requests.get(f"{url}/api/v2/status/models?type=text")
1678
except requests.exceptions.ConnectionError:
1679
logger.init_err("KAI Horde Models", status="Failed")
1680
logger.error("Provided KoboldAI Horde URL unreachable")
1681
emit('from_server', {'cmd': 'errmsg', 'data': "Provided KoboldAI Horde URL unreachable"})
1682
return
1683
if(not req.ok):
1684
# Something went wrong, print the message and quit since we can't initialize an engine
1685
logger.init_err("KAI Horde Models", status="Failed")
1686
logger.error(req.json())
1687
emit('from_server', {'cmd': 'errmsg', 'data': req.json()})
1688
return
1689
1690
engines = req.json()
1691
logger.debug(engines)
1692
try:
1693
engines = [[en["name"], en["name"]] for en in engines]
1694
except:
1695
logger.error(engines)
1696
raise
1697
logger.debug(engines)
1698
1699
online_model = ""
1700
changed=False
1701
1702
#Save the key
1703
if not path.exists("settings"):
1704
# If the client settings file doesn't exist, create it
1705
# Write API key to file
1706
os.makedirs('settings', exist_ok=True)
1707
if path.exists(get_config_filename(vars.model_selected)):
1708
with open(get_config_filename(vars.model_selected), "r") as file:
1709
js = json.load(file)
1710
if 'online_model' in js:
1711
online_model = js['online_model']
1712
if "apikey" in js:
1713
if js['apikey'] != vars.oaiapikey:
1714
changed=True
1715
else:
1716
changed=True
1717
if changed:
1718
js={}
1719
with open(get_config_filename(vars.model_selected), "w") as file:
1720
js["apikey"] = vars.oaiapikey
1721
js["url"] = url
1722
file.write(json.dumps(js, indent=3))
1723
1724
logger.init_ok("KAI Horde Models", status="OK")
1725
emit('from_server', {'cmd': 'oai_engines', 'data': engines, 'online_model': online_model}, broadcast=True)
1726
1727
1728
# Function to patch transformers to use our soft prompt
1729
def patch_causallm(model):
1730
from torch.nn import Embedding
1731
if(getattr(Embedding, "_koboldai_patch_causallm_model", None)):
1732
Embedding._koboldai_patch_causallm_model = model
1733
return model
1734
old_embedding_call = Embedding.__call__
1735
def new_embedding_call(self, input_ids, *args, **kwargs):
1736
if(Embedding._koboldai_patch_causallm_model.get_input_embeddings() is not self):
1737
return old_embedding_call(self, input_ids, *args, **kwargs)
1738
assert input_ids is not None
1739
if(vars.sp is not None):
1740
shifted_input_ids = input_ids - model.config.vocab_size
1741
input_ids.clamp_(max=model.config.vocab_size-1)
1742
inputs_embeds = old_embedding_call(self, input_ids, *args, **kwargs)
1743
if(vars.sp is not None):
1744
vars.sp = vars.sp.to(inputs_embeds.dtype).to(inputs_embeds.device)
1745
inputs_embeds = torch.where(
1746
(shifted_input_ids >= 0)[..., None],
1747
vars.sp[shifted_input_ids.clamp(min=0)],
1748
inputs_embeds,
1749
)
1750
return inputs_embeds
1751
Embedding.__call__ = new_embedding_call
1752
Embedding._koboldai_patch_causallm_model = model
1753
return model
1754
1755
def patch_transformers_download():
1756
global transformers
1757
import copy, requests, tqdm, time
1758
class Send_to_socketio(object):
1759
def write(self, bar):
1760
bar = bar.replace("\r", "").replace("\n", "")
1761
if bar != "":
1762
try:
1763
print(bar, end="\r")
1764
emit('from_server', {'cmd': 'model_load_status', 'data': bar.replace(" ", "&nbsp;")}, broadcast=True)
1765
eventlet.sleep(seconds=0)
1766
except:
1767
pass
1768
def http_get(
1769
url: str,
1770
temp_file,
1771
proxies=None,
1772
resume_size=0,
1773
headers=None,
1774
file_name=None,
1775
):
1776
"""
1777
Download remote file. Do not gobble up errors.
1778
"""
1779
headers = copy.deepcopy(headers)
1780
if resume_size > 0:
1781
headers["Range"] = f"bytes={resume_size}-"
1782
r = requests.get(url, stream=True, proxies=proxies, headers=headers)
1783
transformers.utils.hub._raise_for_status(r)
1784
content_length = r.headers.get("Content-Length")
1785
total = resume_size + int(content_length) if content_length is not None else None
1786
# `tqdm` behavior is determined by `utils.logging.is_progress_bar_enabled()`
1787
# and can be set using `utils.logging.enable/disable_progress_bar()`
1788
if url[-11:] != 'config.json':
1789
progress = tqdm.tqdm(
1790
unit="B",
1791
unit_scale=True,
1792
unit_divisor=1024,
1793
total=total,
1794
initial=resume_size,
1795
desc=f"Downloading {file_name}" if file_name is not None else "Downloading",
1796
file=Send_to_socketio(),
1797
)
1798
for chunk in r.iter_content(chunk_size=1024):
1799
if chunk: # filter out keep-alive new chunks
1800
if url[-11:] != 'config.json':
1801
progress.update(len(chunk))
1802
temp_file.write(chunk)
1803
if url[-11:] != 'config.json':
1804
progress.close()
1805
1806
transformers.utils.hub.http_get = http_get
1807
1808
1809
def patch_transformers():
1810
global transformers
1811
1812
patch_transformers_download()
1813
1814
old_from_pretrained = PreTrainedModel.from_pretrained.__func__
1815
@classmethod
1816
def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
1817
vars.fp32_model = False
1818
utils.num_shards = None
1819
utils.current_shard = 0
1820
utils.from_pretrained_model_name = pretrained_model_name_or_path
1821
utils.from_pretrained_index_filename = None
1822
utils.from_pretrained_kwargs = kwargs
1823
utils.bar = None
1824
if not args.no_aria2:
1825
utils.aria2_hook(pretrained_model_name_or_path, **kwargs)
1826
return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
1827
if(not hasattr(PreTrainedModel, "_kai_patched")):
1828
PreTrainedModel.from_pretrained = new_from_pretrained
1829
PreTrainedModel._kai_patched = True
1830
if(hasattr(modeling_utils, "get_checkpoint_shard_files")):
1831
old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
1832
def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs):
1833
utils.num_shards = utils.get_num_shards(index_filename)
1834
utils.from_pretrained_index_filename = index_filename
1835
return old_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs)
1836
modeling_utils.get_checkpoint_shard_files = new_get_checkpoint_shard_files
1837
1838
# Some versions of transformers 4.17.0.dev0 are affected by
1839
# https://github.com/huggingface/transformers/issues/15736
1840
# This is a workaround for those versions of transformers.
1841
if(transformers_version == "4.17.0.dev0"):
1842
try:
1843
from transformers.models.xglm.modeling_xglm import XGLMSinusoidalPositionalEmbedding
1844
except ImportError:
1845
pass
1846
else:
1847
@torch.no_grad()
1848
def new_forward(self, input_ids: torch.Tensor = None, inputs_embeds: torch.Tensor = None, past_key_values_length: int = 0):
1849
bsz, seq_len = inputs_embeds.size()[:-1]
1850
input_shape = inputs_embeds.size()[:-1]
1851
sequence_length = input_shape[1]
1852
position_ids = torch.arange(
1853
past_key_values_length + self.padding_idx + 1, past_key_values_length + sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
1854
).unsqueeze(0).expand(input_shape).contiguous()
1855
max_pos = self.padding_idx + 1 + seq_len + past_key_values_length
1856
if max_pos > self.weights.size(0):
1857
self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx)
1858
return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach()
1859
XGLMSinusoidalPositionalEmbedding.forward = new_forward
1860
1861
1862
# Fix a bug in OPTForCausalLM where self.lm_head is the wrong size
1863
if(packaging.version.parse("4.19.0.dev0") <= packaging.version.parse(transformers_version) < packaging.version.parse("4.20.0")):
1864
try:
1865
from transformers import OPTForCausalLM, OPTModel
1866
except ImportError:
1867
pass
1868
else:
1869
# This is the same as the original __init__ but with
1870
# config.hidden_size
1871
# replaced with
1872
# config.word_embed_proj_dim
1873
def new_init(self, config):
1874
super(OPTForCausalLM, self).__init__(config)
1875
self.model = OPTModel(config)
1876
self.lm_head = torch.nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False)
1877
self.post_init()
1878
OPTForCausalLM.__init__ = new_init
1879
1880
1881
# Patch transformers to use our custom logit warpers
1882
from transformers import LogitsProcessorList, LogitsWarper, LogitsProcessor, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper, RepetitionPenaltyLogitsProcessor
1883
from warpers import AdvancedRepetitionPenaltyLogitsProcessor, TailFreeLogitsWarper, TypicalLogitsWarper, TopALogitsWarper
1884
1885
def dynamic_processor_wrap(cls, field_name, var_name, cond=None):
1886
old_call = cls.__call__
1887
def new_call(self, *args, **kwargs):
1888
if(not isinstance(field_name, str) and isinstance(field_name, Iterable)):
1889
conds = []
1890
for f, v in zip(field_name, var_name):
1891
conds.append(getattr(vars, v))
1892
setattr(self, f, conds[-1])
1893
else:
1894
conds = getattr(vars, var_name)
1895
setattr(self, field_name, conds)
1896
assert len(args) == 2
1897
if(cond is None or cond(conds)):
1898
return old_call(self, *args, **kwargs)
1899
return args[1]
1900
cls.__call__ = new_call
1901
dynamic_processor_wrap(AdvancedRepetitionPenaltyLogitsProcessor, ("penalty", "penalty_slope", "penalty_range"), ("rep_pen", "rep_pen_slope", "rep_pen_range"), cond=lambda x: x[0] != 1.0)
1902
dynamic_processor_wrap(TopKLogitsWarper, "top_k", "top_k", cond=lambda x: x > 0)
1903
dynamic_processor_wrap(TopALogitsWarper, "top_a", "top_a", cond=lambda x: x > 0.0)
1904
dynamic_processor_wrap(TopPLogitsWarper, "top_p", "top_p", cond=lambda x: x < 1.0)
1905
dynamic_processor_wrap(TailFreeLogitsWarper, "tfs", "tfs", cond=lambda x: x < 1.0)
1906
dynamic_processor_wrap(TypicalLogitsWarper, "typical", "typical", cond=lambda x: x < 1.0)
1907
dynamic_processor_wrap(TemperatureLogitsWarper, "temperature", "temp", cond=lambda x: x != 1.0)
1908
1909
class LuaLogitsProcessor(LogitsProcessor):
1910
1911
def __init__(self):
1912
pass
1913
1914
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
1915
assert scores.ndim == 2
1916
assert input_ids.ndim == 2
1917
self.regeneration_required = False
1918
self.halt = False
1919
1920
if(vars.standalone):
1921
return scores
1922
1923
scores_shape = scores.shape
1924
scores_list = scores.tolist()
1925
vars.lua_koboldbridge.logits = vars.lua_state.table()
1926
for r, row in enumerate(scores_list):
1927
vars.lua_koboldbridge.logits[r+1] = vars.lua_state.table(*row)
1928
vars.lua_koboldbridge.vocab_size = scores_shape[-1]
1929
1930
execute_genmod()
1931
1932
scores = torch.tensor(
1933
tuple(tuple(row.values()) for row in vars.lua_koboldbridge.logits.values()),
1934
device=scores.device,
1935
dtype=scores.dtype,
1936
)
1937
assert scores.shape == scores_shape
1938
1939
return scores
1940
1941
from torch.nn import functional as F
1942
1943
def visualize_probabilities(scores: torch.FloatTensor) -> None:
1944
assert scores.ndim == 2
1945
1946
if vars.numseqs > 1 or not vars.show_probs:
1947
return
1948
1949
probs = F.softmax(scores, dim = -1).cpu().numpy()[0]
1950
token_prob_info = []
1951
for token_id, score in sorted(enumerate(probs), key=lambda x: x[1], reverse=True)[:8]:
1952
token_prob_info.append({
1953
"tokenId": token_id,
1954
"decoded": utils.decodenewlines(tokenizer.decode(token_id)),
1955
"score": float(score),
1956
})
1957
1958
vars.token_stream_queue.probability_buffer = token_prob_info
1959
1960
def new_get_logits_processor(*args, **kwargs) -> LogitsProcessorList:
1961
processors = new_get_logits_processor.old_get_logits_processor(*args, **kwargs)
1962
processors.insert(0, LuaLogitsProcessor())
1963
return processors
1964
new_get_logits_processor.old_get_logits_processor = transformers.generation_utils.GenerationMixin._get_logits_processor
1965
transformers.generation_utils.GenerationMixin._get_logits_processor = new_get_logits_processor
1966
1967
class KoboldLogitsWarperList(LogitsProcessorList):
1968
def __init__(self, beams: int = 1, **kwargs):
1969
self.__warper_list: List[LogitsWarper] = []
1970
self.__warper_list.append(TopKLogitsWarper(top_k=1, min_tokens_to_keep=1 + (beams > 1)))
1971
self.__warper_list.append(TopALogitsWarper(top_a=0.5, min_tokens_to_keep=1 + (beams > 1)))
1972
self.__warper_list.append(TopPLogitsWarper(top_p=0.5, min_tokens_to_keep=1 + (beams > 1)))
1973
self.__warper_list.append(TailFreeLogitsWarper(tfs=0.5, min_tokens_to_keep=1 + (beams > 1)))
1974
self.__warper_list.append(TypicalLogitsWarper(typical=0.5, min_tokens_to_keep=1 + (beams > 1)))
1975
self.__warper_list.append(TemperatureLogitsWarper(temperature=0.5))
1976
self.__warper_list.append(AdvancedRepetitionPenaltyLogitsProcessor())
1977
1978
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, *args, **kwargs):
1979
sampler_order = vars.sampler_order[:]
1980
if len(sampler_order) < 7: # Add repetition penalty at beginning if it's not present
1981
sampler_order = [6] + sampler_order
1982
for k in sampler_order:
1983
scores = self.__warper_list[k](input_ids, scores, *args, **kwargs)
1984
visualize_probabilities(scores)
1985
return scores
1986
1987
def new_get_logits_warper(beams: int = 1,) -> LogitsProcessorList:
1988
return KoboldLogitsWarperList(beams=beams)
1989
1990
def new_sample(self, *args, **kwargs):
1991
assert kwargs.pop("logits_warper", None) is not None
1992
kwargs["logits_warper"] = new_get_logits_warper(
1993
beams=1,
1994
)
1995
if(vars.newlinemode == "s") or (vars.newlinemode == "ns"):
1996
kwargs["eos_token_id"] = -1
1997
kwargs.setdefault("pad_token_id", 2)
1998
return new_sample.old_sample(self, *args, **kwargs)
1999
new_sample.old_sample = transformers.generation_utils.GenerationMixin.sample
2000
transformers.generation_utils.GenerationMixin.sample = new_sample
2001
2002
2003
# Allow bad words filter to ban <|endoftext|> token
2004
import transformers.generation_logits_process
2005
def new_init(self, bad_words_ids: List[List[int]], eos_token_id: int):
2006
return new_init.old_init(self, bad_words_ids, -1)
2007
new_init.old_init = transformers.generation_logits_process.NoBadWordsLogitsProcessor.__init__
2008
transformers.generation_logits_process.NoBadWordsLogitsProcessor.__init__ = new_init
2009
2010
class TokenStreamer(StoppingCriteria):
2011
# A StoppingCriteria is used here because it seems to run after
2012
# everything has been evaluated score-wise.
2013
def __init__(self, tokenizer):
2014
self.tokenizer = tokenizer
2015
2016
def __call__(
2017
self,
2018
input_ids: torch.LongTensor,
2019
scores: torch.FloatTensor,
2020
**kwargs,
2021
) -> bool:
2022
# Do not intermingle multiple generations' outputs!
2023
if vars.numseqs > 1:
2024
return False
2025
2026
if not (vars.show_probs or vars.output_streaming):
2027
return False
2028
2029
if vars.chatmode:
2030
return False
2031
tokenizer_text = utils.decodenewlines(tokenizer.decode(input_ids[0, -1]))
2032
vars.token_stream_queue.add_text(tokenizer_text)
2033
return False
2034
2035
2036
# Sets up dynamic world info scanner
2037
class DynamicWorldInfoScanCriteria(StoppingCriteria):
2038
def __init__(
2039
self,
2040
tokenizer,
2041
excluded_world_info: List[Set],
2042
):
2043
self.regeneration_required = False
2044
self.halt = False
2045
self.tokenizer = tokenizer
2046
self.excluded_world_info = excluded_world_info
2047
def __call__(
2048
self,
2049
input_ids: torch.LongTensor,
2050
scores: torch.FloatTensor,
2051
**kwargs,
2052
) -> bool:
2053
vars.generated_tkns += 1
2054
if(not vars.standalone and vars.lua_koboldbridge.generated_cols and vars.generated_tkns != vars.lua_koboldbridge.generated_cols):
2055
raise RuntimeError(f"Inconsistency detected between KoboldAI Python and Lua backends ({vars.generated_tkns} != {vars.lua_koboldbridge.generated_cols})")
2056
if(vars.abort or vars.generated_tkns >= vars.genamt):
2057
self.regeneration_required = False
2058
self.halt = False
2059
return True
2060
if(vars.standalone):
2061
return False
2062
2063
assert input_ids.ndim == 2
2064
assert len(self.excluded_world_info) == input_ids.shape[0]
2065
self.regeneration_required = vars.lua_koboldbridge.regeneration_required
2066
self.halt = not vars.lua_koboldbridge.generating
2067
vars.lua_koboldbridge.regeneration_required = False
2068
2069
for i in range(vars.numseqs):
2070
vars.lua_koboldbridge.generated[i+1][vars.generated_tkns] = int(input_ids[i, -1].item())
2071
2072
if(not vars.dynamicscan):
2073
return self.regeneration_required or self.halt
2074
tail = input_ids[..., -vars.generated_tkns:]
2075
for i, t in enumerate(tail):
2076
decoded = utils.decodenewlines(tokenizer.decode(t))
2077
_, found = checkworldinfo(decoded, force_use_txt=True, actions=vars._actions)
2078
found -= self.excluded_world_info[i]
2079
if(len(found) != 0):
2080
self.regeneration_required = True
2081
break
2082
return self.regeneration_required or self.halt
2083
old_get_stopping_criteria = transformers.generation_utils.GenerationMixin._get_stopping_criteria
2084
def new_get_stopping_criteria(self, *args, **kwargs):
2085
stopping_criteria = old_get_stopping_criteria(self, *args, **kwargs)
2086
global tokenizer
2087
self.kai_scanner = DynamicWorldInfoScanCriteria(
2088
tokenizer=tokenizer,
2089
excluded_world_info=self.kai_scanner_excluded_world_info,
2090
)
2091
token_streamer = TokenStreamer(tokenizer=tokenizer)
2092
2093
stopping_criteria.insert(0, self.kai_scanner)
2094
stopping_criteria.insert(0, token_streamer)
2095
return stopping_criteria
2096
transformers.generation_utils.GenerationMixin._get_stopping_criteria = new_get_stopping_criteria
2097
2098
def reset_model_settings():
2099
vars.socketio = socketio
2100
vars.max_length = 1024 # Maximum number of tokens to submit per action
2101
vars.ikmax = 3000 # Maximum number of characters to submit to InferKit
2102
vars.genamt = 80 # Amount of text for each action to generate
2103
vars.ikgen = 200 # Number of characters for InferKit to generate
2104
vars.rep_pen = 1.1 # Default generator repetition_penalty
2105
vars.rep_pen_slope = 0.7 # Default generator repetition penalty slope
2106
vars.rep_pen_range = 1024 # Default generator repetition penalty range
2107
vars.temp = 0.5 # Default generator temperature
2108
vars.top_p = 0.9 # Default generator top_p
2109
vars.top_k = 0 # Default generator top_k
2110
vars.top_a = 0.0 # Default generator top-a
2111
vars.tfs = 1.0 # Default generator tfs (tail-free sampling)
2112
vars.typical = 1.0 # Default generator typical sampling threshold
2113
vars.numseqs = 1 # Number of sequences to ask the generator to create
2114
vars.generated_tkns = 0 # If using a backend that supports Lua generation modifiers, how many tokens have already been generated, otherwise 0
2115
vars.badwordsids = []
2116
vars.fp32_model = False # Whether or not the most recently loaded HF model was in fp32 format
2117
vars.modeldim = -1 # Embedding dimension of your model (e.g. it's 4096 for GPT-J-6B and 2560 for GPT-Neo-2.7B)
2118
vars.sampler_order = [6, 0, 1, 2, 3, 4, 5]
2119
vars.newlinemode = "n"
2120
vars.revision = None
2121
vars.lazy_load = True
2122
2123
2124
def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=False, online_model="", use_breakmodel_args=False, breakmodel_args_default_to_cpu=False):
2125
global model
2126
global generator
2127
global torch
2128
global model_config
2129
global GPT2Tokenizer
2130
global tokenizer
2131
if(initial_load):
2132
use_breakmodel_args = True
2133
reset_model_settings()
2134
if not utils.HAS_ACCELERATE:
2135
disk_layers = None
2136
vars.noai = False
2137
if not use_breakmodel_args:
2138
set_aibusy(True)
2139
if vars.model != 'ReadOnly':
2140
emit('from_server', {'cmd': 'model_load_status', 'data': "Loading {}".format(vars.model)}, broadcast=True)
2141
#Have to add a sleep so the server will send the emit for some reason
2142
time.sleep(0.1)
2143
if gpu_layers is not None:
2144
args.breakmodel_gpulayers = gpu_layers
2145
elif use_breakmodel_args:
2146
gpu_layers = args.breakmodel_gpulayers
2147
if breakmodel_args_default_to_cpu and gpu_layers is None:
2148
gpu_layers = args.breakmodel_gpulayers = []
2149
if disk_layers is not None:
2150
args.breakmodel_disklayers = int(disk_layers)
2151
elif use_breakmodel_args:
2152
disk_layers = args.breakmodel_disklayers
2153
if breakmodel_args_default_to_cpu and disk_layers is None:
2154
disk_layers = args.breakmodel_disklayers = 0
2155
2156
#We need to wipe out the existing model and refresh the cuda cache
2157
model = None
2158
generator = None
2159
model_config = None
2160
vars.online_model = ''
2161
with torch.no_grad():
2162
with warnings.catch_warnings():
2163
warnings.filterwarnings("ignore", message="torch.distributed.reduce_op is deprecated")
2164
for tensor in gc.get_objects():
2165
try:
2166
if torch.is_tensor(tensor):
2167
tensor.set_(torch.tensor((), device=tensor.device, dtype=tensor.dtype))
2168
except:
2169
pass
2170
gc.collect()
2171
try:
2172
torch.cuda.empty_cache()
2173
except:
2174
pass
2175
2176
#Reload our badwords
2177
vars.badwordsids = vars.badwordsids_default
2178
2179
if online_model == "":
2180
vars.configname = getmodelname()
2181
#Let's set the GooseAI or OpenAI server URLs if that's applicable
2182
else:
2183
vars.online_model = online_model
2184
# Swap OAI Server if GooseAI was selected
2185
if(vars.model == "GooseAI"):
2186
vars.oaiengines = "https://api.goose.ai/v1/engines"
2187
vars.model = "OAI"
2188
vars.configname = f"GooseAI_{online_model.replace('/', '_')}"
2189
elif(vars.model == "CLUSTER") and type(online_model) is list:
2190
if len(online_model) != 1:
2191
vars.configname = vars.model
2192
else:
2193
vars.configname = f"{vars.model}_{online_model[0].replace('/', '_')}"
2194
else:
2195
vars.configname = f"{vars.model}_{online_model.replace('/', '_')}"
2196
if path.exists(get_config_filename()):
2197
changed=False
2198
with open(get_config_filename(), "r") as file:
2199
# Check if API key exists
2200
js = json.load(file)
2201
if 'online_model' in js:
2202
if js['online_model'] != online_model:
2203
changed=True
2204
js['online_model'] = online_model
2205
else:
2206
changed=True
2207
js['online_model'] = online_model
2208
if changed:
2209
with open(get_config_filename(), "w") as file:
2210
file.write(json.dumps(js, indent=3))
2211
2212
# Swap OAI Server if GooseAI was selected
2213
if(vars.model == "GooseAI"):
2214
vars.oaiengines = "https://api.goose.ai/v1/engines"
2215
vars.model = "OAI"
2216
args.configname = "GooseAI" + "/" + online_model
2217
elif vars.model != "CLUSTER":
2218
args.configname = vars.model + "/" + online_model
2219
vars.oaiurl = vars.oaiengines + "/{0}/completions".format(online_model)
2220
2221
2222
# If transformers model was selected & GPU available, ask to use CPU or GPU
2223
if(vars.model not in ["InferKit", "Colab", "API", "CLUSTER", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]):
2224
vars.allowsp = True
2225
# Test for GPU support
2226
2227
# Make model path the same as the model name to make this consistent with the other loading method if it isn't a known model type
2228
# This code is not just a workaround for below, it is also used to make the behavior consistent with other loading methods - Henk717
2229
if(not vars.model in ["NeoCustom", "GPT2Custom"]):
2230
vars.custmodpth = vars.model
2231
elif(vars.model == "NeoCustom"):
2232
vars.model = os.path.basename(os.path.normpath(vars.custmodpth))
2233
2234
# Get the model_type from the config or assume a model type if it isn't present
2235
from transformers import AutoConfig
2236
if(os.path.isdir(vars.custmodpth.replace('/', '_'))):
2237
try:
2238
model_config = AutoConfig.from_pretrained(vars.custmodpth.replace('/', '_'), revision=args.revision, cache_dir="cache")
2239
vars.model_type = model_config.model_type
2240
except ValueError as e:
2241
vars.model_type = "not_found"
2242
elif(os.path.isdir("models/{}".format(vars.custmodpth.replace('/', '_')))):
2243
try:
2244
model_config = AutoConfig.from_pretrained("models/{}".format(vars.custmodpth.replace('/', '_')), revision=args.revision, cache_dir="cache")
2245
vars.model_type = model_config.model_type
2246
except ValueError as e:
2247
vars.model_type = "not_found"
2248
else:
2249
try:
2250
model_config = AutoConfig.from_pretrained(vars.custmodpth, revision=args.revision, cache_dir="cache")
2251
vars.model_type = model_config.model_type
2252
except ValueError as e:
2253
vars.model_type = "not_found"
2254
if(vars.model_type == "not_found" and vars.model == "NeoCustom"):
2255
vars.model_type = "gpt_neo"
2256
elif(vars.model_type == "not_found" and vars.model == "GPT2Custom"):
2257
vars.model_type = "gpt2"
2258
elif(vars.model_type == "not_found"):
2259
logger.warning("No model type detected, assuming Neo (If this is a GPT2 model use the other menu option or --model GPT2Custom)")
2260
vars.model_type = "gpt_neo"
2261
2262
if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "API", "CLUSTER", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]):
2263
loadmodelsettings()
2264
loadsettings()
2265
logger.init("GPU support", status="Searching")
2266
vars.hascuda = torch.cuda.is_available() and not args.cpu
2267
vars.bmsupported = ((utils.HAS_ACCELERATE and vars.model_type != 'gpt2') or vars.model_type in ("gpt_neo", "gptj", "xglm", "opt")) and not vars.nobreakmodel
2268
if(args.breakmodel is not None and args.breakmodel):
2269
logger.warning("--breakmodel is no longer supported. Breakmodel mode is now automatically enabled when --breakmodel_gpulayers is used (see --help for details).")
2270
if(args.breakmodel_layers is not None):
2271
logger.warning("--breakmodel_layers is deprecated. Use --breakmodel_gpulayers instead (see --help for details).")
2272
if(args.model and vars.bmsupported and not args.breakmodel_gpulayers and not args.breakmodel_layers and (not utils.HAS_ACCELERATE or not args.breakmodel_disklayers)):
2273
logger.warning("Model launched without the --breakmodel_gpulayers argument, defaulting to GPU only mode.")
2274
vars.bmsupported = False
2275
if(not vars.bmsupported and (args.breakmodel_gpulayers is not None or args.breakmodel_layers is not None or args.breakmodel_disklayers is not None)):
2276
logger.warning("This model does not support hybrid generation. --breakmodel_gpulayers will be ignored.")
2277
if(vars.hascuda):
2278
logger.init_ok("GPU support", status="Found")
2279
else:
2280
logger.init_warn("GPU support", status="Not Found")
2281
2282
if args.cpu:
2283
vars.usegpu = False
2284
gpu_layers = None
2285
disk_layers = None
2286
vars.breakmodel = False
2287
elif vars.hascuda:
2288
if(vars.bmsupported):
2289
vars.usegpu = False
2290
vars.breakmodel = True
2291
else:
2292
vars.breakmodel = False
2293
vars.usegpu = use_gpu
2294
2295
2296
# Ask for API key if InferKit was selected
2297
if(vars.model == "InferKit"):
2298
vars.apikey = vars.oaiapikey
2299
2300
# Swap OAI Server if GooseAI was selected
2301
if(vars.model == "GooseAI"):
2302
vars.oaiengines = "https://api.goose.ai/v1/engines"
2303
vars.model = "OAI"
2304
vars.configname = "GooseAI"
2305
2306
# Ask for API key if OpenAI was selected
2307
if(vars.model == "OAI"):
2308
if not vars.configname:
2309
vars.configname = "OAI"
2310
2311
if(vars.model == "ReadOnly"):
2312
vars.noai = True
2313
2314
# Start transformers and create pipeline
2315
if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "API", "CLUSTER", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]):
2316
if(not vars.noai):
2317
logger.init("Transformers", status='Starting')
2318
for m in ("GPTJModel", "XGLMModel"):
2319
try:
2320
globals()[m] = getattr(__import__("transformers"), m)
2321
except:
2322
pass
2323
2324
# Lazy loader
2325
import torch_lazy_loader
2326
def get_lazy_load_callback(n_layers, convert_to_float16=True):
2327
if not vars.lazy_load:
2328
return
2329
2330
from tqdm.auto import tqdm
2331
2332
global breakmodel
2333
import breakmodel
2334
2335
if utils.HAS_ACCELERATE:
2336
import accelerate.utils
2337
2338
if args.breakmodel_disklayers is not None:
2339
breakmodel.disk_blocks = args.breakmodel_disklayers
2340
2341
disk_blocks = breakmodel.disk_blocks
2342
gpu_blocks = breakmodel.gpu_blocks
2343
ram_blocks = ram_blocks = n_layers - sum(gpu_blocks)
2344
cumulative_gpu_blocks = tuple(itertools.accumulate(gpu_blocks))
2345
2346
def lazy_load_callback(model_dict: Dict[str, Union[torch_lazy_loader.LazyTensor, torch.Tensor]], f, **_):
2347
if lazy_load_callback.nested:
2348
return
2349
lazy_load_callback.nested = True
2350
2351
device_map: Dict[str, Union[str, int]] = {}
2352
2353
@functools.lru_cache(maxsize=None)
2354
def get_original_key(key):
2355
return max((original_key for original_key in utils.module_names if original_key.endswith(key)), key=len)
2356
2357
for key, value in model_dict.items():
2358
original_key = get_original_key(key)
2359
if isinstance(value, torch_lazy_loader.LazyTensor) and not any(original_key.startswith(n) for n in utils.layers_module_names):
2360
device_map[key] = vars.gpu_device if vars.hascuda and vars.usegpu else "cpu" if not vars.hascuda or not vars.breakmodel else breakmodel.primary_device
2361
else:
2362
layer = int(max((n for n in utils.layers_module_names if original_key.startswith(n)), key=len).rsplit(".", 1)[1])
2363
device = vars.gpu_device if vars.hascuda and vars.usegpu else "disk" if layer < disk_blocks and layer < ram_blocks else "cpu" if not vars.hascuda or not vars.breakmodel else "shared" if layer < ram_blocks else bisect.bisect_right(cumulative_gpu_blocks, layer - ram_blocks)
2364
device_map[key] = device
2365
2366
if utils.num_shards is None or utils.current_shard == 0:
2367
utils.offload_index = {}
2368
if utils.HAS_ACCELERATE:
2369
if os.path.isdir("accelerate-disk-cache"):
2370
# Delete all of the files in the disk cache folder without deleting the folder itself to allow people to create symbolic links for this folder
2371
# (the folder doesn't contain any subfolders so os.remove will do just fine)
2372
for filename in os.listdir("accelerate-disk-cache"):
2373
try:
2374
os.remove(os.path.join("accelerate-disk-cache", filename))
2375
except OSError:
2376
pass
2377
os.makedirs("accelerate-disk-cache", exist_ok=True)
2378
if utils.num_shards is not None:
2379
num_tensors = len(utils.get_sharded_checkpoint_num_tensors(utils.from_pretrained_model_name, utils.from_pretrained_index_filename, **utils.from_pretrained_kwargs))
2380
else:
2381
num_tensors = len(device_map)
2382
utils.bar = tqdm(total=num_tensors, desc=f"{colors.PURPLE}INIT{colors.END} | Loading model tensors", file=Send_to_socketio())
2383
2384
with zipfile.ZipFile(f, "r") as z:
2385
try:
2386
last_storage_key = None
2387
zipfolder = os.path.basename(os.path.normpath(f)).split('.')[0]
2388
f = None
2389
current_offset = 0
2390
able_to_pin_layers = True
2391
if utils.num_shards is not None:
2392
utils.current_shard += 1
2393
for key in sorted(device_map.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)):
2394
storage_key = model_dict[key].key
2395
if storage_key != last_storage_key or model_dict[key].seek_offset < current_offset:
2396
last_storage_key = storage_key
2397
if isinstance(f, zipfile.ZipExtFile):
2398
f.close()
2399
try:
2400
f = z.open(f"archive/data/{storage_key}")
2401
except:
2402
f = z.open(f"{zipfolder}/data/{storage_key}")
2403
current_offset = 0
2404
if current_offset != model_dict[key].seek_offset:
2405
f.read(model_dict[key].seek_offset - current_offset)
2406
current_offset = model_dict[key].seek_offset
2407
device = device_map[key]
2408
size = functools.reduce(lambda x, y: x * y, model_dict[key].shape, 1)
2409
dtype = model_dict[key].dtype
2410
nbytes = size if dtype is torch.bool else size * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3)
2411
#print(f"Transferring <{key}> to {f'({device.upper()})' if isinstance(device, str) else '[device ' + str(device) + ']'} ... ", end="", flush=True)
2412
model_dict[key] = model_dict[key].materialize(f, map_location="cpu")
2413
if model_dict[key].dtype is torch.float32:
2414
vars.fp32_model = True
2415
if convert_to_float16 and breakmodel.primary_device != "cpu" and vars.hascuda and (vars.breakmodel or vars.usegpu) and model_dict[key].dtype is torch.float32:
2416
model_dict[key] = model_dict[key].to(torch.float16)
2417
if breakmodel.primary_device == "cpu" or (not vars.usegpu and not vars.breakmodel and model_dict[key].dtype is torch.float16):
2418
model_dict[key] = model_dict[key].to(torch.float32)
2419
if device == "shared":
2420
model_dict[key] = model_dict[key].to("cpu").detach_()
2421
if able_to_pin_layers and utils.HAS_ACCELERATE:
2422
try:
2423
model_dict[key] = model_dict[key].pin_memory()
2424
except:
2425
able_to_pin_layers = False
2426
elif device == "disk":
2427
accelerate.utils.offload_weight(model_dict[key], get_original_key(key), "accelerate-disk-cache", index=utils.offload_index)
2428
model_dict[key] = model_dict[key].to("meta")
2429
else:
2430
model_dict[key] = model_dict[key].to(device)
2431
#print("OK", flush=True)
2432
current_offset += nbytes
2433
utils.bar.update(1)
2434
finally:
2435
if utils.num_shards is None or utils.current_shard >= utils.num_shards:
2436
if utils.offload_index:
2437
for name, tensor in utils.named_buffers:
2438
dtype = tensor.dtype
2439
if convert_to_float16 and breakmodel.primary_device != "cpu" and vars.hascuda and (vars.breakmodel or vars.usegpu):
2440
dtype = torch.float16
2441
if breakmodel.primary_device == "cpu" or (not vars.usegpu and not vars.breakmodel):
2442
dtype = torch.float32
2443
if name in model_dict and model_dict[name].dtype is not dtype:
2444
model_dict[name] = model_dict[name].to(dtype)
2445
if tensor.dtype is not dtype:
2446
tensor = tensor.to(dtype)
2447
if name not in utils.offload_index:
2448
accelerate.utils.offload_weight(tensor, name, "accelerate-disk-cache", index=utils.offload_index)
2449
accelerate.utils.save_offload_index(utils.offload_index, "accelerate-disk-cache")
2450
utils.bar.close()
2451
utils.bar = None
2452
lazy_load_callback.nested = False
2453
if isinstance(f, zipfile.ZipExtFile):
2454
f.close()
2455
2456
lazy_load_callback.nested = False
2457
return lazy_load_callback
2458
2459
2460
def maybe_low_cpu_mem_usage() -> Dict[str, Any]:
2461
if(packaging.version.parse(transformers_version) < packaging.version.parse("4.11.0")):
2462
logger.warning(f"Please upgrade to transformers 4.11.0 for lower RAM usage. You have transformers {transformers_version}.")
2463
return {}
2464
return {"low_cpu_mem_usage": True}
2465
2466
@contextlib.contextmanager
2467
def maybe_use_float16(always_use=False):
2468
if(always_use or (vars.hascuda and args.lowmem and (vars.usegpu or vars.breakmodel))):
2469
original_dtype = torch.get_default_dtype()
2470
torch.set_default_dtype(torch.float16)
2471
yield True
2472
torch.set_default_dtype(original_dtype)
2473
else:
2474
yield False
2475
2476
# If custom GPT2 model was chosen
2477
if(vars.model_type == "gpt2"):
2478
vars.lazy_load = False
2479
if os.path.exists(vars.custmodpth):
2480
model_config = open(vars.custmodpth + "/config.json", "r")
2481
elif os.path.exists(os.path.join("models/", vars.custmodpth)):
2482
config_path = os.path.join("models/", vars.custmodpth)
2483
config_path = os.path.join(config_path, "config.json").replace("\\", "//")
2484
model_config = open(config_path, "r")
2485
#js = json.load(model_config)
2486
with(maybe_use_float16()):
2487
try:
2488
if os.path.exists(vars.custmodpth):
2489
model = GPT2LMHeadModel.from_pretrained(vars.custmodpth, revision=args.revision, cache_dir="cache")
2490
tokenizer = GPT2Tokenizer.from_pretrained(vars.custmodpth, revision=args.revision, cache_dir="cache")
2491
elif os.path.exists(os.path.join("models/", vars.custmodpth)):
2492
model = GPT2LMHeadModel.from_pretrained(os.path.join("models/", vars.custmodpth), revision=args.revision, cache_dir="cache")
2493
tokenizer = GPT2Tokenizer.from_pretrained(os.path.join("models/", vars.custmodpth), revision=args.revision, cache_dir="cache")
2494
else:
2495
model = GPT2LMHeadModel.from_pretrained(vars.custmodpth, revision=args.revision, cache_dir="cache")
2496
tokenizer = GPT2Tokenizer.from_pretrained(vars.custmodpth, revision=args.revision, cache_dir="cache")
2497
except Exception as e:
2498
if("out of memory" in traceback.format_exc().lower()):
2499
raise RuntimeError("One of your GPUs ran out of memory when KoboldAI tried to load your model.")
2500
raise e
2501
tokenizer = GPT2Tokenizer.from_pretrained(vars.custmodpth, revision=args.revision, cache_dir="cache")
2502
model.save_pretrained("models/{}".format(vars.model.replace('/', '_')), max_shard_size="500MiB")
2503
tokenizer.save_pretrained("models/{}".format(vars.model.replace('/', '_')))
2504
vars.modeldim = get_hidden_size_from_model(model)
2505
# Is CUDA available? If so, use GPU, otherwise fall back to CPU
2506
if(vars.hascuda and vars.usegpu):
2507
model = model.half().to(vars.gpu_device)
2508
generator = model.generate
2509
else:
2510
model = model.to('cpu').float()
2511
generator = model.generate
2512
patch_causallm(model)
2513
# Use the Generic implementation
2514
else:
2515
lowmem = maybe_low_cpu_mem_usage()
2516
# We must disable low_cpu_mem_usage (by setting lowmem to {}) if
2517
# using a GPT-2 model because GPT-2 is not compatible with this
2518
# feature yet
2519
if(vars.model_type == "gpt2"):
2520
lowmem = {}
2521
vars.lazy_load = False # Also, lazy loader doesn't support GPT-2 models
2522
2523
# If we're using torch_lazy_loader, we need to get breakmodel config
2524
# early so that it knows where to load the individual model tensors
2525
if (utils.HAS_ACCELERATE or vars.lazy_load and vars.hascuda and vars.breakmodel) and not vars.nobreakmodel:
2526
device_config(model_config)
2527
2528
# Download model from Huggingface if it does not exist, otherwise load locally
2529
2530
#If we specify a model and it's in the root directory, we need to move it to the models directory (legacy folder structure to new)
2531
if os.path.isdir(vars.model.replace('/', '_')):
2532
import shutil
2533
shutil.move(vars.model.replace('/', '_'), "models/{}".format(vars.model.replace('/', '_')))
2534
if(vars.lazy_load): # If we're using lazy loader, we need to figure out what the model's hidden layers are called
2535
with torch_lazy_loader.use_lazy_torch_load(dematerialized_modules=True, use_accelerate_init_empty_weights=True):
2536
try:
2537
metamodel = AutoModelForCausalLM.from_config(model_config)
2538
except Exception as e:
2539
metamodel = GPTNeoForCausalLM.from_config(model_config)
2540
utils.layers_module_names = utils.get_layers_module_names(metamodel)
2541
utils.module_names = list(metamodel.state_dict().keys())
2542
utils.named_buffers = list(metamodel.named_buffers(recurse=True))
2543
with maybe_use_float16(), torch_lazy_loader.use_lazy_torch_load(enable=vars.lazy_load, callback=get_lazy_load_callback(utils.num_layers(model_config)) if vars.lazy_load else None, dematerialized_modules=True):
2544
if(vars.lazy_load): # torch_lazy_loader.py and low_cpu_mem_usage can't be used at the same time
2545
lowmem = {}
2546
if(os.path.isdir(vars.custmodpth)):
2547
try:
2548
tokenizer = AutoTokenizer.from_pretrained(vars.custmodpth, revision=args.revision, cache_dir="cache", use_fast=False)
2549
except Exception as e:
2550
try:
2551
tokenizer = AutoTokenizer.from_pretrained(vars.custmodpth, revision=args.revision, cache_dir="cache")
2552
except Exception as e:
2553
try:
2554
tokenizer = GPT2Tokenizer.from_pretrained(vars.custmodpth, revision=args.revision, cache_dir="cache")
2555
except Exception as e:
2556
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", revision=args.revision, cache_dir="cache")
2557
try:
2558
model = AutoModelForCausalLM.from_pretrained(vars.custmodpth, revision=args.revision, cache_dir="cache", **lowmem)
2559
except Exception as e:
2560
if("out of memory" in traceback.format_exc().lower()):
2561
raise RuntimeError("One of your GPUs ran out of memory when KoboldAI tried to load your model.")
2562
model = GPTNeoForCausalLM.from_pretrained(vars.custmodpth, revision=args.revision, cache_dir="cache", **lowmem)
2563
elif(os.path.isdir("models/{}".format(vars.model.replace('/', '_')))):
2564
try:
2565
tokenizer = AutoTokenizer.from_pretrained("models/{}".format(vars.model.replace('/', '_')), revision=args.revision, cache_dir="cache", use_fast=False)
2566
except Exception as e:
2567
try:
2568
tokenizer = AutoTokenizer.from_pretrained("models/{}".format(vars.model.replace('/', '_')), revision=args.revision, cache_dir="cache")
2569
except Exception as e:
2570
try:
2571
tokenizer = GPT2Tokenizer.from_pretrained("models/{}".format(vars.model.replace('/', '_')), revision=args.revision, cache_dir="cache")
2572
except Exception as e:
2573
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", revision=args.revision, cache_dir="cache")
2574
try:
2575
model = AutoModelForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), revision=args.revision, cache_dir="cache", **lowmem)
2576
except Exception as e:
2577
if("out of memory" in traceback.format_exc().lower()):
2578
raise RuntimeError("One of your GPUs ran out of memory when KoboldAI tried to load your model.")
2579
model = GPTNeoForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), revision=args.revision, cache_dir="cache", **lowmem)
2580
else:
2581
old_rebuild_tensor = torch._utils._rebuild_tensor
2582
def new_rebuild_tensor(storage: Union[torch_lazy_loader.LazyTensor, torch.Storage], storage_offset, shape, stride):
2583
if(not isinstance(storage, torch_lazy_loader.LazyTensor)):
2584
dtype = storage.dtype
2585
else:
2586
dtype = storage.storage_type.dtype
2587
if(not isinstance(dtype, torch.dtype)):
2588
dtype = storage.storage_type(0).dtype
2589
if(dtype is torch.float32 and len(shape) >= 2):
2590
vars.fp32_model = True
2591
return old_rebuild_tensor(storage, storage_offset, shape, stride)
2592
torch._utils._rebuild_tensor = new_rebuild_tensor
2593
2594
try:
2595
tokenizer = AutoTokenizer.from_pretrained(vars.model, revision=args.revision, cache_dir="cache", use_fast=False)
2596
except Exception as e:
2597
try:
2598
tokenizer = AutoTokenizer.from_pretrained(vars.model, revision=args.revision, cache_dir="cache")
2599
except Exception as e:
2600
try:
2601
tokenizer = GPT2Tokenizer.from_pretrained(vars.model, revision=args.revision, cache_dir="cache")
2602
except Exception as e:
2603
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", revision=args.revision, cache_dir="cache")
2604
try:
2605
model = AutoModelForCausalLM.from_pretrained(vars.model, revision=args.revision, cache_dir="cache", **lowmem)
2606
except Exception as e:
2607
if("out of memory" in traceback.format_exc().lower()):
2608
raise RuntimeError("One of your GPUs ran out of memory when KoboldAI tried to load your model.")
2609
model = GPTNeoForCausalLM.from_pretrained(vars.model, revision=args.revision, cache_dir="cache", **lowmem)
2610
2611
torch._utils._rebuild_tensor = old_rebuild_tensor
2612
2613
if not args.colab or args.savemodel:
2614
import shutil
2615
tokenizer.save_pretrained("models/{}".format(vars.model.replace('/', '_')))
2616
if(vars.fp32_model and ("breakmodel" not in globals() or not breakmodel.disk_blocks)): # Use save_pretrained to convert fp32 models to fp16, unless we are using disk cache because save_pretrained is not supported in that case
2617
model = model.half()
2618
model.save_pretrained("models/{}".format(vars.model.replace('/', '_')), max_shard_size="500MiB")
2619
else: # For fp16 models, we can just copy the model files directly
2620
import transformers.configuration_utils
2621
import transformers.modeling_utils
2622
import transformers.file_utils
2623
import huggingface_hub
2624
legacy = packaging.version.parse(transformers_version) < packaging.version.parse("4.22.0.dev0")
2625
# Save the config.json
2626
shutil.move(os.path.realpath(huggingface_hub.hf_hub_download(vars.model, transformers.configuration_utils.CONFIG_NAME, revision=args.revision, cache_dir="cache", local_files_only=True, legacy_cache_layout=legacy)), os.path.join("models/{}".format(vars.model.replace('/', '_')), transformers.configuration_utils.CONFIG_NAME))
2627
if(utils.num_shards is None):
2628
# Save the pytorch_model.bin of an unsharded model
2629
shutil.move(os.path.realpath(huggingface_hub.hf_hub_download(vars.model, transformers.modeling_utils.WEIGHTS_NAME, revision=args.revision, cache_dir="cache", local_files_only=True, legacy_cache_layout=legacy)), os.path.join("models/{}".format(vars.model.replace('/', '_')), transformers.modeling_utils.WEIGHTS_NAME))
2630
else:
2631
with open(utils.from_pretrained_index_filename) as f:
2632
map_data = json.load(f)
2633
filenames = set(map_data["weight_map"].values())
2634
# Save the pytorch_model.bin.index.json of a sharded model
2635
shutil.move(os.path.realpath(utils.from_pretrained_index_filename), os.path.join("models/{}".format(vars.model.replace('/', '_')), transformers.modeling_utils.WEIGHTS_INDEX_NAME))
2636
# Then save the pytorch_model-#####-of-#####.bin files
2637
for filename in filenames:
2638
shutil.move(os.path.realpath(huggingface_hub.hf_hub_download(vars.model, filename, revision=args.revision, cache_dir="cache", local_files_only=True, legacy_cache_layout=legacy)), os.path.join("models/{}".format(vars.model.replace('/', '_')), filename))
2639
shutil.rmtree("cache/")
2640
2641
if(vars.badwordsids is vars.badwordsids_default and vars.model_type not in ("gpt2", "gpt_neo", "gptj")):
2642
vars.badwordsids = [[v] for k, v in tokenizer.get_vocab().items() if any(c in str(k) for c in "<>[]") if vars.newlinemode != "s" or str(k) != "</s>"]
2643
2644
patch_causallm(model)
2645
2646
if(vars.hascuda):
2647
if(vars.usegpu):
2648
vars.modeldim = get_hidden_size_from_model(model)
2649
model = model.half().to(vars.gpu_device)
2650
generator = model.generate
2651
elif(vars.breakmodel): # Use both RAM and VRAM (breakmodel)
2652
vars.modeldim = get_hidden_size_from_model(model)
2653
if(not vars.lazy_load):
2654
device_config(model.config)
2655
move_model_to_devices(model)
2656
elif(utils.HAS_ACCELERATE and __import__("breakmodel").disk_blocks > 0):
2657
move_model_to_devices(model)
2658
vars.modeldim = get_hidden_size_from_model(model)
2659
generator = model.generate
2660
else:
2661
model = model.to('cpu').float()
2662
vars.modeldim = get_hidden_size_from_model(model)
2663
generator = model.generate
2664
elif(utils.HAS_ACCELERATE and __import__("breakmodel").disk_blocks > 0):
2665
move_model_to_devices(model)
2666
vars.modeldim = get_hidden_size_from_model(model)
2667
generator = model.generate
2668
else:
2669
model.to('cpu').float()
2670
vars.modeldim = get_hidden_size_from_model(model)
2671
generator = model.generate
2672
2673
# Suppress Author's Note by flagging square brackets (Old implementation)
2674
#vocab = tokenizer.get_vocab()
2675
#vocab_keys = vocab.keys()
2676
#vars.badwords = gettokenids("[")
2677
#for key in vars.badwords:
2678
# vars.badwordsids.append([vocab[key]])
2679
2680
logger.info(f"Pipeline created: {vars.model}")
2681
2682
else:
2683
from transformers import GPT2Tokenizer
2684
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", revision=args.revision, cache_dir="cache")
2685
else:
2686
from transformers import PreTrainedModel
2687
from transformers import modeling_utils
2688
old_from_pretrained = PreTrainedModel.from_pretrained.__func__
2689
@classmethod
2690
def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
2691
vars.fp32_model = False
2692
utils.num_shards = None
2693
utils.current_shard = 0
2694
utils.from_pretrained_model_name = pretrained_model_name_or_path
2695
utils.from_pretrained_index_filename = None
2696
utils.from_pretrained_kwargs = kwargs
2697
utils.bar = None
2698
if not args.no_aria2:
2699
utils.aria2_hook(pretrained_model_name_or_path, **kwargs)
2700
return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
2701
if(not hasattr(PreTrainedModel, "_kai_patched")):
2702
PreTrainedModel.from_pretrained = new_from_pretrained
2703
PreTrainedModel._kai_patched = True
2704
if(hasattr(modeling_utils, "get_checkpoint_shard_files")):
2705
old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
2706
def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs):
2707
utils.num_shards = utils.get_num_shards(index_filename)
2708
utils.from_pretrained_index_filename = index_filename
2709
return old_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs)
2710
modeling_utils.get_checkpoint_shard_files = new_get_checkpoint_shard_files
2711
2712
2713
def tpumtjgenerate_warper_callback(scores) -> "np.array":
2714
scores_shape = scores.shape
2715
scores_list = scores.tolist()
2716
vars.lua_koboldbridge.logits = vars.lua_state.table()
2717
for r, row in enumerate(scores_list):
2718
vars.lua_koboldbridge.logits[r+1] = vars.lua_state.table(*row)
2719
vars.lua_koboldbridge.vocab_size = scores_shape[-1]
2720
2721
execute_genmod()
2722
2723
scores = np.array(
2724
tuple(tuple(row.values()) for row in vars.lua_koboldbridge.logits.values()),
2725
dtype=scores.dtype,
2726
)
2727
assert scores.shape == scores_shape
2728
2729
return scores
2730
2731
def tpumtjgenerate_stopping_callback(generated, n_generated, excluded_world_info) -> Tuple[List[set], bool, bool]:
2732
vars.generated_tkns += 1
2733
2734
assert len(excluded_world_info) == len(generated)
2735
regeneration_required = vars.lua_koboldbridge.regeneration_required
2736
halt = vars.abort or not vars.lua_koboldbridge.generating or vars.generated_tkns >= vars.genamt
2737
vars.lua_koboldbridge.regeneration_required = False
2738
2739
global past
2740
2741
for i in range(vars.numseqs):
2742
vars.lua_koboldbridge.generated[i+1][vars.generated_tkns] = int(generated[i, tpu_mtj_backend.params["seq"] + n_generated - 1].item())
2743
2744
if(not vars.dynamicscan or halt):
2745
return excluded_world_info, regeneration_required, halt
2746
2747
for i, t in enumerate(generated):
2748
decoded = utils.decodenewlines(tokenizer.decode(past[i])) + utils.decodenewlines(tokenizer.decode(t[tpu_mtj_backend.params["seq"] : tpu_mtj_backend.params["seq"] + n_generated]))
2749
_, found = checkworldinfo(decoded, force_use_txt=True, actions=vars._actions)
2750
found -= excluded_world_info[i]
2751
if(len(found) != 0):
2752
regeneration_required = True
2753
break
2754
return excluded_world_info, regeneration_required, halt
2755
2756
def tpumtjgenerate_compiling_callback() -> None:
2757
print(colors.GREEN + "TPU backend compilation triggered" + colors.END)
2758
vars.compiling = True
2759
2760
def tpumtjgenerate_stopped_compiling_callback() -> None:
2761
vars.compiling = False
2762
2763
def tpumtjgenerate_settings_callback() -> dict:
2764
sampler_order = vars.sampler_order[:]
2765
if len(sampler_order) < 7: # Add repetition penalty at beginning if it's not present
2766
sampler_order = [6] + sampler_order
2767
return {
2768
"sampler_order": sampler_order,
2769
"top_p": float(vars.top_p),
2770
"temp": float(vars.temp),
2771
"top_k": int(vars.top_k),
2772
"tfs": float(vars.tfs),
2773
"typical": float(vars.typical),
2774
"top_a": float(vars.top_a),
2775
"repetition_penalty": float(vars.rep_pen),
2776
"rpslope": float(vars.rep_pen_slope),
2777
"rprange": int(vars.rep_pen_range),
2778
}
2779
2780
# If we're running Colab or OAI, we still need a tokenizer.
2781
if(vars.model in ("Colab", "API", "CLUSTER")):
2782
from transformers import GPT2Tokenizer
2783
tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-2.7B", revision=args.revision, cache_dir="cache")
2784
loadsettings()
2785
elif(vars.model == "OAI"):
2786
from transformers import GPT2Tokenizer
2787
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", revision=args.revision, cache_dir="cache")
2788
loadsettings()
2789
# Load the TPU backend if requested
2790
elif(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
2791
global tpu_mtj_backend
2792
import tpu_mtj_backend
2793
if(vars.model == "TPUMeshTransformerGPTNeoX"):
2794
vars.badwordsids = vars.badwordsids_neox
2795
print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END))
2796
if vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") and (not vars.custmodpth or not os.path.isdir(vars.custmodpth)):
2797
raise FileNotFoundError(f"The specified model path {repr(vars.custmodpth)} is not the path to a valid folder")
2798
import tpu_mtj_backend
2799
if(vars.model == "TPUMeshTransformerGPTNeoX"):
2800
tpu_mtj_backend.pad_token_id = 2
2801
tpu_mtj_backend.vars = vars
2802
tpu_mtj_backend.warper_callback = tpumtjgenerate_warper_callback
2803
tpu_mtj_backend.stopping_callback = tpumtjgenerate_stopping_callback
2804
tpu_mtj_backend.compiling_callback = tpumtjgenerate_compiling_callback
2805
tpu_mtj_backend.stopped_compiling_callback = tpumtjgenerate_stopped_compiling_callback
2806
tpu_mtj_backend.settings_callback = tpumtjgenerate_settings_callback
2807
vars.allowsp = True
2808
loadmodelsettings()
2809
loadsettings()
2810
tpu_mtj_backend.load_model(vars.custmodpth, hf_checkpoint=vars.model not in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") and vars.use_colab_tpu, **vars.modelconfig)
2811
vars.modeldim = int(tpu_mtj_backend.params.get("d_embed", tpu_mtj_backend.params["d_model"]))
2812
tokenizer = tpu_mtj_backend.tokenizer
2813
if(vars.badwordsids is vars.badwordsids_default and vars.model_type not in ("gpt2", "gpt_neo", "gptj")):
2814
vars.badwordsids = [[v] for k, v in tokenizer.get_vocab().items() if any(c in str(k) for c in "<>[]") if vars.newlinemode != "s" or str(k) != "</s>"]
2815
else:
2816
loadsettings()
2817
2818
lua_startup()
2819
# Load scripts
2820
load_lua_scripts()
2821
2822
final_startup()
2823
if not initial_load:
2824
set_aibusy(False)
2825
emit('from_server', {'cmd': 'hide_model_name'}, broadcast=True)
2826
time.sleep(0.1)
2827
2828
if not vars.gamestarted:
2829
setStartState()
2830
sendsettings()
2831
refresh_settings()
2832
2833
2834
# Set up Flask routes
2835
@app.route('/')
2836
@app.route('/index')
2837
def index():
2838
if args.no_ui:
2839
return redirect('/api/latest')
2840
else:
2841
return render_template('index.html', hide_ai_menu=args.noaimenu)
2842
@app.route('/api', strict_slashes=False)
2843
def api():
2844
return redirect('/api/latest')
2845
@app.route('/favicon.ico')
2846
def favicon():
2847
return send_from_directory(app.root_path,
2848
'koboldai.ico', mimetype='image/vnd.microsoft.icon')
2849
@app.route('/download')
2850
def download():
2851
if args.no_ui:
2852
raise NotFound()
2853
2854
save_format = request.args.get("format", "json").strip().lower()
2855
2856
if(save_format == "plaintext"):
2857
txt = vars.prompt + "".join(vars.actions.values())
2858
save = Response(txt)
2859
filename = path.basename(vars.savedir)
2860
if filename[-5:] == ".json":
2861
filename = filename[:-5]
2862
save.headers.set('Content-Disposition', 'attachment', filename='%s.txt' % filename)
2863
return(save)
2864
2865
# Build json to write
2866
js = {}
2867
js["gamestarted"] = vars.gamestarted
2868
js["prompt"] = vars.prompt
2869
js["memory"] = vars.memory
2870
js["authorsnote"] = vars.authornote
2871
js["anotetemplate"] = vars.authornotetemplate
2872
js["actions"] = tuple(vars.actions.values())
2873
js["actions_metadata"] = vars.actions_metadata
2874
js["worldinfo"] = []
2875
2876
# Extract only the important bits of WI
2877
for wi in vars.worldinfo:
2878
if(wi["constant"] or wi["key"] != ""):
2879
js["worldinfo"].append({
2880
"key": wi["key"],
2881
"keysecondary": wi["keysecondary"],
2882
"content": wi["content"],
2883
"comment": wi["comment"],
2884
"folder": wi["folder"],
2885
"selective": wi["selective"],
2886
"constant": wi["constant"]
2887
})
2888
2889
save = Response(json.dumps(js, indent=3))
2890
filename = path.basename(vars.savedir)
2891
if filename[-5:] == ".json":
2892
filename = filename[:-5]
2893
save.headers.set('Content-Disposition', 'attachment', filename='%s.json' % filename)
2894
return(save)
2895
2896
2897
#============================ LUA API =============================#
2898
_bridged = {}
2899
F = TypeVar("F", bound=Callable)
2900
def lua_startup():
2901
global _bridged
2902
global F
2903
global bridged
2904
if(path.exists(get_config_filename())):
2905
file = open(get_config_filename(), "r")
2906
js = json.load(file)
2907
if("userscripts" in js):
2908
vars.userscripts = []
2909
for userscript in js["userscripts"]:
2910
if type(userscript) is not str:
2911
continue
2912
userscript = userscript.strip()
2913
if len(userscript) != 0 and all(q not in userscript for q in ("..", ":")) and all(userscript[0] not in q for q in ("/", "\\")) and os.path.exists(fileops.uspath(userscript)):
2914
vars.userscripts.append(userscript)
2915
if("corescript" in js and type(js["corescript"]) is str and all(q not in js["corescript"] for q in ("..", ":")) and all(js["corescript"][0] not in q for q in ("/", "\\"))):
2916
vars.corescript = js["corescript"]
2917
else:
2918
vars.corescript = "default.lua"
2919
file.close()
2920
2921
#==================================================================#
2922
# Lua runtime startup
2923
#==================================================================#
2924
2925
print("", end="", flush=True)
2926
logger.init("LUA bridge", status="Starting")
2927
2928
# Set up Lua state
2929
vars.lua_state = lupa.LuaRuntime(unpack_returned_tuples=True)
2930
2931
# Load bridge.lua
2932
bridged = {
2933
"corescript_path": "cores",
2934
"userscript_path": "userscripts",
2935
"config_path": "userscripts",
2936
"lib_paths": vars.lua_state.table("lualibs", os.path.join("extern", "lualibs")),
2937
"vars": vars,
2938
}
2939
for kwarg in _bridged:
2940
bridged[kwarg] = _bridged[kwarg]
2941
try:
2942
vars.lua_kobold, vars.lua_koboldcore, vars.lua_koboldbridge = vars.lua_state.globals().dofile("bridge.lua")(
2943
vars.lua_state.globals().python,
2944
bridged,
2945
)
2946
except lupa.LuaError as e:
2947
print(colors.RED + "ERROR!" + colors.END)
2948
vars.lua_koboldbridge.obliterate_multiverse()
2949
logger.error('LUA ERROR: ' + str(e).replace("\033", ""))
2950
logger.warning("Lua engine stopped; please open 'Userscripts' and press Load to reinitialize scripts.")
2951
exit(1)
2952
logger.init_ok("LUA bridge", status="OK")
2953
2954
2955
def lua_log_format_name(name):
2956
return f"[{name}]" if type(name) is str else "CORE"
2957
2958
2959
def bridged_kwarg(name=None):
2960
def _bridged_kwarg(f: F):
2961
_bridged[name if name is not None else f.__name__[4:] if f.__name__[:4] == "lua_" else f.__name__] = f
2962
return f
2963
return _bridged_kwarg
2964
2965
#==================================================================#
2966
# Event triggered when a userscript is loaded
2967
#==================================================================#
2968
@bridged_kwarg()
2969
def load_callback(filename, modulename):
2970
print(colors.GREEN + f"Loading Userscript [{modulename}] <{filename}>" + colors.END)
2971
2972
#==================================================================#
2973
# Load all Lua scripts
2974
#==================================================================#
2975
def load_lua_scripts():
2976
logger.init("LUA Scripts", status="Starting")
2977
2978
filenames = []
2979
modulenames = []
2980
descriptions = []
2981
2982
lst = fileops.getusfiles(long_desc=True)
2983
filenames_dict = {ob["filename"]: i for i, ob in enumerate(lst)}
2984
2985
for filename in vars.userscripts:
2986
if filename in filenames_dict:
2987
i = filenames_dict[filename]
2988
filenames.append(filename)
2989
modulenames.append(lst[i]["modulename"])
2990
descriptions.append(lst[i]["description"])
2991
2992
vars.has_genmod = False
2993
2994
try:
2995
vars.lua_koboldbridge.obliterate_multiverse()
2996
tpool.execute(vars.lua_koboldbridge.load_corescript, vars.corescript)
2997
vars.has_genmod = tpool.execute(vars.lua_koboldbridge.load_userscripts, filenames, modulenames, descriptions)
2998
vars.lua_running = True
2999
except lupa.LuaError as e:
3000
try:
3001
vars.lua_koboldbridge.obliterate_multiverse()
3002
except:
3003
pass
3004
vars.lua_running = False
3005
if(vars.serverstarted):
3006
emit('from_server', {'cmd': 'errmsg', 'data': 'Lua script error; please check console.'}, broadcast=True)
3007
sendUSStatItems()
3008
logger.error('LUA ERROR: ' + str(e).replace("\033", ""))
3009
logger.warning("Lua engine stopped; please open 'Userscripts' and press Load to reinitialize scripts.")
3010
if(vars.serverstarted):
3011
set_aibusy(0)
3012
logger.init_ok("LUA Scripts", status="OK")
3013
3014
#==================================================================#
3015
# Print message that originates from the userscript with the given name
3016
#==================================================================#
3017
@bridged_kwarg()
3018
def lua_print(msg):
3019
if(vars.lua_logname != vars.lua_koboldbridge.logging_name):
3020
vars.lua_logname = vars.lua_koboldbridge.logging_name
3021
print(colors.BLUE + lua_log_format_name(vars.lua_logname) + ":" + colors.END, file=sys.stderr)
3022
print(colors.PURPLE + msg.replace("\033", "") + colors.END)
3023
3024
#==================================================================#
3025
# Print warning that originates from the userscript with the given name
3026
#==================================================================#
3027
@bridged_kwarg()
3028
def lua_warn(msg):
3029
if(vars.lua_logname != vars.lua_koboldbridge.logging_name):
3030
vars.lua_logname = vars.lua_koboldbridge.logging_name
3031
print(colors.BLUE + lua_log_format_name(vars.lua_logname) + ":" + colors.END, file=sys.stderr)
3032
print(colors.YELLOW + msg.replace("\033", "") + colors.END)
3033
3034
#==================================================================#
3035
# Decode tokens into a string using current tokenizer
3036
#==================================================================#
3037
@bridged_kwarg()
3038
def lua_decode(tokens):
3039
tokens = list(tokens.values())
3040
assert type(tokens) is list
3041
if("tokenizer" not in globals()):
3042
from transformers import GPT2Tokenizer
3043
global tokenizer
3044
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", revision=args.revision, cache_dir="cache")
3045
return utils.decodenewlines(tokenizer.decode(tokens))
3046
3047
#==================================================================#
3048
# Encode string into list of token IDs using current tokenizer
3049
#==================================================================#
3050
@bridged_kwarg()
3051
def lua_encode(string):
3052
assert type(string) is str
3053
if("tokenizer" not in globals()):
3054
from transformers import GPT2Tokenizer
3055
global tokenizer
3056
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", revision=args.revision, cache_dir="cache")
3057
return tokenizer.encode(utils.encodenewlines(string), max_length=int(4e9), truncation=True)
3058
3059
#==================================================================#
3060
# Computes context given a submission, Lua array of entry UIDs and a Lua array
3061
# of folder UIDs
3062
#==================================================================#
3063
@bridged_kwarg()
3064
def lua_compute_context(submission, entries, folders, kwargs):
3065
assert type(submission) is str
3066
if(kwargs is None):
3067
kwargs = vars.lua_state.table()
3068
actions = vars._actions if vars.lua_koboldbridge.userstate == "genmod" else vars.actions
3069
allowed_entries = None
3070
allowed_folders = None
3071
if(entries is not None):
3072
allowed_entries = set()
3073
i = 1
3074
while(entries[i] is not None):
3075
allowed_entries.add(int(entries[i]))
3076
i += 1
3077
if(folders is not None):
3078
allowed_folders = set()
3079
i = 1
3080
while(folders[i] is not None):
3081
allowed_folders.add(int(folders[i]))
3082
i += 1
3083
winfo, mem, anotetxt, _ = calcsubmitbudgetheader(
3084
submission,
3085
allowed_entries=allowed_entries,
3086
allowed_folders=allowed_folders,
3087
force_use_txt=True,
3088
scan_story=kwargs["scan_story"] if kwargs["scan_story"] != None else True,
3089
)
3090
if kwargs["include_anote"] is not None and not kwargs["include_anote"]:
3091
anotetxt = ""
3092
txt, _, _ = calcsubmitbudget(
3093
len(actions),
3094
winfo,
3095
mem,
3096
anotetxt,
3097
actions,
3098
)
3099
return utils.decodenewlines(tokenizer.decode(txt))
3100
3101
#==================================================================#
3102
# Get property of a world info entry given its UID and property name
3103
#==================================================================#
3104
@bridged_kwarg()
3105
def lua_get_attr(uid, k):
3106
assert type(uid) is int and type(k) is str
3107
if(uid in vars.worldinfo_u and k in (
3108
"key",
3109
"keysecondary",
3110
"content",
3111
"comment",
3112
"folder",
3113
"num",
3114
"selective",
3115
"constant",
3116
"uid",
3117
)):
3118
return vars.worldinfo_u[uid][k]
3119
3120
#==================================================================#
3121
# Set property of a world info entry given its UID, property name and new value
3122
#==================================================================#
3123
@bridged_kwarg()
3124
def lua_set_attr(uid, k, v):
3125
assert type(uid) is int and type(k) is str
3126
assert uid in vars.worldinfo_u and k in (
3127
"key",
3128
"keysecondary",
3129
"content",
3130
"comment",
3131
"selective",
3132
"constant",
3133
)
3134
if(type(vars.worldinfo_u[uid][k]) is int and type(v) is float):
3135
v = int(v)
3136
assert type(vars.worldinfo_u[uid][k]) is type(v)
3137
vars.worldinfo_u[uid][k] = v
3138
print(colors.GREEN + f"{lua_log_format_name(vars.lua_koboldbridge.logging_name)} set {k} of world info entry {uid} to {v}" + colors.END)
3139
3140
#==================================================================#
3141
# Get property of a world info folder given its UID and property name
3142
#==================================================================#
3143
@bridged_kwarg()
3144
def lua_folder_get_attr(uid, k):
3145
assert type(uid) is int and type(k) is str
3146
if(uid in vars.wifolders_d and k in (
3147
"name",
3148
)):
3149
return vars.wifolders_d[uid][k]
3150
3151
#==================================================================#
3152
# Set property of a world info folder given its UID, property name and new value
3153
#==================================================================#
3154
@bridged_kwarg()
3155
def lua_folder_set_attr(uid, k, v):
3156
assert type(uid) is int and type(k) is str
3157
assert uid in vars.wifolders_d and k in (
3158
"name",
3159
)
3160
if(type(vars.wifolders_d[uid][k]) is int and type(v) is float):
3161
v = int(v)
3162
assert type(vars.wifolders_d[uid][k]) is type(v)
3163
vars.wifolders_d[uid][k] = v
3164
print(colors.GREEN + f"{lua_log_format_name(vars.lua_koboldbridge.logging_name)} set {k} of world info folder {uid} to {v}" + colors.END)
3165
3166
#==================================================================#
3167
# Get the "Amount to Generate"
3168
#==================================================================#
3169
@bridged_kwarg()
3170
def lua_get_genamt():
3171
return vars.genamt
3172
3173
#==================================================================#
3174
# Set the "Amount to Generate"
3175
#==================================================================#
3176
@bridged_kwarg()
3177
def lua_set_genamt(genamt):
3178
assert vars.lua_koboldbridge.userstate != "genmod" and type(genamt) in (int, float) and genamt >= 0
3179
print(colors.GREEN + f"{lua_log_format_name(vars.lua_koboldbridge.logging_name)} set genamt to {int(genamt)}" + colors.END)
3180
vars.genamt = int(genamt)
3181
3182
#==================================================================#
3183
# Get the "Gens Per Action"
3184
#==================================================================#
3185
@bridged_kwarg()
3186
def lua_get_numseqs():
3187
return vars.numseqs
3188
3189
#==================================================================#
3190
# Set the "Gens Per Action"
3191
#==================================================================#
3192
@bridged_kwarg()
3193
def lua_set_numseqs(numseqs):
3194
assert type(numseqs) in (int, float) and numseqs >= 1
3195
print(colors.GREEN + f"{lua_log_format_name(vars.lua_koboldbridge.logging_name)} set numseqs to {int(numseqs)}" + colors.END)
3196
vars.numseqs = int(numseqs)
3197
3198
#==================================================================#
3199
# Check if a setting exists with the given name
3200
#==================================================================#
3201
@bridged_kwarg()
3202
def lua_has_setting(setting):
3203
return setting in (
3204
"anotedepth",
3205
"settemp",
3206
"settopp",
3207
"settopk",
3208
"settfs",
3209
"settypical",
3210
"settopa",
3211
"setreppen",
3212
"setreppenslope",
3213
"setreppenrange",
3214
"settknmax",
3215
"setwidepth",
3216
"setuseprompt",
3217
"setadventure",
3218
"setchatmode",
3219
"setdynamicscan",
3220
"setnopromptgen",
3221
"autosave",
3222
"setrngpersist",
3223
"temp",
3224
"topp",
3225
"top_p",
3226
"topk",
3227
"top_k",
3228
"tfs",
3229
"typical",
3230
"topa",
3231
"reppen",
3232
"reppenslope",
3233
"reppenrange",
3234
"tknmax",
3235
"widepth",
3236
"useprompt",
3237
"chatmode",
3238
"chatname",
3239
"adventure",
3240
"dynamicscan",
3241
"nopromptgen",
3242
"rngpersist",
3243
"frmttriminc",
3244
"frmtrmblln",
3245
"frmtrmspch",
3246
"frmtadsnsp",
3247
"frmtsingleline",
3248
"triminc",
3249
"rmblln",
3250
"rmspch",
3251
"adsnsp",
3252
"singleline",
3253
"output_streaming",
3254
"show_probs"
3255
)
3256
3257
#==================================================================#
3258
# Return the setting with the given name if it exists
3259
#==================================================================#
3260
@bridged_kwarg()
3261
def lua_get_setting(setting):
3262
if(setting in ("settemp", "temp")): return vars.temp
3263
if(setting in ("settopp", "topp", "top_p")): return vars.top_p
3264
if(setting in ("settopk", "topk", "top_k")): return vars.top_k
3265
if(setting in ("settfs", "tfs")): return vars.tfs
3266
if(setting in ("settypical", "typical")): return vars.typical
3267
if(setting in ("settopa", "topa")): return vars.top_a
3268
if(setting in ("setreppen", "reppen")): return vars.rep_pen
3269
if(setting in ("setreppenslope", "reppenslope")): return vars.rep_pen_slope
3270
if(setting in ("setreppenrange", "reppenrange")): return vars.rep_pen_range
3271
if(setting in ("settknmax", "tknmax")): return vars.max_length
3272
if(setting == "anotedepth"): return vars.andepth
3273
if(setting in ("setwidepth", "widepth")): return vars.widepth
3274
if(setting in ("setuseprompt", "useprompt")): return vars.useprompt
3275
if(setting in ("setadventure", "adventure")): return vars.adventure
3276
if(setting in ("setchatmode", "chatmode")): return vars.chatmode
3277
if(setting in ("setdynamicscan", "dynamicscan")): return vars.dynamicscan
3278
if(setting in ("setnopromptgen", "nopromptgen")): return vars.nopromptgen
3279
if(setting in ("autosave", "autosave")): return vars.autosave
3280
if(setting in ("setrngpersist", "rngpersist")): return vars.rngpersist
3281
if(setting in ("frmttriminc", "triminc")): return vars.formatoptns["frmttriminc"]
3282
if(setting in ("frmtrmblln", "rmblln")): return vars.formatoptns["frmttrmblln"]
3283
if(setting in ("frmtrmspch", "rmspch")): return vars.formatoptns["frmttrmspch"]
3284
if(setting in ("frmtadsnsp", "adsnsp")): return vars.formatoptns["frmtadsnsp"]
3285
if(setting in ("frmtsingleline", "singleline")): return vars.formatoptns["singleline"]
3286
if(setting == "output_streaming"): return vars.output_streaming
3287
if(setting == "show_probs"): return vars.show_probs
3288
3289
#==================================================================#
3290
# Set the setting with the given name if it exists
3291
#==================================================================#
3292
@bridged_kwarg()
3293
def lua_set_setting(setting, v):
3294
actual_type = type(lua_get_setting(setting))
3295
assert v is not None and (actual_type is type(v) or (actual_type is int and type(v) is float))
3296
v = actual_type(v)
3297
print(colors.GREEN + f"{lua_log_format_name(vars.lua_koboldbridge.logging_name)} set {setting} to {v}" + colors.END)
3298
if(setting in ("setadventure", "adventure") and v):
3299
vars.actionmode = 1
3300
if(setting in ("settemp", "temp")): vars.temp = v
3301
if(setting in ("settopp", "topp")): vars.top_p = v
3302
if(setting in ("settopk", "topk")): vars.top_k = v
3303
if(setting in ("settfs", "tfs")): vars.tfs = v
3304
if(setting in ("settypical", "typical")): vars.typical = v
3305
if(setting in ("settopa", "topa")): vars.top_a = v
3306
if(setting in ("setreppen", "reppen")): vars.rep_pen = v
3307
if(setting in ("setreppenslope", "reppenslope")): vars.rep_pen_slope = v
3308
if(setting in ("setreppenrange", "reppenrange")): vars.rep_pen_range = v
3309
if(setting in ("settknmax", "tknmax")): vars.max_length = v; return True
3310
if(setting == "anotedepth"): vars.andepth = v; return True
3311
if(setting in ("setwidepth", "widepth")): vars.widepth = v; return True
3312
if(setting in ("setuseprompt", "useprompt")): vars.useprompt = v; return True
3313
if(setting in ("setadventure", "adventure")): vars.adventure = v
3314
if(setting in ("setdynamicscan", "dynamicscan")): vars.dynamicscan = v
3315
if(setting in ("setnopromptgen", "nopromptgen")): vars.nopromptgen = v
3316
if(setting in ("autosave", "noautosave")): vars.autosave = v
3317
if(setting in ("setrngpersist", "rngpersist")): vars.rngpersist = v
3318
if(setting in ("setchatmode", "chatmode")): vars.chatmode = v
3319
if(setting in ("frmttriminc", "triminc")): vars.formatoptns["frmttriminc"] = v
3320
if(setting in ("frmtrmblln", "rmblln")): vars.formatoptns["frmttrmblln"] = v
3321
if(setting in ("frmtrmspch", "rmspch")): vars.formatoptns["frmttrmspch"] = v
3322
if(setting in ("frmtadsnsp", "adsnsp")): vars.formatoptns["frmtadsnsp"] = v
3323
if(setting in ("frmtsingleline", "singleline")): vars.formatoptns["singleline"] = v
3324
if(setting == "output_streaming"): vars.output_streaming = v
3325
if(setting == "show_probs"): vars.show_probs = v
3326
3327
#==================================================================#
3328
# Get contents of memory
3329
#==================================================================#
3330
@bridged_kwarg()
3331
def lua_get_memory():
3332
return vars.memory
3333
3334
#==================================================================#
3335
# Set contents of memory
3336
#==================================================================#
3337
@bridged_kwarg()
3338
def lua_set_memory(m):
3339
assert type(m) is str
3340
vars.memory = m
3341
3342
#==================================================================#
3343
# Get contents of author's note
3344
#==================================================================#
3345
@bridged_kwarg()
3346
def lua_get_authorsnote():
3347
return vars.authornote
3348
3349
#==================================================================#
3350
# Set contents of author's note
3351
#==================================================================#
3352
@bridged_kwarg()
3353
def lua_set_authorsnote(m):
3354
assert type(m) is str
3355
vars.authornote = m
3356
3357
#==================================================================#
3358
# Get contents of author's note template
3359
#==================================================================#
3360
@bridged_kwarg()
3361
def lua_get_authorsnotetemplate():
3362
return vars.authornotetemplate
3363
3364
#==================================================================#
3365
# Set contents of author's note template
3366
#==================================================================#
3367
@bridged_kwarg()
3368
def lua_set_authorsnotetemplate(m):
3369
assert type(m) is str
3370
vars.authornotetemplate = m
3371
3372
#==================================================================#
3373
# Save settings and send them to client
3374
#==================================================================#
3375
@bridged_kwarg()
3376
def lua_resend_settings():
3377
settingschanged()
3378
refresh_settings()
3379
3380
#==================================================================#
3381
# Set story chunk text and delete the chunk if the new chunk is empty
3382
#==================================================================#
3383
@bridged_kwarg()
3384
def lua_set_chunk(k, v):
3385
assert type(k) in (int, None) and type(v) is str
3386
assert k >= 0
3387
assert k != 0 or len(v) != 0
3388
if(len(v) == 0):
3389
print(colors.GREEN + f"{lua_log_format_name(vars.lua_koboldbridge.logging_name)} deleted story chunk {k}" + colors.END)
3390
chunk = int(k)
3391
if(vars.lua_koboldbridge.userstate == "genmod"):
3392
del vars._actions[chunk-1]
3393
vars.lua_deleted.add(chunk)
3394
if(not hasattr(vars, "_actions") or vars._actions is not vars.actions):
3395
#Instead of deleting we'll blank out the text. This way our actions and actions_metadata stay in sync and we can restore the chunk on an undo
3396
vars.actions[chunk-1] = ""
3397
vars.actions_metadata[chunk-1]['Alternative Text'] = [{"Text": vars.actions_metadata[chunk-1]['Selected Text'], "Pinned": False, "Editted": True}] + vars.actions_metadata[chunk-1]['Alternative Text']
3398
vars.actions_metadata[chunk-1]['Selected Text'] = ''
3399
send_debug()
3400
else:
3401
if(k == 0):
3402
print(colors.GREEN + f"{lua_log_format_name(vars.lua_koboldbridge.logging_name)} edited prompt chunk" + colors.END)
3403
else:
3404
print(colors.GREEN + f"{lua_log_format_name(vars.lua_koboldbridge.logging_name)} edited story chunk {k}" + colors.END)
3405
chunk = int(k)
3406
if(chunk == 0):
3407
if(vars.lua_koboldbridge.userstate == "genmod"):
3408
vars._prompt = v
3409
vars.lua_edited.add(chunk)
3410
vars.prompt = v
3411
else:
3412
if(vars.lua_koboldbridge.userstate == "genmod"):
3413
vars._actions[chunk-1] = v
3414
vars.lua_edited.add(chunk)
3415
vars.actions[chunk-1] = v
3416
vars.actions_metadata[chunk-1]['Alternative Text'] = [{"Text": vars.actions_metadata[chunk-1]['Selected Text'], "Pinned": False, "Editted": True}] + vars.actions_metadata[chunk-1]['Alternative Text']
3417
vars.actions_metadata[chunk-1]['Selected Text'] = v
3418
send_debug()
3419
3420
#==================================================================#
3421
# Get model type as "gpt-2-xl", "gpt-neo-2.7B", etc.
3422
#==================================================================#
3423
@bridged_kwarg()
3424
def lua_get_modeltype():
3425
if(vars.noai):
3426
return "readonly"
3427
if(vars.model in ("Colab", "API", "CLUSTER", "OAI", "InferKit")):
3428
return "api"
3429
if(not vars.use_colab_tpu and vars.model not in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") and (vars.model in ("GPT2Custom", "NeoCustom") or vars.model_type in ("gpt2", "gpt_neo", "gptj"))):
3430
hidden_size = get_hidden_size_from_model(model)
3431
if(vars.model in ("gpt2",) or (vars.model_type == "gpt2" and hidden_size == 768)):
3432
return "gpt2"
3433
if(vars.model in ("gpt2-medium",) or (vars.model_type == "gpt2" and hidden_size == 1024)):
3434
return "gpt2-medium"
3435
if(vars.model in ("gpt2-large",) or (vars.model_type == "gpt2" and hidden_size == 1280)):
3436
return "gpt2-large"
3437
if(vars.model in ("gpt2-xl",) or (vars.model_type == "gpt2" and hidden_size == 1600)):
3438
return "gpt2-xl"
3439
if(vars.model_type == "gpt_neo" and hidden_size == 768):
3440
return "gpt-neo-125M"
3441
if(vars.model in ("EleutherAI/gpt-neo-1.3B",) or (vars.model_type == "gpt_neo" and hidden_size == 2048)):
3442
return "gpt-neo-1.3B"
3443
if(vars.model in ("EleutherAI/gpt-neo-2.7B",) or (vars.model_type == "gpt_neo" and hidden_size == 2560)):
3444
return "gpt-neo-2.7B"
3445
if(vars.model in ("EleutherAI/gpt-j-6B",) or ((vars.use_colab_tpu or vars.model == "TPUMeshTransformerGPTJ") and tpu_mtj_backend.params["d_model"] == 4096) or (vars.model_type in ("gpt_neo", "gptj") and hidden_size == 4096)):
3446
return "gpt-j-6B"
3447
return "unknown"
3448
3449
#==================================================================#
3450
# Get model backend as "transformers" or "mtj"
3451
#==================================================================#
3452
@bridged_kwarg()
3453
def lua_get_modelbackend():
3454
if(vars.noai):
3455
return "readonly"
3456
if(vars.model in ("Colab", "API", "CLUSTER", "OAI", "InferKit")):
3457
return "api"
3458
if(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
3459
return "mtj"
3460
return "transformers"
3461
3462
#==================================================================#
3463
# Check whether model is loaded from a custom path
3464
#==================================================================#
3465
@bridged_kwarg()
3466
def lua_is_custommodel():
3467
return vars.model in ("GPT2Custom", "NeoCustom", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")
3468
3469
#==================================================================#
3470
# Return the filename (as a string) of the current soft prompt, or
3471
# None if no soft prompt is loaded
3472
#==================================================================#
3473
@bridged_kwarg()
3474
def lua_get_spfilename():
3475
return vars.spfilename.strip() or None
3476
3477
#==================================================================#
3478
# When called with a string as argument, sets the current soft prompt;
3479
# when called with None as argument, uses no soft prompt.
3480
# Returns True if soft prompt changed, False otherwise.
3481
#==================================================================#
3482
@bridged_kwarg()
3483
def lua_set_spfilename(filename: Union[str, None]):
3484
if(filename is None):
3485
filename = ""
3486
filename = str(filename).strip()
3487
changed = lua_get_spfilename() != filename
3488
assert all(q not in filename for q in ("/", "\\"))
3489
spRequest(filename)
3490
return changed
3491
3492
#==================================================================#
3493
#
3494
#==================================================================#
3495
def execute_inmod():
3496
setgamesaved(False)
3497
vars.lua_logname = ...
3498
vars.lua_edited = set()
3499
vars.lua_deleted = set()
3500
try:
3501
tpool.execute(vars.lua_koboldbridge.execute_inmod)
3502
except lupa.LuaError as e:
3503
vars.lua_koboldbridge.obliterate_multiverse()
3504
vars.lua_running = False
3505
emit('from_server', {'cmd': 'errmsg', 'data': 'Lua script error; please check console.'}, broadcast=True)
3506
sendUSStatItems()
3507
logger.error('LUA ERROR: ' + str(e).replace("\033", ""))
3508
logger.warning("Lua engine stopped; please open 'Userscripts' and press Load to reinitialize scripts.")
3509
set_aibusy(0)
3510
3511
def execute_genmod():
3512
vars.lua_koboldbridge.execute_genmod()
3513
3514
def execute_outmod():
3515
setgamesaved(False)
3516
emit('from_server', {'cmd': 'hidemsg', 'data': ''}, broadcast=True)
3517
try:
3518
tpool.execute(vars.lua_koboldbridge.execute_outmod)
3519
except lupa.LuaError as e:
3520
vars.lua_koboldbridge.obliterate_multiverse()
3521
vars.lua_running = False
3522
emit('from_server', {'cmd': 'errmsg', 'data': 'Lua script error; please check console.'}, broadcast=True)
3523
sendUSStatItems()
3524
logger.error('LUA ERROR: ' + str(e).replace("\033", ""))
3525
logger.warning("Lua engine stopped; please open 'Userscripts' and press Load to reinitialize scripts.")
3526
set_aibusy(0)
3527
if(vars.lua_koboldbridge.resend_settings_required):
3528
vars.lua_koboldbridge.resend_settings_required = False
3529
lua_resend_settings()
3530
for k in vars.lua_edited:
3531
inlineedit(k, vars.actions[k])
3532
for k in vars.lua_deleted:
3533
inlinedelete(k)
3534
3535
3536
3537
3538
#============================ METHODS =============================#
3539
3540
#==================================================================#
3541
# Event triggered when browser SocketIO is loaded and connects to server
3542
#==================================================================#
3543
@socketio.on('connect')
3544
def do_connect():
3545
logger.info("Client connected!")
3546
emit('from_server', {'cmd': 'setchatname', 'data': vars.chatname})
3547
emit('from_server', {'cmd': 'setanotetemplate', 'data': vars.authornotetemplate})
3548
emit('from_server', {'cmd': 'connected', 'smandelete': vars.smandelete, 'smanrename': vars.smanrename, 'modelname': getmodelname()})
3549
if(vars.host):
3550
emit('from_server', {'cmd': 'runs_remotely'})
3551
if(vars.allowsp):
3552
emit('from_server', {'cmd': 'allowsp', 'data': vars.allowsp})
3553
3554
sendUSStatItems()
3555
emit('from_server', {'cmd': 'spstatitems', 'data': {vars.spfilename: vars.spmeta} if vars.allowsp and len(vars.spfilename) else {}}, broadcast=True)
3556
3557
if(not vars.gamestarted):
3558
setStartState()
3559
sendsettings()
3560
refresh_settings()
3561
vars.laststory = None
3562
emit('from_server', {'cmd': 'setstoryname', 'data': vars.laststory})
3563
sendwi()
3564
emit('from_server', {'cmd': 'setmemory', 'data': vars.memory})
3565
emit('from_server', {'cmd': 'setanote', 'data': vars.authornote})
3566
vars.mode = "play"
3567
else:
3568
# Game in session, send current game data and ready state to browser
3569
refresh_story()
3570
sendsettings()
3571
refresh_settings()
3572
emit('from_server', {'cmd': 'setstoryname', 'data': vars.laststory})
3573
sendwi()
3574
emit('from_server', {'cmd': 'setmemory', 'data': vars.memory})
3575
emit('from_server', {'cmd': 'setanote', 'data': vars.authornote})
3576
if(vars.mode == "play"):
3577
if(not vars.aibusy):
3578
emit('from_server', {'cmd': 'setgamestate', 'data': 'ready'})
3579
else:
3580
emit('from_server', {'cmd': 'setgamestate', 'data': 'wait'})
3581
elif(vars.mode == "edit"):
3582
emit('from_server', {'cmd': 'editmode', 'data': 'true'})
3583
elif(vars.mode == "memory"):
3584
emit('from_server', {'cmd': 'memmode', 'data': 'true'})
3585
elif(vars.mode == "wi"):
3586
emit('from_server', {'cmd': 'wimode', 'data': 'true'})
3587
3588
emit('from_server', {'cmd': 'gamesaved', 'data': vars.gamesaved}, broadcast=True)
3589
3590
#==================================================================#
3591
# Event triggered when browser SocketIO sends data to the server
3592
#==================================================================#
3593
@socketio.on('message')
3594
def get_message(msg):
3595
if not vars.quiet:
3596
logger.debug(f"Data received: {msg}")
3597
# Submit action
3598
if(msg['cmd'] == 'submit'):
3599
if(vars.mode == "play"):
3600
if(vars.aibusy):
3601
if(msg.get('allowabort', False)):
3602
vars.abort = True
3603
return
3604
vars.abort = False
3605
vars.lua_koboldbridge.feedback = None
3606
if(vars.chatmode):
3607
if(type(msg['chatname']) is not str):
3608
raise ValueError("Chatname must be a string")
3609
vars.chatname = msg['chatname']
3610
settingschanged()
3611
emit('from_server', {'cmd': 'setchatname', 'data': vars.chatname})
3612
vars.recentrng = vars.recentrngm = None
3613
actionsubmit(msg['data'], actionmode=msg['actionmode'])
3614
elif(vars.mode == "edit"):
3615
editsubmit(msg['data'])
3616
elif(vars.mode == "memory"):
3617
memsubmit(msg['data'])
3618
# Retry Action
3619
elif(msg['cmd'] == 'retry'):
3620
if(vars.aibusy):
3621
if(msg.get('allowabort', False)):
3622
vars.abort = True
3623
return
3624
vars.abort = False
3625
if(vars.chatmode):
3626
if(type(msg['chatname']) is not str):
3627
raise ValueError("Chatname must be a string")
3628
vars.chatname = msg['chatname']
3629
settingschanged()
3630
emit('from_server', {'cmd': 'setchatname', 'data': vars.chatname})
3631
actionretry(msg['data'])
3632
# Back/Undo Action
3633
elif(msg['cmd'] == 'back'):
3634
ignore = actionback()
3635
# Forward/Redo Action
3636
elif(msg['cmd'] == 'redo'):
3637
actionredo()
3638
# EditMode Action (old)
3639
elif(msg['cmd'] == 'edit'):
3640
if(vars.mode == "play"):
3641
vars.mode = "edit"
3642
emit('from_server', {'cmd': 'editmode', 'data': 'true'}, broadcast=True)
3643
elif(vars.mode == "edit"):
3644
vars.mode = "play"
3645
emit('from_server', {'cmd': 'editmode', 'data': 'false'}, broadcast=True)
3646
# EditLine Action (old)
3647
elif(msg['cmd'] == 'editline'):
3648
editrequest(int(msg['data']))
3649
# Inline edit
3650
elif(msg['cmd'] == 'inlineedit'):
3651
inlineedit(msg['chunk'], msg['data'])
3652
elif(msg['cmd'] == 'inlinedelete'):
3653
inlinedelete(msg['data'])
3654
# DeleteLine Action (old)
3655
elif(msg['cmd'] == 'delete'):
3656
deleterequest()
3657
elif(msg['cmd'] == 'memory'):
3658
togglememorymode()
3659
elif(not vars.host and msg['cmd'] == 'savetofile'):
3660
savetofile()
3661
elif(not vars.host and msg['cmd'] == 'loadfromfile'):
3662
loadfromfile()
3663
elif(msg['cmd'] == 'loadfromstring'):
3664
loadRequest(json.loads(msg['data']), filename=msg['filename'])
3665
elif(not vars.host and msg['cmd'] == 'import'):
3666
importRequest()
3667
elif(msg['cmd'] == 'newgame'):
3668
newGameRequest()
3669
elif(msg['cmd'] == 'rndgame'):
3670
randomGameRequest(msg['data'], memory=msg['memory'])
3671
elif(msg['cmd'] == 'settemp'):
3672
vars.temp = float(msg['data'])
3673
emit('from_server', {'cmd': 'setlabeltemp', 'data': msg['data']}, broadcast=True)
3674
settingschanged()
3675
refresh_settings()
3676
elif(msg['cmd'] == 'settopp'):
3677
vars.top_p = float(msg['data'])
3678
emit('from_server', {'cmd': 'setlabeltopp', 'data': msg['data']}, broadcast=True)
3679
settingschanged()
3680
refresh_settings()
3681
elif(msg['cmd'] == 'settopk'):
3682
vars.top_k = int(msg['data'])
3683
emit('from_server', {'cmd': 'setlabeltopk', 'data': msg['data']}, broadcast=True)
3684
settingschanged()
3685
refresh_settings()
3686
elif(msg['cmd'] == 'settfs'):
3687
vars.tfs = float(msg['data'])
3688
emit('from_server', {'cmd': 'setlabeltfs', 'data': msg['data']}, broadcast=True)
3689
settingschanged()
3690
refresh_settings()
3691
elif(msg['cmd'] == 'settypical'):
3692
vars.typical = float(msg['data'])
3693
emit('from_server', {'cmd': 'setlabeltypical', 'data': msg['data']}, broadcast=True)
3694
settingschanged()
3695
refresh_settings()
3696
elif(msg['cmd'] == 'settopa'):
3697
vars.top_a = float(msg['data'])
3698
emit('from_server', {'cmd': 'setlabeltopa', 'data': msg['data']}, broadcast=True)
3699
settingschanged()
3700
refresh_settings()
3701
elif(msg['cmd'] == 'setreppen'):
3702
vars.rep_pen = float(msg['data'])
3703
emit('from_server', {'cmd': 'setlabelreppen', 'data': msg['data']}, broadcast=True)
3704
settingschanged()
3705
refresh_settings()
3706
elif(msg['cmd'] == 'setreppenslope'):
3707
vars.rep_pen_slope = float(msg['data'])
3708
emit('from_server', {'cmd': 'setlabelreppenslope', 'data': msg['data']}, broadcast=True)
3709
settingschanged()
3710
refresh_settings()
3711
elif(msg['cmd'] == 'setreppenrange'):
3712
vars.rep_pen_range = float(msg['data'])
3713
emit('from_server', {'cmd': 'setlabelreppenrange', 'data': msg['data']}, broadcast=True)
3714
settingschanged()
3715
refresh_settings()
3716
elif(msg['cmd'] == 'setoutput'):
3717
vars.genamt = int(msg['data'])
3718
emit('from_server', {'cmd': 'setlabeloutput', 'data': msg['data']}, broadcast=True)
3719
settingschanged()
3720
refresh_settings()
3721
elif(msg['cmd'] == 'settknmax'):
3722
vars.max_length = int(msg['data'])
3723
emit('from_server', {'cmd': 'setlabeltknmax', 'data': msg['data']}, broadcast=True)
3724
settingschanged()
3725
refresh_settings()
3726
elif(msg['cmd'] == 'setikgen'):
3727
vars.ikgen = int(msg['data'])
3728
emit('from_server', {'cmd': 'setlabelikgen', 'data': msg['data']}, broadcast=True)
3729
settingschanged()
3730
refresh_settings()
3731
# Author's Note field update
3732
elif(msg['cmd'] == 'anote'):
3733
anotesubmit(msg['data'], template=msg['template'])
3734
# Author's Note depth update
3735
elif(msg['cmd'] == 'anotedepth'):
3736
vars.andepth = int(msg['data'])
3737
emit('from_server', {'cmd': 'setlabelanotedepth', 'data': msg['data']}, broadcast=True)
3738
settingschanged()
3739
refresh_settings()
3740
# Format - Trim incomplete sentences
3741
elif(msg['cmd'] == 'frmttriminc'):
3742
if('frmttriminc' in vars.formatoptns):
3743
vars.formatoptns["frmttriminc"] = msg['data']
3744
settingschanged()
3745
refresh_settings()
3746
elif(msg['cmd'] == 'frmtrmblln'):
3747
if('frmtrmblln' in vars.formatoptns):
3748
vars.formatoptns["frmtrmblln"] = msg['data']
3749
settingschanged()
3750
refresh_settings()
3751
elif(msg['cmd'] == 'frmtrmspch'):
3752
if('frmtrmspch' in vars.formatoptns):
3753
vars.formatoptns["frmtrmspch"] = msg['data']
3754
settingschanged()
3755
refresh_settings()
3756
elif(msg['cmd'] == 'frmtadsnsp'):
3757
if('frmtadsnsp' in vars.formatoptns):
3758
vars.formatoptns["frmtadsnsp"] = msg['data']
3759
settingschanged()
3760
refresh_settings()
3761
elif(msg['cmd'] == 'singleline'):
3762
if('singleline' in vars.formatoptns):
3763
vars.formatoptns["singleline"] = msg['data']
3764
settingschanged()
3765
refresh_settings()
3766
elif(msg['cmd'] == 'importselect'):
3767
vars.importnum = int(msg["data"].replace("import", ""))
3768
elif(msg['cmd'] == 'importcancel'):
3769
emit('from_server', {'cmd': 'popupshow', 'data': False})
3770
vars.importjs = {}
3771
elif(msg['cmd'] == 'importaccept'):
3772
emit('from_server', {'cmd': 'popupshow', 'data': False})
3773
importgame()
3774
elif(msg['cmd'] == 'wi'):
3775
togglewimode()
3776
elif(msg['cmd'] == 'wiinit'):
3777
if(int(msg['data']) < len(vars.worldinfo)):
3778
setgamesaved(False)
3779
vars.worldinfo[msg['data']]["init"] = True
3780
addwiitem(folder_uid=msg['folder'])
3781
elif(msg['cmd'] == 'wifolderinit'):
3782
addwifolder()
3783
elif(msg['cmd'] == 'wimoveitem'):
3784
movewiitem(msg['destination'], msg['data'])
3785
elif(msg['cmd'] == 'wimovefolder'):
3786
movewifolder(msg['destination'], msg['data'])
3787
elif(msg['cmd'] == 'widelete'):
3788
deletewi(msg['data'])
3789
elif(msg['cmd'] == 'wifolderdelete'):
3790
deletewifolder(msg['data'])
3791
elif(msg['cmd'] == 'wiexpand'):
3792
assert 0 <= int(msg['data']) < len(vars.worldinfo)
3793
setgamesaved(False)
3794
emit('from_server', {'cmd': 'wiexpand', 'data': msg['data']}, broadcast=True)
3795
elif(msg['cmd'] == 'wiexpandfolder'):
3796
assert 0 <= int(msg['data']) < len(vars.worldinfo)
3797
setgamesaved(False)
3798
emit('from_server', {'cmd': 'wiexpandfolder', 'data': msg['data']}, broadcast=True)
3799
elif(msg['cmd'] == 'wifoldercollapsecontent'):
3800
setgamesaved(False)
3801
vars.wifolders_d[msg['data']]['collapsed'] = True
3802
emit('from_server', {'cmd': 'wifoldercollapsecontent', 'data': msg['data']}, broadcast=True)
3803
elif(msg['cmd'] == 'wifolderexpandcontent'):
3804
setgamesaved(False)
3805
vars.wifolders_d[msg['data']]['collapsed'] = False
3806
emit('from_server', {'cmd': 'wifolderexpandcontent', 'data': msg['data']}, broadcast=True)
3807
elif(msg['cmd'] == 'wiupdate'):
3808
setgamesaved(False)
3809
num = int(msg['num'])
3810
fields = ("key", "keysecondary", "content", "comment")
3811
for field in fields:
3812
if(field in msg['data'] and type(msg['data'][field]) is str):
3813
vars.worldinfo[num][field] = msg['data'][field]
3814
emit('from_server', {'cmd': 'wiupdate', 'num': msg['num'], 'data': {field: vars.worldinfo[num][field] for field in fields}}, broadcast=True)
3815
elif(msg['cmd'] == 'wifolderupdate'):
3816
setgamesaved(False)
3817
uid = int(msg['uid'])
3818
fields = ("name", "collapsed")
3819
for field in fields:
3820
if(field in msg['data'] and type(msg['data'][field]) is (str if field != "collapsed" else bool)):
3821
vars.wifolders_d[uid][field] = msg['data'][field]
3822
emit('from_server', {'cmd': 'wifolderupdate', 'uid': msg['uid'], 'data': {field: vars.wifolders_d[uid][field] for field in fields}}, broadcast=True)
3823
elif(msg['cmd'] == 'wiselon'):
3824
setgamesaved(False)
3825
vars.worldinfo[msg['data']]["selective"] = True
3826
emit('from_server', {'cmd': 'wiselon', 'data': msg['data']}, broadcast=True)
3827
elif(msg['cmd'] == 'wiseloff'):
3828
setgamesaved(False)
3829
vars.worldinfo[msg['data']]["selective"] = False
3830
emit('from_server', {'cmd': 'wiseloff', 'data': msg['data']}, broadcast=True)
3831
elif(msg['cmd'] == 'wiconstanton'):
3832
setgamesaved(False)
3833
vars.worldinfo[msg['data']]["constant"] = True
3834
emit('from_server', {'cmd': 'wiconstanton', 'data': msg['data']}, broadcast=True)
3835
elif(msg['cmd'] == 'wiconstantoff'):
3836
setgamesaved(False)
3837
vars.worldinfo[msg['data']]["constant"] = False
3838
emit('from_server', {'cmd': 'wiconstantoff', 'data': msg['data']}, broadcast=True)
3839
elif(msg['cmd'] == 'sendwilist'):
3840
commitwi(msg['data'])
3841
elif(msg['cmd'] == 'aidgimport'):
3842
importAidgRequest(msg['data'])
3843
elif(msg['cmd'] == 'saveasrequest'):
3844
saveas(msg['data'])
3845
elif(msg['cmd'] == 'saverequest'):
3846
save()
3847
elif(msg['cmd'] == 'loadlistrequest'):
3848
getloadlist()
3849
elif(msg['cmd'] == 'splistrequest'):
3850
getsplist()
3851
elif(msg['cmd'] == 'uslistrequest'):
3852
unloaded, loaded = getuslist()
3853
emit('from_server', {'cmd': 'buildus', 'data': {"unloaded": unloaded, "loaded": loaded}})
3854
elif(msg['cmd'] == 'samplerlistrequest'):
3855
emit('from_server', {'cmd': 'buildsamplers', 'data': vars.sampler_order})
3856
elif(msg['cmd'] == 'usloaded'):
3857
vars.userscripts = []
3858
for userscript in msg['data']:
3859
if type(userscript) is not str:
3860
continue
3861
userscript = userscript.strip()
3862
if len(userscript) != 0 and all(q not in userscript for q in ("..", ":")) and all(userscript[0] not in q for q in ("/", "\\")) and os.path.exists(fileops.uspath(userscript)):
3863
vars.userscripts.append(userscript)
3864
settingschanged()
3865
elif(msg['cmd'] == 'usload'):
3866
load_lua_scripts()
3867
unloaded, loaded = getuslist()
3868
sendUSStatItems()
3869
elif(msg['cmd'] == 'samplers'):
3870
sampler_order = msg["data"]
3871
sampler_order_min_length = 6
3872
sampler_order_max_length = 7
3873
if(not isinstance(sampler_order, list)):
3874
raise ValueError(f"Sampler order must be a list, but got a {type(sampler_order)}")
3875
if(not (sampler_order_min_length <= len(sampler_order) <= sampler_order_max_length)):
3876
raise ValueError(f"Sampler order must be a list of length greater than or equal to {sampler_order_min_length} and less than or equal to {sampler_order_max_length}, but got a list of length {len(sampler_order)}")
3877
if(not all(isinstance(e, int) for e in sampler_order)):
3878
raise ValueError(f"Sampler order must be a list of ints, but got a list with at least one non-int element")
3879
if(min(sampler_order) != 0 or max(sampler_order) != len(sampler_order) - 1 or len(set(sampler_order)) != len(sampler_order)):
3880
raise ValueError(f"Sampler order list of length {len(sampler_order)} must be a permutation of the first {len(sampler_order)} nonnegative integers")
3881
vars.sampler_order = sampler_order
3882
settingschanged()
3883
elif(msg['cmd'] == 'list_model'):
3884
sendModelSelection(menu=msg['data'])
3885
elif(msg['cmd'] == 'load_model'):
3886
logger.debug(f"Selected Model: {vars.model_selected}")
3887
if not os.path.exists("settings/"):
3888
os.mkdir("settings")
3889
changed = True
3890
if not utils.HAS_ACCELERATE:
3891
msg['disk_layers'] = "0"
3892
if os.path.exists("settings/" + vars.model_selected.replace('/', '_') + ".breakmodel"):
3893
with open("settings/" + vars.model_selected.replace('/', '_') + ".breakmodel", "r") as file:
3894
data = file.read().split('\n')[:2]
3895
if len(data) < 2:
3896
data.append("0")
3897
gpu_layers, disk_layers = data
3898
if gpu_layers == msg['gpu_layers'] and disk_layers == msg['disk_layers']:
3899
changed = False
3900
if changed:
3901
if vars.model_selected in ["NeoCustom", "GPT2Custom"]:
3902
filename = "settings/{}.breakmodel".format(os.path.basename(os.path.normpath(vars.custmodpth)))
3903
else:
3904
filename = "settings/{}.breakmodel".format(vars.model_selected.replace('/', '_'))
3905
f = open(filename, "w")
3906
f.write(str(msg['gpu_layers']) + '\n' + str(msg['disk_layers']))
3907
f.close()
3908
vars.colaburl = msg['url'] + "/request"
3909
vars.model = vars.model_selected
3910
if vars.model == "CLUSTER":
3911
if type(msg['online_model']) is not list:
3912
if msg['online_model'] == '':
3913
vars.cluster_requested_models = []
3914
else:
3915
vars.cluster_requested_models = [msg['online_model']]
3916
else:
3917
vars.cluster_requested_models = msg['online_model']
3918
load_model(use_gpu=msg['use_gpu'], gpu_layers=msg['gpu_layers'], disk_layers=msg['disk_layers'], online_model=msg['online_model'])
3919
elif(msg['cmd'] == 'show_model'):
3920
logger.info(f"Model Name: {getmodelname()}")
3921
emit('from_server', {'cmd': 'show_model_name', 'data': getmodelname()}, broadcast=True)
3922
elif(msg['cmd'] == 'selectmodel'):
3923
# This is run when a model line is selected from the UI (line from the model_menu variable) that is tagged as not a menu
3924
# otherwise we should be running the msg['cmd'] == 'list_model'
3925
3926
# We have to do a bit of processing though, if we select a custom path, we need to list out the contents of folders
3927
# But if we select something else, we need to potentially show model layers for each GPU
3928
# We might also need to show key input. All of that happens here
3929
3930
# The data variable will contain the model name. But our Custom lines need a bit more processing
3931
# If we're on a custom line that we have selected a model for, the path variable will be in msg
3932
# so if that's missing we need to run the menu to show the model folders in the models folder
3933
if msg['data'] in ('NeoCustom', 'GPT2Custom') and 'path' not in msg and 'path_modelname' not in msg:
3934
if 'folder' not in msg or vars.host:
3935
folder = "./models"
3936
else:
3937
folder = msg['folder']
3938
sendModelSelection(menu=msg['data'], folder=folder)
3939
elif msg['data'] in ('NeoCustom', 'GPT2Custom') and 'path_modelname' in msg:
3940
#Here the user entered custom text in the text box. This could be either a model name or a path.
3941
if check_if_dir_is_model(msg['path_modelname']):
3942
vars.model_selected = msg['data']
3943
vars.custmodpth = msg['path_modelname']
3944
get_model_info(msg['data'], directory=msg['path'])
3945
else:
3946
vars.model_selected = msg['path_modelname']
3947
try:
3948
get_model_info(vars.model_selected)
3949
except:
3950
emit('from_server', {'cmd': 'errmsg', 'data': "The model entered doesn't exist."})
3951
elif msg['data'] in ('NeoCustom', 'GPT2Custom'):
3952
if check_if_dir_is_model(msg['path']):
3953
vars.model_selected = msg['data']
3954
vars.custmodpth = msg['path']
3955
get_model_info(msg['data'], directory=msg['path'])
3956
else:
3957
if vars.host:
3958
sendModelSelection(menu=msg['data'], folder="./models")
3959
else:
3960
sendModelSelection(menu=msg['data'], folder=msg['path'])
3961
else:
3962
vars.model_selected = msg['data']
3963
if 'path' in msg:
3964
vars.custmodpth = msg['path']
3965
get_model_info(msg['data'], directory=msg['path'])
3966
else:
3967
get_model_info(vars.model_selected)
3968
elif(msg['cmd'] == 'delete_model'):
3969
if "{}/models".format(os.getcwd()) in os.path.abspath(msg['data']) or "{}\\models".format(os.getcwd()) in os.path.abspath(msg['data']):
3970
if check_if_dir_is_model(msg['data']):
3971
logger.warning(f"Someone deleted {msg['data']}")
3972
import shutil
3973
shutil.rmtree(msg['data'])
3974
sendModelSelection(menu=msg['menu'])
3975
else:
3976
logger.error(f"Someone attempted to delete {msg['data']} but this is not a valid model")
3977
else:
3978
logger.critical(f"Someone maliciously attempted to delete {msg['data']}. The attempt has been blocked.")
3979
elif(msg['cmd'] == 'OAI_Key_Update'):
3980
get_oai_models(msg['key'])
3981
elif(msg['cmd'] == 'Cluster_Key_Update'):
3982
get_cluster_models(msg)
3983
elif(msg['cmd'] == 'loadselect'):
3984
vars.loadselect = msg["data"]
3985
elif(msg['cmd'] == 'spselect'):
3986
vars.spselect = msg["data"]
3987
elif(msg['cmd'] == 'loadrequest'):
3988
loadRequest(fileops.storypath(vars.loadselect))
3989
elif(msg['cmd'] == 'sprequest'):
3990
spRequest(vars.spselect)
3991
elif(msg['cmd'] == 'deletestory'):
3992
deletesave(msg['data'])
3993
elif(msg['cmd'] == 'renamestory'):
3994
renamesave(msg['data'], msg['newname'])
3995
elif(msg['cmd'] == 'clearoverwrite'):
3996
vars.svowname = ""
3997
vars.saveow = False
3998
elif(msg['cmd'] == 'seqsel'):
3999
selectsequence(msg['data'])
4000
elif(msg['cmd'] == 'seqpin'):
4001
pinsequence(msg['data'])
4002
elif(msg['cmd'] == 'setnumseq'):
4003
vars.numseqs = int(msg['data'])
4004
emit('from_server', {'cmd': 'setlabelnumseq', 'data': msg['data']})
4005
settingschanged()
4006
refresh_settings()
4007
elif(msg['cmd'] == 'setwidepth'):
4008
vars.widepth = int(msg['data'])
4009
emit('from_server', {'cmd': 'setlabelwidepth', 'data': msg['data']})
4010
settingschanged()
4011
refresh_settings()
4012
elif(msg['cmd'] == 'setuseprompt'):
4013
vars.useprompt = msg['data']
4014
settingschanged()
4015
refresh_settings()
4016
elif(msg['cmd'] == 'setadventure'):
4017
vars.adventure = msg['data']
4018
vars.chatmode = False
4019
settingschanged()
4020
refresh_settings()
4021
elif(msg['cmd'] == 'autosave'):
4022
vars.autosave = msg['data']
4023
settingschanged()
4024
refresh_settings()
4025
elif(msg['cmd'] == 'setchatmode'):
4026
vars.chatmode = msg['data']
4027
vars.adventure = False
4028
settingschanged()
4029
refresh_settings()
4030
elif(msg['cmd'] == 'setdynamicscan'):
4031
vars.dynamicscan = msg['data']
4032
settingschanged()
4033
refresh_settings()
4034
elif(msg['cmd'] == 'setnopromptgen'):
4035
vars.nopromptgen = msg['data']
4036
settingschanged()
4037
refresh_settings()
4038
elif(msg['cmd'] == 'setrngpersist'):
4039
vars.rngpersist = msg['data']
4040
settingschanged()
4041
refresh_settings()
4042
elif(msg['cmd'] == 'setnogenmod'):
4043
vars.nogenmod = msg['data']
4044
settingschanged()
4045
refresh_settings()
4046
elif(msg['cmd'] == 'setfulldeterminism'):
4047
vars.full_determinism = msg['data']
4048
settingschanged()
4049
refresh_settings()
4050
elif(msg['cmd'] == 'setoutputstreaming'):
4051
vars.output_streaming = msg['data']
4052
settingschanged()
4053
refresh_settings()
4054
elif(msg['cmd'] == 'setshowbudget'):
4055
vars.show_budget = msg['data']
4056
settingschanged()
4057
refresh_settings()
4058
elif(msg['cmd'] == 'setshowprobs'):
4059
vars.show_probs = msg['data']
4060
settingschanged()
4061
refresh_settings()
4062
elif(not vars.host and msg['cmd'] == 'importwi'):
4063
wiimportrequest()
4064
elif(msg['cmd'] == 'debug'):
4065
vars.debug = msg['data']
4066
emit('from_server', {'cmd': 'set_debug', 'data': msg['data']}, broadcast=True)
4067
if vars.debug:
4068
send_debug()
4069
elif(msg['cmd'] == 'getfieldbudget'):
4070
unencoded = msg["data"]["unencoded"]
4071
field = msg["data"]["field"]
4072
4073
# Tokenizer may be undefined here when a model has not been chosen.
4074
if "tokenizer" not in globals():
4075
# We don't have a tokenizer, just return nulls.
4076
emit(
4077
'from_server',
4078
{'cmd': 'showfieldbudget', 'data': {"length": None, "max": None, "field": field}},
4079
)
4080
return
4081
4082
header_length = len(tokenizer._koboldai_header)
4083
max_tokens = vars.max_length - header_length - vars.sp_length - vars.genamt
4084
4085
if not unencoded:
4086
# Unencoded is empty, just return 0
4087
emit(
4088
'from_server',
4089
{'cmd': 'showfieldbudget', 'data': {"length": 0, "max": max_tokens, "field": field}},
4090
broadcast=True
4091
)
4092
else:
4093
if field == "anoteinput":
4094
unencoded = buildauthorsnote(unencoded, msg["data"]["anotetemplate"])
4095
tokens_length = len(tokenizer.encode(unencoded))
4096
4097
emit(
4098
'from_server',
4099
{'cmd': 'showfieldbudget', 'data': {"length": tokens_length, "max": max_tokens, "field": field}},
4100
broadcast=True
4101
)
4102
4103
#==================================================================#
4104
# Send userscripts list to client
4105
#==================================================================#
4106
def sendUSStatItems():
4107
_, loaded = getuslist()
4108
loaded = loaded if vars.lua_running else []
4109
last_userscripts = [e["filename"] for e in loaded]
4110
emit('from_server', {'cmd': 'usstatitems', 'data': loaded, 'flash': last_userscripts != vars.last_userscripts}, broadcast=True)
4111
vars.last_userscripts = last_userscripts
4112
4113
#==================================================================#
4114
# KoboldAI Markup Formatting (Mixture of Markdown and sanitized html)
4115
#==================================================================#
4116
def kml(txt):
4117
txt = txt.replace('>', '&gt;')
4118
txt = bleach.clean(markdown.markdown(txt), tags = ['p', 'em', 'strong', 'code', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'li', 'ul', 'b', 'i', 'a', 'span', 'button'], styles = ['color', 'font-weight'], attributes=['id', 'class', 'style', 'href'])
4119
return txt
4120
4121
#==================================================================#
4122
# Send start message and tell Javascript to set UI state
4123
#==================================================================#
4124
def setStartState():
4125
if(vars.welcome):
4126
txt = kml(vars.welcome) + "<br/>"
4127
else:
4128
txt = "<span>Welcome to <span class=\"color_cyan\">KoboldAI</span>! You are running <span class=\"color_green\">"+getmodelname()+"</span>.<br/>"
4129
if(not vars.noai and not vars.welcome):
4130
txt = txt + "Please load a game or enter a prompt below to begin!</span>"
4131
if(vars.noai):
4132
txt = txt + "Please load or import a story to read. There is no AI in this mode."
4133
emit('from_server', {'cmd': 'updatescreen', 'gamestarted': vars.gamestarted, 'data': txt}, broadcast=True)
4134
emit('from_server', {'cmd': 'setgamestate', 'data': 'start'}, broadcast=True)
4135
4136
#==================================================================#
4137
# Transmit applicable settings to SocketIO to build UI sliders/toggles
4138
#==================================================================#
4139
def sendsettings():
4140
# Send settings for selected AI type
4141
emit('from_server', {'cmd': 'reset_menus'})
4142
if(vars.model != "InferKit"):
4143
for set in gensettings.gensettingstf:
4144
emit('from_server', {'cmd': 'addsetting', 'data': set})
4145
else:
4146
for set in gensettings.gensettingsik:
4147
emit('from_server', {'cmd': 'addsetting', 'data': set})
4148
4149
# Send formatting options
4150
for frm in gensettings.formatcontrols:
4151
emit('from_server', {'cmd': 'addformat', 'data': frm})
4152
# Add format key to vars if it wasn't loaded with client.settings
4153
if(not frm["id"] in vars.formatoptns):
4154
vars.formatoptns[frm["id"]] = False;
4155
4156
#==================================================================#
4157
# Set value of gamesaved
4158
#==================================================================#
4159
def setgamesaved(gamesaved):
4160
assert type(gamesaved) is bool
4161
if(gamesaved != vars.gamesaved):
4162
emit('from_server', {'cmd': 'gamesaved', 'data': gamesaved}, broadcast=True)
4163
vars.gamesaved = gamesaved
4164
4165
#==================================================================#
4166
# Take input text from SocketIO and decide what to do with it
4167
#==================================================================#
4168
4169
def check_for_backend_compilation():
4170
if(vars.checking):
4171
return
4172
vars.checking = True
4173
for _ in range(31):
4174
time.sleep(0.06276680299820175)
4175
if(vars.compiling):
4176
emit('from_server', {'cmd': 'warnmsg', 'data': 'Compiling TPU backend&mdash;this usually takes 1&ndash;2 minutes...'}, broadcast=True)
4177
break
4178
vars.checking = False
4179
4180
def actionsubmit(data, actionmode=0, force_submit=False, force_prompt_gen=False, disable_recentrng=False, no_generate=False, ignore_aibusy=False):
4181
# Ignore new submissions if the AI is currently busy
4182
if(not ignore_aibusy and vars.aibusy):
4183
return
4184
4185
while(True):
4186
set_aibusy(1)
4187
4188
if(vars.model in ["API","CLUSTER"]):
4189
global tokenizer
4190
if vars.model == "API":
4191
tokenizer_id = requests.get(
4192
vars.colaburl[:-8] + "/api/v1/model",
4193
).json()["result"]
4194
elif len(vars.cluster_requested_models) >= 1:
4195
# If the player has requested one or more models, we use the first one for the tokenizer
4196
tokenizer_id = vars.cluster_requested_models[0]
4197
# The cluster can return any number of possible models for each gen, but this happens after this step
4198
# So at this point, this is unknown
4199
else:
4200
tokenizer_id = ""
4201
if tokenizer_id != vars.api_tokenizer_id:
4202
try:
4203
if(os.path.isdir(tokenizer_id)):
4204
try:
4205
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, revision=args.revision, cache_dir="cache")
4206
except:
4207
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, revision=args.revision, cache_dir="cache", use_fast=False)
4208
elif(os.path.isdir("models/{}".format(tokenizer_id.replace('/', '_')))):
4209
try:
4210
tokenizer = AutoTokenizer.from_pretrained("models/{}".format(tokenizer_id.replace('/', '_')), revision=args.revision, cache_dir="cache")
4211
except:
4212
tokenizer = AutoTokenizer.from_pretrained("models/{}".format(tokenizer_id.replace('/', '_')), revision=args.revision, cache_dir="cache", use_fast=False)
4213
else:
4214
try:
4215
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, revision=args.revision, cache_dir="cache")
4216
except:
4217
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, revision=args.revision, cache_dir="cache", use_fast=False)
4218
except:
4219
logger.warning(f"Unknown tokenizer {repr(tokenizer_id)}")
4220
vars.api_tokenizer_id = tokenizer_id
4221
4222
if(disable_recentrng):
4223
vars.recentrng = vars.recentrngm = None
4224
4225
vars.recentback = False
4226
vars.recentedit = False
4227
vars.actionmode = actionmode
4228
4229
# "Action" mode
4230
if(actionmode == 1):
4231
data = data.strip().lstrip('>')
4232
data = re.sub(r'\n+', ' ', data)
4233
if(len(data)):
4234
data = f"\n\n> {data}\n"
4235
4236
# "Chat" mode
4237
if(vars.chatmode and vars.gamestarted):
4238
data = re.sub(r'\n+', ' ', data)
4239
if(len(data)):
4240
data = f"\n{vars.chatname}: {data}\n"
4241
4242
# If we're not continuing, store a copy of the raw input
4243
if(data != ""):
4244
vars.lastact = data
4245
4246
if(not vars.gamestarted):
4247
vars.submission = data
4248
if(not no_generate):
4249
execute_inmod()
4250
vars.submission = re.sub(r"[^\S\r\n]*([\r\n]*)$", r"\1", vars.submission) # Remove trailing whitespace, excluding newlines
4251
data = vars.submission
4252
if(not force_submit and len(data.strip()) == 0):
4253
assert False
4254
# Start the game
4255
vars.gamestarted = True
4256
if(not no_generate and not vars.noai and vars.lua_koboldbridge.generating and (not vars.nopromptgen or force_prompt_gen)):
4257
# Save this first action as the prompt
4258
vars.prompt = data
4259
# Clear the startup text from game screen
4260
emit('from_server', {'cmd': 'updatescreen', 'gamestarted': False, 'data': 'Please wait, generating story...'}, broadcast=True)
4261
calcsubmit(data) # Run the first action through the generator
4262
if(not no_generate and not vars.abort and vars.lua_koboldbridge.restart_sequence is not None and len(vars.genseqs) == 0):
4263
data = ""
4264
force_submit = True
4265
disable_recentrng = True
4266
continue
4267
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
4268
break
4269
else:
4270
# Save this first action as the prompt
4271
vars.prompt = data if len(data) > 0 else '"'
4272
for i in range(vars.numseqs):
4273
vars.lua_koboldbridge.outputs[i+1] = ""
4274
if(not no_generate):
4275
execute_outmod()
4276
vars.lua_koboldbridge.regeneration_required = False
4277
genout = []
4278
for i in range(vars.numseqs):
4279
genout.append({"generated_text": vars.lua_koboldbridge.outputs[i+1]})
4280
assert type(genout[-1]["generated_text"]) is str
4281
if(len(genout) == 1):
4282
genresult(genout[0]["generated_text"], flash=False)
4283
refresh_story()
4284
if(len(vars.actions) > 0):
4285
emit('from_server', {'cmd': 'texteffect', 'data': vars.actions.get_last_key() + 1}, broadcast=True)
4286
if(not vars.abort and vars.lua_koboldbridge.restart_sequence is not None):
4287
data = ""
4288
force_submit = True
4289
disable_recentrng = True
4290
continue
4291
else:
4292
if(not vars.abort and vars.lua_koboldbridge.restart_sequence is not None and vars.lua_koboldbridge.restart_sequence > 0):
4293
genresult(genout[vars.lua_koboldbridge.restart_sequence-1]["generated_text"], flash=False)
4294
refresh_story()
4295
data = ""
4296
force_submit = True
4297
disable_recentrng = True
4298
continue
4299
genselect(genout)
4300
refresh_story()
4301
set_aibusy(0)
4302
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
4303
break
4304
else:
4305
# Apply input formatting & scripts before sending to tokenizer
4306
if(vars.actionmode == 0):
4307
data = applyinputformatting(data)
4308
vars.submission = data
4309
if(not no_generate):
4310
execute_inmod()
4311
vars.submission = re.sub(r"[^\S\r\n]*([\r\n]*)$", r"\1", vars.submission) # Remove trailing whitespace, excluding newlines
4312
data = vars.submission
4313
# Dont append submission if it's a blank/continue action
4314
if(data != ""):
4315
# Store the result in the Action log
4316
if(len(vars.prompt.strip()) == 0):
4317
vars.prompt = data
4318
else:
4319
vars.actions.append(data)
4320
# we now need to update the actions_metadata
4321
# we'll have two conditions.
4322
# 1. This is totally new (user entered)
4323
if vars.actions.get_last_key() not in vars.actions_metadata:
4324
vars.actions_metadata[vars.actions.get_last_key()] = {"Selected Text": data, "Alternative Text": []}
4325
else:
4326
# 2. We've selected a chunk of text that is was presented previously
4327
try:
4328
alternatives = [item['Text'] for item in vars.actions_metadata[len(vars.actions)-1]["Alternative Text"]]
4329
except:
4330
logger.debug(len(vars.actions))
4331
logger.debug(vars.actions_metadata)
4332
raise
4333
if data in alternatives:
4334
alternatives = [item for item in vars.actions_metadata[vars.actions.get_last_key() ]["Alternative Text"] if item['Text'] != data]
4335
vars.actions_metadata[vars.actions.get_last_key()]["Alternative Text"] = alternatives
4336
vars.actions_metadata[vars.actions.get_last_key()]["Selected Text"] = data
4337
update_story_chunk('last')
4338
send_debug()
4339
4340
if(not no_generate and not vars.noai and vars.lua_koboldbridge.generating):
4341
# Off to the tokenizer!
4342
calcsubmit(data)
4343
if(not vars.abort and vars.lua_koboldbridge.restart_sequence is not None and len(vars.genseqs) == 0):
4344
data = ""
4345
force_submit = True
4346
disable_recentrng = True
4347
continue
4348
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
4349
break
4350
else:
4351
if(not no_generate):
4352
for i in range(vars.numseqs):
4353
vars.lua_koboldbridge.outputs[i+1] = ""
4354
execute_outmod()
4355
vars.lua_koboldbridge.regeneration_required = False
4356
genout = []
4357
for i in range(vars.numseqs):
4358
genout.append({"generated_text": vars.lua_koboldbridge.outputs[i+1] if not no_generate else ""})
4359
assert type(genout[-1]["generated_text"]) is str
4360
if(len(genout) == 1):
4361
genresult(genout[0]["generated_text"])
4362
if(not no_generate and not vars.abort and vars.lua_koboldbridge.restart_sequence is not None):
4363
data = ""
4364
force_submit = True
4365
disable_recentrng = True
4366
continue
4367
else:
4368
if(not no_generate and not vars.abort and vars.lua_koboldbridge.restart_sequence is not None and vars.lua_koboldbridge.restart_sequence > 0):
4369
genresult(genout[vars.lua_koboldbridge.restart_sequence-1]["generated_text"])
4370
data = ""
4371
force_submit = True
4372
disable_recentrng = True
4373
continue
4374
genselect(genout)
4375
set_aibusy(0)
4376
emit('from_server', {'cmd': 'scrolldown', 'data': ''}, broadcast=True)
4377
break
4378
4379
def apiactionsubmit_generate(txt, minimum, maximum):
4380
vars.generated_tkns = 0
4381
4382
if not vars.quiet:
4383
logger.debug(f"Prompt Min:{minimum}, Max:{maximum}")
4384
logger.prompt(utils.decodenewlines(tokenizer.decode(txt)).encode("unicode_escape").decode("utf-8"))
4385
4386
# Clear CUDA cache if using GPU
4387
if(vars.hascuda and (vars.usegpu or vars.breakmodel)):
4388
gc.collect()
4389
torch.cuda.empty_cache()
4390
4391
# Submit input text to generator
4392
_genout, already_generated = tpool.execute(_generate, txt, minimum, maximum, set())
4393
4394
genout = [applyoutputformatting(utils.decodenewlines(tokenizer.decode(tokens[-already_generated:]))) for tokens in _genout]
4395
4396
# Clear CUDA cache again if using GPU
4397
if(vars.hascuda and (vars.usegpu or vars.breakmodel)):
4398
del _genout
4399
gc.collect()
4400
torch.cuda.empty_cache()
4401
4402
return genout
4403
4404
def apiactionsubmit_tpumtjgenerate(txt, minimum, maximum):
4405
vars.generated_tkns = 0
4406
4407
if(vars.full_determinism):
4408
tpu_mtj_backend.set_rng_seed(vars.seed)
4409
4410
if not vars.quiet:
4411
logger.debug(f"Prompt Min:{minimum}, Max:{maximum}")
4412
logger.prompt(utils.decodenewlines(tokenizer.decode(txt)).encode("unicode_escape").decode("utf-8"))
4413
4414
vars._actions = vars.actions
4415
vars._prompt = vars.prompt
4416
if(vars.dynamicscan):
4417
vars._actions = vars._actions.copy()
4418
4419
# Submit input text to generator
4420
soft_tokens = tpumtjgetsofttokens()
4421
genout = tpool.execute(
4422
tpu_mtj_backend.infer_static,
4423
np.uint32(txt),
4424
gen_len = maximum-minimum+1,
4425
temp=vars.temp,
4426
top_p=vars.top_p,
4427
top_k=vars.top_k,
4428
tfs=vars.tfs,
4429
typical=vars.typical,
4430
top_a=vars.top_a,
4431
numseqs=vars.numseqs,
4432
repetition_penalty=vars.rep_pen,
4433
rpslope=vars.rep_pen_slope,
4434
rprange=vars.rep_pen_range,
4435
soft_embeddings=vars.sp,
4436
soft_tokens=soft_tokens,
4437
sampler_order=vars.sampler_order,
4438
)
4439
genout = [applyoutputformatting(utils.decodenewlines(tokenizer.decode(txt))) for txt in genout]
4440
4441
return genout
4442
4443
def apiactionsubmit(data, use_memory=False, use_world_info=False, use_story=False, use_authors_note=False):
4444
if(vars.model == "Colab"):
4445
raise NotImplementedError("API generation is not supported in old Colab API mode.")
4446
elif(vars.model == "API"):
4447
raise NotImplementedError("API generation is not supported in API mode.")
4448
elif(vars.model == "CLUSTER"):
4449
raise NotImplementedError("API generation is not supported in API mode.")
4450
elif(vars.model == "OAI"):
4451
raise NotImplementedError("API generation is not supported in OpenAI/GooseAI mode.")
4452
elif(vars.model == "ReadOnly"):
4453
raise NotImplementedError("API generation is not supported in read-only mode; please load a model and then try again.")
4454
4455
data = applyinputformatting(data)
4456
4457
if(vars.memory != "" and vars.memory[-1] != "\n"):
4458
mem = vars.memory + "\n"
4459
else:
4460
mem = vars.memory
4461
if(use_authors_note and vars.authornote != ""):
4462
anotetxt = ("\n" + vars.authornotetemplate + "\n").replace("<|>", vars.authornote)
4463
else:
4464
anotetxt = ""
4465
MIN_STORY_TOKENS = 8
4466
story_tokens = []
4467
mem_tokens = []
4468
wi_tokens = []
4469
story_budget = lambda: vars.max_length - vars.sp_length - vars.genamt - len(tokenizer._koboldai_header) - len(story_tokens) - len(mem_tokens) - len(wi_tokens)
4470
budget = lambda: story_budget() + MIN_STORY_TOKENS
4471
if budget() < 0:
4472
abort(Response(json.dumps({"detail": {
4473
"msg": f"Your Max Tokens setting is too low for your current soft prompt and tokenizer to handle. It needs to be at least {vars.max_length - budget()}.",
4474
"type": "token_overflow",
4475
}}), mimetype="application/json", status=500))
4476
if use_memory:
4477
mem_tokens = tokenizer.encode(utils.encodenewlines(mem))[-budget():]
4478
if use_world_info:
4479
world_info, _ = checkworldinfo(data, force_use_txt=True, scan_story=use_story)
4480
wi_tokens = tokenizer.encode(utils.encodenewlines(world_info))[-budget():]
4481
if use_story:
4482
if vars.useprompt:
4483
story_tokens = tokenizer.encode(utils.encodenewlines(vars.prompt))[-budget():]
4484
story_tokens = tokenizer.encode(utils.encodenewlines(data))[-story_budget():] + story_tokens
4485
if use_story:
4486
for i, action in enumerate(reversed(vars.actions.values())):
4487
if story_budget() <= 0:
4488
assert story_budget() == 0
4489
break
4490
story_tokens = tokenizer.encode(utils.encodenewlines(action))[-story_budget():] + story_tokens
4491
if i == vars.andepth - 1:
4492
story_tokens = tokenizer.encode(utils.encodenewlines(anotetxt))[-story_budget():] + story_tokens
4493
if not vars.useprompt:
4494
story_tokens = tokenizer.encode(utils.encodenewlines(vars.prompt))[-budget():] + story_tokens
4495
tokens = tokenizer._koboldai_header + mem_tokens + wi_tokens + story_tokens
4496
assert story_budget() >= 0
4497
minimum = len(tokens) + 1
4498
maximum = len(tokens) + vars.genamt
4499
4500
if(not vars.use_colab_tpu and vars.model not in ["Colab", "API", "CLUSTER", "OAI", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]):
4501
genout = apiactionsubmit_generate(tokens, minimum, maximum)
4502
elif(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
4503
genout = apiactionsubmit_tpumtjgenerate(tokens, minimum, maximum)
4504
4505
return genout
4506
4507
#==================================================================#
4508
#
4509
#==================================================================#
4510
def actionretry(data):
4511
if(vars.noai):
4512
emit('from_server', {'cmd': 'errmsg', 'data': "Retry function unavailable in Read Only mode."})
4513
return
4514
if(vars.recentrng is not None):
4515
if(not vars.aibusy):
4516
randomGameRequest(vars.recentrng, memory=vars.recentrngm)
4517
return
4518
if actionback():
4519
actionsubmit("", actionmode=vars.actionmode, force_submit=True)
4520
send_debug()
4521
elif(not vars.useprompt):
4522
emit('from_server', {'cmd': 'errmsg', 'data': "Please enable \"Always Add Prompt\" to retry with your prompt."})
4523
4524
#==================================================================#
4525
#
4526
#==================================================================#
4527
def actionback():
4528
if(vars.aibusy):
4529
return
4530
# Remove last index of actions and refresh game screen
4531
if(len(vars.genseqs) == 0 and len(vars.actions) > 0):
4532
# We are going to move the selected text to alternative text in the actions_metadata variable so we can redo this action
4533
vars.actions_metadata[vars.actions.get_last_key() ]['Alternative Text'] = [{'Text': vars.actions_metadata[vars.actions.get_last_key() ]['Selected Text'],
4534
'Pinned': False,
4535
"Previous Selection": True,
4536
"Edited": False}] + vars.actions_metadata[vars.actions.get_last_key() ]['Alternative Text']
4537
vars.actions_metadata[vars.actions.get_last_key() ]['Selected Text'] = ""
4538
4539
last_key = vars.actions.get_last_key()
4540
vars.actions.pop()
4541
vars.recentback = True
4542
remove_story_chunk(last_key + 1)
4543
#for the redo to not get out of whack, need to reset the max # in the actions sequence
4544
vars.actions.set_next_id(last_key)
4545
success = True
4546
elif(len(vars.genseqs) == 0):
4547
emit('from_server', {'cmd': 'errmsg', 'data': "Cannot delete the prompt."})
4548
success = False
4549
else:
4550
vars.genseqs = []
4551
success = True
4552
send_debug()
4553
return success
4554
4555
def actionredo():
4556
i = 0
4557
#First we need to find the next valid key
4558
#We might have deleted text so we don't want to show a redo for that blank chunk
4559
4560
restore_id = vars.actions.get_last_key()+1
4561
if restore_id in vars.actions_metadata:
4562
ok_to_use = False
4563
while not ok_to_use:
4564
for item in vars.actions_metadata[restore_id]['Alternative Text']:
4565
if item['Previous Selection'] and item['Text'] != "":
4566
ok_to_use = True
4567
if not ok_to_use:
4568
restore_id+=1
4569
if restore_id not in vars.actions_metadata:
4570
return
4571
else:
4572
vars.actions.set_next_id(restore_id)
4573
4574
4575
if restore_id in vars.actions_metadata:
4576
genout = [{"generated_text": item['Text']} for item in vars.actions_metadata[restore_id]['Alternative Text'] if (item["Previous Selection"]==True)]
4577
if len(genout) > 0:
4578
genout = genout + [{"generated_text": item['Text']} for item in vars.actions_metadata[restore_id]['Alternative Text'] if (item["Pinned"]==True) and (item["Previous Selection"]==False)]
4579
if len(genout) == 1:
4580
vars.actions_metadata[restore_id]['Alternative Text'] = [item for item in vars.actions_metadata[restore_id]['Alternative Text'] if (item["Previous Selection"]!=True)]
4581
genresult(genout[0]['generated_text'], flash=True, ignore_formatting=True)
4582
else:
4583
# Store sequences in memory until selection is made
4584
vars.genseqs = genout
4585
4586
4587
# Send sequences to UI for selection
4588
genout = [[item['Text'], "redo"] for item in vars.actions_metadata[restore_id]['Alternative Text'] if (item["Previous Selection"]==True)]
4589
4590
emit('from_server', {'cmd': 'genseqs', 'data': genout}, broadcast=True)
4591
else:
4592
emit('from_server', {'cmd': 'popuperror', 'data': "There's nothing to undo"}, broadcast=True)
4593
send_debug()
4594
4595
#==================================================================#
4596
#
4597
#==================================================================#
4598
def buildauthorsnote(authorsnote, template):
4599
# Build Author's Note if set
4600
if authorsnote == "":
4601
return ""
4602
return ("\n" + template + "\n").replace("<|>", authorsnote)
4603
4604
def calcsubmitbudgetheader(txt, **kwargs):
4605
# Scan for WorldInfo matches
4606
winfo, found_entries = checkworldinfo(txt, **kwargs)
4607
4608
# Add a newline to the end of memory
4609
if(vars.memory != "" and vars.memory[-1] != "\n"):
4610
mem = vars.memory + "\n"
4611
else:
4612
mem = vars.memory
4613
4614
anotetxt = buildauthorsnote(vars.authornote, vars.authornotetemplate)
4615
4616
return winfo, mem, anotetxt, found_entries
4617
4618
def calcsubmitbudget(actionlen, winfo, mem, anotetxt, actions, submission=None, budget_deduction=0):
4619
forceanote = False # In case we don't have enough actions to hit A.N. depth
4620
anoteadded = False # In case our budget runs out before we hit A.N. depth
4621
anotetkns = [] # Placeholder for Author's Note tokens
4622
lnanote = 0 # Placeholder for Author's Note length
4623
4624
lnsp = vars.sp_length
4625
4626
if("tokenizer" not in globals()):
4627
from transformers import GPT2Tokenizer
4628
global tokenizer
4629
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", revision=args.revision, cache_dir="cache")
4630
4631
lnheader = len(tokenizer._koboldai_header)
4632
4633
# Calculate token budget
4634
prompttkns = tokenizer.encode(utils.encodenewlines(vars.comregex_ai.sub('', vars.prompt)), max_length=int(2e9), truncation=True)
4635
lnprompt = len(prompttkns)
4636
4637
memtokens = tokenizer.encode(utils.encodenewlines(mem), max_length=int(2e9), truncation=True)
4638
lnmem = len(memtokens)
4639
if(lnmem > vars.max_length - lnheader - lnsp - vars.genamt - budget_deduction):
4640
raise OverflowError("The memory in your story is too long. Please either write a shorter memory text or increase the Max Tokens setting. If you are using a soft prompt, additionally consider using a smaller soft prompt.")
4641
4642
witokens = tokenizer.encode(utils.encodenewlines(winfo), max_length=int(2e9), truncation=True)
4643
lnwi = len(witokens)
4644
if(lnmem + lnwi > vars.max_length - lnheader - lnsp - vars.genamt - budget_deduction):
4645
raise OverflowError("The current active world info keys take up too many tokens. Please either write shorter world info, decrease World Info Depth or increase the Max Tokens setting. If you are using a soft prompt, additionally consider using a smaller soft prompt.")
4646
4647
if(anotetxt != ""):
4648
anotetkns = tokenizer.encode(utils.encodenewlines(anotetxt), max_length=int(2e9), truncation=True)
4649
lnanote = len(anotetkns)
4650
if(lnmem + lnwi + lnanote > vars.max_length - lnheader - lnsp - vars.genamt - budget_deduction):
4651
raise OverflowError("The author's note in your story is too long. Please either write a shorter author's note or increase the Max Tokens setting. If you are using a soft prompt, additionally consider using a smaller soft prompt.")
4652
4653
if(vars.useprompt):
4654
budget = vars.max_length - lnheader - lnsp - lnprompt - lnmem - lnanote - lnwi - vars.genamt - budget_deduction
4655
else:
4656
budget = vars.max_length - lnheader - lnsp - lnmem - lnanote - lnwi - vars.genamt - budget_deduction
4657
4658
lnsubmission = len(tokenizer.encode(utils.encodenewlines(vars.comregex_ai.sub('', submission)), max_length=int(2e9), truncation=True)) if submission is not None else 0
4659
maybe_lnprompt = lnprompt if vars.useprompt and actionlen > 0 else 0
4660
4661
if(lnmem + lnwi + lnanote + maybe_lnprompt + lnsubmission > vars.max_length - lnheader - lnsp - vars.genamt - budget_deduction):
4662
raise OverflowError("Your submission is too long. Please either write a shorter submission or increase the Max Tokens setting. If you are using a soft prompt, additionally consider using a smaller soft prompt. If you are using the Always Add Prompt setting, turning it off may help.")
4663
4664
assert budget >= 0
4665
4666
if(actionlen == 0):
4667
# First/Prompt action
4668
tokens = (tokenizer._koboldai_header if vars.model not in ("Colab", "API", "CLUSTER", "OAI") else []) + memtokens + witokens + anotetkns + prompttkns
4669
assert len(tokens) <= vars.max_length - lnsp - vars.genamt - budget_deduction
4670
ln = len(tokens) + lnsp
4671
return tokens, ln+1, ln+vars.genamt
4672
else:
4673
tokens = []
4674
4675
# Check if we have the action depth to hit our A.N. depth
4676
if(anotetxt != "" and actionlen < vars.andepth):
4677
forceanote = True
4678
4679
# Get most recent action tokens up to our budget
4680
n = 0
4681
for key in reversed(actions):
4682
chunk = vars.comregex_ai.sub('', actions[key])
4683
4684
assert budget >= 0
4685
if(budget <= 0):
4686
break
4687
acttkns = tokenizer.encode(utils.encodenewlines(chunk), max_length=int(2e9), truncation=True)
4688
tknlen = len(acttkns)
4689
if(tknlen < budget):
4690
tokens = acttkns + tokens
4691
budget -= tknlen
4692
else:
4693
count = budget * -1
4694
tokens = acttkns[count:] + tokens
4695
budget = 0
4696
break
4697
4698
# Inject Author's Note if we've reached the desired depth
4699
if(n == vars.andepth-1):
4700
if(anotetxt != ""):
4701
tokens = anotetkns + tokens # A.N. len already taken from bdgt
4702
anoteadded = True
4703
n += 1
4704
4705
# If we're not using the prompt every time and there's still budget left,
4706
# add some prompt.
4707
if(not vars.useprompt):
4708
if(budget > 0):
4709
prompttkns = prompttkns[-budget:]
4710
else:
4711
prompttkns = []
4712
4713
# Did we get to add the A.N.? If not, do it here
4714
if(anotetxt != ""):
4715
if((not anoteadded) or forceanote):
4716
tokens = (tokenizer._koboldai_header if vars.model not in ("Colab", "API", "CLUSTER", "OAI") else []) + memtokens + witokens + anotetkns + prompttkns + tokens
4717
else:
4718
tokens = (tokenizer._koboldai_header if vars.model not in ("Colab", "API", "CLUSTER", "OAI") else []) + memtokens + witokens + prompttkns + tokens
4719
else:
4720
# Prepend Memory, WI, and Prompt before action tokens
4721
tokens = (tokenizer._koboldai_header if vars.model not in ("Colab", "API", "CLUSTER", "OAI") else []) + memtokens + witokens + prompttkns + tokens
4722
4723
# Send completed bundle to generator
4724
assert len(tokens) <= vars.max_length - lnsp - vars.genamt - budget_deduction
4725
ln = len(tokens) + lnsp
4726
return tokens, ln+1, ln+vars.genamt
4727
4728
#==================================================================#
4729
# Take submitted text and build the text to be given to generator
4730
#==================================================================#
4731
def calcsubmit(txt):
4732
anotetxt = "" # Placeholder for Author's Note text
4733
forceanote = False # In case we don't have enough actions to hit A.N. depth
4734
anoteadded = False # In case our budget runs out before we hit A.N. depth
4735
actionlen = len(vars.actions)
4736
4737
winfo, mem, anotetxt, found_entries = calcsubmitbudgetheader(txt)
4738
4739
# For all transformers models
4740
if(vars.model != "InferKit"):
4741
subtxt, min, max = calcsubmitbudget(actionlen, winfo, mem, anotetxt, vars.actions, submission=txt)
4742
if(actionlen == 0):
4743
if(not vars.use_colab_tpu and vars.model not in ["Colab", "API", "CLUSTER", "OAI", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]):
4744
generate(subtxt, min, max, found_entries=found_entries)
4745
elif(vars.model == "Colab"):
4746
sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
4747
elif(vars.model == "API"):
4748
sendtoapi(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
4749
elif(vars.model == "CLUSTER"):
4750
sendtocluster(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
4751
elif(vars.model == "OAI"):
4752
oairequest(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
4753
elif(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
4754
tpumtjgenerate(subtxt, min, max, found_entries=found_entries)
4755
else:
4756
if(not vars.use_colab_tpu and vars.model not in ["Colab", "API", "CLUSTER", "OAI", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]):
4757
generate(subtxt, min, max, found_entries=found_entries)
4758
elif(vars.model == "Colab"):
4759
sendtocolab(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
4760
elif(vars.model == "API"):
4761
sendtoapi(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
4762
elif(vars.model == "CLUSTER"):
4763
sendtocluster(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
4764
elif(vars.model == "OAI"):
4765
oairequest(utils.decodenewlines(tokenizer.decode(subtxt)), min, max)
4766
elif(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
4767
tpumtjgenerate(subtxt, min, max, found_entries=found_entries)
4768
4769
# For InferKit web API
4770
else:
4771
# Check if we have the action depth to hit our A.N. depth
4772
if(anotetxt != "" and actionlen < vars.andepth):
4773
forceanote = True
4774
4775
if(vars.useprompt):
4776
budget = vars.ikmax - len(vars.comregex_ai.sub('', vars.prompt)) - len(anotetxt) - len(mem) - len(winfo) - 1
4777
else:
4778
budget = vars.ikmax - len(anotetxt) - len(mem) - len(winfo) - 1
4779
4780
subtxt = ""
4781
prompt = vars.comregex_ai.sub('', vars.prompt)
4782
n = 0
4783
for key in reversed(vars.actions):
4784
chunk = vars.actions[key]
4785
4786
if(budget <= 0):
4787
break
4788
actlen = len(chunk)
4789
if(actlen < budget):
4790
subtxt = chunk + subtxt
4791
budget -= actlen
4792
else:
4793
count = budget * -1
4794
subtxt = chunk[count:] + subtxt
4795
budget = 0
4796
break
4797
4798
# If we're not using the prompt every time and there's still budget left,
4799
# add some prompt.
4800
if(not vars.useprompt):
4801
if(budget > 0):
4802
prompt = vars.comregex_ai.sub('', vars.prompt)[-budget:]
4803
else:
4804
prompt = ""
4805
4806
# Inject Author's Note if we've reached the desired depth
4807
if(n == vars.andepth-1):
4808
if(anotetxt != ""):
4809
subtxt = anotetxt + subtxt # A.N. len already taken from bdgt
4810
anoteadded = True
4811
n += 1
4812
4813
# Did we get to add the A.N.? If not, do it here
4814
if(anotetxt != ""):
4815
if((not anoteadded) or forceanote):
4816
subtxt = mem + winfo + anotetxt + prompt + subtxt
4817
else:
4818
subtxt = mem + winfo + prompt + subtxt
4819
else:
4820
subtxt = mem + winfo + prompt + subtxt
4821
4822
# Send it!
4823
ikrequest(subtxt)
4824
4825
#==================================================================#
4826
# Send text to generator and deal with output
4827
#==================================================================#
4828
4829
def _generate(txt, minimum, maximum, found_entries):
4830
if(vars.full_determinism):
4831
torch.manual_seed(vars.seed)
4832
4833
gen_in = torch.tensor(txt, dtype=torch.long)[None]
4834
if(vars.sp is not None):
4835
soft_tokens = torch.arange(
4836
model.config.vocab_size,
4837
model.config.vocab_size + vars.sp.shape[0],
4838
)
4839
gen_in = torch.cat((soft_tokens[None], gen_in), dim=-1)
4840
assert gen_in.shape[-1] + vars.genamt <= vars.max_length
4841
4842
if(vars.hascuda and vars.usegpu):
4843
gen_in = gen_in.to(vars.gpu_device)
4844
elif(vars.hascuda and vars.breakmodel):
4845
gen_in = gen_in.to(breakmodel.primary_device)
4846
else:
4847
gen_in = gen_in.to('cpu')
4848
4849
model.kai_scanner_excluded_world_info = found_entries
4850
4851
vars._actions = vars.actions
4852
vars._prompt = vars.prompt
4853
if(vars.dynamicscan):
4854
vars._actions = vars._actions.copy()
4855
4856
with torch.no_grad():
4857
already_generated = 0
4858
numseqs = vars.numseqs
4859
while True:
4860
genout = generator(
4861
gen_in,
4862
do_sample=True,
4863
max_length=int(2e9),
4864
repetition_penalty=1.0,
4865
bad_words_ids=vars.badwordsids,
4866
use_cache=True,
4867
num_return_sequences=numseqs
4868
)
4869
already_generated += len(genout[0]) - len(gen_in[0])
4870
assert already_generated <= vars.genamt
4871
if(model.kai_scanner.halt or not model.kai_scanner.regeneration_required):
4872
break
4873
assert genout.ndim >= 2
4874
assert genout.shape[0] == vars.numseqs
4875
if(vars.lua_koboldbridge.generated_cols and vars.generated_tkns != vars.lua_koboldbridge.generated_cols):
4876
raise RuntimeError("Inconsistency detected between KoboldAI Python and Lua backends")
4877
if(already_generated != vars.generated_tkns):
4878
raise RuntimeError("WI scanning error")
4879
for r in range(vars.numseqs):
4880
for c in range(already_generated):
4881
assert vars.lua_koboldbridge.generated[r+1][c+1] is not None
4882
genout[r][genout.shape[-1] - already_generated + c] = vars.lua_koboldbridge.generated[r+1][c+1]
4883
encoded = []
4884
for i in range(vars.numseqs):
4885
txt = utils.decodenewlines(tokenizer.decode(genout[i, -already_generated:]))
4886
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True, actions=vars._actions)
4887
found_entries[i].update(_found_entries)
4888
txt, _, _ = calcsubmitbudget(len(vars._actions), winfo, mem, anotetxt, vars._actions, submission=txt)
4889
encoded.append(torch.tensor(txt, dtype=torch.long, device=genout.device))
4890
max_length = len(max(encoded, key=len))
4891
encoded = torch.stack(tuple(torch.nn.functional.pad(e, (max_length - len(e), 0), value=model.config.pad_token_id or model.config.eos_token_id) for e in encoded))
4892
genout = torch.cat(
4893
(
4894
encoded,
4895
genout[..., -already_generated:],
4896
),
4897
dim=-1
4898
)
4899
if(vars.sp is not None):
4900
soft_tokens = torch.arange(
4901
model.config.vocab_size,
4902
model.config.vocab_size + vars.sp.shape[0],
4903
device=genout.device,
4904
)
4905
genout = torch.cat((soft_tokens.tile(vars.numseqs, 1), genout), dim=-1)
4906
assert genout.shape[-1] + vars.genamt - already_generated <= vars.max_length
4907
diff = genout.shape[-1] - gen_in.shape[-1]
4908
minimum += diff
4909
maximum += diff
4910
gen_in = genout
4911
numseqs = 1
4912
4913
return genout, already_generated
4914
4915
4916
def generate(txt, minimum, maximum, found_entries=None):
4917
vars.generated_tkns = 0
4918
4919
if(found_entries is None):
4920
found_entries = set()
4921
found_entries = tuple(found_entries.copy() for _ in range(vars.numseqs))
4922
4923
if not vars.quiet:
4924
logger.debug(f"Prompt Min:{minimum}, Max:{maximum}")
4925
logger.prompt(utils.decodenewlines(tokenizer.decode(txt)).encode("unicode_escape").decode("utf-8"))
4926
4927
# Store context in memory to use it for comparison with generated content
4928
vars.lastctx = utils.decodenewlines(tokenizer.decode(txt))
4929
4930
# Clear CUDA cache if using GPU
4931
if(vars.hascuda and (vars.usegpu or vars.breakmodel)):
4932
gc.collect()
4933
torch.cuda.empty_cache()
4934
4935
# Submit input text to generator
4936
try:
4937
genout, already_generated = tpool.execute(_generate, txt, minimum, maximum, found_entries)
4938
except Exception as e:
4939
if(issubclass(type(e), lupa.LuaError)):
4940
vars.lua_koboldbridge.obliterate_multiverse()
4941
vars.lua_running = False
4942
emit('from_server', {'cmd': 'errmsg', 'data': 'Lua script error; please check console.'}, broadcast=True)
4943
sendUSStatItems()
4944
logger.error('LUA ERROR: ' + str(e).replace("\033", ""))
4945
logger.warning("Lua engine stopped; please open 'Userscripts' and press Load to reinitialize scripts.")
4946
else:
4947
emit('from_server', {'cmd': 'errmsg', 'data': 'Error occurred during generator call; please check console.'}, broadcast=True)
4948
logger.error(traceback.format_exc().replace("\033", ""))
4949
set_aibusy(0)
4950
return
4951
4952
for i in range(vars.numseqs):
4953
vars.lua_koboldbridge.generated[i+1][vars.generated_tkns] = int(genout[i, -1].item())
4954
vars.lua_koboldbridge.outputs[i+1] = utils.decodenewlines(tokenizer.decode(genout[i, -already_generated:]))
4955
4956
execute_outmod()
4957
if(vars.lua_koboldbridge.regeneration_required):
4958
vars.lua_koboldbridge.regeneration_required = False
4959
genout = []
4960
for i in range(vars.numseqs):
4961
genout.append({"generated_text": vars.lua_koboldbridge.outputs[i+1]})
4962
assert type(genout[-1]["generated_text"]) is str
4963
else:
4964
genout = [{"generated_text": utils.decodenewlines(tokenizer.decode(tokens[-already_generated:]))} for tokens in genout]
4965
4966
if(len(genout) == 1):
4967
genresult(genout[0]["generated_text"])
4968
else:
4969
if(vars.lua_koboldbridge.restart_sequence is not None and vars.lua_koboldbridge.restart_sequence > 0):
4970
genresult(genout[vars.lua_koboldbridge.restart_sequence-1]["generated_text"])
4971
else:
4972
genselect(genout)
4973
4974
# Clear CUDA cache again if using GPU
4975
if(vars.hascuda and (vars.usegpu or vars.breakmodel)):
4976
del genout
4977
gc.collect()
4978
torch.cuda.empty_cache()
4979
4980
set_aibusy(0)
4981
4982
#==================================================================#
4983
# Deal with a single return sequence from generate()
4984
#==================================================================#
4985
def genresult(genout, flash=True, ignore_formatting=False):
4986
if not vars.quiet:
4987
logger.generation(genout.encode("unicode_escape").decode("utf-8"))
4988
4989
# Format output before continuing
4990
if not ignore_formatting:
4991
genout = applyoutputformatting(genout)
4992
4993
vars.lua_koboldbridge.feedback = genout
4994
4995
if(len(genout) == 0):
4996
return
4997
4998
# Add formatted text to Actions array and refresh the game screen
4999
if(len(vars.prompt.strip()) == 0):
5000
vars.prompt = genout
5001
else:
5002
vars.actions.append(genout)
5003
if vars.actions.get_last_key() not in vars.actions_metadata:
5004
vars.actions_metadata[vars.actions.get_last_key()] = {'Selected Text': genout, 'Alternative Text': []}
5005
else:
5006
vars.actions_metadata[vars.actions.get_last_key()]['Selected Text'] = genout
5007
update_story_chunk('last')
5008
if(flash):
5009
emit('from_server', {'cmd': 'texteffect', 'data': vars.actions.get_last_key() + 1 if len(vars.actions) else 0}, broadcast=True)
5010
send_debug()
5011
5012
#==================================================================#
5013
# Send generator sequences to the UI for selection
5014
#==================================================================#
5015
def genselect(genout):
5016
i = 0
5017
for result in genout:
5018
# Apply output formatting rules to sequences
5019
result["generated_text"] = applyoutputformatting(result["generated_text"])
5020
if not vars.quiet:
5021
logger.info(f"Generation Result {i}")
5022
logger.generation(result["generated_text"].encode("unicode_escape").decode("utf-8"))
5023
i += 1
5024
5025
# Add the options to the actions metadata
5026
# If we've already generated text for this action but haven't selected one we'll want to kill all non-pinned, non-previous selection, and non-edited options then add the new ones
5027
if vars.actions.get_next_id() in vars.actions_metadata:
5028
if (vars.actions_metadata[vars.actions.get_next_id()]['Selected Text'] == ""):
5029
vars.actions_metadata[vars.actions.get_next_id()]['Alternative Text'] = [{"Text": item['Text'], "Pinned": item['Pinned'],
5030
"Previous Selection": item["Previous Selection"],
5031
"Edited": item["Edited"]} for item in vars.actions_metadata[vars.actions.get_next_id()]['Alternative Text']
5032
if item['Pinned'] or item["Previous Selection"] or item["Edited"]] + [{"Text": text["generated_text"],
5033
"Pinned": False, "Previous Selection": False, "Edited": False} for text in genout]
5034
else:
5035
vars.actions_metadata[vars.actions.get_next_id()] = {'Selected Text': '', 'Alternative Text': [{"Text": text["generated_text"], "Pinned": False, "Previous Selection": False, "Edited": False} for text in genout]}
5036
else:
5037
vars.actions_metadata[vars.actions.get_next_id()] = {'Selected Text': '', 'Alternative Text': [{"Text": text["generated_text"], "Pinned": False, "Previous Selection": False, "Edited": False} for text in genout]}
5038
5039
genout = [{"generated_text": item['Text']} for item in vars.actions_metadata[vars.actions.get_next_id()]['Alternative Text'] if (item["Previous Selection"]==False) and (item["Edited"]==False)]
5040
5041
# Store sequences in memory until selection is made
5042
vars.genseqs = genout
5043
5044
genout = [[item['Text'], "pinned" if item['Pinned'] else "normal"] for item in vars.actions_metadata[vars.actions.get_next_id()]['Alternative Text'] if (item["Previous Selection"]==False) and (item["Edited"]==False)]
5045
5046
# Send sequences to UI for selection
5047
emit('from_server', {'cmd': 'genseqs', 'data': genout}, broadcast=True)
5048
send_debug()
5049
5050
#==================================================================#
5051
# Send selected sequence to action log and refresh UI
5052
#==================================================================#
5053
def selectsequence(n):
5054
if(len(vars.genseqs) == 0):
5055
return
5056
vars.lua_koboldbridge.feedback = vars.genseqs[int(n)]["generated_text"]
5057
if(len(vars.lua_koboldbridge.feedback) != 0):
5058
vars.actions.append(vars.lua_koboldbridge.feedback)
5059
#We'll want to remove the option from the alternative text and put it in selected text
5060
vars.actions_metadata[vars.actions.get_last_key() ]['Alternative Text'] = [item for item in vars.actions_metadata[vars.actions.get_last_key()]['Alternative Text'] if item['Text'] != vars.lua_koboldbridge.feedback]
5061
vars.actions_metadata[vars.actions.get_last_key() ]['Selected Text'] = vars.lua_koboldbridge.feedback
5062
update_story_chunk('last')
5063
emit('from_server', {'cmd': 'texteffect', 'data': vars.actions.get_last_key() + 1 if len(vars.actions) else 0}, broadcast=True)
5064
emit('from_server', {'cmd': 'hidegenseqs', 'data': ''}, broadcast=True)
5065
vars.genseqs = []
5066
5067
if(vars.lua_koboldbridge.restart_sequence is not None):
5068
actionsubmit("", actionmode=vars.actionmode, force_submit=True, disable_recentrng=True)
5069
send_debug()
5070
5071
#==================================================================#
5072
# Pin/Unpin the selected sequence
5073
#==================================================================#
5074
def pinsequence(n):
5075
if n.isnumeric():
5076
text = vars.genseqs[int(n)]['generated_text']
5077
if text in [item['Text'] for item in vars.actions_metadata[vars.actions.get_next_id()]['Alternative Text']]:
5078
alternatives = vars.actions_metadata[vars.actions.get_next_id()]['Alternative Text']
5079
for i in range(len(alternatives)):
5080
if alternatives[i]['Text'] == text:
5081
alternatives[i]['Pinned'] = not alternatives[i]['Pinned']
5082
break
5083
vars.actions_metadata[vars.actions.get_next_id()]['Alternative Text'] = alternatives
5084
send_debug()
5085
5086
5087
#==================================================================#
5088
# Send transformers-style request to ngrok/colab host
5089
#==================================================================#
5090
def sendtocolab(txt, min, max):
5091
# Log request to console
5092
if not vars.quiet:
5093
print("{0}Tokens:{1}, Txt:{2}{3}".format(colors.YELLOW, min-1, txt, colors.END))
5094
5095
# Store context in memory to use it for comparison with generated content
5096
vars.lastctx = txt
5097
5098
# Build request JSON data
5099
reqdata = {
5100
'text': txt,
5101
'min': min,
5102
'max': max,
5103
'rep_pen': vars.rep_pen,
5104
'rep_pen_slope': vars.rep_pen_slope,
5105
'rep_pen_range': vars.rep_pen_range,
5106
'temperature': vars.temp,
5107
'top_p': vars.top_p,
5108
'top_k': vars.top_k,
5109
'tfs': vars.tfs,
5110
'typical': vars.typical,
5111
'topa': vars.top_a,
5112
'numseqs': vars.numseqs,
5113
'retfultxt': False
5114
}
5115
5116
# Create request
5117
req = requests.post(
5118
vars.colaburl,
5119
json = reqdata
5120
)
5121
5122
# Deal with the response
5123
if(req.status_code == 200):
5124
js = req.json()["data"]
5125
5126
# Try to be backwards compatible with outdated colab
5127
if("text" in js):
5128
genout = [getnewcontent(js["text"])]
5129
else:
5130
genout = js["seqs"]
5131
5132
for i in range(vars.numseqs):
5133
vars.lua_koboldbridge.outputs[i+1] = genout[i]
5134
5135
execute_outmod()
5136
if(vars.lua_koboldbridge.regeneration_required):
5137
vars.lua_koboldbridge.regeneration_required = False
5138
genout = []
5139
for i in range(vars.numseqs):
5140
genout.append(vars.lua_koboldbridge.outputs[i+1])
5141
assert type(genout[-1]) is str
5142
5143
if(len(genout) == 1):
5144
genresult(genout[0])
5145
else:
5146
# Convert torch output format to transformers
5147
seqs = []
5148
for seq in genout:
5149
seqs.append({"generated_text": seq})
5150
if(vars.lua_koboldbridge.restart_sequence is not None and vars.lua_koboldbridge.restart_sequence > 0):
5151
genresult(genout[vars.lua_koboldbridge.restart_sequence-1]["generated_text"])
5152
else:
5153
genselect(genout)
5154
5155
# Format output before continuing
5156
#genout = applyoutputformatting(getnewcontent(genout))
5157
5158
# Add formatted text to Actions array and refresh the game screen
5159
#vars.actions.append(genout)
5160
#refresh_story()
5161
#emit('from_server', {'cmd': 'texteffect', 'data': vars.actions.get_last_key() + 1 if len(vars.actions) else 0})
5162
5163
set_aibusy(0)
5164
else:
5165
errmsg = "Colab API Error: Failed to get a reply from the server. Please check the colab console."
5166
print("{0}{1}{2}".format(colors.RED, errmsg, colors.END))
5167
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True)
5168
set_aibusy(0)
5169
5170
5171
#==================================================================#
5172
# Send transformers-style request to KoboldAI API
5173
#==================================================================#
5174
def sendtoapi(txt, min, max):
5175
# Log request to console
5176
if not vars.quiet:
5177
print("{0}Tokens:{1}, Txt:{2}{3}".format(colors.YELLOW, min-1, txt, colors.END))
5178
5179
# Store context in memory to use it for comparison with generated content
5180
vars.lastctx = txt
5181
5182
# Build request JSON data
5183
reqdata = {
5184
'prompt': txt,
5185
'max_length': max - min + 1,
5186
'max_context_length': vars.max_length,
5187
'rep_pen': vars.rep_pen,
5188
'rep_pen_slope': vars.rep_pen_slope,
5189
'rep_pen_range': vars.rep_pen_range,
5190
'temperature': vars.temp,
5191
'top_p': vars.top_p,
5192
'top_k': vars.top_k,
5193
'top_a': vars.top_a,
5194
'tfs': vars.tfs,
5195
'typical': vars.typical,
5196
'n': vars.numseqs,
5197
}
5198
5199
# Create request
5200
while True:
5201
req = requests.post(
5202
vars.colaburl[:-8] + "/api/v1/generate",
5203
json=reqdata,
5204
)
5205
if(req.status_code == 503): # Server is currently generating something else so poll until it's our turn
5206
time.sleep(1)
5207
continue
5208
js = req.json()
5209
if(req.status_code != 200):
5210
errmsg = "KoboldAI API Error: Failed to get a reply from the server. Please check the console."
5211
print("{0}{1}{2}".format(colors.RED, json.dumps(js, indent=2), colors.END))
5212
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True)
5213
set_aibusy(0)
5214
return
5215
5216
genout = [obj["text"] for obj in js["results"]]
5217
5218
for i in range(vars.numseqs):
5219
vars.lua_koboldbridge.outputs[i+1] = genout[i]
5220
5221
execute_outmod()
5222
if(vars.lua_koboldbridge.regeneration_required):
5223
vars.lua_koboldbridge.regeneration_required = False
5224
genout = []
5225
for i in range(vars.numseqs):
5226
genout.append(vars.lua_koboldbridge.outputs[i+1])
5227
assert type(genout[-1]) is str
5228
5229
if(len(genout) == 1):
5230
genresult(genout[0])
5231
else:
5232
adjusted_genout = []
5233
for item in genout:
5234
adjusted_genout.append({"generated_text": item})
5235
# Convert torch output format to transformers
5236
seqs = []
5237
for seq in adjusted_genout:
5238
seqs.append({"generated_text": seq})
5239
if(vars.lua_koboldbridge.restart_sequence is not None and vars.lua_koboldbridge.restart_sequence > 0):
5240
genresult(adjusted_genout[vars.lua_koboldbridge.restart_sequence-1]["generated_text"])
5241
else:
5242
genselect(adjusted_genout)
5243
5244
set_aibusy(0)
5245
return
5246
5247
#==================================================================#
5248
# Send transformers-style request to KoboldAI Cluster
5249
#==================================================================#
5250
def sendtocluster(txt, min, max):
5251
# Log request to console
5252
if not vars.quiet:
5253
logger.debug(f"Tokens Min:{min-1}")
5254
logger.prompt(txt.encode("unicode_escape").decode("utf-8"))
5255
5256
# Store context in memory to use it for comparison with generated content
5257
vars.lastctx = txt
5258
# Build request JSON data
5259
reqdata = {
5260
'max_length': max - min + 1,
5261
'max_context_length': vars.max_length,
5262
'rep_pen': vars.rep_pen,
5263
'rep_pen_slope': vars.rep_pen_slope,
5264
'rep_pen_range': vars.rep_pen_range,
5265
'temperature': vars.temp,
5266
'top_p': vars.top_p,
5267
'top_k': vars.top_k,
5268
'top_a': vars.top_a,
5269
'tfs': vars.tfs,
5270
'typical': vars.typical,
5271
'n': vars.numseqs,
5272
}
5273
cluster_metadata = {
5274
'prompt': txt,
5275
'params': reqdata,
5276
'models': vars.cluster_requested_models,
5277
'trusted_workers': False,
5278
}
5279
client_agent = "KoboldAI:1.19.3:koboldai.org"
5280
cluster_headers = {
5281
'apikey': vars.apikey,
5282
"Client-Agent": client_agent
5283
}
5284
logger.debug(f"Horde Payload: {cluster_metadata}")
5285
try:
5286
# Create request
5287
req = requests.post(
5288
vars.colaburl[:-8] + "/api/v2/generate/text/async",
5289
json=cluster_metadata,
5290
headers=cluster_headers,
5291
)
5292
except requests.exceptions.ConnectionError:
5293
errmsg = f"Horde unavailable. Please try again later"
5294
logger.error(errmsg)
5295
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True)
5296
set_aibusy(0)
5297
return
5298
if(req.status_code == 503):
5299
errmsg = f"KoboldAI API Error: No available KoboldAI servers found in Horde to fulfil this request using the selected models or other properties."
5300
logger.error(req.text)
5301
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True)
5302
set_aibusy(0)
5303
return
5304
if(not req.ok):
5305
errmsg = f"KoboldAI API Error: Failed to get a standard reply from the Horde. Please check the console."
5306
logger.error(req.text)
5307
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True)
5308
set_aibusy(0)
5309
return
5310
try:
5311
js = req.json()
5312
except requests.exceptions.JSONDecodeError:
5313
errmsg = f"Unexpected message received from the Horde: '{req.text}'"
5314
logger.error(errmsg)
5315
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True)
5316
set_aibusy(0)
5317
return
5318
5319
request_id = js["id"]
5320
logger.debug("Horde Request ID: {}".format(request_id))
5321
5322
cluster_agent_headers = {
5323
"Client-Agent": client_agent
5324
}
5325
finished = False
5326
5327
while not finished:
5328
try:
5329
req = requests.get(vars.colaburl[:-8] + "/api/v2/generate/text/status/" + request_id, headers=cluster_agent_headers)
5330
except requests.exceptions.ConnectionError:
5331
errmsg = f"Horde unavailable. Please try again later"
5332
logger.error(errmsg)
5333
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True)
5334
set_aibusy(0)
5335
return
5336
5337
if not req.ok:
5338
errmsg = f"KoboldAI API Error: Failed to get a standard reply from the Horde. Please check the console."
5339
logger.error(req.text)
5340
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True)
5341
set_aibusy(0)
5342
return
5343
5344
try:
5345
req_status = req.json()
5346
except requests.exceptions.JSONDecodeError:
5347
errmsg = f"Unexpected message received from the KoboldAI Horde: '{req.text}'"
5348
logger.error(errmsg)
5349
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True)
5350
set_aibusy(0)
5351
return
5352
5353
if "done" not in req_status:
5354
errmsg = f"Unexpected response received from the KoboldAI Horde: '{js}'"
5355
logger.error(errmsg)
5356
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True)
5357
set_aibusy(0)
5358
return
5359
5360
finished = req_status["done"]
5361
5362
if not finished:
5363
logger.debug(req_status)
5364
time.sleep(1)
5365
5366
logger.debug("Last Horde Status Message: {}".format(js))
5367
if req_status["faulted"]:
5368
errmsg = "Horde Text generation faulted! Please try again"
5369
logger.error(errmsg)
5370
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True)
5371
set_aibusy(0)
5372
return
5373
5374
generations = req_status['generations']
5375
gen_workers = [(cgen['worker_name'],cgen['worker_id']) for cgen in generations]
5376
logger.info(f"Generations by: {gen_workers}")
5377
5378
5379
5380
5381
5382
5383
# Just in case we want to announce it to the user
5384
if len(generations) == 1:
5385
warnmsg = f"Text generated by {[w[0] for w in gen_workers]}"
5386
emit('from_server', {'cmd': 'warnmsg', 'data': warnmsg}, broadcast=True)
5387
genout = [cgen['text'] for cgen in generations]
5388
5389
for i in range(vars.numseqs):
5390
vars.lua_koboldbridge.outputs[i+1] = genout[i]
5391
5392
execute_outmod()
5393
if(vars.lua_koboldbridge.regeneration_required):
5394
vars.lua_koboldbridge.regeneration_required = False
5395
genout = []
5396
for i in range(vars.numseqs):
5397
genout.append(vars.lua_koboldbridge.outputs[i+1])
5398
assert type(genout[-1]) is str
5399
5400
if(len(genout) == 1):
5401
genresult(genout[0])
5402
else:
5403
adjusted_genout = []
5404
for item in genout:
5405
adjusted_genout.append({"generated_text": item})
5406
# Convert torch output format to transformers
5407
seqs = []
5408
for seq in adjusted_genout:
5409
seqs.append({"generated_text": seq})
5410
if(vars.lua_koboldbridge.restart_sequence is not None and vars.lua_koboldbridge.restart_sequence > 0):
5411
genresult(adjusted_genout[vars.lua_koboldbridge.restart_sequence-1]["generated_text"])
5412
else:
5413
genselect(adjusted_genout)
5414
5415
set_aibusy(0)
5416
return
5417
5418
#==================================================================#
5419
# Send text to TPU mesh transformer backend
5420
#==================================================================#
5421
def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
5422
if(vars.full_determinism):
5423
tpu_mtj_backend.set_rng_seed(vars.seed)
5424
5425
vars.generated_tkns = 0
5426
5427
if(found_entries is None):
5428
found_entries = set()
5429
found_entries = tuple(found_entries.copy() for _ in range(vars.numseqs))
5430
5431
if not vars.quiet:
5432
logger.debug(f"Prompt Min:{minimum}, Max:{maximum}")
5433
logger.prompt(utils.decodenewlines(tokenizer.decode(txt)).encode("unicode_escape").decode("utf-8"))
5434
5435
vars._actions = vars.actions
5436
vars._prompt = vars.prompt
5437
if(vars.dynamicscan):
5438
vars._actions = vars._actions.copy()
5439
5440
# Submit input text to generator
5441
try:
5442
soft_tokens = tpumtjgetsofttokens()
5443
5444
global past
5445
5446
socketio.start_background_task(copy_current_request_context(check_for_backend_compilation))
5447
5448
if(vars.dynamicscan or (not vars.nogenmod and vars.has_genmod)):
5449
5450
context = np.tile(np.uint32(txt), (vars.numseqs, 1))
5451
past = np.empty((vars.numseqs, 0), dtype=np.uint32)
5452
5453
while(True):
5454
genout, n_generated, regeneration_required, halt = tpool.execute(
5455
tpu_mtj_backend.infer_dynamic,
5456
context,
5457
gen_len = maximum-minimum+1,
5458
numseqs=vars.numseqs,
5459
soft_embeddings=vars.sp,
5460
soft_tokens=soft_tokens,
5461
excluded_world_info=found_entries,
5462
)
5463
5464
past = np.pad(past, ((0, 0), (0, n_generated)))
5465
for r in range(vars.numseqs):
5466
for c in range(vars.lua_koboldbridge.generated_cols):
5467
assert vars.lua_koboldbridge.generated[r+1][c+1] is not None
5468
past[r, c] = vars.lua_koboldbridge.generated[r+1][c+1]
5469
5470
if(vars.abort or halt or not regeneration_required):
5471
break
5472
print("(regeneration triggered)")
5473
5474
encoded = []
5475
for i in range(vars.numseqs):
5476
txt = utils.decodenewlines(tokenizer.decode(past[i]))
5477
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True, actions=vars._actions)
5478
found_entries[i].update(_found_entries)
5479
txt, _, _ = calcsubmitbudget(len(vars._actions), winfo, mem, anotetxt, vars._actions, submission=txt)
5480
encoded.append(np.array(txt, dtype=np.uint32))
5481
max_length = len(max(encoded, key=len))
5482
encoded = np.stack(tuple(np.pad(e, (max_length - len(e), 0), constant_values=tpu_mtj_backend.pad_token_id) for e in encoded))
5483
context = np.concatenate(
5484
(
5485
encoded,
5486
past,
5487
),
5488
axis=-1,
5489
)
5490
5491
else:
5492
genout = tpool.execute(
5493
tpu_mtj_backend.infer_static,
5494
np.uint32(txt),
5495
gen_len = maximum-minimum+1,
5496
temp=vars.temp,
5497
top_p=vars.top_p,
5498
top_k=vars.top_k,
5499
tfs=vars.tfs,
5500
typical=vars.typical,
5501
top_a=vars.top_a,
5502
numseqs=vars.numseqs,
5503
repetition_penalty=vars.rep_pen,
5504
rpslope=vars.rep_pen_slope,
5505
rprange=vars.rep_pen_range,
5506
soft_embeddings=vars.sp,
5507
soft_tokens=soft_tokens,
5508
sampler_order=vars.sampler_order,
5509
)
5510
past = genout
5511
for i in range(vars.numseqs):
5512
vars.lua_koboldbridge.generated[i+1] = vars.lua_state.table(*genout[i].tolist())
5513
vars.lua_koboldbridge.generated_cols = vars.generated_tkns = genout[0].shape[-1]
5514
5515
except Exception as e:
5516
if(issubclass(type(e), lupa.LuaError)):
5517
vars.lua_koboldbridge.obliterate_multiverse()
5518
vars.lua_running = False
5519
emit('from_server', {'cmd': 'errmsg', 'data': 'Lua script error; please check console.'}, broadcast=True)
5520
sendUSStatItems()
5521
logger.error('LUA ERROR: ' + str(e).replace("\033", ""))
5522
logger.warning("Lua engine stopped; please open 'Userscripts' and press Load to reinitialize scripts.")
5523
else:
5524
emit('from_server', {'cmd': 'errmsg', 'data': 'Error occurred during generator call; please check console.'}, broadcast=True)
5525
print("{0}{1}{2}".format(colors.RED, traceback.format_exc().replace("\033", ""), colors.END), file=sys.stderr)
5526
set_aibusy(0)
5527
return
5528
5529
for i in range(vars.numseqs):
5530
vars.lua_koboldbridge.outputs[i+1] = utils.decodenewlines(tokenizer.decode(past[i]))
5531
genout = past
5532
5533
execute_outmod()
5534
if(vars.lua_koboldbridge.regeneration_required):
5535
vars.lua_koboldbridge.regeneration_required = False
5536
genout = []
5537
for i in range(vars.numseqs):
5538
genout.append({"generated_text": vars.lua_koboldbridge.outputs[i+1]})
5539
assert type(genout[-1]["generated_text"]) is str
5540
else:
5541
genout = [{"generated_text": utils.decodenewlines(tokenizer.decode(txt))} for txt in genout]
5542
5543
if(len(genout) == 1):
5544
genresult(genout[0]["generated_text"])
5545
else:
5546
if(vars.lua_koboldbridge.restart_sequence is not None and vars.lua_koboldbridge.restart_sequence > 0):
5547
genresult(genout[vars.lua_koboldbridge.restart_sequence-1]["generated_text"])
5548
else:
5549
genselect(genout)
5550
5551
set_aibusy(0)
5552
5553
5554
#==================================================================#
5555
# Replaces returns and newlines with HTML breaks
5556
#==================================================================#
5557
def formatforhtml(txt):
5558
return txt.replace("\\r\\n", "<br/>").replace("\\r", "<br/>").replace("\\n", "<br/>").replace("\r\n", "<br/>").replace('\n', '<br/>').replace('\r', '<br/>').replace('&lt;/s&gt;', '<br/>')
5559
5560
#==================================================================#
5561
# Strips submitted text from the text returned by the AI
5562
#==================================================================#
5563
def getnewcontent(txt):
5564
# If the submitted context was blank, then everything is new
5565
if(vars.lastctx == ""):
5566
return txt
5567
5568
# Tokenize the last context and the generated content
5569
ctxtokens = tokenizer.encode(utils.encodenewlines(vars.lastctx), max_length=int(2e9), truncation=True)
5570
txttokens = tokenizer.encode(utils.encodenewlines(txt), max_length=int(2e9), truncation=True)
5571
dif = (len(txttokens) - len(ctxtokens)) * -1
5572
5573
# Remove the context from the returned text
5574
newtokens = txttokens[dif:]
5575
5576
return utils.decodenewlines(tokenizer.decode(newtokens))
5577
5578
#==================================================================#
5579
# Applies chosen formatting options to text submitted to AI
5580
#==================================================================#
5581
def applyinputformatting(txt):
5582
# Add sentence spacing
5583
if(vars.formatoptns["frmtadsnsp"]):
5584
txt = utils.addsentencespacing(txt, vars)
5585
5586
return txt
5587
5588
#==================================================================#
5589
# Applies chosen formatting options to text returned from AI
5590
#==================================================================#
5591
def applyoutputformatting(txt):
5592
# Use standard quotes and apostrophes
5593
txt = utils.fixquotes(txt)
5594
5595
# Adventure mode clipping of all characters after '>'
5596
if(vars.adventure):
5597
txt = vars.acregex_ai.sub('', txt)
5598
5599
# Trim incomplete sentences
5600
if(vars.formatoptns["frmttriminc"] and not vars.chatmode):
5601
txt = utils.trimincompletesentence(txt)
5602
# Replace blank lines
5603
if(vars.formatoptns["frmtrmblln"] or vars.chatmode):
5604
txt = utils.replaceblanklines(txt)
5605
# Remove special characters
5606
if(vars.formatoptns["frmtrmspch"]):
5607
txt = utils.removespecialchars(txt, vars)
5608
# Single Line Mode
5609
if(vars.formatoptns["singleline"] or vars.chatmode):
5610
txt = utils.singlelineprocessing(txt, vars)
5611
5612
return txt
5613
5614
#==================================================================#
5615
# Sends the current story content to the Game Screen
5616
#==================================================================#
5617
def refresh_story():
5618
text_parts = ['<chunk n="0" id="n0" tabindex="-1">', vars.comregex_ui.sub(lambda m: '\n'.join('<comment>' + l + '</comment>' for l in m.group().split('\n')), html.escape(vars.prompt)), '</chunk>']
5619
for idx in vars.actions:
5620
item = vars.actions[idx]
5621
idx += 1
5622
item = html.escape(item)
5623
item = vars.comregex_ui.sub(lambda m: '\n'.join('<comment>' + l + '</comment>' for l in m.group().split('\n')), item) # Add special formatting to comments
5624
item = vars.acregex_ui.sub('<action>\\1</action>', item) # Add special formatting to adventure actions
5625
text_parts.extend(('<chunk n="', str(idx), '" id="n', str(idx), '" tabindex="-1">', item, '</chunk>'))
5626
emit('from_server', {'cmd': 'updatescreen', 'gamestarted': vars.gamestarted, 'data': formatforhtml(''.join(text_parts))}, broadcast=True)
5627
5628
5629
#==================================================================#
5630
# Signals the Game Screen to update one of the chunks
5631
#==================================================================#
5632
def update_story_chunk(idx: Union[int, str]):
5633
if idx == 'last':
5634
if len(vars.actions) <= 1:
5635
# In this case, we are better off just refreshing the whole thing as the
5636
# prompt might not have been shown yet (with a "Generating story..."
5637
# message instead).
5638
refresh_story()
5639
setgamesaved(False)
5640
return
5641
5642
idx = (vars.actions.get_last_key() if len(vars.actions) else 0) + 1
5643
5644
if idx == 0:
5645
text = vars.prompt
5646
else:
5647
# Actions are 0 based, but in chunks 0 is the prompt.
5648
# So the chunk index is one more than the corresponding action index.
5649
if(idx - 1 not in vars.actions):
5650
return
5651
text = vars.actions[idx - 1]
5652
5653
item = html.escape(text)
5654
item = vars.comregex_ui.sub(lambda m: '\n'.join('<comment>' + l + '</comment>' for l in m.group().split('\n')), item) # Add special formatting to comments
5655
item = vars.acregex_ui.sub('<action>\\1</action>', item) # Add special formatting to adventure actions
5656
5657
chunk_text = f'<chunk n="{idx}" id="n{idx}" tabindex="-1">{formatforhtml(item)}</chunk>'
5658
emit('from_server', {'cmd': 'updatechunk', 'data': {'index': idx, 'html': chunk_text}}, broadcast=True)
5659
5660
setgamesaved(False)
5661
5662
#If we've set the auto save flag, we'll now save the file
5663
if vars.autosave and (".json" in vars.savedir):
5664
save()
5665
5666
5667
#==================================================================#
5668
# Signals the Game Screen to remove one of the chunks
5669
#==================================================================#
5670
def remove_story_chunk(idx: int):
5671
emit('from_server', {'cmd': 'removechunk', 'data': idx}, broadcast=True)
5672
setgamesaved(False)
5673
5674
5675
#==================================================================#
5676
# Sends the current generator settings to the Game Menu
5677
#==================================================================#
5678
def refresh_settings():
5679
# Suppress toggle change events while loading state
5680
emit('from_server', {'cmd': 'allowtoggle', 'data': False}, broadcast=True)
5681
5682
if(vars.model != "InferKit"):
5683
emit('from_server', {'cmd': 'updatetemp', 'data': vars.temp}, broadcast=True)
5684
emit('from_server', {'cmd': 'updatetopp', 'data': vars.top_p}, broadcast=True)
5685
emit('from_server', {'cmd': 'updatetopk', 'data': vars.top_k}, broadcast=True)
5686
emit('from_server', {'cmd': 'updatetfs', 'data': vars.tfs}, broadcast=True)
5687
emit('from_server', {'cmd': 'updatetypical', 'data': vars.typical}, broadcast=True)
5688
emit('from_server', {'cmd': 'updatetopa', 'data': vars.top_a}, broadcast=True)
5689
emit('from_server', {'cmd': 'updatereppen', 'data': vars.rep_pen}, broadcast=True)
5690
emit('from_server', {'cmd': 'updatereppenslope', 'data': vars.rep_pen_slope}, broadcast=True)
5691
emit('from_server', {'cmd': 'updatereppenrange', 'data': vars.rep_pen_range}, broadcast=True)
5692
emit('from_server', {'cmd': 'updateoutlen', 'data': vars.genamt}, broadcast=True)
5693
emit('from_server', {'cmd': 'updatetknmax', 'data': vars.max_length}, broadcast=True)
5694
emit('from_server', {'cmd': 'updatenumseq', 'data': vars.numseqs}, broadcast=True)
5695
else:
5696
emit('from_server', {'cmd': 'updatetemp', 'data': vars.temp}, broadcast=True)
5697
emit('from_server', {'cmd': 'updatetopp', 'data': vars.top_p}, broadcast=True)
5698
emit('from_server', {'cmd': 'updateikgen', 'data': vars.ikgen}, broadcast=True)
5699
5700
emit('from_server', {'cmd': 'updateanotedepth', 'data': vars.andepth}, broadcast=True)
5701
emit('from_server', {'cmd': 'updatewidepth', 'data': vars.widepth}, broadcast=True)
5702
emit('from_server', {'cmd': 'updateuseprompt', 'data': vars.useprompt}, broadcast=True)
5703
emit('from_server', {'cmd': 'updateadventure', 'data': vars.adventure}, broadcast=True)
5704
emit('from_server', {'cmd': 'updatechatmode', 'data': vars.chatmode}, broadcast=True)
5705
emit('from_server', {'cmd': 'updatedynamicscan', 'data': vars.dynamicscan}, broadcast=True)
5706
emit('from_server', {'cmd': 'updateautosave', 'data': vars.autosave}, broadcast=True)
5707
emit('from_server', {'cmd': 'updatenopromptgen', 'data': vars.nopromptgen}, broadcast=True)
5708
emit('from_server', {'cmd': 'updaterngpersist', 'data': vars.rngpersist}, broadcast=True)
5709
emit('from_server', {'cmd': 'updatenogenmod', 'data': vars.nogenmod}, broadcast=True)
5710
emit('from_server', {'cmd': 'updatefulldeterminism', 'data': vars.full_determinism}, broadcast=True)
5711
5712
emit('from_server', {'cmd': 'updatefrmttriminc', 'data': vars.formatoptns["frmttriminc"]}, broadcast=True)
5713
emit('from_server', {'cmd': 'updatefrmtrmblln', 'data': vars.formatoptns["frmtrmblln"]}, broadcast=True)
5714
emit('from_server', {'cmd': 'updatefrmtrmspch', 'data': vars.formatoptns["frmtrmspch"]}, broadcast=True)
5715
emit('from_server', {'cmd': 'updatefrmtadsnsp', 'data': vars.formatoptns["frmtadsnsp"]}, broadcast=True)
5716
emit('from_server', {'cmd': 'updatesingleline', 'data': vars.formatoptns["singleline"]}, broadcast=True)
5717
emit('from_server', {'cmd': 'updateoutputstreaming', 'data': vars.output_streaming}, broadcast=True)
5718
emit('from_server', {'cmd': 'updateshowbudget', 'data': vars.show_budget}, broadcast=True)
5719
emit('from_server', {'cmd': 'updateshowprobs', 'data': vars.show_probs}, broadcast=True)
5720
5721
# Allow toggle events again
5722
emit('from_server', {'cmd': 'allowtoggle', 'data': True}, broadcast=True)
5723
5724
#==================================================================#
5725
# Sets the logical and display states for the AI Busy condition
5726
#==================================================================#
5727
def set_aibusy(state):
5728
if(vars.disable_set_aibusy):
5729
return
5730
if(state):
5731
vars.aibusy = True
5732
emit('from_server', {'cmd': 'setgamestate', 'data': 'wait'}, broadcast=True)
5733
else:
5734
vars.aibusy = False
5735
emit('from_server', {'cmd': 'setgamestate', 'data': 'ready'}, broadcast=True)
5736
5737
#==================================================================#
5738
#
5739
#==================================================================#
5740
def editrequest(n):
5741
if(n == 0):
5742
txt = vars.prompt
5743
else:
5744
txt = vars.actions[n-1]
5745
5746
vars.editln = n
5747
emit('from_server', {'cmd': 'setinputtext', 'data': txt}, broadcast=True)
5748
emit('from_server', {'cmd': 'enablesubmit', 'data': ''}, broadcast=True)
5749
5750
#==================================================================#
5751
#
5752
#==================================================================#
5753
def editsubmit(data):
5754
vars.recentedit = True
5755
if(vars.editln == 0):
5756
vars.prompt = data
5757
else:
5758
vars.actions_metadata[vars.editln-1]['Alternative Text'] = vars.actions_metadata[vars.editln-1]['Alternative Text'] + [{"Text": vars.actions[vars.editln-1], "Pinned": False,
5759
"Previous Selection": False,
5760
"Edited": True}]
5761
vars.actions_metadata[vars.editln-1]['Selected Text'] = data
5762
vars.actions[vars.editln-1] = data
5763
5764
vars.mode = "play"
5765
update_story_chunk(vars.editln)
5766
emit('from_server', {'cmd': 'texteffect', 'data': vars.editln}, broadcast=True)
5767
emit('from_server', {'cmd': 'editmode', 'data': 'false'})
5768
send_debug()
5769
5770
#==================================================================#
5771
#
5772
#==================================================================#
5773
def deleterequest():
5774
vars.recentedit = True
5775
# Don't delete prompt
5776
if(vars.editln == 0):
5777
# Send error message
5778
pass
5779
else:
5780
vars.actions_metadata[vars.editln-1]['Alternative Text'] = [{"Text": vars.actions[vars.editln-1], "Pinned": False,
5781
"Previous Selection": True, "Edited": False}] + vars.actions_metadata[vars.editln-1]['Alternative Text']
5782
vars.actions_metadata[vars.editln-1]['Selected Text'] = ''
5783
vars.actions[vars.editln-1] = ''
5784
vars.mode = "play"
5785
remove_story_chunk(vars.editln)
5786
emit('from_server', {'cmd': 'editmode', 'data': 'false'})
5787
send_debug()
5788
5789
#==================================================================#
5790
#
5791
#==================================================================#
5792
def inlineedit(chunk, data):
5793
vars.recentedit = True
5794
chunk = int(chunk)
5795
if(chunk == 0):
5796
if(len(data.strip()) == 0):
5797
return
5798
vars.prompt = data
5799
else:
5800
if(chunk-1 in vars.actions):
5801
vars.actions_metadata[chunk-1]['Alternative Text'] = vars.actions_metadata[chunk-1]['Alternative Text'] + [{"Text": vars.actions[chunk-1], "Pinned": False,
5802
"Previous Selection": False,
5803
"Edited": True}]
5804
vars.actions_metadata[chunk-1]['Selected Text'] = data
5805
vars.actions[chunk-1] = data
5806
else:
5807
logger.warning(f"Attempted to edit non-existent chunk {chunk}")
5808
5809
setgamesaved(False)
5810
update_story_chunk(chunk)
5811
emit('from_server', {'cmd': 'texteffect', 'data': chunk}, broadcast=True)
5812
emit('from_server', {'cmd': 'editmode', 'data': 'false'}, broadcast=True)
5813
send_debug()
5814
5815
#==================================================================#
5816
#
5817
#==================================================================#
5818
def inlinedelete(chunk):
5819
vars.recentedit = True
5820
chunk = int(chunk)
5821
# Don't delete prompt
5822
if(chunk == 0):
5823
# Send error message
5824
update_story_chunk(chunk)
5825
emit('from_server', {'cmd': 'errmsg', 'data': "Cannot delete the prompt."})
5826
emit('from_server', {'cmd': 'editmode', 'data': 'false'}, broadcast=True)
5827
else:
5828
if(chunk-1 in vars.actions):
5829
vars.actions_metadata[chunk-1]['Alternative Text'] = [{"Text": vars.actions[chunk-1], "Pinned": False,
5830
"Previous Selection": True,
5831
"Edited": False}] + vars.actions_metadata[chunk-1]['Alternative Text']
5832
vars.actions_metadata[chunk-1]['Selected Text'] = ''
5833
del vars.actions[chunk-1]
5834
else:
5835
logger.warning(f"Attempted to delete non-existent chunk {chunk}")
5836
setgamesaved(False)
5837
remove_story_chunk(chunk)
5838
emit('from_server', {'cmd': 'editmode', 'data': 'false'}, broadcast=True)
5839
send_debug()
5840
5841
#==================================================================#
5842
# Toggles the game mode for memory editing and sends UI commands
5843
#==================================================================#
5844
def togglememorymode():
5845
if(vars.mode == "play"):
5846
vars.mode = "memory"
5847
emit('from_server', {'cmd': 'memmode', 'data': 'true'}, broadcast=True)
5848
emit('from_server', {'cmd': 'setinputtext', 'data': vars.memory}, broadcast=True)
5849
emit('from_server', {'cmd': 'setanote', 'data': vars.authornote}, broadcast=True)
5850
emit('from_server', {'cmd': 'setanotetemplate', 'data': vars.authornotetemplate}, broadcast=True)
5851
elif(vars.mode == "memory"):
5852
vars.mode = "play"
5853
emit('from_server', {'cmd': 'memmode', 'data': 'false'}, broadcast=True)
5854
5855
#==================================================================#
5856
# Toggles the game mode for WI editing and sends UI commands
5857
#==================================================================#
5858
def togglewimode():
5859
if(vars.mode == "play"):
5860
vars.mode = "wi"
5861
emit('from_server', {'cmd': 'wimode', 'data': 'true'}, broadcast=True)
5862
elif(vars.mode == "wi"):
5863
# Commit WI fields first
5864
requestwi()
5865
# Then set UI state back to Play
5866
vars.mode = "play"
5867
emit('from_server', {'cmd': 'wimode', 'data': 'false'}, broadcast=True)
5868
sendwi()
5869
5870
#==================================================================#
5871
#
5872
#==================================================================#
5873
def addwiitem(folder_uid=None):
5874
assert folder_uid is None or folder_uid in vars.wifolders_d
5875
ob = {"key": "", "keysecondary": "", "content": "", "comment": "", "folder": folder_uid, "num": len(vars.worldinfo), "init": False, "selective": False, "constant": False}
5876
vars.worldinfo.append(ob)
5877
while(True):
5878
uid = int.from_bytes(os.urandom(4), "little", signed=True)
5879
if(uid not in vars.worldinfo_u):
5880
break
5881
vars.worldinfo_u[uid] = vars.worldinfo[-1]
5882
vars.worldinfo[-1]["uid"] = uid
5883
if(folder_uid is not None):
5884
vars.wifolders_u[folder_uid].append(vars.worldinfo[-1])
5885
emit('from_server', {'cmd': 'addwiitem', 'data': ob}, broadcast=True)
5886
5887
#==================================================================#
5888
# Creates a new WI folder with an unused cryptographically secure random UID
5889
#==================================================================#
5890
def addwifolder():
5891
while(True):
5892
uid = int.from_bytes(os.urandom(4), "little", signed=True)
5893
if(uid not in vars.wifolders_d):
5894
break
5895
ob = {"name": "", "collapsed": False}
5896
vars.wifolders_d[uid] = ob
5897
vars.wifolders_l.append(uid)
5898
vars.wifolders_u[uid] = []
5899
emit('from_server', {'cmd': 'addwifolder', 'uid': uid, 'data': ob}, broadcast=True)
5900
addwiitem(folder_uid=uid)
5901
5902
#==================================================================#
5903
# Move the WI entry with UID src so that it immediately precedes
5904
# the WI entry with UID dst
5905
#==================================================================#
5906
def movewiitem(dst, src):
5907
setgamesaved(False)
5908
if(vars.worldinfo_u[src]["folder"] is not None):
5909
for i, e in enumerate(vars.wifolders_u[vars.worldinfo_u[src]["folder"]]):
5910
if(e is vars.worldinfo_u[src]):
5911
vars.wifolders_u[vars.worldinfo_u[src]["folder"]].pop(i)
5912
break
5913
if(vars.worldinfo_u[dst]["folder"] is not None):
5914
vars.wifolders_u[vars.worldinfo_u[dst]["folder"]].append(vars.worldinfo_u[src])
5915
vars.worldinfo_u[src]["folder"] = vars.worldinfo_u[dst]["folder"]
5916
for i, e in enumerate(vars.worldinfo):
5917
if(e is vars.worldinfo_u[src]):
5918
_src = i
5919
elif(e is vars.worldinfo_u[dst]):
5920
_dst = i
5921
vars.worldinfo.insert(_dst - (_dst >= _src), vars.worldinfo.pop(_src))
5922
sendwi()
5923
5924
#==================================================================#
5925
# Move the WI folder with UID src so that it immediately precedes
5926
# the WI folder with UID dst
5927
#==================================================================#
5928
def movewifolder(dst, src):
5929
setgamesaved(False)
5930
vars.wifolders_l.remove(src)
5931
if(dst is None):
5932
# If dst is None, that means we should move src to be the last folder
5933
vars.wifolders_l.append(src)
5934
else:
5935
vars.wifolders_l.insert(vars.wifolders_l.index(dst), src)
5936
sendwi()
5937
5938
#==================================================================#
5939
#
5940
#==================================================================#
5941
def sendwi():
5942
# Cache len of WI
5943
ln = len(vars.worldinfo)
5944
5945
# Clear contents of WI container
5946
emit('from_server', {'cmd': 'wistart', 'wifolders_d': vars.wifolders_d, 'wifolders_l': vars.wifolders_l, 'data': ''}, broadcast=True)
5947
5948
# Stable-sort WI entries in order of folder
5949
stablesortwi()
5950
5951
vars.worldinfo_i = [wi for wi in vars.worldinfo if wi["init"]]
5952
5953
# If there are no WI entries, send an empty WI object
5954
if(ln == 0):
5955
addwiitem()
5956
else:
5957
# Send contents of WI array
5958
last_folder = ...
5959
for wi in vars.worldinfo:
5960
if(wi["folder"] != last_folder):
5961
emit('from_server', {'cmd': 'addwifolder', 'uid': wi["folder"], 'data': vars.wifolders_d[wi["folder"]] if wi["folder"] is not None else None}, broadcast=True)
5962
last_folder = wi["folder"]
5963
ob = wi
5964
emit('from_server', {'cmd': 'addwiitem', 'data': ob}, broadcast=True)
5965
5966
emit('from_server', {'cmd': 'wifinish', 'data': ''}, broadcast=True)
5967
5968
#==================================================================#
5969
# Request current contents of all WI HTML elements
5970
#==================================================================#
5971
def requestwi():
5972
list = []
5973
for wi in vars.worldinfo:
5974
list.append(wi["num"])
5975
emit('from_server', {'cmd': 'requestwiitem', 'data': list})
5976
5977
#==================================================================#
5978
# Stable-sort WI items so that items in the same folder are adjacent,
5979
# and items in different folders are sorted based on the order of the folders
5980
#==================================================================#
5981
def stablesortwi():
5982
mapping = {uid: index for index, uid in enumerate(vars.wifolders_l)}
5983
vars.worldinfo.sort(key=lambda x: mapping[x["folder"]] if x["folder"] is not None else float("inf"))
5984
last_folder = ...
5985
last_wi = None
5986
for i, wi in enumerate(vars.worldinfo):
5987
wi["num"] = i
5988
wi["init"] = True
5989
if(wi["folder"] != last_folder):
5990
if(last_wi is not None and last_folder is not ...):
5991
last_wi["init"] = False
5992
last_folder = wi["folder"]
5993
last_wi = wi
5994
if(last_wi is not None):
5995
last_wi["init"] = False
5996
for folder in vars.wifolders_u:
5997
vars.wifolders_u[folder].sort(key=lambda x: x["num"])
5998
5999
#==================================================================#
6000
# Extract object from server and send it to WI objects
6001
#==================================================================#
6002
def commitwi(ar):
6003
for ob in ar:
6004
ob["uid"] = int(ob["uid"])
6005
vars.worldinfo_u[ob["uid"]]["key"] = ob["key"]
6006
vars.worldinfo_u[ob["uid"]]["keysecondary"] = ob["keysecondary"]
6007
vars.worldinfo_u[ob["uid"]]["content"] = ob["content"]
6008
vars.worldinfo_u[ob["uid"]]["comment"] = ob.get("comment", "")
6009
vars.worldinfo_u[ob["uid"]]["folder"] = ob.get("folder", None)
6010
vars.worldinfo_u[ob["uid"]]["selective"] = ob["selective"]
6011
vars.worldinfo_u[ob["uid"]]["constant"] = ob.get("constant", False)
6012
stablesortwi()
6013
vars.worldinfo_i = [wi for wi in vars.worldinfo if wi["init"]]
6014
6015
#==================================================================#
6016
#
6017
#==================================================================#
6018
def deletewi(uid):
6019
if(uid in vars.worldinfo_u):
6020
setgamesaved(False)
6021
# Store UID of deletion request
6022
vars.deletewi = uid
6023
if(vars.deletewi is not None):
6024
if(vars.worldinfo_u[vars.deletewi]["folder"] is not None):
6025
for i, e in enumerate(vars.wifolders_u[vars.worldinfo_u[vars.deletewi]["folder"]]):
6026
if(e is vars.worldinfo_u[vars.deletewi]):
6027
vars.wifolders_u[vars.worldinfo_u[vars.deletewi]["folder"]].pop(i)
6028
for i, e in enumerate(vars.worldinfo):
6029
if(e is vars.worldinfo_u[vars.deletewi]):
6030
del vars.worldinfo[i]
6031
break
6032
del vars.worldinfo_u[vars.deletewi]
6033
# Send the new WI array structure
6034
sendwi()
6035
# And reset deletewi
6036
vars.deletewi = None
6037
6038
#==================================================================#
6039
#
6040
#==================================================================#
6041
def deletewifolder(uid):
6042
uid = int(uid)
6043
del vars.wifolders_u[uid]
6044
del vars.wifolders_d[uid]
6045
del vars.wifolders_l[vars.wifolders_l.index(uid)]
6046
setgamesaved(False)
6047
# Delete uninitialized entries in the folder we're going to delete
6048
vars.worldinfo = [wi for wi in vars.worldinfo if wi["folder"] != uid or wi["init"]]
6049
vars.worldinfo_i = [wi for wi in vars.worldinfo if wi["init"]]
6050
# Move WI entries that are inside of the folder we're going to delete
6051
# so that they're outside of all folders
6052
for wi in vars.worldinfo:
6053
if(wi["folder"] == uid):
6054
wi["folder"] = None
6055
6056
sendwi()
6057
6058
#==================================================================#
6059
# Look for WI keys in text to generator
6060
#==================================================================#
6061
def checkworldinfo(txt, allowed_entries=None, allowed_folders=None, force_use_txt=False, scan_story=True, actions=None):
6062
original_txt = txt
6063
6064
if(actions is None):
6065
actions = vars.actions
6066
6067
# Dont go any further if WI is empty
6068
if(len(vars.worldinfo) == 0):
6069
return "", set()
6070
6071
# Cache actions length
6072
ln = len(actions)
6073
6074
# Don't bother calculating action history if widepth is 0
6075
if(vars.widepth > 0 and scan_story):
6076
depth = vars.widepth
6077
# If this is not a continue, add 1 to widepth since submitted
6078
# text is already in action history @ -1
6079
if(not force_use_txt and (txt != "" and vars.prompt != txt)):
6080
txt = ""
6081
depth += 1
6082
6083
if(ln > 0):
6084
chunks = collections.deque()
6085
i = 0
6086
for key in reversed(actions):
6087
chunk = actions[key]
6088
chunks.appendleft(chunk)
6089
i += 1
6090
if(i == depth):
6091
break
6092
6093
if(ln >= depth):
6094
txt = "".join(chunks)
6095
elif(ln > 0):
6096
txt = vars.comregex_ai.sub('', vars.prompt) + "".join(chunks)
6097
elif(ln == 0):
6098
txt = vars.comregex_ai.sub('', vars.prompt)
6099
6100
if(force_use_txt):
6101
txt += original_txt
6102
6103
# Scan text for matches on WI keys
6104
wimem = ""
6105
found_entries = set()
6106
for wi in vars.worldinfo:
6107
if(allowed_entries is not None and wi["uid"] not in allowed_entries):
6108
continue
6109
if(allowed_folders is not None and wi["folder"] not in allowed_folders):
6110
continue
6111
6112
if(wi.get("constant", False)):
6113
wimem = wimem + wi["content"] + "\n"
6114
found_entries.add(id(wi))
6115
continue
6116
6117
if(len(wi["key"].strip()) > 0 and (not wi.get("selective", False) or len(wi.get("keysecondary", "").strip()) > 0)):
6118
# Split comma-separated keys
6119
keys = wi["key"].split(",")
6120
keys_secondary = wi.get("keysecondary", "").split(",")
6121
6122
for k in keys:
6123
ky = k
6124
# Remove leading/trailing spaces if the option is enabled
6125
if(vars.wirmvwhtsp):
6126
ky = k.strip()
6127
if ky.lower() in txt.lower():
6128
if wi.get("selective", False) and len(keys_secondary):
6129
found = False
6130
for ks in keys_secondary:
6131
ksy = ks
6132
if(vars.wirmvwhtsp):
6133
ksy = ks.strip()
6134
if ksy.lower() in txt.lower():
6135
wimem = wimem + wi["content"] + "\n"
6136
found_entries.add(id(wi))
6137
found = True
6138
break
6139
if found:
6140
break
6141
else:
6142
wimem = wimem + wi["content"] + "\n"
6143
found_entries.add(id(wi))
6144
break
6145
6146
return wimem, found_entries
6147
6148
#==================================================================#
6149
# Commit changes to Memory storage
6150
#==================================================================#
6151
def memsubmit(data):
6152
emit('from_server', {'cmd': 'setinputtext', 'data': data}, broadcast=True)
6153
# Maybe check for length at some point
6154
# For now just send it to storage
6155
if(data != vars.memory):
6156
setgamesaved(False)
6157
vars.memory = data
6158
vars.mode = "play"
6159
emit('from_server', {'cmd': 'memmode', 'data': 'false'}, broadcast=True)
6160
6161
# Ask for contents of Author's Note field
6162
emit('from_server', {'cmd': 'getanote', 'data': ''})
6163
6164
#==================================================================#
6165
# Commit changes to Author's Note
6166
#==================================================================#
6167
def anotesubmit(data, template=""):
6168
assert type(data) is str and type(template) is str
6169
# Maybe check for length at some point
6170
# For now just send it to storage
6171
if(data != vars.authornote):
6172
setgamesaved(False)
6173
vars.authornote = data
6174
6175
if(vars.authornotetemplate != template):
6176
vars.setauthornotetemplate = template
6177
settingschanged()
6178
vars.authornotetemplate = template
6179
6180
emit('from_server', {'cmd': 'setanote', 'data': vars.authornote}, broadcast=True)
6181
emit('from_server', {'cmd': 'setanotetemplate', 'data': vars.authornotetemplate}, broadcast=True)
6182
6183
#==================================================================#
6184
# Assembles game data into a request to InferKit API
6185
#==================================================================#
6186
def ikrequest(txt):
6187
# Log request to console
6188
if not vars.quiet:
6189
print("{0}Len:{1}, Txt:{2}{3}".format(colors.YELLOW, len(txt), txt, colors.END))
6190
6191
# Build request JSON data
6192
reqdata = {
6193
'forceNoEnd': True,
6194
'length': vars.ikgen,
6195
'prompt': {
6196
'isContinuation': False,
6197
'text': txt
6198
},
6199
'startFromBeginning': False,
6200
'streamResponse': False,
6201
'temperature': vars.temp,
6202
'topP': vars.top_p
6203
}
6204
6205
# Create request
6206
req = requests.post(
6207
vars.url,
6208
json = reqdata,
6209
headers = {
6210
'Authorization': 'Bearer '+vars.apikey
6211
}
6212
)
6213
6214
# Deal with the response
6215
if(req.status_code == 200):
6216
genout = req.json()["data"]["text"]
6217
6218
vars.lua_koboldbridge.outputs[1] = genout
6219
6220
execute_outmod()
6221
if(vars.lua_koboldbridge.regeneration_required):
6222
vars.lua_koboldbridge.regeneration_required = False
6223
genout = vars.lua_koboldbridge.outputs[1]
6224
assert genout is str
6225
6226
if not vars.quiet:
6227
print("{0}{1}{2}".format(colors.CYAN, genout, colors.END))
6228
vars.actions.append(genout)
6229
if vars.actions.get_last_key() in vars.actions_metadata:
6230
vars.actions_metadata[vars.actions.get_last_key()] = {"Selected Text": genout, "Alternative Text": []}
6231
else:
6232
# 2. We've selected a chunk of text that is was presented previously
6233
alternatives = [item['Text'] for item in vars.actions_metadata[vars.actions.get_last_key()]["Alternative Text"]]
6234
if genout in alternatives:
6235
alternatives = [item for item in vars.actions_metadata[vars.actions.get_last_key()]["Alternative Text"] if item['Text'] != genout]
6236
vars.actions_metadata[vars.actions.get_last_key()]["Alternative Text"] = alternatives
6237
vars.actions_metadata[vars.actions.get_last_key()]["Selected Text"] = genout
6238
update_story_chunk('last')
6239
emit('from_server', {'cmd': 'texteffect', 'data': vars.actions.get_last_key() + 1 if len(vars.actions) else 0}, broadcast=True)
6240
send_debug()
6241
set_aibusy(0)
6242
else:
6243
# Send error message to web client
6244
er = req.json()
6245
if("error" in er):
6246
code = er["error"]["extensions"]["code"]
6247
elif("errors" in er):
6248
code = er["errors"][0]["extensions"]["code"]
6249
6250
errmsg = "InferKit API Error: {0} - {1}".format(req.status_code, code)
6251
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True)
6252
set_aibusy(0)
6253
6254
#==================================================================#
6255
# Assembles game data into a request to OpenAI API
6256
#==================================================================#
6257
def oairequest(txt, min, max):
6258
# Log request to console
6259
if not vars.quiet:
6260
print("{0}Len:{1}, Txt:{2}{3}".format(colors.YELLOW, len(txt), txt, colors.END))
6261
6262
# Store context in memory to use it for comparison with generated content
6263
vars.lastctx = txt
6264
6265
# Build request JSON data
6266
# GooseAI is a subntype of OAI. So to check if it's this type, we check the configname as a workaround
6267
# as the vars.model will always be OAI
6268
if 'GooseAI' in vars.configname:
6269
reqdata = {
6270
'prompt': txt,
6271
'max_tokens': vars.genamt,
6272
'temperature': vars.temp,
6273
'top_a': vars.top_a,
6274
'top_p': vars.top_p,
6275
'top_k': vars.top_k,
6276
'tfs': vars.tfs,
6277
'typical_p': vars.typical,
6278
'repetition_penalty': vars.rep_pen,
6279
'repetition_penalty_slope': vars.rep_pen_slope,
6280
'repetition_penalty_range': vars.rep_pen_range,
6281
'n': vars.numseqs,
6282
'stream': False
6283
}
6284
else:
6285
reqdata = {
6286
'prompt': txt,
6287
'max_tokens': vars.genamt,
6288
'temperature': vars.temp,
6289
'top_p': vars.top_p,
6290
'n': vars.numseqs,
6291
'stream': False
6292
}
6293
6294
req = requests.post(
6295
vars.oaiurl,
6296
json = reqdata,
6297
headers = {
6298
'Authorization': 'Bearer '+vars.oaiapikey,
6299
'Content-Type': 'application/json'
6300
}
6301
)
6302
6303
# Deal with the response
6304
if(req.status_code == 200):
6305
outputs = [out["text"] for out in req.json()["choices"]]
6306
6307
for idx in range(len(outputs)):
6308
vars.lua_koboldbridge.outputs[idx+1] = outputs[idx]
6309
6310
execute_outmod()
6311
if (vars.lua_koboldbridge.regeneration_required):
6312
vars.lua_koboldbridge.regeneration_required = False
6313
genout = []
6314
for i in range(len(outputs)):
6315
genout.append(
6316
{"generated_text": vars.lua_koboldbridge.outputs[i + 1]})
6317
assert type(genout[-1]["generated_text"]) is str
6318
else:
6319
genout = [
6320
{"generated_text": utils.decodenewlines(txt)}
6321
for txt in outputs]
6322
6323
if vars.actions.get_last_key() not in vars.actions_metadata:
6324
vars.actions_metadata[vars.actions.get_last_key()] = {
6325
"Selected Text": genout[0], "Alternative Text": []}
6326
else:
6327
# 2. We've selected a chunk of text that is was presented previously
6328
try:
6329
alternatives = [item['Text'] for item in vars.actions_metadata[len(vars.actions)-1]["Alternative Text"]]
6330
except:
6331
print(len(vars.actions))
6332
print(vars.actions_metadata)
6333
raise
6334
if genout in alternatives:
6335
alternatives = [item for item in vars.actions_metadata[vars.actions.get_last_key() ]["Alternative Text"] if item['Text'] != genout]
6336
vars.actions_metadata[vars.actions.get_last_key()]["Alternative Text"] = alternatives
6337
vars.actions_metadata[vars.actions.get_last_key()]["Selected Text"] = genout
6338
6339
if (len(genout) == 1):
6340
genresult(genout[0]["generated_text"])
6341
else:
6342
if (vars.lua_koboldbridge.restart_sequence is not None and
6343
vars.lua_koboldbridge.restart_sequence > 0):
6344
genresult(genout[vars.lua_koboldbridge.restart_sequence - 1][
6345
"generated_text"])
6346
else:
6347
genselect(genout)
6348
6349
if not vars.quiet:
6350
print("{0}{1}{2}".format(colors.CYAN, genout, colors.END))
6351
6352
set_aibusy(0)
6353
else:
6354
# Send error message to web client
6355
er = req.json()
6356
if("error" in er):
6357
type = er["error"]["type"]
6358
message = er["error"]["message"]
6359
6360
errmsg = "OpenAI API Error: {0} - {1}".format(type, message)
6361
emit('from_server', {'cmd': 'errmsg', 'data': errmsg}, broadcast=True)
6362
set_aibusy(0)
6363
6364
#==================================================================#
6365
# Forces UI to Play mode
6366
#==================================================================#
6367
def exitModes():
6368
if(vars.mode == "edit"):
6369
emit('from_server', {'cmd': 'editmode', 'data': 'false'}, broadcast=True)
6370
elif(vars.mode == "memory"):
6371
emit('from_server', {'cmd': 'memmode', 'data': 'false'}, broadcast=True)
6372
elif(vars.mode == "wi"):
6373
emit('from_server', {'cmd': 'wimode', 'data': 'false'}, broadcast=True)
6374
vars.mode = "play"
6375
6376
#==================================================================#
6377
# Launch in-browser save prompt
6378
#==================================================================#
6379
def saveas(data):
6380
6381
name = data['name']
6382
savepins = data['pins']
6383
# Check if filename exists already
6384
name = utils.cleanfilename(name)
6385
if(not fileops.saveexists(name) or (vars.saveow and vars.svowname == name)):
6386
# All clear to save
6387
e = saveRequest(fileops.storypath(name), savepins=savepins)
6388
vars.saveow = False
6389
vars.svowname = ""
6390
if(e is None):
6391
emit('from_server', {'cmd': 'hidesaveas', 'data': ''})
6392
else:
6393
print("{0}{1}{2}".format(colors.RED, str(e), colors.END))
6394
emit('from_server', {'cmd': 'popuperror', 'data': str(e)})
6395
else:
6396
# File exists, prompt for overwrite
6397
vars.saveow = True
6398
vars.svowname = name
6399
emit('from_server', {'cmd': 'askforoverwrite', 'data': ''})
6400
6401
#==================================================================#
6402
# Launch in-browser story-delete prompt
6403
#==================================================================#
6404
def deletesave(name):
6405
name = utils.cleanfilename(name)
6406
e = fileops.deletesave(name)
6407
if(e is None):
6408
if(vars.smandelete):
6409
emit('from_server', {'cmd': 'hidepopupdelete', 'data': ''})
6410
getloadlist()
6411
else:
6412
emit('from_server', {'cmd': 'popuperror', 'data': "The server denied your request to delete this story"})
6413
else:
6414
print("{0}{1}{2}".format(colors.RED, str(e), colors.END))
6415
emit('from_server', {'cmd': 'popuperror', 'data': str(e)})
6416
6417
#==================================================================#
6418
# Launch in-browser story-rename prompt
6419
#==================================================================#
6420
def renamesave(name, newname):
6421
# Check if filename exists already
6422
name = utils.cleanfilename(name)
6423
newname = utils.cleanfilename(newname)
6424
if(not fileops.saveexists(newname) or name == newname or (vars.saveow and vars.svowname == newname)):
6425
e = fileops.renamesave(name, newname)
6426
vars.saveow = False
6427
vars.svowname = ""
6428
if(e is None):
6429
if(vars.smanrename):
6430
emit('from_server', {'cmd': 'hidepopuprename', 'data': ''})
6431
getloadlist()
6432
else:
6433
emit('from_server', {'cmd': 'popuperror', 'data': "The server denied your request to rename this story"})
6434
else:
6435
print("{0}{1}{2}".format(colors.RED, str(e), colors.END))
6436
emit('from_server', {'cmd': 'popuperror', 'data': str(e)})
6437
else:
6438
# File exists, prompt for overwrite
6439
vars.saveow = True
6440
vars.svowname = newname
6441
emit('from_server', {'cmd': 'askforoverwrite', 'data': ''})
6442
6443
#==================================================================#
6444
# Save the currently running story
6445
#==================================================================#
6446
def save():
6447
# Check if a file is currently open
6448
if(".json" in vars.savedir):
6449
saveRequest(vars.savedir)
6450
else:
6451
emit('from_server', {'cmd': 'saveas', 'data': ''})
6452
6453
#==================================================================#
6454
# Save the story via file browser
6455
#==================================================================#
6456
def savetofile():
6457
savpath = fileops.getsavepath(vars.savedir, "Save Story As", [("Json", "*.json")])
6458
saveRequest(savpath)
6459
6460
#==================================================================#
6461
# Save the story to specified path
6462
#==================================================================#
6463
def saveRequest(savpath, savepins=True):
6464
if(savpath):
6465
# Leave Edit/Memory mode before continuing
6466
exitModes()
6467
6468
# Save path for future saves
6469
vars.savedir = savpath
6470
txtpath = os.path.splitext(savpath)[0] + ".txt"
6471
# Build json to write
6472
js = {}
6473
js["gamestarted"] = vars.gamestarted
6474
js["prompt"] = vars.prompt
6475
js["memory"] = vars.memory
6476
js["authorsnote"] = vars.authornote
6477
js["anotetemplate"] = vars.authornotetemplate
6478
js["actions"] = tuple(vars.actions.values())
6479
if savepins:
6480
js["actions_metadata"] = vars.actions_metadata
6481
js["worldinfo"] = []
6482
js["wifolders_d"] = vars.wifolders_d
6483
js["wifolders_l"] = vars.wifolders_l
6484
6485
# Extract only the important bits of WI
6486
for wi in vars.worldinfo_i:
6487
if(True):
6488
js["worldinfo"].append({
6489
"key": wi["key"],
6490
"keysecondary": wi["keysecondary"],
6491
"content": wi["content"],
6492
"comment": wi["comment"],
6493
"folder": wi["folder"],
6494
"selective": wi["selective"],
6495
"constant": wi["constant"]
6496
})
6497
6498
txt = vars.prompt + "".join(vars.actions.values())
6499
6500
# Write it
6501
try:
6502
file = open(savpath, "w")
6503
except Exception as e:
6504
return e
6505
try:
6506
file.write(json.dumps(js, indent=3))
6507
except Exception as e:
6508
file.close()
6509
return e
6510
file.close()
6511
6512
try:
6513
file = open(txtpath, "w")
6514
except Exception as e:
6515
return e
6516
try:
6517
file.write(txt)
6518
except Exception as e:
6519
file.close()
6520
return e
6521
file.close()
6522
6523
filename = path.basename(savpath)
6524
if(filename.endswith('.json')):
6525
filename = filename[:-5]
6526
vars.laststory = filename
6527
emit('from_server', {'cmd': 'setstoryname', 'data': vars.laststory}, broadcast=True)
6528
setgamesaved(True)
6529
print("{0}Story saved to {1}!{2}".format(colors.GREEN, path.basename(savpath), colors.END))
6530
6531
#==================================================================#
6532
# Show list of saved stories
6533
#==================================================================#
6534
def getloadlist():
6535
emit('from_server', {'cmd': 'buildload', 'data': fileops.getstoryfiles()})
6536
6537
#==================================================================#
6538
# Show list of soft prompts
6539
#==================================================================#
6540
def getsplist():
6541
if(vars.allowsp):
6542
emit('from_server', {'cmd': 'buildsp', 'data': fileops.getspfiles(vars.modeldim)})
6543
6544
#==================================================================#
6545
# Get list of userscripts
6546
#==================================================================#
6547
def getuslist():
6548
files = {i: v for i, v in enumerate(fileops.getusfiles())}
6549
loaded = []
6550
unloaded = []
6551
userscripts = set(vars.userscripts)
6552
for i in range(len(files)):
6553
if files[i]["filename"] not in userscripts:
6554
unloaded.append(files[i])
6555
files = {files[k]["filename"]: files[k] for k in files}
6556
userscripts = set(files.keys())
6557
for filename in vars.userscripts:
6558
if filename in userscripts:
6559
loaded.append(files[filename])
6560
return unloaded, loaded
6561
6562
#==================================================================#
6563
# Load a saved story via file browser
6564
#==================================================================#
6565
def loadfromfile():
6566
loadpath = fileops.getloadpath(vars.savedir, "Select Story File", [("Json", "*.json")])
6567
loadRequest(loadpath)
6568
6569
#==================================================================#
6570
# Load a stored story from a file
6571
#==================================================================#
6572
def loadRequest(loadpath, filename=None):
6573
if(loadpath):
6574
# Leave Edit/Memory mode before continuing
6575
exitModes()
6576
6577
# Read file contents into JSON object
6578
if(isinstance(loadpath, str)):
6579
with open(loadpath, "r") as file:
6580
js = json.load(file)
6581
if(filename is None):
6582
filename = path.basename(loadpath)
6583
else:
6584
js = loadpath
6585
if(filename is None):
6586
filename = "untitled.json"
6587
6588
# Copy file contents to vars
6589
vars.gamestarted = js["gamestarted"]
6590
vars.prompt = js["prompt"]
6591
vars.memory = js["memory"]
6592
vars.worldinfo = []
6593
vars.worldinfo = []
6594
vars.worldinfo_u = {}
6595
vars.wifolders_d = {int(k): v for k, v in js.get("wifolders_d", {}).items()}
6596
vars.wifolders_l = js.get("wifolders_l", [])
6597
vars.wifolders_u = {uid: [] for uid in vars.wifolders_d}
6598
vars.lastact = ""
6599
vars.submission = ""
6600
vars.lastctx = ""
6601
vars.genseqs = []
6602
6603
del vars.actions
6604
vars.actions = structures.KoboldStoryRegister()
6605
actions = collections.deque(js["actions"])
6606
6607
6608
if "actions_metadata" in js:
6609
6610
if type(js["actions_metadata"]) == dict:
6611
temp = js["actions_metadata"]
6612
vars.actions_metadata = {}
6613
#we need to redo the numbering of the actions_metadata since the actions list doesn't preserve it's number on saving
6614
if len(temp) > 0:
6615
counter = 0
6616
temp = {int(k):v for k,v in temp.items()}
6617
for i in range(max(temp)+1):
6618
if i in temp:
6619
vars.actions_metadata[counter] = temp[i]
6620
counter += 1
6621
del temp
6622
else:
6623
#fix if we're using the old metadata format
6624
vars.actions_metadata = {}
6625
i = 0
6626
6627
for text in js['actions']:
6628
vars.actions_metadata[i] = {'Selected Text': text, 'Alternative Text': []}
6629
i+=1
6630
else:
6631
vars.actions_metadata = {}
6632
i = 0
6633
6634
for text in js['actions']:
6635
vars.actions_metadata[i] = {'Selected Text': text, 'Alternative Text': []}
6636
i+=1
6637
6638
footer = ""
6639
6640
if(len(vars.prompt.strip()) == 0):
6641
while(len(actions)):
6642
action = actions.popleft()
6643
if(len(action.strip()) != 0):
6644
vars.prompt = action
6645
break
6646
else:
6647
vars.gamestarted = False
6648
vars.prompt = vars.prompt.lstrip()
6649
ln = len(vars.prompt.rstrip())
6650
footer += vars.prompt[ln:]
6651
vars.prompt = vars.prompt[:ln]
6652
if(vars.gamestarted):
6653
for s in actions:
6654
if(len(s.strip()) == 0):
6655
# If this action only contains whitespace, we merge it with the next action
6656
footer += s
6657
continue
6658
vars.actions.append(footer + s)
6659
footer = ""
6660
# If there is trailing whitespace at the end of an action, we move that whitespace to the beginning of the next action
6661
ln = len(vars.actions[vars.actions.get_last_key()].rstrip())
6662
footer += vars.actions[vars.actions.get_last_key()][ln:]
6663
vars.actions[vars.actions.get_last_key()] = vars.actions[vars.actions.get_last_key()][:ln]
6664
6665
# Try not to break older save files
6666
if("authorsnote" in js):
6667
vars.authornote = js["authorsnote"]
6668
else:
6669
vars.authornote = ""
6670
if("anotetemplate" in js):
6671
vars.authornotetemplate = js["anotetemplate"]
6672
else:
6673
vars.authornotetemplate = "[Author's note: <|>]"
6674
6675
if("worldinfo" in js):
6676
num = 0
6677
for wi in js["worldinfo"]:
6678
vars.worldinfo.append({
6679
"key": wi["key"],
6680
"keysecondary": wi.get("keysecondary", ""),
6681
"content": wi["content"],
6682
"comment": wi.get("comment", ""),
6683
"folder": wi.get("folder", None),
6684
"num": num,
6685
"init": True,
6686
"selective": wi.get("selective", False),
6687
"constant": wi.get("constant", False),
6688
"uid": None,
6689
})
6690
while(True):
6691
uid = int.from_bytes(os.urandom(4), "little", signed=True)
6692
if(uid not in vars.worldinfo_u):
6693
break
6694
vars.worldinfo_u[uid] = vars.worldinfo[-1]
6695
vars.worldinfo[-1]["uid"] = uid
6696
if(vars.worldinfo[-1]["folder"] is not None):
6697
vars.wifolders_u[vars.worldinfo[-1]["folder"]].append(vars.worldinfo[-1])
6698
num += 1
6699
6700
for uid in vars.wifolders_l + [None]:
6701
vars.worldinfo.append({"key": "", "keysecondary": "", "content": "", "comment": "", "folder": uid, "num": None, "init": False, "selective": False, "constant": False, "uid": None})
6702
while(True):
6703
uid = int.from_bytes(os.urandom(4), "little", signed=True)
6704
if(uid not in vars.worldinfo_u):
6705
break
6706
vars.worldinfo_u[uid] = vars.worldinfo[-1]
6707
vars.worldinfo[-1]["uid"] = uid
6708
if(vars.worldinfo[-1]["folder"] is not None):
6709
vars.wifolders_u[vars.worldinfo[-1]["folder"]].append(vars.worldinfo[-1])
6710
stablesortwi()
6711
vars.worldinfo_i = [wi for wi in vars.worldinfo if wi["init"]]
6712
6713
# Save path for save button
6714
vars.savedir = loadpath
6715
6716
# Clear loadselect var
6717
vars.loadselect = ""
6718
6719
# Refresh game screen
6720
_filename = filename
6721
if(filename.endswith('.json')):
6722
_filename = filename[:-5]
6723
vars.laststory = _filename
6724
emit('from_server', {'cmd': 'setstoryname', 'data': vars.laststory}, broadcast=True)
6725
setgamesaved(True)
6726
sendwi()
6727
emit('from_server', {'cmd': 'setmemory', 'data': vars.memory}, broadcast=True)
6728
emit('from_server', {'cmd': 'setanote', 'data': vars.authornote}, broadcast=True)
6729
emit('from_server', {'cmd': 'setanotetemplate', 'data': vars.authornotetemplate}, broadcast=True)
6730
refresh_story()
6731
emit('from_server', {'cmd': 'setgamestate', 'data': 'ready'}, broadcast=True)
6732
emit('from_server', {'cmd': 'hidegenseqs', 'data': ''}, broadcast=True)
6733
print("{0}Story loaded from {1}!{2}".format(colors.GREEN, filename, colors.END))
6734
6735
send_debug()
6736
6737
#==================================================================#
6738
# Import an AIDungon game exported with Mimi's tool
6739
#==================================================================#
6740
def importRequest():
6741
importpath = fileops.getloadpath(vars.savedir, "Select AID CAT File", [("Json", "*.json")])
6742
6743
if(importpath):
6744
# Leave Edit/Memory mode before continuing
6745
exitModes()
6746
6747
# Read file contents into JSON object
6748
file = open(importpath, "rb")
6749
vars.importjs = json.load(file)
6750
6751
# If a bundle file is being imported, select just the Adventures object
6752
if type(vars.importjs) is dict and "stories" in vars.importjs:
6753
vars.importjs = vars.importjs["stories"]
6754
6755
# Clear Popup Contents
6756
emit('from_server', {'cmd': 'clearpopup', 'data': ''}, broadcast=True)
6757
6758
# Initialize vars
6759
num = 0
6760
vars.importnum = -1
6761
6762
# Get list of stories
6763
for story in vars.importjs:
6764
ob = {}
6765
ob["num"] = num
6766
if(story["title"] != "" and story["title"] != None):
6767
ob["title"] = story["title"]
6768
else:
6769
ob["title"] = "(No Title)"
6770
if(story["description"] != "" and story["description"] != None):
6771
ob["descr"] = story["description"]
6772
else:
6773
ob["descr"] = "(No Description)"
6774
if("actions" in story):
6775
ob["acts"] = len(story["actions"])
6776
elif("actionWindow" in story):
6777
ob["acts"] = len(story["actionWindow"])
6778
emit('from_server', {'cmd': 'addimportline', 'data': ob})
6779
num += 1
6780
6781
# Show Popup
6782
emit('from_server', {'cmd': 'popupshow', 'data': True})
6783
6784
#==================================================================#
6785
# Import an AIDungon game selected in popup
6786
#==================================================================#
6787
def importgame():
6788
if(vars.importnum >= 0):
6789
# Cache reference to selected game
6790
ref = vars.importjs[vars.importnum]
6791
6792
# Copy game contents to vars
6793
vars.gamestarted = True
6794
6795
# Support for different versions of export script
6796
if("actions" in ref):
6797
if(len(ref["actions"]) > 0):
6798
vars.prompt = ref["actions"][0]["text"]
6799
else:
6800
vars.prompt = ""
6801
elif("actionWindow" in ref):
6802
if(len(ref["actionWindow"]) > 0):
6803
vars.prompt = ref["actionWindow"][0]["text"]
6804
else:
6805
vars.prompt = ""
6806
else:
6807
vars.prompt = ""
6808
vars.memory = ref["memory"]
6809
vars.authornote = ref["authorsNote"] if type(ref["authorsNote"]) is str else ""
6810
vars.authornotetemplate = "[Author's note: <|>]"
6811
vars.actions = structures.KoboldStoryRegister()
6812
vars.actions_metadata = {}
6813
vars.worldinfo = []
6814
vars.worldinfo_i = []
6815
vars.worldinfo_u = {}
6816
vars.wifolders_d = {}
6817
vars.wifolders_l = []
6818
vars.wifolders_u = {uid: [] for uid in vars.wifolders_d}
6819
vars.lastact = ""
6820
vars.submission = ""
6821
vars.lastctx = ""
6822
6823
# Get all actions except for prompt
6824
if("actions" in ref):
6825
if(len(ref["actions"]) > 1):
6826
for act in ref["actions"][1:]:
6827
vars.actions.append(act["text"])
6828
elif("actionWindow" in ref):
6829
if(len(ref["actionWindow"]) > 1):
6830
for act in ref["actionWindow"][1:]:
6831
vars.actions.append(act["text"])
6832
6833
# Get just the important parts of world info
6834
if(ref["worldInfo"] != None):
6835
if(len(ref["worldInfo"]) > 1):
6836
num = 0
6837
for wi in ref["worldInfo"]:
6838
vars.worldinfo.append({
6839
"key": wi["keys"],
6840
"keysecondary": wi.get("keysecondary", ""),
6841
"content": wi["entry"],
6842
"comment": wi.get("comment", ""),
6843
"folder": wi.get("folder", None),
6844
"num": num,
6845
"init": True,
6846
"selective": wi.get("selective", False),
6847
"constant": wi.get("constant", False),
6848
"uid": None,
6849
})
6850
while(True):
6851
uid = int.from_bytes(os.urandom(4), "little", signed=True)
6852
if(uid not in vars.worldinfo_u):
6853
break
6854
vars.worldinfo_u[uid] = vars.worldinfo[-1]
6855
vars.worldinfo[-1]["uid"] = uid
6856
if(vars.worldinfo[-1]["folder"]) is not None:
6857
vars.wifolders_u[vars.worldinfo[-1]["folder"]].append(vars.worldinfo[-1])
6858
num += 1
6859
6860
for uid in vars.wifolders_l + [None]:
6861
vars.worldinfo.append({"key": "", "keysecondary": "", "content": "", "comment": "", "folder": uid, "num": None, "init": False, "selective": False, "constant": False, "uid": None})
6862
while(True):
6863
uid = int.from_bytes(os.urandom(4), "little", signed=True)
6864
if(uid not in vars.worldinfo_u):
6865
break
6866
vars.worldinfo_u[uid] = vars.worldinfo[-1]
6867
vars.worldinfo[-1]["uid"] = uid
6868
if(vars.worldinfo[-1]["folder"] is not None):
6869
vars.wifolders_u[vars.worldinfo[-1]["folder"]].append(vars.worldinfo[-1])
6870
stablesortwi()
6871
vars.worldinfo_i = [wi for wi in vars.worldinfo if wi["init"]]
6872
6873
# Clear import data
6874
vars.importjs = {}
6875
6876
# Reset current save
6877
vars.savedir = getcwd()+"\\stories"
6878
6879
# Refresh game screen
6880
vars.laststory = None
6881
emit('from_server', {'cmd': 'setstoryname', 'data': vars.laststory}, broadcast=True)
6882
setgamesaved(False)
6883
sendwi()
6884
emit('from_server', {'cmd': 'setmemory', 'data': vars.memory}, broadcast=True)
6885
emit('from_server', {'cmd': 'setanote', 'data': vars.authornote}, broadcast=True)
6886
emit('from_server', {'cmd': 'setanotetemplate', 'data': vars.authornotetemplate}, broadcast=True)
6887
refresh_story()
6888
emit('from_server', {'cmd': 'setgamestate', 'data': 'ready'}, broadcast=True)
6889
emit('from_server', {'cmd': 'hidegenseqs', 'data': ''}, broadcast=True)
6890
6891
#==================================================================#
6892
# Import an aidg.club prompt and start a new game with it.
6893
#==================================================================#
6894
def importAidgRequest(id):
6895
exitModes()
6896
6897
urlformat = "https://aetherroom.club/api/"
6898
req = requests.get(urlformat+id)
6899
6900
if(req.status_code == 200):
6901
js = req.json()
6902
6903
# Import game state
6904
vars.gamestarted = True
6905
vars.prompt = js["promptContent"]
6906
vars.memory = js["memory"]
6907
vars.authornote = js["authorsNote"]
6908
vars.authornotetemplate = "[Author's note: <|>]"
6909
vars.actions = structures.KoboldStoryRegister()
6910
vars.actions_metadata = {}
6911
vars.worldinfo = []
6912
vars.worldinfo_i = []
6913
vars.worldinfo_u = {}
6914
vars.wifolders_d = {}
6915
vars.wifolders_l = []
6916
vars.wifolders_u = {uid: [] for uid in vars.wifolders_d}
6917
vars.lastact = ""
6918
vars.submission = ""
6919
vars.lastctx = ""
6920
6921
if not vars.memory:
6922
vars.memory = ""
6923
if not vars.authornote:
6924
vars.authornote = ""
6925
6926
num = 0
6927
for wi in js["worldInfos"]:
6928
vars.worldinfo.append({
6929
"key": wi["keys"],
6930
"keysecondary": wi.get("keysecondary", ""),
6931
"content": wi["entry"],
6932
"comment": wi.get("comment", ""),
6933
"folder": wi.get("folder", None),
6934
"num": num,
6935
"init": True,
6936
"selective": wi.get("selective", False),
6937
"constant": wi.get("constant", False),
6938
"uid": None,
6939
})
6940
while(True):
6941
uid = int.from_bytes(os.urandom(4), "little", signed=True)
6942
if(uid not in vars.worldinfo_u):
6943
break
6944
vars.worldinfo_u[uid] = vars.worldinfo[-1]
6945
vars.worldinfo[-1]["uid"] = uid
6946
if(vars.worldinfo[-1]["folder"]) is not None:
6947
vars.wifolders_u[vars.worldinfo[-1]["folder"]].append(vars.worldinfo[-1])
6948
num += 1
6949
6950
for uid in vars.wifolders_l + [None]:
6951
vars.worldinfo.append({"key": "", "keysecondary": "", "content": "", "comment": "", "folder": uid, "num": None, "init": False, "selective": False, "constant": False, "uid": None})
6952
while(True):
6953
uid = int.from_bytes(os.urandom(4), "little", signed=True)
6954
if(uid not in vars.worldinfo_u):
6955
break
6956
vars.worldinfo_u[uid] = vars.worldinfo[-1]
6957
vars.worldinfo[-1]["uid"] = uid
6958
if(vars.worldinfo[-1]["folder"] is not None):
6959
vars.wifolders_u[vars.worldinfo[-1]["folder"]].append(vars.worldinfo[-1])
6960
stablesortwi()
6961
vars.worldinfo_i = [wi for wi in vars.worldinfo if wi["init"]]
6962
6963
# Reset current save
6964
vars.savedir = getcwd()+"\\stories"
6965
6966
# Refresh game screen
6967
vars.laststory = None
6968
emit('from_server', {'cmd': 'setstoryname', 'data': vars.laststory}, broadcast=True)
6969
setgamesaved(False)
6970
sendwi()
6971
emit('from_server', {'cmd': 'setmemory', 'data': vars.memory}, broadcast=True)
6972
emit('from_server', {'cmd': 'setanote', 'data': vars.authornote}, broadcast=True)
6973
emit('from_server', {'cmd': 'setanotetemplate', 'data': vars.authornotetemplate}, broadcast=True)
6974
refresh_story()
6975
emit('from_server', {'cmd': 'setgamestate', 'data': 'ready'}, broadcast=True)
6976
6977
#==================================================================#
6978
# Import World Info JSON file
6979
#==================================================================#
6980
def wiimportrequest():
6981
importpath = fileops.getloadpath(vars.savedir, "Select World Info File", [("Json", "*.json")])
6982
if(importpath):
6983
file = open(importpath, "rb")
6984
js = json.load(file)
6985
if(len(js) > 0):
6986
# If the most recent WI entry is blank, remove it.
6987
if(not vars.worldinfo[-1]["init"]):
6988
del vars.worldinfo[-1]
6989
# Now grab the new stuff
6990
num = len(vars.worldinfo)
6991
for wi in js:
6992
vars.worldinfo.append({
6993
"key": wi["keys"],
6994
"keysecondary": wi.get("keysecondary", ""),
6995
"content": wi["entry"],
6996
"comment": wi.get("comment", ""),
6997
"folder": wi.get("folder", None),
6998
"num": num,
6999
"init": True,
7000
"selective": wi.get("selective", False),
7001
"constant": wi.get("constant", False),
7002
"uid": None,
7003
})
7004
while(True):
7005
uid = int.from_bytes(os.urandom(4), "little", signed=True)
7006
if(uid not in vars.worldinfo_u):
7007
break
7008
vars.worldinfo_u[uid] = vars.worldinfo[-1]
7009
vars.worldinfo[-1]["uid"] = uid
7010
if(vars.worldinfo[-1]["folder"]) is not None:
7011
vars.wifolders_u[vars.worldinfo[-1]["folder"]].append(vars.worldinfo[-1])
7012
num += 1
7013
for uid in [None]:
7014
vars.worldinfo.append({"key": "", "keysecondary": "", "content": "", "comment": "", "folder": uid, "num": None, "init": False, "selective": False, "constant": False, "uid": None})
7015
while(True):
7016
uid = int.from_bytes(os.urandom(4), "little", signed=True)
7017
if(uid not in vars.worldinfo_u):
7018
break
7019
vars.worldinfo_u[uid] = vars.worldinfo[-1]
7020
vars.worldinfo[-1]["uid"] = uid
7021
if(vars.worldinfo[-1]["folder"] is not None):
7022
vars.wifolders_u[vars.worldinfo[-1]["folder"]].append(vars.worldinfo[-1])
7023
7024
if not vars.quiet:
7025
print("{0}".format(vars.worldinfo[0]))
7026
7027
# Refresh game screen
7028
setgamesaved(False)
7029
sendwi()
7030
7031
#==================================================================#
7032
# Starts a new story
7033
#==================================================================#
7034
def newGameRequest():
7035
# Leave Edit/Memory mode before continuing
7036
exitModes()
7037
7038
# Clear vars values
7039
vars.gamestarted = False
7040
vars.prompt = ""
7041
vars.memory = ""
7042
vars.actions = structures.KoboldStoryRegister()
7043
vars.actions_metadata = {}
7044
7045
vars.authornote = ""
7046
vars.authornotetemplate = vars.setauthornotetemplate
7047
vars.worldinfo = []
7048
vars.worldinfo_i = []
7049
vars.worldinfo_u = {}
7050
vars.wifolders_d = {}
7051
vars.wifolders_l = []
7052
vars.lastact = ""
7053
vars.submission = ""
7054
vars.lastctx = ""
7055
7056
# Reset current save
7057
vars.savedir = getcwd()+"\\stories"
7058
7059
# Refresh game screen
7060
vars.laststory = None
7061
emit('from_server', {'cmd': 'setstoryname', 'data': vars.laststory}, broadcast=True)
7062
setgamesaved(True)
7063
sendwi()
7064
emit('from_server', {'cmd': 'setmemory', 'data': vars.memory}, broadcast=True)
7065
emit('from_server', {'cmd': 'setanote', 'data': vars.authornote}, broadcast=True)
7066
emit('from_server', {'cmd': 'setanotetemplate', 'data': vars.authornotetemplate}, broadcast=True)
7067
setStartState()
7068
7069
def randomGameRequest(topic, memory=""):
7070
if(vars.noai):
7071
newGameRequest()
7072
vars.memory = memory
7073
emit('from_server', {'cmd': 'setmemory', 'data': vars.memory}, broadcast=True)
7074
return
7075
vars.recentrng = topic
7076
vars.recentrngm = memory
7077
newGameRequest()
7078
setgamesaved(False)
7079
_memory = memory
7080
if(len(memory) > 0):
7081
_memory = memory.rstrip() + "\n\n"
7082
vars.memory = _memory + "You generate the following " + topic + " story concept :"
7083
vars.lua_koboldbridge.feedback = None
7084
actionsubmit("", force_submit=True, force_prompt_gen=True)
7085
vars.memory = memory
7086
emit('from_server', {'cmd': 'setmemory', 'data': vars.memory}, broadcast=True)
7087
7088
def final_startup():
7089
# Prevent tokenizer from taking extra time the first time it's used
7090
def __preempt_tokenizer():
7091
if("tokenizer" not in globals()):
7092
return
7093
utils.decodenewlines(tokenizer.decode([25678, 559]))
7094
tokenizer.encode(utils.encodenewlines("eunoia"))
7095
threading.Thread(target=__preempt_tokenizer).start()
7096
7097
# Load soft prompt specified by the settings file, if applicable
7098
if(path.exists(get_config_filename())):
7099
file = open(get_config_filename(), "r")
7100
js = json.load(file)
7101
if(vars.allowsp and "softprompt" in js and type(js["softprompt"]) is str and all(q not in js["softprompt"] for q in ("..", ":")) and (len(js["softprompt"]) == 0 or all(js["softprompt"][0] not in q for q in ("/", "\\")))):
7102
spRequest(js["softprompt"])
7103
else:
7104
vars.spfilename = ""
7105
file.close()
7106
7107
# Precompile TPU backend if required
7108
if(vars.use_colab_tpu or vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")):
7109
soft_tokens = tpumtjgetsofttokens()
7110
if(vars.dynamicscan or (not vars.nogenmod and vars.has_genmod)):
7111
threading.Thread(
7112
target=tpu_mtj_backend.infer_dynamic,
7113
args=(np.tile(np.uint32((23403, 727, 20185)), (vars.numseqs, 1)),),
7114
kwargs={
7115
"soft_embeddings": vars.sp,
7116
"soft_tokens": soft_tokens,
7117
"gen_len": 1,
7118
"use_callback": False,
7119
"numseqs": vars.numseqs,
7120
"excluded_world_info": list(set() for _ in range(vars.numseqs)),
7121
},
7122
).start()
7123
else:
7124
threading.Thread(
7125
target=tpu_mtj_backend.infer_static,
7126
args=(np.uint32((23403, 727, 20185)),),
7127
kwargs={
7128
"soft_embeddings": vars.sp,
7129
"soft_tokens": soft_tokens,
7130
"gen_len": 1,
7131
"numseqs": vars.numseqs,
7132
},
7133
).start()
7134
7135
# Set the initial RNG seed
7136
if(vars.seed is not None):
7137
if(vars.use_colab_tpu):
7138
if(vars.seed_specified):
7139
__import__("tpu_mtj_backend").set_rng_seed(vars.seed)
7140
else:
7141
__import__("tpu_mtj_backend").randomize_rng_seed()
7142
else:
7143
if(vars.seed_specified):
7144
__import__("torch").manual_seed(vars.seed)
7145
else:
7146
__import__("torch").seed()
7147
vars.seed = __import__("tpu_mtj_backend").get_rng_seed() if vars.use_colab_tpu else __import__("torch").initial_seed()
7148
7149
def send_debug():
7150
if vars.debug:
7151
debug_info = ""
7152
try:
7153
debug_info = "{}Seed: {} ({})\n".format(debug_info, repr(__import__("tpu_mtj_backend").get_rng_seed() if vars.use_colab_tpu else __import__("torch").initial_seed()), "specified by user in settings file" if vars.seed_specified else "randomly generated")
7154
except:
7155
pass
7156
try:
7157
debug_info = "{}Newline Mode: {}\n".format(debug_info, vars.newlinemode)
7158
except:
7159
pass
7160
try:
7161
debug_info = "{}Action Length: {}\n".format(debug_info, vars.actions.get_last_key())
7162
except:
7163
pass
7164
try:
7165
debug_info = "{}Actions Metadata Length: {}\n".format(debug_info, max(vars.actions_metadata) if len(vars.actions_metadata) > 0 else 0)
7166
except:
7167
pass
7168
try:
7169
debug_info = "{}Actions: {}\n".format(debug_info, [k for k in vars.actions])
7170
except:
7171
pass
7172
try:
7173
debug_info = "{}Actions Metadata: {}\n".format(debug_info, [k for k in vars.actions_metadata])
7174
except:
7175
pass
7176
try:
7177
debug_info = "{}Last Action: {}\n".format(debug_info, vars.actions[vars.actions.get_last_key()])
7178
except:
7179
pass
7180
try:
7181
debug_info = "{}Last Metadata: {}\n".format(debug_info, vars.actions_metadata[max(vars.actions_metadata)])
7182
except:
7183
pass
7184
7185
emit('from_server', {'cmd': 'debug_info', 'data': debug_info}, broadcast=True)
7186
7187
#==================================================================#
7188
# Load file browser for soft prompts
7189
#==================================================================#
7190
@socketio.on('show_folder_soft_prompt')
7191
def show_folder_soft_prompt(data):
7192
file_popup("Load Softprompt", "./softprompts", "", renameable=True, folder_only=False, editable=False, deleteable=True, jailed=True, item_check=None)
7193
7194
#==================================================================#
7195
# Load file browser for user scripts
7196
#==================================================================#
7197
@socketio.on('show_folder_usersripts')
7198
def show_folder_usersripts(data):
7199
file_popup("Load Softprompt", "./userscripts", "", renameable=True, folder_only=False, editable=True, deleteable=True, jailed=True, item_check=None)
7200
7201
7202
7203
#==================================================================#
7204
# File Popup options
7205
#==================================================================#
7206
7207
@socketio.on('upload_file')
7208
def upload_file(data):
7209
print("upload_file {}".format(data['filename']))
7210
print('current_folder' in session)
7211
print('popup_jailed_dir' not in session)
7212
print(session['popup_jailed_dir'])
7213
print(session['current_folder'])
7214
if 'current_folder' in session:
7215
path = os.path.abspath(os.path.join(session['current_folder'], data['filename']).replace("\\", "/")).replace("\\", "/")
7216
print(path)
7217
print(os.path.exists(path))
7218
if 'popup_jailed_dir' not in session:
7219
print("Someone is trying to upload a file to your server. Blocked.")
7220
elif session['popup_jailed_dir'] is None:
7221
if os.path.exists(path):
7222
print("popup error")
7223
emit("error_popup", "The file already exists. Please delete it or rename the file before uploading", room="UI_2");
7224
else:
7225
with open(path, "wb") as f:
7226
f.write(data['data'])
7227
get_files_folders(session['current_folder'])
7228
print("saved")
7229
elif session['popup_jailed_dir'] in session['current_folder']:
7230
if os.path.exists(path):
7231
print("popup error")
7232
emit("error_popup", "The file already exists. Please delete it or rename the file before uploading", room="UI_2");
7233
else:
7234
with open(path, "wb") as f:
7235
f.write(data['data'])
7236
get_files_folders(session['current_folder'])
7237
print("saved")
7238
7239
@socketio.on('popup_change_folder')
7240
def popup_change_folder(data):
7241
print("Doing popup change folder: {}".format(data))
7242
if 'popup_jailed_dir' not in session:
7243
print("Someone is trying to get at files in your server. Blocked.")
7244
return
7245
if session['popup_jailed_dir'] is None:
7246
get_files_folders(data)
7247
elif session['popup_jailed_dir'] in data:
7248
get_files_folders(data)
7249
else:
7250
print("User is trying to get at files in your server outside the jail. Blocked. Jailed Dir: {} Requested Dir: {}".format(session['popup_jailed_dir'], data))
7251
7252
@socketio.on('popup_rename')
7253
def popup_rename(data):
7254
if 'popup_renameable' not in session:
7255
print("Someone is trying to rename a file in your server. Blocked.")
7256
return
7257
if not session['popup_renameable']:
7258
print("Someone is trying to rename a file in your server. Blocked.")
7259
return
7260
7261
if session['popup_jailed_dir'] is None:
7262
os.rename(data['file'], data['new_name'])
7263
get_files_folders(os.path.dirname(data['file']))
7264
elif session['popup_jailed_dir'] in data:
7265
os.rename(data['file'], data['new_name'])
7266
get_files_folders(os.path.dirname(data['file']))
7267
else:
7268
print("User is trying to rename files in your server outside the jail. Blocked. Jailed Dir: {} Requested Dir: {}".format(session['popup_jailed_dir'], data['file']))
7269
7270
7271
@socketio.on('popup_delete')
7272
def popup_delete(data):
7273
if 'popup_deletable' not in session:
7274
print("Someone is trying to delete a file in your server. Blocked.")
7275
return
7276
if not session['popup_deletable']:
7277
print("Someone is trying to delete a file in your server. Blocked.")
7278
return
7279
7280
if session['popup_jailed_dir'] is None:
7281
import shutil
7282
if os.path.isdir(data):
7283
shutil.rmtree(data)
7284
else:
7285
os.remove(data)
7286
path = os.path.abspath(data).replace("\\", "/")
7287
if path[-1] == "/":
7288
path = path[:-1]
7289
path = "/".join(path.split("/")[:-1])
7290
get_files_folders(path)
7291
elif session['popup_jailed_dir'] in data:
7292
import shutil
7293
if os.path.isdir(data):
7294
shutil.rmtree(data)
7295
else:
7296
os.remove(data)
7297
path = os.path.abspath(data).replace("\\", "/")
7298
if path[-1] == "/":
7299
path = path[:-1]
7300
path = "/".join(path.split("/")[:-1])
7301
get_files_folders(path)
7302
else:
7303
print("User is trying to delete files in your server outside the jail. Blocked. Jailed Dir: {} Requested Dir: {}".format(session['popup_jailed_dir'], data))
7304
7305
@socketio.on('popup_edit')
7306
def popup_edit(data):
7307
if 'popup_editable' not in session:
7308
print("Someone is trying to edit a file in your server. Blocked.")
7309
return
7310
if not session['popup_editable']:
7311
print("Someone is trying to edit a file in your server. Blocked.")
7312
return
7313
7314
if session['popup_jailed_dir'] is None:
7315
emit("popup_edit_file", {"file": data, "text": open(data, 'r', encoding='utf-8').read()});
7316
elif session['popup_jailed_dir'] in data:
7317
emit("popup_edit_file", {"file": data, "text": open(data, 'r', encoding='utf-8').read()});
7318
else:
7319
print("User is trying to delete files in your server outside the jail. Blocked. Jailed Dir: {} Requested Dir: {}".format(session['popup_jailed_dir'], data))
7320
7321
@socketio.on('popup_change_file')
7322
def popup_change_file(data):
7323
if 'popup_editable' not in session:
7324
print("Someone is trying to edit a file in your server. Blocked.")
7325
return
7326
if not session['popup_editable']:
7327
print("Someone is trying to edit a file in your server. Blocked.")
7328
return
7329
7330
if session['popup_jailed_dir'] is None:
7331
with open(data['file'], 'w') as f:
7332
f.write(data['data'])
7333
elif session['popup_jailed_dir'] in data['file']:
7334
with open(data['file'], 'w') as f:
7335
f.write(data['data'])
7336
else:
7337
print("User is trying to delete files in your server outside the jail. Blocked. Jailed Dir: {} Requested Dir: {}".format(session['popup_jailed_dir'], data))
7338
7339
def file_popup(popup_title, starting_folder, return_event, upload=True, jailed=True, folder_only=True, renameable=False, deleteable=False, editable=False, show_breadcrumbs=True, item_check=None, show_hidden=False):
7340
#starting_folder = The folder we're going to get folders and/or items from
7341
#return_event = the socketio event that will be emitted when the load button is clicked
7342
#jailed = if set to true will look for the session variable jailed_folder and prevent navigation outside of that folder
7343
#folder_only = will only show folders, no files
7344
#deletable = will show the delete icons/methods.
7345
#editable = will show the edit icons/methods
7346
#show_breadcrumbs = will show the breadcrumbs at the top of the screen
7347
#item_check will call this function to check if the item is valid as a selection if not none. Will pass absolute directory as only argument to function
7348
#show_hidden = ... really, you have to ask?
7349
if jailed:
7350
session['popup_jailed_dir'] = os.path.abspath(starting_folder).replace("\\", "/")
7351
else:
7352
session['popup_jailed_dir'] = None
7353
session['popup_deletable'] = deleteable
7354
session['popup_renameable'] = renameable
7355
session['popup_editable'] = editable
7356
session['popup_show_hidden'] = show_hidden
7357
session['popup_item_check'] = item_check
7358
session['popup_folder_only'] = folder_only
7359
session['popup_show_breadcrumbs'] = show_breadcrumbs
7360
session['upload'] = upload
7361
7362
socketio.emit("load_popup", {"popup_title": popup_title, "call_back": return_event, "renameable": renameable, "deleteable": deleteable, "editable": editable, 'upload': upload}, broadcast=True)
7363
7364
get_files_folders(starting_folder)
7365
7366
7367
def get_files_folders(starting_folder):
7368
import stat
7369
session['current_folder'] = os.path.abspath(starting_folder).replace("\\", "/")
7370
item_check = session['popup_item_check']
7371
show_breadcrumbs = session['popup_show_breadcrumbs']
7372
show_hidden = session['popup_show_hidden']
7373
folder_only = session['popup_folder_only']
7374
7375
if starting_folder == 'This PC':
7376
breadcrumbs = [['This PC', 'This PC']]
7377
items = [["{}:/".format(chr(i)), "{}:\\".format(chr(i))] for i in range(65, 91) if os.path.exists("{}:".format(chr(i)))]
7378
else:
7379
path = os.path.abspath(starting_folder).replace("\\", "/")
7380
if path[-1] == "/":
7381
path = path[:-1]
7382
breadcrumbs = []
7383
for i in range(len(path.split("/"))):
7384
breadcrumbs.append(["/".join(path.split("/")[:i+1]),
7385
path.split("/")[i]])
7386
if len(breadcrumbs) == 1:
7387
breadcrumbs = [["{}:/".format(chr(i)), "{}:\\".format(chr(i))] for i in range(65, 91) if os.path.exists("{}:".format(chr(i)))]
7388
else:
7389
if len([["{}:/".format(chr(i)), "{}:\\".format(chr(i))] for i in range(65, 91) if os.path.exists("{}:".format(chr(i)))]) > 0:
7390
breadcrumbs.insert(0, ['This PC', 'This PC'])
7391
7392
#if we're jailed, remove the stuff before the jail from the breadcrumbs
7393
if session['popup_jailed_dir'] is not None:
7394
7395
breadcrumbs = breadcrumbs[len(session['popup_jailed_dir'].split("/")):]
7396
7397
folders = []
7398
files = []
7399
base_path = os.path.abspath(starting_folder).replace("\\", "/")
7400
for item in os.listdir(base_path):
7401
item_full_path = os.path.join(base_path, item).replace("\\", "/")
7402
if hasattr(os.stat(item_full_path), "st_file_attributes"):
7403
hidden = bool(os.stat(item_full_path).st_file_attributes & stat.FILE_ATTRIBUTE_HIDDEN)
7404
else:
7405
hidden = item[0] == "."
7406
if item_check is None:
7407
valid_selection = True
7408
else:
7409
valid_selection = item_check(item_full_path)
7410
7411
if (show_hidden and hidden) or not hidden:
7412
if os.path.isdir(os.path.join(base_path, item)):
7413
folders.append([True, item_full_path, item, valid_selection])
7414
else:
7415
files.append([False, item_full_path, item, valid_selection])
7416
items = folders
7417
if not folder_only:
7418
items += files
7419
7420
socketio.emit("popup_items", items, broadcast=True, include_self=True)
7421
if show_breadcrumbs:
7422
socketio.emit("popup_breadcrumbs", breadcrumbs, broadcast=True)
7423
7424
7425
class EmptySchema(KoboldSchema):
7426
pass
7427
7428
class BasicTextResultInnerSchema(KoboldSchema):
7429
text: str = fields.String(required=True)
7430
7431
class BasicTextResultSchema(KoboldSchema):
7432
result: BasicTextResultInnerSchema = fields.Nested(BasicTextResultInnerSchema)
7433
7434
class BasicResultInnerSchema(KoboldSchema):
7435
result: str = fields.String(required=True)
7436
7437
class BasicResultSchema(KoboldSchema):
7438
result: BasicResultInnerSchema = fields.Nested(BasicResultInnerSchema, required=True)
7439
7440
class BasicResultsSchema(KoboldSchema):
7441
results: BasicResultInnerSchema = fields.List(fields.Nested(BasicResultInnerSchema), required=True)
7442
7443
class BasicStringSchema(KoboldSchema):
7444
value: str = fields.String(required=True)
7445
7446
class BasicBooleanSchema(KoboldSchema):
7447
value: bool = fields.Boolean(required=True)
7448
7449
class BasicUIDSchema(KoboldSchema):
7450
uid: str = fields.Integer(required=True, validate=validate.Range(min=-2147483648, max=2147483647), metadata={"description": "32-bit signed integer unique to this world info entry/folder."})
7451
7452
class BasicErrorSchema(KoboldSchema):
7453
msg: str = fields.String(required=True)
7454
type: str = fields.String(required=True)
7455
7456
class StoryEmptyErrorSchema(KoboldSchema):
7457
detail: BasicErrorSchema = fields.Nested(BasicErrorSchema, required=True)
7458
7459
class StoryTooShortErrorSchema(KoboldSchema):
7460
detail: BasicErrorSchema = fields.Nested(BasicErrorSchema, required=True)
7461
7462
class OutOfMemoryErrorSchema(KoboldSchema):
7463
detail: BasicErrorSchema = fields.Nested(BasicErrorSchema, required=True)
7464
7465
class NotFoundErrorSchema(KoboldSchema):
7466
detail: BasicErrorSchema = fields.Nested(BasicErrorSchema, required=True)
7467
7468
api_out_of_memory_response = """507:
7469
description: Out of memory
7470
content:
7471
application/json:
7472
schema: OutOfMemoryErrorSchema
7473
examples:
7474
gpu.cuda:
7475
value:
7476
detail:
7477
msg: "KoboldAI ran out of memory: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 4.00 GiB total capacity; 2.97 GiB already allocated; 0 bytes free; 2.99 GiB reserved in total by PyTorch)"
7478
type: out_of_memory.gpu.cuda
7479
gpu.hip:
7480
value:
7481
detail:
7482
msg: "KoboldAI ran out of memory: HIP out of memory. Tried to allocate 20.00 MiB (GPU 0; 4.00 GiB total capacity; 2.97 GiB already allocated; 0 bytes free; 2.99 GiB reserved in total by PyTorch)"
7483
type: out_of_memory.gpu.hip
7484
tpu.hbm:
7485
value:
7486
detail:
7487
msg: "KoboldAI ran out of memory: Compilation failed: Compilation failure: Ran out of memory in memory space hbm. Used 8.83G of 8.00G hbm. Exceeded hbm capacity by 848.88M."
7488
type: out_of_memory.tpu.hbm
7489
cpu.default_cpu_allocator:
7490
value:
7491
detail:
7492
msg: "KoboldAI ran out of memory: DefaultCPUAllocator: not enough memory: you tried to allocate 209715200 bytes."
7493
type: out_of_memory.cpu.default_cpu_allocator
7494
unknown.unknown:
7495
value:
7496
detail:
7497
msg: "KoboldAI ran out of memory."
7498
type: out_of_memory.unknown.unknown"""
7499
7500
class ValidationErrorSchema(KoboldSchema):
7501
detail: Dict[str, List[str]] = fields.Dict(keys=fields.String(), values=fields.List(fields.String(), validate=validate.Length(min=1)), required=True)
7502
7503
api_validation_error_response = """422:
7504
description: Validation error
7505
content:
7506
application/json:
7507
schema: ValidationErrorSchema"""
7508
7509
class ServerBusyErrorSchema(KoboldSchema):
7510
detail: BasicErrorSchema = fields.Nested(BasicErrorSchema, required=True)
7511
7512
api_server_busy_response = """503:
7513
description: Server is busy
7514
content:
7515
application/json:
7516
schema: ServerBusyErrorSchema
7517
example:
7518
detail:
7519
msg: Server is busy; please try again later.
7520
type: service_unavailable"""
7521
7522
class NotImplementedErrorSchema(KoboldSchema):
7523
detail: BasicErrorSchema = fields.Nested(BasicErrorSchema, required=True)
7524
7525
api_not_implemented_response = """501:
7526
description: Not implemented
7527
content:
7528
application/json:
7529
schema: NotImplementedErrorSchema
7530
example:
7531
detail:
7532
msg: API generation is not supported in read-only mode; please load a model and then try again.
7533
type: not_implemented"""
7534
7535
class SamplerSettingsSchema(KoboldSchema):
7536
rep_pen: Optional[float] = fields.Float(validate=validate.Range(min=1), metadata={"description": "Base repetition penalty value."})
7537
rep_pen_range: Optional[int] = fields.Integer(validate=validate.Range(min=0), metadata={"description": "Repetition penalty range."})
7538
rep_pen_slope: Optional[float] = fields.Float(validate=validate.Range(min=0), metadata={"description": "Repetition penalty slope."})
7539
top_k: Optional[int] = fields.Integer(validate=validate.Range(min=0), metadata={"description": "Top-k sampling value."})
7540
top_a: Optional[float] = fields.Float(validate=validate.Range(min=0), metadata={"description": "Top-a sampling value."})
7541
top_p: Optional[float] = fields.Float(validate=validate.Range(min=0, max=1), metadata={"description": "Top-p sampling value."})
7542
tfs: Optional[float] = fields.Float(validate=validate.Range(min=0, max=1), metadata={"description": "Tail free sampling value."})
7543
typical: Optional[float] = fields.Float(validate=validate.Range(min=0, max=1), metadata={"description": "Typical sampling value."})
7544
temperature: Optional[float] = fields.Float(validate=validate.Range(min=0, min_inclusive=False), metadata={"description": "Temperature value."})
7545
7546
def soft_prompt_validator(soft_prompt: str):
7547
if len(soft_prompt.strip()) == 0:
7548
return
7549
if not vars.allowsp:
7550
raise ValidationError("Cannot use soft prompts with current backend.")
7551
if any(q in soft_prompt for q in ("/", "\\")):
7552
return
7553
z, _, _, _, _ = fileops.checksp(soft_prompt.strip(), vars.modeldim)
7554
if isinstance(z, int):
7555
raise ValidationError("Must be a valid soft prompt name.")
7556
z.close()
7557
return True
7558
7559
def story_load_validator(name: str):
7560
if any(q in name for q in ("/", "\\")):
7561
return
7562
if len(name.strip()) == 0 or not os.path.isfile(fileops.storypath(name)):
7563
raise ValidationError("Must be a valid story name.")
7564
return True
7565
7566
def permutation_validator(lst: list):
7567
if any(not isinstance(e, int) for e in lst):
7568
return
7569
if min(lst) != 0 or max(lst) != len(lst) - 1 or len(set(lst)) != len(lst):
7570
raise ValidationError("Must be a permutation of the first N non-negative integers, where N is the length of this array")
7571
return True
7572
7573
class GenerationInputSchema(SamplerSettingsSchema):
7574
prompt: str = fields.String(required=True, metadata={"description": "This is the submission."})
7575
use_memory: bool = fields.Boolean(load_default=False, metadata={"description": "Whether or not to use the memory from the KoboldAI GUI when generating text."})
7576
use_story: bool = fields.Boolean(load_default=False, metadata={"description": "Whether or not to use the story from the KoboldAI GUI when generating text."})
7577
use_authors_note: bool = fields.Boolean(load_default=False, metadata={"description": "Whether or not to use the author's note from the KoboldAI GUI when generating text. This has no effect unless `use_story` is also enabled."})
7578
use_world_info: bool = fields.Boolean(load_default=False, metadata={"description": "Whether or not to use the world info from the KoboldAI GUI when generating text."})
7579
use_userscripts: bool = fields.Boolean(load_default=False, metadata={"description": "Whether or not to use the userscripts from the KoboldAI GUI when generating text."})
7580
soft_prompt: Optional[str] = fields.String(metadata={"description": "Soft prompt to use when generating. If set to the empty string or any other string containing no non-whitespace characters, uses no soft prompt."}, validate=[soft_prompt_validator, validate.Regexp(r"^[^/\\]*$")])
7581
max_length: int = fields.Integer(validate=validate.Range(min=1, max=512), metadata={"description": "Number of tokens to generate."})
7582
max_context_length: int = fields.Integer(validate=validate.Range(min=512, max=2048), metadata={"description": "Maximum number of tokens to send to the model."})
7583
n: int = fields.Integer(validate=validate.Range(min=1, max=5), metadata={"description": "Number of outputs to generate."})
7584
disable_output_formatting: bool = fields.Boolean(load_default=True, metadata={"description": "When enabled, all output formatting options default to `false` instead of the value in the KoboldAI GUI."})
7585
frmttriminc: Optional[bool] = fields.Boolean(metadata={"description": "Output formatting option. When enabled, removes some characters from the end of the output such that the output doesn't end in the middle of a sentence. If the output is less than one sentence long, does nothing.\n\nIf `disable_output_formatting` is `true`, this defaults to `false` instead of the value in the KoboldAI GUI."})
7586
frmtrmblln: Optional[bool] = fields.Boolean(metadata={"description": "Output formatting option. When enabled, replaces all occurrences of two or more consecutive newlines in the output with one newline.\n\nIf `disable_output_formatting` is `true`, this defaults to `false` instead of the value in the KoboldAI GUI."})
7587
frmtrmspch: Optional[bool] = fields.Boolean(metadata={"description": "Output formatting option. When enabled, removes `#/@%{}+=~|\^<>` from the output.\n\nIf `disable_output_formatting` is `true`, this defaults to `false` instead of the value in the KoboldAI GUI."})
7588
singleline: Optional[bool] = fields.Boolean(metadata={"description": "Output formatting option. When enabled, removes everything after the first line of the output, including the newline.\n\nIf `disable_output_formatting` is `true`, this defaults to `false` instead of the value in the KoboldAI GUI."})
7589
disable_input_formatting: bool = fields.Boolean(load_default=True, metadata={"description": "When enabled, all input formatting options default to `false` instead of the value in the KoboldAI GUI"})
7590
frmtadsnsp: Optional[bool] = fields.Boolean(metadata={"description": "Input formatting option. When enabled, adds a leading space to your input if there is no trailing whitespace at the end of the previous action.\n\nIf `disable_input_formatting` is `true`, this defaults to `false` instead of the value in the KoboldAI GUI."})
7591
quiet: Optional[bool] = fields.Boolean(metadata={"description": "When enabled, Generated output will not be displayed in the console."})
7592
sampler_order: Optional[List[int]] = fields.List(fields.Integer(), validate=[validate.Length(min=6), permutation_validator], metadata={"description": "Sampler order to be used. If N is the length of this array, then N must be greater than or equal to 6 and the array must be a permutation of the first N non-negative integers."})
7593
sampler_seed: Optional[int] = fields.Integer(validate=validate.Range(min=0, max=2**64 - 1), metadata={"description": "RNG seed to use for sampling. If not specified, the global RNG will be used."})
7594
sampler_full_determinism: Optional[bool] = fields.Boolean(metadata={"description": "If enabled, the generated text will always be the same as long as you use the same RNG seed, input and settings. If disabled, only the *sequence* of generated texts that you get when repeatedly generating text will be the same given the same RNG seed, input and settings."})
7595
7596
class GenerationResultSchema(KoboldSchema):
7597
text: str = fields.String(required=True, metadata={"description": "Generated output as plain text."})
7598
7599
class GenerationOutputSchema(KoboldSchema):
7600
results: List[GenerationResultSchema] = fields.List(fields.Nested(GenerationResultSchema), required=True, metadata={"description": "Array of generated outputs."})
7601
7602
class StoryNumsChunkSchema(KoboldSchema):
7603
num: int = fields.Integer(required=True, metadata={"description": "Guaranteed to not equal the `num` of any other active story chunk. Equals 0 iff this is the first action of the story (the prompt)."})
7604
7605
class StoryChunkSchema(StoryNumsChunkSchema, KoboldSchema):
7606
text: str = fields.String(required=True, metadata={"description": "The text inside this story chunk."})
7607
7608
class StorySchema(KoboldSchema):
7609
results: List[StoryChunkSchema] = fields.List(fields.Nested(StoryChunkSchema), required=True, metadata={"description": "Array of story actions. The array is sorted such that actions closer to the end of this array are closer to the end of the story."})
7610
7611
class BasicBooleanSchema(KoboldSchema):
7612
result: bool = fields.Boolean(required=True)
7613
7614
class StoryNumsSchema(KoboldSchema):
7615
results: List[int] = fields.List(fields.Integer(), required=True, metadata={"description": "Array of story action nums. The array is sorted such that actions closer to the end of this array are closer to the end of the story."})
7616
7617
class StoryChunkResultSchema(KoboldSchema):
7618
result: StoryChunkSchema = fields.Nested(StoryChunkSchema, required=True)
7619
7620
class StoryChunkNumSchema(KoboldSchema):
7621
value: int = fields.Integer(required=True)
7622
7623
class StoryChunkTextSchema(KoboldSchema):
7624
value: str = fields.String(required=True)
7625
7626
class StoryChunkSetTextSchema(KoboldSchema):
7627
value: str = fields.String(required=True, validate=validate.Regexp(r"^(.|\n)*\S$"))
7628
7629
class StoryLoadSchema(KoboldSchema):
7630
name: str = fields.String(required=True, validate=[story_load_validator, validate.Regexp(r"^[^/\\]*$")])
7631
7632
class StorySaveSchema(KoboldSchema):
7633
name: str = fields.String(required=True, validate=validate.Regexp(r"^(?=.*\S)(?!.*[/\\]).*$"))
7634
7635
class WorldInfoEntrySchema(KoboldSchema):
7636
uid: int = fields.Integer(required=True, validate=validate.Range(min=-2147483648, max=2147483647), metadata={"description": "32-bit signed integer unique to this world info entry."})
7637
content: str = fields.String(required=True, metadata={"description": "The \"What To Remember\" for this entry."})
7638
key: str = fields.String(required=True, metadata={"description": "Comma-separated list of keys, or of primary keys if selective mode is enabled."})
7639
keysecondary: str = fields.String(metadata={"description": "Comma-separated list of secondary keys if selective mode is enabled."})
7640
selective: bool = fields.Boolean(required=True, metadata={"description": "Whether or not selective mode is enabled for this world info entry."})
7641
constant: bool = fields.Boolean(required=True, metadata={"description": "Whether or not constant mode is enabled for this world info entry."})
7642
comment: bool = fields.String(required=True, metadata={"description": "The comment/description/title for this world info entry."})
7643
7644
class WorldInfoEntryResultSchema(KoboldSchema):
7645
result: WorldInfoEntrySchema = fields.Nested(WorldInfoEntrySchema, required=True)
7646
7647
class WorldInfoFolderBasicSchema(KoboldSchema):
7648
uid: int = fields.Integer(required=True, validate=validate.Range(min=-2147483648, max=2147483647), metadata={"description": "32-bit signed integer unique to this world info folder."})
7649
name: str = fields.String(required=True, metadata={"description": "Name of this world info folder."})
7650
7651
class WorldInfoFolderSchema(WorldInfoFolderBasicSchema):
7652
entries: List[WorldInfoEntrySchema] = fields.List(fields.Nested(WorldInfoEntrySchema), required=True)
7653
7654
class WorldInfoFolderUIDsSchema(KoboldSchema):
7655
uid: int = fields.Integer(required=True, validate=validate.Range(min=-2147483648, max=2147483647), metadata={"description": "32-bit signed integer unique to this world info folder."})
7656
entries: List[int] = fields.List(fields.Integer(required=True, validate=validate.Range(min=-2147483648, max=2147483647), metadata={"description": "32-bit signed integer unique to this world info entry."}), required=True)
7657
7658
class WorldInfoEntriesSchema(KoboldSchema):
7659
entries: List[WorldInfoEntrySchema] = fields.List(fields.Nested(WorldInfoEntrySchema), required=True)
7660
7661
class WorldInfoFoldersSchema(KoboldSchema):
7662
folders: List[WorldInfoFolderBasicSchema] = fields.List(fields.Nested(WorldInfoFolderBasicSchema), required=True)
7663
7664
class WorldInfoSchema(WorldInfoEntriesSchema):
7665
folders: List[WorldInfoFolderSchema] = fields.List(fields.Nested(WorldInfoFolderSchema), required=True)
7666
7667
class WorldInfoEntriesUIDsSchema(KoboldSchema):
7668
entries: List[int] = fields.List(fields.Integer(required=True, validate=validate.Range(min=-2147483648, max=2147483647), metadata={"description": "32-bit signed integer unique to this world info entry."}), required=True)
7669
7670
class WorldInfoFoldersUIDsSchema(KoboldSchema):
7671
folders: List[int] = fields.List(fields.Integer(required=True, validate=validate.Range(min=-2147483648, max=2147483647), metadata={"description": "32-bit signed integer unique to this world info folder."}), required=True)
7672
7673
class WorldInfoUIDsSchema(WorldInfoEntriesUIDsSchema):
7674
folders: List[WorldInfoFolderSchema] = fields.List(fields.Nested(WorldInfoFolderUIDsSchema), required=True)
7675
7676
class ModelSelectionSchema(KoboldSchema):
7677
model: str = fields.String(required=True, validate=validate.Regexp(r"^(?!\s*NeoCustom)(?!\s*GPT2Custom)(?!\s*TPUMeshTransformerGPTJ)(?!\s*TPUMeshTransformerGPTNeoX)(?!\s*GooseAI)(?!\s*OAI)(?!\s*InferKit)(?!\s*Colab)(?!\s*API).*$"), metadata={"description": 'Hugging Face model ID, the path to a model folder (relative to the "models" folder in the KoboldAI root folder) or "ReadOnly" for no model'})
7678
7679
def _generate_text(body: GenerationInputSchema):
7680
if vars.aibusy or vars.genseqs:
7681
abort(Response(json.dumps({"detail": {
7682
"msg": "Server is busy; please try again later.",
7683
"type": "service_unavailable",
7684
}}), mimetype="application/json", status=503))
7685
if vars.use_colab_tpu:
7686
import tpu_mtj_backend
7687
if hasattr(body, "sampler_seed"):
7688
# If a seed was specified, we need to save the global RNG state so we
7689
# can restore it later
7690
old_seed = vars.seed
7691
old_rng_state = tpu_mtj_backend.get_rng_state() if vars.use_colab_tpu else torch.get_rng_state()
7692
vars.seed = body.sampler_seed
7693
# We should try to use a previously saved RNG state with the same seed
7694
if body.sampler_seed in vars.rng_states:
7695
if vars.use_colab_tpu:
7696
tpu_mtj_backend.set_rng_state(vars.rng_states[body.sampler_seed])
7697
else:
7698
torch.set_rng_state(vars.rng_states[body.sampler_seed])
7699
else:
7700
if vars.use_colab_tpu:
7701
tpu_mtj_backend.set_rng_state(tpu_mtj_backend.new_rng_state(body.sampler_seed))
7702
else:
7703
torch.manual_seed(body.sampler_seed)
7704
vars.rng_states[body.sampler_seed] = tpu_mtj_backend.get_rng_state() if vars.use_colab_tpu else torch.get_rng_state()
7705
if hasattr(body, "sampler_order"):
7706
if len(body.sampler_order) < 7:
7707
body.sampler_order = [6] + body.sampler_order
7708
# This maps each property of the setting to use when sending the generate idempotently
7709
# To the object which typically contains it's value
7710
# This allows to set the property only for the API generation, and then revert the setting
7711
# To what it was before.
7712
mapping = {
7713
"disable_input_formatting": ("vars", "disable_input_formatting", None),
7714
"disable_output_formatting": ("vars", "disable_output_formatting", None),
7715
"rep_pen": ("vars", "rep_pen", None),
7716
"rep_pen_range": ("vars", "rep_pen_range", None),
7717
"rep_pen_slope": ("vars", "rep_pen_slope", None),
7718
"top_k": ("vars", "top_k", None),
7719
"top_a": ("vars", "top_a", None),
7720
"top_p": ("vars", "top_p", None),
7721
"tfs": ("vars", "tfs", None),
7722
"typical": ("vars", "typical", None),
7723
"temperature": ("vars", "temp", None),
7724
"frmtadsnsp": ("vars.formatoptns", "@frmtadsnsp", "input"),
7725
"frmttriminc": ("vars.formatoptns", "@frmttriminc", "output"),
7726
"frmtrmblln": ("vars.formatoptns", "@frmtrmblln", "output"),
7727
"frmtrmspch": ("vars.formatoptns", "@frmtrmspch", "output"),
7728
"singleline": ("vars.formatoptns", "@singleline", "output"),
7729
"max_length": ("vars", "genamt", None),
7730
"max_context_length": ("vars", "max_length", None),
7731
"n": ("vars", "numseqs", None),
7732
"quiet": ("vars", "quiet", None),
7733
"sampler_order": ("vars", "sampler_order", None),
7734
"sampler_full_determinism": ("vars", "full_determinism", None),
7735
}
7736
saved_settings = {}
7737
set_aibusy(1)
7738
disable_set_aibusy = vars.disable_set_aibusy
7739
vars.disable_set_aibusy = True
7740
_standalone = vars.standalone
7741
vars.standalone = True
7742
show_probs = vars.show_probs
7743
vars.show_probs = False
7744
output_streaming = vars.output_streaming
7745
vars.output_streaming = False
7746
for key, entry in mapping.items():
7747
obj = {"vars": vars, "vars.formatoptns": vars.formatoptns}[entry[0]]
7748
if entry[2] == "input" and vars.disable_input_formatting and not hasattr(body, key):
7749
setattr(body, key, False)
7750
if entry[2] == "output" and vars.disable_output_formatting and not hasattr(body, key):
7751
setattr(body, key, False)
7752
if getattr(body, key, None) is not None:
7753
if entry[1].startswith("@"):
7754
saved_settings[key] = obj[entry[1][1:]]
7755
obj[entry[1][1:]] = getattr(body, key)
7756
else:
7757
saved_settings[key] = getattr(obj, entry[1])
7758
setattr(obj, entry[1], getattr(body, key))
7759
try:
7760
if vars.allowsp and getattr(body, "soft_prompt", None) is not None:
7761
if any(q in body.soft_prompt for q in ("/", "\\")):
7762
raise RuntimeError
7763
old_spfilename = vars.spfilename
7764
spRequest(body.soft_prompt.strip())
7765
genout = apiactionsubmit(body.prompt, use_memory=body.use_memory, use_story=body.use_story, use_world_info=body.use_world_info, use_authors_note=body.use_authors_note)
7766
output = {"results": [{"text": txt} for txt in genout]}
7767
finally:
7768
for key in saved_settings:
7769
entry = mapping[key]
7770
obj = {"vars": vars, "vars.formatoptns": vars.formatoptns}[entry[0]]
7771
if getattr(body, key, None) is not None:
7772
if entry[1].startswith("@"):
7773
if obj[entry[1][1:]] == getattr(body, key):
7774
obj[entry[1][1:]] = saved_settings[key]
7775
else:
7776
if getattr(obj, entry[1]) == getattr(body, key):
7777
setattr(obj, entry[1], saved_settings[key])
7778
vars.disable_set_aibusy = disable_set_aibusy
7779
vars.standalone = _standalone
7780
vars.show_probs = show_probs
7781
vars.output_streaming = output_streaming
7782
if vars.allowsp and getattr(body, "soft_prompt", None) is not None:
7783
spRequest(old_spfilename)
7784
if hasattr(body, "sampler_seed"):
7785
vars.seed = old_seed
7786
if vars.use_colab_tpu:
7787
tpu_mtj_backend.set_rng_state(old_rng_state)
7788
else:
7789
torch.set_rng_state(old_rng_state)
7790
set_aibusy(0)
7791
return output
7792
7793
7794
@api_v1.get("/info/version")
7795
@api_schema_wrap
7796
def get_version():
7797
"""---
7798
get:
7799
summary: Current API version
7800
tags:
7801
- info
7802
description: |-2
7803
Returns the version of the API that you are currently using.
7804
responses:
7805
200:
7806
description: Successful request
7807
content:
7808
application/json:
7809
schema: BasicResultSchema
7810
example:
7811
result: 1.0.0
7812
"""
7813
return {"result": api_version}
7814
7815
7816
@api_v1.get("/info/version/latest")
7817
@api_schema_wrap
7818
def get_version_latest():
7819
"""---
7820
get:
7821
summary: Latest API version
7822
tags:
7823
- info
7824
description: |-2
7825
Returns the latest API version available.
7826
responses:
7827
200:
7828
description: Successful request
7829
content:
7830
application/json:
7831
schema: BasicResultSchema
7832
example:
7833
result: 1.0.0
7834
"""
7835
return {"result": api_versions[-1]}
7836
7837
7838
@api_v1.get("/info/version/list")
7839
@api_schema_wrap
7840
def get_version_list():
7841
"""---
7842
get:
7843
summary: List API versions
7844
tags:
7845
- info
7846
description: |-2
7847
Returns a list of available API versions sorted in ascending order.
7848
responses:
7849
200:
7850
description: Successful request
7851
content:
7852
application/json:
7853
schema: BasicResultsSchema
7854
example:
7855
results:
7856
- 1.0.0
7857
"""
7858
return {"results": api_versions}
7859
7860
7861
@api_v1.post("/generate")
7862
@api_schema_wrap
7863
def post_generate(body: GenerationInputSchema):
7864
"""---
7865
post:
7866
summary: Generate text
7867
tags:
7868
- generate
7869
description: |-2
7870
Generates text given a submission, sampler settings, soft prompt and number of return sequences.
7871
7872
By default, the story, userscripts, memory, author's note and world info are disabled.
7873
7874
Unless otherwise specified, optional values default to the values in the KoboldAI GUI.
7875
requestBody:
7876
required: true
7877
content:
7878
application/json:
7879
schema: GenerationInputSchema
7880
example:
7881
prompt: |-2
7882
Niko the kobold stalked carefully down the alley, his small scaly figure obscured by a dusky cloak that fluttered lightly in the cold winter breeze.
7883
top_p: 0.9
7884
temperature: 0.5
7885
responses:
7886
200:
7887
description: Successful request
7888
content:
7889
application/json:
7890
schema: GenerationOutputSchema
7891
example:
7892
results:
7893
- text: |-2
7894
Holding up his tail to keep it from dragging in the dirty snow that covered the cobblestone, he waited patiently for the butcher to turn his attention from his stall so that he could pilfer his next meal: a tender-looking chicken.
7895
{api_validation_error_response}
7896
{api_not_implemented_response}
7897
{api_server_busy_response}
7898
{api_out_of_memory_response}
7899
"""
7900
return _generate_text(body)
7901
7902
7903
@api_v1.get("/model")
7904
@api_schema_wrap
7905
def get_model():
7906
"""---
7907
get:
7908
summary: Retrieve the current model string
7909
description: |-2
7910
Gets the current model string, which is shown in the title of the KoboldAI GUI in parentheses, e.g. "KoboldAI Client (KoboldAI/fairseq-dense-13B-Nerys-v2)".
7911
tags:
7912
- model
7913
responses:
7914
200:
7915
description: Successful request
7916
content:
7917
application/json:
7918
schema: BasicResultSchema
7919
example:
7920
result: KoboldAI/fairseq-dense-13B-Nerys-v2
7921
"""
7922
return {"result": vars.model}
7923
7924
7925
@api_v1.put("/model")
7926
@api_schema_wrap
7927
def put_model(body: ModelSelectionSchema):
7928
"""---
7929
put:
7930
summary: Load a model
7931
description: |-2
7932
Loads a model given its Hugging Face model ID, the path to a model folder (relative to the "models" folder in the KoboldAI root folder) or "ReadOnly" for no model.
7933
tags:
7934
- model
7935
requestBody:
7936
required: true
7937
content:
7938
application/json:
7939
schema: ModelSelectionSchema
7940
example:
7941
model: ReadOnly
7942
responses:
7943
200:
7944
description: Successful request
7945
content:
7946
application/json:
7947
schema: EmptySchema
7948
{api_validation_error_response}
7949
{api_server_busy_response}
7950
"""
7951
if vars.aibusy or vars.genseqs:
7952
abort(Response(json.dumps({"detail": {
7953
"msg": "Server is busy; please try again later.",
7954
"type": "service_unavailable",
7955
}}), mimetype="application/json", status=503))
7956
set_aibusy(1)
7957
old_model = vars.model
7958
vars.model = body.model.strip()
7959
try:
7960
load_model(use_breakmodel_args=True, breakmodel_args_default_to_cpu=True)
7961
except Exception as e:
7962
vars.model = old_model
7963
raise e
7964
set_aibusy(0)
7965
return {}
7966
7967
7968
def prompt_validator(prompt: str):
7969
if len(prompt.strip()) == 0:
7970
raise ValidationError("String does not match expected pattern.")
7971
7972
class SubmissionInputSchema(KoboldSchema):
7973
prompt: str = fields.String(required=True, validate=prompt_validator, metadata={"pattern": r"^[\S\s]*\S[\S\s]*$", "description": "This is the submission."})
7974
disable_input_formatting: bool = fields.Boolean(load_default=True, metadata={"description": "When enabled, disables all input formatting options, overriding their individual enabled/disabled states."})
7975
frmtadsnsp: Optional[bool] = fields.Boolean(metadata={"description": "Input formatting option. When enabled, adds a leading space to your input if there is no trailing whitespace at the end of the previous action."})
7976
7977
@api_v1.post("/story/end")
7978
@api_schema_wrap
7979
def post_story_end(body: SubmissionInputSchema):
7980
"""---
7981
post:
7982
summary: Add an action to the end of the story
7983
tags:
7984
- story
7985
description: |-2
7986
Inserts a single action at the end of the story in the KoboldAI GUI without generating text.
7987
requestBody:
7988
required: true
7989
content:
7990
application/json:
7991
schema: SubmissionInputSchema
7992
example:
7993
prompt: |-2
7994
This is some text to put at the end of the story.
7995
responses:
7996
200:
7997
description: Successful request
7998
content:
7999
application/json:
8000
schema: EmptySchema
8001
{api_validation_error_response}
8002
{api_server_busy_response}
8003
"""
8004
if vars.aibusy or vars.genseqs:
8005
abort(Response(json.dumps({"detail": {
8006
"msg": "Server is busy; please try again later.",
8007
"type": "service_unavailable",
8008
}}), mimetype="application/json", status=503))
8009
set_aibusy(1)
8010
disable_set_aibusy = vars.disable_set_aibusy
8011
vars.disable_set_aibusy = True
8012
_standalone = vars.standalone
8013
vars.standalone = True
8014
numseqs = vars.numseqs
8015
vars.numseqs = 1
8016
try:
8017
actionsubmit(body.prompt, force_submit=True, no_generate=True, ignore_aibusy=True)
8018
finally:
8019
vars.disable_set_aibusy = disable_set_aibusy
8020
vars.standalone = _standalone
8021
vars.numseqs = numseqs
8022
set_aibusy(0)
8023
return {}
8024
8025
8026
@api_v1.get("/story/end")
8027
@api_schema_wrap
8028
def get_story_end():
8029
"""---
8030
get:
8031
summary: Retrieve the last action of the story
8032
tags:
8033
- story
8034
description: |-2
8035
Returns the last action of the story in the KoboldAI GUI.
8036
responses:
8037
200:
8038
description: Successful request
8039
content:
8040
application/json:
8041
schema: StoryChunkResultSchema
8042
510:
8043
description: Story is empty
8044
content:
8045
application/json:
8046
schema: StoryEmptyErrorSchema
8047
example:
8048
detail:
8049
msg: Could not retrieve the last action of the story because the story is empty.
8050
type: story_empty
8051
"""
8052
if not vars.gamestarted:
8053
abort(Response(json.dumps({"detail": {
8054
"msg": "Could not retrieve the last action of the story because the story is empty.",
8055
"type": "story_empty",
8056
}}), mimetype="application/json", status=510))
8057
if len(vars.actions) == 0:
8058
return {"result": {"text": vars.prompt, "num": 0}}
8059
return {"result": {"text": vars.actions[vars.actions.get_last_key()], "num": vars.actions.get_last_key() + 1}}
8060
8061
8062
@api_v1.get("/story/end/num")
8063
@api_schema_wrap
8064
def get_story_end_num():
8065
"""---
8066
get:
8067
summary: Retrieve the num of the last action of the story
8068
tags:
8069
- story
8070
description: |-2
8071
Returns the `num` of the last action of the story in the KoboldAI GUI.
8072
responses:
8073
200:
8074
description: Successful request
8075
content:
8076
application/json:
8077
schema: StoryChunkNumSchema
8078
510:
8079
description: Story is empty
8080
content:
8081
application/json:
8082
schema: StoryEmptyErrorSchema
8083
example:
8084
detail:
8085
msg: Could not retrieve the last action of the story because the story is empty.
8086
type: story_empty
8087
"""
8088
if not vars.gamestarted:
8089
abort(Response(json.dumps({"detail": {
8090
"msg": "Could not retrieve the last action of the story because the story is empty.",
8091
"type": "story_empty",
8092
}}), mimetype="application/json", status=510))
8093
if len(vars.actions) == 0:
8094
return {"result": {"text": 0}}
8095
return {"result": {"text": vars.actions.get_last_key() + 1}}
8096
8097
8098
@api_v1.get("/story/end/text")
8099
@api_schema_wrap
8100
def get_story_end_text():
8101
"""---
8102
get:
8103
summary: Retrieve the text of the last action of the story
8104
tags:
8105
- story
8106
description: |-2
8107
Returns the text of the last action of the story in the KoboldAI GUI.
8108
responses:
8109
200:
8110
description: Successful request
8111
content:
8112
application/json:
8113
schema: StoryChunkTextSchema
8114
510:
8115
description: Story is empty
8116
content:
8117
application/json:
8118
schema: StoryEmptyErrorSchema
8119
example:
8120
detail:
8121
msg: Could not retrieve the last action of the story because the story is empty.
8122
type: story_empty
8123
"""
8124
if not vars.gamestarted:
8125
abort(Response(json.dumps({"detail": {
8126
"msg": "Could not retrieve the last action of the story because the story is empty.",
8127
"type": "story_empty",
8128
}}), mimetype="application/json", status=510))
8129
if len(vars.actions) == 0:
8130
return {"result": {"text": vars.prompt}}
8131
return {"result": {"text": vars.actions[vars.actions.get_last_key()]}}
8132
8133
8134
@api_v1.put("/story/end/text")
8135
@api_schema_wrap
8136
def put_story_end_text(body: StoryChunkSetTextSchema):
8137
"""---
8138
put:
8139
summary: Set the text of the last action of the story
8140
tags:
8141
- story
8142
description: |-2
8143
Sets the text of the last action of the story in the KoboldAI GUI to the desired value.
8144
requestBody:
8145
required: true
8146
content:
8147
application/json:
8148
schema: StoryChunkSetTextSchema
8149
example:
8150
value: string
8151
responses:
8152
200:
8153
description: Successful request
8154
content:
8155
application/json:
8156
schema: EmptySchema
8157
510:
8158
description: Story is empty
8159
content:
8160
application/json:
8161
schema: StoryEmptyErrorSchema
8162
example:
8163
detail:
8164
msg: Could not retrieve the last action of the story because the story is empty.
8165
type: story_empty
8166
{api_validation_error_response}
8167
"""
8168
if not vars.gamestarted:
8169
abort(Response(json.dumps({"detail": {
8170
"msg": "Could not retrieve the last action of the story because the story is empty.",
8171
"type": "story_empty",
8172
}}), mimetype="application/json", status=510))
8173
value = body.value.rstrip()
8174
if len(vars.actions) == 0:
8175
inlineedit(0, value)
8176
else:
8177
inlineedit(vars.actions.get_last_key() + 1, value)
8178
return {}
8179
8180
8181
@api_v1.post("/story/end/delete")
8182
@api_schema_wrap
8183
def post_story_end_delete(body: EmptySchema):
8184
"""---
8185
post:
8186
summary: Remove the last action of the story
8187
tags:
8188
- story
8189
description: |-2
8190
Removes the last action of the story in the KoboldAI GUI.
8191
requestBody:
8192
required: true
8193
content:
8194
application/json:
8195
schema: EmptySchema
8196
responses:
8197
200:
8198
description: Successful request
8199
content:
8200
application/json:
8201
schema: EmptySchema
8202
510:
8203
description: Story too short
8204
content:
8205
application/json:
8206
schema: StoryTooShortErrorSchema
8207
example:
8208
detail:
8209
msg: Could not delete the last action of the story because the number of actions in the story is less than or equal to 1.
8210
type: story_too_short
8211
{api_validation_error_response}
8212
{api_server_busy_response}
8213
"""
8214
if vars.aibusy or vars.genseqs:
8215
abort(Response(json.dumps({"detail": {
8216
"msg": "Server is busy; please try again later.",
8217
"type": "service_unavailable",
8218
}}), mimetype="application/json", status=503))
8219
if not vars.gamestarted or not len(vars.actions):
8220
abort(Response(json.dumps({"detail": {
8221
"msg": "Could not delete the last action of the story because the number of actions in the story is less than or equal to 1.",
8222
"type": "story_too_short",
8223
}}), mimetype="application/json", status=510))
8224
actionback()
8225
return {}
8226
8227
8228
@api_v1.get("/story")
8229
@api_schema_wrap
8230
def get_story():
8231
"""---
8232
get:
8233
summary: Retrieve the entire story
8234
tags:
8235
- story
8236
description: |-2
8237
Returns the entire story currently shown in the KoboldAI GUI.
8238
responses:
8239
200:
8240
description: Successful request
8241
content:
8242
application/json:
8243
schema: StorySchema
8244
"""
8245
chunks = []
8246
if vars.gamestarted:
8247
chunks.append({"num": 0, "text": vars.prompt})
8248
for num, action in vars.actions.items():
8249
chunks.append({"num": num + 1, "text": action})
8250
return {"results": chunks}
8251
8252
8253
@api_v1.get("/story/nums")
8254
@api_schema_wrap
8255
def get_story_nums():
8256
"""---
8257
get:
8258
summary: Retrieve a list of the nums of the chunks in the current story
8259
tags:
8260
- story
8261
description: |-2
8262
Returns the `num`s of the story chunks currently shown in the KoboldAI GUI.
8263
responses:
8264
200:
8265
description: Successful request
8266
content:
8267
application/json:
8268
schema: StorySchema
8269
"""
8270
chunks = []
8271
if vars.gamestarted:
8272
chunks.append(0)
8273
for num in vars.actions.keys():
8274
chunks.append(num + 1)
8275
return {"results": chunks}
8276
8277
8278
@api_v1.get("/story/nums/<int(signed=True):num>")
8279
@api_schema_wrap
8280
def get_story_nums_num(num: int):
8281
"""---
8282
get:
8283
summary: Determine whether or not there is a story chunk with the given num
8284
tags:
8285
- story
8286
parameters:
8287
- name: num
8288
in: path
8289
description: |-2
8290
`num` of the desired story chunk.
8291
schema:
8292
type: integer
8293
responses:
8294
200:
8295
description: Successful request
8296
content:
8297
application/json:
8298
schema: BasicBooleanSchema
8299
"""
8300
if num == 0:
8301
return {"result": vars.gamestarted}
8302
return {"result": num - 1 in vars.actions}
8303
8304
8305
@api_v1.get("/story/<int(signed=True):num>")
8306
@api_schema_wrap
8307
def get_story_num(num: int):
8308
"""---
8309
get:
8310
summary: Retrieve a story chunk
8311
tags:
8312
- story
8313
description: |-2
8314
Returns information about a story chunk given its `num`.
8315
parameters:
8316
- name: num
8317
in: path
8318
description: |-2
8319
`num` of the desired story chunk.
8320
schema:
8321
type: integer
8322
responses:
8323
200:
8324
description: Successful request
8325
content:
8326
application/json:
8327
schema: StoryChunkResultSchema
8328
404:
8329
description: Not found
8330
content:
8331
application/json:
8332
schema: NotFoundErrorSchema
8333
example:
8334
detail:
8335
msg: No chunk with the given num exists.
8336
type: key_error
8337
"""
8338
if num == 0:
8339
if not vars.gamestarted:
8340
abort(Response(json.dumps({"detail": {
8341
"msg": "No chunk with the given num exists.",
8342
"type": "key_error",
8343
}}), mimetype="application/json", status=404))
8344
return {"result": {"text": vars.prompt, "num": num}}
8345
if num - 1 not in vars.actions:
8346
abort(Response(json.dumps({"detail": {
8347
"msg": "No chunk with the given num exists.",
8348
"type": "key_error",
8349
}}), mimetype="application/json", status=404))
8350
return {"result": {"text": vars.actions[num - 1], "num": num}}
8351
8352
8353
@api_v1.get("/story/<int(signed=True):num>/text")
8354
@api_schema_wrap
8355
def get_story_num_text(num: int):
8356
"""---
8357
get:
8358
summary: Retrieve the text of a story chunk
8359
tags:
8360
- story
8361
description: |-2
8362
Returns the text inside a story chunk given its `num`.
8363
parameters:
8364
- name: num
8365
in: path
8366
description: |-2
8367
`num` of the desired story chunk.
8368
schema:
8369
type: integer
8370
responses:
8371
200:
8372
description: Successful request
8373
content:
8374
application/json:
8375
schema: StoryChunkTextSchema
8376
404:
8377
description: Not found
8378
content:
8379
application/json:
8380
schema: NotFoundErrorSchema
8381
example:
8382
detail:
8383
msg: No chunk with the given num exists.
8384
type: key_error
8385
"""
8386
if num == 0:
8387
if not vars.gamestarted:
8388
abort(Response(json.dumps({"detail": {
8389
"msg": "No chunk with the given num exists.",
8390
"type": "key_error",
8391
}}), mimetype="application/json", status=404))
8392
return {"value": vars.prompt}
8393
if num - 1 not in vars.actions:
8394
abort(Response(json.dumps({"detail": {
8395
"msg": "No chunk with the given num exists.",
8396
"type": "key_error",
8397
}}), mimetype="application/json", status=404))
8398
return {"value": vars.actions[num - 1]}
8399
8400
8401
@api_v1.put("/story/<int(signed=True):num>/text")
8402
@api_schema_wrap
8403
def put_story_num_text(body: StoryChunkSetTextSchema, num: int):
8404
"""---
8405
put:
8406
summary: Set the text of a story chunk
8407
tags:
8408
- story
8409
description: |-2
8410
Sets the text inside a story chunk given its `num`.
8411
parameters:
8412
- name: num
8413
in: path
8414
description: |-2
8415
`num` of the desired story chunk.
8416
schema:
8417
type: integer
8418
requestBody:
8419
required: true
8420
content:
8421
application/json:
8422
schema: StoryChunkSetTextSchema
8423
example:
8424
value: string
8425
responses:
8426
200:
8427
description: Successful request
8428
content:
8429
application/json:
8430
schema: EmptySchema
8431
404:
8432
description: Not found
8433
content:
8434
application/json:
8435
schema: NotFoundErrorSchema
8436
example:
8437
detail:
8438
msg: No chunk with the given num exists.
8439
type: key_error
8440
{api_validation_error_response}
8441
"""
8442
if num == 0:
8443
if not vars.gamestarted:
8444
abort(Response(json.dumps({"detail": {
8445
"msg": "No chunk with the given num exists.",
8446
"type": "key_error",
8447
}}), mimetype="application/json", status=404))
8448
inlineedit(0, body.value.rstrip())
8449
return {}
8450
if num - 1 not in vars.actions:
8451
abort(Response(json.dumps({"detail": {
8452
"msg": "No chunk with the given num exists.",
8453
"type": "key_error",
8454
}}), mimetype="application/json", status=404))
8455
inlineedit(num, body.value.rstrip())
8456
return {}
8457
8458
8459
@api_v1.delete("/story/<int(signed=True):num>")
8460
@api_schema_wrap
8461
def post_story_num_delete(num: int):
8462
"""---
8463
delete:
8464
summary: Remove a story chunk
8465
tags:
8466
- story
8467
description: |-2
8468
Removes a story chunk from the story in the KoboldAI GUI given its `num`. Cannot be used to delete the first action (the prompt).
8469
parameters:
8470
- name: num
8471
in: path
8472
description: |-2
8473
`num` of the desired story chunk. Must be larger than or equal to 1.
8474
schema:
8475
type: integer
8476
minimum: 1
8477
responses:
8478
200:
8479
description: Successful request
8480
content:
8481
application/json:
8482
schema: EmptySchema
8483
404:
8484
description: Not found
8485
content:
8486
application/json:
8487
schema: NotFoundErrorSchema
8488
example:
8489
detail:
8490
msg: No chunk with the given num exists.
8491
type: key_error
8492
{api_server_busy_response}
8493
"""
8494
if num < 1:
8495
abort(Response(json.dumps({"detail": {
8496
"num": ["Must be greater than or equal to 1."],
8497
}}), mimetype="application/json", status=422))
8498
if num - 1 not in vars.actions:
8499
abort(Response(json.dumps({"detail": {
8500
"msg": "No chunk with the given num exists.",
8501
"type": "key_error",
8502
}}), mimetype="application/json", status=404))
8503
if vars.aibusy or vars.genseqs:
8504
abort(Response(json.dumps({"detail": {
8505
"msg": "Server is busy; please try again later.",
8506
"type": "service_unavailable",
8507
}}), mimetype="application/json", status=503))
8508
inlinedelete(num)
8509
return {}
8510
8511
8512
@api_v1.delete("/story")
8513
@api_schema_wrap
8514
def delete_story():
8515
"""---
8516
delete:
8517
summary: Clear the story
8518
tags:
8519
- story
8520
description: |-2
8521
Starts a new blank story.
8522
responses:
8523
200:
8524
description: Successful request
8525
content:
8526
application/json:
8527
schema: EmptySchema
8528
{api_server_busy_response}
8529
"""
8530
if vars.aibusy or vars.genseqs:
8531
abort(Response(json.dumps({"detail": {
8532
"msg": "Server is busy; please try again later.",
8533
"type": "service_unavailable",
8534
}}), mimetype="application/json", status=503))
8535
newGameRequest()
8536
return {}
8537
8538
8539
@api_v1.put("/story/load")
8540
@api_schema_wrap
8541
def put_story_load(body: StoryLoadSchema):
8542
"""---
8543
put:
8544
summary: Load a story
8545
tags:
8546
- story
8547
description: |-2
8548
Loads a story given its filename (without the .json).
8549
requestBody:
8550
required: true
8551
content:
8552
application/json:
8553
schema: StoryLoadSchema
8554
example:
8555
name: string
8556
responses:
8557
200:
8558
description: Successful request
8559
content:
8560
application/json:
8561
schema: EmptySchema
8562
{api_validation_error_response}
8563
{api_server_busy_response}
8564
"""
8565
if vars.aibusy or vars.genseqs:
8566
abort(Response(json.dumps({"detail": {
8567
"msg": "Server is busy; please try again later.",
8568
"type": "service_unavailable",
8569
}}), mimetype="application/json", status=503))
8570
loadRequest(fileops.storypath(body.name.strip()))
8571
return {}
8572
8573
8574
@api_v1.put("/story/save")
8575
@api_schema_wrap
8576
def put_story_save(body: StorySaveSchema):
8577
"""---
8578
put:
8579
summary: Save the current story
8580
tags:
8581
- story
8582
description: |-2
8583
Saves the current story given its destination filename (without the .json).
8584
requestBody:
8585
required: true
8586
content:
8587
application/json:
8588
schema: StorySaveSchema
8589
example:
8590
name: string
8591
responses:
8592
200:
8593
description: Successful request
8594
content:
8595
application/json:
8596
schema: EmptySchema
8597
{api_validation_error_response}
8598
"""
8599
saveRequest(fileops.storypath(body.name.strip()))
8600
return {}
8601
8602
8603
@api_v1.get("/world_info")
8604
@api_schema_wrap
8605
def get_world_info():
8606
"""---
8607
get:
8608
summary: Retrieve all world info entries
8609
tags:
8610
- world_info
8611
description: |-2
8612
Returns all world info entries currently shown in the KoboldAI GUI.
8613
8614
The `folders` are sorted in the same order as they are in the GUI and the `entries` within the folders and within the parent `result` object are all sorted in the same order as they are in their respective parts of the GUI.
8615
responses:
8616
200:
8617
description: Successful request
8618
content:
8619
application/json:
8620
schema: WorldInfoSchema
8621
"""
8622
folders = []
8623
entries = []
8624
ln = len(vars.worldinfo)
8625
stablesortwi()
8626
vars.worldinfo_i = [wi for wi in vars.worldinfo if wi["init"]]
8627
folder: Optional[list] = None
8628
if ln:
8629
last_folder = ...
8630
for wi in vars.worldinfo_i:
8631
if wi["folder"] != last_folder:
8632
folder = []
8633
if wi["folder"] is not None:
8634
folders.append({"uid": wi["folder"], "name": vars.wifolders_d[wi["folder"]]["name"], "entries": folder})
8635
last_folder = wi["folder"]
8636
(folder if wi["folder"] is not None else entries).append({k: v for k, v in wi.items() if k not in ("init", "folder", "num") and (wi["selective"] or k != "keysecondary")})
8637
return {"folders": folders, "entries": entries}
8638
8639
@api_v1.get("/world_info/uids")
8640
@api_schema_wrap
8641
def get_world_info_uids():
8642
"""---
8643
get:
8644
summary: Retrieve the UIDs of all world info entries
8645
tags:
8646
- world_info
8647
description: |-2
8648
Returns in a similar format as GET /world_info except only the `uid`s are returned.
8649
responses:
8650
200:
8651
description: Successful request
8652
content:
8653
application/json:
8654
schema: WorldInfoUIDsSchema
8655
"""
8656
folders = []
8657
entries = []
8658
ln = len(vars.worldinfo)
8659
stablesortwi()
8660
vars.worldinfo_i = [wi for wi in vars.worldinfo if wi["init"]]
8661
folder: Optional[list] = None
8662
if ln:
8663
last_folder = ...
8664
for wi in vars.worldinfo_i:
8665
if wi["folder"] != last_folder:
8666
folder = []
8667
if wi["folder"] is not None:
8668
folders.append({"uid": wi["folder"], "entries": folder})
8669
last_folder = wi["folder"]
8670
(folder if wi["folder"] is not None else entries).append(wi["uid"])
8671
return {"folders": folders, "entries": entries}
8672
8673
8674
@api_v1.get("/world_info/uids/<int(signed=True):uid>")
8675
@api_schema_wrap
8676
def get_world_info_uids_uid(uid: int):
8677
"""---
8678
get:
8679
summary: Determine whether or not there is a world info entry with the given UID
8680
tags:
8681
- world_info
8682
parameters:
8683
- name: uid
8684
in: path
8685
description: |-2
8686
`uid` of the desired world info entry.
8687
schema:
8688
type: integer
8689
minimum: -2147483648
8690
maximum: 2147483647
8691
responses:
8692
200:
8693
description: Successful request
8694
content:
8695
application/json:
8696
schema: BasicBooleanSchema
8697
"""
8698
return {"result": uid in vars.worldinfo_u and vars.worldinfo_u[uid]["init"]}
8699
8700
8701
@api_v1.get("/world_info/folders")
8702
@api_schema_wrap
8703
def get_world_info_folders():
8704
"""---
8705
get:
8706
summary: Retrieve all world info folders
8707
tags:
8708
- world_info
8709
description: |-2
8710
Returns details about all world info folders currently shown in the KoboldAI GUI.
8711
8712
The `folders` are sorted in the same order as they are in the GUI.
8713
responses:
8714
200:
8715
description: Successful request
8716
content:
8717
application/json:
8718
schema: WorldInfoFoldersSchema
8719
"""
8720
stablesortwi()
8721
vars.worldinfo_i = [wi for wi in vars.worldinfo if wi["init"]]
8722
return {"folders": [{"uid": folder, **{k: v for k, v in vars.wifolders_d[folder].items() if k != "collapsed"}} for folder in vars.wifolders_l]}
8723
8724
8725
@api_v1.get("/world_info/folders/uids")
8726
@api_schema_wrap
8727
def get_world_info_folders_uids():
8728
"""---
8729
get:
8730
summary: Retrieve the UIDs all world info folders
8731
tags:
8732
- world_info
8733
description: |-2
8734
Returns the `uid`s of all world info folders currently shown in the KoboldAI GUI.
8735
8736
The `folders` are sorted in the same order as they are in the GUI.
8737
responses:
8738
200:
8739
description: Successful request
8740
content:
8741
application/json:
8742
schema: WorldInfoFoldersUIDsSchema
8743
"""
8744
stablesortwi()
8745
vars.worldinfo_i = [wi for wi in vars.worldinfo if wi["init"]]
8746
return {"folders": vars.wifolders_l}
8747
8748
8749
@api_v1.get("/world_info/folders/none")
8750
@api_schema_wrap
8751
def get_world_info_folders_none():
8752
"""---
8753
get:
8754
summary: Retrieve all world info entries not in a folder
8755
tags:
8756
- world_info
8757
description: |-2
8758
Returns all world info entries that are not in a world info folder.
8759
8760
The `entries` are sorted in the same order as they are in the KoboldAI GUI.
8761
responses:
8762
200:
8763
description: Successful request
8764
content:
8765
application/json:
8766
schema: WorldInfoEntriesSchema
8767
"""
8768
entries = []
8769
stablesortwi()
8770
vars.worldinfo_i = [wi for wi in vars.worldinfo if wi["init"]]
8771
for wi in reversed(vars.worldinfo_i):
8772
if wi["folder"] is not None:
8773
break
8774
entries.append({k: v for k, v in wi.items() if k not in ("init", "folder", "num") and (wi["selective"] or k != "keysecondary")})
8775
return {"entries": list(reversed(entries))}
8776
8777
8778
@api_v1.get("/world_info/folders/none/uids")
8779
@api_schema_wrap
8780
def get_world_info_folders_none_uids():
8781
"""---
8782
get:
8783
summary: Retrieve the UIDs of all world info entries not in a folder
8784
tags:
8785
- world_info
8786
description: |-2
8787
Returns the `uid`s of all world info entries that are not in a world info folder.
8788
8789
The `entries` are sorted in the same order as they are in the KoboldAI GUI.
8790
responses:
8791
200:
8792
description: Successful request
8793
content:
8794
application/json:
8795
schema: WorldInfoEntriesUIDsSchema
8796
"""
8797
entries = []
8798
stablesortwi()
8799
vars.worldinfo_i = [wi for wi in vars.worldinfo if wi["init"]]
8800
for wi in reversed(vars.worldinfo_i):
8801
if wi["folder"] is not None:
8802
break
8803
entries.append(wi["uid"])
8804
return {"entries": list(reversed(entries))}
8805
8806
8807
@api_v1.get("/world_info/folders/none/uids/<int(signed=True):uid>")
8808
@api_schema_wrap
8809
def get_world_info_folders_none_uids_uid(uid: int):
8810
"""---
8811
get:
8812
summary: Determine whether or not there is a world info entry with the given UID that is not in a world info folder
8813
tags:
8814
- world_info
8815
parameters:
8816
- name: uid
8817
in: path
8818
description: |-2
8819
`uid` of the desired world info entry.
8820
schema:
8821
type: integer
8822
minimum: -2147483648
8823
maximum: 2147483647
8824
responses:
8825
200:
8826
description: Successful request
8827
content:
8828
application/json:
8829
schema: BasicBooleanSchema
8830
"""
8831
return {"result": uid in vars.worldinfo_u and vars.worldinfo_u[uid]["folder"] is None and vars.worldinfo_u[uid]["init"]}
8832
8833
8834
@api_v1.get("/world_info/folders/<int(signed=True):uid>")
8835
@api_schema_wrap
8836
def get_world_info_folders_uid(uid: int):
8837
"""---
8838
get:
8839
summary: Retrieve all world info entries in the given folder
8840
tags:
8841
- world_info
8842
parameters:
8843
- name: uid
8844
in: path
8845
description: |-2
8846
`uid` of the desired world info folder.
8847
schema:
8848
type: integer
8849
minimum: -2147483648
8850
maximum: 2147483647
8851
description: |-2
8852
Returns all world info entries that are in the world info folder with the given `uid`.
8853
8854
The `entries` are sorted in the same order as they are in the KoboldAI GUI.
8855
responses:
8856
200:
8857
description: Successful request
8858
content:
8859
application/json:
8860
schema: WorldInfoEntriesSchema
8861
404:
8862
description: Not found
8863
content:
8864
application/json:
8865
schema: NotFoundErrorSchema
8866
example:
8867
detail:
8868
msg: No world info folder with the given uid exists.
8869
type: key_error
8870
"""
8871
if uid not in vars.wifolders_d:
8872
abort(Response(json.dumps({"detail": {
8873
"msg": "No world info folder with the given uid exists.",
8874
"type": "key_error",
8875
}}), mimetype="application/json", status=404))
8876
entries = []
8877
stablesortwi()
8878
vars.worldinfo_i = [wi for wi in vars.worldinfo if wi["init"]]
8879
for wi in vars.wifolders_u[uid]:
8880
if wi["init"]:
8881
entries.append({k: v for k, v in wi.items() if k not in ("init", "folder", "num") and (wi["selective"] or k != "keysecondary")})
8882
return {"entries": entries}
8883
8884
8885
@api_v1.get("/world_info/folders/<int(signed=True):uid>/uids")
8886
@api_schema_wrap
8887
def get_world_info_folders_uid_uids(uid: int):
8888
"""---
8889
get:
8890
summary: Retrieve the UIDs of all world info entries in the given folder
8891
tags:
8892
- world_info
8893
parameters:
8894
- name: uid
8895
in: path
8896
description: |-2
8897
`uid` of the desired world info folder.
8898
schema:
8899
type: integer
8900
minimum: -2147483648
8901
maximum: 2147483647
8902
description: |-2
8903
Returns the `uid`s of all world info entries that are in the world info folder with the given `uid`.
8904
8905
The `entries` are sorted in the same order as they are in the KoboldAI GUI.
8906
responses:
8907
200:
8908
description: Successful request
8909
content:
8910
application/json:
8911
schema: WorldInfoEntriesUIDsSchema
8912
404:
8913
description: Not found
8914
content:
8915
application/json:
8916
schema: NotFoundErrorSchema
8917
example:
8918
detail:
8919
msg: No world info folder with the given uid exists.
8920
type: key_error
8921
"""
8922
if uid not in vars.wifolders_d:
8923
abort(Response(json.dumps({"detail": {
8924
"msg": "No world info folder with the given uid exists.",
8925
"type": "key_error",
8926
}}), mimetype="application/json", status=404))
8927
entries = []
8928
stablesortwi()
8929
vars.worldinfo_i = [wi for wi in vars.worldinfo if wi["init"]]
8930
for wi in vars.wifolders_u[uid]:
8931
if wi["init"]:
8932
entries.append(wi["uid"])
8933
return {"entries": entries}
8934
8935
8936
@api_v1.get("/world_info/folders/<int(signed=True):folder_uid>/uids/<int(signed=True):entry_uid>")
8937
@api_schema_wrap
8938
def get_world_info_folders_folder_uid_uids_entry_uid(folder_uid: int, entry_uid: int):
8939
"""---
8940
get:
8941
summary: Determine whether or not there is a world info entry with the given UID in the world info folder with the given UID
8942
tags:
8943
- world_info
8944
parameters:
8945
- name: folder_uid
8946
in: path
8947
description: |-2
8948
`uid` of the desired world info folder.
8949
schema:
8950
type: integer
8951
minimum: -2147483648
8952
maximum: 2147483647
8953
- name: entry_uid
8954
in: path
8955
description: |-2
8956
`uid` of the desired world info entry.
8957
schema:
8958
type: integer
8959
minimum: -2147483648
8960
maximum: 2147483647
8961
responses:
8962
200:
8963
description: Successful request
8964
content:
8965
application/json:
8966
schema: BasicBooleanSchema
8967
"""
8968
return {"result": entry_uid in vars.worldinfo_u and vars.worldinfo_u[entry_uid]["folder"] == folder_uid and vars.worldinfo_u[entry_uid]["init"]}
8969
8970
8971
@api_v1.get("/world_info/folders/<int(signed=True):uid>/name")
8972
@api_schema_wrap
8973
def get_world_info_folders_uid_name(uid: int):
8974
"""---
8975
get:
8976
summary: Retrieve the name of the world info folder with the given UID
8977
tags:
8978
- world_info
8979
parameters:
8980
- name: uid
8981
in: path
8982
description: |-2
8983
`uid` of the desired world info folder.
8984
schema:
8985
type: integer
8986
minimum: -2147483648
8987
maximum: 2147483647
8988
responses:
8989
200:
8990
description: Successful request
8991
content:
8992
application/json:
8993
schema: BasicStringSchema
8994
404:
8995
description: Not found
8996
content:
8997
application/json:
8998
schema: NotFoundErrorSchema
8999
example:
9000
detail:
9001
msg: No world info folder with the given uid exists.
9002
type: key_error
9003
"""
9004
if uid not in vars.wifolders_d:
9005
abort(Response(json.dumps({"detail": {
9006
"msg": "No world info folder with the given uid exists.",
9007
"type": "key_error",
9008
}}), mimetype="application/json", status=404))
9009
return {"value": vars.wifolders_d[uid]["name"]}
9010
9011
9012
@api_v1.put("/world_info/folders/<int(signed=True):uid>/name")
9013
@api_schema_wrap
9014
def put_world_info_folders_uid_name(body: BasicStringSchema, uid: int):
9015
"""---
9016
put:
9017
summary: Set the name of the world info folder with the given UID to the specified value
9018
tags:
9019
- world_info
9020
parameters:
9021
- name: uid
9022
in: path
9023
description: |-2
9024
`uid` of the desired world info folder.
9025
schema:
9026
type: integer
9027
minimum: -2147483648
9028
maximum: 2147483647
9029
requestBody:
9030
required: true
9031
content:
9032
application/json:
9033
schema: BasicStringSchema
9034
example:
9035
value: string
9036
responses:
9037
200:
9038
description: Successful request
9039
content:
9040
application/json:
9041
schema: EmptySchema
9042
404:
9043
description: Not found
9044
content:
9045
application/json:
9046
schema: NotFoundErrorSchema
9047
example:
9048
detail:
9049
msg: No world info folder with the given uid exists.
9050
type: key_error
9051
{api_validation_error_response}
9052
"""
9053
if uid not in vars.wifolders_d:
9054
abort(Response(json.dumps({"detail": {
9055
"msg": "No world info folder with the given uid exists.",
9056
"type": "key_error",
9057
}}), mimetype="application/json", status=404))
9058
vars.wifolders_d[uid]["name"] = body.value
9059
setgamesaved(False)
9060
return {}
9061
9062
9063
@api_v1.get("/world_info/<int(signed=True):uid>")
9064
@api_schema_wrap
9065
def get_world_info_uid(uid: int):
9066
"""---
9067
get:
9068
summary: Retrieve information about the world info entry with the given UID
9069
tags:
9070
- world_info
9071
parameters:
9072
- name: uid
9073
in: path
9074
description: |-2
9075
`uid` of the desired world info entry.
9076
schema:
9077
type: integer
9078
minimum: -2147483648
9079
maximum: 2147483647
9080
responses:
9081
200:
9082
description: Successful request
9083
content:
9084
application/json:
9085
schema: WorldInfoEntrySchema
9086
404:
9087
description: Not found
9088
content:
9089
application/json:
9090
schema: NotFoundErrorSchema
9091
example:
9092
detail:
9093
msg: No world info entry with the given uid exists.
9094
type: key_error
9095
"""
9096
if uid not in vars.worldinfo_u:
9097
abort(Response(json.dumps({"detail": {
9098
"msg": "No world info entry with the given uid exists.",
9099
"type": "key_error",
9100
}}), mimetype="application/json", status=404))
9101
wi = vars.worldinfo_u[uid]
9102
return {k: v for k, v in wi.items() if k not in ("init", "folder", "num") and (wi["selective"] or k != "keysecondary")}
9103
9104
9105
@api_v1.get("/world_info/<int(signed=True):uid>/comment")
9106
@api_schema_wrap
9107
def get_world_info_uid_comment(uid: int):
9108
"""---
9109
get:
9110
summary: Retrieve the comment of the world info entry with the given UID
9111
tags:
9112
- world_info
9113
parameters:
9114
- name: uid
9115
in: path
9116
description: |-2
9117
`uid` of the desired world info entry.
9118
schema:
9119
type: integer
9120
minimum: -2147483648
9121
maximum: 2147483647
9122
responses:
9123
200:
9124
description: Successful request
9125
content:
9126
application/json:
9127
schema: BasicStringSchema
9128
404:
9129
description: Not found
9130
content:
9131
application/json:
9132
schema: NotFoundErrorSchema
9133
example:
9134
detail:
9135
msg: No world info entry with the given uid exists.
9136
type: key_error
9137
"""
9138
if uid not in vars.worldinfo_u:
9139
abort(Response(json.dumps({"detail": {
9140
"msg": "No world info entry with the given uid exists.",
9141
"type": "key_error",
9142
}}), mimetype="application/json", status=404))
9143
return {"value": vars.worldinfo_u[uid]["comment"]}
9144
9145
9146
@api_v1.put("/world_info/<int(signed=True):uid>/comment")
9147
@api_schema_wrap
9148
def put_world_info_uid_comment(body: BasicStringSchema, uid: int):
9149
"""---
9150
put:
9151
summary: Set the comment of the world info entry with the given UID to the specified value
9152
tags:
9153
- world_info
9154
parameters:
9155
- name: uid
9156
in: path
9157
description: |-2
9158
`uid` of the desired world info entry.
9159
schema:
9160
type: integer
9161
minimum: -2147483648
9162
maximum: 2147483647
9163
requestBody:
9164
required: true
9165
content:
9166
application/json:
9167
schema: BasicStringSchema
9168
example:
9169
value: string
9170
responses:
9171
200:
9172
description: Successful request
9173
content:
9174
application/json:
9175
schema: EmptySchema
9176
404:
9177
description: Not found
9178
content:
9179
application/json:
9180
schema: NotFoundErrorSchema
9181
example:
9182
detail:
9183
msg: No world info entry with the given uid exists.
9184
type: key_error
9185
{api_validation_error_response}
9186
"""
9187
if uid not in vars.worldinfo_u:
9188
abort(Response(json.dumps({"detail": {
9189
"msg": "No world info entry with the given uid exists.",
9190
"type": "key_error",
9191
}}), mimetype="application/json", status=404))
9192
vars.worldinfo_u[uid]["comment"] = body.value
9193
setgamesaved(False)
9194
return {}
9195
9196
9197
@api_v1.get("/world_info/<int(signed=True):uid>/content")
9198
@api_schema_wrap
9199
def get_world_info_uid_content(uid: int):
9200
"""---
9201
get:
9202
summary: Retrieve the content of the world info entry with the given UID
9203
tags:
9204
- world_info
9205
parameters:
9206
- name: uid
9207
in: path
9208
description: |-2
9209
`uid` of the desired world info entry.
9210
schema:
9211
type: integer
9212
minimum: -2147483648
9213
maximum: 2147483647
9214
responses:
9215
200:
9216
description: Successful request
9217
content:
9218
application/json:
9219
schema: BasicStringSchema
9220
404:
9221
description: Not found
9222
content:
9223
application/json:
9224
schema: NotFoundErrorSchema
9225
example:
9226
detail:
9227
msg: No world info entry with the given uid exists.
9228
type: key_error
9229
"""
9230
if uid not in vars.worldinfo_u:
9231
abort(Response(json.dumps({"detail": {
9232
"msg": "No world info entry with the given uid exists.",
9233
"type": "key_error",
9234
}}), mimetype="application/json", status=404))
9235
return {"value": vars.worldinfo_u[uid]["content"]}
9236
9237
9238
@api_v1.put("/world_info/<int(signed=True):uid>/content")
9239
@api_schema_wrap
9240
def put_world_info_uid_content(body: BasicStringSchema, uid: int):
9241
"""---
9242
put:
9243
summary: Set the content of the world info entry with the given UID to the specified value
9244
tags:
9245
- world_info
9246
parameters:
9247
- name: uid
9248
in: path
9249
description: |-2
9250
`uid` of the desired world info entry.
9251
schema:
9252
type: integer
9253
minimum: -2147483648
9254
maximum: 2147483647
9255
requestBody:
9256
required: true
9257
content:
9258
application/json:
9259
schema: BasicStringSchema
9260
example:
9261
value: string
9262
responses:
9263
200:
9264
description: Successful request
9265
content:
9266
application/json:
9267
schema: EmptySchema
9268
404:
9269
description: Not found
9270
content:
9271
application/json:
9272
schema: NotFoundErrorSchema
9273
example:
9274
detail:
9275
msg: No world info entry with the given uid exists.
9276
type: key_error
9277
{api_validation_error_response}
9278
"""
9279
if uid not in vars.worldinfo_u:
9280
abort(Response(json.dumps({"detail": {
9281
"msg": "No world info entry with the given uid exists.",
9282
"type": "key_error",
9283
}}), mimetype="application/json", status=404))
9284
vars.worldinfo_u[uid]["content"] = body.value
9285
setgamesaved(False)
9286
return {}
9287
9288
9289
@api_v1.get("/world_info/<int(signed=True):uid>/key")
9290
@api_schema_wrap
9291
def get_world_info_uid_key(uid: int):
9292
"""---
9293
get:
9294
summary: Retrieve the keys or primary keys of the world info entry with the given UID
9295
tags:
9296
- world_info
9297
parameters:
9298
- name: uid
9299
in: path
9300
description: |-2
9301
`uid` of the desired world info entry.
9302
schema:
9303
type: integer
9304
minimum: -2147483648
9305
maximum: 2147483647
9306
responses:
9307
200:
9308
description: Successful request
9309
content:
9310
application/json:
9311
schema: BasicStringSchema
9312
404:
9313
description: Not found
9314
content:
9315
application/json:
9316
schema: NotFoundErrorSchema
9317
example:
9318
detail:
9319
msg: No world info entry with the given uid exists.
9320
type: key_error
9321
"""
9322
if uid not in vars.worldinfo_u:
9323
abort(Response(json.dumps({"detail": {
9324
"msg": "No world info entry with the given uid exists.",
9325
"type": "key_error",
9326
}}), mimetype="application/json", status=404))
9327
return {"value": vars.worldinfo_u[uid]["key"]}
9328
9329
9330
@api_v1.put("/world_info/<int(signed=True):uid>/key")
9331
@api_schema_wrap
9332
def put_world_info_uid_key(body: BasicStringSchema, uid: int):
9333
"""---
9334
put:
9335
summary: Set the keys or primary keys of the world info entry with the given UID to the specified value
9336
tags:
9337
- world_info
9338
parameters:
9339
- name: uid
9340
in: path
9341
description: |-2
9342
`uid` of the desired world info entry.
9343
schema:
9344
type: integer
9345
minimum: -2147483648
9346
maximum: 2147483647
9347
requestBody:
9348
required: true
9349
content:
9350
application/json:
9351
schema: BasicStringSchema
9352
example:
9353
value: string
9354
responses:
9355
200:
9356
description: Successful request
9357
content:
9358
application/json:
9359
schema: EmptySchema
9360
404:
9361
description: Not found
9362
content:
9363
application/json:
9364
schema: NotFoundErrorSchema
9365
example:
9366
detail:
9367
msg: No world info entry with the given uid exists.
9368
type: key_error
9369
{api_validation_error_response}
9370
"""
9371
if uid not in vars.worldinfo_u:
9372
abort(Response(json.dumps({"detail": {
9373
"msg": "No world info entry with the given uid exists.",
9374
"type": "key_error",
9375
}}), mimetype="application/json", status=404))
9376
vars.worldinfo_u[uid]["key"] = body.value
9377
setgamesaved(False)
9378
return {}
9379
9380
9381
@api_v1.get("/world_info/<int(signed=True):uid>/keysecondary")
9382
@api_schema_wrap
9383
def get_world_info_uid_keysecondary(uid: int):
9384
"""---
9385
get:
9386
summary: Retrieve the secondary keys of the world info entry with the given UID
9387
tags:
9388
- world_info
9389
parameters:
9390
- name: uid
9391
in: path
9392
description: |-2
9393
`uid` of the desired world info entry.
9394
schema:
9395
type: integer
9396
minimum: -2147483648
9397
maximum: 2147483647
9398
responses:
9399
200:
9400
description: Successful request
9401
content:
9402
application/json:
9403
schema: BasicStringSchema
9404
404:
9405
description: Not found
9406
content:
9407
application/json:
9408
schema: NotFoundErrorSchema
9409
example:
9410
detail:
9411
msg: No world info entry with the given uid exists.
9412
type: key_error
9413
"""
9414
if uid not in vars.worldinfo_u:
9415
abort(Response(json.dumps({"detail": {
9416
"msg": "No world info entry with the given uid exists.",
9417
"type": "key_error",
9418
}}), mimetype="application/json", status=404))
9419
return {"value": vars.worldinfo_u[uid]["keysecondary"]}
9420
9421
9422
@api_v1.put("/world_info/<int(signed=True):uid>/keysecondary")
9423
@api_schema_wrap
9424
def put_world_info_uid_keysecondary(body: BasicStringSchema, uid: int):
9425
"""---
9426
put:
9427
summary: Set the secondary keys of the world info entry with the given UID to the specified value
9428
tags:
9429
- world_info
9430
parameters:
9431
- name: uid
9432
in: path
9433
description: |-2
9434
`uid` of the desired world info entry.
9435
schema:
9436
type: integer
9437
minimum: -2147483648
9438
maximum: 2147483647
9439
requestBody:
9440
required: true
9441
content:
9442
application/json:
9443
schema: BasicStringSchema
9444
example:
9445
value: string
9446
responses:
9447
200:
9448
description: Successful request
9449
content:
9450
application/json:
9451
schema: EmptySchema
9452
404:
9453
description: Not found
9454
content:
9455
application/json:
9456
schema: NotFoundErrorSchema
9457
example:
9458
detail:
9459
msg: No world info entry with the given uid exists.
9460
type: key_error
9461
{api_validation_error_response}
9462
"""
9463
if uid not in vars.worldinfo_u:
9464
abort(Response(json.dumps({"detail": {
9465
"msg": "No world info entry with the given uid exists.",
9466
"type": "key_error",
9467
}}), mimetype="application/json", status=404))
9468
vars.worldinfo_u[uid]["keysecondary"] = body.value
9469
setgamesaved(False)
9470
return {}
9471
9472
9473
@api_v1.get("/world_info/<int(signed=True):uid>/selective")
9474
@api_schema_wrap
9475
def get_world_info_uid_selective(uid: int):
9476
"""---
9477
get:
9478
summary: Retrieve the selective mode state of the world info entry with the given UID
9479
tags:
9480
- world_info
9481
parameters:
9482
- name: uid
9483
in: path
9484
description: |-2
9485
`uid` of the desired world info entry.
9486
schema:
9487
type: integer
9488
minimum: -2147483648
9489
maximum: 2147483647
9490
responses:
9491
200:
9492
description: Successful request
9493
content:
9494
application/json:
9495
schema: BasicBooleanSchema
9496
404:
9497
description: Not found
9498
content:
9499
application/json:
9500
schema: NotFoundErrorSchema
9501
example:
9502
detail:
9503
msg: No world info entry with the given uid exists.
9504
type: key_error
9505
"""
9506
if uid not in vars.worldinfo_u:
9507
abort(Response(json.dumps({"detail": {
9508
"msg": "No world info entry with the given uid exists.",
9509
"type": "key_error",
9510
}}), mimetype="application/json", status=404))
9511
return {"value": vars.worldinfo_u[uid]["selective"]}
9512
9513
9514
@api_v1.put("/world_info/<int(signed=True):uid>/selective")
9515
@api_schema_wrap
9516
def put_world_info_uid_selective(body: BasicBooleanSchema, uid: int):
9517
"""---
9518
put:
9519
summary: Set the selective mode state of the world info entry with the given UID to the specified value
9520
tags:
9521
- world_info
9522
parameters:
9523
- name: uid
9524
in: path
9525
description: |-2
9526
`uid` of the desired world info entry.
9527
schema:
9528
type: integer
9529
minimum: -2147483648
9530
maximum: 2147483647
9531
requestBody:
9532
required: true
9533
content:
9534
application/json:
9535
schema: BasicBooleanSchema
9536
example:
9537
value: true
9538
responses:
9539
200:
9540
description: Successful request
9541
content:
9542
application/json:
9543
schema: EmptySchema
9544
404:
9545
description: Not found
9546
content:
9547
application/json:
9548
schema: NotFoundErrorSchema
9549
example:
9550
detail:
9551
msg: No world info entry with the given uid exists.
9552
type: key_error
9553
{api_validation_error_response}
9554
"""
9555
if uid not in vars.worldinfo_u:
9556
abort(Response(json.dumps({"detail": {
9557
"msg": "No world info entry with the given uid exists.",
9558
"type": "key_error",
9559
}}), mimetype="application/json", status=404))
9560
vars.worldinfo_u[uid]["selective"] = body.value
9561
setgamesaved(False)
9562
return {}
9563
9564
9565
@api_v1.get("/world_info/<int(signed=True):uid>/constant")
9566
@api_schema_wrap
9567
def get_world_info_uid_constant(uid: int):
9568
"""---
9569
get:
9570
summary: Retrieve the constant mode state of the world info entry with the given UID
9571
tags:
9572
- world_info
9573
parameters:
9574
- name: uid
9575
in: path
9576
description: |-2
9577
`uid` of the desired world info entry.
9578
schema:
9579
type: integer
9580
minimum: -2147483648
9581
maximum: 2147483647
9582
responses:
9583
200:
9584
description: Successful request
9585
content:
9586
application/json:
9587
schema: BasicBooleanSchema
9588
404:
9589
description: Not found
9590
content:
9591
application/json:
9592
schema: NotFoundErrorSchema
9593
example:
9594
detail:
9595
msg: No world info entry with the given uid exists.
9596
type: key_error
9597
"""
9598
if uid not in vars.worldinfo_u:
9599
abort(Response(json.dumps({"detail": {
9600
"msg": "No world info entry with the given uid exists.",
9601
"type": "key_error",
9602
}}), mimetype="application/json", status=404))
9603
return {"value": vars.worldinfo_u[uid]["constant"]}
9604
9605
9606
@api_v1.put("/world_info/<int(signed=True):uid>/constant")
9607
@api_schema_wrap
9608
def put_world_info_uid_constant(body: BasicBooleanSchema, uid: int):
9609
"""---
9610
put:
9611
summary: Set the constant mode state of the world info entry with the given UID to the specified value
9612
tags:
9613
- world_info
9614
parameters:
9615
- name: uid
9616
in: path
9617
description: |-2
9618
`uid` of the desired world info entry.
9619
schema:
9620
type: integer
9621
minimum: -2147483648
9622
maximum: 2147483647
9623
requestBody:
9624
required: true
9625
content:
9626
application/json:
9627
schema: BasicBooleanSchema
9628
example:
9629
value: true
9630
responses:
9631
200:
9632
description: Successful request
9633
content:
9634
application/json:
9635
schema: EmptySchema
9636
404:
9637
description: Not found
9638
content:
9639
application/json:
9640
schema: NotFoundErrorSchema
9641
example:
9642
detail:
9643
msg: No world info entry with the given uid exists.
9644
type: key_error
9645
{api_validation_error_response}
9646
"""
9647
if uid not in vars.worldinfo_u:
9648
abort(Response(json.dumps({"detail": {
9649
"msg": "No world info entry with the given uid exists.",
9650
"type": "key_error",
9651
}}), mimetype="application/json", status=404))
9652
vars.worldinfo_u[uid]["constant"] = body.value
9653
setgamesaved(False)
9654
return {}
9655
9656
9657
@api_v1.post("/world_info/folders/none")
9658
@api_schema_wrap
9659
def post_world_info_folders_none(body: EmptySchema):
9660
"""---
9661
post:
9662
summary: Create a new world info entry outside of a world info folder, at the end of the world info
9663
tags:
9664
- world_info
9665
requestBody:
9666
required: true
9667
content:
9668
application/json:
9669
schema: EmptySchema
9670
responses:
9671
200:
9672
description: Successful request
9673
content:
9674
application/json:
9675
schema: BasicUIDSchema
9676
{api_validation_error_response}
9677
"""
9678
stablesortwi()
9679
vars.worldinfo_i = [wi for wi in vars.worldinfo if wi["init"]]
9680
setgamesaved(False)
9681
emit('from_server', {'cmd': 'wiexpand', 'data': vars.worldinfo[-1]["num"]}, broadcast=True)
9682
vars.worldinfo[-1]["init"] = True
9683
addwiitem(folder_uid=None)
9684
return {"uid": vars.worldinfo[-2]["uid"]}
9685
9686
9687
@api_v1.post("/world_info/folders/<int(signed=True):uid>")
9688
@api_schema_wrap
9689
def post_world_info_folders_uid(body: EmptySchema, uid: int):
9690
"""---
9691
post:
9692
summary: Create a new world info entry at the end of the world info folder with the given UID
9693
tags:
9694
- world_info
9695
parameters:
9696
- name: uid
9697
in: path
9698
description: |-2
9699
`uid` of the desired world info folder.
9700
schema:
9701
type: integer
9702
minimum: -2147483648
9703
maximum: 2147483647
9704
requestBody:
9705
required: true
9706
content:
9707
application/json:
9708
schema: EmptySchema
9709
responses:
9710
200:
9711
description: Successful request
9712
content:
9713
application/json:
9714
schema: BasicUIDSchema
9715
404:
9716
description: Not found
9717
content:
9718
application/json:
9719
schema: NotFoundErrorSchema
9720
example:
9721
detail:
9722
msg: No world info folder with the given uid exists.
9723
type: key_error
9724
{api_validation_error_response}
9725
"""
9726
if uid not in vars.wifolders_d:
9727
abort(Response(json.dumps({"detail": {
9728
"msg": "No world info folder with the given uid exists.",
9729
"type": "key_error",
9730
}}), mimetype="application/json", status=404))
9731
stablesortwi()
9732
vars.worldinfo_i = [wi for wi in vars.worldinfo if wi["init"]]
9733
setgamesaved(False)
9734
emit('from_server', {'cmd': 'wiexpand', 'data': vars.wifolders_u[uid][-1]["num"]}, broadcast=True)
9735
vars.wifolders_u[uid][-1]["init"] = True
9736
addwiitem(folder_uid=uid)
9737
return {"uid": vars.wifolders_u[uid][-2]["uid"]}
9738
9739
9740
@api_v1.delete("/world_info/<int(signed=True):uid>")
9741
@api_schema_wrap
9742
def delete_world_info_uid(uid: int):
9743
"""---
9744
delete:
9745
summary: Delete the world info entry with the given UID
9746
tags:
9747
- world_info
9748
parameters:
9749
- name: uid
9750
in: path
9751
description: |-2
9752
`uid` of the desired world info entry.
9753
schema:
9754
type: integer
9755
minimum: -2147483648
9756
maximum: 2147483647
9757
responses:
9758
200:
9759
description: Successful request
9760
content:
9761
application/json:
9762
schema: EmptySchema
9763
404:
9764
description: Not found
9765
content:
9766
application/json:
9767
schema: NotFoundErrorSchema
9768
example:
9769
detail:
9770
msg: No world info entry with the given uid exists.
9771
type: key_error
9772
"""
9773
if uid not in vars.worldinfo_u:
9774
abort(Response(json.dumps({"detail": {
9775
"msg": "No world info entry with the given uid exists.",
9776
"type": "key_error",
9777
}}), mimetype="application/json", status=404))
9778
deletewi(uid)
9779
return {}
9780
9781
9782
@api_v1.post("/world_info/folders")
9783
@api_schema_wrap
9784
def post_world_info_folders(body: EmptySchema):
9785
"""---
9786
post:
9787
summary: Create a new world info folder at the end of the world info
9788
tags:
9789
- world_info
9790
requestBody:
9791
required: true
9792
content:
9793
application/json:
9794
schema: EmptySchema
9795
responses:
9796
200:
9797
description: Successful request
9798
content:
9799
application/json:
9800
schema: BasicUIDSchema
9801
{api_validation_error_response}
9802
"""
9803
addwifolder()
9804
return {"uid": vars.wifolders_l[-1]}
9805
9806
9807
@api_v1.delete("/world_info/folders/<int(signed=True):uid>")
9808
@api_schema_wrap
9809
def delete_world_info_folders_uid(uid: int):
9810
"""---
9811
delete:
9812
summary: Delete the world info folder with the given UID
9813
tags:
9814
- world_info
9815
parameters:
9816
- name: uid
9817
in: path
9818
description: |-2
9819
`uid` of the desired world info folder.
9820
schema:
9821
type: integer
9822
minimum: -2147483648
9823
maximum: 2147483647
9824
responses:
9825
200:
9826
description: Successful request
9827
content:
9828
application/json:
9829
schema: EmptySchema
9830
404:
9831
description: Not found
9832
content:
9833
application/json:
9834
schema: NotFoundErrorSchema
9835
example:
9836
detail:
9837
msg: No world info folders with the given uid exists.
9838
type: key_error
9839
"""
9840
if uid not in vars.wifolders_d:
9841
abort(Response(json.dumps({"detail": {
9842
"msg": "No world info folder with the given uid exists.",
9843
"type": "key_error",
9844
}}), mimetype="application/json", status=404))
9845
deletewifolder(uid)
9846
return {}
9847
9848
9849
def _make_f_get(obj, _var_name, _name, _schema, _example_yaml_value):
9850
def f_get():
9851
"""---
9852
get:
9853
summary: Retrieve the current {} setting value
9854
tags:
9855
- config
9856
responses:
9857
200:
9858
description: Successful request
9859
content:
9860
application/json:
9861
schema: {}
9862
example:
9863
value: {}
9864
"""
9865
_obj = {"vars": vars, "vars.formatoptns": vars.formatoptns}[obj]
9866
if _var_name.startswith("@"):
9867
return {"value": _obj[_var_name[1:]]}
9868
else:
9869
return {"value": getattr(_obj, _var_name)}
9870
f_get.__doc__ = f_get.__doc__.format(_name, _schema, _example_yaml_value)
9871
return f_get
9872
9873
def _make_f_put(schema_class: Type[KoboldSchema], obj, _var_name, _name, _schema, _example_yaml_value):
9874
def f_put(body: schema_class):
9875
"""---
9876
put:
9877
summary: Set {} setting to specified value
9878
tags:
9879
- config
9880
requestBody:
9881
required: true
9882
content:
9883
application/json:
9884
schema: {}
9885
example:
9886
value: {}
9887
responses:
9888
200:
9889
description: Successful request
9890
content:
9891
application/json:
9892
schema: EmptySchema
9893
{api_validation_error_response}
9894
"""
9895
_obj = {"vars": vars, "vars.formatoptns": vars.formatoptns}[obj]
9896
if _var_name.startswith("@"):
9897
_obj[_var_name[1:]] = body.value
9898
else:
9899
setattr(_obj, _var_name, body.value)
9900
settingschanged()
9901
refresh_settings()
9902
return {}
9903
f_put.__doc__ = f_put.__doc__.format(_name, _schema, _example_yaml_value, api_validation_error_response=api_validation_error_response)
9904
return f_put
9905
9906
def create_config_endpoint(method="GET", schema="MemorySchema"):
9907
_name = globals()[schema].KoboldMeta.name
9908
_var_name = globals()[schema].KoboldMeta.var_name
9909
_route_name = globals()[schema].KoboldMeta.route_name
9910
_obj = globals()[schema].KoboldMeta.obj
9911
_example_yaml_value = globals()[schema].KoboldMeta.example_yaml_value
9912
_schema = schema
9913
f = _make_f_get(_obj, _var_name, _name, _schema, _example_yaml_value) if method == "GET" else _make_f_put(globals()[schema], _obj, _var_name, _name, _schema, _example_yaml_value)
9914
f.__name__ = f"{method.lower()}_config_{_name}"
9915
f = api_schema_wrap(f)
9916
for api in (api_v1,):
9917
f = api.route(f"/config/{_route_name}", methods=[method])(f)
9918
9919
class SoftPromptSettingSchema(KoboldSchema):
9920
value: str = fields.String(required=True, validate=[soft_prompt_validator, validate.Regexp(r"^[^/\\]*$")], metadata={"description": "Soft prompt name, or a string containing only whitespace for no soft prompt. If using the GET method and no soft prompt is loaded, this will always be the empty string."})
9921
9922
@api_v1.get("/config/soft_prompt")
9923
@api_schema_wrap
9924
def get_config_soft_prompt():
9925
"""---
9926
get:
9927
summary: Retrieve the current soft prompt name
9928
tags:
9929
- config
9930
responses:
9931
200:
9932
description: Successful request
9933
content:
9934
application/json:
9935
schema: SoftPromptSettingSchema
9936
example:
9937
value: ""
9938
"""
9939
return {"value": vars.spfilename.strip()}
9940
9941
class SoftPromptsListSchema(KoboldSchema):
9942
values: List[SoftPromptSettingSchema] = fields.List(fields.Nested(SoftPromptSettingSchema), required=True, metadata={"description": "Array of available softprompts."})
9943
9944
@api_v1.get("/config/soft_prompts_list")
9945
@api_schema_wrap
9946
def get_config_soft_prompts_list():
9947
"""---
9948
get:
9949
summary: Retrieve all available softprompt filenames
9950
tags:
9951
- config
9952
responses:
9953
200:
9954
description: Successful request
9955
content:
9956
application/json:
9957
schema: SoftPromptsListSchema
9958
example:
9959
values: []
9960
"""
9961
splist = []
9962
for sp in fileops.getspfiles(vars.modeldim):
9963
9964
splist.append({"value":sp["filename"]})
9965
return {"values": splist}
9966
9967
@api_v1.put("/config/soft_prompt")
9968
@api_schema_wrap
9969
def put_config_soft_prompt(body: SoftPromptSettingSchema):
9970
"""---
9971
put:
9972
summary: Set soft prompt by name
9973
tags:
9974
- config
9975
requestBody:
9976
required: true
9977
content:
9978
application/json:
9979
schema: SoftPromptSettingSchema
9980
example:
9981
value: ""
9982
responses:
9983
200:
9984
description: Successful request
9985
content:
9986
application/json:
9987
schema: EmptySchema
9988
{api_validation_error_response}
9989
"""
9990
if vars.allowsp:
9991
spRequest(body.value)
9992
settingschanged()
9993
return {}
9994
9995
class SamplerSeedSettingSchema(KoboldSchema):
9996
value: int = fields.Integer(validate=validate.Range(min=0, max=2**64 - 1), required=True)
9997
9998
@api_v1.get("/config/sampler_seed")
9999
@api_schema_wrap
10000
def get_config_sampler_seed():
10001
"""---
10002
get:
10003
summary: Retrieve the current global sampler seed value
10004
tags:
10005
- config
10006
responses:
10007
200:
10008
description: Successful request
10009
content:
10010
application/json:
10011
schema: SamplerSeedSettingSchema
10012
example:
10013
value: 3475097509890965500
10014
"""
10015
return {"value": __import__("tpu_mtj_backend").get_rng_seed() if vars.use_colab_tpu else __import__("torch").initial_seed()}
10016
10017
@api_v1.put("/config/sampler_seed")
10018
@api_schema_wrap
10019
def put_config_sampler_seed(body: SamplerSeedSettingSchema):
10020
"""---
10021
put:
10022
summary: Set the global sampler seed value
10023
tags:
10024
- config
10025
requestBody:
10026
required: true
10027
content:
10028
application/json:
10029
schema: SamplerSeedSettingSchema
10030
example:
10031
value: 3475097509890965500
10032
responses:
10033
200:
10034
description: Successful request
10035
content:
10036
application/json:
10037
schema: EmptySchema
10038
{api_validation_error_response}
10039
"""
10040
if vars.use_colab_tpu:
10041
import tpu_mtj_backend
10042
tpu_mtj_backend.set_rng_seed(body.value)
10043
else:
10044
import torch
10045
torch.manual_seed(body.value)
10046
vars.seed = body.value
10047
return {}
10048
10049
config_endpoint_schemas: List[Type[KoboldSchema]] = []
10050
10051
def config_endpoint_schema(c: Type[KoboldSchema]):
10052
config_endpoint_schemas.append(c)
10053
return c
10054
10055
10056
@config_endpoint_schema
10057
class MemorySettingSchema(KoboldSchema):
10058
value = fields.String(required=True)
10059
class KoboldMeta:
10060
route_name = "memory"
10061
obj = "vars"
10062
var_name = "memory"
10063
name = "memory"
10064
example_yaml_value = "Memory"
10065
10066
@config_endpoint_schema
10067
class AuthorsNoteSettingSchema(KoboldSchema):
10068
value = fields.String(required=True)
10069
class KoboldMeta:
10070
route_name = "authors_note"
10071
obj = "vars"
10072
var_name = "authornote"
10073
name = "author's note"
10074
example_yaml_value = "''"
10075
10076
@config_endpoint_schema
10077
class AuthorsNoteTemplateSettingSchema(KoboldSchema):
10078
value = fields.String(required=True)
10079
class KoboldMeta:
10080
route_name = "authors_note_template"
10081
obj = "vars"
10082
var_name = "authornotetemplate"
10083
name = "author's note template"
10084
example_yaml_value = "\"[Author's note: <|>]\""
10085
10086
@config_endpoint_schema
10087
class TopKSamplingSettingSchema(KoboldSchema):
10088
value = fields.Integer(validate=validate.Range(min=0), required=True)
10089
class KoboldMeta:
10090
route_name = "top_k"
10091
obj = "vars"
10092
var_name = "top_k"
10093
name = "top-k sampling"
10094
example_yaml_value = "0"
10095
10096
@config_endpoint_schema
10097
class TopASamplingSettingSchema(KoboldSchema):
10098
value = fields.Float(validate=validate.Range(min=0), required=True)
10099
class KoboldMeta:
10100
route_name = "top_a"
10101
obj = "vars"
10102
var_name = "top_a"
10103
name = "top-a sampling"
10104
example_yaml_value = "0.0"
10105
10106
@config_endpoint_schema
10107
class TopPSamplingSettingSchema(KoboldSchema):
10108
value = fields.Float(validate=validate.Range(min=0, max=1), required=True)
10109
class KoboldMeta:
10110
route_name = "top_p"
10111
obj = "vars"
10112
var_name = "top_p"
10113
name = "top-p sampling"
10114
example_yaml_value = "0.9"
10115
10116
@config_endpoint_schema
10117
class TailFreeSamplingSettingSchema(KoboldSchema):
10118
value = fields.Float(validate=validate.Range(min=0, max=1), required=True)
10119
class KoboldMeta:
10120
route_name = "tfs"
10121
obj = "vars"
10122
var_name = "tfs"
10123
name = "tail free sampling"
10124
example_yaml_value = "1.0"
10125
10126
@config_endpoint_schema
10127
class TypicalSamplingSettingSchema(KoboldSchema):
10128
value = fields.Float(validate=validate.Range(min=0, max=1), required=True)
10129
class KoboldMeta:
10130
route_name = "typical"
10131
obj = "vars"
10132
var_name = "typical"
10133
name = "typical sampling"
10134
example_yaml_value = "1.0"
10135
10136
@config_endpoint_schema
10137
class TemperatureSamplingSettingSchema(KoboldSchema):
10138
value = fields.Float(validate=validate.Range(min=0, min_inclusive=False), required=True)
10139
class KoboldMeta:
10140
route_name = "temperature"
10141
obj = "vars"
10142
var_name = "temp"
10143
name = "temperature"
10144
example_yaml_value = "0.5"
10145
10146
@config_endpoint_schema
10147
class GensPerActionSettingSchema(KoboldSchema):
10148
value = fields.Integer(validate=validate.Range(min=0, max=5), required=True)
10149
class KoboldMeta:
10150
route_name = "n"
10151
obj = "vars"
10152
var_name = "numseqs"
10153
name = "Gens Per Action"
10154
example_yaml_value = "1"
10155
10156
@config_endpoint_schema
10157
class MaxLengthSettingSchema(KoboldSchema):
10158
value = fields.Integer(validate=validate.Range(min=1, max=512), required=True)
10159
class KoboldMeta:
10160
route_name = "max_length"
10161
obj = "vars"
10162
var_name = "genamt"
10163
name = "max length"
10164
example_yaml_value = "80"
10165
10166
@config_endpoint_schema
10167
class WorldInfoDepthSettingSchema(KoboldSchema):
10168
value = fields.Integer(validate=validate.Range(min=1, max=5), required=True)
10169
class KoboldMeta:
10170
route_name = "world_info_depth"
10171
obj = "vars"
10172
var_name = "widepth"
10173
name = "world info depth"
10174
example_yaml_value = "3"
10175
10176
@config_endpoint_schema
10177
class AuthorsNoteDepthSettingSchema(KoboldSchema):
10178
value = fields.Integer(validate=validate.Range(min=1, max=5), required=True)
10179
class KoboldMeta:
10180
route_name = "authors_note_depth"
10181
obj = "vars"
10182
var_name = "andepth"
10183
name = "author's note depth"
10184
example_yaml_value = "3"
10185
10186
@config_endpoint_schema
10187
class MaxContextLengthSettingSchema(KoboldSchema):
10188
value = fields.Integer(validate=validate.Range(min=512, max=2048), required=True)
10189
class KoboldMeta:
10190
route_name = "max_context_length"
10191
obj = "vars"
10192
var_name = "max_length"
10193
name = "max context length"
10194
example_yaml_value = "2048"
10195
10196
@config_endpoint_schema
10197
class TrimIncompleteSentencesSettingsSchema(KoboldSchema):
10198
value = fields.Boolean(required=True)
10199
class KoboldMeta:
10200
route_name = "frmttriminc"
10201
obj = "vars.formatoptns"
10202
var_name = "@frmttriminc"
10203
name = "trim incomplete sentences (output formatting)"
10204
example_yaml_value = "false"
10205
10206
@config_endpoint_schema
10207
class RemoveBlankLinesSettingsSchema(KoboldSchema):
10208
value = fields.Boolean(required=True)
10209
class KoboldMeta:
10210
route_name = "frmtrmblln"
10211
obj = "vars.formatoptns"
10212
var_name = "@frmtrmblln"
10213
name = "remove blank lines (output formatting)"
10214
example_yaml_value = "false"
10215
10216
@config_endpoint_schema
10217
class RemoveSpecialCharactersSettingsSchema(KoboldSchema):
10218
value = fields.Boolean(required=True)
10219
class KoboldMeta:
10220
route_name = "frmtrmspch"
10221
obj = "vars.formatoptns"
10222
var_name = "@frmtrmspch"
10223
name = "remove special characters (output formatting)"
10224
example_yaml_value = "false"
10225
10226
@config_endpoint_schema
10227
class SingleLineSettingsSchema(KoboldSchema):
10228
value = fields.Boolean(required=True)
10229
class KoboldMeta:
10230
route_name = "singleline"
10231
obj = "vars.formatoptns"
10232
var_name = "@singleline"
10233
name = "single line (output formatting)"
10234
example_yaml_value = "false"
10235
10236
@config_endpoint_schema
10237
class AddSentenceSpacingSettingsSchema(KoboldSchema):
10238
value = fields.Boolean(required=True)
10239
class KoboldMeta:
10240
route_name = "frmtadsnsp"
10241
obj = "vars.formatoptns"
10242
var_name = "@frmtadsnsp"
10243
name = "add sentence spacing (input formatting)"
10244
example_yaml_value = "false"
10245
10246
@config_endpoint_schema
10247
class SamplerOrderSettingSchema(KoboldSchema):
10248
value = fields.List(fields.Integer(), validate=[validate.Length(min=6), permutation_validator], required=True)
10249
class KoboldMeta:
10250
route_name = "sampler_order"
10251
obj = "vars"
10252
var_name = "sampler_order"
10253
name = "sampler order"
10254
example_yaml_value = "[6, 0, 1, 2, 3, 4, 5]"
10255
10256
@config_endpoint_schema
10257
class SamplerFullDeterminismSettingSchema(KoboldSchema):
10258
value = fields.Boolean(required=True)
10259
class KoboldMeta:
10260
route_name = "sampler_full_determinism"
10261
obj = "vars"
10262
var_name = "full_determinism"
10263
name = "sampler full determinism"
10264
example_yaml_value = "false"
10265
10266
10267
for schema in config_endpoint_schemas:
10268
create_config_endpoint(schema=schema.__name__, method="GET")
10269
create_config_endpoint(schema=schema.__name__, method="PUT")
10270
10271
10272
#==================================================================#
10273
# Final startup commands to launch Flask app
10274
#==================================================================#
10275
if __name__ == "__main__":
10276
10277
general_startup()
10278
# Start flask & SocketIO
10279
logger.init("Flask", status="Starting")
10280
Session(app)
10281
logger.init_ok("Flask", status="OK")
10282
logger.init("Webserver", status="Starting")
10283
patch_transformers()
10284
#show_select_model_list()
10285
if vars.model == "" or vars.model is None:
10286
vars.model = "ReadOnly"
10287
load_model(initial_load=True)
10288
10289
# Start Flask/SocketIO (Blocking, so this must be last method!)
10290
port = args.port if "port" in args and args.port is not None else 5000
10291
10292
#socketio.run(app, host='0.0.0.0', port=port)
10293
if(vars.host):
10294
if(args.localtunnel):
10295
import subprocess, shutil
10296
localtunnel = subprocess.Popen([shutil.which('lt'), '-p', str(port), 'http'], stdout=subprocess.PIPE)
10297
attempts = 0
10298
while attempts < 10:
10299
try:
10300
cloudflare = str(localtunnel.stdout.readline())
10301
cloudflare = (re.search("(?P<url>https?:\/\/[^\s]+loca.lt)", cloudflare).group("url"))
10302
break
10303
except:
10304
attempts += 1
10305
time.sleep(3)
10306
continue
10307
if attempts == 10:
10308
print("LocalTunnel could not be created, falling back to cloudflare...")
10309
from flask_cloudflared import _run_cloudflared
10310
cloudflare = _run_cloudflared(port)
10311
elif(args.ngrok):
10312
from flask_ngrok import _run_ngrok
10313
cloudflare = _run_ngrok()
10314
elif(args.remote):
10315
from flask_cloudflared import _run_cloudflared
10316
cloudflare = _run_cloudflared(port)
10317
if(args.localtunnel or args.ngrok or args.remote):
10318
with open('cloudflare.log', 'w') as cloudflarelog:
10319
cloudflarelog.write("KoboldAI has finished loading and is available at the following link : " + cloudflare)
10320
logger.init_ok("Webserver", status="OK")
10321
logger.message(f"KoboldAI has finished loading and is available at the following link: {cloudflare}")
10322
else:
10323
logger.init_ok("Webserver", status="OK")
10324
logger.message(f"Webserver has started, you can now connect to this machine at port: {port}")
10325
vars.serverstarted = True
10326
socketio.run(app, host='0.0.0.0', port=port)
10327
else:
10328
if args.unblock:
10329
if not args.no_ui:
10330
try:
10331
import webbrowser
10332
webbrowser.open_new('http://localhost:{0}'.format(port))
10333
except:
10334
pass
10335
logger.init_ok("Webserver", status="OK")
10336
logger.message(f"Webserver started! You may now connect with a browser at http://127.0.0.1:{port}")
10337
vars.serverstarted = True
10338
socketio.run(app, port=port, host='0.0.0.0')
10339
else:
10340
if not args.no_ui:
10341
try:
10342
import webbrowser
10343
webbrowser.open_new('http://localhost:{0}'.format(port))
10344
except:
10345
pass
10346
logger.init_ok("Webserver", status="OK")
10347
logger.message(f"Webserver started! You may now connect with a browser at http://127.0.0.1:{port}")
10348
vars.serverstarted = True
10349
socketio.run(app, port=port)
10350
logger.init("Webserver", status="Closed")
10351
10352
10353
else:
10354
general_startup()
10355
# Start flask & SocketIO
10356
logger.init("Flask", status="Starting")
10357
Session(app)
10358
logger.init_ok("Flask", status="OK")
10359
patch_transformers()
10360
#show_select_model_list()
10361
if vars.model == "" or vars.model is None:
10362
vars.model = "ReadOnly"
10363
load_model(initial_load=True)
10364
print("{0}\nServer started in WSGI mode!{1}".format(colors.GREEN, colors.END), flush=True)
10365
10366