Path: blob/master/labml_nn/sampling/experiment_tiny.py
4910 views
from typing import Tuple12import torch34from labml import experiment, monit5from labml import logger6from labml.logger import Text7from labml_nn.helpers.datasets import TextDataset8from labml_nn.sampling import Sampler9from labml_nn.sampling.greedy import GreedySampler10from labml_nn.sampling.nucleus import NucleusSampler11from labml_nn.sampling.temperature import TemperatureSampler12from labml_nn.sampling.top_k import TopKSampler13from labml_nn.transformers.basic.autoregressive_experiment import Configs, AutoregressiveTransformer141516def get_model_dataset(run_uuid: str) -> Tuple[AutoregressiveTransformer, TextDataset]:17experiment.evaluate()1819conf = Configs()2021experiment.configs(conf, experiment.load_configs(run_uuid))2223experiment.load(run_uuid)2425experiment.add_pytorch_models({'model': conf.model})2627experiment.start()2829return conf.model, conf.text303132def sample(model, ds, sampler: Sampler, n_samples: int, n_tokens: int, seq_len: int, prompt: str):33with torch.no_grad():34data = torch.tile(ds.text_to_i(prompt)[:, None], (1, n_samples))3536# Collect output for printing37logs = [[(prompt, Text.meta)] for _ in range(n_samples)]38# Sample 25 tokens39for i in monit.iterate('Sample', n_tokens):40# Tokenize the prompt41data = data[-seq_len:]42# Get the model output43logits, *_ = model(data)44logits = logits[-1]45# Get the model prediction (greedy)46res = sampler(logits)47data = torch.cat([data, res[None, :]], dim=0)48# Add the prediction for logging49for j in range(n_samples):50logs[j] += [('' + ds.itos[res[j]], Text.value)]5152# Print the sampled output53for j in range(n_samples):54logger.log(logs[j])555657def main():58model, ds = get_model_dataset('074d4004cc6b11ecad7a0242ac1c0002')59model.eval()6061with monit.section('greedy'):62sample(model, ds, GreedySampler(), 4, 32, 128, 'It is')6364with monit.section('temperature=1.'):65sample(model, ds, TemperatureSampler(1.), 4, 32, 128, 'It is')66with monit.section('temperature=.1'):67sample(model, ds, TemperatureSampler(.1), 4, 32, 128, 'It is')68with monit.section('temperature=10.'):69sample(model, ds, TemperatureSampler(10.), 4, 32, 128, 'It is')7071with monit.section('top_k=5'):72sample(model, ds, TopKSampler(2, TemperatureSampler(1.)), 4, 32, 128, 'It is')7374with monit.section('nucles p=.95'):75sample(model, ds, NucleusSampler(0.95, TemperatureSampler(1.)), 4, 32, 128, 'It is')76with monit.section('nucles p=.95'):77sample(model, ds, NucleusSampler(0.1, TemperatureSampler(1.)), 4, 32, 128, 'It is')787980if __name__ == '__main__':81main()828384