Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
labmlai
GitHub Repository: labmlai/annotated_deep_learning_paper_implementations
Path: blob/master/labml_nn/sampling/greedy.py
4918 views
1
"""
2
---
3
title: Greedy Sampling
4
summary: A PyTorch implementation of greedy sampling from language models.
5
---
6
7
# Greedy Sampling
8
9
Here we sample the most likely token from the distribution of logits.
10
11
Here's an [experiment](experiment.html) that uses these sampling techniques.
12
"""
13
14
import torch
15
16
from labml_nn.sampling import Sampler
17
18
19
class GreedySampler(Sampler):
20
def __call__(self, logits: torch.Tensor):
21
"""
22
Sample the most likely token from the distribution of logits
23
"""
24
return logits.argmax(dim=-1)
25
26