Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
labmlai
GitHub Repository: labmlai/annotated_deep_learning_paper_implementations
Path: blob/master/labml_nn/neox/utils/trainer.py
4921 views
1
from typing import Optional, Set, List
2
3
import torch.nn as nn
4
import torch.optim
5
import torch.utils.data
6
from torch.cuda import amp
7
from torch.cuda.amp import GradScaler
8
9
from labml import monit, tracker
10
from labml.configs import BaseConfigs, option
11
from labml_nn.neox.utils.finetune import FineTuner
12
13
14
def get_trainable_params(model: nn.Module):
15
"""
16
### Get trainable parameters
17
18
:param model: is the model to train
19
:return: a list of parameters for training
20
"""
21
22
# Get all parameters
23
params = list(model.parameters())
24
# Filter parameters that require gradients
25
trainable_params = [p for p in params if p.requires_grad]
26
27
#
28
return trainable_params
29
30
31
class TrainerConf(BaseConfigs):
32
model: nn.Module
33
layers: List[nn.Module]
34
optimizer: torch.optim.Optimizer = 'Adam'
35
train_loader: torch.utils.data.DataLoader
36
valid_loader: Optional[torch.utils.data.DataLoader] = None,
37
device: torch.device = torch.device('cuda:0')
38
scaler: Optional[GradScaler] = 'Default'
39
is_amp: bool = True
40
dtype: torch.dtype = torch.float16
41
42
is_clone_layers: bool = True
43
44
loss_func: nn.Module = nn.CrossEntropyLoss()
45
checkpoints_per_epoch: int = 0
46
samples_per_epoch: int = 0
47
48
grad_norm: Optional[float] = 1.0
49
learning_rate: float = 3e-4
50
max_seq_len: int = 1024
51
batch_size: int = 64
52
epochs: int = 16
53
54
n_gpus: int = torch.cuda.device_count()
55
56
filter_layers: Optional[Set] = None
57
58
def get_loss(self, sample, dataset_split: str):
59
"""
60
:param dataset_split: train/valid
61
:param sample: is the sample
62
:return: the loss, output and the target
63
"""
64
data, target = sample
65
66
# Forward pass
67
with monit.section('Forward pass'):
68
output = self.model(data.to(self.device))
69
# Move targets to the same device as output
70
target = target.to(output.device)
71
# Calculate loss
72
loss = self.loss_func(output.view(target.numel(), -1), target.view(-1))
73
74
return loss, output, target
75
76
def train(self):
77
for epoch in monit.loop(self.epochs):
78
self.train_epoch()
79
tracker.new_line()
80
81
def sample(self, idx):
82
pass
83
84
def save_checkpoint(self, idx):
85
pass
86
87
def get_iterators(self):
88
# Iterate through the batches
89
iterators = [('train', self.train_loader)]
90
if self.valid_loader is not None:
91
iterators.append(('valid', self.valid_loader))
92
93
if self.samples_per_epoch > 0:
94
iterators.append((self.sample, [i for i in range(self.samples_per_epoch)]))
95
96
if self.checkpoints_per_epoch > 0:
97
iterators.append((self.save_checkpoint, [i for i in range(self.checkpoints_per_epoch)]))
98
99
return iterators
100
101
def train_epoch(self):
102
# Set model for train
103
self.model.train()
104
105
iterators = self.get_iterators()
106
for split_name, sample in monit.mix(1024, *iterators):
107
if split_name == 'train':
108
# Set gradients to zero
109
self.optimizer.zero_grad()
110
tracker.add_global_step()
111
112
with torch.set_grad_enabled(split_name == 'train'):
113
if self.is_amp:
114
# Forward pass
115
with amp.autocast():
116
loss, output, target = self.get_loss(sample, split_name)
117
else:
118
loss, output, target = self.get_loss(sample, split_name)
119
120
# Get predictions
121
pred = output.argmax(dim=-1)
122
# Calculate accuracy
123
accuracy = pred.eq(target).sum().item() / (target != -100).sum()
124
125
tracker.add({f'loss.{split_name}': loss, f'acc.{split_name}': accuracy * 100})
126
127
if split_name == 'train':
128
if self.scaler is not None:
129
# Backward pass
130
loss = self.scaler.scale(loss)
131
# tracker.add({'loss.scaled': loss})
132
133
with monit.section('Backward pass'):
134
loss.backward()
135
136
# Optimize
137
with monit.section('Optimize'):
138
if self.scaler is None:
139
self.optimizer.step()
140
else:
141
self.scaler.unscale_(self.optimizer)
142
if self.grad_norm is not None:
143
torch.nn.utils.clip_grad_norm_(get_trainable_params(self.model), self.grad_norm)
144
self.scaler.step(self.optimizer)
145
self.scaler.update()
146
147
tracker.save()
148
149
150
@option(TrainerConf.optimizer, 'Adam')
151
def adam_optimizer(c: TrainerConf):
152
if c.dtype == torch.float32:
153
return torch.optim.Adam(get_trainable_params(c.model), lr=c.learning_rate)
154
elif c.dtype == torch.float16:
155
from labml_nn.optimizers.adam_fp16 import AdamFP16
156
return AdamFP16(get_trainable_params(c.model), lr=c.learning_rate)
157
else:
158
raise NotImplementedError()
159
160
161
@option(TrainerConf.optimizer, 'SGD')
162
def sgd_optimizer(c: TrainerConf):
163
return torch.optim.SGD(get_trainable_params(c.model), lr=c.learning_rate)
164
165
166
@option(TrainerConf.scaler, 'Default')
167
def grad_scaler(c: TrainerConf):
168
if not c.is_amp:
169
return None
170
171
if c.dtype == torch.float16:
172
from labml_nn.optimizers.adam_fp16 import GradScalerFP16
173
return GradScalerFP16()
174
else:
175
return GradScaler()
176
177
178
class PipelineParallelTrainerConf(TrainerConf):
179
is_checkpointing: bool = False
180
chunks: int
181
182
fine_tuner: FineTuner
183
184