Path: blob/main/userscripts/kaipreset_logits_viewer.lua
473 views
-- Logit viewer1-- Displays raw token scores and softmax probabilities during generation.23kobold = require("bridge")()4local userscript = {} ---@class KoboldUserScript56local K = 1078---@class Pair9---@field id integer10---@field score number1112---@class ArrayBase13---@type table<any, Pair>14local _ = {}1516---@class Array : ArrayBase17---@field n integer1819---@param array Array20---@param index integer21---@return nil22local function bubble(array, index)23local j = 024while (index<<1)+1 < array.n do25j = index26if array[(index<<1)+1].score > array[j].score then27j = (index<<1)+128end29if (index<<1)+2 < array.n and array[(index<<1)+2].score > array[j].score then30j = (index<<1)+231end32if index == j then33break34end35local b = array[index]36array[index] = array[j]37array[j] = b38index = j39end40end4142---@param array Array43---@return nil44local function build(array)45for i = (array.n-1)>>1, 0, -1 do46bubble(array, i)47end48end4950---@param array Array51---@return Pair52local function pop(array)53local r = array[0]54array.n = array.n - 155array[0] = array[array.n]56bubble(array, 0)57return r58end5960function userscript.genmod()61if K > kobold.logits_cols then62error("K must be at most the vocabulary size of the model")63end6465if kobold.generated_cols > 0 then66for s, logits in ipairs(kobold.logits) do67local token = kobold.generated[s][kobold.generated_cols]68print("Previous result for sequence " .. s .. ": [" .. kobold.decode(token):gsub("\n", "\\n") .. "] (" .. math.tointeger(token) .. ")")69end70end7172for s, logits in ipairs(kobold.logits) do73local a = {} ---@type Array74local sum = 0.075for i = 0, kobold.logits_cols-1 do76a[i] = {id = i, score = logits[i + 1]}77a.n = i + 178sum = sum + math.exp(logits[i + 1])79end80build(a)81print()82print("Top " .. K .. " scores for sequence " .. s .. ":")83for i = 1, K do84local e = pop(a)85print(("%.6f"):format(e.score), ("%.3f%% "):format(100 * (math.exp(e.score) / sum)), e.id, "[" .. (kobold.decode(e.id):gsub("\n", "\\n")) .. "]")86end87end88end8990return userscript919293