Path: blob/master/labml_nn/neox/model.py
4918 views
"""1---2title: GPT-NeoX Model Definition3summary: >4This is the model definition of GPT-NeoX.5---67# GPT-NeoX Model89Here is the code for layers of GPT-NeoX model and the code to load1020B checkpoint.1112The method `load_state` in the layers load the checkpoints of that layer.13The checkpoint loading helpers are on [`checkpoint.py`](checkpoint.html)14"""15import copy16import math17from typing import Dict, Optional, Set, Callable, Any, Generator, Tuple1819import torch20from torch import nn21from torch.cuda.amp import autocast2223from labml import monit, logger24from labml.logger import Text25from labml_nn.neox import checkpoint26from labml_nn.neox.utils.cache import get_cache272829class NeoXModule(nn.Module):30def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):31pass323334class Embedding(NeoXModule):35"""36## Embedding layer3738This is a standard embeddings layer with code to load the checkpoint.39"""4041def __init__(self, n_vocab: int = 50_432, n_hidden: int = 6_144):42"""43:param n_vocab: is the size of the vocabulary44:param n_hidden: is the size of the embeddings45"""46super().__init__()4748self.emb = nn.Embedding(n_vocab, n_hidden)4950def forward(self, x: torch.Tensor):51"""52:param x: are the token ids of shape `[batch_size, seq_len]`53"""54return self.emb(x)5556def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):57"""58Code to load the checkpoint59"""60with monit.section('Load embedding layer'):61checkpoint.merge_params_dim_0(self.emb.weight, 'word_embeddings.weight', p1, p2)626364class RoPE(nn.Module):65"""66## Rotary Positional Embeddings6768GPT-NeoX uses [rotary positional embeddings (RoPE)](https://arxiv.org/abs/2104.09864).6970WE have annotated implementation of RoPE [here](https://nn.labml.ai/transformers/rope/index.html)71with more notes the theory.72"""7374def __init__(self, d_rope: int, base: float = 10_000.):75"""76:param d_rope: is the number of features for RoPE embeddings77:param base: is the base for $\theta_i = 10000^{\frac{2(i-1)}{d}}$, which defaults to $10000$78"""79super().__init__()8081# To store $\theta_i$ for the features82self.theta = None83# Cache $\cos m\theta_i$ and $\sin m\theta_i$84self.cos_cached = None85self.sin_cached = None8687# Base for $\theta_i = 10000^{\frac{2(i-1)}{d}}$88self.base = base89# Number of features for RoPE90self.d_rope = d_rope9192@staticmethod93def rotate_half(x: torch.Tensor):94"""95### Rotate the features9697$[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., -x^{(\frac{d}{2})}]$98"""99x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:]100return torch.cat((-x2, x1), dim=-1)101102def forward(self, x: torch.Tensor, offset: int = 0):103"""104:param x: has shape `[..., seq, n_heads, d_k]`105:param offset: is the starting position of `x`. This is $\gt 0$ when we have106cached the keys and queries of previous positions107"""108109# Get the actual sequence length110seq_len = x.shape[-3] + offset111112# Initialize $\theta$113if self.theta is None:114# $\theta_i = 10000^{\frac{2(i-1)}{d}}$115theta = 1.0 / (self.base ** (torch.arange(0, self.d_rope, 2).float() / self.d_rope))116self.theta = theta.to(x.device).to(x.dtype)117118# Initialize $\cos m\theta_i$ and $\sin m\theta_i$ cache119if (120self.cos_cached is None or121seq_len > self.cos_cached.shape[1] or122self.cos_cached.device != x.device or123self.cos_cached.dtype != x.dtype124):125# Get position indexes $m$126seq_idx = torch.arange(seq_len, device=x.device).type_as(self.theta)127# $m \theta_i$128idx_theta = torch.einsum("s,d->sd", seq_idx, self.theta)129# Concatenate so that for row $m$ we have130#131# $$[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$$132idx_theta2 = torch.cat((idx_theta, idx_theta), dim=-1).to(x.device)133134# Calculate $\cos m\theta_i$ and $\sin m\theta_i$ in fp32135with autocast(enabled=False):136idx_theta2 = idx_theta2.float()137# Add head dimension138self.cos_cached = idx_theta2.cos()[:, None, :]139self.sin_cached = idx_theta2.sin()[:, None, :]140141# Cache them142self.cos_cached = self.cos_cached.to(x.dtype)143self.sin_cached = self.sin_cached.to(x.dtype)144145# Split the features. We apply RoPE to only `d_rope` features146x_rope, x_pass = x[..., :self.d_rope], x[..., self.d_rope:]147148# Get the sin and cos values from the cache149cos, sin = self.cos_cached[offset: seq_len], self.sin_cached[offset: seq_len]150151# RoPE embeddings152#153# \begin{align}154# \begin{pmatrix}155# x^{(i)}_m \cos m \theta_i - x^{(i + \frac{d}{2})}_m \sin m \theta_i \\156# x^{(i + \frac{d}{2})}_m \cos m\theta_i + x^{(i)}_m \sin m \theta_i \\157# \end{pmatrix} \\158# \end{align}159#160# for $i \in {1, 2, ..., \frac{d}{2}}$161x_rope = (x_rope * cos) + (self.rotate_half(x_rope) * sin)162163# Concatenate with features that didn't get RoPE embeddings164return torch.cat((x_rope, x_pass), dim=-1)165166167class AttentionLayer(nn.Module):168"""169## Attention layer170"""171172def __init__(self, n_hidden: int = 6_144, n_heads: int = 64, rope_percentage: float = 0.25,173mask_fill: float = -10_000.0, *, is_flash_attention: bool = False):174"""175:param n_hidden: the number of features in embeddings176:param n_heads: the number of attention heads177:param rope_percentage: percentage of features to add RoPE embeddings178:param mask_fill: masking fill value for attention matrix179:param is_flash_attention: specifies whether to use180[FlashAttention](https://github.com/HazyResearch/flash-attention)181"""182super().__init__()183184self.n_heads = n_heads185self.mask_fill = mask_fill186187# Linear layer for query, key and value188self.qkv_lin = nn.Linear(n_hidden, n_hidden * 3)189# Final linear layer190self.output = nn.Linear(n_hidden, n_hidden)191192# Number of features per head193d_k = n_hidden // n_heads194# RoPE embedding module195self.rope = RoPE(int(d_k * rope_percentage))196197# Attention scaling factor198self.scale = 1 / math.sqrt(d_k)199200# To cache causal mask201self.causal_mask = None202203# Attention softmax module204self.softmax = nn.Softmax(dim=-2)205206# [FlashAttention](https://github.com/HazyResearch/flash-attention)207if is_flash_attention:208try:209from flash_attn.flash_attention import FlashAttention210self.flash_attention = FlashAttention()211except ImportError:212logger.log('Install flash attention github.com/HazyResearch/flash-attention. '213'Falling back to normal attention', Text.warning)214self.flash_attention = None215else:216self.flash_attention = None217218def _get_mask(self, attn: torch.Tensor):219"""220#### Calculate the causal mask221222* `attn` has shape [batch_size, query_seq_len, key_seq_len, n_heads]223"""224225# Query and key lengths226nq, nk = attn.shape[1:3]227228# Create mask229if (230self.causal_mask is None or231self.causal_mask.shape[0] != nq or232self.causal_mask.shape[1] != nk or233self.causal_mask.device != attn.device234):235self.causal_mask = torch.triu(attn.new_ones([nq, nk], dtype=torch.bool), 1 + nk - nq)236237# Return from cache238return self.causal_mask[None, :, :, None]239240def forward(self, x: torch.Tensor):241"""242:param x: has shape `[batch_size, seq_len, n_hidden]`243"""244# Get query, key and value embeddings (all concatenated).245# The last dimension size will change from n_hidden -> `3 x n_hidden`246qkv = self.qkv_lin(x)247248# Split into heads by changing the shape to `[batch_size, seq_len, n_heads, 3 * d_k]`249qkv = qkv.view(*qkv.shape[:-1], self.n_heads, -1)250# Split into query, key and value each of shape `[batch_size, seq_len, n_heads, 3 * d_k]`251q, k, v = torch.split(qkv, qkv.shape[-1] // 3, dim=-1)252253# If we are caching the states of previous tokens254if get_cache().get('use_cache', False):255# Get the state id's. We use to retrieve previous states and store the next states256prev_state_id, next_state_id = get_cache().get('state_ids')257# If there's cache258if prev_state_id is not None:259# Get the past keys and values. These will have shape `[batch_size, prev_seq_len, n_heads, d_k]`260k_past, v_past = get_cache().pop(f'attn_kv_{prev_state_id}')261# Offset of the current embeddings262offset = k_past.shape[1]263264# Add RoPE embeddings265q = self.rope(q, offset=offset)266k = self.rope(k, offset=offset)267268# Concatenate the past269k = torch.cat([k_past, k], dim=1)270v = torch.cat([v_past, v], dim=1)271else:272# Add RoPE embeddings273q = self.rope(q)274k = self.rope(k)275276# Save the current state277get_cache().push(f'attn_kv_{next_state_id}', (k, v))278else:279# No cache - simply add RoPE embeddings280q = self.rope(q)281k = self.rope(k)282283# Use flash attention284if self.flash_attention is not None and q.shape[1] == k.shape[1] and q.shape[-1] <= 128:285output = self.compute_flash_attention(q, k, v)286# Otherwise, use normal attention287else:288output = self.compute_attention(q, k, v)289290# Reshape from `[batch_size, seq_len, n_heads, d_k] to `[batch_size, seq_len, n_hidden]`291output = output.reshape(*x.shape)292293# Final linear layer294return self.output(output)295296def compute_flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):297# Stack them into shape `[batch_size, seq_len, 3, n_heads, d_k]`298qkv = torch.stack((q, k, v), dim=2)299d_k = qkv.shape[-1]300if d_k <= 32:301pad = 32 - d_k302elif d_k <= 64:303pad = 64 - d_k304elif d_k <= 128:305pad = 128 - d_k306else:307raise ValueError(f'Head size {d_k} too large for flash attention')308309if pad > 0:310qkv = torch.cat((qkv, qkv.new_zeros(*qkv.shape[:-1], pad)), dim=-1)311312output, _ = self.flash_attention(qkv, causal=True)313# The output is of shape `[batch_size, seq_len, n_heads, d_k + padding]`314output = output[:, :, :, :d_k]315316return output317318def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):319# Disable auto-casting to fp16 for attention computation320with autocast(enabled=False):321if q.dtype == torch.float16:322# Convert to fp32 if the current dtype is fp16323attn = torch.einsum('bihk,bjhk->bijh', q.float(), k.float())324else:325# Do not cast for bfloat326attn = torch.einsum('bihk,bjhk->bijh', q, k)327328# Scale attention329attn = attn * self.scale330331# Get causal mask332mask = self._get_mask(attn)333# Apply mask334attn.masked_fill_(mask, self.mask_fill)335336# Attention softmax337attn = self.softmax(attn)338339# Get attention weighted values340output = torch.einsum('bijh,bjhk->bihk', attn.to(v.dtype), v)341342return output343344345class FFNLayer(nn.Module):346"""347## Feedforward Network348"""349350def __init__(self, n_hidden: int = 6_144, d_ff: int = 0):351"""352:param n_hidden: is the embedding size353"""354super().__init__()355356if not d_ff:357d_ff = n_hidden * 4358359# Expansion linear layer360self.dense_h_h4 = nn.Linear(n_hidden, d_ff)361# GELU activation362self.activation = nn.GELU()363# Contraction linear layer364self.dense_h4_h = nn.Linear(d_ff, n_hidden)365366def forward(self, x: torch.Tensor):367"""368:param x: has shape `[batch_size, seq_len, n_hidden]`369"""370x = self.dense_h_h4(x)371x = self.activation(x)372x = self.dense_h4_h(x)373374return x375376377class TransformerLayer(NeoXModule):378"""379## Transformer Layer380"""381382def __init__(self, n_hidden: int = 6_144, n_heads: int = 64, *, is_flash_attention: bool = False):383"""384:param n_hidden: is the embedding size385:param n_heads: is the number of heads386:param is_flash_attention: specifies whether to use387[FlashAttention](https://github.com/HazyResearch/flash-attention)388389*Out implementation doesn't include dropout*.390"""391super().__init__()392393# Layer normalization before attention394self.pre_ln_attn = nn.LayerNorm(n_hidden)395# Layer normalization before FFN396self.pre_ln_ffn = nn.LayerNorm(n_hidden)397398# Attention layer399self.attention = AttentionLayer(n_hidden, n_heads, is_flash_attention=is_flash_attention)400# FFN layer401self.ffn = FFNLayer(n_hidden)402403def forward(self, x: torch.Tensor):404"""405:param x: are the embeddings of shape `[batch_size, seq_len, n_hidden]`406"""407408# Residual connection409residual = x410# NeoX runs attention and feedforward network in parallel411attn = self.attention(self.pre_ln_attn(x))412ffn = self.ffn(self.pre_ln_ffn(x))413# Add them and the residual connection414return attn + ffn + residual415416def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):417"""418Code to load the checkpoint419"""420with monit.section('Load transformer layer'):421# Attention output transform422checkpoint.merge_params_sum(self.attention.output.bias, 'attention.dense.bias', p1, p2)423checkpoint.merge_params_dim_1(self.attention.output.weight, 'attention.dense.weight', p1, p2)424425# Attention query, key and value transform426checkpoint.merge_params_dim_0(self.attention.qkv_lin.bias, 'attention.query_key_value.bias', p1, p2)427checkpoint.merge_params_dim_0(self.attention.qkv_lin.weight, 'attention.query_key_value.weight', p1, p2)428429# Layer norm before attention430checkpoint.merge_params_duplicate(self.pre_ln_attn.bias, 'input_layernorm.bias', p1, p2)431checkpoint.merge_params_duplicate(self.pre_ln_attn.weight, 'input_layernorm.weight', p1, p2)432433# FFN second transform434checkpoint.merge_params_dim_0(self.ffn.dense_h_h4.bias, 'mlp.dense_h_to_4h.bias', p1, p2)435checkpoint.merge_params_dim_0(self.ffn.dense_h_h4.weight, 'mlp.dense_h_to_4h.weight', p1, p2)436437# FFN first transform438checkpoint.merge_params_sum(self.ffn.dense_h4_h.bias, 'mlp.dense_4h_to_h.bias', p1, p2)439checkpoint.merge_params_dim_1(self.ffn.dense_h4_h.weight, 'mlp.dense_4h_to_h.weight', p1, p2)440441# Layer norm before FFN442checkpoint.merge_params_duplicate(self.pre_ln_ffn.bias, 'post_attention_layernorm.bias', p1, p2)443checkpoint.merge_params_duplicate(self.pre_ln_ffn.weight, 'post_attention_layernorm.weight', p1, p2)444445446class FinalNorm(NeoXModule):447"""448## Final normalization layer449"""450451def __init__(self, n_hidden: int = 6_144):452"""453:param n_hidden: is the embedding size454"""455super().__init__()456457self.ln = nn.LayerNorm(n_hidden)458459def forward(self, x: torch.Tensor):460"""461:param x: are the embeddings of shape `[batch_size, seq_len, n_hidden]`462"""463return self.ln(x)464465def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):466"""467Code to load the checkpoint468"""469with monit.section('Load final normalization layer'):470checkpoint.merge_params_duplicate(self.ln.bias, 'norm.bias', p1, p2)471checkpoint.merge_params_duplicate(self.ln.weight, 'norm.weight', p1, p2)472473474class ReadoutLayer(NeoXModule):475"""476Readout layer477"""478479def __init__(self, n_hidden: int = 6_144, n_vocab: int = 50_432):480"""481:param n_hidden: is the embedding size482:param n_vocab: is the size of the vocabulary483"""484super().__init__()485486self.linear = nn.Linear(n_hidden, n_vocab, bias=False)487488def forward(self, x: torch.Tensor):489"""490:param x: are the embeddings of shape `[batch_size, seq_len, n_hidden]`491"""492return self.linear(x)493494def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):495"""496Code to load the checkpoint497"""498with monit.section('Load final linear layer'):499checkpoint.merge_params_dim_0(self.linear.weight, 'final_linear.weight', p1, p2)500501502class LayerGenerator:503pre_created_layers: Dict[Any, Optional[NeoXModule]]504505def __init__(self, *, n_vocab: int = 50_432, n_hidden: int = 6_144,506n_layers: int = 44, n_heads: int = 64,507filter_layers: Optional[Set] = None,508is_clone_layers: bool = True,509dtype: torch.dtype = torch.float,510device: torch.device = torch.device('cpu'),511is_llm_int8: bool = False,512llm_int8_threshold: float = 6.0,513is_flash_attention: bool = False514):515"""516### Generator to create layers517518The layers are generated in the same order as checkpoints.519520It gives `None` when a layer is not available; we use the layer indices as NeoX and there are two521transformation layers we don't need in our implementation.522523:param n_vocab: is the number of tokens in the vocabulary524:param n_hidden: is the number of features in the embeddings525:param n_layers: is the number of transformer layers526:param n_heads: is the number of attention heads527:param filter_layers: are the set of layers to be used. All layers will be used if None.528This is used to test smaller versions of the model with fewer layers529:param is_clone_layers: specifies whether to clone the transformer layers (a bit faster)530:param dtype: is the data type of the model531:param device: is the device of the model532:param is_llm_int8: specifies whether to use int8 quantization533:param llm_int8_threshold: is the threshold $\alpha$ used to separate outlier features534:param is_flash_attention: specifies whether to use535[FlashAttention](https://github.com/HazyResearch/flash-attention)536"""537if filter_layers is None:538filter_layers = set(range(n_layers + 3))539540self.n_vocab = n_vocab541self.n_hidden = n_hidden542self.n_layers = n_layers543self.n_heads = n_heads544self.filter_layers = filter_layers545self.is_clone_layers = is_clone_layers546self.dtype = dtype547self.device = device548self.is_llm_int8 = is_llm_int8549self.llm_int8_threshold = llm_int8_threshold550self.is_flash_attention = is_flash_attention551552self.pre_created_layers = dict(553transformer_layer=None,554)555556def _prepare_layer(self, layer: NeoXModule):557"""558#### Prepares the layer for usage559560We move the layer to the device and convert it to the correct data type561562:param layer: is the layer to prepare563:return: the prepared layer564"""565return layer.to(self.device, self.dtype)566567@torch.no_grad()568def post_load_prepare(self, layer: NeoXModule, *,569is_llm_int8: bool = None,570device: torch.device = None,571llm_int8_threshold: float = None,572):573"""574<a id="post_load_prepare"></a>575576### Layer transformations after loading the checkpoint577578This function implements layer transformations after loading the checkpoint.579580Currently, it only applies the int8 quantization.581582:param layer: is the layer to prepare583:param is_llm_int8: specifies whether to use int8 quantization584:param device: is the device of the model585:param llm_int8_threshold: is the threshold $\alpha$ used to separate outlier features586:return: the prepared layer587"""588589# Get default values if not specified590if is_llm_int8 is None:591is_llm_int8 = self.is_llm_int8592if device is None:593device = self.device594if llm_int8_threshold is None:595llm_int8_threshold = self.llm_int8_threshold596597# Skip if not using int8 quantization598if not is_llm_int8:599return layer600601# Only convert the linear layers in the transformer layers602if not isinstance(layer, TransformerLayer):603return layer604605# Use `make_llm_int8_linear` defined in [utilities](./utils/llm_int8.html).606from labml_nn.neox.utils.llm_int8 import make_llm_int8_linear607608# Convert the linear layers609with monit.section('Convert to int8'):610layer.attention.output = make_llm_int8_linear(layer.attention.output,611device=device,612threshold=llm_int8_threshold)613layer.attention.qkv_lin = make_llm_int8_linear(layer.attention.qkv_lin,614device=device,615threshold=llm_int8_threshold)616layer.ffn.dense_h_h4 = make_llm_int8_linear(layer.ffn.dense_h_h4,617device=device,618threshold=llm_int8_threshold)619layer.ffn.dense_h4_h = make_llm_int8_linear(layer.ffn.dense_h4_h,620device=device,621threshold=llm_int8_threshold)622#623return layer624625def _create_and_cache_layer(self, name: str, creator: Callable[[], NeoXModule]):626"""627#### Creates and caches a layer628629Copying cached layers is faster than initializing new layers because it takes time to630initialize parameters.631632:param name: is the name of the layer633:param creator: is the function to create the layer634:return: the created layer or a copy of the cached layer635"""636637if not self.is_clone_layers:638return self._prepare_layer(creator())639640if self.pre_created_layers[name] is None:641self.pre_created_layers[name] = self._prepare_layer(creator())642643layer = copy.deepcopy(self.pre_created_layers[name])644return layer645646def _create_transformer_layer(self):647return self._create_and_cache_layer(648'transformer_layer',649lambda: TransformerLayer(self.n_hidden, self.n_heads, is_flash_attention=self.is_flash_attention)650)651652def _create_embedding_layer(self):653return Embedding(self.n_vocab, self.n_hidden)654655def _create_final_norm_layer(self):656return FinalNorm(self.n_hidden)657658def _create_readout_layer(self):659return ReadoutLayer(self.n_hidden, self.n_vocab)660661@torch.no_grad()662def get_layers(self) -> Generator[Tuple[NeoXModule, Tuple[str, str]], None, None]:663"""664### Generator to get layers665"""666# Embedding layer667if 0 in self.filter_layers:668with monit.section('Embedding layer'):669layer = self._prepare_layer(self._create_embedding_layer())670yield layer, ('layer_00-model_00-model_states.pt', 'layer_00-model_01-model_states.pt')671672# Transformer layers673for i in range(self.n_layers):674# Transformer layer675if i + 1 in self.filter_layers:676with monit.section(f'Transformer Layer {i}'):677yield self._create_transformer_layer(), \678(f'layer_{i + 2 :02d}-model_00-model_states.pt',679f'layer_{i + 2 :02d}-model_01-model_states.pt')680681# Final normalization layer682if self.n_layers + 1 in self.filter_layers:683with monit.section('Final norm layer'):684layer = self._prepare_layer(self._create_final_norm_layer())685yield layer, ('layer_47-model_00-model_states.pt', 'layer_47-model_01-model_states.pt')686687# Readout layer688if self.n_layers + 2 in self.filter_layers:689with monit.section('Readout layer'):690layer = self._prepare_layer(self._create_readout_layer())691yield layer, ('layer_48-model_00-model_states.pt', 'layer_48-model_01-model_states.pt')692693for k in self.pre_created_layers.keys():694self.pre_created_layers[k] = None695696@property697def total_layers(self):698"""699### Returns the total number of layers700"""701return self.n_layers + 3702703@torch.no_grad()704def load(self) -> Generator[NeoXModule, None, None]:705"""706### Generator to load layers707"""708with monit.section("Layers"):709for i, (layer, files) in enumerate(self.get_layers()):710if files is not None:711layer.load_state(*checkpoint.load_checkpoint_files(files))712713layer = self.post_load_prepare(layer)714715monit.progress(min(0.99, (i + 1) / self.total_layers))716yield layer717718719