Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
labmlai
GitHub Repository: labmlai/annotated_deep_learning_paper_implementations
Path: blob/master/labml_nn/neox/model.py
4918 views
1
"""
2
---
3
title: GPT-NeoX Model Definition
4
summary: >
5
This is the model definition of GPT-NeoX.
6
---
7
8
# GPT-NeoX Model
9
10
Here is the code for layers of GPT-NeoX model and the code to load
11
20B checkpoint.
12
13
The method `load_state` in the layers load the checkpoints of that layer.
14
The checkpoint loading helpers are on [`checkpoint.py`](checkpoint.html)
15
"""
16
import copy
17
import math
18
from typing import Dict, Optional, Set, Callable, Any, Generator, Tuple
19
20
import torch
21
from torch import nn
22
from torch.cuda.amp import autocast
23
24
from labml import monit, logger
25
from labml.logger import Text
26
from labml_nn.neox import checkpoint
27
from labml_nn.neox.utils.cache import get_cache
28
29
30
class NeoXModule(nn.Module):
31
def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
32
pass
33
34
35
class Embedding(NeoXModule):
36
"""
37
## Embedding layer
38
39
This is a standard embeddings layer with code to load the checkpoint.
40
"""
41
42
def __init__(self, n_vocab: int = 50_432, n_hidden: int = 6_144):
43
"""
44
:param n_vocab: is the size of the vocabulary
45
:param n_hidden: is the size of the embeddings
46
"""
47
super().__init__()
48
49
self.emb = nn.Embedding(n_vocab, n_hidden)
50
51
def forward(self, x: torch.Tensor):
52
"""
53
:param x: are the token ids of shape `[batch_size, seq_len]`
54
"""
55
return self.emb(x)
56
57
def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
58
"""
59
Code to load the checkpoint
60
"""
61
with monit.section('Load embedding layer'):
62
checkpoint.merge_params_dim_0(self.emb.weight, 'word_embeddings.weight', p1, p2)
63
64
65
class RoPE(nn.Module):
66
"""
67
## Rotary Positional Embeddings
68
69
GPT-NeoX uses [rotary positional embeddings (RoPE)](https://arxiv.org/abs/2104.09864).
70
71
WE have annotated implementation of RoPE [here](https://nn.labml.ai/transformers/rope/index.html)
72
with more notes the theory.
73
"""
74
75
def __init__(self, d_rope: int, base: float = 10_000.):
76
"""
77
:param d_rope: is the number of features for RoPE embeddings
78
:param base: is the base for $\theta_i = 10000^{\frac{2(i-1)}{d}}$, which defaults to $10000$
79
"""
80
super().__init__()
81
82
# To store $\theta_i$ for the features
83
self.theta = None
84
# Cache $\cos m\theta_i$ and $\sin m\theta_i$
85
self.cos_cached = None
86
self.sin_cached = None
87
88
# Base for $\theta_i = 10000^{\frac{2(i-1)}{d}}$
89
self.base = base
90
# Number of features for RoPE
91
self.d_rope = d_rope
92
93
@staticmethod
94
def rotate_half(x: torch.Tensor):
95
"""
96
### Rotate the features
97
98
$[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., -x^{(\frac{d}{2})}]$
99
"""
100
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
101
return torch.cat((-x2, x1), dim=-1)
102
103
def forward(self, x: torch.Tensor, offset: int = 0):
104
"""
105
:param x: has shape `[..., seq, n_heads, d_k]`
106
:param offset: is the starting position of `x`. This is $\gt 0$ when we have
107
cached the keys and queries of previous positions
108
"""
109
110
# Get the actual sequence length
111
seq_len = x.shape[-3] + offset
112
113
# Initialize $\theta$
114
if self.theta is None:
115
# $\theta_i = 10000^{\frac{2(i-1)}{d}}$
116
theta = 1.0 / (self.base ** (torch.arange(0, self.d_rope, 2).float() / self.d_rope))
117
self.theta = theta.to(x.device).to(x.dtype)
118
119
# Initialize $\cos m\theta_i$ and $\sin m\theta_i$ cache
120
if (
121
self.cos_cached is None or
122
seq_len > self.cos_cached.shape[1] or
123
self.cos_cached.device != x.device or
124
self.cos_cached.dtype != x.dtype
125
):
126
# Get position indexes $m$
127
seq_idx = torch.arange(seq_len, device=x.device).type_as(self.theta)
128
# $m \theta_i$
129
idx_theta = torch.einsum("s,d->sd", seq_idx, self.theta)
130
# Concatenate so that for row $m$ we have
131
#
132
# $$[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$$
133
idx_theta2 = torch.cat((idx_theta, idx_theta), dim=-1).to(x.device)
134
135
# Calculate $\cos m\theta_i$ and $\sin m\theta_i$ in fp32
136
with autocast(enabled=False):
137
idx_theta2 = idx_theta2.float()
138
# Add head dimension
139
self.cos_cached = idx_theta2.cos()[:, None, :]
140
self.sin_cached = idx_theta2.sin()[:, None, :]
141
142
# Cache them
143
self.cos_cached = self.cos_cached.to(x.dtype)
144
self.sin_cached = self.sin_cached.to(x.dtype)
145
146
# Split the features. We apply RoPE to only `d_rope` features
147
x_rope, x_pass = x[..., :self.d_rope], x[..., self.d_rope:]
148
149
# Get the sin and cos values from the cache
150
cos, sin = self.cos_cached[offset: seq_len], self.sin_cached[offset: seq_len]
151
152
# RoPE embeddings
153
#
154
# \begin{align}
155
# \begin{pmatrix}
156
# x^{(i)}_m \cos m \theta_i - x^{(i + \frac{d}{2})}_m \sin m \theta_i \\
157
# x^{(i + \frac{d}{2})}_m \cos m\theta_i + x^{(i)}_m \sin m \theta_i \\
158
# \end{pmatrix} \\
159
# \end{align}
160
#
161
# for $i \in {1, 2, ..., \frac{d}{2}}$
162
x_rope = (x_rope * cos) + (self.rotate_half(x_rope) * sin)
163
164
# Concatenate with features that didn't get RoPE embeddings
165
return torch.cat((x_rope, x_pass), dim=-1)
166
167
168
class AttentionLayer(nn.Module):
169
"""
170
## Attention layer
171
"""
172
173
def __init__(self, n_hidden: int = 6_144, n_heads: int = 64, rope_percentage: float = 0.25,
174
mask_fill: float = -10_000.0, *, is_flash_attention: bool = False):
175
"""
176
:param n_hidden: the number of features in embeddings
177
:param n_heads: the number of attention heads
178
:param rope_percentage: percentage of features to add RoPE embeddings
179
:param mask_fill: masking fill value for attention matrix
180
:param is_flash_attention: specifies whether to use
181
[FlashAttention](https://github.com/HazyResearch/flash-attention)
182
"""
183
super().__init__()
184
185
self.n_heads = n_heads
186
self.mask_fill = mask_fill
187
188
# Linear layer for query, key and value
189
self.qkv_lin = nn.Linear(n_hidden, n_hidden * 3)
190
# Final linear layer
191
self.output = nn.Linear(n_hidden, n_hidden)
192
193
# Number of features per head
194
d_k = n_hidden // n_heads
195
# RoPE embedding module
196
self.rope = RoPE(int(d_k * rope_percentage))
197
198
# Attention scaling factor
199
self.scale = 1 / math.sqrt(d_k)
200
201
# To cache causal mask
202
self.causal_mask = None
203
204
# Attention softmax module
205
self.softmax = nn.Softmax(dim=-2)
206
207
# [FlashAttention](https://github.com/HazyResearch/flash-attention)
208
if is_flash_attention:
209
try:
210
from flash_attn.flash_attention import FlashAttention
211
self.flash_attention = FlashAttention()
212
except ImportError:
213
logger.log('Install flash attention github.com/HazyResearch/flash-attention. '
214
'Falling back to normal attention', Text.warning)
215
self.flash_attention = None
216
else:
217
self.flash_attention = None
218
219
def _get_mask(self, attn: torch.Tensor):
220
"""
221
#### Calculate the causal mask
222
223
* `attn` has shape [batch_size, query_seq_len, key_seq_len, n_heads]
224
"""
225
226
# Query and key lengths
227
nq, nk = attn.shape[1:3]
228
229
# Create mask
230
if (
231
self.causal_mask is None or
232
self.causal_mask.shape[0] != nq or
233
self.causal_mask.shape[1] != nk or
234
self.causal_mask.device != attn.device
235
):
236
self.causal_mask = torch.triu(attn.new_ones([nq, nk], dtype=torch.bool), 1 + nk - nq)
237
238
# Return from cache
239
return self.causal_mask[None, :, :, None]
240
241
def forward(self, x: torch.Tensor):
242
"""
243
:param x: has shape `[batch_size, seq_len, n_hidden]`
244
"""
245
# Get query, key and value embeddings (all concatenated).
246
# The last dimension size will change from n_hidden -> `3 x n_hidden`
247
qkv = self.qkv_lin(x)
248
249
# Split into heads by changing the shape to `[batch_size, seq_len, n_heads, 3 * d_k]`
250
qkv = qkv.view(*qkv.shape[:-1], self.n_heads, -1)
251
# Split into query, key and value each of shape `[batch_size, seq_len, n_heads, 3 * d_k]`
252
q, k, v = torch.split(qkv, qkv.shape[-1] // 3, dim=-1)
253
254
# If we are caching the states of previous tokens
255
if get_cache().get('use_cache', False):
256
# Get the state id's. We use to retrieve previous states and store the next states
257
prev_state_id, next_state_id = get_cache().get('state_ids')
258
# If there's cache
259
if prev_state_id is not None:
260
# Get the past keys and values. These will have shape `[batch_size, prev_seq_len, n_heads, d_k]`
261
k_past, v_past = get_cache().pop(f'attn_kv_{prev_state_id}')
262
# Offset of the current embeddings
263
offset = k_past.shape[1]
264
265
# Add RoPE embeddings
266
q = self.rope(q, offset=offset)
267
k = self.rope(k, offset=offset)
268
269
# Concatenate the past
270
k = torch.cat([k_past, k], dim=1)
271
v = torch.cat([v_past, v], dim=1)
272
else:
273
# Add RoPE embeddings
274
q = self.rope(q)
275
k = self.rope(k)
276
277
# Save the current state
278
get_cache().push(f'attn_kv_{next_state_id}', (k, v))
279
else:
280
# No cache - simply add RoPE embeddings
281
q = self.rope(q)
282
k = self.rope(k)
283
284
# Use flash attention
285
if self.flash_attention is not None and q.shape[1] == k.shape[1] and q.shape[-1] <= 128:
286
output = self.compute_flash_attention(q, k, v)
287
# Otherwise, use normal attention
288
else:
289
output = self.compute_attention(q, k, v)
290
291
# Reshape from `[batch_size, seq_len, n_heads, d_k] to `[batch_size, seq_len, n_hidden]`
292
output = output.reshape(*x.shape)
293
294
# Final linear layer
295
return self.output(output)
296
297
def compute_flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
298
# Stack them into shape `[batch_size, seq_len, 3, n_heads, d_k]`
299
qkv = torch.stack((q, k, v), dim=2)
300
d_k = qkv.shape[-1]
301
if d_k <= 32:
302
pad = 32 - d_k
303
elif d_k <= 64:
304
pad = 64 - d_k
305
elif d_k <= 128:
306
pad = 128 - d_k
307
else:
308
raise ValueError(f'Head size {d_k} too large for flash attention')
309
310
if pad > 0:
311
qkv = torch.cat((qkv, qkv.new_zeros(*qkv.shape[:-1], pad)), dim=-1)
312
313
output, _ = self.flash_attention(qkv, causal=True)
314
# The output is of shape `[batch_size, seq_len, n_heads, d_k + padding]`
315
output = output[:, :, :, :d_k]
316
317
return output
318
319
def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
320
# Disable auto-casting to fp16 for attention computation
321
with autocast(enabled=False):
322
if q.dtype == torch.float16:
323
# Convert to fp32 if the current dtype is fp16
324
attn = torch.einsum('bihk,bjhk->bijh', q.float(), k.float())
325
else:
326
# Do not cast for bfloat
327
attn = torch.einsum('bihk,bjhk->bijh', q, k)
328
329
# Scale attention
330
attn = attn * self.scale
331
332
# Get causal mask
333
mask = self._get_mask(attn)
334
# Apply mask
335
attn.masked_fill_(mask, self.mask_fill)
336
337
# Attention softmax
338
attn = self.softmax(attn)
339
340
# Get attention weighted values
341
output = torch.einsum('bijh,bjhk->bihk', attn.to(v.dtype), v)
342
343
return output
344
345
346
class FFNLayer(nn.Module):
347
"""
348
## Feedforward Network
349
"""
350
351
def __init__(self, n_hidden: int = 6_144, d_ff: int = 0):
352
"""
353
:param n_hidden: is the embedding size
354
"""
355
super().__init__()
356
357
if not d_ff:
358
d_ff = n_hidden * 4
359
360
# Expansion linear layer
361
self.dense_h_h4 = nn.Linear(n_hidden, d_ff)
362
# GELU activation
363
self.activation = nn.GELU()
364
# Contraction linear layer
365
self.dense_h4_h = nn.Linear(d_ff, n_hidden)
366
367
def forward(self, x: torch.Tensor):
368
"""
369
:param x: has shape `[batch_size, seq_len, n_hidden]`
370
"""
371
x = self.dense_h_h4(x)
372
x = self.activation(x)
373
x = self.dense_h4_h(x)
374
375
return x
376
377
378
class TransformerLayer(NeoXModule):
379
"""
380
## Transformer Layer
381
"""
382
383
def __init__(self, n_hidden: int = 6_144, n_heads: int = 64, *, is_flash_attention: bool = False):
384
"""
385
:param n_hidden: is the embedding size
386
:param n_heads: is the number of heads
387
:param is_flash_attention: specifies whether to use
388
[FlashAttention](https://github.com/HazyResearch/flash-attention)
389
390
*Out implementation doesn't include dropout*.
391
"""
392
super().__init__()
393
394
# Layer normalization before attention
395
self.pre_ln_attn = nn.LayerNorm(n_hidden)
396
# Layer normalization before FFN
397
self.pre_ln_ffn = nn.LayerNorm(n_hidden)
398
399
# Attention layer
400
self.attention = AttentionLayer(n_hidden, n_heads, is_flash_attention=is_flash_attention)
401
# FFN layer
402
self.ffn = FFNLayer(n_hidden)
403
404
def forward(self, x: torch.Tensor):
405
"""
406
:param x: are the embeddings of shape `[batch_size, seq_len, n_hidden]`
407
"""
408
409
# Residual connection
410
residual = x
411
# NeoX runs attention and feedforward network in parallel
412
attn = self.attention(self.pre_ln_attn(x))
413
ffn = self.ffn(self.pre_ln_ffn(x))
414
# Add them and the residual connection
415
return attn + ffn + residual
416
417
def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
418
"""
419
Code to load the checkpoint
420
"""
421
with monit.section('Load transformer layer'):
422
# Attention output transform
423
checkpoint.merge_params_sum(self.attention.output.bias, 'attention.dense.bias', p1, p2)
424
checkpoint.merge_params_dim_1(self.attention.output.weight, 'attention.dense.weight', p1, p2)
425
426
# Attention query, key and value transform
427
checkpoint.merge_params_dim_0(self.attention.qkv_lin.bias, 'attention.query_key_value.bias', p1, p2)
428
checkpoint.merge_params_dim_0(self.attention.qkv_lin.weight, 'attention.query_key_value.weight', p1, p2)
429
430
# Layer norm before attention
431
checkpoint.merge_params_duplicate(self.pre_ln_attn.bias, 'input_layernorm.bias', p1, p2)
432
checkpoint.merge_params_duplicate(self.pre_ln_attn.weight, 'input_layernorm.weight', p1, p2)
433
434
# FFN second transform
435
checkpoint.merge_params_dim_0(self.ffn.dense_h_h4.bias, 'mlp.dense_h_to_4h.bias', p1, p2)
436
checkpoint.merge_params_dim_0(self.ffn.dense_h_h4.weight, 'mlp.dense_h_to_4h.weight', p1, p2)
437
438
# FFN first transform
439
checkpoint.merge_params_sum(self.ffn.dense_h4_h.bias, 'mlp.dense_4h_to_h.bias', p1, p2)
440
checkpoint.merge_params_dim_1(self.ffn.dense_h4_h.weight, 'mlp.dense_4h_to_h.weight', p1, p2)
441
442
# Layer norm before FFN
443
checkpoint.merge_params_duplicate(self.pre_ln_ffn.bias, 'post_attention_layernorm.bias', p1, p2)
444
checkpoint.merge_params_duplicate(self.pre_ln_ffn.weight, 'post_attention_layernorm.weight', p1, p2)
445
446
447
class FinalNorm(NeoXModule):
448
"""
449
## Final normalization layer
450
"""
451
452
def __init__(self, n_hidden: int = 6_144):
453
"""
454
:param n_hidden: is the embedding size
455
"""
456
super().__init__()
457
458
self.ln = nn.LayerNorm(n_hidden)
459
460
def forward(self, x: torch.Tensor):
461
"""
462
:param x: are the embeddings of shape `[batch_size, seq_len, n_hidden]`
463
"""
464
return self.ln(x)
465
466
def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
467
"""
468
Code to load the checkpoint
469
"""
470
with monit.section('Load final normalization layer'):
471
checkpoint.merge_params_duplicate(self.ln.bias, 'norm.bias', p1, p2)
472
checkpoint.merge_params_duplicate(self.ln.weight, 'norm.weight', p1, p2)
473
474
475
class ReadoutLayer(NeoXModule):
476
"""
477
Readout layer
478
"""
479
480
def __init__(self, n_hidden: int = 6_144, n_vocab: int = 50_432):
481
"""
482
:param n_hidden: is the embedding size
483
:param n_vocab: is the size of the vocabulary
484
"""
485
super().__init__()
486
487
self.linear = nn.Linear(n_hidden, n_vocab, bias=False)
488
489
def forward(self, x: torch.Tensor):
490
"""
491
:param x: are the embeddings of shape `[batch_size, seq_len, n_hidden]`
492
"""
493
return self.linear(x)
494
495
def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
496
"""
497
Code to load the checkpoint
498
"""
499
with monit.section('Load final linear layer'):
500
checkpoint.merge_params_dim_0(self.linear.weight, 'final_linear.weight', p1, p2)
501
502
503
class LayerGenerator:
504
pre_created_layers: Dict[Any, Optional[NeoXModule]]
505
506
def __init__(self, *, n_vocab: int = 50_432, n_hidden: int = 6_144,
507
n_layers: int = 44, n_heads: int = 64,
508
filter_layers: Optional[Set] = None,
509
is_clone_layers: bool = True,
510
dtype: torch.dtype = torch.float,
511
device: torch.device = torch.device('cpu'),
512
is_llm_int8: bool = False,
513
llm_int8_threshold: float = 6.0,
514
is_flash_attention: bool = False
515
):
516
"""
517
### Generator to create layers
518
519
The layers are generated in the same order as checkpoints.
520
521
It gives `None` when a layer is not available; we use the layer indices as NeoX and there are two
522
transformation layers we don't need in our implementation.
523
524
:param n_vocab: is the number of tokens in the vocabulary
525
:param n_hidden: is the number of features in the embeddings
526
:param n_layers: is the number of transformer layers
527
:param n_heads: is the number of attention heads
528
:param filter_layers: are the set of layers to be used. All layers will be used if None.
529
This is used to test smaller versions of the model with fewer layers
530
:param is_clone_layers: specifies whether to clone the transformer layers (a bit faster)
531
:param dtype: is the data type of the model
532
:param device: is the device of the model
533
:param is_llm_int8: specifies whether to use int8 quantization
534
:param llm_int8_threshold: is the threshold $\alpha$ used to separate outlier features
535
:param is_flash_attention: specifies whether to use
536
[FlashAttention](https://github.com/HazyResearch/flash-attention)
537
"""
538
if filter_layers is None:
539
filter_layers = set(range(n_layers + 3))
540
541
self.n_vocab = n_vocab
542
self.n_hidden = n_hidden
543
self.n_layers = n_layers
544
self.n_heads = n_heads
545
self.filter_layers = filter_layers
546
self.is_clone_layers = is_clone_layers
547
self.dtype = dtype
548
self.device = device
549
self.is_llm_int8 = is_llm_int8
550
self.llm_int8_threshold = llm_int8_threshold
551
self.is_flash_attention = is_flash_attention
552
553
self.pre_created_layers = dict(
554
transformer_layer=None,
555
)
556
557
def _prepare_layer(self, layer: NeoXModule):
558
"""
559
#### Prepares the layer for usage
560
561
We move the layer to the device and convert it to the correct data type
562
563
:param layer: is the layer to prepare
564
:return: the prepared layer
565
"""
566
return layer.to(self.device, self.dtype)
567
568
@torch.no_grad()
569
def post_load_prepare(self, layer: NeoXModule, *,
570
is_llm_int8: bool = None,
571
device: torch.device = None,
572
llm_int8_threshold: float = None,
573
):
574
"""
575
<a id="post_load_prepare"></a>
576
577
### Layer transformations after loading the checkpoint
578
579
This function implements layer transformations after loading the checkpoint.
580
581
Currently, it only applies the int8 quantization.
582
583
:param layer: is the layer to prepare
584
:param is_llm_int8: specifies whether to use int8 quantization
585
:param device: is the device of the model
586
:param llm_int8_threshold: is the threshold $\alpha$ used to separate outlier features
587
:return: the prepared layer
588
"""
589
590
# Get default values if not specified
591
if is_llm_int8 is None:
592
is_llm_int8 = self.is_llm_int8
593
if device is None:
594
device = self.device
595
if llm_int8_threshold is None:
596
llm_int8_threshold = self.llm_int8_threshold
597
598
# Skip if not using int8 quantization
599
if not is_llm_int8:
600
return layer
601
602
# Only convert the linear layers in the transformer layers
603
if not isinstance(layer, TransformerLayer):
604
return layer
605
606
# Use `make_llm_int8_linear` defined in [utilities](./utils/llm_int8.html).
607
from labml_nn.neox.utils.llm_int8 import make_llm_int8_linear
608
609
# Convert the linear layers
610
with monit.section('Convert to int8'):
611
layer.attention.output = make_llm_int8_linear(layer.attention.output,
612
device=device,
613
threshold=llm_int8_threshold)
614
layer.attention.qkv_lin = make_llm_int8_linear(layer.attention.qkv_lin,
615
device=device,
616
threshold=llm_int8_threshold)
617
layer.ffn.dense_h_h4 = make_llm_int8_linear(layer.ffn.dense_h_h4,
618
device=device,
619
threshold=llm_int8_threshold)
620
layer.ffn.dense_h4_h = make_llm_int8_linear(layer.ffn.dense_h4_h,
621
device=device,
622
threshold=llm_int8_threshold)
623
#
624
return layer
625
626
def _create_and_cache_layer(self, name: str, creator: Callable[[], NeoXModule]):
627
"""
628
#### Creates and caches a layer
629
630
Copying cached layers is faster than initializing new layers because it takes time to
631
initialize parameters.
632
633
:param name: is the name of the layer
634
:param creator: is the function to create the layer
635
:return: the created layer or a copy of the cached layer
636
"""
637
638
if not self.is_clone_layers:
639
return self._prepare_layer(creator())
640
641
if self.pre_created_layers[name] is None:
642
self.pre_created_layers[name] = self._prepare_layer(creator())
643
644
layer = copy.deepcopy(self.pre_created_layers[name])
645
return layer
646
647
def _create_transformer_layer(self):
648
return self._create_and_cache_layer(
649
'transformer_layer',
650
lambda: TransformerLayer(self.n_hidden, self.n_heads, is_flash_attention=self.is_flash_attention)
651
)
652
653
def _create_embedding_layer(self):
654
return Embedding(self.n_vocab, self.n_hidden)
655
656
def _create_final_norm_layer(self):
657
return FinalNorm(self.n_hidden)
658
659
def _create_readout_layer(self):
660
return ReadoutLayer(self.n_hidden, self.n_vocab)
661
662
@torch.no_grad()
663
def get_layers(self) -> Generator[Tuple[NeoXModule, Tuple[str, str]], None, None]:
664
"""
665
### Generator to get layers
666
"""
667
# Embedding layer
668
if 0 in self.filter_layers:
669
with monit.section('Embedding layer'):
670
layer = self._prepare_layer(self._create_embedding_layer())
671
yield layer, ('layer_00-model_00-model_states.pt', 'layer_00-model_01-model_states.pt')
672
673
# Transformer layers
674
for i in range(self.n_layers):
675
# Transformer layer
676
if i + 1 in self.filter_layers:
677
with monit.section(f'Transformer Layer {i}'):
678
yield self._create_transformer_layer(), \
679
(f'layer_{i + 2 :02d}-model_00-model_states.pt',
680
f'layer_{i + 2 :02d}-model_01-model_states.pt')
681
682
# Final normalization layer
683
if self.n_layers + 1 in self.filter_layers:
684
with monit.section('Final norm layer'):
685
layer = self._prepare_layer(self._create_final_norm_layer())
686
yield layer, ('layer_47-model_00-model_states.pt', 'layer_47-model_01-model_states.pt')
687
688
# Readout layer
689
if self.n_layers + 2 in self.filter_layers:
690
with monit.section('Readout layer'):
691
layer = self._prepare_layer(self._create_readout_layer())
692
yield layer, ('layer_48-model_00-model_states.pt', 'layer_48-model_01-model_states.pt')
693
694
for k in self.pre_created_layers.keys():
695
self.pre_created_layers[k] = None
696
697
@property
698
def total_layers(self):
699
"""
700
### Returns the total number of layers
701
"""
702
return self.n_layers + 3
703
704
@torch.no_grad()
705
def load(self) -> Generator[NeoXModule, None, None]:
706
"""
707
### Generator to load layers
708
"""
709
with monit.section("Layers"):
710
for i, (layer, files) in enumerate(self.get_layers()):
711
if files is not None:
712
layer.load_state(*checkpoint.load_checkpoint_files(files))
713
714
layer = self.post_load_prepare(layer)
715
716
monit.progress(min(0.99, (i + 1) / self.total_layers))
717
yield layer
718
719