Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
labmlai
GitHub Repository: labmlai/annotated_deep_learning_paper_implementations
Path: blob/master/labml_nn/neox/samples/generate.py
4929 views
1
"""
2
---
3
title: Generate Text with GPT-NeoX
4
summary: >
5
Generate Text with GPT-NeoX
6
---
7
8
# Generate Text with GPT-NeoX
9
10
This shows how to generate text from GPT-NeoX with a single GPU.
11
12
This needs a GPU with more than 45GB memory.
13
"""
14
15
# Imports
16
from typing import List
17
18
import torch
19
from torch import nn
20
21
from labml import monit
22
from labml_nn.neox.model import LayerGenerator
23
from labml_nn.neox.utils import get_tokens, print_tokens
24
from labml_nn.neox.utils.cache import get_cache
25
26
# List of layers to load. This is used for testing.
27
# You can assign a subset of layers like `{0, 1}` so that it only loads
28
# the first to transformer layers.
29
LAYERS = None
30
31
# Prompt to complete
32
PROMPT = 'Einstein was born in the German Empire, but moved to Switzerland in 1895, forsaking his German'
33
34
35
def infer(model: nn.Module, ids: List[int], device: torch.device):
36
"""
37
### Predict the next token
38
39
:param model: is the model
40
:param ids: are the input token ids
41
:param device: is the device of the model
42
"""
43
44
with torch.no_grad():
45
# Get the tokens
46
x = torch.tensor(ids)[None, :].to(device)
47
# Eval model
48
x = model(x)
49
50
# Return predicted token
51
return x[0].max(dim=-1)[1].tolist()
52
53
54
def generate():
55
"""
56
## Generate text
57
"""
58
59
# Setup [cache](../utils/cache.html) to cache intermediate key/value pairs for faster generation
60
cache = get_cache()
61
cache.set('use_cache', True)
62
63
# Device
64
device = torch.device('cuda:0')
65
66
# Load layers
67
layers = list(LayerGenerator(is_clone_layers=True,
68
filter_layers=LAYERS,
69
dtype=torch.float16,
70
device=device,
71
).load())
72
73
model = nn.Sequential(*layers)
74
75
# Get token ids
76
ids = get_tokens(PROMPT)
77
78
# Run the model
79
cache.set('state_ids', (None, 1))
80
with monit.section('Infer'):
81
next_token = infer(model, ids, device)[-1]
82
83
# Append the predicted token
84
ids += [next_token]
85
86
# Predict 100 tokens
87
for i in range(1, 100):
88
# Set the state to use cached activations
89
cache.set('state_ids', (i, i + 1))
90
# Get next token. Note that we only feed the last token to the model because
91
# we cache the key/value pairs of previous tokens.
92
with monit.section('Infer'):
93
next_token = infer(model, [next_token], device)[-1]
94
# Append the predicted token
95
ids += [next_token]
96
# Print
97
print_tokens(ids, [ids])
98
99
100
#
101
if __name__ == '__main__':
102
generate()
103
104