Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
yiming-wange
GitHub Repository: yiming-wange/cs224n-2023-solution
Path: blob/main/a5/mingpt-demo/mingpt/model.py
1003 views
1
"""
2
GPT model:
3
- the initial stem consists of a combination of token encoding and a positional encoding
4
- the meat of it is a uniform sequence of Transformer blocks
5
- each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block
6
- all blocks feed into a central residual pathway similar to resnets
7
- the final decoder is a linear projection into a vanilla Softmax classifier
8
"""
9
10
import math
11
import logging
12
13
import torch
14
import torch.nn as nn
15
from torch.nn import functional as F
16
17
logger = logging.getLogger(__name__)
18
19
class GPTConfig:
20
""" base GPT config, params common to all GPT versions """
21
embd_pdrop = 0.1
22
resid_pdrop = 0.1
23
attn_pdrop = 0.1
24
25
def __init__(self, vocab_size, block_size, **kwargs):
26
self.vocab_size = vocab_size
27
self.block_size = block_size
28
for k,v in kwargs.items():
29
setattr(self, k, v)
30
31
class GPT1Config(GPTConfig):
32
""" GPT-1 like network roughly 125M params """
33
n_layer = 12
34
n_head = 12
35
n_embd = 768
36
37
class CausalSelfAttention(nn.Module):
38
"""
39
A vanilla multi-head masked self-attention layer with a projection at the end.
40
It is possible to use torch.nn.MultiheadAttention here but I am including an
41
explicit implementation here to show that there is nothing too scary here.
42
"""
43
44
def __init__(self, config):
45
super().__init__()
46
assert config.n_embd % config.n_head == 0
47
# key, query, value projections for all heads
48
self.key = nn.Linear(config.n_embd, config.n_embd)
49
self.query = nn.Linear(config.n_embd, config.n_embd)
50
self.value = nn.Linear(config.n_embd, config.n_embd)
51
# regularization
52
self.attn_drop = nn.Dropout(config.attn_pdrop)
53
self.resid_drop = nn.Dropout(config.resid_pdrop)
54
# output projection
55
self.proj = nn.Linear(config.n_embd, config.n_embd)
56
# causal mask to ensure that attention is only applied to the left in the input sequence
57
self.register_buffer("mask", torch.tril(torch.ones(config.block_size, config.block_size))
58
.view(1, 1, config.block_size, config.block_size))
59
self.n_head = config.n_head
60
61
def forward(self, x, layer_past=None):
62
B, T, C = x.size()
63
64
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
65
k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
66
q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
67
v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
68
69
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
70
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
71
att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
72
att = F.softmax(att, dim=-1)
73
att = self.attn_drop(att)
74
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
75
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
76
77
# output projection
78
y = self.resid_drop(self.proj(y))
79
return y
80
81
class Block(nn.Module):
82
""" an unassuming Transformer block """
83
84
def __init__(self, config):
85
super().__init__()
86
self.ln1 = nn.LayerNorm(config.n_embd)
87
self.ln2 = nn.LayerNorm(config.n_embd)
88
self.attn = CausalSelfAttention(config)
89
self.mlp = nn.Sequential(
90
nn.Linear(config.n_embd, 4 * config.n_embd),
91
nn.GELU(),
92
nn.Linear(4 * config.n_embd, config.n_embd),
93
nn.Dropout(config.resid_pdrop),
94
)
95
96
def forward(self, x):
97
x = x + self.attn(self.ln1(x))
98
x = x + self.mlp(self.ln2(x))
99
return x
100
101
class GPT(nn.Module):
102
""" the full GPT language model, with a context size of block_size """
103
104
def __init__(self, config):
105
super().__init__()
106
107
# input embedding stem
108
self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
109
self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
110
self.drop = nn.Dropout(config.embd_pdrop)
111
# transformer
112
self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
113
# decoder head
114
self.ln_f = nn.LayerNorm(config.n_embd)
115
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
116
117
self.block_size = config.block_size
118
self.apply(self._init_weights)
119
120
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
121
122
def get_block_size(self):
123
return self.block_size
124
125
def _init_weights(self, module):
126
if isinstance(module, (nn.Linear, nn.Embedding)):
127
module.weight.data.normal_(mean=0.0, std=0.02)
128
if isinstance(module, nn.Linear) and module.bias is not None:
129
module.bias.data.zero_()
130
elif isinstance(module, nn.LayerNorm):
131
module.bias.data.zero_()
132
module.weight.data.fill_(1.0)
133
134
def configure_optimizers(self, train_config):
135
"""
136
This long function is unfortunately doing something very simple and is being very defensive:
137
We are separating out all parameters of the model into two buckets: those that will experience
138
weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
139
We are then returning the PyTorch optimizer object.
140
"""
141
142
# separate out all parameters to those that will and won't experience regularizing weight decay
143
decay = set()
144
no_decay = set()
145
whitelist_weight_modules = (torch.nn.Linear, )
146
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
147
for mn, m in self.named_modules():
148
for pn, p in m.named_parameters():
149
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
150
151
if pn.endswith('bias'):
152
# all biases will not be decayed
153
no_decay.add(fpn)
154
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
155
# weights of whitelist modules will be weight decayed
156
decay.add(fpn)
157
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
158
# weights of blacklist modules will NOT be weight decayed
159
no_decay.add(fpn)
160
161
# special case the position embedding parameter in the root GPT module as not decayed
162
no_decay.add('pos_emb')
163
164
# validate that we considered every parameter
165
param_dict = {pn: p for pn, p in self.named_parameters()}
166
inter_params = decay & no_decay
167
union_params = decay | no_decay
168
assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
169
assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
170
% (str(param_dict.keys() - union_params), )
171
172
# create the pytorch optimizer object
173
optim_groups = [
174
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
175
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
176
]
177
optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)
178
return optimizer
179
180
def forward(self, idx, targets=None):
181
b, t = idx.size()
182
assert t <= self.block_size, "Cannot forward, model block size is exhausted."
183
184
# forward the GPT model
185
token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
186
position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
187
x = self.drop(token_embeddings + position_embeddings)
188
x = self.blocks(x)
189
x = self.ln_f(x)
190
logits = self.head(x)
191
192
# if we are given some desired targets also calculate the loss
193
loss = None
194
if targets is not None:
195
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
196
197
return logits, loss
198
199