Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
labmlai
GitHub Repository: labmlai/annotated_deep_learning_paper_implementations
Path: blob/master/labml_nn/sampling/experiment.py
4918 views
1
"""
2
---
3
title: Trying out Sampling Techniques for Language Models
4
summary: >
5
We try out different sampling techniques for language models on HuggingFace's GPT2 model.
6
---
7
8
# Trying out Sampling Techniques for Language Models
9
10
* [Greedy Sampling](greedy.html)
11
* [Temperature Sampling](temperature.html)
12
* [Top-k Sampling](top_k.html)
13
* [Nucleus Sampling](nucleus.html)
14
15
This experiment uses the above sampling techniques, on HuggingFace's GPT2 model.
16
"""
17
18
import torch
19
20
from labml import monit, logger, lab
21
22
from labml.logger import Text
23
24
from labml_nn.sampling import Sampler
25
from labml_nn.sampling.greedy import GreedySampler
26
from labml_nn.sampling.nucleus import NucleusSampler
27
from labml_nn.sampling.temperature import TemperatureSampler
28
from labml_nn.sampling.top_k import TopKSampler
29
from transformers import GPT2Tokenizer, GPT2LMHeadModel
30
31
32
@torch.no_grad()
33
def sample(model: GPT2LMHeadModel, tokenizer: GPT2Tokenizer, sampler: Sampler,
34
n_samples: int, n_tokens: int, seq_len: int, prompt: str):
35
"""
36
## Sample from model
37
38
:param model: is the model to sample from
39
:param tokenizer: is the tokenizer to use
40
:param sampler: is the sampler to use
41
:param n_samples: is the number of samples to generate
42
:param n_tokens: is the number of tokens to generate
43
:param seq_len: is the maximum sequence length for the model
44
:param prompt: is the starting prompt
45
"""
46
# Tokenize the `prompt` and make `n_samples` copies of it
47
data = torch.tile(torch.tensor(tokenizer.encode(prompt))[None, :], (n_samples, 1))
48
49
# Collect output for printing
50
logs = [[(prompt, Text.meta)] for _ in range(n_samples)]
51
# Sample `n_tokens`
52
for i in monit.iterate('Sample', n_tokens):
53
# Truncate the data to the maximum sequence length
54
data = data[-seq_len:]
55
# Get the model output. The 'logits' has shape `[batch_size, seq_len, n_tokens]`
56
logits = model(data)[0]
57
# Get the `logits` of the last token
58
logits = logits[:, -1]
59
# Sample from the `logits`
60
res = sampler(logits)
61
# Add the sampled token to the data
62
data = torch.cat([data, res[:, None]], dim=1)
63
# Decode and add the sampled token for logging
64
for j in range(n_samples):
65
logs[j] += [('' + tokenizer.decode(res[j]), Text.value)]
66
67
# Print the sampled outputs
68
for j in range(n_samples):
69
logger.log(logs[j])
70
71
72
def main():
73
"""
74
### Try different sampling techniques
75
"""
76
77
# Load the model and tokenizer
78
with monit.section('Load tokenizer/model'):
79
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', cache_dir=lab.get_data_path() / 'cache')
80
model = GPT2LMHeadModel.from_pretrained('gpt2', cache_dir=lab.get_data_path() / 'cache')
81
# Set the model to eval mode
82
model.eval()
83
84
# Prompts to use for sampling
85
prompt = 'I saw an interesting dream last night. '
86
87
# [Greedy Sampling](greedy.html)
88
with monit.section('greedy'):
89
sample(model, tokenizer, GreedySampler(), 4, 32, 128, prompt)
90
91
# [Temperature Sampling](temperature.html)
92
with monit.section('temperature=1.'):
93
sample(model, tokenizer, TemperatureSampler(1.), 4, 32, 128, prompt)
94
with monit.section('temperature=.1'):
95
sample(model, tokenizer, TemperatureSampler(.1), 4, 32, 128, prompt)
96
with monit.section('temperature=10.'):
97
sample(model, tokenizer, TemperatureSampler(10.), 4, 32, 128, prompt)
98
99
# [Top-k Sampling](top_k.html)
100
with monit.section('top_k=5'):
101
sample(model, tokenizer, TopKSampler(2, TemperatureSampler(1.)), 4, 32, 128, prompt)
102
103
# [Nucleus Sampling](nucleus.html)
104
with monit.section('nucleus p=.95'):
105
sample(model, tokenizer, NucleusSampler(0.95, TemperatureSampler(1.)), 4, 32, 128, prompt)
106
with monit.section('nucleus p=.1'):
107
sample(model, tokenizer, NucleusSampler(0.1, TemperatureSampler(1.)), 4, 32, 128, prompt)
108
109
#
110
if __name__ == '__main__':
111
main()
112
113