Path: blob/master/xtra_labs/llm_finetune/utils.py
549 views
"""1Contains functions that the students will not interface with2"""3import matplotlib.pyplot as plt4import numpy as np5import pandas as pd6import tensorflow as tf7import torch8import torch.nn.functional as F9from tqdm import tqdm1011def run_benchmark(model, tokenizer, dataset, few_shot=7, num_steps=500, verbose=False):12device = model.device13dataset["Correct"] = 0.01415# Loop through every question in the benchmark16for step, row in tqdm(dataset.iterrows(), total=len(dataset)):17question = row['Question']18pre_text = f"### Human: {question}### Assistant:"19len_prefix = len(tokenizer.encode(pre_text))2021# Run the model individually with each of the four responses.22# Measure the model's logprob for outputing each of the four responses.23# Choose the answer with the highest logprob24logprobs = []25answers = []26for choice in ["A", "B", "C", "D"]:27answer = row[f'Answer {choice}']28text = f"{pre_text} {answer}"2930# Run the model31with torch.no_grad():32x = tokenizer.encode(text, return_tensors="pt").to(device)33logits = model(x).logits34probs = F.softmax(logits, dim=-1)[0, :-1, :] # shape: [seq_len-1, vocab_size]35y = x[0, 1:] # shape: [seq_len-1]3637# Compute the log probability for this answer to appear (average logprob over the answer tokens)38next_token_prob = np.array([probs[i, y[i]].item() for i in range(y.shape[0])])39num_ans_tokens = x.shape[1] - len_prefix40logprob = np.mean(np.log(next_token_prob[-num_ans_tokens:]))41logprobs.append(logprob)42answers.append(answer)4344# Check for the correct answer (always the zero-th index, by definition)45correct = np.argmax(logprobs) == 04647# Record if the model got the answer correct or not.48# Optionally print the question -> prediction if verbose49dataset.at[step, "Correct"] = float(correct)50if verbose:51print(f"[{correct}] {question} -> {answers[np.argmax(logprobs)]}")525354# Group by the the categories and compute the average accuracy55accs = dataset.groupby("Category")["Correct"].mean()56sorted_accs = accs.sort_values()57print(sorted_accs)5859return accs, dataset["Correct"].mean()6061def make_spider_plot(data):62"""63Data is a dictionary where keys are different entities64Values are pd Series where series indices are plot labels and series values show performance65"""66colors = ['#1aaf6c', '#429bf4', '#d42cea']67i = 068fig, ax = plt.subplots(figsize=(8,6), subplot_kw=dict(polar=True))69for k,v in data.items():70labels = v.index.tolist()71values = v.values.tolist()7273num_vars = len(labels)74angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist()75angles += angles[:1]76values += values[:1]7778ax.plot(angles, values, color=colors[i], linewidth=1, label=k)79ax.fill(angles, values, color=colors[i], alpha=0.25)8081i+=18283ax.set_theta_offset(np.pi / 2)84ax.set_theta_direction(-1)85ax.set_thetagrids(np.degrees(angles[:-1]), labels)86for label, angle in zip(ax.get_xticklabels(), angles):87if angle in (0, np.pi):88label.set_horizontalalignment('center')89elif 0 < angle < np.pi:90label.set_horizontalalignment('left')91else:92label.set_horizontalalignment('right')9394ax.set_ylim(0, 1)95ax.set_rlabel_position(180 / num_vars)9697ax.legend(loc='upper right', bbox_to_anchor=(1.3, 1.1))9899plt.savefig("spider.png")100101102103104