Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
labmlai
GitHub Repository: labmlai/annotated_deep_learning_paper_implementations
Path: blob/master/labml_nn/sampling/top_k.py
4918 views
1
"""
2
---
3
title: Top-k Sampling
4
summary: A PyTorch implementation of top-k sampling from language models.
5
---
6
7
# Top-k Sampling
8
9
Here we first pick the top-k tokens from the distribution of logits, and then
10
sample from them.
11
12
Here's an [experiment](experiment.html) that uses these sampling techniques.
13
"""
14
15
import torch
16
17
from labml_nn.sampling import Sampler
18
19
20
class TopKSampler(Sampler):
21
"""
22
## Top-k Sampler
23
"""
24
def __init__(self, k: int, sampler: Sampler):
25
"""
26
:param k: is the number of tokens to pick
27
:param sampler: is the sampler to use for the top-k tokens
28
29
`sampler` can be any sampler that takes a logits tensor as input and returns a token tensor;
30
e.g. [`TemperatureSampler'](temperature.html).
31
"""
32
self.k = k
33
self.sampler = sampler
34
35
def __call__(self, logits: torch.Tensor):
36
"""
37
Sample from logits
38
"""
39
# New logits filled with $-\infty$; i.e. zero probability
40
zeros = logits.new_ones(logits.shape) * float('-inf')
41
# Pick the largest $k$ logits and their indices
42
values, indices = torch.topk(logits, self.k, dim=-1)
43
# Set the values of the top-k selected indices to actual logits.
44
# Logits of other tokens remain $-\infty$
45
zeros.scatter_(-1, indices, values)
46
47
# Sample from the top-k logits with the specified sampler.
48
return self.sampler(zeros)
49
50