Path: blob/master/labml_nn/neox/samples/generate.py
4929 views
"""1---2title: Generate Text with GPT-NeoX3summary: >4Generate Text with GPT-NeoX5---67# Generate Text with GPT-NeoX89This shows how to generate text from GPT-NeoX with a single GPU.1011This needs a GPU with more than 45GB memory.12"""1314# Imports15from typing import List1617import torch18from torch import nn1920from labml import monit21from labml_nn.neox.model import LayerGenerator22from labml_nn.neox.utils import get_tokens, print_tokens23from labml_nn.neox.utils.cache import get_cache2425# List of layers to load. This is used for testing.26# You can assign a subset of layers like `{0, 1}` so that it only loads27# the first to transformer layers.28LAYERS = None2930# Prompt to complete31PROMPT = 'Einstein was born in the German Empire, but moved to Switzerland in 1895, forsaking his German'323334def infer(model: nn.Module, ids: List[int], device: torch.device):35"""36### Predict the next token3738:param model: is the model39:param ids: are the input token ids40:param device: is the device of the model41"""4243with torch.no_grad():44# Get the tokens45x = torch.tensor(ids)[None, :].to(device)46# Eval model47x = model(x)4849# Return predicted token50return x[0].max(dim=-1)[1].tolist()515253def generate():54"""55## Generate text56"""5758# Setup [cache](../utils/cache.html) to cache intermediate key/value pairs for faster generation59cache = get_cache()60cache.set('use_cache', True)6162# Device63device = torch.device('cuda:0')6465# Load layers66layers = list(LayerGenerator(is_clone_layers=True,67filter_layers=LAYERS,68dtype=torch.float16,69device=device,70).load())7172model = nn.Sequential(*layers)7374# Get token ids75ids = get_tokens(PROMPT)7677# Run the model78cache.set('state_ids', (None, 1))79with monit.section('Infer'):80next_token = infer(model, ids, device)[-1]8182# Append the predicted token83ids += [next_token]8485# Predict 100 tokens86for i in range(1, 100):87# Set the state to use cached activations88cache.set('state_ids', (i, i + 1))89# Get next token. Note that we only feed the last token to the model because90# we cache the key/value pairs of previous tokens.91with monit.section('Infer'):92next_token = infer(model, [next_token], device)[-1]93# Append the predicted token94ids += [next_token]9596print_tokens(ids, [ids])979899#100if __name__ == '__main__':101generate()102103104