Path: blob/main/a5/mingpt-demo/mingpt/utils.py
1003 views
import random1import numpy as np2import torch3import torch.nn as nn4from torch.nn import functional as F56def set_seed(seed):7random.seed(seed)8np.random.seed(seed)9torch.manual_seed(seed)10torch.cuda.manual_seed_all(seed)1112def top_k_logits(logits, k):13v, ix = torch.topk(logits, k)14out = logits.clone()15out[out < v[:, [-1]]] = -float('Inf')16return out1718@torch.no_grad()19def sample(model, x, steps, temperature=1.0, sample=False, top_k=None):20"""21take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in22the sequence, feeding the predictions back into the model each time. Clearly the sampling23has quadratic complexity unlike an RNN that is only linear, and has a finite context window24of block_size, unlike an RNN that has an infinite context window.25"""26block_size = model.get_block_size()27model.eval()28for k in range(steps):29x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed30logits, _ = model(x_cond)31# pluck the logits at the final step and scale by temperature32logits = logits[:, -1, :] / temperature33# optionally crop probabilities to only the top k options34if top_k is not None:35logits = top_k_logits(logits, top_k)36# apply softmax to convert to probabilities37probs = F.softmax(logits, dim=-1)38# sample from the distribution or take the most likely39if sample:40ix = torch.multinomial(probs, num_samples=1)41else:42_, ix = torch.topk(probs, k=1, dim=-1)43# append to the sequence and continue44x = torch.cat((x, ix), dim=1)4546return x474849