Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
labmlai
GitHub Repository: labmlai/annotated_deep_learning_paper_implementations
Path: blob/master/labml_nn/lora/gpt2.py
4921 views
1
"""
2
---
3
title: GPT-2 with LoRA
4
summary: GPT-2 implementation with LoRA modules
5
---
6
7
# GPT-2 with [LoRA modules](index.html)
8
9
Here's [the training code](experiment.html) for training a GPT2 model with LoRA
10
on Tiny Shakespeare dataset.
11
"""
12
13
import torch
14
import torch.nn as nn
15
16
from labml_nn.lora import Linear, Embedding
17
18
19
class FFN(nn.Module):
20
"""
21
### Feedforward Network
22
"""
23
24
def __init__(self, d_model: int, d_ff: int, r: int):
25
"""
26
:param d_model: is the number of dimensions
27
:param d_ff: is the size of the hidden dimension
28
:param r: is the lora rank
29
"""
30
super().__init__()
31
32
# The linear layers and the activation
33
self.linear_in = Linear(d_model, d_ff, r=r, bias=True)
34
self.linear_out = Linear(d_ff, d_model, r=r, bias=True)
35
self.act = nn.GELU()
36
37
def forward(self, x: torch.Tensor) -> torch.Tensor:
38
"""
39
:param x: is the embeddings tensor with shape `[batch_size, seq_len, d_model]`
40
"""
41
x = self.linear_in(x)
42
x = self.act(x)
43
x = self.linear_out(x)
44
return x
45
46
47
class MultiHeadAttention(nn.Module):
48
"""
49
### Multi-Head Attention
50
"""
51
52
def __init__(self, d_model: int, n_heads: int, r: int):
53
"""
54
:param d_model: is the number of dimensions in the embeddings
55
:param n_heads: is the number of heads
56
:param r: is the lora rank
57
"""
58
super().__init__()
59
self.d_model = d_model
60
self.n_heads = n_heads
61
self.d_head = d_model // n_heads
62
63
# Linear transformation for QKV
64
self.qkv_projection = Linear(d_model, d_model * 3, r=r, bias=True)
65
# Output projection
66
self.output_projection = Linear(d_model, d_model, r=r, bias=True)
67
68
def _split_heads(self, x: torch.Tensor):
69
"""
70
:param x: is the tensor with shape `[batch_size, seq_len, d_model]`
71
"""
72
# Split last dimension to `[n_heads, d_head]`
73
x = x.view(x.shape[:-1] + (self.n_heads, self.d_head))
74
# Reorder to `[batch_size, head, seq_length, d_head]`
75
return x.permute(0, 2, 1, 3)
76
77
def forward(self, x: torch.Tensor) -> torch.Tensor:
78
"""
79
:param x: is the embeddings tensor with shape `[batch_size, seq_len, d_model]`
80
"""
81
batch_size, seq_length, _ = x.shape
82
83
# Get query, key and value
84
q, k, v = self.qkv_projection(x).split(self.d_model, dim=-1)
85
86
# Transform them from shape `[batch_size, seq_len, d_model]` to `[batch_size, head, seq_length, d_head]`
87
q = self._split_heads(q)
88
k = self._split_heads(k)
89
v = self._split_heads(v)
90
91
# Apply causal attention
92
attn_output = torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=True)
93
94
# Transform them from shape `[batch_size, head, seq_length, d_head]` to `[batch_size, seq_len, d_model]`
95
attn_output = attn_output.permute(0, 2, 1, 3).reshape(batch_size, seq_length, self.d_model)
96
97
# Final project
98
return self.output_projection(attn_output)
99
100
101
class Block(nn.Module):
102
"""
103
### Decoder block
104
"""
105
106
def __init__(self, d_model: int, n_heads: int, layer_norm_epsilon: float, r: int):
107
"""
108
:param d_model: is the number of dimensions in the embeddings
109
:param n_heads: is the number of heads
110
:param layer_norm_epsilon: is the layer norm epsilon
111
:param r: is the lora rank
112
"""
113
super().__init__()
114
# Attention pre-normalization layer
115
self.attn_norm = nn.LayerNorm(d_model, eps=layer_norm_epsilon)
116
# Attention layer
117
self.attn = MultiHeadAttention(d_model, n_heads, r)
118
# FFN pre-normalization layer
119
self.ffn_norm = nn.LayerNorm(d_model, eps=layer_norm_epsilon)
120
# Feed-forward network
121
self.ffn = FFN(d_model, d_model * 4, r)
122
123
def forward(self, x: torch.Tensor) -> torch.Tensor:
124
"""
125
:param x: is the embeddings tensor with shape `[batch_size, seq_len, d_model]`
126
"""
127
# Attention
128
x = x + self.attn(self.attn_norm(x))
129
# FFN
130
x = x + self.ffn(self.ffn_norm(x))
131
132
return x
133
134
135
class GPTModel(nn.Module):
136
"""
137
## GPT2 Model
138
"""
139
140
def __init__(self, *, d_model: int,
141
n_heads: int, n_layers: int,
142
n_positions: int,
143
layer_norm_epsilon: float,
144
vocab_size: int, r: int):
145
"""
146
:param d_model: is the number of dimensions in the embeddings
147
:param n_heads: is the number of attention heads
148
:param n_layers: is the number of decoder layers
149
:param n_positions: is the number of positional embeddings
150
:param layer_norm_epsilon: is the layer norm epsilon
151
:param vocab_size: is the vocabulary size
152
:param r: is the lora rank
153
"""
154
super().__init__()
155
156
# Token and absolute positional embeddings
157
self.token_embedding = Embedding(vocab_size, d_model, r=r)
158
self.position_embedding = Embedding(n_positions, d_model, r=r)
159
160
# Decoder blocks
161
self.blocks = nn.ModuleList([Block(d_model, n_heads, layer_norm_epsilon, r=r)
162
for _ in range(n_layers)])
163
164
# Final layer norm
165
self.final_norm = nn.LayerNorm(d_model, eps=layer_norm_epsilon)
166
# Projection layer to logit space
167
self.lm_head = Linear(d_model, vocab_size, r=r, bias=False)
168
169
def forward(self, input_ids: torch.Tensor):
170
"""
171
:param input_ids: has shape `[batch_size, seq_len]`
172
"""
173
batch_size, seq_len = input_ids.shape
174
175
# Get token embeddings
176
token_embeddings = self.token_embedding(input_ids)
177
# Get position ids
178
position_ids = torch.arange(seq_len, device=input_ids.device)[None, :]
179
# Get position embeddings
180
position_embeddings = self.position_embedding(position_ids)
181
182
# Add position embeddings
183
x = token_embeddings + position_embeddings
184
185
# Run through transformer blocks
186
for block in self.blocks:
187
x = block(x)
188
189
# Final normalization
190
x = self.final_norm(x)
191
# Get logits from projection layer
192
return self.lm_head(x)
193
194