Path: blob/master/labml_nn/lora/experiment.py
4921 views
"""1---2title: Finetune GPT-2 with LoRA3summary: This is training code with notes for fine-tuning pre-trained GPT-2 model with LoRA.4---56# Finetune [GPT-2](gpt2.html) with [LoRA](index.html)78Here's a Colab notebook for training a feedback transformer on Tiny Shakespeare dataset.910[](https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/lora/experiment.ipynb)11"""1213import torch14from torch.optim import Adam15from torch.utils.data import DataLoader, TensorDataset16from transformers import AutoTokenizer, AutoModelForCausalLM1718from labml import lab, monit, tracker19from labml.configs import BaseConfigs, option20from labml.utils.download import download_file21from labml_nn.helpers.device import DeviceConfigs22from labml_nn.lora.gpt2 import GPTModel232425class Trainer(BaseConfigs):26"""27## Trainer configurations and the training loop2829The default configs can and will be over-ridden when we start the experiment30"""31device: torch.device = DeviceConfigs()3233# GPT-2 configs34layer_norm_epsilon: float = 1e-0535d_model: int = 76836n_layers: int = 1237n_heads: int = 1238n_positions: int = 102439vocab_size: int = 502574041# Training configs42epochs: int = 1043batch_size: int = 3244learning_rate: float = 1e-445context_len: int = 5124647# LoRA rank48lora_r: int = 324950# Dataset51text: TensorDataset = "tiny_shakespeare"52# Huggingface tokenizer53tokenizer = AutoTokenizer.from_pretrained("gpt2")54# [GPT2 model](gpt2.html)55model: GPTModel56# Optimizer57optimizer: torch.optim.Adam58# Cross entropy loss59loss_func = torch.nn.CrossEntropyLoss()60# Dataloader61data_loader: DataLoader6263def _load_pretrained_weights(self):64"""65### Load pre-trained [GPT-2 from huggingface](https://huggingface.co/openai-community/gpt2)66"""6768# Load the huggingface model and get the parameters69hf_model = AutoModelForCausalLM.from_pretrained("gpt2")70state_dict = hf_model.state_dict()7172# Transformer embedding and prediction layer parameter mapping (`hf: ours`)73mapping = {74'transformer.wte.weight': 'token_embedding.weight',75'transformer.wpe.weight': 'position_embedding.weight',76'transformer.ln_f.weight': 'final_norm.weight',77'transformer.ln_f.bias': 'final_norm.bias',78'lm_head.weight': 'lm_head.weight'79}8081# Mapping (`hf: ours`) of decoder layers82for i in range(12):83mapping[f'transformer.h.{i}.ln_1.weight'] = f'blocks.{i}.attn_norm.weight'84mapping[f'transformer.h.{i}.ln_1.bias'] = f'blocks.{i}.attn_norm.bias'85mapping[f'transformer.h.{i}.attn.c_attn.weight'] = f'blocks.{i}.attn.qkv_projection.weight'86mapping[f'transformer.h.{i}.attn.c_attn.bias'] = f'blocks.{i}.attn.qkv_projection.bias'87mapping[f'transformer.h.{i}.attn.c_proj.weight'] = f'blocks.{i}.attn.output_projection.weight'88mapping[f'transformer.h.{i}.attn.c_proj.bias'] = f'blocks.{i}.attn.output_projection.bias'89mapping[f'transformer.h.{i}.ln_2.weight'] = f'blocks.{i}.ffn_norm.weight'90mapping[f'transformer.h.{i}.ln_2.bias'] = f'blocks.{i}.ffn_norm.bias'91mapping[f'transformer.h.{i}.mlp.c_fc.weight'] = f'blocks.{i}.ffn.linear_in.weight'92mapping[f'transformer.h.{i}.mlp.c_fc.bias'] = f'blocks.{i}.ffn.linear_in.bias'93mapping[f'transformer.h.{i}.mlp.c_proj.weight'] = f'blocks.{i}.ffn.linear_out.weight'94mapping[f'transformer.h.{i}.mlp.c_proj.bias'] = f'blocks.{i}.ffn.linear_out.bias'9596# Move the parameters based on mapping97new_state_dict = {}98for old_key, new_key in mapping.items():99if old_key in state_dict:100new_state_dict[new_key] = state_dict[old_key]101102# GPT-2 hugging face uses 1D Convolution layers. We need to transpose those weights since we use linear layers103convo_layers = ([f'blocks.{i}.ffn.linear_in.weight' for i in range(12)] +104[f'blocks.{i}.ffn.linear_out.weight' for i in range(12)] +105[f'blocks.{i}.attn.qkv_projection.weight' for i in range(12)] +106[f'blocks.{i}.attn.output_projection.weight' for i in range(12)])107108for layer in convo_layers:109new_state_dict[layer] = torch.transpose(new_state_dict[layer], 0, 1)110111# Load out model. We use `strict = False` because the state does not have LoRA weights112missing_keys, unexpected_keys = self.model.load_state_dict(new_state_dict, strict=False)113114# make sure that only lora weights are not loaded115assert all('lora' in key for key in missing_keys)116assert not unexpected_keys117118def initialize(self):119"""120### Initialize the model, optimizer and dataloader121"""122# Initialize the [GPT2 model](gpt2.html)123self.model = GPTModel(124layer_norm_epsilon=self.layer_norm_epsilon,125d_model=self.d_model,126n_layers=self.n_layers,127n_heads=self.n_heads,128n_positions=self.n_positions,129vocab_size=self.vocab_size,130r=self.lora_r,131)132self.model.to(self.device)133# Load pre-trained model weights134self._load_pretrained_weights()135136# Initialize the optimizer137self.optimizer = Adam(self.model.parameters(), lr=self.learning_rate)138139# Initialize the data loader140self.data_loader = DataLoader(self.text, batch_size=self.batch_size, shuffle=True)141142def run(self):143"""144### Training loop145"""146147for _ in monit.loop(self.epochs):148# `inputs` has shape `[batch_size, seq_len]`149for (inputs,) in monit.iterate('Train', self.data_loader):150# Move `inputs` to device151inputs = inputs.to(self.device)152# Call the model, with the all but the last token153logits = self.model(inputs[:, :-1])154# Get cross entropy loss155loss = self.loss_func(logits.reshape(-1, logits.shape[-1]), inputs[:, 1:].reshape(-1))156157# Make gradients 0158self.optimizer.zero_grad()159# Compute gradients160loss.backward()161# Optimize162self.optimizer.step()163164# Log the loss165tracker.save({'loss': loss})166tracker.add_global_step()167#168tracker.new_line()169170171@option(Trainer.text)172def tiny_shakespeare(c: Trainer):173"""174### Tiny Shakespeare dataset175176It will download from the url if not present177"""178path = lab.get_data_path() / 'tiny_shakespeare.txt'179if not path.exists():180download_file("https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt", path)181with open(path, 'r', encoding='utf-8') as f:182text = f.read()183184tokens = c.tokenizer.encode(text)185num_batches = len(tokens) // (c.batch_size * c.context_len)186tokens = tokens[:num_batches * c.batch_size * c.context_len]187input_ids = torch.tensor(tokens).view(-1, c.context_len)188return TensorDataset(input_ids)189190191