Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
labmlai
GitHub Repository: labmlai/annotated_deep_learning_paper_implementations
Path: blob/master/labml_nn/transformers/models.py
4918 views
1
"""
2
---
3
title: Transformer Encoder and Decoder Models
4
summary: >
5
These are PyTorch implementations of Transformer based encoder and decoder models,
6
as well as other related modules.
7
---
8
9
# Transformer Encoder and Decoder Models
10
11
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/transformers/basic/autoregressive_experiment.ipynb)
12
"""
13
import math
14
15
import torch
16
import torch.nn as nn
17
18
from labml_nn.utils import clone_module_list
19
from .feed_forward import FeedForward
20
from .mha import MultiHeadAttention
21
from .positional_encoding import get_positional_encoding
22
23
24
class EmbeddingsWithPositionalEncoding(nn.Module):
25
"""
26
<a id="EmbeddingsWithPositionalEncoding"></a>
27
28
## Embed tokens and add [fixed positional encoding](positional_encoding.html)
29
"""
30
31
def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000):
32
super().__init__()
33
self.linear = nn.Embedding(n_vocab, d_model)
34
self.d_model = d_model
35
self.register_buffer('positional_encodings', get_positional_encoding(d_model, max_len))
36
37
def forward(self, x: torch.Tensor):
38
pe = self.positional_encodings[:x.shape[0]].requires_grad_(False)
39
return self.linear(x) * math.sqrt(self.d_model) + pe
40
41
42
class EmbeddingsWithLearnedPositionalEncoding(nn.Module):
43
"""
44
<a id="EmbeddingsWithLearnedPositionalEncoding"></a>
45
46
## Embed tokens and add parameterized positional encodings
47
"""
48
49
def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000):
50
super().__init__()
51
self.linear = nn.Embedding(n_vocab, d_model)
52
self.d_model = d_model
53
self.positional_encodings = nn.Parameter(torch.zeros(max_len, 1, d_model), requires_grad=True)
54
55
def forward(self, x: torch.Tensor):
56
pe = self.positional_encodings[:x.shape[0]]
57
return self.linear(x) * math.sqrt(self.d_model) + pe
58
59
60
class TransformerLayer(nn.Module):
61
"""
62
<a id="TransformerLayer"></a>
63
64
## Transformer Layer
65
66
This can act as an encoder layer or a decoder layer. We use pre-norm.
67
"""
68
69
def __init__(self, *,
70
d_model: int,
71
self_attn: MultiHeadAttention,
72
src_attn: MultiHeadAttention = None,
73
feed_forward: FeedForward,
74
dropout_prob: float):
75
"""
76
* `d_model` is the token embedding size
77
* `self_attn` is the self attention module
78
* `src_attn` is the source attention module (when this is used in a decoder)
79
* `feed_forward` is the feed forward module
80
* `dropout_prob` is the probability of dropping out after self attention and FFN
81
"""
82
super().__init__()
83
self.size = d_model
84
self.self_attn = self_attn
85
self.src_attn = src_attn
86
self.feed_forward = feed_forward
87
self.dropout = nn.Dropout(dropout_prob)
88
self.norm_self_attn = nn.LayerNorm([d_model])
89
if self.src_attn is not None:
90
self.norm_src_attn = nn.LayerNorm([d_model])
91
self.norm_ff = nn.LayerNorm([d_model])
92
# Whether to save input to the feed forward layer
93
self.is_save_ff_input = False
94
95
def forward(self, *,
96
x: torch.Tensor,
97
mask: torch.Tensor,
98
src: torch.Tensor = None,
99
src_mask: torch.Tensor = None):
100
# Normalize the vectors before doing self attention
101
z = self.norm_self_attn(x)
102
# Run through self attention, i.e. keys and values are from self
103
self_attn = self.self_attn(query=z, key=z, value=z, mask=mask)
104
# Add the self attention results
105
x = x + self.dropout(self_attn)
106
107
# If a source is provided, get results from attention to source.
108
# This is when you have a decoder layer that pays attention to
109
# encoder outputs
110
if src is not None:
111
# Normalize vectors
112
z = self.norm_src_attn(x)
113
# Attention to source. i.e. keys and values are from source
114
attn_src = self.src_attn(query=z, key=src, value=src, mask=src_mask)
115
# Add the source attention results
116
x = x + self.dropout(attn_src)
117
118
# Normalize for feed-forward
119
z = self.norm_ff(x)
120
# Save the input to the feed forward layer if specified
121
if self.is_save_ff_input:
122
self.ff_input = z.clone()
123
# Pass through the feed-forward network
124
ff = self.feed_forward(z)
125
# Add the feed-forward results back
126
x = x + self.dropout(ff)
127
128
return x
129
130
131
class Encoder(nn.Module):
132
"""
133
<a id="Encoder"></a>
134
135
## Transformer Encoder
136
"""
137
138
def __init__(self, layer: TransformerLayer, n_layers: int):
139
super().__init__()
140
# Make copies of the transformer layer
141
self.layers = clone_module_list(layer, n_layers)
142
# Final normalization layer
143
self.norm = nn.LayerNorm([layer.size])
144
145
def forward(self, x: torch.Tensor, mask: torch.Tensor):
146
# Run through each transformer layer
147
for layer in self.layers:
148
x = layer(x=x, mask=mask)
149
# Finally, normalize the vectors
150
return self.norm(x)
151
152
153
class Decoder(nn.Module):
154
"""
155
<a id="Decoder"></a>
156
157
## Transformer Decoder
158
"""
159
160
def __init__(self, layer: TransformerLayer, n_layers: int):
161
super().__init__()
162
# Make copies of the transformer layer
163
self.layers = clone_module_list(layer, n_layers)
164
# Final normalization layer
165
self.norm = nn.LayerNorm([layer.size])
166
167
def forward(self, x: torch.Tensor, memory: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):
168
# Run through each transformer layer
169
for layer in self.layers:
170
x = layer(x=x, mask=tgt_mask, src=memory, src_mask=src_mask)
171
# Finally, normalize the vectors
172
return self.norm(x)
173
174
175
class Generator(nn.Module):
176
"""
177
<a id="Generator"></a>
178
179
## Generator
180
181
This predicts the tokens and gives the lof softmax of those.
182
You don't need this if you are using `nn.CrossEntropyLoss`.
183
"""
184
185
def __init__(self, n_vocab: int, d_model: int):
186
super().__init__()
187
self.projection = nn.Linear(d_model, n_vocab)
188
189
def forward(self, x):
190
return self.projection(x)
191
192
193
class EncoderDecoder(nn.Module):
194
"""
195
<a id="EncoderDecoder"></a>
196
197
## Combined Encoder-Decoder
198
"""
199
200
def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: nn.Module, tgt_embed: nn.Module, generator: nn.Module):
201
super().__init__()
202
self.encoder = encoder
203
self.decoder = decoder
204
self.src_embed = src_embed
205
self.tgt_embed = tgt_embed
206
self.generator = generator
207
208
# This was important from their code.
209
# Initialize parameters with Glorot / fan_avg.
210
for p in self.parameters():
211
if p.dim() > 1:
212
nn.init.xavier_uniform_(p)
213
214
def forward(self, src: torch.Tensor, tgt: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):
215
# Run the source through encoder
216
enc = self.encode(src, src_mask)
217
# Run encodings and targets through decoder
218
return self.decode(enc, src_mask, tgt, tgt_mask)
219
220
def encode(self, src: torch.Tensor, src_mask: torch.Tensor):
221
return self.encoder(self.src_embed(src), src_mask)
222
223
def decode(self, memory: torch.Tensor, src_mask: torch.Tensor, tgt: torch.Tensor, tgt_mask: torch.Tensor):
224
return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)
225
226