Path: blob/master/xtra_labs/llm_finetune/draft.py
549 views
"""1Drafting lab flow in script format using PyTorch2"""3from datasets import load_dataset4import math5import numpy as np6import pandas as pd7import random8import torch9import torch.nn as nn10import torch.nn.functional as F11from torch.nn import CrossEntropyLoss12from torch.optim import Adam13import transformers14from trl import SFTTrainer15from tqdm import tqdm1617from utils import run_benchmark, make_spider_plot1819# Part 12021# TEXT: overview of LLM lab22# Load pretrained LLM (medium size model)2324# model_name = "facebook/opt-1.3b"25model_name = "facebook/opt-125m"26model = transformers.AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")27tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)2829# TEXT: explain tokenizer30# Include cell for tokenizer inspection3132# TEXT: explain how LLMs are trained for next token prediction33# Write a function to predict next token34def predict_next_token(probs, tokenizer):35new_token = np.random.choice(len(probs), p=probs.numpy())36print(tokenizer.decode(new_token), end='', flush=True)37return new_token3839# TEXT: explain that next token prediction must be called multiple times for inference40# Call in loop for autoregressive inference41def generate(start_text, model, tokenizer, num_steps=20, temp=1.):42print(start_text, end="")43x = tokenizer.encode(start_text)44num_start = len(x)4546for i in range(num_steps):47input_tensor = torch.tensor(x).view(1, -1).to("cuda")48logits = model(input_tensor).logits49probs = F.softmax(logits/temp, dim=-1)[0, -1, :].cpu().detach()5051new_token = predict_next_token(probs, tokenizer)52x.append(new_token)5354output = tokenizer.decode(x[num_start:])55return output5657def generate_pt(model, tokenizer, text, num_steps=50, until=None, temp=1.):58device = model.device59print(text, end='', flush=True)60x = tokenizer.encode(text)61enc_until = tokenizer.encode(until)[1:]62num_start = len(x)6364decoded = tokenizer.decode(x)6566for step in range(num_steps):67with torch.no_grad():68input_tensor = torch.reshape(torch.LongTensor(x), [1, -1]).to(device)69logits = model(input_tensor).logits70probs = F.softmax(logits/temp, dim=-1)[0, -1, :]71probs = probs.detach().cpu().numpy()7273new_token = np.random.choice(len(probs), p=probs)74x.append(new_token)7576new_decoded = tokenizer.decode(x)77new_part = new_decoded[len(decoded):]78decoded = new_decoded7980print(new_part, end='', flush=True)81text += new_part8283if len(x) >= len(until) and text[-len(until):] == until:84break858687output = tokenizer.decode(x[num_start:])88print("\n", flush=True)89return output9091# Test autoregressive generation92# while True:93# print("\n\n\n\n\n")94# input_text = input("Prompt: ")95# output = generate(input_text, model, tokenizer)9697# TEXT: some background on LLM benchmarking98# Load benchmark dataset and evaluate model99benchmark_dataset = pd.read_csv("benchmark.csv")100# category_accs_1300m, avg_acc_1300m = run_benchmark(model, tokenizer, benchmark_dataset)101102# TEXT: ask them to make a prediction on how accuracy will be affected by different model sizes103104# Benchmark smaller model105# model_name_350m = "facebook/opt-350m"106# model_350m = transformers.AutoModelForCausalLM.from_pretrained(model_name_350m, device_map="auto")107# tokenizer_350m = transformers.AutoTokenizer.from_pretrained(model_name_350m)108109# category_accs_350m, avg_acc_350m = run_benchmark(model_350m, tokenizer_350m, benchmark_dataset)110111# Benchmark larger model112# model_name_2700m = "facebook/opt-2.7b"113# model_2700m = transformers.AutoModelForCausalLM.from_pretrained(model_name_2700m, device_map="auto")114# tokenizer_2700m = transformers.AutoTokenizer.from_pretrained(model_name_2700m)115116# category_accs_2700m, avg_acc_2700m = run_benchmark(model_2700m, tokenizer_2700m, benchmark_dataset)117118# Spider plot119120# benchmark_data = {"350M-Model": category_accs_350m, "1300M-Model": category_accs_1300m, "2700M-Model": category_accs_2700m}121# benchmark_data = {"350M-Model": category_accs_1300m}122# make_spider_plot(benchmark_data)123124def print_lora_params(module, layer_type):125summ = 0126for name, child in module.named_children():127if isinstance(child, layer_type):128num_params = sum(p.numel() for p in child.parameters() if p.requires_grad)129130print(name, num_params, child.in_features, child.out_features, (child.in_features * 8 + child.out_features * 8 == num_params))131132summ += num_params133else:134summ += print_lora_params(child, layer_type)135136return summ137138# Part 2139140# inspect current model141# print(model)142143# summ = print_lora_params(model, nn.Linear)144145# print("with function", summ)146147# print("without function", sum(p.numel() for p in model.parameters() if p.requires_grad))148149# # freeze all parameter gradients150for param in model.parameters():151param.requires_grad = False152153# new LoRA linear layer class154class LoRALinear(nn.Module):155def __init__(156self,157in_features: int,158out_features: int,159pretrained_weight: torch.Tensor,160pretrained_bias: torch.Tensor,161r: int = 8,162lora_alpha: int = 8,163lora_dropout: float = 0.1,164**kwargs165):166super(LoRALinear, self).__init__()167168self.r = r169self.in_features = in_features170self.out_features = out_features171self.lora_alpha = lora_alpha172173self.weight = nn.Parameter(pretrained_weight)174self.weight.requires_grad = False175176if pretrained_bias is not None:177self.bias = nn.Parameter(pretrained_bias)178self.bias.requires_grad = False179else:180self.bias = None181182# from https://github.com/microsoft/LoRA/blob/main/loralib/layers.py183self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))184self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))185self.scaling = self.lora_alpha / self.r186self.lora_dropout = nn.Dropout(p=lora_dropout)187188def forward(self, x: torch.Tensor):189result = F.linear(x, self.weight, bias=self.bias)190result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling191return result192193# replace linear layers in model recursively194def replace_linear_with_lora(module):195for name, child in module.named_children():196if isinstance(child, nn.Linear):197setattr(module, name, LoRALinear(child.in_features, child.out_features, child.weight, child.bias))198else:199replace_linear_with_lora(child)200201replace_linear_with_lora(model)202203204205# summ = print_lora_params(model, LoRALinear)206207# print("with function", summ)208209# print("without function", sum(p.numel() for p in model.parameters() if p.requires_grad))210211212# inspect new model213# print(model)214215# load chat dataset216dataset_name = "timdettmers/openassistant-guanaco"217ft_dataset = load_dataset(dataset_name, split="train")218219# train model (barebones loop)220context_length = 768221loss_fn = CrossEntropyLoss()222223learning_rate = 1e-4224optimizer = Adam(model.parameters(), lr=learning_rate)225num_epochs = 5226227model = model.to("cuda")228229### Train the model230# Define some training args231args = transformers.TrainingArguments("/home/dnori/introtodeeplearning/xtra_labs/llm_finetune/outputs",232per_device_train_batch_size=1,233logging_first_step=True,234logging_steps=20,235save_steps=100,236)237238# Define a callback to check the progress on a sample question239class PrinterCallback(transformers.TrainerCallback):240def on_log(self, args, state, control, model, logs=None, **kwargs):241start_text = "### Human: When the weather is sunny, what color is the sky?### Assistant:"242generate_pt(model, tokenizer, start_text, num_steps=200, until="###")243244# Actually train the model245trainer = SFTTrainer(246model,247args=args,248train_dataset=ft_dataset,249dataset_text_field="text",250max_seq_length=context_length,251callbacks=[PrinterCallback()]252)253trainer.train()254255256# for epoch in range(num_epochs):257# total_loss = 0258# num_batches = 0259260# for batch in tqdm(ft_dataset):261# prompt = batch["text"]262263# # encode with tokenizer264# x = tokenizer.encode(prompt)265# x_tensor = torch.tensor(x).view(1, -1).to("cuda")266# max_len = min(context_length, x_tensor.shape[1]-1)267# selected_len = random.randint(1,max_len)268269# input_tensor = x_tensor[:,:selected_len]270# target_tensor = x_tensor[0,1:selected_len+1]271272# # zero gradients273# optimizer.zero_grad()274275# # run through model276# logits = model(input_tensor).logits[0]277278# # apply loss279# loss = loss_fn(logits, target_tensor)280281# # backpropagation282# loss.backward()283284# # optimizer step285# optimizer.step()286287# total_loss += loss.item()288# num_batches += 1289290# # Print average loss for the epoch291# average_loss = total_loss / num_batches292# print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {average_loss}")293294# # evaluate finetuned model on benchmark295# category_accs_1300m_ft, avg_acc_1300m_ft = run_benchmark(model, tokenizer, benchmark_dataset)296297# add to spider plot298# benchmark_data = {"350M-Model": category_accs_350m, "1300M-Model": category_accs_1300m, "1300M-Model-Finetuned": category_accs_1300m_ft, "2700M-Model": category_accs_2700m}299# benchmark_data = {"350M-Model": category_accs_1300m, "350M-Model-Finetuned": category_accs_1300m_ft}300# make_spider_plot(benchmark_data)301302