Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
KoboldAI
GitHub Repository: KoboldAI/KoboldAI-Client
Path: blob/main/userscripts/kaipreset_logits_viewer.lua
473 views
1
-- Logit viewer
2
-- Displays raw token scores and softmax probabilities during generation.
3
4
kobold = require("bridge")()
5
local userscript = {} ---@class KoboldUserScript
6
7
local K = 10
8
9
---@class Pair
10
---@field id integer
11
---@field score number
12
13
---@class ArrayBase
14
---@type table<any, Pair>
15
local _ = {}
16
17
---@class Array : ArrayBase
18
---@field n integer
19
20
---@param array Array
21
---@param index integer
22
---@return nil
23
local function bubble(array, index)
24
local j = 0
25
while (index<<1)+1 < array.n do
26
j = index
27
if array[(index<<1)+1].score > array[j].score then
28
j = (index<<1)+1
29
end
30
if (index<<1)+2 < array.n and array[(index<<1)+2].score > array[j].score then
31
j = (index<<1)+2
32
end
33
if index == j then
34
break
35
end
36
local b = array[index]
37
array[index] = array[j]
38
array[j] = b
39
index = j
40
end
41
end
42
43
---@param array Array
44
---@return nil
45
local function build(array)
46
for i = (array.n-1)>>1, 0, -1 do
47
bubble(array, i)
48
end
49
end
50
51
---@param array Array
52
---@return Pair
53
local function pop(array)
54
local r = array[0]
55
array.n = array.n - 1
56
array[0] = array[array.n]
57
bubble(array, 0)
58
return r
59
end
60
61
function userscript.genmod()
62
if K > kobold.logits_cols then
63
error("K must be at most the vocabulary size of the model")
64
end
65
66
if kobold.generated_cols > 0 then
67
for s, logits in ipairs(kobold.logits) do
68
local token = kobold.generated[s][kobold.generated_cols]
69
print("Previous result for sequence " .. s .. ": [" .. kobold.decode(token):gsub("\n", "\\n") .. "] (" .. math.tointeger(token) .. ")")
70
end
71
end
72
73
for s, logits in ipairs(kobold.logits) do
74
local a = {} ---@type Array
75
local sum = 0.0
76
for i = 0, kobold.logits_cols-1 do
77
a[i] = {id = i, score = logits[i + 1]}
78
a.n = i + 1
79
sum = sum + math.exp(logits[i + 1])
80
end
81
build(a)
82
print()
83
print("Top " .. K .. " scores for sequence " .. s .. ":")
84
for i = 1, K do
85
local e = pop(a)
86
print(("%.6f"):format(e.score), ("%.3f%% "):format(100 * (math.exp(e.score) / sum)), e.id, "[" .. (kobold.decode(e.id):gsub("\n", "\\n")) .. "]")
87
end
88
end
89
end
90
91
return userscript
92
93