Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
prophesier
GitHub Repository: prophesier/diff-svc
Path: blob/main/training/task/base_task.py
694 views
1
from datetime import datetime
2
import shutil
3
4
import matplotlib
5
6
matplotlib.use('Agg')
7
8
from utils.hparams import hparams, set_hparams
9
import random
10
import sys
11
import numpy as np
12
import torch.distributed as dist
13
from pytorch_lightning.loggers import TensorBoardLogger
14
from utils.pl_utils import LatestModelCheckpoint, BaseTrainer, data_loader, DDP
15
from torch import nn
16
import torch.utils.data
17
import utils
18
import logging
19
import os
20
21
torch.multiprocessing.set_sharing_strategy(os.getenv('TORCH_SHARE_STRATEGY', 'file_system'))
22
23
log_format = '%(asctime)s %(message)s'
24
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
25
format=log_format, datefmt='%m/%d %I:%M:%S %p')
26
27
class BaseTask(nn.Module):
28
'''
29
Base class for training tasks.
30
1. *load_ckpt*:
31
load checkpoint;
32
2. *training_step*:
33
record and log the loss;
34
3. *optimizer_step*:
35
run backwards step;
36
4. *start*:
37
load training configs, backup code, log to tensorboard, start training;
38
5. *configure_ddp* and *init_ddp_connection*:
39
start parallel training.
40
41
Subclasses should define:
42
1. *build_model*, *build_optimizer*, *build_scheduler*:
43
how to build the model, the optimizer and the training scheduler;
44
2. *_training_step*:
45
one training step of the model;
46
3. *validation_end* and *_validation_end*:
47
postprocess the validation output.
48
'''
49
def __init__(self, *args, **kwargs):
50
# dataset configs
51
super(BaseTask, self).__init__(*args, **kwargs)
52
self.current_epoch = 0
53
self.global_step = 0
54
self.loaded_optimizer_states_dict = {}
55
self.trainer = None
56
self.logger = None
57
self.on_gpu = False
58
self.use_dp = False
59
self.use_ddp = False
60
self.example_input_array = None
61
62
self.max_tokens = hparams['max_tokens']
63
self.max_sentences = hparams['max_sentences']
64
self.max_eval_tokens = hparams['max_eval_tokens']
65
if self.max_eval_tokens == -1:
66
hparams['max_eval_tokens'] = self.max_eval_tokens = self.max_tokens
67
self.max_eval_sentences = hparams['max_eval_sentences']
68
if self.max_eval_sentences == -1:
69
hparams['max_eval_sentences'] = self.max_eval_sentences = self.max_sentences
70
71
self.model = None
72
self.training_losses_meter = None
73
74
###########
75
# Training, validation and testing
76
###########
77
def build_model(self):
78
raise NotImplementedError
79
80
def load_ckpt(self, ckpt_base_dir, current_model_name=None, model_name='model', force=True, strict=True):
81
# This function is updated on 2021.12.13
82
if current_model_name is None:
83
current_model_name = model_name
84
utils.load_ckpt(self.__getattr__(current_model_name), ckpt_base_dir, current_model_name, force, strict)
85
86
def on_epoch_start(self):
87
self.training_losses_meter = {'total_loss': utils.AvgrageMeter()}
88
89
def _training_step(self, sample, batch_idx, optimizer_idx):
90
"""
91
92
:param sample:
93
:param batch_idx:
94
:return: total loss: torch.Tensor, loss_log: dict
95
"""
96
raise NotImplementedError
97
98
def training_step(self, sample, batch_idx, optimizer_idx=-1):
99
loss_ret = self._training_step(sample, batch_idx, optimizer_idx)
100
self.opt_idx = optimizer_idx
101
if loss_ret is None:
102
return {'loss': None}
103
total_loss, log_outputs = loss_ret
104
log_outputs = utils.tensors_to_scalars(log_outputs)
105
for k, v in log_outputs.items():
106
if k not in self.training_losses_meter:
107
self.training_losses_meter[k] = utils.AvgrageMeter()
108
if not np.isnan(v):
109
self.training_losses_meter[k].update(v)
110
self.training_losses_meter['total_loss'].update(total_loss.item())
111
112
try:
113
log_outputs['lr'] = self.scheduler.get_lr()
114
if isinstance(log_outputs['lr'], list):
115
log_outputs['lr'] = log_outputs['lr'][0]
116
except:
117
pass
118
119
# log_outputs['all_loss'] = total_loss.item()
120
progress_bar_log = log_outputs
121
tb_log = {f'tr/{k}': v for k, v in log_outputs.items()}
122
return {
123
'loss': total_loss,
124
'progress_bar': progress_bar_log,
125
'log': tb_log
126
}
127
128
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx):
129
optimizer.step()
130
optimizer.zero_grad()
131
if self.scheduler is not None:
132
self.scheduler.step(self.global_step // hparams['accumulate_grad_batches'])
133
134
def on_epoch_end(self):
135
loss_outputs = {k: round(v.avg, 4) for k, v in self.training_losses_meter.items()}
136
print(f"\n==============\n "
137
f"Epoch {self.current_epoch} ended. Steps: {self.global_step}. {loss_outputs}"
138
f"\n==============\n")
139
140
def validation_step(self, sample, batch_idx):
141
"""
142
143
:param sample:
144
:param batch_idx:
145
:return: output: dict
146
"""
147
raise NotImplementedError
148
149
def _validation_end(self, outputs):
150
"""
151
152
:param outputs:
153
:return: loss_output: dict
154
"""
155
raise NotImplementedError
156
157
def validation_end(self, outputs):
158
loss_output = self._validation_end(outputs)
159
print(f"\n==============\n "
160
f"valid results: {loss_output}"
161
f"\n==============\n")
162
return {
163
'log': {f'val/{k}': v for k, v in loss_output.items()},
164
'val_loss': loss_output['total_loss']
165
}
166
167
def build_scheduler(self, optimizer):
168
raise NotImplementedError
169
170
def build_optimizer(self, model):
171
raise NotImplementedError
172
173
def configure_optimizers(self):
174
optm = self.build_optimizer(self.model)
175
self.scheduler = self.build_scheduler(optm)
176
return [optm]
177
178
def test_start(self):
179
pass
180
181
def test_step(self, sample, batch_idx):
182
return self.validation_step(sample, batch_idx)
183
184
def test_end(self, outputs):
185
return self.validation_end(outputs)
186
187
###########
188
# Running configuration
189
###########
190
191
@classmethod
192
def start(cls):
193
set_hparams()
194
os.environ['MASTER_PORT'] = str(random.randint(15000, 30000))
195
random.seed(hparams['seed'])
196
np.random.seed(hparams['seed'])
197
task = cls()
198
work_dir = hparams['work_dir']
199
trainer = BaseTrainer(checkpoint_callback=LatestModelCheckpoint(
200
filepath=work_dir,
201
verbose=True,
202
monitor='val_loss',
203
mode='min',
204
num_ckpt_keep=hparams['num_ckpt_keep'],
205
save_best=hparams['save_best'],
206
period=1 if hparams['save_ckpt'] else 100000
207
),
208
logger=TensorBoardLogger(
209
save_dir=work_dir,
210
name='lightning_logs',
211
version='lastest'
212
),
213
gradient_clip_val=hparams['clip_grad_norm'],
214
val_check_interval=hparams['val_check_interval'],
215
row_log_interval=hparams['log_interval'],
216
max_updates=hparams['max_updates'],
217
num_sanity_val_steps=hparams['num_sanity_val_steps'] if not hparams[
218
'validate'] else 10000,
219
accumulate_grad_batches=hparams['accumulate_grad_batches'])
220
if not hparams['infer']: # train
221
# copy_code = input(f'{hparams["save_codes"]} code backup? y/n: ') == 'y'
222
# copy_code = True # backup code every time
223
# if copy_code:
224
# t = datetime.now().strftime('%Y%m%d%H%M%S')
225
# code_dir = f'{work_dir}/codes/{t}'
226
# # TODO: test filesystem calls
227
# os.makedirs(code_dir, exist_ok=True)
228
# # subprocess.check_call(f'mkdir "{code_dir}"', shell=True)
229
# for c in hparams['save_codes']:
230
# shutil.copytree(c, code_dir, dirs_exist_ok=True)
231
# # subprocess.check_call(f'xcopy "{c}" "{code_dir}/" /s /e /y', shell=True)
232
# print(f"| Copied codes to {code_dir}.")
233
trainer.checkpoint_callback.task = task
234
trainer.fit(task)
235
else:
236
trainer.test(task)
237
238
def configure_ddp(self, model, device_ids):
239
model = DDP(
240
model,
241
device_ids=device_ids,
242
find_unused_parameters=True
243
)
244
if dist.get_rank() != 0 and not hparams['debug']:
245
sys.stdout = open(os.devnull, "w")
246
sys.stderr = open(os.devnull, "w")
247
random.seed(hparams['seed'])
248
np.random.seed(hparams['seed'])
249
return model
250
251
def training_end(self, *args, **kwargs):
252
return None
253
254
def init_ddp_connection(self, proc_rank, world_size):
255
set_hparams(print_hparams=False)
256
# guarantees unique ports across jobs from same grid search
257
default_port = 12910
258
# if user gave a port number, use that one instead
259
try:
260
default_port = os.environ['MASTER_PORT']
261
except Exception:
262
os.environ['MASTER_PORT'] = str(default_port)
263
264
# figure out the root node addr
265
root_node = '127.0.0.2'
266
root_node = self.trainer.resolve_root_node_address(root_node)
267
os.environ['MASTER_ADDR'] = root_node
268
dist.init_process_group('nccl', rank=proc_rank, world_size=world_size)
269
270
@data_loader
271
def train_dataloader(self):
272
return None
273
274
@data_loader
275
def test_dataloader(self):
276
return None
277
278
@data_loader
279
def val_dataloader(self):
280
return None
281
282
def on_load_checkpoint(self, checkpoint):
283
pass
284
285
def on_save_checkpoint(self, checkpoint):
286
pass
287
288
def on_sanity_check_start(self):
289
pass
290
291
def on_train_start(self):
292
pass
293
294
def on_train_end(self):
295
pass
296
297
def on_batch_start(self, batch):
298
pass
299
300
def on_batch_end(self):
301
pass
302
303
def on_pre_performance_check(self):
304
pass
305
306
def on_post_performance_check(self):
307
pass
308
309
def on_before_zero_grad(self, optimizer):
310
pass
311
312
def on_after_backward(self):
313
pass
314
315
def backward(self, loss, optimizer):
316
loss.backward()
317
318
def grad_norm(self, norm_type):
319
results = {}
320
total_norm = 0
321
for name, p in self.named_parameters():
322
if p.requires_grad:
323
try:
324
param_norm = p.grad.data.norm(norm_type)
325
total_norm += param_norm ** norm_type
326
norm = param_norm ** (1 / norm_type)
327
328
grad = round(norm.data.cpu().numpy().flatten()[0], 3)
329
results['grad_{}_norm_{}'.format(norm_type, name)] = grad
330
except Exception:
331
# this param had no grad
332
pass
333
334
total_norm = total_norm ** (1. / norm_type)
335
grad = round(total_norm.data.cpu().numpy().flatten()[0], 3)
336
results['grad_{}_norm_total'.format(norm_type)] = grad
337
return results
338
339