Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
labmlai
GitHub Repository: labmlai/annotated_deep_learning_paper_implementations
Path: blob/master/labml_nn/transformers/configs.py
4910 views
1
"""
2
---
3
title: Configurable Transformer Components
4
summary: These are configurable components that can be re-used quite easily.
5
---
6
7
# Configurable Transformer Components
8
"""
9
import copy
10
11
import torch.nn as nn
12
13
from labml.configs import BaseConfigs, option, calculate, aggregate
14
from .feed_forward import FeedForward
15
from .mha import MultiHeadAttention
16
from .models import EmbeddingsWithPositionalEncoding, EmbeddingsWithLearnedPositionalEncoding, TransformerLayer, \
17
Encoder, Decoder, Generator, EncoderDecoder
18
19
20
class FeedForwardConfigs(BaseConfigs):
21
"""
22
<a id="FFN"></a>
23
24
## FFN Configurations
25
26
Creates a Position-wise FeedForward Network defined in
27
[`feed_forward.py`](feed_forward.html).
28
"""
29
# Position-wise feedforward layer
30
ffn: FeedForward
31
# Number of features in the embedding
32
d_model: int
33
# Number of features in in the hidden layer
34
d_ff: int = 2048
35
# Dropout probability
36
dropout: float = 0.1
37
# Activation in position-wise feedforward layer
38
activation: nn.Module = 'ReLU'
39
# Whether the FFN layer should be gated
40
is_gated: bool = False
41
# Whether the first fully connected layer should have a learnable bias
42
bias1: bool = True
43
# Whether the second fully connected layer should have a learnable bias
44
bias2: bool = True
45
# Whether the fully connected layer for the gate should have a learnable bias
46
bias_gate: bool = False
47
# Predefined GLU variants
48
glu_variant: str = 'none'
49
50
51
@option(FeedForwardConfigs.activation, 'ReLU')
52
def _ffn_activation_relu():
53
"""
54
### ReLU activation
55
56
$$\max(0, x)$$
57
"""
58
return nn.ReLU()
59
60
61
@option(FeedForwardConfigs.activation, 'GELU')
62
def _ffn_activation_gelu():
63
"""
64
### GELU activation
65
66
$$x \Phi(x)$$ where $\Phi(x) = P(X \le x), X \sim \mathcal{N}(0,1)$
67
68
It was introduced in paper [Gaussian Error Linear Units](https://arxiv.org/abs/1606.08415).
69
"""
70
return nn.GELU()
71
72
73
@option(FeedForwardConfigs.ffn, 'default')
74
def _feed_forward(c: FeedForwardConfigs):
75
"""
76
Initialize a [feed forward network](feed_forward.html)
77
"""
78
return FeedForward(c.d_model, c.d_ff,
79
dropout=c.dropout,
80
activation=c.activation,
81
is_gated=c.is_gated,
82
bias1=c.bias1,
83
bias2=c.bias2,
84
bias_gate=c.bias_gate)
85
86
# ## GLU Variants
87
# These are variants with gated hidden layers for the FFN
88
# as introduced in paper [GLU Variants Improve Transformer](https://arxiv.org/abs/2002.05202).
89
# We have omitted the bias terms as specified in the paper.
90
91
# ### FFN with Gated Linear Units
92
#
93
# $$FFN_{GLU}(x)(x, W_1, V, W_2) = (\sigma(x W_1) \otimes x V) W_2$$
94
aggregate(FeedForwardConfigs.glu_variant, 'GLU',
95
(FeedForwardConfigs.is_gated, True),
96
(FeedForwardConfigs.bias1, False),
97
(FeedForwardConfigs.bias2, False),
98
(FeedForwardConfigs.bias_gate, False),
99
(FeedForwardConfigs.activation, nn.Sigmoid()))
100
101
# ### FFN with Bilinear hidden layer
102
#
103
# $$FFN_{Bilinear}(x)(x, W_1, V, W_2) = (x W_1 \otimes x V) W_2$$
104
aggregate(FeedForwardConfigs.glu_variant, 'Bilinear',
105
(FeedForwardConfigs.is_gated, True),
106
(FeedForwardConfigs.bias1, False),
107
(FeedForwardConfigs.bias2, False),
108
(FeedForwardConfigs.bias_gate, False),
109
(FeedForwardConfigs.activation, nn.Identity()))
110
111
# ### FFN with ReLU gate
112
#
113
# $$FFN_{ReGLU}(x)(x, W_1, V, W_2) = (\max(0, x W_1) \otimes x V) W_2$$
114
aggregate(FeedForwardConfigs.glu_variant, 'ReGLU',
115
(FeedForwardConfigs.is_gated, True),
116
(FeedForwardConfigs.bias1, False),
117
(FeedForwardConfigs.bias2, False),
118
(FeedForwardConfigs.bias_gate, False),
119
(FeedForwardConfigs.activation, nn.ReLU()))
120
121
# ### FFN with GELU gate
122
#
123
# $$FFN_{GEGLU}(x)(x, W_1, V, W_2) = (\text{GELU}(x W_1) \otimes x V) W_2$$
124
aggregate(FeedForwardConfigs.glu_variant, 'GEGLU',
125
(FeedForwardConfigs.is_gated, True),
126
(FeedForwardConfigs.bias1, False),
127
(FeedForwardConfigs.bias2, False),
128
(FeedForwardConfigs.bias_gate, False),
129
(FeedForwardConfigs.activation, nn.GELU()))
130
131
# ### FFN with Swish gate
132
#
133
# $$FFN_{SwiGLU}(x)(x, W_1, V, W_2) = (\text{Swish}_1(x W_1) \otimes x V) W_2$$
134
# where $\text{Swish}_\beta(x) = x \sigma(\beta x)$
135
aggregate(FeedForwardConfigs.glu_variant, 'SwiGLU',
136
(FeedForwardConfigs.is_gated, True),
137
(FeedForwardConfigs.bias1, False),
138
(FeedForwardConfigs.bias2, False),
139
(FeedForwardConfigs.bias_gate, False),
140
(FeedForwardConfigs.activation, nn.SiLU()))
141
142
143
class TransformerConfigs(BaseConfigs):
144
"""
145
<a id="TransformerConfigs"></a>
146
147
## Transformer Configurations
148
149
This defines configurations for a transformer.
150
The configurations are calculate using option functions.
151
These are lazy loaded and therefore only the necessary modules
152
are calculated.
153
"""
154
# Number of attention heads
155
n_heads: int = 8
156
# Transformer embedding size
157
d_model: int = 512
158
# Number of layers
159
n_layers: int = 6
160
# Dropout probability
161
dropout: float = 0.1
162
# Number of tokens in the source vocabulary (for token embeddings)
163
n_src_vocab: int
164
# Number of tokens in the target vocabulary (to generate logits for prediction)
165
n_tgt_vocab: int
166
167
# The encoder self attention
168
encoder_attn: MultiHeadAttention = 'mha'
169
# The decoder self attention
170
decoder_attn: MultiHeadAttention = 'mha'
171
# The decoder memory attention
172
decoder_mem_attn: MultiHeadAttention = 'mha'
173
174
# Configurable Feedforward Layer
175
ffn: FeedForwardConfigs
176
177
# Encoder layer
178
encoder_layer: TransformerLayer = 'default'
179
# Decoder layer
180
decoder_layer: TransformerLayer = 'default'
181
182
# Encoder consisting of multiple encoder layers
183
encoder: Encoder = 'default'
184
# Encoder consisting of multiple decoder layers
185
decoder: Decoder = 'default'
186
187
# Embedding layer for source
188
src_embed: nn.Module = 'fixed_pos'
189
# Embedding layer for target (for decoder)
190
tgt_embed: nn.Module = 'fixed_pos'
191
192
# Logit generator for prediction
193
generator: Generator = 'default'
194
195
# Encoder-decoder
196
encoder_decoder: EncoderDecoder
197
198
199
# ### Multi-head Attention
200
def _mha(c: TransformerConfigs):
201
return MultiHeadAttention(c.n_heads, c.d_model, dropout_prob=c.dropout)
202
203
204
calculate(TransformerConfigs.encoder_attn, 'mha', _mha)
205
calculate(TransformerConfigs.decoder_attn, 'mha', _mha)
206
calculate(TransformerConfigs.decoder_mem_attn, 'mha', _mha)
207
208
209
# ### Relative Multi-head Attention
210
def _relative_mha(c: TransformerConfigs):
211
from labml_nn.transformers.xl.relative_mha import RelativeMultiHeadAttention
212
return RelativeMultiHeadAttention(c.n_heads, c.d_model)
213
214
215
calculate(TransformerConfigs.encoder_attn, 'relative', _relative_mha)
216
calculate(TransformerConfigs.decoder_attn, 'relative', _relative_mha)
217
calculate(TransformerConfigs.decoder_mem_attn, 'relative', _relative_mha)
218
219
220
@option(TransformerConfigs.ffn, 'default')
221
def _feed_forward(c: TransformerConfigs):
222
"""
223
Create feedforward layer configurations
224
"""
225
conf = FeedForwardConfigs()
226
conf.set_default(FeedForwardConfigs.d_model, func=lambda: c.d_model)
227
conf.set_default(FeedForwardConfigs.dropout, func=lambda: c.dropout)
228
return conf
229
230
231
@option(TransformerConfigs.encoder_layer, 'default')
232
def _encoder_layer(c: TransformerConfigs):
233
"""
234
Encoder layer
235
"""
236
return TransformerLayer(d_model=c.d_model, self_attn=c.encoder_attn,
237
src_attn=None, feed_forward=copy.deepcopy(c.ffn.ffn),
238
dropout_prob=c.dropout)
239
240
241
@option(TransformerConfigs.decoder_layer, 'default')
242
def _decoder_layer(c: TransformerConfigs):
243
"""
244
Decoder layer
245
"""
246
return TransformerLayer(d_model=c.d_model, self_attn=c.decoder_attn,
247
src_attn=c.decoder_mem_attn, feed_forward=copy.deepcopy(c.ffn.ffn),
248
dropout_prob=c.dropout)
249
250
251
@option(TransformerConfigs.encoder, 'default')
252
def _encoder(c: TransformerConfigs):
253
"""
254
Encoder
255
"""
256
return Encoder(c.encoder_layer, c.n_layers)
257
258
259
@option(TransformerConfigs.decoder, 'default')
260
def _decoder(c: TransformerConfigs):
261
"""
262
Decoder
263
"""
264
return Decoder(c.decoder_layer, c.n_layers)
265
266
267
@option(TransformerConfigs.generator, 'default')
268
def _generator(c: TransformerConfigs):
269
"""
270
Logit generator
271
"""
272
return Generator(c.n_tgt_vocab, c.d_model)
273
274
275
# ### Fixed Positional Embeddings
276
@option(TransformerConfigs.src_embed, 'fixed_pos')
277
def _src_embed_with_positional(c: TransformerConfigs):
278
"""
279
Source embedding with fixed positional encodings
280
"""
281
return EmbeddingsWithPositionalEncoding(c.d_model, c.n_src_vocab)
282
283
284
@option(TransformerConfigs.tgt_embed, 'fixed_pos')
285
def _tgt_embed_with_positional(c: TransformerConfigs):
286
"""
287
Target embedding with fixed positional encodings
288
"""
289
return EmbeddingsWithPositionalEncoding(c.d_model, c.n_tgt_vocab)
290
291
292
# ### Learned Positional Embeddings
293
@option(TransformerConfigs.src_embed, 'learned_pos')
294
def _src_embed_with_learned_positional(c: TransformerConfigs):
295
"""
296
Source embedding with learned positional encodings
297
"""
298
return EmbeddingsWithLearnedPositionalEncoding(c.d_model, c.n_src_vocab)
299
300
301
@option(TransformerConfigs.tgt_embed, 'learned_pos')
302
def _tgt_embed_with_learned_positional(c: TransformerConfigs):
303
"""
304
Target embedding with learned positional encodings
305
"""
306
return EmbeddingsWithLearnedPositionalEncoding(c.d_model, c.n_tgt_vocab)
307
308
309
# ### No Positional Embeddings
310
@option(TransformerConfigs.src_embed, 'no_pos')
311
def _src_embed_without_positional(c: TransformerConfigs):
312
"""
313
Source embedding without positional encodings
314
"""
315
return nn.Embedding(c.n_src_vocab, c.d_model)
316
317
318
@option(TransformerConfigs.tgt_embed, 'no_pos')
319
def _tgt_embed_without_positional(c: TransformerConfigs):
320
return nn.Embedding(c.n_tgt_vocab, c.d_model)
321
322
323
@option(TransformerConfigs.encoder_decoder, 'default')
324
def _encoder_decoder(c: TransformerConfigs):
325
return EncoderDecoder(c.encoder, c.decoder, c.src_embed, c.tgt_embed, c.generator)
326
327