Path: blob/master/labml_nn/sampling/temperature.py
4918 views
"""1---2title: Sampling from Language Models with Temperature3summary: A PyTorch implementation of sampling from language models with temperature.4---56# Sampling from Language Models with Temperature78Here we sample from the following probability distribution where $V$ is the vocabulary,9$u_{1:|V|}$ are the logits of the distribution and T is the temperature:1011$$P(x_i=V_l | x_{1:i-1}) = \frac{\exp(\frac{u_l}{T})}{\sum_j \exp(\frac{u_j}{T})}$$1213$T = 1$ is normal random sampling.1415Here's an [experiment](experiment.html) that uses these sampling techniques.16"""1718import torch19from torch.distributions import Categorical2021from labml_nn.sampling import Sampler222324class TemperatureSampler(Sampler):25"""26## Sampler with Temperature27"""28def __init__(self, temperature: float = 1.0):29"""30:param temperature: is the temperature to sample with31"""32self.temperature = temperature3334def __call__(self, logits: torch.Tensor):35"""36Sample from logits37"""3839# Create a categorical distribution with temperature adjusted logits40dist = Categorical(logits=logits / self.temperature)4142# Sample43return dist.sample()444546