Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
aamini
GitHub Repository: aamini/introtodeeplearning
Path: blob/master/xtra_labs/llm_finetune/utils.py
549 views
1
"""
2
Contains functions that the students will not interface with
3
"""
4
import matplotlib.pyplot as plt
5
import numpy as np
6
import pandas as pd
7
import tensorflow as tf
8
import torch
9
import torch.nn.functional as F
10
from tqdm import tqdm
11
12
def run_benchmark(model, tokenizer, dataset, few_shot=7, num_steps=500, verbose=False):
13
device = model.device
14
dataset["Correct"] = 0.0
15
16
# Loop through every question in the benchmark
17
for step, row in tqdm(dataset.iterrows(), total=len(dataset)):
18
question = row['Question']
19
pre_text = f"### Human: {question}### Assistant:"
20
len_prefix = len(tokenizer.encode(pre_text))
21
22
# Run the model individually with each of the four responses.
23
# Measure the model's logprob for outputing each of the four responses.
24
# Choose the answer with the highest logprob
25
logprobs = []
26
answers = []
27
for choice in ["A", "B", "C", "D"]:
28
answer = row[f'Answer {choice}']
29
text = f"{pre_text} {answer}"
30
31
# Run the model
32
with torch.no_grad():
33
x = tokenizer.encode(text, return_tensors="pt").to(device)
34
logits = model(x).logits
35
probs = F.softmax(logits, dim=-1)[0, :-1, :] # shape: [seq_len-1, vocab_size]
36
y = x[0, 1:] # shape: [seq_len-1]
37
38
# Compute the log probability for this answer to appear (average logprob over the answer tokens)
39
next_token_prob = np.array([probs[i, y[i]].item() for i in range(y.shape[0])])
40
num_ans_tokens = x.shape[1] - len_prefix
41
logprob = np.mean(np.log(next_token_prob[-num_ans_tokens:]))
42
logprobs.append(logprob)
43
answers.append(answer)
44
45
# Check for the correct answer (always the zero-th index, by definition)
46
correct = np.argmax(logprobs) == 0
47
48
# Record if the model got the answer correct or not.
49
# Optionally print the question -> prediction if verbose
50
dataset.at[step, "Correct"] = float(correct)
51
if verbose:
52
print(f"[{correct}] {question} -> {answers[np.argmax(logprobs)]}")
53
54
55
# Group by the the categories and compute the average accuracy
56
accs = dataset.groupby("Category")["Correct"].mean()
57
sorted_accs = accs.sort_values()
58
print(sorted_accs)
59
60
return accs, dataset["Correct"].mean()
61
62
def make_spider_plot(data):
63
"""
64
Data is a dictionary where keys are different entities
65
Values are pd Series where series indices are plot labels and series values show performance
66
"""
67
colors = ['#1aaf6c', '#429bf4', '#d42cea']
68
i = 0
69
fig, ax = plt.subplots(figsize=(8,6), subplot_kw=dict(polar=True))
70
for k,v in data.items():
71
labels = v.index.tolist()
72
values = v.values.tolist()
73
74
num_vars = len(labels)
75
angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist()
76
angles += angles[:1]
77
values += values[:1]
78
79
ax.plot(angles, values, color=colors[i], linewidth=1, label=k)
80
ax.fill(angles, values, color=colors[i], alpha=0.25)
81
82
i+=1
83
84
ax.set_theta_offset(np.pi / 2)
85
ax.set_theta_direction(-1)
86
ax.set_thetagrids(np.degrees(angles[:-1]), labels)
87
for label, angle in zip(ax.get_xticklabels(), angles):
88
if angle in (0, np.pi):
89
label.set_horizontalalignment('center')
90
elif 0 < angle < np.pi:
91
label.set_horizontalalignment('left')
92
else:
93
label.set_horizontalalignment('right')
94
95
ax.set_ylim(0, 1)
96
ax.set_rlabel_position(180 / num_vars)
97
98
ax.legend(loc='upper right', bbox_to_anchor=(1.3, 1.1))
99
100
plt.savefig("spider.png")
101
102
103
104