Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
yiming-wange
GitHub Repository: yiming-wange/cs224n-2023-solution
Path: blob/main/a5/src/trainer.py
1003 views
1
"""
2
Simple training loop; Boilerplate that could apply to any arbitrary neural network,
3
so nothing in this file really has anything to do with GPT specifically.
4
"""
5
6
import math
7
import logging
8
9
from tqdm import tqdm
10
import numpy as np
11
12
import torch
13
import torch.optim as optim
14
from torch.optim.lr_scheduler import LambdaLR
15
from torch.utils.data.dataloader import DataLoader
16
17
logger = logging.getLogger(__name__)
18
19
class TrainerConfig:
20
# optimization parameters
21
max_epochs = 10
22
batch_size = 64
23
learning_rate = 3e-4
24
betas = (0.9, 0.95)
25
grad_norm_clip = 1.0
26
weight_decay = 0.1 # only applied on matmul weights
27
# learning rate decay params: linear warmup followed by cosine decay to 10% of original
28
lr_decay = False
29
warmup_tokens = 375e6 # these two numbers come from the GPT-3 paper, but may not be good defaults elsewhere
30
final_tokens = 260e9 # (at what point we reach 10% of original LR)
31
# checkpoint settings
32
ckpt_path = None
33
num_workers = 0 # for DataLoader
34
writer = None
35
36
def __init__(self, **kwargs):
37
for k,v in kwargs.items():
38
setattr(self, k, v)
39
40
class Trainer:
41
42
def __init__(self, model, train_dataset, test_dataset, config):
43
self.model = model
44
self.train_dataset = train_dataset
45
self.test_dataset = test_dataset
46
self.config = config
47
48
# take over whatever gpus are on the system
49
self.device = 'cpu'
50
if torch.cuda.is_available():
51
self.device = torch.cuda.current_device()
52
self.model = torch.nn.DataParallel(self.model).to(self.device)
53
54
def save_checkpoint(self):
55
if self.config.ckpt_path is not None:
56
ckpt_model = self.model.module if hasattr(self.model, "module") else self.model
57
logger.info("saving %s", self.config.ckpt_path)
58
torch.save(ckpt_model.state_dict(), self.config.ckpt_path)
59
60
def train(self):
61
model, config = self.model, self.config
62
63
# create the optimizer
64
no_decay = ["bias", "LayerNorm.weight"]
65
params_decay = [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)]
66
params_nodecay = [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)]
67
optim_groups = [
68
{"params": params_decay, "weight_decay": config.weight_decay},
69
{"params": params_nodecay, "weight_decay": 0.0},
70
]
71
optimizer = optim.AdamW(optim_groups, lr=config.learning_rate, betas=config.betas)
72
step = 0
73
def run_epoch(split):
74
nonlocal step
75
is_train = split == 'train'
76
model.train(is_train)
77
data = self.train_dataset if is_train else self.test_dataset
78
loader = DataLoader(data, batch_size=config.batch_size, num_workers=config.num_workers)
79
80
losses = []
81
pbar = tqdm(enumerate(loader), total=len(loader)) if is_train else enumerate(loader)
82
83
for it, (x, y) in pbar:
84
85
# place data on the correct device
86
x = x.to(self.device)
87
y = y.to(self.device)
88
89
# forward the model
90
with torch.set_grad_enabled(is_train):
91
logits, loss = model(x, y)
92
loss = loss.mean() # collapse all losses if they are scattered on multiple gpus
93
losses.append(loss.item())
94
95
if is_train:
96
97
# backprop and update the parameters
98
model.zero_grad()
99
loss.backward()
100
torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip)
101
optimizer.step()
102
103
# decay the learning rate based on our progress
104
if config.lr_decay:
105
self.tokens += (y >= 0).sum() # number of tokens processed this step (i.e. label is not -100)
106
if self.tokens < config.warmup_tokens:
107
# linear warmup
108
lr_mult = float(self.tokens) / float(max(1, config.warmup_tokens))
109
else:
110
# cosine learning rate decay
111
progress = float(self.tokens - config.warmup_tokens) / float(max(1, config.final_tokens - config.warmup_tokens))
112
lr_mult = max(0.1, 0.5 * (1.0 + math.cos(math.pi * progress)))
113
lr = config.learning_rate * lr_mult
114
for param_group in optimizer.param_groups:
115
param_group['lr'] = lr
116
else:
117
lr = config.learning_rate
118
119
# report progress
120
pbar.set_description(f"epoch {epoch+1} iter {it}: train loss {loss.item():.5f}. lr {lr:e}")
121
122
if config.writer is not None:
123
config.writer.add_scalar('train/loss', loss.item(), step)
124
config.writer.add_scalar('train/lr', lr, step)
125
126
step += 1
127
if not is_train:
128
logger.info("test loss: %f", np.mean(losses))
129
130
self.tokens = 0 # counter used for learning rate decay
131
for epoch in range(config.max_epochs):
132
133
run_epoch('train')
134
if self.test_dataset is not None:
135
run_epoch('test')
136
137
self.save_checkpoint()
138
139