Path: blob/main/userscripts/kaipreset_basic_phrase_bias.lua
473 views
-- Basic phrase bias1-- Makes certain sequences of tokens more or less likely to appear than normal.23-- This file is part of KoboldAI.4--5-- KoboldAI is free software: you can redistribute it and/or modify6-- it under the terms of the GNU Affero General Public License as published by7-- the Free Software Foundation, either version 3 of the License, or8-- (at your option) any later version.9--10-- This program is distributed in the hope that it will be useful,11-- but WITHOUT ANY WARRANTY; without even the implied warranty of12-- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the13-- GNU Affero General Public License for more details.14--15-- You should have received a copy of the GNU Affero General Public License16-- along with this program. If not, see <https://www.gnu.org/licenses/>.1718kobold = require("bridge")() -- This line is optional and is only for EmmyLua type annotations19local userscript = {} ---@class KoboldUserScript202122---@class PhraseBiasEntry23---@field starting_bias number24---@field ending_bias number25---@field tokens table<integer, integer>26---@field n_tokens integer2728local example_config = [[# Phrase bias29#30# For each phrase you want to bias, add a new line into31# this config file as a comma-separated list in this format:32# <starting bias>, <ending bias>, <comma-separated list of token IDs>33# For <starting bias> and <ending bias>, this script accepts floating point34# numbers or -inf, where positive bias values make it more likely for tokens35# to appear, negative bias values make it less likely and -inf makes it36# impossible.37#38# Example 1 (makes it impossible for the word "CHAPTER", case-sensitive, to39# appear at the beginning of a line in the output):40# -inf, -inf, 4148141#42# Example 2 (makes it unlikely for the word " CHAPTER", case-sensitive, with43# a leading space, to appear in the output, with the unlikeliness increasing44# even more if the first token " CH" has appeared):45# -10.0, -20.0, 5870, 2948546#47# Example 3 (makes it more likely for " let the voice of love take you higher",48# case-sensitive, with a leading space, to appear in the output, with the49# bias increasing as each consecutive token in that phrase appears):50# 7, 25.4, 1309, 262, 3809, 286, 1842, 1011, 345, 244051#52]]5354-- If config file is empty, write example config55local f = kobold.get_config_file()56f:seek("set")57if f:read(1) == nil then58f:write(example_config)59end60f:seek("set")61example_config = nil6263-- Read config64print("Loading phrase bias config...")65local bias_array = {} ---@type table<integer, PhraseBiasEntry>66local bias_array_count = 067local val_count = 068local line_count = 069local row = {} ---@type PhraseBiasEntry70local val_orig71for line in f:lines("l") do72line_count = line_count + 173if line:find("^ *#") == nil and line:find("%S") ~= nil then74bias_array_count = bias_array_count + 175val_count = 076row = {}77row.tokens = {}78row.n_tokens = 079for val in line:gmatch("[^,%s]+") do80val_count = val_count + 181val_orig = val82if val_count <= 2 then83val = val:lower()84if val:sub(-3) == "inf" then85val = math.tointeger(val:sub(1, -4) .. "1")86if val ~= val or type(val) ~= "number" or val > 0 then87f:close()88error("First two values of line " .. line_count .. " of config file must be finite floating-point numbers or -inf, but got '" .. val_orig .. "' as value #" .. val_count)89end90val = val * math.huge91else92val = tonumber(val)93if val ~= val or type(val) ~= "number" then94f:close()95error("First two values of line " .. line_count .. " of config file must be finite floating-point numbers or -inf, but got '" .. val_orig .. "' as value #" .. val_count)96end97end98if val_count == 1 then99row.starting_bias = val100else101row.ending_bias = val102end103else104val = math.tointeger(val)105if type(val) ~= "number" or val < 0 then106f:close()107error("All values after the first two values of line " .. line_count .. " of config file must be nonnegative integers, but got '" .. val_orig .. "' as value #" .. val_count)108end109row.n_tokens = row.n_tokens + 1110row.tokens[row.n_tokens] = val111end112end113if val_count < 3 then114f:close()115error("Line " .. line_count .. " of config file must contain at least 3 values, but found " .. val_count)116end117bias_array[bias_array_count] = row118end119end120f:close()121print("Successfully loaded " .. bias_array_count .. " phrase bias entr" .. (bias_array_count == 1 and "y" or "ies") .. ".")122123124local genmod_run = false125126---@param starting_val number127---@param ending_val number128---@param factor number129---@return number130local function logit_interpolate(starting_val, ending_val, factor)131-- First use the logistic function on the start and end values132starting_val = 1/(1 + math.exp(-starting_val))133ending_val = 1/(1 + math.exp(-ending_val))134135-- Use linear interpolation between these two values136local val = starting_val + factor*(ending_val - starting_val)137138-- Return logit of this value139return math.log(val/(1 - val))140end141142143function userscript.genmod()144genmod_run = true145146local context_tokens = kobold.encode(kobold.worldinfo:compute_context(kobold.submission))147local factor ---@type number148local next_token ---@type integer149local sequences = {} ---@type table<integer, table<integer, integer>>150local n_tokens = 0151local max_overlap = {} ---@type table<integer, integer>152153local biased_tokens = {} ---@type table<integer, table<integer, boolean>>154for i = 1, kobold.generated_rows do155biased_tokens[i] = {}156end157158-- For each partially-generated sequence...159for i, generated_row in ipairs(kobold.generated) do160161-- Build an array `tokens` as the concatenation of the context162-- tokens and the generated tokens of this sequence163164local tokens = {}165n_tokens = 0166for k, v in ipairs(context_tokens) do167n_tokens = n_tokens + 1168tokens[n_tokens] = v169end170for k, v in ipairs(generated_row) do171n_tokens = n_tokens + 1172tokens[n_tokens] = v173end174175-- For each phrase bias entry in the config file...176for _, bias_entry in ipairs(bias_array) do177178-- Determine the largest integer `max_overlap[i]` such that the last179-- `max_overlap[i]` elements of `tokens` equal the first180-- `max_overlap[i]` elements of `bias_entry.tokens`181182max_overlap[i] = 0183local s = {}184local z = {[0] = 0}185local l = 0186local r = 0187local n_s = math.min(n_tokens, bias_entry.n_tokens)188local j = 0189for k = 1, n_s do190s[j] = bias_entry.tokens[k]191j = j + 1192end193for k = n_tokens - n_s + 1, n_tokens do194s[j] = tokens[k]195j = j + 1196end197for k = 1, (n_s<<1) - 1 do198if k <= r and z[k - l] - 1 < r - k then199z[k] = z[k - l]200else201l = k202if k > r then203r = k204end205while r < (n_s<<1) and s[r - l] == s[r] do206r = r + 1207end208z[k] = r - l209r = r - 1210end211if z[k] <= n_s and z[k] == (n_s<<1) - k then212max_overlap[i] = z[k]213break214end215end216end217end218219-- For each phrase bias entry in the config file...220for _, bias_entry in ipairs(bias_array) do221222-- For each partially-generated sequence...223for i, generated_row in ipairs(kobold.generated) do224225-- Use `max_overlap` to determine which token in the bias entry to226-- apply bias to227228if max_overlap[i] == 0 or max_overlap[i] == bias_entry.n_tokens then229if bias_entry.tokens[2] == nil then230factor = 1231else232factor = 0233end234next_token = bias_entry.tokens[1]235else236factor = max_overlap[i]/(bias_entry.n_tokens - 1)237next_token = bias_entry.tokens[max_overlap[i]+1]238end239240-- Apply bias241242if not biased_tokens[i][next_token] then243kobold.logits[i][next_token + 1] = kobold.logits[i][next_token + 1] + logit_interpolate(bias_entry.starting_bias, bias_entry.ending_bias, factor)244biased_tokens[i][next_token] = true245end246end247end248end249250function userscript.outmod()251if not genmod_run then252warn("WARNING: Generation modifier was not executed, so this script has had no effect")253end254end255256return userscript257258259