import random
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def top_k_logits(logits, k):
v, ix = torch.topk(logits, k)
out = logits.clone()
out[out < v[:, [-1]]] = -float('Inf')
return out
@torch.no_grad()
def sample(model, x, steps, temperature=1.0, sample=False, top_k=None):
"""
take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in
the sequence, feeding the predictions back into the model each time. Clearly the sampling
has quadratic complexity unlike an RNN that is only linear, and has a finite context window
of block_size, unlike an RNN that has an infinite context window.
"""
block_size = model.get_block_size()
model.eval()
for k in range(steps):
x_cond = x if x.size(1) <= block_size else x[:, -block_size:]
logits, _ = model(x_cond)
logits = logits[:, -1, :] / temperature
if top_k is not None:
logits = top_k_logits(logits, top_k)
probs = F.softmax(logits, dim=-1)
if sample:
ix = torch.multinomial(probs, num_samples=1)
else:
_, ix = torch.topk(probs, k=1, dim=-1)
x = torch.cat((x, ix), dim=1)
return x
def get_name_prediction(model, dataset_object, input_string):
x = torch.tensor([dataset_object.stoi[s] for s in input_string], dtype=torch.long)[None,...].to(trainer_obj.device)
pred = utils.sample(model, x, 32, sample=False)[0]
completion = ''.join([train_dataset.itos[int(i)] for i in pred])
pred = completion.split('⁇')[1]
return pred
def evaluate_places(filepath, predicted_places):
""" Computes percent of correctly predicted birth places.
Arguments:
filepath: path to a file with our name, birth place data.
predicted_places: a list of strings representing the
predicted birth place of each person.
Returns: (total, correct), floats
"""
with open(filepath, encoding='utf-8') as fin:
lines = [x.strip().split('\t') for x in fin]
if len(lines[0]) == 1:
print('No gold birth places provided; returning (0,0)')
return (0,0)
true_places = [x[1] for x in lines]
total = len(true_places)
assert total == len(predicted_places)
correct = len(list(filter(lambda x: x[0] == x[1],
zip(true_places, predicted_places))))
return (float(total),float(correct))