Path: blob/master/labml_nn/sampling/nucleus.py
4918 views
"""1---2title: Nucleus Sampling3summary: A PyTorch implementation of nucleus sampling from language models.4---56# Nucleus Sampling78This is an implementation of nucleus sampling, introduced in the paper9[The Curious Case of Neural Text Degeneration](https://arxiv.org/abs/1904.09751).1011The paper discusses the problems with other sampling methods such as Beam Search,12[Pure sampling](temperature.html), [Temperature sampling](temperature.html), and13[Top-k sampling](top_k.html). The paper introduces the idea of nucleus sampling,14which practically performs better than other sampling methods for text generation.1516Nucleus sampling first picks a subset of the vocabulary $V^{(p)} \subset V$,17where $V^{(p)}$ is smallest set of tokens such that1819$$\sum_{x_i \in V^{(p)}} P(x_i | x_{1:i-1}) \ge p$$2021That is, we pick the highest probable tokens until the sum of their probabilities is less that $p$.2223Then we sample from the selected tokens.2425Here's an [experiment](experiment.html) that uses these sampling techniques.26"""2728import torch29from torch import nn3031from labml_nn.sampling import Sampler323334class NucleusSampler(Sampler):35"""36## Nucleus Sampler37"""38def __init__(self, p: float, sampler: Sampler):39"""40:param p: is the sum of probabilities of tokens to pick $p$41:param sampler: is the sampler to use for the selected tokens42"""43self.p = p44self.sampler = sampler45# Softmax to compute $P(x_i | x_{1:i-1})$ from the logits46self.softmax = nn.Softmax(dim=-1)4748def __call__(self, logits: torch.Tensor):49"""50Sample from logits with Nucleus Sampling51"""5253# Get probabilities $P(x_i | x_{1:i-1})$54probs = self.softmax(logits)5556# Sort probabilities in descending order57sorted_probs, indices = torch.sort(probs, dim=-1, descending=True)58# Get the cumulative sum of probabilities in the sorted order59cum_sum_probs = torch.cumsum(sorted_probs, dim=-1)60# Find the cumulative sums less than $p$.61nucleus = cum_sum_probs < self.p62# Prepend ones so that we add one token after the minimum number63# of tokens with cumulative probability less that $p$.64nucleus = torch.cat([nucleus.new_ones(nucleus.shape[:-1] + (1,)), nucleus[..., :-1]], dim=-1)6566# Get log probabilities and mask out the non-nucleus67sorted_log_probs = torch.log(sorted_probs)68sorted_log_probs[~nucleus] = float('-inf')6970# Sample from the sampler71sampled_sorted_indexes = self.sampler(sorted_log_probs)7273# Get the actual indexes74res = indices.gather(-1, sampled_sorted_indexes.unsqueeze(-1))7576#77return res.squeeze(-1)787980