Path: blob/main/a5/mingpt-demo/mingpt/model.py
1003 views
"""1GPT model:2- the initial stem consists of a combination of token encoding and a positional encoding3- the meat of it is a uniform sequence of Transformer blocks4- each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block5- all blocks feed into a central residual pathway similar to resnets6- the final decoder is a linear projection into a vanilla Softmax classifier7"""89import math10import logging1112import torch13import torch.nn as nn14from torch.nn import functional as F1516logger = logging.getLogger(__name__)1718class GPTConfig:19""" base GPT config, params common to all GPT versions """20embd_pdrop = 0.121resid_pdrop = 0.122attn_pdrop = 0.12324def __init__(self, vocab_size, block_size, **kwargs):25self.vocab_size = vocab_size26self.block_size = block_size27for k,v in kwargs.items():28setattr(self, k, v)2930class GPT1Config(GPTConfig):31""" GPT-1 like network roughly 125M params """32n_layer = 1233n_head = 1234n_embd = 7683536class CausalSelfAttention(nn.Module):37"""38A vanilla multi-head masked self-attention layer with a projection at the end.39It is possible to use torch.nn.MultiheadAttention here but I am including an40explicit implementation here to show that there is nothing too scary here.41"""4243def __init__(self, config):44super().__init__()45assert config.n_embd % config.n_head == 046# key, query, value projections for all heads47self.key = nn.Linear(config.n_embd, config.n_embd)48self.query = nn.Linear(config.n_embd, config.n_embd)49self.value = nn.Linear(config.n_embd, config.n_embd)50# regularization51self.attn_drop = nn.Dropout(config.attn_pdrop)52self.resid_drop = nn.Dropout(config.resid_pdrop)53# output projection54self.proj = nn.Linear(config.n_embd, config.n_embd)55# causal mask to ensure that attention is only applied to the left in the input sequence56self.register_buffer("mask", torch.tril(torch.ones(config.block_size, config.block_size))57.view(1, 1, config.block_size, config.block_size))58self.n_head = config.n_head5960def forward(self, x, layer_past=None):61B, T, C = x.size()6263# calculate query, key, values for all heads in batch and move head forward to be the batch dim64k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)65q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)66v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)6768# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)69att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))70att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))71att = F.softmax(att, dim=-1)72att = self.attn_drop(att)73y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)74y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side7576# output projection77y = self.resid_drop(self.proj(y))78return y7980class Block(nn.Module):81""" an unassuming Transformer block """8283def __init__(self, config):84super().__init__()85self.ln1 = nn.LayerNorm(config.n_embd)86self.ln2 = nn.LayerNorm(config.n_embd)87self.attn = CausalSelfAttention(config)88self.mlp = nn.Sequential(89nn.Linear(config.n_embd, 4 * config.n_embd),90nn.GELU(),91nn.Linear(4 * config.n_embd, config.n_embd),92nn.Dropout(config.resid_pdrop),93)9495def forward(self, x):96x = x + self.attn(self.ln1(x))97x = x + self.mlp(self.ln2(x))98return x99100class GPT(nn.Module):101""" the full GPT language model, with a context size of block_size """102103def __init__(self, config):104super().__init__()105106# input embedding stem107self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)108self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))109self.drop = nn.Dropout(config.embd_pdrop)110# transformer111self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])112# decoder head113self.ln_f = nn.LayerNorm(config.n_embd)114self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)115116self.block_size = config.block_size117self.apply(self._init_weights)118119logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))120121def get_block_size(self):122return self.block_size123124def _init_weights(self, module):125if isinstance(module, (nn.Linear, nn.Embedding)):126module.weight.data.normal_(mean=0.0, std=0.02)127if isinstance(module, nn.Linear) and module.bias is not None:128module.bias.data.zero_()129elif isinstance(module, nn.LayerNorm):130module.bias.data.zero_()131module.weight.data.fill_(1.0)132133def configure_optimizers(self, train_config):134"""135This long function is unfortunately doing something very simple and is being very defensive:136We are separating out all parameters of the model into two buckets: those that will experience137weight decay for regularization and those that won't (biases, and layernorm/embedding weights).138We are then returning the PyTorch optimizer object.139"""140141# separate out all parameters to those that will and won't experience regularizing weight decay142decay = set()143no_decay = set()144whitelist_weight_modules = (torch.nn.Linear, )145blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)146for mn, m in self.named_modules():147for pn, p in m.named_parameters():148fpn = '%s.%s' % (mn, pn) if mn else pn # full param name149150if pn.endswith('bias'):151# all biases will not be decayed152no_decay.add(fpn)153elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):154# weights of whitelist modules will be weight decayed155decay.add(fpn)156elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):157# weights of blacklist modules will NOT be weight decayed158no_decay.add(fpn)159160# special case the position embedding parameter in the root GPT module as not decayed161no_decay.add('pos_emb')162163# validate that we considered every parameter164param_dict = {pn: p for pn, p in self.named_parameters()}165inter_params = decay & no_decay166union_params = decay | no_decay167assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )168assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \169% (str(param_dict.keys() - union_params), )170171# create the pytorch optimizer object172optim_groups = [173{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},174{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},175]176optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)177return optimizer178179def forward(self, idx, targets=None):180b, t = idx.size()181assert t <= self.block_size, "Cannot forward, model block size is exhausted."182183# forward the GPT model184token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector185position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector186x = self.drop(token_embeddings + position_embeddings)187x = self.blocks(x)188x = self.ln_f(x)189logits = self.head(x)190191# if we are given some desired targets also calculate the loss192loss = None193if targets is not None:194loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))195196return logits, loss197198199