Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
KoboldAI
GitHub Repository: KoboldAI/KoboldAI-Client
Path: blob/main/fileops.py
471 views
1
from os import getcwd, listdir, path
2
from typing import Tuple, Union, Optional
3
import os
4
import json
5
import zipfile
6
from logger import logger
7
8
#==================================================================#
9
# Generic Method for prompting for file path
10
#==================================================================#
11
def getsavepath(dir, title, types):
12
import tkinter as tk
13
from tkinter import filedialog
14
root = tk.Tk()
15
root.attributes("-topmost", True)
16
path = tk.filedialog.asksaveasfile(
17
initialdir=dir,
18
title=title,
19
filetypes = types,
20
defaultextension="*.*"
21
)
22
root.destroy()
23
if(path != "" and path != None):
24
return path.name
25
else:
26
return None
27
28
#==================================================================#
29
# Generic Method for prompting for file path
30
#==================================================================#
31
def getloadpath(dir, title, types):
32
import tkinter as tk
33
from tkinter import filedialog
34
root = tk.Tk()
35
root.attributes("-topmost", True)
36
path = tk.filedialog.askopenfilename(
37
initialdir=dir,
38
title=title,
39
filetypes = types
40
)
41
root.destroy()
42
if(path != "" and path != None):
43
return path
44
else:
45
return None
46
47
#==================================================================#
48
# Generic Method for prompting for directory path
49
#==================================================================#
50
def getdirpath(dir, title):
51
import tkinter as tk
52
from tkinter import filedialog
53
root = tk.Tk()
54
root.attributes("-topmost", True)
55
path = filedialog.askdirectory(
56
initialdir=dir,
57
title=title
58
)
59
root.destroy()
60
if(path != "" and path != None):
61
return path
62
else:
63
return None
64
65
#==================================================================#
66
# Returns the path (as a string) to the given story by its name
67
#==================================================================#
68
def storypath(name):
69
return path.join("stories", name + ".json")
70
71
#==================================================================#
72
# Returns the path (as a string) to the given soft prompt by its filename
73
#==================================================================#
74
def sppath(filename):
75
return path.join("softprompts", filename)
76
77
#==================================================================#
78
# Returns the path (as a string) to the given username by its filename
79
#==================================================================#
80
def uspath(filename):
81
return path.join("userscripts", filename)
82
83
#==================================================================#
84
# Returns an array of dicts containing story files in /stories
85
#==================================================================#
86
def getstoryfiles():
87
list = []
88
for file in listdir("stories"):
89
if file.endswith(".json") and not file.endswith(".v2.json"):
90
ob = {}
91
ob["name"] = file.replace(".json", "")
92
f = open("stories/"+file, "r")
93
try:
94
js = json.load(f)
95
except:
96
print(f"Browser loading error: {file} is malformed or not a JSON file.")
97
f.close()
98
continue
99
f.close()
100
try:
101
ob["actions"] = len(js["actions"])
102
except TypeError:
103
print(f"Browser loading error: {file} has incorrect format.")
104
continue
105
list.append(ob)
106
return list
107
108
#==================================================================#
109
# Checks if the given soft prompt file is valid
110
#==================================================================#
111
def checksp(filename: str, model_dimension: int) -> Tuple[Union[zipfile.ZipFile, int], Optional[Tuple[int, int]], Optional[Tuple[int, int]], Optional[bool], Optional['np.dtype']]:
112
global np
113
if 'np' not in globals():
114
import numpy as np
115
try:
116
z = zipfile.ZipFile("softprompts/"+filename)
117
with z.open('tensor.npy') as f:
118
# Read only the header of the npy file, for efficiency reasons
119
version: Tuple[int, int] = np.lib.format.read_magic(f)
120
shape: Tuple[int, int]
121
fortran_order: bool
122
dtype: np.dtype
123
shape, fortran_order, dtype = np.lib.format._read_array_header(f, version)
124
assert len(shape) == 2
125
except:
126
try:
127
z.close()
128
except UnboundLocalError:
129
pass
130
return 1, None, None, None, None
131
if dtype not in ('V2', np.float16, np.float32):
132
z.close()
133
return 2, version, shape, fortran_order, dtype
134
if shape[1] != model_dimension:
135
z.close()
136
return 3, version, shape, fortran_order, dtype
137
if shape[0] >= 2048:
138
z.close()
139
return 4, version, shape, fortran_order, dtype
140
return z, version, shape, fortran_order, dtype
141
142
#==================================================================#
143
# Returns an array of dicts containing softprompt files in /softprompts
144
#==================================================================#
145
def getspfiles(model_dimension: int):
146
lst = []
147
os.makedirs("softprompts", exist_ok=True)
148
for file in listdir("softprompts"):
149
if not file.endswith(".zip"):
150
continue
151
z, version, shape, fortran_order, dtype = checksp(file, model_dimension)
152
if z == 1:
153
logger.warning(f"Softprompt {file} is malformed or not a soft prompt ZIP file.")
154
continue
155
if z == 2:
156
logger.warning(f"Softprompt {file} tensor.npy has unsupported dtype '{dtype.name}'.")
157
continue
158
if z == 3:
159
logger.debug(f"Softprompt {file} tensor.npy has model dimension {shape[1]} which does not match your model's model dimension of {model_dimension}. This usually means this soft prompt is not compatible with your model.")
160
continue
161
if z == 4:
162
logger.warning(f"Softprompt {file} tensor.npy has {shape[0]} tokens but it is supposed to have less than 2048 tokens.")
163
continue
164
assert isinstance(z, zipfile.ZipFile)
165
try:
166
with z.open('meta.json') as f:
167
ob = json.load(f)
168
except:
169
ob = {}
170
z.close()
171
ob["filename"] = file
172
ob["n_tokens"] = shape[-2]
173
lst.append(ob)
174
return lst
175
176
#==================================================================#
177
# Returns an array of dicts containing userscript files in /userscripts
178
#==================================================================#
179
def getusfiles(long_desc=False):
180
lst = []
181
os.makedirs("userscripts", exist_ok=True)
182
for file in listdir("userscripts"):
183
if file.endswith(".lua"):
184
ob = {}
185
ob["filename"] = file
186
description = []
187
multiline = False
188
with open(uspath(file)) as f:
189
ob["modulename"] = f.readline().strip().replace("\033", "")
190
if ob["modulename"][:2] != "--":
191
ob["modulename"] = file
192
else:
193
ob["modulename"] = ob["modulename"][2:]
194
if ob["modulename"][:2] == "[[":
195
ob["modulename"] = ob["modulename"][2:]
196
multiline = True
197
ob["modulename"] = ob["modulename"].lstrip("-").strip()
198
for line in f:
199
line = line.strip().replace("\033", "")
200
if multiline:
201
index = line.find("]]")
202
if index > -1:
203
description.append(line[:index])
204
if index != len(line) - 2:
205
break
206
multiline = False
207
else:
208
description.append(line)
209
else:
210
if line[:2] != "--":
211
break
212
line = line[2:]
213
if line[:2] == "[[":
214
multiline = True
215
line = line[2:]
216
description.append(line.strip())
217
ob["description"] = "\n".join(description)
218
if not long_desc:
219
if len(ob["description"]) > 250:
220
ob["description"] = ob["description"][:247] + "..."
221
lst.append(ob)
222
return lst
223
224
#==================================================================#
225
# Returns True if json file exists with requested save name
226
#==================================================================#
227
def saveexists(name):
228
return path.exists(storypath(name))
229
230
#==================================================================#
231
# Delete save file by name; returns None if successful, or the exception if not
232
#==================================================================#
233
def deletesave(name):
234
try:
235
os.remove(storypath(name))
236
except Exception as e:
237
return e
238
239
#==================================================================#
240
# Rename save file; returns None if successful, or the exception if not
241
#==================================================================#
242
def renamesave(name, new_name):
243
try:
244
os.replace(storypath(name), storypath(new_name))
245
except Exception as e:
246
return e
247
248