"""
Simple training loop; Boilerplate that could apply to any arbitrary neural network,
so nothing in this file really has anything to do with GPT specifically.
"""
import math
import logging
from tqdm import tqdm
import numpy as np
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data.dataloader import DataLoader
logger = logging.getLogger(__name__)
class TrainerConfig:
max_epochs = 10
batch_size = 64
learning_rate = 3e-4
betas = (0.9, 0.95)
grad_norm_clip = 1.0
weight_decay = 0.1
lr_decay = False
warmup_tokens = 375e6
final_tokens = 260e9
ckpt_path = None
num_workers = 0
writer = None
def __init__(self, **kwargs):
for k,v in kwargs.items():
setattr(self, k, v)
class Trainer:
def __init__(self, model, train_dataset, test_dataset, config):
self.model = model
self.train_dataset = train_dataset
self.test_dataset = test_dataset
self.config = config
self.device = 'cpu'
if torch.cuda.is_available():
self.device = torch.cuda.current_device()
self.model = torch.nn.DataParallel(self.model).to(self.device)
def save_checkpoint(self):
if self.config.ckpt_path is not None:
ckpt_model = self.model.module if hasattr(self.model, "module") else self.model
logger.info("saving %s", self.config.ckpt_path)
torch.save(ckpt_model.state_dict(), self.config.ckpt_path)
def train(self):
model, config = self.model, self.config
no_decay = ["bias", "LayerNorm.weight"]
params_decay = [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)]
params_nodecay = [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)]
optim_groups = [
{"params": params_decay, "weight_decay": config.weight_decay},
{"params": params_nodecay, "weight_decay": 0.0},
]
optimizer = optim.AdamW(optim_groups, lr=config.learning_rate, betas=config.betas)
step = 0
def run_epoch(split):
nonlocal step
is_train = split == 'train'
model.train(is_train)
data = self.train_dataset if is_train else self.test_dataset
loader = DataLoader(data, batch_size=config.batch_size, num_workers=config.num_workers)
losses = []
pbar = tqdm(enumerate(loader), total=len(loader)) if is_train else enumerate(loader)
for it, (x, y) in pbar:
x = x.to(self.device)
y = y.to(self.device)
with torch.set_grad_enabled(is_train):
logits, loss = model(x, y)
loss = loss.mean()
losses.append(loss.item())
if is_train:
model.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip)
optimizer.step()
if config.lr_decay:
self.tokens += (y >= 0).sum()
if self.tokens < config.warmup_tokens:
lr_mult = float(self.tokens) / float(max(1, config.warmup_tokens))
else:
progress = float(self.tokens - config.warmup_tokens) / float(max(1, config.final_tokens - config.warmup_tokens))
lr_mult = max(0.1, 0.5 * (1.0 + math.cos(math.pi * progress)))
lr = config.learning_rate * lr_mult
for param_group in optimizer.param_groups:
param_group['lr'] = lr
else:
lr = config.learning_rate
pbar.set_description(f"epoch {epoch+1} iter {it}: train loss {loss.item():.5f}. lr {lr:e}")
if config.writer is not None:
config.writer.add_scalar('train/loss', loss.item(), step)
config.writer.add_scalar('train/lr', lr, step)
step += 1
if not is_train:
logger.info("test loss: %f", np.mean(losses))
self.tokens = 0
for epoch in range(config.max_epochs):
run_epoch('train')
if self.test_dataset is not None:
run_epoch('test')
self.save_checkpoint()