Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
jantic
GitHub Repository: jantic/deoldify
Path: blob/master/fastai/text/models/transformer.py
840 views
1
from ...torch_core import *
2
from ...layers import *
3
from .awd_lstm import RNNDropout, LinearDecoder, SequentialRNN
4
5
__all__ = ['Activation', 'PositionalEncoding', 'GeLU', 'Swish', 'feed_forward', 'MultiHeadAttention', 'MultiHeadRelativeAttention',
6
'DecoderLayer', 'Transformer', 'TransformerXL', 'tfmer_lm_config', 'tfmer_clas_config', 'tfmer_lm_split', 'tfmer_clas_split',
7
'tfmerXL_lm_config', 'tfmerXL_clas_config', 'tfmerXL_lm_split', 'tfmerXL_clas_split']
8
9
Activation = Enum('Activation', 'ReLU Swish GeLU')
10
11
class PositionalEncoding(Module):
12
"Encode the position with a sinusoid."
13
def __init__(self, d:int): self.register_buffer('freq', 1 / (10000 ** (torch.arange(0., d, 2.)/d)))
14
15
def forward(self, pos:Tensor):
16
inp = torch.ger(pos, self.freq)
17
enc = torch.cat([inp.sin(), inp.cos()], dim=-1)
18
return enc
19
20
class GeLU(Module):
21
def forward(self, x): return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
22
23
class Swish(Module):
24
def forward(self, x): return x * torch.sigmoid(x)
25
26
_activ_func = {Activation.ReLU:nn.ReLU(inplace=True), Activation.GeLU:GeLU(), Activation.Swish: Swish()}
27
28
def feed_forward(d_model:int, d_ff:int, ff_p:float=0., act:Activation=Activation.ReLU, double_drop:bool=True):
29
layers = [nn.Linear(d_model, d_ff), _activ_func[act]]
30
if double_drop: layers.append(nn.Dropout(ff_p))
31
return SequentialEx(*layers, nn.Linear(d_ff, d_model), nn.Dropout(ff_p), MergeLayer(), nn.LayerNorm(d_model))
32
33
class MultiHeadAttention(Module):
34
"MutiHeadAttention."
35
def __init__(self, n_heads:int, d_model:int, d_head:int=None, resid_p:float=0., attn_p:float=0., bias:bool=True,
36
scale:bool=True):
37
d_head = ifnone(d_head, d_model//n_heads)
38
self.n_heads,self.d_head,self.scale = n_heads,d_head,scale
39
self.attention = nn.Linear(d_model, 3 * n_heads * d_head, bias=bias)
40
self.out = nn.Linear(n_heads * d_head, d_model, bias=bias)
41
self.drop_att,self.drop_res = nn.Dropout(attn_p),nn.Dropout(resid_p)
42
self.ln = nn.LayerNorm(d_model)
43
44
def forward(self, x:Tensor, mask:Tensor=None, **kwargs):
45
return self.ln(x + self.drop_res(self.out(self._apply_attention(x, mask=mask, **kwargs))))
46
47
def _apply_attention(self, x:Tensor, mask:Tensor=None):
48
bs,x_len = x.size(0),x.size(1)
49
wq,wk,wv = torch.chunk(self.attention(x), 3, dim=-1)
50
wq,wk,wv = map(lambda x:x.view(bs, x.size(1), self.n_heads, self.d_head), (wq,wk,wv))
51
wq,wk,wv = wq.permute(0, 2, 1, 3),wk.permute(0, 2, 3, 1),wv.permute(0, 2, 1, 3)
52
attn_score = torch.matmul(wq, wk)
53
if self.scale: attn_score.div_(self.d_head ** 0.5)
54
if mask is not None:
55
attn_score = attn_score.float().masked_fill(mask, -float('inf')).type_as(attn_score)
56
attn_prob = self.drop_att(F.softmax(attn_score, dim=-1))
57
attn_vec = torch.matmul(attn_prob, wv)
58
return attn_vec.permute(0, 2, 1, 3).contiguous().contiguous().view(bs, x_len, -1)
59
60
def _attention_einsum(self, x, mask=None):
61
# Permute and matmul is a little bit faster but this implementation is more readable
62
bs,x_len = x.size(0),x.size(1)
63
wq,wk,wv = torch.chunk(self.attention(x), 3, dim=-1)
64
wq,wk,wv = map(lambda x:x.view(bs, x.size(1), self.n_heads, self.d_head), (wq,wk,wv))
65
attn_score = torch.einsum('bind,bjnd->bijn', (wq, wk))
66
if self.scale: attn_score.mul_(1/(self.d_head ** 0.5))
67
if mask is not None:
68
attn_score = attn_score.float().masked_fill(mask, -float('inf')).type_as(attn_score)
69
attn_prob = self.drop_att(F.softmax(attn_score, dim=2))
70
attn_vec = torch.einsum('bijn,bjnd->bind', (attn_prob, wv))
71
return attn_vec.contiguous().view(bs, x_len, -1)
72
73
#def _line_shift1(x:Tensor, mask:bool=False):
74
# "Shift the line i of `x` by p-i elements to the left, is `mask` puts 0s on the diagonal."
75
# bs,n,p,nh = x.size()
76
# x_pad = torch.cat([x.new_zeros(bs,n,1,nh), x], dim=2)
77
# x_shift = x_pad.view(bs,p + 1,n,nh)[:,1:].view_as(x)
78
# if mask: x_shift.mul_(torch.tril(x.new_ones(n,p), p-n)[None,:,:,None])
79
# return x_shift
80
81
def _line_shift(x:Tensor, mask:bool=False):
82
"Shift the line i of `x` by p-i elements to the left, is `mask` puts 0s on the diagonal."
83
bs,nh,n,p = x.size()
84
x_pad = torch.cat([x.new_zeros(bs,nh,n,1), x], dim=3)
85
x_shift = x_pad.view(bs,nh,p + 1,n)[:,:,1:].view_as(x)
86
if mask: x_shift.mul_(torch.tril(x.new_ones(n,p), p-n)[None,None,])
87
return x_shift
88
89
class MultiHeadRelativeAttention(MultiHeadAttention):
90
"MutiHeadAttention with relative positional encoding."
91
92
def __init__(self, n_heads:int, d_model:int, d_head:int, resid_p:float=0., attn_p:float=0., bias:bool=True,
93
scale:bool=True):
94
super().__init__(n_heads, d_model, d_head, resid_p=resid_p, attn_p=attn_p, bias=bias, scale=scale)
95
self.r_attn = nn.Linear(d_model, n_heads * d_head, bias=bias)
96
97
def _apply_attention(self, x:Tensor, r:Tensor=None, u:Tensor=None, v:Tensor=None, mask:Tensor=None, mem:Tensor=None):
98
#Notations from the paper: x input, r vector of relative distance between two elements, u et v learnable
99
#parameters of the model common between all layers, mask to avoid cheating and mem the previous hidden states.
100
bs,x_len,seq_len = x.size(0),x.size(1),r.size(0)
101
context = x if mem is None else torch.cat([mem, x], dim=1)
102
wq,wk,wv = torch.chunk(self.attention(context), 3, dim=-1)
103
wq = wq[:,-x_len:]
104
wq,wk,wv = map(lambda x:x.view(bs, x.size(1), self.n_heads, self.d_head), (wq,wk,wv))
105
wq,wk,wv = wq.permute(0, 2, 1, 3),wk.permute(0, 2, 3, 1),wv.permute(0, 2, 1, 3)
106
wkr = self.r_attn(r)
107
wkr = wkr.view(seq_len, self.n_heads, self.d_head)
108
wkr = wkr.permute(1,2,0)
109
#### compute attention score (AC is (a) + (c) and BS is (b) + (d) in the paper)
110
AC = torch.matmul(wq+u,wk)
111
BD = _line_shift(torch.matmul(wq+v, wkr))
112
if self.scale: attn_score = (AC + BD).mul_(1/(self.d_head ** 0.5))
113
if mask is not None:
114
attn_score = attn_score.float().masked_fill(mask, -float('inf')).type_as(attn_score)
115
attn_prob = self.drop_att(F.softmax(attn_score, dim=-1))
116
attn_vec = torch.matmul(attn_prob, wv)
117
return attn_vec.permute(0, 2, 1, 3).contiguous().view(bs, x_len, -1)
118
119
def _attention_einsum(self, x:Tensor, r:Tensor=None, u:Tensor=None, v:Tensor=None, mask:Tensor=None, mem:Tensor=None):
120
# Permute and matmul is a little bit faster but this implementation is more readable
121
bs,x_len,seq_len = x.size(0),x.size(1),r.size(0)
122
context = x if mem is None else torch.cat([mem, x], dim=1)
123
wq,wk,wv = torch.chunk(self.attention(context), 3, dim=-1)
124
wq = wq[:,-x_len:]
125
wkr = self.r_attn(r)
126
wq,wk,wv = map(lambda x:x.view(bs, x.size(1), self.n_heads, self.d_head), (wq,wk,wv))
127
wkr = wkr.view(seq_len, self.n_heads, self.d_head)
128
#### compute attention score (AC is (a) + (c) and BS is (b) + (d) in the paper)
129
AC = torch.einsum('bind,bjnd->bijn', (wq+u, wk))
130
BD = _line_shift1(torch.einsum('bind,jnd->bijn', (wq+v, wkr)))
131
attn_score = (AC + BD).mul_(1/(self.d_head ** 0.5))
132
if mask is not None:
133
attn_score = attn_score.float().masked_fill(mask, -float('inf')).type_as(attn_score)
134
attn_prob = self.drop_att(F.softmax(attn_score, dim=2))
135
attn_vec = torch.einsum('bijn,bjnd->bind', (attn_prob, wv))
136
return attn_vec.contiguous().view(bs, x_len, -1)
137
138
class DecoderLayer(Module):
139
"Basic block of a Transformer model."
140
#Can't use Sequential directly cause more than one input...
141
def __init__(self, n_heads:int, d_model:int, d_head:int, d_inner:int, resid_p:float=0., attn_p:float=0., ff_p:float=0.,
142
bias:bool=True, scale:bool=True, act:Activation=Activation.ReLU, double_drop:bool=True,
143
attn_cls:Callable=MultiHeadAttention):
144
self.mhra = attn_cls(n_heads, d_model, d_head, resid_p=resid_p, attn_p=attn_p, bias=bias, scale=scale)
145
self.ff = feed_forward(d_model, d_inner, ff_p=ff_p, act=act, double_drop=double_drop)
146
147
def forward(self, x:Tensor, mask:Tensor=None, **kwargs): return self.ff(self.mhra(x, mask=mask, **kwargs))
148
149
class Transformer(Module):
150
"Transformer model: https://arxiv.org/abs/1706.03762."
151
def __init__(self, vocab_sz:int, ctx_len:int, n_layers:int, n_heads:int, d_model:int, d_head:int, d_inner:int,
152
resid_p:float=0., attn_p:float=0., ff_p:float=0., embed_p:float=0., bias:bool=True, scale:bool=True,
153
act:Activation=Activation.ReLU, double_drop:bool=True, attn_cls:Callable=MultiHeadAttention,
154
learned_pos_enc:bool=True, mask:bool=True):
155
self.mask = mask
156
self.encoder = nn.Embedding(vocab_sz, d_model)
157
self.pos_enc = nn.Embedding(ctx_len, d_model) if learned_pos_enc else PositionalEncoding(d_model)
158
self.drop_emb = nn.Dropout(embed_p)
159
self.layers = nn.ModuleList([DecoderLayer(n_heads, d_model, d_head, d_inner, resid_p=resid_p, attn_p=attn_p,
160
ff_p=ff_p, bias=bias, scale=scale, act=act, double_drop=double_drop,
161
attn_cls=attn_cls) for k in range(n_layers)])
162
163
def reset(self): pass
164
165
def forward(self, x):
166
bs, x_len = x.size()
167
pos = torch.arange(0, x_len, device=x.device, dtype=x.dtype)
168
inp = self.drop_emb(self.encoder(x) + self.pos_enc(pos)[None]) #.mul_(self.d_model ** 0.5)
169
mask = torch.triu(x.new_ones(x_len, x_len), diagonal=1).byte()[None,None] if self.mask else None
170
#[None,:,:None] for einsum implementation of attention
171
for layer in self.layers: inp = layer(inp, mask=mask)
172
return ([inp],[inp]) #For the LinearDecoder
173
174
class TransformerXL(Module):
175
"TransformerXL model: https://arxiv.org/abs/1901.02860."
176
def __init__(self, vocab_sz:int, ctx_len:int, n_layers:int, n_heads:int, d_model:int, d_head:int, d_inner:int,
177
resid_p:float=0., attn_p:float=0., ff_p:float=0., embed_p:float=0., bias:bool=False, scale:bool=True,
178
act:Activation=Activation.ReLU, double_drop:bool=True, attn_cls:Callable=MultiHeadRelativeAttention,
179
learned_pos_enc:bool=False, mask:bool=True, mem_len:int=0):
180
self.encoder = nn.Embedding(vocab_sz, d_model)
181
self.pos_enc = nn.Embedding(ctx_len, d_model) if learned_pos_enc else PositionalEncoding(d_model)
182
self.drop_emb = nn.Dropout(embed_p)
183
self.u = nn.Parameter(torch.Tensor(n_heads, 1, d_head)) #Remove 1 for einsum implementation of attention
184
self.v = nn.Parameter(torch.Tensor(n_heads, 1, d_head)) #Remove 1 for einsum implementation of attention
185
self.mem_len,self.n_layers,self.d_model,self.mask = mem_len,n_layers,d_model,mask
186
self.init = False
187
self.layers = nn.ModuleList([DecoderLayer(n_heads, d_model, d_head, d_inner, resid_p=resid_p, attn_p=attn_p,
188
ff_p=ff_p, bias=bias, scale=scale, act=act, double_drop=double_drop,
189
attn_cls=attn_cls) for k in range(n_layers)])
190
191
def reset(self):
192
"Reset the internal memory."
193
self.hidden = [next(self.parameters()).data.new(0) for i in range(self.n_layers+1)]
194
195
def _update_mems(self, hids):
196
if not getattr(self, 'hidden', False): return None
197
assert len(hids) == len(self.hidden), 'len(hids) != len(self.hidden)'
198
with torch.no_grad():
199
for i in range(len(hids)):
200
cat = torch.cat([self.hidden[i], hids[i]], dim=1)
201
self.hidden[i] = cat[:,-self.mem_len:].detach()
202
203
def select_hidden(self, idxs): self.hidden = [h[idxs] for h in self.hidden]
204
205
def forward(self, x):
206
#The hidden state has to be initiliazed in the forward pass for nn.DataParallel
207
if self.mem_len > 0 and not self.init:
208
self.reset()
209
self.init = True
210
bs,x_len = x.size()
211
inp = self.drop_emb(self.encoder(x)) #.mul_(self.d_model ** 0.5)
212
m_len = self.hidden[0].size(1) if hasattr(self, 'hidden') and len(self.hidden[0].size()) > 1 else 0
213
seq_len = m_len + x_len
214
mask = torch.triu(x.new_ones(x_len, seq_len), diagonal=1+m_len).byte()[None,None] if self.mask else None
215
#[None,:,:None] for einsum implementation of attention
216
hids = []
217
pos = torch.arange(seq_len-1, -1, -1, device=inp.device, dtype=inp.dtype)
218
pos_enc = self.pos_enc(pos)
219
hids.append(inp)
220
for i, layer in enumerate(self.layers):
221
mem = self.hidden[i] if self.mem_len > 0 else None
222
inp = layer(inp, r=pos_enc, u=self.u, v=self.v, mask=mask, mem=mem)
223
hids.append(inp)
224
core_out = inp[:,-x_len:]
225
if self.mem_len > 0 : self._update_mems(hids)
226
return (self.hidden if self.mem_len > 0 else [core_out]),[core_out]
227
228
def init_transformer(m):
229
classname = m.__class__.__name__
230
if classname.find('Linear') != -1:
231
if hasattr(m, 'weight') and m.weight is not None: nn.init.normal_(m.weight, 0., 0.02)
232
if hasattr(m, 'bias') and m.bias is not None: nn.init.constant_(m.bias, 0.)
233
elif classname.find('LayerNorm') != -1:
234
if hasattr(m, 'weight') and m.weight is not None: nn.init.normal_(m.weight, 1., 0.02)
235
if hasattr(m, 'bias') and m.bias is not None: nn.init.constant_(m.bias, 0.)
236
elif classname.find('TransformerXL') != -1:
237
if hasattr(m, 'u'): nn.init.normal_(m.u, 0., 0.02)
238
if hasattr(m, 'v'): nn.init.normal_(m.v, 0., 0.02)
239
240
tfmer_lm_config = dict(ctx_len=512, n_layers=12, n_heads=12, d_model=768, d_head=64, d_inner=3072, resid_p=0.1, attn_p=0.1,
241
ff_p=0.1, embed_p=0.1, output_p=0., bias=True, scale=True, act=Activation.GeLU, double_drop=False,
242
tie_weights=True, out_bias=False, init=init_transformer, mask=True)
243
244
tfmer_clas_config = dict(ctx_len=512, n_layers=12, n_heads=12, d_model=768, d_head=64, d_inner=3072, resid_p=0.1, attn_p=0.1,
245
ff_p=0.1, embed_p=0.1, output_p=0., bias=True, scale=True, act=Activation.GeLU, double_drop=False,
246
init=init_transformer, mask=False)
247
248
def tfmer_lm_split(model:nn.Module) -> List[nn.Module]:
249
"Split a RNN `model` in groups for differential learning rates."
250
encoder = model[0]
251
n = len(encoder.layers)//3
252
groups = [list(encoder.layers[:n]), list(encoder.layers[n:2*n]), list(encoder.layers[2*n:])]
253
return groups + [[encoder.encoder, model[1]]]
254
255
def tfmer_clas_split(model:nn.Module) -> List[nn.Module]:
256
"Split a RNN `model` in groups for differential learning rates."
257
encoder = model[0].module
258
n = len(encoder.layers)//3
259
groups = [[encoder.encoder], list(encoder.layers[:n]), list(encoder.layers[n:2*n]), list(encoder.layers[2*n:])]
260
return groups + [[model[1]]]
261
262
tfmerXL_lm_config = dict(ctx_len=150, n_layers=12, n_heads=10, d_model=410, d_head=41, d_inner=2100, resid_p=0.1, attn_p=0.1,
263
ff_p=0.1, embed_p=0.1, output_p=0.1, bias=False, scale=True, act=Activation.ReLU, double_drop=True,
264
tie_weights=True, out_bias=True, init=init_transformer, mem_len=150, mask=True)
265
266
tfmerXL_clas_config = dict(ctx_len=150, n_layers=12, n_heads=10, d_model=410, d_head=41, d_inner=2100, resid_p=0.1, attn_p=0.1,
267
ff_p=0.1, embed_p=0.1, output_p=0.1, bias=False, scale=True, act=Activation.ReLU, double_drop=True,
268
init=init_transformer, mem_len=150, mask=False)
269
270
def tfmerXL_lm_split(model:nn.Module) -> List[nn.Module]:
271
"Split a RNN `model` in groups for differential learning rates."
272
encoder = model[0]
273
n = len(encoder.layers)//3
274
groups = [list(encoder.layers[:n]) + [ParameterModule(encoder.u), ParameterModule(encoder.v)]]
275
return groups + [list(encoder.layers[n:2*n]), list(encoder.layers[2*n:]), [encoder.encoder, model[1]]]
276
277
def tfmerXL_clas_split(model:nn.Module) -> List[nn.Module]:
278
"Split a RNN `model` in groups for differential learning rates."
279
encoder = model[0].module
280
n = len(encoder.layers)//3
281
groups = [[encoder.encoder], list(encoder.layers[:n]) + [ParameterModule(encoder.u), ParameterModule(encoder.v)]]
282
return groups + [list(encoder.layers[n:2*n]), list(encoder.layers[2*n:]), [model[1]]]
283
284