Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
KoboldAI
GitHub Repository: KoboldAI/KoboldAI-Client
Path: blob/main/userscripts/kaipreset_basic_phrase_bias.lua
473 views
1
-- Basic phrase bias
2
-- Makes certain sequences of tokens more or less likely to appear than normal.
3
4
-- This file is part of KoboldAI.
5
--
6
-- KoboldAI is free software: you can redistribute it and/or modify
7
-- it under the terms of the GNU Affero General Public License as published by
8
-- the Free Software Foundation, either version 3 of the License, or
9
-- (at your option) any later version.
10
--
11
-- This program is distributed in the hope that it will be useful,
12
-- but WITHOUT ANY WARRANTY; without even the implied warranty of
13
-- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
-- GNU Affero General Public License for more details.
15
--
16
-- You should have received a copy of the GNU Affero General Public License
17
-- along with this program. If not, see <https://www.gnu.org/licenses/>.
18
19
kobold = require("bridge")() -- This line is optional and is only for EmmyLua type annotations
20
local userscript = {} ---@class KoboldUserScript
21
22
23
---@class PhraseBiasEntry
24
---@field starting_bias number
25
---@field ending_bias number
26
---@field tokens table<integer, integer>
27
---@field n_tokens integer
28
29
local example_config = [[# Phrase bias
30
#
31
# For each phrase you want to bias, add a new line into
32
# this config file as a comma-separated list in this format:
33
# <starting bias>, <ending bias>, <comma-separated list of token IDs>
34
# For <starting bias> and <ending bias>, this script accepts floating point
35
# numbers or -inf, where positive bias values make it more likely for tokens
36
# to appear, negative bias values make it less likely and -inf makes it
37
# impossible.
38
#
39
# Example 1 (makes it impossible for the word "CHAPTER", case-sensitive, to
40
# appear at the beginning of a line in the output):
41
# -inf, -inf, 41481
42
#
43
# Example 2 (makes it unlikely for the word " CHAPTER", case-sensitive, with
44
# a leading space, to appear in the output, with the unlikeliness increasing
45
# even more if the first token " CH" has appeared):
46
# -10.0, -20.0, 5870, 29485
47
#
48
# Example 3 (makes it more likely for " let the voice of love take you higher",
49
# case-sensitive, with a leading space, to appear in the output, with the
50
# bias increasing as each consecutive token in that phrase appears):
51
# 7, 25.4, 1309, 262, 3809, 286, 1842, 1011, 345, 2440
52
#
53
]]
54
55
-- If config file is empty, write example config
56
local f = kobold.get_config_file()
57
f:seek("set")
58
if f:read(1) == nil then
59
f:write(example_config)
60
end
61
f:seek("set")
62
example_config = nil
63
64
-- Read config
65
print("Loading phrase bias config...")
66
local bias_array = {} ---@type table<integer, PhraseBiasEntry>
67
local bias_array_count = 0
68
local val_count = 0
69
local line_count = 0
70
local row = {} ---@type PhraseBiasEntry
71
local val_orig
72
for line in f:lines("l") do
73
line_count = line_count + 1
74
if line:find("^ *#") == nil and line:find("%S") ~= nil then
75
bias_array_count = bias_array_count + 1
76
val_count = 0
77
row = {}
78
row.tokens = {}
79
row.n_tokens = 0
80
for val in line:gmatch("[^,%s]+") do
81
val_count = val_count + 1
82
val_orig = val
83
if val_count <= 2 then
84
val = val:lower()
85
if val:sub(-3) == "inf" then
86
val = math.tointeger(val:sub(1, -4) .. "1")
87
if val ~= val or type(val) ~= "number" or val > 0 then
88
f:close()
89
error("First two values of line " .. line_count .. " of config file must be finite floating-point numbers or -inf, but got '" .. val_orig .. "' as value #" .. val_count)
90
end
91
val = val * math.huge
92
else
93
val = tonumber(val)
94
if val ~= val or type(val) ~= "number" then
95
f:close()
96
error("First two values of line " .. line_count .. " of config file must be finite floating-point numbers or -inf, but got '" .. val_orig .. "' as value #" .. val_count)
97
end
98
end
99
if val_count == 1 then
100
row.starting_bias = val
101
else
102
row.ending_bias = val
103
end
104
else
105
val = math.tointeger(val)
106
if type(val) ~= "number" or val < 0 then
107
f:close()
108
error("All values after the first two values of line " .. line_count .. " of config file must be nonnegative integers, but got '" .. val_orig .. "' as value #" .. val_count)
109
end
110
row.n_tokens = row.n_tokens + 1
111
row.tokens[row.n_tokens] = val
112
end
113
end
114
if val_count < 3 then
115
f:close()
116
error("Line " .. line_count .. " of config file must contain at least 3 values, but found " .. val_count)
117
end
118
bias_array[bias_array_count] = row
119
end
120
end
121
f:close()
122
print("Successfully loaded " .. bias_array_count .. " phrase bias entr" .. (bias_array_count == 1 and "y" or "ies") .. ".")
123
124
125
local genmod_run = false
126
127
---@param starting_val number
128
---@param ending_val number
129
---@param factor number
130
---@return number
131
local function logit_interpolate(starting_val, ending_val, factor)
132
-- First use the logistic function on the start and end values
133
starting_val = 1/(1 + math.exp(-starting_val))
134
ending_val = 1/(1 + math.exp(-ending_val))
135
136
-- Use linear interpolation between these two values
137
local val = starting_val + factor*(ending_val - starting_val)
138
139
-- Return logit of this value
140
return math.log(val/(1 - val))
141
end
142
143
144
function userscript.genmod()
145
genmod_run = true
146
147
local context_tokens = kobold.encode(kobold.worldinfo:compute_context(kobold.submission))
148
local factor ---@type number
149
local next_token ---@type integer
150
local sequences = {} ---@type table<integer, table<integer, integer>>
151
local n_tokens = 0
152
local max_overlap = {} ---@type table<integer, integer>
153
154
local biased_tokens = {} ---@type table<integer, table<integer, boolean>>
155
for i = 1, kobold.generated_rows do
156
biased_tokens[i] = {}
157
end
158
159
-- For each partially-generated sequence...
160
for i, generated_row in ipairs(kobold.generated) do
161
162
-- Build an array `tokens` as the concatenation of the context
163
-- tokens and the generated tokens of this sequence
164
165
local tokens = {}
166
n_tokens = 0
167
for k, v in ipairs(context_tokens) do
168
n_tokens = n_tokens + 1
169
tokens[n_tokens] = v
170
end
171
for k, v in ipairs(generated_row) do
172
n_tokens = n_tokens + 1
173
tokens[n_tokens] = v
174
end
175
176
-- For each phrase bias entry in the config file...
177
for _, bias_entry in ipairs(bias_array) do
178
179
-- Determine the largest integer `max_overlap[i]` such that the last
180
-- `max_overlap[i]` elements of `tokens` equal the first
181
-- `max_overlap[i]` elements of `bias_entry.tokens`
182
183
max_overlap[i] = 0
184
local s = {}
185
local z = {[0] = 0}
186
local l = 0
187
local r = 0
188
local n_s = math.min(n_tokens, bias_entry.n_tokens)
189
local j = 0
190
for k = 1, n_s do
191
s[j] = bias_entry.tokens[k]
192
j = j + 1
193
end
194
for k = n_tokens - n_s + 1, n_tokens do
195
s[j] = tokens[k]
196
j = j + 1
197
end
198
for k = 1, (n_s<<1) - 1 do
199
if k <= r and z[k - l] - 1 < r - k then
200
z[k] = z[k - l]
201
else
202
l = k
203
if k > r then
204
r = k
205
end
206
while r < (n_s<<1) and s[r - l] == s[r] do
207
r = r + 1
208
end
209
z[k] = r - l
210
r = r - 1
211
end
212
if z[k] <= n_s and z[k] == (n_s<<1) - k then
213
max_overlap[i] = z[k]
214
break
215
end
216
end
217
end
218
end
219
220
-- For each phrase bias entry in the config file...
221
for _, bias_entry in ipairs(bias_array) do
222
223
-- For each partially-generated sequence...
224
for i, generated_row in ipairs(kobold.generated) do
225
226
-- Use `max_overlap` to determine which token in the bias entry to
227
-- apply bias to
228
229
if max_overlap[i] == 0 or max_overlap[i] == bias_entry.n_tokens then
230
if bias_entry.tokens[2] == nil then
231
factor = 1
232
else
233
factor = 0
234
end
235
next_token = bias_entry.tokens[1]
236
else
237
factor = max_overlap[i]/(bias_entry.n_tokens - 1)
238
next_token = bias_entry.tokens[max_overlap[i]+1]
239
end
240
241
-- Apply bias
242
243
if not biased_tokens[i][next_token] then
244
kobold.logits[i][next_token + 1] = kobold.logits[i][next_token + 1] + logit_interpolate(bias_entry.starting_bias, bias_entry.ending_bias, factor)
245
biased_tokens[i][next_token] = true
246
end
247
end
248
end
249
end
250
251
function userscript.outmod()
252
if not genmod_run then
253
warn("WARNING: Generation modifier was not executed, so this script has had no effect")
254
end
255
end
256
257
return userscript
258
259