Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
labmlai
GitHub Repository: labmlai/annotated_deep_learning_paper_implementations
Path: blob/master/labml_nn/sampling/experiment_tiny.py
4910 views
1
from typing import Tuple
2
3
import torch
4
5
from labml import experiment, monit
6
from labml import logger
7
from labml.logger import Text
8
from labml_nn.helpers.datasets import TextDataset
9
from labml_nn.sampling import Sampler
10
from labml_nn.sampling.greedy import GreedySampler
11
from labml_nn.sampling.nucleus import NucleusSampler
12
from labml_nn.sampling.temperature import TemperatureSampler
13
from labml_nn.sampling.top_k import TopKSampler
14
from labml_nn.transformers.basic.autoregressive_experiment import Configs, AutoregressiveTransformer
15
16
17
def get_model_dataset(run_uuid: str) -> Tuple[AutoregressiveTransformer, TextDataset]:
18
experiment.evaluate()
19
20
conf = Configs()
21
22
experiment.configs(conf, experiment.load_configs(run_uuid))
23
24
experiment.load(run_uuid)
25
26
experiment.add_pytorch_models({'model': conf.model})
27
28
experiment.start()
29
30
return conf.model, conf.text
31
32
33
def sample(model, ds, sampler: Sampler, n_samples: int, n_tokens: int, seq_len: int, prompt: str):
34
with torch.no_grad():
35
data = torch.tile(ds.text_to_i(prompt)[:, None], (1, n_samples))
36
37
# Collect output for printing
38
logs = [[(prompt, Text.meta)] for _ in range(n_samples)]
39
# Sample 25 tokens
40
for i in monit.iterate('Sample', n_tokens):
41
# Tokenize the prompt
42
data = data[-seq_len:]
43
# Get the model output
44
logits, *_ = model(data)
45
logits = logits[-1]
46
# Get the model prediction (greedy)
47
res = sampler(logits)
48
data = torch.cat([data, res[None, :]], dim=0)
49
# Add the prediction for logging
50
for j in range(n_samples):
51
logs[j] += [('' + ds.itos[res[j]], Text.value)]
52
53
# Print the sampled output
54
for j in range(n_samples):
55
logger.log(logs[j])
56
57
58
def main():
59
model, ds = get_model_dataset('074d4004cc6b11ecad7a0242ac1c0002')
60
model.eval()
61
62
with monit.section('greedy'):
63
sample(model, ds, GreedySampler(), 4, 32, 128, 'It is')
64
65
with monit.section('temperature=1.'):
66
sample(model, ds, TemperatureSampler(1.), 4, 32, 128, 'It is')
67
with monit.section('temperature=.1'):
68
sample(model, ds, TemperatureSampler(.1), 4, 32, 128, 'It is')
69
with monit.section('temperature=10.'):
70
sample(model, ds, TemperatureSampler(10.), 4, 32, 128, 'It is')
71
72
with monit.section('top_k=5'):
73
sample(model, ds, TopKSampler(2, TemperatureSampler(1.)), 4, 32, 128, 'It is')
74
75
with monit.section('nucles p=.95'):
76
sample(model, ds, NucleusSampler(0.95, TemperatureSampler(1.)), 4, 32, 128, 'It is')
77
with monit.section('nucles p=.95'):
78
sample(model, ds, NucleusSampler(0.1, TemperatureSampler(1.)), 4, 32, 128, 'It is')
79
80
81
if __name__ == '__main__':
82
main()
83
84