Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
yiming-wange
GitHub Repository: yiming-wange/cs224n-2023-solution
Path: blob/main/a5/mingpt-demo/mingpt/utils.py
1003 views
1
import random
2
import numpy as np
3
import torch
4
import torch.nn as nn
5
from torch.nn import functional as F
6
7
def set_seed(seed):
8
random.seed(seed)
9
np.random.seed(seed)
10
torch.manual_seed(seed)
11
torch.cuda.manual_seed_all(seed)
12
13
def top_k_logits(logits, k):
14
v, ix = torch.topk(logits, k)
15
out = logits.clone()
16
out[out < v[:, [-1]]] = -float('Inf')
17
return out
18
19
@torch.no_grad()
20
def sample(model, x, steps, temperature=1.0, sample=False, top_k=None):
21
"""
22
take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in
23
the sequence, feeding the predictions back into the model each time. Clearly the sampling
24
has quadratic complexity unlike an RNN that is only linear, and has a finite context window
25
of block_size, unlike an RNN that has an infinite context window.
26
"""
27
block_size = model.get_block_size()
28
model.eval()
29
for k in range(steps):
30
x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed
31
logits, _ = model(x_cond)
32
# pluck the logits at the final step and scale by temperature
33
logits = logits[:, -1, :] / temperature
34
# optionally crop probabilities to only the top k options
35
if top_k is not None:
36
logits = top_k_logits(logits, top_k)
37
# apply softmax to convert to probabilities
38
probs = F.softmax(logits, dim=-1)
39
# sample from the distribution or take the most likely
40
if sample:
41
ix = torch.multinomial(probs, num_samples=1)
42
else:
43
_, ix = torch.topk(probs, k=1, dim=-1)
44
# append to the sequence and continue
45
x = torch.cat((x, ix), dim=1)
46
47
return x
48
49