Path: blob/master/labml_nn/neox/checkpoint.py
4918 views
"""1---2title: GPT-NeoX Checkpoints3summary: >4Code to download checkpoints and helpers to load them.5---67# GPT-NeoX Checkpoints89"""10from pathlib import Path11from typing import Dict, Union, Tuple, Optional1213import torch14from torch import nn1516from labml import monit, lab, logger17from labml.logger import Text, inspect18from labml.utils.download import download_file1920# Parent url21CHECKPOINTS_URL = 'https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/'2223_CHECKPOINTS_DOWNLOAD_PATH: Optional[Path] = None242526# Download path27def get_checkpoints_download_path():28global _CHECKPOINTS_DOWNLOAD_PATH2930if _CHECKPOINTS_DOWNLOAD_PATH is not None:31return _CHECKPOINTS_DOWNLOAD_PATH3233_CHECKPOINTS_DOWNLOAD_PATH = lab.get_data_path() / 'neox_fast' / 'slim_weights'34if not _CHECKPOINTS_DOWNLOAD_PATH.exists():35_CHECKPOINTS_DOWNLOAD_PATH = lab.get_data_path() / 'neox' / 'slim_weights'36inspect(neox_checkpoint_path=_CHECKPOINTS_DOWNLOAD_PATH)3738return _CHECKPOINTS_DOWNLOAD_PATH394041def get_files_to_download(n_layers: int = 44):42"""43### Get files to download4445:return: a list of files to be downloaded46"""47layers = (48# Embedding layer49[0] +50# Transformer layers51list(range(2, 2 + n_layers)) +52# Final normalization layer and readout layer53[47, 48]54)5556return (57# Vocabulary and configs58['20B_tokenizer.json', 'configs/20B.yml', 'latest'] +59# Layer checkpoints60[f'global_step150000/layer_{i :02d}-model_{p :02d}-model_states.pt' for i in layers for p in range(2)] +61# Empty states (not used)62[f'global_step150000/mp_rank_{i :02d}_model_states.pt' for i in range(8)]63)646566def download(n_layers: int = 44):67"""68## Download all checkpoint files69"""7071# Get files to download72files = get_files_to_download(n_layers)7374# Iterate75for i, f in monit.enum('Download All', files):76# Log77logger.log(['Downloading ', (f'{i + 1 :3d}/{len(files)}', Text.meta), ': ', (f, Text.value)])78# Download79download_file(CHECKPOINTS_URL + f, get_checkpoints_download_path() / f)808182def load_checkpoint_files(files: Tuple[str, str]):83"""84### Load a pair of checkpoint files8586:param files: pair of files to load87:return: the loaded parameter tensors88"""89checkpoint_path = get_checkpoints_download_path() / 'global_step150000'90with monit.section('Load checkpoint'):91data = [torch.load(checkpoint_path / f) for f in files]9293return data949596def merge_params_dim_0(param: Union[nn.Parameter, torch.Tensor], key: str, p1: Dict[str, torch.Tensor],97p2: Dict[str, torch.Tensor]):98"""99### Load a parameter by merging the partitions along first dimension100101:param param: is the parameter102:param key: is the name of the parameter103:param p1: first partition dictionary104:param p2: second partition dictionary105"""106w1, w2 = p1[key], p2[key]107param.data[:w1.shape[0]] = w1108param.data[w1.shape[0]:] = w2109110111def merge_params_dim_1(param: Union[nn.Parameter, torch.Tensor], key: str, p1: Dict[str, torch.Tensor],112p2: Dict[str, torch.Tensor]):113"""114### Load a parameter by merging the partitions along second dimension115116:param param: is the parameter117:param key: is the name of the parameter118:param p1: first partition dictionary119:param p2: second partition dictionary120"""121w1, w2 = p1[key], p2[key]122param.data[:, :w1.shape[1]] = w1123param.data[:, w1.shape[1]:] = w2124125126def merge_params_duplicate(param: Union[nn.Parameter, torch.Tensor], key: str, p1: Dict[str, torch.Tensor],127p2: Dict[str, torch.Tensor]):128"""129### Load an un-partitioned parameter130131This does a sanity check to make use both partitions are the same132133:param param: is the parameter134:param key: is the name of the parameter135:param p1: first partition dictionary136:param p2: second partition dictionary137"""138w1, w2 = p1[key], p2[key]139140diff = sum((w1 - w2) ** 2).item()141assert diff < 1e-4, f'The partitions do not match: {key}'142143param.data[:] = (w1 + w2) / 2.144145146def merge_params_sum(param: Union[nn.Parameter, torch.Tensor], key: str, p1: Dict[str, torch.Tensor],147p2: Dict[str, torch.Tensor]):148"""149### Load biases that are partitioned which gets added on reduce150151:param param: is the parameter152:param key: is the name of the parameter153:param p1: first partition dictionary154:param p2: second partition dictionary155"""156w1, w2 = p1[key], p2[key]157158param.data[:] = w1 + w2159160161#162if __name__ == '__main__':163download()164165166