Path: blob/master/labml_nn/sampling/top_k.py
4918 views
"""1---2title: Top-k Sampling3summary: A PyTorch implementation of top-k sampling from language models.4---56# Top-k Sampling78Here we first pick the top-k tokens from the distribution of logits, and then9sample from them.1011Here's an [experiment](experiment.html) that uses these sampling techniques.12"""1314import torch1516from labml_nn.sampling import Sampler171819class TopKSampler(Sampler):20"""21## Top-k Sampler22"""23def __init__(self, k: int, sampler: Sampler):24"""25:param k: is the number of tokens to pick26:param sampler: is the sampler to use for the top-k tokens2728`sampler` can be any sampler that takes a logits tensor as input and returns a token tensor;29e.g. [`TemperatureSampler'](temperature.html).30"""31self.k = k32self.sampler = sampler3334def __call__(self, logits: torch.Tensor):35"""36Sample from logits37"""38# New logits filled with $-\infty$; i.e. zero probability39zeros = logits.new_ones(logits.shape) * float('-inf')40# Pick the largest $k$ logits and their indices41values, indices = torch.topk(logits, self.k, dim=-1)42# Set the values of the top-k selected indices to actual logits.43# Logits of other tokens remain $-\infty$44zeros.scatter_(-1, indices, values)4546# Sample from the top-k logits with the specified sampler.47return self.sampler(zeros)484950