Path: blob/master/labml_nn/transformers/configs.py
4910 views
"""1---2title: Configurable Transformer Components3summary: These are configurable components that can be re-used quite easily.4---56# Configurable Transformer Components7"""8import copy910import torch.nn as nn1112from labml.configs import BaseConfigs, option, calculate, aggregate13from .feed_forward import FeedForward14from .mha import MultiHeadAttention15from .models import EmbeddingsWithPositionalEncoding, EmbeddingsWithLearnedPositionalEncoding, TransformerLayer, \16Encoder, Decoder, Generator, EncoderDecoder171819class FeedForwardConfigs(BaseConfigs):20"""21<a id="FFN"></a>2223## FFN Configurations2425Creates a Position-wise FeedForward Network defined in26[`feed_forward.py`](feed_forward.html).27"""28# Position-wise feedforward layer29ffn: FeedForward30# Number of features in the embedding31d_model: int32# Number of features in in the hidden layer33d_ff: int = 204834# Dropout probability35dropout: float = 0.136# Activation in position-wise feedforward layer37activation: nn.Module = 'ReLU'38# Whether the FFN layer should be gated39is_gated: bool = False40# Whether the first fully connected layer should have a learnable bias41bias1: bool = True42# Whether the second fully connected layer should have a learnable bias43bias2: bool = True44# Whether the fully connected layer for the gate should have a learnable bias45bias_gate: bool = False46# Predefined GLU variants47glu_variant: str = 'none'484950@option(FeedForwardConfigs.activation, 'ReLU')51def _ffn_activation_relu():52"""53### ReLU activation5455$$\max(0, x)$$56"""57return nn.ReLU()585960@option(FeedForwardConfigs.activation, 'GELU')61def _ffn_activation_gelu():62"""63### GELU activation6465$$x \Phi(x)$$ where $\Phi(x) = P(X \le x), X \sim \mathcal{N}(0,1)$6667It was introduced in paper [Gaussian Error Linear Units](https://arxiv.org/abs/1606.08415).68"""69return nn.GELU()707172@option(FeedForwardConfigs.ffn, 'default')73def _feed_forward(c: FeedForwardConfigs):74"""75Initialize a [feed forward network](feed_forward.html)76"""77return FeedForward(c.d_model, c.d_ff,78dropout=c.dropout,79activation=c.activation,80is_gated=c.is_gated,81bias1=c.bias1,82bias2=c.bias2,83bias_gate=c.bias_gate)8485# ## GLU Variants86# These are variants with gated hidden layers for the FFN87# as introduced in paper [GLU Variants Improve Transformer](https://arxiv.org/abs/2002.05202).88# We have omitted the bias terms as specified in the paper.8990# ### FFN with Gated Linear Units91#92# $$FFN_{GLU}(x)(x, W_1, V, W_2) = (\sigma(x W_1) \otimes x V) W_2$$93aggregate(FeedForwardConfigs.glu_variant, 'GLU',94(FeedForwardConfigs.is_gated, True),95(FeedForwardConfigs.bias1, False),96(FeedForwardConfigs.bias2, False),97(FeedForwardConfigs.bias_gate, False),98(FeedForwardConfigs.activation, nn.Sigmoid()))99100# ### FFN with Bilinear hidden layer101#102# $$FFN_{Bilinear}(x)(x, W_1, V, W_2) = (x W_1 \otimes x V) W_2$$103aggregate(FeedForwardConfigs.glu_variant, 'Bilinear',104(FeedForwardConfigs.is_gated, True),105(FeedForwardConfigs.bias1, False),106(FeedForwardConfigs.bias2, False),107(FeedForwardConfigs.bias_gate, False),108(FeedForwardConfigs.activation, nn.Identity()))109110# ### FFN with ReLU gate111#112# $$FFN_{ReGLU}(x)(x, W_1, V, W_2) = (\max(0, x W_1) \otimes x V) W_2$$113aggregate(FeedForwardConfigs.glu_variant, 'ReGLU',114(FeedForwardConfigs.is_gated, True),115(FeedForwardConfigs.bias1, False),116(FeedForwardConfigs.bias2, False),117(FeedForwardConfigs.bias_gate, False),118(FeedForwardConfigs.activation, nn.ReLU()))119120# ### FFN with GELU gate121#122# $$FFN_{GEGLU}(x)(x, W_1, V, W_2) = (\text{GELU}(x W_1) \otimes x V) W_2$$123aggregate(FeedForwardConfigs.glu_variant, 'GEGLU',124(FeedForwardConfigs.is_gated, True),125(FeedForwardConfigs.bias1, False),126(FeedForwardConfigs.bias2, False),127(FeedForwardConfigs.bias_gate, False),128(FeedForwardConfigs.activation, nn.GELU()))129130# ### FFN with Swish gate131#132# $$FFN_{SwiGLU}(x)(x, W_1, V, W_2) = (\text{Swish}_1(x W_1) \otimes x V) W_2$$133# where $\text{Swish}_\beta(x) = x \sigma(\beta x)$134aggregate(FeedForwardConfigs.glu_variant, 'SwiGLU',135(FeedForwardConfigs.is_gated, True),136(FeedForwardConfigs.bias1, False),137(FeedForwardConfigs.bias2, False),138(FeedForwardConfigs.bias_gate, False),139(FeedForwardConfigs.activation, nn.SiLU()))140141142class TransformerConfigs(BaseConfigs):143"""144<a id="TransformerConfigs"></a>145146## Transformer Configurations147148This defines configurations for a transformer.149The configurations are calculate using option functions.150These are lazy loaded and therefore only the necessary modules151are calculated.152"""153# Number of attention heads154n_heads: int = 8155# Transformer embedding size156d_model: int = 512157# Number of layers158n_layers: int = 6159# Dropout probability160dropout: float = 0.1161# Number of tokens in the source vocabulary (for token embeddings)162n_src_vocab: int163# Number of tokens in the target vocabulary (to generate logits for prediction)164n_tgt_vocab: int165166# The encoder self attention167encoder_attn: MultiHeadAttention = 'mha'168# The decoder self attention169decoder_attn: MultiHeadAttention = 'mha'170# The decoder memory attention171decoder_mem_attn: MultiHeadAttention = 'mha'172173# Configurable Feedforward Layer174ffn: FeedForwardConfigs175176# Encoder layer177encoder_layer: TransformerLayer = 'default'178# Decoder layer179decoder_layer: TransformerLayer = 'default'180181# Encoder consisting of multiple encoder layers182encoder: Encoder = 'default'183# Encoder consisting of multiple decoder layers184decoder: Decoder = 'default'185186# Embedding layer for source187src_embed: nn.Module = 'fixed_pos'188# Embedding layer for target (for decoder)189tgt_embed: nn.Module = 'fixed_pos'190191# Logit generator for prediction192generator: Generator = 'default'193194# Encoder-decoder195encoder_decoder: EncoderDecoder196197198# ### Multi-head Attention199def _mha(c: TransformerConfigs):200return MultiHeadAttention(c.n_heads, c.d_model, dropout_prob=c.dropout)201202203calculate(TransformerConfigs.encoder_attn, 'mha', _mha)204calculate(TransformerConfigs.decoder_attn, 'mha', _mha)205calculate(TransformerConfigs.decoder_mem_attn, 'mha', _mha)206207208# ### Relative Multi-head Attention209def _relative_mha(c: TransformerConfigs):210from labml_nn.transformers.xl.relative_mha import RelativeMultiHeadAttention211return RelativeMultiHeadAttention(c.n_heads, c.d_model)212213214calculate(TransformerConfigs.encoder_attn, 'relative', _relative_mha)215calculate(TransformerConfigs.decoder_attn, 'relative', _relative_mha)216calculate(TransformerConfigs.decoder_mem_attn, 'relative', _relative_mha)217218219@option(TransformerConfigs.ffn, 'default')220def _feed_forward(c: TransformerConfigs):221"""222Create feedforward layer configurations223"""224conf = FeedForwardConfigs()225conf.set_default(FeedForwardConfigs.d_model, func=lambda: c.d_model)226conf.set_default(FeedForwardConfigs.dropout, func=lambda: c.dropout)227return conf228229230@option(TransformerConfigs.encoder_layer, 'default')231def _encoder_layer(c: TransformerConfigs):232"""233Encoder layer234"""235return TransformerLayer(d_model=c.d_model, self_attn=c.encoder_attn,236src_attn=None, feed_forward=copy.deepcopy(c.ffn.ffn),237dropout_prob=c.dropout)238239240@option(TransformerConfigs.decoder_layer, 'default')241def _decoder_layer(c: TransformerConfigs):242"""243Decoder layer244"""245return TransformerLayer(d_model=c.d_model, self_attn=c.decoder_attn,246src_attn=c.decoder_mem_attn, feed_forward=copy.deepcopy(c.ffn.ffn),247dropout_prob=c.dropout)248249250@option(TransformerConfigs.encoder, 'default')251def _encoder(c: TransformerConfigs):252"""253Encoder254"""255return Encoder(c.encoder_layer, c.n_layers)256257258@option(TransformerConfigs.decoder, 'default')259def _decoder(c: TransformerConfigs):260"""261Decoder262"""263return Decoder(c.decoder_layer, c.n_layers)264265266@option(TransformerConfigs.generator, 'default')267def _generator(c: TransformerConfigs):268"""269Logit generator270"""271return Generator(c.n_tgt_vocab, c.d_model)272273274# ### Fixed Positional Embeddings275@option(TransformerConfigs.src_embed, 'fixed_pos')276def _src_embed_with_positional(c: TransformerConfigs):277"""278Source embedding with fixed positional encodings279"""280return EmbeddingsWithPositionalEncoding(c.d_model, c.n_src_vocab)281282283@option(TransformerConfigs.tgt_embed, 'fixed_pos')284def _tgt_embed_with_positional(c: TransformerConfigs):285"""286Target embedding with fixed positional encodings287"""288return EmbeddingsWithPositionalEncoding(c.d_model, c.n_tgt_vocab)289290291# ### Learned Positional Embeddings292@option(TransformerConfigs.src_embed, 'learned_pos')293def _src_embed_with_learned_positional(c: TransformerConfigs):294"""295Source embedding with learned positional encodings296"""297return EmbeddingsWithLearnedPositionalEncoding(c.d_model, c.n_src_vocab)298299300@option(TransformerConfigs.tgt_embed, 'learned_pos')301def _tgt_embed_with_learned_positional(c: TransformerConfigs):302"""303Target embedding with learned positional encodings304"""305return EmbeddingsWithLearnedPositionalEncoding(c.d_model, c.n_tgt_vocab)306307308# ### No Positional Embeddings309@option(TransformerConfigs.src_embed, 'no_pos')310def _src_embed_without_positional(c: TransformerConfigs):311"""312Source embedding without positional encodings313"""314return nn.Embedding(c.n_src_vocab, c.d_model)315316317@option(TransformerConfigs.tgt_embed, 'no_pos')318def _tgt_embed_without_positional(c: TransformerConfigs):319return nn.Embedding(c.n_tgt_vocab, c.d_model)320321322@option(TransformerConfigs.encoder_decoder, 'default')323def _encoder_decoder(c: TransformerConfigs):324return EncoderDecoder(c.encoder, c.decoder, c.src_embed, c.tgt_embed, c.generator)325326327