Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
yiming-wange
GitHub Repository: yiming-wange/cs224n-2023-solution
Path: blob/main/a5/src/model.py
1003 views
1
"""
2
GPT model:
3
- the initial stem consists of a combination of token encoding and a positional encoding
4
- the meat of it is a uniform sequence of Transformer blocks
5
- each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block
6
- all blocks feed into a central residual pathway similar to resnets
7
- the final decoder is a linear projection into a vanilla Softmax classifier
8
9
10
Originally forked from Andrej Karpathy's minGPT.
11
12
CS224N 2022-23: Homework 5
13
14
John Hewitt <[email protected]>
15
Ansh Khurana <[email protected]>
16
"""
17
18
import math
19
20
import torch
21
import torch.nn as nn
22
from torch.nn import functional as F
23
24
import attention
25
26
torch.manual_seed(1)
27
28
29
class GPTConfig:
30
""" base GPT config, params common to all GPT versions """
31
embd_pdrop = 0.1
32
resid_pdrop = 0.1
33
attn_pdrop = 0.1
34
perceiver = False
35
bottleneck_dim = None
36
37
def __init__(self, vocab_size, block_size, **kwargs):
38
self.vocab_size = vocab_size
39
self.block_size = block_size
40
for k,v in kwargs.items():
41
setattr(self, k, v)
42
43
44
class GPT1Config(GPTConfig):
45
""" GPT-1 like network roughly 125M params """
46
n_layer = 12
47
n_head = 12
48
n_embd = 768
49
50
51
class Block(nn.Module):
52
""" an unassuming Transformer block """
53
54
def __init__(self, config):
55
super().__init__()
56
self.ln1 = nn.LayerNorm(config.n_embd)
57
self.ln2 = nn.LayerNorm(config.n_embd)
58
self.attn = attention.CausalSelfAttention(config)
59
self.mlp = nn.Sequential(
60
nn.Linear(config.n_embd, 4 * config.n_embd),
61
nn.GELU(),
62
nn.Linear(4 * config.n_embd, config.n_embd),
63
nn.Dropout(config.resid_pdrop),
64
)
65
66
def forward(self, x):
67
x = x + self.attn(self.ln1(x))
68
x = x + self.mlp(self.ln2(x))
69
return x
70
71
72
class DownProjectBlock(nn.Module):
73
"""Transformer block used for down projection.
74
75
Initialize similarly to the regular tranformer Block class,
76
while using the CausalCrossAttention layer instead of the regular
77
CausalSelfAttention layer.
78
79
You also need to initialize the parameter for the basis vectors `self.C` here.
80
Initialize `self.C` with appropriate dimensions and xavier_uniform initalization.
81
82
self.C should be 1 x bottleneck_dim x n_embd. We need the first dimension
83
for appropriate broadcasting along the batch_size dimension of the input
84
sequence.
85
86
`self.C` will be used to compute the Query vector for the cross attention
87
layer.
88
"""
89
def __init__(self, config):
90
super().__init__()
91
### YOUR CODE HERE
92
### Hint: Copy over the code from Block and make necessary modifications.
93
pass
94
### END YOUR CODE
95
96
def forward(self, x_input):
97
"""Hint: perform cross-attention between x_input and self.C.
98
Use the layernorm layers on C, and then on the input to the MLP.
99
"""
100
### YOUR CODE HERE
101
### Hint: Copy over the code from Block and make necessary modifications.
102
### Should be around 3-5 lines.
103
pass
104
### END YOUR CODE
105
106
107
class UpProjectBlock(nn.Module):
108
"""Transformer block used for up projection.
109
110
Initialize similarly to the regular transformer Block class,
111
while using the CausalCrossAttention layer instead of the regular
112
CausalSelfAttention layer.
113
"""
114
def __init__(self, config):
115
super().__init__()
116
### YOUR CODE HERE
117
### Hint: Copy over the code from Block and make necessary modifications.
118
pass
119
### END YOUR CODE
120
121
def forward(self, y, x_input):
122
"""Hint: perform cross-attention between previous layer's output y and
123
x_input.
124
Use the layernorm layers on y, and then on the input to the MLP.
125
"""
126
### YOUR CODE HERE
127
### Hint: Copy over the code from Block and make necessary modifications.
128
### Should be around 3-5 lines.
129
pass
130
### END YOUR CODE
131
132
133
class GPT(nn.Module):
134
""" the full GPT language model, with a context size of block_size """
135
136
def __init__(self, config):
137
super().__init__()
138
139
# input embedding stem
140
self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
141
self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
142
self.drop = nn.Dropout(config.embd_pdrop)
143
# transformer
144
self.perceiver = config.perceiver
145
if config.perceiver:
146
input_block_size = config.block_size
147
148
# input sequence based causal mask
149
self.down_block = DownProjectBlock(config)
150
151
# bottleneck basis based causal mask
152
config.block_size = config.bottleneck_dim
153
self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer-2)])
154
155
# reset value of the block size back to the original.
156
config.block_size = input_block_size
157
self.up_block = UpProjectBlock(config)
158
159
160
else:
161
self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
162
# decoder head
163
self.ln_f = nn.LayerNorm(config.n_embd)
164
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
165
166
self.block_size = config.block_size
167
self.apply(self._init_weights)
168
169
print("number of parameters: {}".format(sum(p.numel() for p in self.parameters())))
170
171
def _init_weights(self, module):
172
if isinstance(module, (nn.Linear, nn.Embedding)):
173
module.weight.data.normal_(mean=0.0, std=0.02)
174
if isinstance(module, nn.Linear) and module.bias is not None:
175
module.bias.data.zero_()
176
elif isinstance(module, nn.LayerNorm):
177
module.bias.data.zero_()
178
module.weight.data.fill_(1.0)
179
180
def get_block_size(self):
181
return self.block_size
182
183
def forward(self, idx, targets=None):
184
b, t = idx.size()
185
assert t <= self.block_size, "Cannot forward, model block size (%d, %d) is exhausted." % (t, self.block_size)
186
187
# forward the GPT model
188
token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
189
position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
190
x_input = self.drop(token_embeddings + position_embeddings)
191
192
if self.perceiver:
193
x = self.down_block(x_input)
194
else:
195
x = x_input
196
197
# always compute through the blocks
198
x = self.blocks(x)
199
200
if self.perceiver:
201
x = self.up_block(x, x_input)
202
203
x = self.ln_f(x)
204
logits = self.head(x)
205
206
# if we are given some desired targets also calculate the loss
207
loss = None
208
if targets is not None:
209
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=0)
210
211
return logits, loss
212
213
214