Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
yiming-wange
GitHub Repository: yiming-wange/cs224n-2023-solution
Path: blob/main/a5/mingpt-demo/mingpt/trainer.py
1003 views
1
"""
2
Simple training loop; Boilerplate that could apply to any arbitrary neural network,
3
so nothing in this file really has anything to do with GPT specifically.
4
"""
5
6
import math
7
import logging
8
9
from tqdm import tqdm
10
import numpy as np
11
12
import torch
13
import torch.optim as optim
14
from torch.optim.lr_scheduler import LambdaLR
15
from torch.utils.data.dataloader import DataLoader
16
17
logger = logging.getLogger(__name__)
18
19
class TrainerConfig:
20
# optimization parameters
21
max_epochs = 10
22
batch_size = 64
23
learning_rate = 3e-4
24
betas = (0.9, 0.95)
25
grad_norm_clip = 1.0
26
weight_decay = 0.1 # only applied on matmul weights
27
# learning rate decay params: linear warmup followed by cosine decay to 10% of original
28
lr_decay = False
29
warmup_tokens = 375e6 # these two numbers come from the GPT-3 paper, but may not be good defaults elsewhere
30
final_tokens = 260e9 # (at what point we reach 10% of original LR)
31
# checkpoint settings
32
ckpt_path = None
33
num_workers = 0 # for DataLoader
34
35
def __init__(self, **kwargs):
36
for k,v in kwargs.items():
37
setattr(self, k, v)
38
39
class Trainer:
40
41
def __init__(self, model, train_dataset, test_dataset, config):
42
self.model = model
43
self.train_dataset = train_dataset
44
self.test_dataset = test_dataset
45
self.config = config
46
47
# take over whatever gpus are on the system
48
self.device = 'cpu'
49
if torch.cuda.is_available():
50
self.device = torch.cuda.current_device()
51
self.model = torch.nn.DataParallel(self.model).to(self.device)
52
53
def save_checkpoint(self):
54
# DataParallel wrappers keep raw model object in .module attribute
55
raw_model = self.model.module if hasattr(self.model, "module") else self.model
56
logger.info("saving %s", self.config.ckpt_path)
57
torch.save(raw_model.state_dict(), self.config.ckpt_path)
58
59
def train(self):
60
model, config = self.model, self.config
61
raw_model = model.module if hasattr(self.model, "module") else model
62
optimizer = raw_model.configure_optimizers(config)
63
64
def run_epoch(split):
65
is_train = split == 'train'
66
model.train(is_train)
67
data = self.train_dataset if is_train else self.test_dataset
68
loader = DataLoader(data, shuffle=True, pin_memory=True,
69
batch_size=config.batch_size,
70
num_workers=config.num_workers)
71
72
losses = []
73
pbar = tqdm(enumerate(loader), total=len(loader)) if is_train else enumerate(loader)
74
for it, (x, y) in pbar:
75
76
# place data on the correct device
77
x = x.to(self.device)
78
y = y.to(self.device)
79
80
# forward the model
81
with torch.set_grad_enabled(is_train):
82
logits, loss = model(x, y)
83
loss = loss.mean() # collapse all losses if they are scattered on multiple gpus
84
losses.append(loss.item())
85
86
if is_train:
87
88
# backprop and update the parameters
89
model.zero_grad()
90
loss.backward()
91
torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip)
92
optimizer.step()
93
94
# decay the learning rate based on our progress
95
if config.lr_decay:
96
self.tokens += (y >= 0).sum() # number of tokens processed this step (i.e. label is not -100)
97
if self.tokens < config.warmup_tokens:
98
# linear warmup
99
lr_mult = float(self.tokens) / float(max(1, config.warmup_tokens))
100
else:
101
# cosine learning rate decay
102
progress = float(self.tokens - config.warmup_tokens) / float(max(1, config.final_tokens - config.warmup_tokens))
103
lr_mult = max(0.1, 0.5 * (1.0 + math.cos(math.pi * progress)))
104
lr = config.learning_rate * lr_mult
105
for param_group in optimizer.param_groups:
106
param_group['lr'] = lr
107
else:
108
lr = config.learning_rate
109
110
# report progress
111
pbar.set_description(f"epoch {epoch+1} iter {it}: train loss {loss.item():.5f}. lr {lr:e}")
112
113
if not is_train:
114
test_loss = float(np.mean(losses))
115
logger.info("test loss: %f", test_loss)
116
return test_loss
117
118
best_loss = float('inf')
119
self.tokens = 0 # counter used for learning rate decay
120
for epoch in range(config.max_epochs):
121
122
run_epoch('train')
123
if self.test_dataset is not None:
124
test_loss = run_epoch('test')
125
126
# supports early stopping based on the test loss, or just save always if no test set is provided
127
good_model = self.test_dataset is None or test_loss < best_loss
128
if self.config.ckpt_path is not None and good_model:
129
best_loss = test_loss
130
self.save_checkpoint()
131
132