Path: blob/master/labml_nn/lora/gpt2.py
4921 views
"""1---2title: GPT-2 with LoRA3summary: GPT-2 implementation with LoRA modules4---56# GPT-2 with [LoRA modules](index.html)78Here's [the training code](experiment.html) for training a GPT2 model with LoRA9on Tiny Shakespeare dataset.10"""1112import torch13import torch.nn as nn1415from labml_nn.lora import Linear, Embedding161718class FFN(nn.Module):19"""20### Feedforward Network21"""2223def __init__(self, d_model: int, d_ff: int, r: int):24"""25:param d_model: is the number of dimensions26:param d_ff: is the size of the hidden dimension27:param r: is the lora rank28"""29super().__init__()3031# The linear layers and the activation32self.linear_in = Linear(d_model, d_ff, r=r, bias=True)33self.linear_out = Linear(d_ff, d_model, r=r, bias=True)34self.act = nn.GELU()3536def forward(self, x: torch.Tensor) -> torch.Tensor:37"""38:param x: is the embeddings tensor with shape `[batch_size, seq_len, d_model]`39"""40x = self.linear_in(x)41x = self.act(x)42x = self.linear_out(x)43return x444546class MultiHeadAttention(nn.Module):47"""48### Multi-Head Attention49"""5051def __init__(self, d_model: int, n_heads: int, r: int):52"""53:param d_model: is the number of dimensions in the embeddings54:param n_heads: is the number of heads55:param r: is the lora rank56"""57super().__init__()58self.d_model = d_model59self.n_heads = n_heads60self.d_head = d_model // n_heads6162# Linear transformation for QKV63self.qkv_projection = Linear(d_model, d_model * 3, r=r, bias=True)64# Output projection65self.output_projection = Linear(d_model, d_model, r=r, bias=True)6667def _split_heads(self, x: torch.Tensor):68"""69:param x: is the tensor with shape `[batch_size, seq_len, d_model]`70"""71# Split last dimension to `[n_heads, d_head]`72x = x.view(x.shape[:-1] + (self.n_heads, self.d_head))73# Reorder to `[batch_size, head, seq_length, d_head]`74return x.permute(0, 2, 1, 3)7576def forward(self, x: torch.Tensor) -> torch.Tensor:77"""78:param x: is the embeddings tensor with shape `[batch_size, seq_len, d_model]`79"""80batch_size, seq_length, _ = x.shape8182# Get query, key and value83q, k, v = self.qkv_projection(x).split(self.d_model, dim=-1)8485# Transform them from shape `[batch_size, seq_len, d_model]` to `[batch_size, head, seq_length, d_head]`86q = self._split_heads(q)87k = self._split_heads(k)88v = self._split_heads(v)8990# Apply causal attention91attn_output = torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=True)9293# Transform them from shape `[batch_size, head, seq_length, d_head]` to `[batch_size, seq_len, d_model]`94attn_output = attn_output.permute(0, 2, 1, 3).reshape(batch_size, seq_length, self.d_model)9596# Final project97return self.output_projection(attn_output)9899100class Block(nn.Module):101"""102### Decoder block103"""104105def __init__(self, d_model: int, n_heads: int, layer_norm_epsilon: float, r: int):106"""107:param d_model: is the number of dimensions in the embeddings108:param n_heads: is the number of heads109:param layer_norm_epsilon: is the layer norm epsilon110:param r: is the lora rank111"""112super().__init__()113# Attention pre-normalization layer114self.attn_norm = nn.LayerNorm(d_model, eps=layer_norm_epsilon)115# Attention layer116self.attn = MultiHeadAttention(d_model, n_heads, r)117# FFN pre-normalization layer118self.ffn_norm = nn.LayerNorm(d_model, eps=layer_norm_epsilon)119# Feed-forward network120self.ffn = FFN(d_model, d_model * 4, r)121122def forward(self, x: torch.Tensor) -> torch.Tensor:123"""124:param x: is the embeddings tensor with shape `[batch_size, seq_len, d_model]`125"""126# Attention127x = x + self.attn(self.attn_norm(x))128# FFN129x = x + self.ffn(self.ffn_norm(x))130131return x132133134class GPTModel(nn.Module):135"""136## GPT2 Model137"""138139def __init__(self, *, d_model: int,140n_heads: int, n_layers: int,141n_positions: int,142layer_norm_epsilon: float,143vocab_size: int, r: int):144"""145:param d_model: is the number of dimensions in the embeddings146:param n_heads: is the number of attention heads147:param n_layers: is the number of decoder layers148:param n_positions: is the number of positional embeddings149:param layer_norm_epsilon: is the layer norm epsilon150:param vocab_size: is the vocabulary size151:param r: is the lora rank152"""153super().__init__()154155# Token and absolute positional embeddings156self.token_embedding = Embedding(vocab_size, d_model, r=r)157self.position_embedding = Embedding(n_positions, d_model, r=r)158159# Decoder blocks160self.blocks = nn.ModuleList([Block(d_model, n_heads, layer_norm_epsilon, r=r)161for _ in range(n_layers)])162163# Final layer norm164self.final_norm = nn.LayerNorm(d_model, eps=layer_norm_epsilon)165# Projection layer to logit space166self.lm_head = Linear(d_model, vocab_size, r=r, bias=False)167168def forward(self, input_ids: torch.Tensor):169"""170:param input_ids: has shape `[batch_size, seq_len]`171"""172batch_size, seq_len = input_ids.shape173174# Get token embeddings175token_embeddings = self.token_embedding(input_ids)176# Get position ids177position_ids = torch.arange(seq_len, device=input_ids.device)[None, :]178# Get position embeddings179position_embeddings = self.position_embedding(position_ids)180181# Add position embeddings182x = token_embeddings + position_embeddings183184# Run through transformer blocks185for block in self.blocks:186x = block(x)187188# Final normalization189x = self.final_norm(x)190# Get logits from projection layer191return self.lm_head(x)192193194