Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
labmlai
GitHub Repository: labmlai/annotated_deep_learning_paper_implementations
Path: blob/master/labml_nn/neox/checkpoint.py
4918 views
1
"""
2
---
3
title: GPT-NeoX Checkpoints
4
summary: >
5
Code to download checkpoints and helpers to load them.
6
---
7
8
# GPT-NeoX Checkpoints
9
10
"""
11
from pathlib import Path
12
from typing import Dict, Union, Tuple, Optional
13
14
import torch
15
from torch import nn
16
17
from labml import monit, lab, logger
18
from labml.logger import Text, inspect
19
from labml.utils.download import download_file
20
21
# Parent url
22
CHECKPOINTS_URL = 'https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/'
23
24
_CHECKPOINTS_DOWNLOAD_PATH: Optional[Path] = None
25
26
27
# Download path
28
def get_checkpoints_download_path():
29
global _CHECKPOINTS_DOWNLOAD_PATH
30
31
if _CHECKPOINTS_DOWNLOAD_PATH is not None:
32
return _CHECKPOINTS_DOWNLOAD_PATH
33
34
_CHECKPOINTS_DOWNLOAD_PATH = lab.get_data_path() / 'neox_fast' / 'slim_weights'
35
if not _CHECKPOINTS_DOWNLOAD_PATH.exists():
36
_CHECKPOINTS_DOWNLOAD_PATH = lab.get_data_path() / 'neox' / 'slim_weights'
37
inspect(neox_checkpoint_path=_CHECKPOINTS_DOWNLOAD_PATH)
38
39
return _CHECKPOINTS_DOWNLOAD_PATH
40
41
42
def get_files_to_download(n_layers: int = 44):
43
"""
44
### Get files to download
45
46
:return: a list of files to be downloaded
47
"""
48
layers = (
49
# Embedding layer
50
[0] +
51
# Transformer layers
52
list(range(2, 2 + n_layers)) +
53
# Final normalization layer and readout layer
54
[47, 48]
55
)
56
57
return (
58
# Vocabulary and configs
59
['20B_tokenizer.json', 'configs/20B.yml', 'latest'] +
60
# Layer checkpoints
61
[f'global_step150000/layer_{i :02d}-model_{p :02d}-model_states.pt' for i in layers for p in range(2)] +
62
# Empty states (not used)
63
[f'global_step150000/mp_rank_{i :02d}_model_states.pt' for i in range(8)]
64
)
65
66
67
def download(n_layers: int = 44):
68
"""
69
## Download all checkpoint files
70
"""
71
72
# Get files to download
73
files = get_files_to_download(n_layers)
74
75
# Iterate
76
for i, f in monit.enum('Download All', files):
77
# Log
78
logger.log(['Downloading ', (f'{i + 1 :3d}/{len(files)}', Text.meta), ': ', (f, Text.value)])
79
# Download
80
download_file(CHECKPOINTS_URL + f, get_checkpoints_download_path() / f)
81
82
83
def load_checkpoint_files(files: Tuple[str, str]):
84
"""
85
### Load a pair of checkpoint files
86
87
:param files: pair of files to load
88
:return: the loaded parameter tensors
89
"""
90
checkpoint_path = get_checkpoints_download_path() / 'global_step150000'
91
with monit.section('Load checkpoint'):
92
data = [torch.load(checkpoint_path / f) for f in files]
93
94
return data
95
96
97
def merge_params_dim_0(param: Union[nn.Parameter, torch.Tensor], key: str, p1: Dict[str, torch.Tensor],
98
p2: Dict[str, torch.Tensor]):
99
"""
100
### Load a parameter by merging the partitions along first dimension
101
102
:param param: is the parameter
103
:param key: is the name of the parameter
104
:param p1: first partition dictionary
105
:param p2: second partition dictionary
106
"""
107
w1, w2 = p1[key], p2[key]
108
param.data[:w1.shape[0]] = w1
109
param.data[w1.shape[0]:] = w2
110
111
112
def merge_params_dim_1(param: Union[nn.Parameter, torch.Tensor], key: str, p1: Dict[str, torch.Tensor],
113
p2: Dict[str, torch.Tensor]):
114
"""
115
### Load a parameter by merging the partitions along second dimension
116
117
:param param: is the parameter
118
:param key: is the name of the parameter
119
:param p1: first partition dictionary
120
:param p2: second partition dictionary
121
"""
122
w1, w2 = p1[key], p2[key]
123
param.data[:, :w1.shape[1]] = w1
124
param.data[:, w1.shape[1]:] = w2
125
126
127
def merge_params_duplicate(param: Union[nn.Parameter, torch.Tensor], key: str, p1: Dict[str, torch.Tensor],
128
p2: Dict[str, torch.Tensor]):
129
"""
130
### Load an un-partitioned parameter
131
132
This does a sanity check to make use both partitions are the same
133
134
:param param: is the parameter
135
:param key: is the name of the parameter
136
:param p1: first partition dictionary
137
:param p2: second partition dictionary
138
"""
139
w1, w2 = p1[key], p2[key]
140
141
diff = sum((w1 - w2) ** 2).item()
142
assert diff < 1e-4, f'The partitions do not match: {key}'
143
144
param.data[:] = (w1 + w2) / 2.
145
146
147
def merge_params_sum(param: Union[nn.Parameter, torch.Tensor], key: str, p1: Dict[str, torch.Tensor],
148
p2: Dict[str, torch.Tensor]):
149
"""
150
### Load biases that are partitioned which gets added on reduce
151
152
:param param: is the parameter
153
:param key: is the name of the parameter
154
:param p1: first partition dictionary
155
:param p2: second partition dictionary
156
"""
157
w1, w2 = p1[key], p2[key]
158
159
param.data[:] = w1 + w2
160
161
162
#
163
if __name__ == '__main__':
164
download()
165
166