Path: blob/main/a5/mingpt-demo/mingpt/trainer.py
1003 views
"""1Simple training loop; Boilerplate that could apply to any arbitrary neural network,2so nothing in this file really has anything to do with GPT specifically.3"""45import math6import logging78from tqdm import tqdm9import numpy as np1011import torch12import torch.optim as optim13from torch.optim.lr_scheduler import LambdaLR14from torch.utils.data.dataloader import DataLoader1516logger = logging.getLogger(__name__)1718class TrainerConfig:19# optimization parameters20max_epochs = 1021batch_size = 6422learning_rate = 3e-423betas = (0.9, 0.95)24grad_norm_clip = 1.025weight_decay = 0.1 # only applied on matmul weights26# learning rate decay params: linear warmup followed by cosine decay to 10% of original27lr_decay = False28warmup_tokens = 375e6 # these two numbers come from the GPT-3 paper, but may not be good defaults elsewhere29final_tokens = 260e9 # (at what point we reach 10% of original LR)30# checkpoint settings31ckpt_path = None32num_workers = 0 # for DataLoader3334def __init__(self, **kwargs):35for k,v in kwargs.items():36setattr(self, k, v)3738class Trainer:3940def __init__(self, model, train_dataset, test_dataset, config):41self.model = model42self.train_dataset = train_dataset43self.test_dataset = test_dataset44self.config = config4546# take over whatever gpus are on the system47self.device = 'cpu'48if torch.cuda.is_available():49self.device = torch.cuda.current_device()50self.model = torch.nn.DataParallel(self.model).to(self.device)5152def save_checkpoint(self):53# DataParallel wrappers keep raw model object in .module attribute54raw_model = self.model.module if hasattr(self.model, "module") else self.model55logger.info("saving %s", self.config.ckpt_path)56torch.save(raw_model.state_dict(), self.config.ckpt_path)5758def train(self):59model, config = self.model, self.config60raw_model = model.module if hasattr(self.model, "module") else model61optimizer = raw_model.configure_optimizers(config)6263def run_epoch(split):64is_train = split == 'train'65model.train(is_train)66data = self.train_dataset if is_train else self.test_dataset67loader = DataLoader(data, shuffle=True, pin_memory=True,68batch_size=config.batch_size,69num_workers=config.num_workers)7071losses = []72pbar = tqdm(enumerate(loader), total=len(loader)) if is_train else enumerate(loader)73for it, (x, y) in pbar:7475# place data on the correct device76x = x.to(self.device)77y = y.to(self.device)7879# forward the model80with torch.set_grad_enabled(is_train):81logits, loss = model(x, y)82loss = loss.mean() # collapse all losses if they are scattered on multiple gpus83losses.append(loss.item())8485if is_train:8687# backprop and update the parameters88model.zero_grad()89loss.backward()90torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip)91optimizer.step()9293# decay the learning rate based on our progress94if config.lr_decay:95self.tokens += (y >= 0).sum() # number of tokens processed this step (i.e. label is not -100)96if self.tokens < config.warmup_tokens:97# linear warmup98lr_mult = float(self.tokens) / float(max(1, config.warmup_tokens))99else:100# cosine learning rate decay101progress = float(self.tokens - config.warmup_tokens) / float(max(1, config.final_tokens - config.warmup_tokens))102lr_mult = max(0.1, 0.5 * (1.0 + math.cos(math.pi * progress)))103lr = config.learning_rate * lr_mult104for param_group in optimizer.param_groups:105param_group['lr'] = lr106else:107lr = config.learning_rate108109# report progress110pbar.set_description(f"epoch {epoch+1} iter {it}: train loss {loss.item():.5f}. lr {lr:e}")111112if not is_train:113test_loss = float(np.mean(losses))114logger.info("test loss: %f", test_loss)115return test_loss116117best_loss = float('inf')118self.tokens = 0 # counter used for learning rate decay119for epoch in range(config.max_epochs):120121run_epoch('train')122if self.test_dataset is not None:123test_loss = run_epoch('test')124125# supports early stopping based on the test loss, or just save always if no test set is provided126good_model = self.test_dataset is None or test_loss < best_loss127if self.config.ckpt_path is not None and good_model:128best_loss = test_loss129self.save_checkpoint()130131132