Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
labmlai
GitHub Repository: labmlai/annotated_deep_learning_paper_implementations
Path: blob/master/labml_nn/sampling/nucleus.py
4918 views
1
"""
2
---
3
title: Nucleus Sampling
4
summary: A PyTorch implementation of nucleus sampling from language models.
5
---
6
7
# Nucleus Sampling
8
9
This is an implementation of nucleus sampling, introduced in the paper
10
[The Curious Case of Neural Text Degeneration](https://arxiv.org/abs/1904.09751).
11
12
The paper discusses the problems with other sampling methods such as Beam Search,
13
[Pure sampling](temperature.html), [Temperature sampling](temperature.html), and
14
[Top-k sampling](top_k.html). The paper introduces the idea of nucleus sampling,
15
which practically performs better than other sampling methods for text generation.
16
17
Nucleus sampling first picks a subset of the vocabulary $V^{(p)} \subset V$,
18
where $V^{(p)}$ is smallest set of tokens such that
19
20
$$\sum_{x_i \in V^{(p)}} P(x_i | x_{1:i-1}) \ge p$$
21
22
That is, we pick the highest probable tokens until the sum of their probabilities is less that $p$.
23
24
Then we sample from the selected tokens.
25
26
Here's an [experiment](experiment.html) that uses these sampling techniques.
27
"""
28
29
import torch
30
from torch import nn
31
32
from labml_nn.sampling import Sampler
33
34
35
class NucleusSampler(Sampler):
36
"""
37
## Nucleus Sampler
38
"""
39
def __init__(self, p: float, sampler: Sampler):
40
"""
41
:param p: is the sum of probabilities of tokens to pick $p$
42
:param sampler: is the sampler to use for the selected tokens
43
"""
44
self.p = p
45
self.sampler = sampler
46
# Softmax to compute $P(x_i | x_{1:i-1})$ from the logits
47
self.softmax = nn.Softmax(dim=-1)
48
49
def __call__(self, logits: torch.Tensor):
50
"""
51
Sample from logits with Nucleus Sampling
52
"""
53
54
# Get probabilities $P(x_i | x_{1:i-1})$
55
probs = self.softmax(logits)
56
57
# Sort probabilities in descending order
58
sorted_probs, indices = torch.sort(probs, dim=-1, descending=True)
59
# Get the cumulative sum of probabilities in the sorted order
60
cum_sum_probs = torch.cumsum(sorted_probs, dim=-1)
61
# Find the cumulative sums less than $p$.
62
nucleus = cum_sum_probs < self.p
63
# Prepend ones so that we add one token after the minimum number
64
# of tokens with cumulative probability less that $p$.
65
nucleus = torch.cat([nucleus.new_ones(nucleus.shape[:-1] + (1,)), nucleus[..., :-1]], dim=-1)
66
67
# Get log probabilities and mask out the non-nucleus
68
sorted_log_probs = torch.log(sorted_probs)
69
sorted_log_probs[~nucleus] = float('-inf')
70
71
# Sample from the sampler
72
sampled_sorted_indexes = self.sampler(sorted_log_probs)
73
74
# Get the actual indexes
75
res = indices.gather(-1, sampled_sorted_indexes.unsqueeze(-1))
76
77
#
78
return res.squeeze(-1)
79
80