Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
yiming-wange
GitHub Repository: yiming-wange/cs224n-2023-solution
Path: blob/main/minBERT/bert.py
984 views
1
from typing import Dict, List, Optional, Union, Tuple, Callable
2
import math
3
import torch
4
import torch.nn as nn
5
import torch.nn.functional as F
6
from base_bert import BertPreTrainedModel
7
from utils import *
8
9
10
class BertSelfAttention(nn.Module):
11
def __init__(self, config):
12
super().__init__()
13
14
self.num_attention_heads = config.num_attention_heads
15
self.attention_head_size = int(
16
config.hidden_size / config.num_attention_heads)
17
self.all_head_size = self.num_attention_heads * self.attention_head_size
18
19
# initialize the linear transformation layers for key, value, query
20
self.query = nn.Linear(config.hidden_size, self.all_head_size)
21
self.key = nn.Linear(config.hidden_size, self.all_head_size)
22
self.value = nn.Linear(config.hidden_size, self.all_head_size)
23
# this dropout is applied to normalized attention scores following the original implementation of transformer
24
# although it is a bit unusual, we empirically observe that it yields better performance
25
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
26
27
def transform(self, x, linear_layer):
28
# the corresponding linear_layer of k, v, q are used to project the hidden_state (x)
29
bs, seq_len = x.shape[:2]
30
proj = linear_layer(x)
31
# next, we need to produce multiple heads for the proj
32
# this is done by spliting the hidden state to self.num_attention_heads, each of size self.attention_head_size
33
proj = proj.view(bs, seq_len, self.num_attention_heads,
34
self.attention_head_size)
35
# by proper transpose, we have proj of [bs, num_attention_heads, seq_len, attention_head_size] d/h-> atention head size (B, h, n, d/h)
36
proj = proj.transpose(1, 2) # atention head size (B, h, n, d/h)
37
return proj
38
39
def attention(self, key, query, value, attention_mask):
40
'''
41
key, query, value -> (B, h, n, d/h)
42
they are XK, XQ, XV then transform to suitable size
43
'''
44
# each attention is calculated following eq (1) of https://arxiv.org/pdf/1706.03762.pdf
45
# attention scores are calculated by multiply query and key
46
# and get back a score matrix S of [bs, num_attention_heads, seq_len, seq_len]
47
# S[*, i, j, k] represents the (unnormalized)attention score between the j-th and k-th token, given by i-th attention head
48
# before normalizing the scores, use the attention mask to mask out the padding token scores
49
# Note again: in the attention_mask non-padding tokens with 0 and padding tokens with a large negative number
50
# normalize the scores
51
# multiply the attention scores to the value and get back V'
52
# next, we need to concat multi-heads and recover the original shape [bs, seq_len, num_attention_heads * attention_head_size = hidden_size]
53
B, h, n, ahs = key.size()
54
S = torch.matmul(query, key.transpose(-1, -2)) / ahs**0.5
55
S += attention_mask
56
weight = F.softmax(S, dim=-1)
57
# transpose back is the key step!!
58
V_ = torch.matmul(weight, value).transpose(1, 2).contiguous()
59
return V_.view(B, n, -1)
60
61
def forward(self, hidden_states, attention_mask):
62
"""
63
hidden_states: [bs, seq_len, hidden_state]
64
attention_mask: [bs, 1, 1, seq_len]
65
output: [bs, seq_len, hidden_state]
66
"""
67
# first, we have to generate the key, value, query for each token for multi-head attention w/ transform (more details inside the function)
68
# of *_layers are of [bs, num_attention_heads, seq_len, attention_head_size]
69
key_layer = self.transform(hidden_states, self.key)
70
value_layer = self.transform(hidden_states, self.value)
71
query_layer = self.transform(hidden_states, self.query)
72
# calculate the multi-head attention
73
attn_value = self.attention(
74
key_layer, query_layer, value_layer, attention_mask)
75
return attn_value
76
77
78
class BertLayer(nn.Module):
79
def __init__(self, config):
80
super().__init__()
81
# multi-head attention
82
self.self_attention = BertSelfAttention(config)
83
# add-norm
84
self.attention_dense = nn.Linear(
85
config.hidden_size, config.hidden_size)
86
self.attention_layer_norm = nn.LayerNorm(
87
config.hidden_size, eps=config.layer_norm_eps)
88
self.attention_dropout = nn.Dropout(config.hidden_dropout_prob)
89
# feed forward
90
self.interm_dense = nn.Linear(
91
config.hidden_size, config.intermediate_size)
92
self.interm_af = F.gelu
93
# another add-norm
94
self.out_dense = nn.Linear(
95
config.intermediate_size, config.hidden_size)
96
self.out_layer_norm = nn.LayerNorm(
97
config.hidden_size, eps=config.layer_norm_eps)
98
self.out_dropout = nn.Dropout(config.hidden_dropout_prob)
99
100
def add_norm(self, input, output, dense_layer, dropout, ln_layer):
101
"""
102
this function is applied after the multi-head attention layer or the feed forward layer
103
input: the input of the previous layer
104
output: the output of the previous layer
105
dense_layer: used to transform the output
106
dropout: the dropout to be applied
107
ln_layer: the layer norm to be applied
108
"""
109
# Hint: Remember that BERT applies to the output of each sub-layer, before it is added to the sub-layer input and normalized
110
output = dense_layer(output)
111
output = dropout(output)
112
out = input + output
113
out = ln_layer(out)
114
return out
115
116
def forward(self, hidden_states, attention_mask):
117
"""
118
hidden_states: either from the embedding layer (first bert layer) or from the previous bert layer
119
as shown in the left of Figure 1 of https://arxiv.org/pdf/1706.03762.pdf
120
each block consists of
121
1. a multi-head attention layer (BertSelfAttention)
122
2. a add-norm that takes the input and output of the multi-head attention layer
123
3. a feed forward layer
124
4. a add-norm that takes the input and output of the feed forward layer
125
"""
126
output = self.self_attention(hidden_states, attention_mask)
127
out_first_addnorm = self.add_norm(
128
hidden_states, output, self.attention_dense, self.attention_dropout, self.attention_layer_norm)
129
out = self.interm_af(self.interm_dense(out_first_addnorm))
130
out = self.add_norm(out_first_addnorm, out, self.out_dense,
131
self.out_dropout, self.out_layer_norm)
132
return out
133
134
135
class BertModel(BertPreTrainedModel):
136
"""
137
the bert model returns the final embeddings for each token in a sentence
138
it consists
139
1. embedding (used in self.embed)
140
2. a stack of n bert layers (used in self.encode)
141
3. a linear transformation layer for [CLS] token (used in self.forward, as given)
142
"""
143
144
def __init__(self, config):
145
super().__init__(config)
146
self.config = config
147
148
# embedding
149
self.word_embedding = nn.Embedding(
150
config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
151
self.pos_embedding = nn.Embedding(
152
config.max_position_embeddings, config.hidden_size)
153
self.tk_type_embedding = nn.Embedding(
154
config.type_vocab_size, config.hidden_size)
155
self.embed_layer_norm = nn.LayerNorm(
156
config.hidden_size, eps=config.layer_norm_eps)
157
self.embed_dropout = nn.Dropout(config.hidden_dropout_prob)
158
# position_ids (1, len position emb) is a constant, register to buffer
159
position_ids = torch.arange(
160
config.max_position_embeddings).unsqueeze(0)
161
self.register_buffer('position_ids', position_ids)
162
163
# bert encoder
164
self.bert_layers = nn.ModuleList(
165
[BertLayer(config) for _ in range(config.num_hidden_layers)])
166
167
# for [CLS] token
168
self.pooler_dense = nn.Linear(config.hidden_size, config.hidden_size)
169
self.pooler_af = nn.Tanh()
170
171
self.init_weights()
172
173
def embed(self, input_ids):
174
175
input_shape = input_ids.size()
176
seq_length = input_shape[1]
177
178
# Get word embedding from self.word_embedding into input_embeds.
179
# TODO
180
inputs_embeds = self.word_embedding(
181
input_ids) # (B, Seq_len, Hidden_state)
182
183
# Get position index and position embedding from self.pos_embedding into pos_embeds.
184
pos_ids = self.position_ids[:, :seq_length]
185
# TODO
186
pos_embeds = self.pos_embedding(pos_ids) # (1, Seq_len, Hidden_state)
187
188
# Get token type ids, since we are not consider token type, just a placeholder.
189
tk_type_ids = torch.zeros(
190
input_shape, dtype=torch.long, device=input_ids.device)
191
tk_type_embeds = self.tk_type_embedding(
192
tk_type_ids) # (B, Seq_len, Hidden_state)
193
194
# Add three embeddings together; then apply embed_layer_norm and dropout and return.
195
# TODO
196
embed = inputs_embeds + pos_embeds + tk_type_embeds
197
out = self.embed_layer_norm(embed)
198
out = self.embed_dropout(out)
199
return out
200
201
def encode(self, hidden_states, attention_mask):
202
"""
203
hidden_states: the output from the embedding layer [batch_size, seq_len, hidden_size]
204
attention_mask: [batch_size, seq_len]
205
"""
206
# get the extended attention mask for self attention
207
# returns extended_attention_mask of [batch_size, 1, 1, seq_len]
208
# non-padding tokens with 0 and padding tokens with a large negative number
209
extended_attention_mask: torch.Tensor = get_extended_attention_mask(
210
attention_mask, self.dtype)
211
212
# pass the hidden states through the encoder layers
213
for i, layer_module in enumerate(self.bert_layers):
214
# feed the encoding from the last bert_layer to the next
215
hidden_states = layer_module(
216
hidden_states, extended_attention_mask)
217
218
return hidden_states
219
220
def forward(self, input_ids, attention_mask):
221
"""
222
input_ids: [batch_size, seq_len], seq_len is the max length of the batch
223
attention_mask: same size as input_ids, 1 represents non-padding tokens, 0 represents padding tokens
224
"""
225
# get the embedding for each input token
226
embedding_output = self.embed(input_ids=input_ids)
227
228
# feed to a transformer (a stack of BertLayers)
229
sequence_output = self.encode(
230
embedding_output, attention_mask=attention_mask) #(B, seq_len, hidden_size)
231
232
# get cls token hidden state
233
first_tk = sequence_output[:, 0]
234
first_tk = self.pooler_dense(first_tk)
235
first_tk = self.pooler_af(first_tk)#(B, hidden_size)
236
237
return {'last_hidden_state': sequence_output, 'pooler_output': first_tk}
238
239