Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
labmlai
GitHub Repository: labmlai/annotated_deep_learning_paper_implementations
Path: blob/master/labml_nn/lora/experiment.py
4921 views
1
"""
2
---
3
title: Finetune GPT-2 with LoRA
4
summary: This is training code with notes for fine-tuning pre-trained GPT-2 model with LoRA.
5
---
6
7
# Finetune [GPT-2](gpt2.html) with [LoRA](index.html)
8
9
Here's a Colab notebook for training a feedback transformer on Tiny Shakespeare dataset.
10
11
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/lora/experiment.ipynb)
12
"""
13
14
import torch
15
from torch.optim import Adam
16
from torch.utils.data import DataLoader, TensorDataset
17
from transformers import AutoTokenizer, AutoModelForCausalLM
18
19
from labml import lab, monit, tracker
20
from labml.configs import BaseConfigs, option
21
from labml.utils.download import download_file
22
from labml_nn.helpers.device import DeviceConfigs
23
from labml_nn.lora.gpt2 import GPTModel
24
25
26
class Trainer(BaseConfigs):
27
"""
28
## Trainer configurations and the training loop
29
30
The default configs can and will be over-ridden when we start the experiment
31
"""
32
device: torch.device = DeviceConfigs()
33
34
# GPT-2 configs
35
layer_norm_epsilon: float = 1e-05
36
d_model: int = 768
37
n_layers: int = 12
38
n_heads: int = 12
39
n_positions: int = 1024
40
vocab_size: int = 50257
41
42
# Training configs
43
epochs: int = 10
44
batch_size: int = 32
45
learning_rate: float = 1e-4
46
context_len: int = 512
47
48
# LoRA rank
49
lora_r: int = 32
50
51
# Dataset
52
text: TensorDataset = "tiny_shakespeare"
53
# Huggingface tokenizer
54
tokenizer = AutoTokenizer.from_pretrained("gpt2")
55
# [GPT2 model](gpt2.html)
56
model: GPTModel
57
# Optimizer
58
optimizer: torch.optim.Adam
59
# Cross entropy loss
60
loss_func = torch.nn.CrossEntropyLoss()
61
# Dataloader
62
data_loader: DataLoader
63
64
def _load_pretrained_weights(self):
65
"""
66
### Load pre-trained [GPT-2 from huggingface](https://huggingface.co/openai-community/gpt2)
67
"""
68
69
# Load the huggingface model and get the parameters
70
hf_model = AutoModelForCausalLM.from_pretrained("gpt2")
71
state_dict = hf_model.state_dict()
72
73
# Transformer embedding and prediction layer parameter mapping (`hf: ours`)
74
mapping = {
75
'transformer.wte.weight': 'token_embedding.weight',
76
'transformer.wpe.weight': 'position_embedding.weight',
77
'transformer.ln_f.weight': 'final_norm.weight',
78
'transformer.ln_f.bias': 'final_norm.bias',
79
'lm_head.weight': 'lm_head.weight'
80
}
81
82
# Mapping (`hf: ours`) of decoder layers
83
for i in range(12):
84
mapping[f'transformer.h.{i}.ln_1.weight'] = f'blocks.{i}.attn_norm.weight'
85
mapping[f'transformer.h.{i}.ln_1.bias'] = f'blocks.{i}.attn_norm.bias'
86
mapping[f'transformer.h.{i}.attn.c_attn.weight'] = f'blocks.{i}.attn.qkv_projection.weight'
87
mapping[f'transformer.h.{i}.attn.c_attn.bias'] = f'blocks.{i}.attn.qkv_projection.bias'
88
mapping[f'transformer.h.{i}.attn.c_proj.weight'] = f'blocks.{i}.attn.output_projection.weight'
89
mapping[f'transformer.h.{i}.attn.c_proj.bias'] = f'blocks.{i}.attn.output_projection.bias'
90
mapping[f'transformer.h.{i}.ln_2.weight'] = f'blocks.{i}.ffn_norm.weight'
91
mapping[f'transformer.h.{i}.ln_2.bias'] = f'blocks.{i}.ffn_norm.bias'
92
mapping[f'transformer.h.{i}.mlp.c_fc.weight'] = f'blocks.{i}.ffn.linear_in.weight'
93
mapping[f'transformer.h.{i}.mlp.c_fc.bias'] = f'blocks.{i}.ffn.linear_in.bias'
94
mapping[f'transformer.h.{i}.mlp.c_proj.weight'] = f'blocks.{i}.ffn.linear_out.weight'
95
mapping[f'transformer.h.{i}.mlp.c_proj.bias'] = f'blocks.{i}.ffn.linear_out.bias'
96
97
# Move the parameters based on mapping
98
new_state_dict = {}
99
for old_key, new_key in mapping.items():
100
if old_key in state_dict:
101
new_state_dict[new_key] = state_dict[old_key]
102
103
# GPT-2 hugging face uses 1D Convolution layers. We need to transpose those weights since we use linear layers
104
convo_layers = ([f'blocks.{i}.ffn.linear_in.weight' for i in range(12)] +
105
[f'blocks.{i}.ffn.linear_out.weight' for i in range(12)] +
106
[f'blocks.{i}.attn.qkv_projection.weight' for i in range(12)] +
107
[f'blocks.{i}.attn.output_projection.weight' for i in range(12)])
108
109
for layer in convo_layers:
110
new_state_dict[layer] = torch.transpose(new_state_dict[layer], 0, 1)
111
112
# Load out model. We use `strict = False` because the state does not have LoRA weights
113
missing_keys, unexpected_keys = self.model.load_state_dict(new_state_dict, strict=False)
114
115
# make sure that only lora weights are not loaded
116
assert all('lora' in key for key in missing_keys)
117
assert not unexpected_keys
118
119
def initialize(self):
120
"""
121
### Initialize the model, optimizer and dataloader
122
"""
123
# Initialize the [GPT2 model](gpt2.html)
124
self.model = GPTModel(
125
layer_norm_epsilon=self.layer_norm_epsilon,
126
d_model=self.d_model,
127
n_layers=self.n_layers,
128
n_heads=self.n_heads,
129
n_positions=self.n_positions,
130
vocab_size=self.vocab_size,
131
r=self.lora_r,
132
)
133
self.model.to(self.device)
134
# Load pre-trained model weights
135
self._load_pretrained_weights()
136
137
# Initialize the optimizer
138
self.optimizer = Adam(self.model.parameters(), lr=self.learning_rate)
139
140
# Initialize the data loader
141
self.data_loader = DataLoader(self.text, batch_size=self.batch_size, shuffle=True)
142
143
def run(self):
144
"""
145
### Training loop
146
"""
147
148
for _ in monit.loop(self.epochs):
149
# `inputs` has shape `[batch_size, seq_len]`
150
for (inputs,) in monit.iterate('Train', self.data_loader):
151
# Move `inputs` to device
152
inputs = inputs.to(self.device)
153
# Call the model, with the all but the last token
154
logits = self.model(inputs[:, :-1])
155
# Get cross entropy loss
156
loss = self.loss_func(logits.reshape(-1, logits.shape[-1]), inputs[:, 1:].reshape(-1))
157
158
# Make gradients 0
159
self.optimizer.zero_grad()
160
# Compute gradients
161
loss.backward()
162
# Optimize
163
self.optimizer.step()
164
165
# Log the loss
166
tracker.save({'loss': loss})
167
tracker.add_global_step()
168
#
169
tracker.new_line()
170
171
172
@option(Trainer.text)
173
def tiny_shakespeare(c: Trainer):
174
"""
175
### Tiny Shakespeare dataset
176
177
It will download from the url if not present
178
"""
179
path = lab.get_data_path() / 'tiny_shakespeare.txt'
180
if not path.exists():
181
download_file("https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt", path)
182
with open(path, 'r', encoding='utf-8') as f:
183
text = f.read()
184
185
tokens = c.tokenizer.encode(text)
186
num_batches = len(tokens) // (c.batch_size * c.context_len)
187
tokens = tokens[:num_batches * c.batch_size * c.context_len]
188
input_ids = torch.tensor(tokens).view(-1, c.context_len)
189
return TensorDataset(input_ids)
190
191