Path: blob/master/labml_nn/sampling/experiment.py
4918 views
"""1---2title: Trying out Sampling Techniques for Language Models3summary: >4We try out different sampling techniques for language models on HuggingFace's GPT2 model.5---67# Trying out Sampling Techniques for Language Models89* [Greedy Sampling](greedy.html)10* [Temperature Sampling](temperature.html)11* [Top-k Sampling](top_k.html)12* [Nucleus Sampling](nucleus.html)1314This experiment uses the above sampling techniques, on HuggingFace's GPT2 model.15"""1617import torch1819from labml import monit, logger, lab2021from labml.logger import Text2223from labml_nn.sampling import Sampler24from labml_nn.sampling.greedy import GreedySampler25from labml_nn.sampling.nucleus import NucleusSampler26from labml_nn.sampling.temperature import TemperatureSampler27from labml_nn.sampling.top_k import TopKSampler28from transformers import GPT2Tokenizer, GPT2LMHeadModel293031@torch.no_grad()32def sample(model: GPT2LMHeadModel, tokenizer: GPT2Tokenizer, sampler: Sampler,33n_samples: int, n_tokens: int, seq_len: int, prompt: str):34"""35## Sample from model3637:param model: is the model to sample from38:param tokenizer: is the tokenizer to use39:param sampler: is the sampler to use40:param n_samples: is the number of samples to generate41:param n_tokens: is the number of tokens to generate42:param seq_len: is the maximum sequence length for the model43:param prompt: is the starting prompt44"""45# Tokenize the `prompt` and make `n_samples` copies of it46data = torch.tile(torch.tensor(tokenizer.encode(prompt))[None, :], (n_samples, 1))4748# Collect output for printing49logs = [[(prompt, Text.meta)] for _ in range(n_samples)]50# Sample `n_tokens`51for i in monit.iterate('Sample', n_tokens):52# Truncate the data to the maximum sequence length53data = data[-seq_len:]54# Get the model output. The 'logits' has shape `[batch_size, seq_len, n_tokens]`55logits = model(data)[0]56# Get the `logits` of the last token57logits = logits[:, -1]58# Sample from the `logits`59res = sampler(logits)60# Add the sampled token to the data61data = torch.cat([data, res[:, None]], dim=1)62# Decode and add the sampled token for logging63for j in range(n_samples):64logs[j] += [('' + tokenizer.decode(res[j]), Text.value)]6566# Print the sampled outputs67for j in range(n_samples):68logger.log(logs[j])697071def main():72"""73### Try different sampling techniques74"""7576# Load the model and tokenizer77with monit.section('Load tokenizer/model'):78tokenizer = GPT2Tokenizer.from_pretrained('gpt2', cache_dir=lab.get_data_path() / 'cache')79model = GPT2LMHeadModel.from_pretrained('gpt2', cache_dir=lab.get_data_path() / 'cache')80# Set the model to eval mode81model.eval()8283# Prompts to use for sampling84prompt = 'I saw an interesting dream last night. '8586# [Greedy Sampling](greedy.html)87with monit.section('greedy'):88sample(model, tokenizer, GreedySampler(), 4, 32, 128, prompt)8990# [Temperature Sampling](temperature.html)91with monit.section('temperature=1.'):92sample(model, tokenizer, TemperatureSampler(1.), 4, 32, 128, prompt)93with monit.section('temperature=.1'):94sample(model, tokenizer, TemperatureSampler(.1), 4, 32, 128, prompt)95with monit.section('temperature=10.'):96sample(model, tokenizer, TemperatureSampler(10.), 4, 32, 128, prompt)9798# [Top-k Sampling](top_k.html)99with monit.section('top_k=5'):100sample(model, tokenizer, TopKSampler(2, TemperatureSampler(1.)), 4, 32, 128, prompt)101102# [Nucleus Sampling](nucleus.html)103with monit.section('nucleus p=.95'):104sample(model, tokenizer, NucleusSampler(0.95, TemperatureSampler(1.)), 4, 32, 128, prompt)105with monit.section('nucleus p=.1'):106sample(model, tokenizer, NucleusSampler(0.1, TemperatureSampler(1.)), 4, 32, 128, prompt)107108#109if __name__ == '__main__':110main()111112113