import utils
import trainer
import model
import dataset
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
from torch.nn import functional as F
from torch.utils.tensorboard import SummaryWriter
import random
import argparse
random.seed(0)
argp = argparse.ArgumentParser()
argp.add_argument('function', help="Choose pretrain, finetune, or evaluate")
argp.add_argument('variant', help="Choose vanilla or perceiver")
argp.add_argument('--bottleneck_dim', type=int, default=32)
argp.add_argument('pretrain_corpus_path', default=None)
argp.add_argument('--reading_params_path', default=None)
argp.add_argument('--writing_params_path', default=None)
argp.add_argument('--finetune_corpus_path', default=None)
argp.add_argument('--eval_corpus_path', default=None)
argp.add_argument('--outputs_path', default=None)
argp.add_argument('--pretrain_lr', default=6e-3, type=float)
argp.add_argument('--finetune_lr', default=6e-4, type=float)
argp.add_argument('--tb_expt_name', help='debug string for tb log.',
default='run')
args = argp.parse_args()
device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
writer = SummaryWriter(log_dir='expt/%s/%s_%s_%d_pt_lr_%f_ft_lr_%f' % (
args.function,
args.tb_expt_name,
args.variant,
args.bottleneck_dim,
args.pretrain_lr,
args.finetune_lr))
block_size = 128
text = open(args.pretrain_corpus_path, encoding='utf-8').read()
pretrain_dataset = dataset.CharCorruptionDataset(text, block_size)
mconf = model.GPTConfig(pretrain_dataset.vocab_size, pretrain_dataset.block_size,
n_layer=4, n_head=8, n_embd=256)
"""
Don't change above here; write your code below
"""
if args.variant == 'vanilla':
model = model.GPT(mconf).to(device)
elif args.variant == 'perceiver':
pass
else:
raise ValueError("Unknown model variant")
if args.function == 'pretrain':
assert args.writing_params_path is not None
hyperparameters = {
"max_epochs": 650,
"batch_size": 128,
"learning_rate": 6e-3,
"lr_decay": True,
"warmup_tokens": 512*20,
"final_tokens": 200*len(pretrain_dataset)*block_size,
"num_workers": 4
}
tconf = trainer.TrainerConfig(**hyperparameters)
trainer.Trainer(model, pretrain_dataset, None, tconf).train()
torch.save(model.state_dict(), args.writing_params_path)
raise NotImplementedError
elif args.function == 'finetune':
assert args.writing_params_path is not None
assert args.finetune_corpus_path is not None
if args.reading_params_path is not None:
model.load_state_dict(torch.load(args.reading_params_path))
hyperparameters = {
"max_epochs": 10,
"batch_size": 256,
"learning_rate": 6e-4,
"lr_decay": True,
"warmup_tokens": 512*20,
"final_tokens": 200*len(pretrain_dataset)*block_size,
"num_workers": 4
}
else:
hyperparameters = {
"max_epochs": 75,
"batch_size": 256,
"learning_rate": 6e-4,
"lr_decay": True,
"warmup_tokens": 512*20,
"final_tokens": 200*len(pretrain_dataset)*block_size,
"num_workers": 4
}
finetune_corpus = open(args.finetune_corpus_path).read()
finetune_dataset = dataset.NameDataset(pretrain_dataset, finetune_corpus)
if args.eval_corpus_path is not None:
eval_corpus = open(args.eval_corpus_path).read()
eval_dataset = dataset.NameDataset(pretrain_dataset, eval_corpus)
else:
eval_dataset = None
tconf = trainer.TrainerConfig(**hyperparameters)
trainer.Trainer(model, finetune_dataset, eval_dataset, tconf).train()
torch.save(model.state_dict(), args.writing_params_path)
raise NotImplementedError
elif args.function == 'evaluate':
assert args.outputs_path is not None
assert args.reading_params_path is not None
assert args.eval_corpus_path is not None
model.load_state_dict(torch.load(args.reading_params_path))
correct = 0
total = 0
with open(args.outputs_path, 'w', encoding='utf-8') as fout:
predictions = []
for line in tqdm(open(args.eval_corpus_path, encoding='utf-8')):
x = line.split('\t')[0]
x = x + '⁇'
x = torch.tensor([pretrain_dataset.stoi[s]
for s in x], dtype=torch.long)[None, ...].to(device)
pred = utils.sample(model, x, 32, sample=False)[0]
completion = ''.join([pretrain_dataset.itos[int(i)] for i in pred])
pred = completion.split('⁇')[1]
predictions.append(pred)
fout.write(pred + '\n')
total, correct = utils.evaluate_places(
args.eval_corpus_path, predictions)
if total > 0:
print('Correct: {} out of {}: {}%'.format(
correct, total, correct/total*100))
else:
print('Predictions written to {}; no targets provided'
.format(args.outputs_path))