Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
prophesier
GitHub Repository: prophesier/diff-svc
Path: blob/main/utils/pl_utils.py
694 views
1
import matplotlib
2
from torch.nn import DataParallel
3
from torch.nn.parallel import DistributedDataParallel
4
5
matplotlib.use('Agg')
6
import glob
7
import itertools
8
import subprocess
9
import threading
10
import traceback
11
12
from pytorch_lightning.callbacks import GradientAccumulationScheduler
13
from pytorch_lightning.callbacks import ModelCheckpoint
14
15
from functools import wraps
16
from torch.cuda._utils import _get_device_index
17
import numpy as np
18
import torch.optim
19
import torch.utils.data
20
import copy
21
import logging
22
import os
23
import re
24
import sys
25
import torch
26
import torch.distributed as dist
27
import torch.multiprocessing as mp
28
import tqdm
29
from torch.optim.optimizer import Optimizer
30
31
32
def get_a_var(obj): # pragma: no cover
33
if isinstance(obj, torch.Tensor):
34
return obj
35
36
if isinstance(obj, list) or isinstance(obj, tuple):
37
for result in map(get_a_var, obj):
38
if isinstance(result, torch.Tensor):
39
return result
40
if isinstance(obj, dict):
41
for result in map(get_a_var, obj.items()):
42
if isinstance(result, torch.Tensor):
43
return result
44
return None
45
46
47
def data_loader(fn):
48
"""
49
Decorator to make any fx with this use the lazy property
50
:param fn:
51
:return:
52
"""
53
54
wraps(fn)
55
attr_name = '_lazy_' + fn.__name__
56
57
def _get_data_loader(self):
58
try:
59
value = getattr(self, attr_name)
60
except AttributeError:
61
try:
62
value = fn(self) # Lazy evaluation, done only once.
63
if (
64
value is not None and
65
not isinstance(value, list) and
66
fn.__name__ in ['test_dataloader', 'val_dataloader']
67
):
68
value = [value]
69
except AttributeError as e:
70
# Guard against AttributeError suppression. (Issue #142)
71
traceback.print_exc()
72
error = f'{fn.__name__}: An AttributeError was encountered: ' + str(e)
73
raise RuntimeError(error) from e
74
setattr(self, attr_name, value) # Memoize evaluation.
75
return value
76
77
return _get_data_loader
78
79
80
def parallel_apply(modules, inputs, kwargs_tup=None, devices=None): # pragma: no cover
81
r"""Applies each `module` in :attr:`modules` in parallel on arguments
82
contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword)
83
on each of :attr:`devices`.
84
85
Args:
86
modules (Module): modules to be parallelized
87
inputs (tensor): inputs to the modules
88
devices (list of int or torch.device): CUDA devices
89
90
:attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and
91
:attr:`devices` (if given) should all have same length. Moreover, each
92
element of :attr:`inputs` can either be a single object as the only argument
93
to a module, or a collection of positional arguments.
94
"""
95
assert len(modules) == len(inputs)
96
if kwargs_tup is not None:
97
assert len(modules) == len(kwargs_tup)
98
else:
99
kwargs_tup = ({},) * len(modules)
100
if devices is not None:
101
assert len(modules) == len(devices)
102
else:
103
devices = [None] * len(modules)
104
devices = list(map(lambda x: _get_device_index(x, True), devices))
105
lock = threading.Lock()
106
results = {}
107
grad_enabled = torch.is_grad_enabled()
108
109
def _worker(i, module, input, kwargs, device=None):
110
torch.set_grad_enabled(grad_enabled)
111
if device is None:
112
device = get_a_var(input).get_device()
113
try:
114
with torch.cuda.device(device):
115
# this also avoids accidental slicing of `input` if it is a Tensor
116
if not isinstance(input, (list, tuple)):
117
input = (input,)
118
119
# ---------------
120
# CHANGE
121
if module.training:
122
output = module.training_step(*input, **kwargs)
123
124
elif module.testing:
125
output = module.test_step(*input, **kwargs)
126
127
else:
128
output = module.validation_step(*input, **kwargs)
129
# ---------------
130
131
with lock:
132
results[i] = output
133
except Exception as e:
134
with lock:
135
results[i] = e
136
137
# make sure each module knows what training state it's in...
138
# fixes weird bug where copies are out of sync
139
root_m = modules[0]
140
for m in modules[1:]:
141
m.training = root_m.training
142
m.testing = root_m.testing
143
144
if len(modules) > 1:
145
threads = [threading.Thread(target=_worker,
146
args=(i, module, input, kwargs, device))
147
for i, (module, input, kwargs, device) in
148
enumerate(zip(modules, inputs, kwargs_tup, devices))]
149
150
for thread in threads:
151
thread.start()
152
for thread in threads:
153
thread.join()
154
else:
155
_worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])
156
157
outputs = []
158
for i in range(len(inputs)):
159
output = results[i]
160
if isinstance(output, Exception):
161
raise output
162
outputs.append(output)
163
return outputs
164
165
166
def _find_tensors(obj): # pragma: no cover
167
r"""
168
Recursively find all tensors contained in the specified object.
169
"""
170
if isinstance(obj, torch.Tensor):
171
return [obj]
172
if isinstance(obj, (list, tuple)):
173
return itertools.chain(*map(_find_tensors, obj))
174
if isinstance(obj, dict):
175
return itertools.chain(*map(_find_tensors, obj.values()))
176
return []
177
178
179
class DDP(DistributedDataParallel):
180
"""
181
Override the forward call in lightning so it goes to training and validation step respectively
182
"""
183
184
def parallel_apply(self, replicas, inputs, kwargs):
185
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
186
187
def forward(self, *inputs, **kwargs): # pragma: no cover
188
self._sync_params()
189
if self.device_ids:
190
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
191
if len(self.device_ids) == 1:
192
# --------------
193
# LIGHTNING MOD
194
# --------------
195
# normal
196
# output = self.module(*inputs[0], **kwargs[0])
197
# lightning
198
if self.module.training:
199
output = self.module.training_step(*inputs[0], **kwargs[0])
200
elif self.module.testing:
201
output = self.module.test_step(*inputs[0], **kwargs[0])
202
else:
203
output = self.module.validation_step(*inputs[0], **kwargs[0])
204
else:
205
outputs = self.parallel_apply(self._module_copies[:len(inputs)], inputs, kwargs)
206
output = self.gather(outputs, self.output_device)
207
else:
208
# normal
209
output = self.module(*inputs, **kwargs)
210
211
if torch.is_grad_enabled():
212
# We'll return the output object verbatim since it is a freeform
213
# object. We need to find any tensors in this object, though,
214
# because we need to figure out which parameters were used during
215
# this forward pass, to ensure we short circuit reduction for any
216
# unused parameters. Only if `find_unused_parameters` is set.
217
if self.find_unused_parameters:
218
self.reducer.prepare_for_backward(list(_find_tensors(output)))
219
else:
220
self.reducer.prepare_for_backward([])
221
return output
222
223
224
class DP(DataParallel):
225
"""
226
Override the forward call in lightning so it goes to training and validation step respectively
227
"""
228
229
def forward(self, *inputs, **kwargs):
230
if not self.device_ids:
231
return self.module(*inputs, **kwargs)
232
233
for t in itertools.chain(self.module.parameters(), self.module.buffers()):
234
if t.device != self.src_device_obj:
235
raise RuntimeError("module must have its parameters and buffers "
236
"on device {} (device_ids[0]) but found one of "
237
"them on device: {}".format(self.src_device_obj, t.device))
238
239
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
240
if len(self.device_ids) == 1:
241
# lightning
242
if self.module.training:
243
return self.module.training_step(*inputs[0], **kwargs[0])
244
elif self.module.testing:
245
return self.module.test_step(*inputs[0], **kwargs[0])
246
else:
247
return self.module.validation_step(*inputs[0], **kwargs[0])
248
249
replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
250
outputs = self.parallel_apply(replicas, inputs, kwargs)
251
return self.gather(outputs, self.output_device)
252
253
def parallel_apply(self, replicas, inputs, kwargs):
254
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
255
256
257
class GradientAccumulationScheduler:
258
def __init__(self, scheduling: dict):
259
if scheduling == {}: # empty dict error
260
raise TypeError("Empty dict cannot be interpreted correct")
261
262
for key in scheduling.keys():
263
if not isinstance(key, int) or not isinstance(scheduling[key], int):
264
raise TypeError("All epoches and accumulation factor must be integers")
265
266
minimal_epoch = min(scheduling.keys())
267
if minimal_epoch < 1:
268
msg = f"Epochs indexing from 1, epoch {minimal_epoch} cannot be interpreted correct"
269
raise IndexError(msg)
270
elif minimal_epoch != 1: # if user didnt define first epoch accumulation factor
271
scheduling.update({1: 1})
272
273
self.scheduling = scheduling
274
self.epochs = sorted(scheduling.keys())
275
276
def on_epoch_begin(self, epoch, trainer):
277
epoch += 1 # indexing epochs from 1
278
for i in reversed(range(len(self.epochs))):
279
if epoch >= self.epochs[i]:
280
trainer.accumulate_grad_batches = self.scheduling.get(self.epochs[i])
281
break
282
283
284
class LatestModelCheckpoint(ModelCheckpoint):
285
def __init__(self, filepath, monitor='val_loss', verbose=0, num_ckpt_keep=5,
286
save_weights_only=False, mode='auto', period=1, prefix='model', save_best=True):
287
super(ModelCheckpoint, self).__init__()
288
self.monitor = monitor
289
self.verbose = verbose
290
self.filepath = filepath
291
os.makedirs(filepath, exist_ok=True)
292
self.num_ckpt_keep = num_ckpt_keep
293
self.save_best = save_best
294
self.save_weights_only = save_weights_only
295
self.period = period
296
self.epochs_since_last_check = 0
297
self.prefix = prefix
298
self.best_k_models = {}
299
# {filename: monitor}
300
self.kth_best_model = ''
301
self.save_top_k = 1
302
self.task = None
303
if mode == 'min':
304
self.monitor_op = np.less
305
self.best = np.Inf
306
self.mode = 'min'
307
elif mode == 'max':
308
self.monitor_op = np.greater
309
self.best = -np.Inf
310
self.mode = 'max'
311
else:
312
if 'acc' in self.monitor or self.monitor.startswith('fmeasure'):
313
self.monitor_op = np.greater
314
self.best = -np.Inf
315
self.mode = 'max'
316
else:
317
self.monitor_op = np.less
318
self.best = np.Inf
319
self.mode = 'min'
320
if os.path.exists(f'{self.filepath}/best_valid.npy'):
321
self.best = np.load(f'{self.filepath}/best_valid.npy')[0]
322
323
def get_all_ckpts(self):
324
return sorted(glob.glob(f'{self.filepath}/{self.prefix}_ckpt_steps_*.ckpt'),
325
key=lambda x: -int(re.findall('.*steps\_(\d+)\.ckpt', x)[0]))
326
327
def on_epoch_end(self, epoch, logs=None):
328
logs = logs or {}
329
self.epochs_since_last_check += 1
330
best_filepath = f'{self.filepath}/{self.prefix}_ckpt_best.pt'
331
if self.epochs_since_last_check >= self.period:
332
self.epochs_since_last_check = 0
333
filepath = f'{self.filepath}/{self.prefix}_ckpt_steps_{self.task.global_step}.ckpt'
334
if self.verbose > 0:
335
logging.info(f'Epoch {epoch:05d}@{self.task.global_step}: saving model to {filepath}')
336
self._save_model(filepath)
337
for old_ckpt in self.get_all_ckpts()[self.num_ckpt_keep:]:
338
# TODO: test filesystem calls
339
os.remove(old_ckpt)
340
# subprocess.check_call(f'del "{old_ckpt}"', shell=True)
341
if self.verbose > 0:
342
logging.info(f'Delete ckpt: {os.path.basename(old_ckpt)}')
343
current = logs.get(self.monitor)
344
if current is not None and self.save_best:
345
if self.monitor_op(current, self.best):
346
self.best = current
347
if self.verbose > 0:
348
logging.info(
349
f'Epoch {epoch:05d}@{self.task.global_step}: {self.monitor} reached'
350
f' {current:0.5f} (best {self.best:0.5f}), saving model to'
351
f' {best_filepath} as top 1')
352
self._save_model(best_filepath)
353
np.save(f'{self.filepath}/best_valid.npy', [self.best])
354
355
def _save_model(self,path):
356
return self.save_function(path)
357
358
359
360
class BaseTrainer:
361
def __init__(
362
self,
363
logger=True,
364
checkpoint_callback=True,
365
default_save_path=None,
366
gradient_clip_val=0,
367
process_position=0,
368
gpus=-1,
369
log_gpu_memory=None,
370
show_progress_bar=True,
371
track_grad_norm=-1,
372
check_val_every_n_epoch=1,
373
accumulate_grad_batches=1,
374
max_updates=1000,
375
min_epochs=1,
376
val_check_interval=1.0,
377
log_save_interval=100,
378
row_log_interval=10,
379
print_nan_grads=False,
380
weights_summary='full',
381
num_sanity_val_steps=5,
382
resume_from_checkpoint=None,
383
):
384
self.log_gpu_memory = log_gpu_memory
385
self.gradient_clip_val = gradient_clip_val
386
self.check_val_every_n_epoch = check_val_every_n_epoch
387
self.track_grad_norm = track_grad_norm
388
self.on_gpu = True if (gpus and torch.cuda.is_available()) else False
389
self.process_position = process_position
390
self.weights_summary = weights_summary
391
self.max_updates = max_updates
392
self.min_epochs = min_epochs
393
self.num_sanity_val_steps = num_sanity_val_steps
394
self.print_nan_grads = print_nan_grads
395
self.resume_from_checkpoint = resume_from_checkpoint
396
self.default_save_path = default_save_path
397
398
# training bookeeping
399
self.total_batch_idx = 0
400
self.running_loss = []
401
self.avg_loss = 0
402
self.batch_idx = 0
403
self.tqdm_metrics = {}
404
self.callback_metrics = {}
405
self.num_val_batches = 0
406
self.num_training_batches = 0
407
self.num_test_batches = 0
408
self.get_train_dataloader = None
409
self.get_test_dataloaders = None
410
self.get_val_dataloaders = None
411
self.is_iterable_train_dataloader = False
412
413
# training state
414
self.model = None
415
self.testing = False
416
self.disable_validation = False
417
self.lr_schedulers = []
418
self.optimizers = None
419
self.global_step = 0
420
self.current_epoch = 0
421
self.total_batches = 0
422
423
# configure checkpoint callback
424
self.checkpoint_callback = checkpoint_callback
425
self.checkpoint_callback.save_function = self.save_checkpoint
426
self.weights_save_path = self.checkpoint_callback.filepath
427
428
# accumulated grads
429
self.configure_accumulated_gradients(accumulate_grad_batches)
430
431
# allow int, string and gpu list
432
self.data_parallel_device_ids = [
433
int(x) for x in os.environ.get("CUDA_VISIBLE_DEVICES", "").split(",") if x != '']
434
if len(self.data_parallel_device_ids) == 0:
435
self.root_gpu = None
436
self.on_gpu = False
437
else:
438
self.root_gpu = self.data_parallel_device_ids[0]
439
self.on_gpu = True
440
441
# distributed backend choice
442
self.use_ddp = False
443
self.use_dp = False
444
self.single_gpu = False
445
self.distributed_backend = 'ddp' if self.num_gpus > 0 else 'dp'
446
self.set_distributed_mode(self.distributed_backend)
447
448
self.proc_rank = 0
449
self.world_size = 1
450
self.node_rank = 0
451
452
# can't init progress bar here because starting a new process
453
# means the progress_bar won't survive pickling
454
self.show_progress_bar = show_progress_bar
455
456
# logging
457
self.log_save_interval = log_save_interval
458
self.val_check_interval = val_check_interval
459
self.logger = logger
460
self.logger.rank = 0
461
self.row_log_interval = row_log_interval
462
463
@property
464
def num_gpus(self):
465
gpus = self.data_parallel_device_ids
466
if gpus is None:
467
return 0
468
else:
469
return len(gpus)
470
471
@property
472
def data_parallel(self):
473
return self.use_dp or self.use_ddp
474
475
def get_model(self):
476
is_dp_module = isinstance(self.model, (DDP, DP))
477
model = self.model.module if is_dp_module else self.model
478
return model
479
480
# -----------------------------
481
# MODEL TRAINING
482
# -----------------------------
483
def fit(self, model):
484
if self.use_ddp:
485
mp.spawn(self.ddp_train, nprocs=self.num_gpus, args=(model,))
486
else:
487
model.model = model.build_model()
488
if not self.testing:
489
self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers())
490
if self.use_dp:
491
model.cuda(self.root_gpu)
492
model = DP(model, device_ids=self.data_parallel_device_ids)
493
elif self.single_gpu:
494
model.cuda(self.root_gpu)
495
self.run_pretrain_routine(model)
496
return 1
497
498
def init_optimizers(self, optimizers):
499
500
# single optimizer
501
if isinstance(optimizers, Optimizer):
502
return [optimizers], []
503
504
# two lists
505
elif len(optimizers) == 2 and isinstance(optimizers[0], list):
506
optimizers, lr_schedulers = optimizers
507
return optimizers, lr_schedulers
508
509
# single list or tuple
510
elif isinstance(optimizers, list) or isinstance(optimizers, tuple):
511
return optimizers, []
512
513
def run_pretrain_routine(self, model):
514
"""Sanity check a few things before starting actual training.
515
516
:param model:
517
"""
518
ref_model = model
519
if self.data_parallel:
520
ref_model = model.module
521
522
# give model convenience properties
523
ref_model.trainer = self
524
525
# set local properties on the model
526
self.copy_trainer_model_properties(ref_model)
527
528
# link up experiment object
529
if self.logger is not None:
530
ref_model.logger = self.logger
531
self.logger.save()
532
533
if self.use_ddp:
534
dist.barrier()
535
536
# set up checkpoint callback
537
# self.configure_checkpoint_callback()
538
539
# transfer data loaders from model
540
self.get_dataloaders(ref_model)
541
542
# track model now.
543
# if cluster resets state, the model will update with the saved weights
544
self.model = model
545
546
# restore training and model before hpc call
547
self.restore_weights(model)
548
549
# when testing requested only run test and return
550
if self.testing:
551
self.run_evaluation(test=True)
552
return
553
554
# check if we should run validation during training
555
self.disable_validation = self.num_val_batches == 0
556
557
# run tiny validation (if validation defined)
558
# to make sure program won't crash during val
559
ref_model.on_sanity_check_start()
560
ref_model.on_train_start()
561
if not self.disable_validation and self.num_sanity_val_steps > 0:
562
# init progress bars for validation sanity check
563
pbar = tqdm.tqdm(desc='Validation sanity check',
564
total=self.num_sanity_val_steps * len(self.get_val_dataloaders()),
565
leave=False, position=2 * self.process_position,
566
disable=not self.show_progress_bar, dynamic_ncols=True, unit='batch')
567
self.main_progress_bar = pbar
568
# dummy validation progress bar
569
self.val_progress_bar = tqdm.tqdm(disable=True)
570
571
self.evaluate(model, self.get_val_dataloaders(), self.num_sanity_val_steps, self.testing)
572
573
# close progress bars
574
self.main_progress_bar.close()
575
self.val_progress_bar.close()
576
577
# init progress bar
578
pbar = tqdm.tqdm(leave=True, position=2 * self.process_position,
579
disable=not self.show_progress_bar, dynamic_ncols=True, unit='batch',
580
file=sys.stdout)
581
self.main_progress_bar = pbar
582
583
# clear cache before training
584
if self.on_gpu:
585
torch.cuda.empty_cache()
586
587
# CORE TRAINING LOOP
588
self.train()
589
590
def test(self, model):
591
self.testing = True
592
self.fit(model)
593
594
@property
595
def training_tqdm_dict(self):
596
tqdm_dict = {
597
'step': '{}'.format(self.global_step),
598
}
599
tqdm_dict.update(self.tqdm_metrics)
600
return tqdm_dict
601
602
# --------------------
603
# restore ckpt
604
# --------------------
605
def restore_weights(self, model):
606
"""
607
To restore weights we have two cases.
608
First, attempt to restore hpc weights. If successful, don't restore
609
other weights.
610
611
Otherwise, try to restore actual weights
612
:param model:
613
:return:
614
"""
615
# clear cache before restore
616
if self.on_gpu:
617
torch.cuda.empty_cache()
618
619
if self.resume_from_checkpoint is not None:
620
self.restore(self.resume_from_checkpoint, on_gpu=self.on_gpu)
621
else:
622
# restore weights if same exp version
623
self.restore_state_if_checkpoint_exists(model)
624
625
# wait for all models to restore weights
626
if self.use_ddp:
627
# wait for all processes to catch up
628
dist.barrier()
629
630
# clear cache after restore
631
if self.on_gpu:
632
torch.cuda.empty_cache()
633
634
def restore_state_if_checkpoint_exists(self, model):
635
did_restore = False
636
637
# do nothing if there's not dir or callback
638
no_ckpt_callback = (self.checkpoint_callback is None) or (not self.checkpoint_callback)
639
if no_ckpt_callback or not os.path.exists(self.checkpoint_callback.filepath):
640
return did_restore
641
642
# restore trainer state and model if there is a weight for this experiment
643
last_steps = -1
644
last_ckpt_name = None
645
646
# find last epoch
647
checkpoints = os.listdir(self.checkpoint_callback.filepath)
648
for name in checkpoints:
649
if '.ckpt' in name and not name.endswith('part'):
650
if 'steps_' in name:
651
steps = name.split('steps_')[1]
652
steps = int(re.sub('[^0-9]', '', steps))
653
654
if steps > last_steps:
655
last_steps = steps
656
last_ckpt_name = name
657
658
# restore last checkpoint
659
if last_ckpt_name is not None:
660
last_ckpt_path = os.path.join(self.checkpoint_callback.filepath, last_ckpt_name)
661
self.restore(last_ckpt_path, self.on_gpu)
662
logging.info(f'model and trainer restored from checkpoint: {last_ckpt_path}')
663
did_restore = True
664
665
return did_restore
666
667
def restore(self, checkpoint_path, on_gpu):
668
checkpoint = torch.load(checkpoint_path, map_location='cpu')
669
670
# load model state
671
model = self.get_model()
672
673
# load the state_dict on the model automatically
674
model.load_state_dict(checkpoint['state_dict'], strict=False)
675
if on_gpu:
676
model.cuda(self.root_gpu)
677
# load training state (affects trainer only)
678
self.restore_training_state(checkpoint)
679
model.global_step = self.global_step
680
del checkpoint
681
682
try:
683
if dist.is_initialized() and dist.get_rank() > 0:
684
return
685
except Exception as e:
686
print(e)
687
return
688
689
def restore_training_state(self, checkpoint):
690
"""
691
Restore trainer state.
692
Model will get its change to update
693
:param checkpoint:
694
:return:
695
"""
696
if self.checkpoint_callback is not None and self.checkpoint_callback is not False:
697
# return allowing checkpoints with meta information (global_step, etc)
698
self.checkpoint_callback.best = checkpoint['checkpoint_callback_best']
699
700
self.global_step = checkpoint['global_step']
701
self.current_epoch = checkpoint['epoch']
702
703
if self.testing:
704
return
705
706
# restore the optimizers
707
optimizer_states = checkpoint['optimizer_states']
708
for optimizer, opt_state in zip(self.optimizers, optimizer_states):
709
if optimizer is None:
710
return
711
optimizer.load_state_dict(opt_state)
712
713
# move optimizer to GPU 1 weight at a time
714
# avoids OOM
715
if self.root_gpu is not None:
716
for state in optimizer.state.values():
717
for k, v in state.items():
718
if isinstance(v, torch.Tensor):
719
state[k] = v.cuda(self.root_gpu)
720
721
# restore the lr schedulers
722
lr_schedulers = checkpoint['lr_schedulers']
723
for scheduler, lrs_state in zip(self.lr_schedulers, lr_schedulers):
724
scheduler.load_state_dict(lrs_state)
725
726
# --------------------
727
# MODEL SAVE CHECKPOINT
728
# --------------------
729
def _atomic_save(self, checkpoint, filepath):
730
"""Saves a checkpoint atomically, avoiding the creation of incomplete checkpoints.
731
732
This will create a temporary checkpoint with a suffix of ``.part``, then copy it to the final location once
733
saving is finished.
734
735
Args:
736
checkpoint (object): The object to save.
737
Built to be used with the ``dump_checkpoint`` method, but can deal with anything which ``torch.save``
738
accepts.
739
filepath (str|pathlib.Path): The path to which the checkpoint will be saved.
740
This points to the file that the checkpoint will be stored in.
741
"""
742
tmp_path = str(filepath) + ".part"
743
torch.save(checkpoint, tmp_path)
744
os.replace(tmp_path, filepath)
745
746
def save_checkpoint(self, filepath):
747
checkpoint = self.dump_checkpoint()
748
self._atomic_save(checkpoint, filepath)
749
750
def dump_checkpoint(self):
751
752
checkpoint = {
753
'epoch': self.current_epoch,
754
'global_step': self.global_step
755
}
756
757
if self.checkpoint_callback is not None and self.checkpoint_callback is not False:
758
checkpoint['checkpoint_callback_best'] = self.checkpoint_callback.best
759
760
# save optimizers
761
optimizer_states = []
762
for i, optimizer in enumerate(self.optimizers):
763
if optimizer is not None:
764
optimizer_states.append(optimizer.state_dict())
765
766
checkpoint['optimizer_states'] = optimizer_states
767
768
# save lr schedulers
769
lr_schedulers = []
770
for i, scheduler in enumerate(self.lr_schedulers):
771
lr_schedulers.append(scheduler.state_dict())
772
773
checkpoint['lr_schedulers'] = lr_schedulers
774
775
# add the hparams and state_dict from the model
776
model = self.get_model()
777
checkpoint['state_dict'] = model.state_dict()
778
# give the model a chance to add a few things
779
model.on_save_checkpoint(checkpoint)
780
781
return checkpoint
782
783
def copy_trainer_model_properties(self, model):
784
if isinstance(model, DP):
785
ref_model = model.module
786
elif isinstance(model, DDP):
787
ref_model = model.module
788
else:
789
ref_model = model
790
791
for m in [model, ref_model]:
792
m.trainer = self
793
m.on_gpu = self.on_gpu
794
m.use_dp = self.use_dp
795
m.use_ddp = self.use_ddp
796
m.testing = self.testing
797
m.single_gpu = self.single_gpu
798
799
def transfer_batch_to_gpu(self, batch, gpu_id):
800
# base case: object can be directly moved using `cuda` or `to`
801
if callable(getattr(batch, 'cuda', None)):
802
return batch.cuda(gpu_id, non_blocking=True)
803
804
elif callable(getattr(batch, 'to', None)):
805
return batch.to(torch.device('cuda', gpu_id), non_blocking=True)
806
807
# when list
808
elif isinstance(batch, list):
809
for i, x in enumerate(batch):
810
batch[i] = self.transfer_batch_to_gpu(x, gpu_id)
811
return batch
812
813
# when tuple
814
elif isinstance(batch, tuple):
815
batch = list(batch)
816
for i, x in enumerate(batch):
817
batch[i] = self.transfer_batch_to_gpu(x, gpu_id)
818
return tuple(batch)
819
820
# when dict
821
elif isinstance(batch, dict):
822
for k, v in batch.items():
823
batch[k] = self.transfer_batch_to_gpu(v, gpu_id)
824
825
return batch
826
827
# nothing matches, return the value as is without transform
828
return batch
829
830
def set_distributed_mode(self, distributed_backend):
831
# skip for CPU
832
if self.num_gpus == 0:
833
return
834
835
# single GPU case
836
# in single gpu case we allow ddp so we can train on multiple
837
# nodes, 1 gpu per node
838
elif self.num_gpus == 1:
839
self.single_gpu = True
840
self.use_dp = False
841
self.use_ddp = False
842
self.root_gpu = 0
843
self.data_parallel_device_ids = [0]
844
else:
845
if distributed_backend is not None:
846
self.use_dp = distributed_backend == 'dp'
847
self.use_ddp = distributed_backend == 'ddp'
848
elif distributed_backend is None:
849
self.use_dp = True
850
self.use_ddp = False
851
852
logging.info(f'gpu available: {torch.cuda.is_available()}, used: {self.on_gpu}')
853
854
def ddp_train(self, gpu_idx, model):
855
"""
856
Entry point into a DP thread
857
:param gpu_idx:
858
:param model:
859
:param cluster_obj:
860
:return:
861
"""
862
# otherwise default to node rank 0
863
self.node_rank = 0
864
865
# show progressbar only on progress_rank 0
866
self.show_progress_bar = self.show_progress_bar and self.node_rank == 0 and gpu_idx == 0
867
868
# determine which process we are and world size
869
if self.use_ddp:
870
self.proc_rank = self.node_rank * self.num_gpus + gpu_idx
871
self.world_size = self.num_gpus
872
873
# let the exp know the rank to avoid overwriting logs
874
if self.logger is not None:
875
self.logger.rank = self.proc_rank
876
877
# set up server using proc 0's ip address
878
# try to init for 20 times at max in case ports are taken
879
# where to store ip_table
880
model.trainer = self
881
model.init_ddp_connection(self.proc_rank, self.world_size)
882
883
# CHOOSE OPTIMIZER
884
# allow for lr schedulers as well
885
model.model = model.build_model()
886
if not self.testing:
887
self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers())
888
889
# MODEL
890
# copy model to each gpu
891
if self.distributed_backend == 'ddp':
892
torch.cuda.set_device(gpu_idx)
893
model.cuda(gpu_idx)
894
895
# set model properties before going into wrapper
896
self.copy_trainer_model_properties(model)
897
898
# override root GPU
899
self.root_gpu = gpu_idx
900
901
if self.distributed_backend == 'ddp':
902
device_ids = [gpu_idx]
903
else:
904
device_ids = None
905
906
# allow user to configure ddp
907
model = model.configure_ddp(model, device_ids)
908
909
# continue training routine
910
self.run_pretrain_routine(model)
911
912
def resolve_root_node_address(self, root_node):
913
if '[' in root_node:
914
name = root_node.split('[')[0]
915
number = root_node.split(',')[0]
916
if '-' in number:
917
number = number.split('-')[0]
918
919
number = re.sub('[^0-9]', '', number)
920
root_node = name + number
921
922
return root_node
923
924
def log_metrics(self, metrics, grad_norm_dic, step=None):
925
"""Logs the metric dict passed in.
926
927
:param metrics:
928
:param grad_norm_dic:
929
"""
930
# added metrics by Lightning for convenience
931
metrics['epoch'] = self.current_epoch
932
933
# add norms
934
metrics.update(grad_norm_dic)
935
936
# turn all tensors to scalars
937
scalar_metrics = self.metrics_to_scalars(metrics)
938
939
step = step if step is not None else self.global_step
940
# log actual metrics
941
if self.proc_rank == 0 and self.logger is not None:
942
self.logger.log_metrics(scalar_metrics, step=step)
943
self.logger.save()
944
945
def add_tqdm_metrics(self, metrics):
946
for k, v in metrics.items():
947
if type(v) is torch.Tensor:
948
v = v.item()
949
950
self.tqdm_metrics[k] = v
951
952
def metrics_to_scalars(self, metrics):
953
new_metrics = {}
954
for k, v in metrics.items():
955
if isinstance(v, torch.Tensor):
956
v = v.item()
957
958
if type(v) is dict:
959
v = self.metrics_to_scalars(v)
960
961
new_metrics[k] = v
962
963
return new_metrics
964
965
def process_output(self, output, train=False):
966
"""Reduces output according to the training mode.
967
968
Separates loss from logging and tqdm metrics
969
:param output:
970
:return:
971
"""
972
# ---------------
973
# EXTRACT CALLBACK KEYS
974
# ---------------
975
# all keys not progress_bar or log are candidates for callbacks
976
callback_metrics = {}
977
for k, v in output.items():
978
if k not in ['progress_bar', 'log', 'hiddens']:
979
callback_metrics[k] = v
980
981
if train and self.use_dp:
982
num_gpus = self.num_gpus
983
callback_metrics = self.reduce_distributed_output(callback_metrics, num_gpus)
984
985
for k, v in callback_metrics.items():
986
if isinstance(v, torch.Tensor):
987
callback_metrics[k] = v.item()
988
989
# ---------------
990
# EXTRACT PROGRESS BAR KEYS
991
# ---------------
992
try:
993
progress_output = output['progress_bar']
994
995
# reduce progress metrics for tqdm when using dp
996
if train and self.use_dp:
997
num_gpus = self.num_gpus
998
progress_output = self.reduce_distributed_output(progress_output, num_gpus)
999
1000
progress_bar_metrics = progress_output
1001
except Exception:
1002
progress_bar_metrics = {}
1003
1004
# ---------------
1005
# EXTRACT LOGGING KEYS
1006
# ---------------
1007
# extract metrics to log to experiment
1008
try:
1009
log_output = output['log']
1010
1011
# reduce progress metrics for tqdm when using dp
1012
if train and self.use_dp:
1013
num_gpus = self.num_gpus
1014
log_output = self.reduce_distributed_output(log_output, num_gpus)
1015
1016
log_metrics = log_output
1017
except Exception:
1018
log_metrics = {}
1019
1020
# ---------------
1021
# EXTRACT LOSS
1022
# ---------------
1023
# if output dict doesn't have the keyword loss
1024
# then assume the output=loss if scalar
1025
loss = None
1026
if train:
1027
try:
1028
loss = output['loss']
1029
except Exception:
1030
if type(output) is torch.Tensor:
1031
loss = output
1032
else:
1033
raise RuntimeError(
1034
'No `loss` value in the dictionary returned from `model.training_step()`.'
1035
)
1036
1037
# when using dp need to reduce the loss
1038
if self.use_dp:
1039
loss = self.reduce_distributed_output(loss, self.num_gpus)
1040
1041
# ---------------
1042
# EXTRACT HIDDEN
1043
# ---------------
1044
hiddens = output.get('hiddens')
1045
1046
# use every metric passed in as a candidate for callback
1047
callback_metrics.update(progress_bar_metrics)
1048
callback_metrics.update(log_metrics)
1049
1050
# convert tensors to numpy
1051
for k, v in callback_metrics.items():
1052
if isinstance(v, torch.Tensor):
1053
callback_metrics[k] = v.item()
1054
1055
return loss, progress_bar_metrics, log_metrics, callback_metrics, hiddens
1056
1057
def reduce_distributed_output(self, output, num_gpus):
1058
if num_gpus <= 1:
1059
return output
1060
1061
# when using DP, we get one output per gpu
1062
# average outputs and return
1063
if type(output) is torch.Tensor:
1064
return output.mean()
1065
1066
for k, v in output.items():
1067
# recurse on nested dics
1068
if isinstance(output[k], dict):
1069
output[k] = self.reduce_distributed_output(output[k], num_gpus)
1070
1071
# do nothing when there's a scalar
1072
elif isinstance(output[k], torch.Tensor) and output[k].dim() == 0:
1073
pass
1074
1075
# reduce only metrics that have the same number of gpus
1076
elif output[k].size(0) == num_gpus:
1077
reduced = torch.mean(output[k])
1078
output[k] = reduced
1079
return output
1080
1081
def clip_gradients(self):
1082
if self.gradient_clip_val > 0:
1083
model = self.get_model()
1084
torch.nn.utils.clip_grad_norm_(model.parameters(), self.gradient_clip_val)
1085
1086
def print_nan_gradients(self):
1087
model = self.get_model()
1088
for param in model.parameters():
1089
if (param.grad is not None) and torch.isnan(param.grad.float()).any():
1090
logging.info(param, param.grad)
1091
1092
def configure_accumulated_gradients(self, accumulate_grad_batches):
1093
self.accumulate_grad_batches = None
1094
1095
if isinstance(accumulate_grad_batches, dict):
1096
self.accumulation_scheduler = GradientAccumulationScheduler(accumulate_grad_batches)
1097
elif isinstance(accumulate_grad_batches, int):
1098
schedule = {1: accumulate_grad_batches}
1099
self.accumulation_scheduler = GradientAccumulationScheduler(schedule)
1100
else:
1101
raise TypeError("Gradient accumulation supports only int and dict types")
1102
1103
def get_dataloaders(self, model):
1104
if not self.testing:
1105
self.init_train_dataloader(model)
1106
self.init_val_dataloader(model)
1107
else:
1108
self.init_test_dataloader(model)
1109
1110
if self.use_ddp:
1111
dist.barrier()
1112
if not self.testing:
1113
self.get_train_dataloader()
1114
self.get_val_dataloaders()
1115
else:
1116
self.get_test_dataloaders()
1117
1118
def init_train_dataloader(self, model):
1119
self.fisrt_epoch = True
1120
self.get_train_dataloader = model.train_dataloader
1121
if isinstance(self.get_train_dataloader(), torch.utils.data.DataLoader):
1122
self.num_training_batches = len(self.get_train_dataloader())
1123
self.num_training_batches = int(self.num_training_batches)
1124
else:
1125
self.num_training_batches = float('inf')
1126
self.is_iterable_train_dataloader = True
1127
if isinstance(self.val_check_interval, int):
1128
self.val_check_batch = self.val_check_interval
1129
else:
1130
self._percent_range_check('val_check_interval')
1131
self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
1132
self.val_check_batch = max(1, self.val_check_batch)
1133
1134
def init_val_dataloader(self, model):
1135
self.get_val_dataloaders = model.val_dataloader
1136
self.num_val_batches = 0
1137
if self.get_val_dataloaders() is not None:
1138
if isinstance(self.get_val_dataloaders()[0], torch.utils.data.DataLoader):
1139
self.num_val_batches = sum(len(dataloader) for dataloader in self.get_val_dataloaders())
1140
self.num_val_batches = int(self.num_val_batches)
1141
else:
1142
self.num_val_batches = float('inf')
1143
1144
def init_test_dataloader(self, model):
1145
self.get_test_dataloaders = model.test_dataloader
1146
if self.get_test_dataloaders() is not None:
1147
if isinstance(self.get_test_dataloaders()[0], torch.utils.data.DataLoader):
1148
self.num_test_batches = sum(len(dataloader) for dataloader in self.get_test_dataloaders())
1149
self.num_test_batches = int(self.num_test_batches)
1150
else:
1151
self.num_test_batches = float('inf')
1152
1153
def evaluate(self, model, dataloaders, max_batches, test=False):
1154
"""Run evaluation code.
1155
1156
:param model: PT model
1157
:param dataloaders: list of PT dataloaders
1158
:param max_batches: Scalar
1159
:param test: boolean
1160
:return:
1161
"""
1162
# enable eval mode
1163
model.zero_grad()
1164
model.eval()
1165
1166
# copy properties for forward overrides
1167
self.copy_trainer_model_properties(model)
1168
1169
# disable gradients to save memory
1170
torch.set_grad_enabled(False)
1171
1172
if test:
1173
self.get_model().test_start()
1174
# bookkeeping
1175
outputs = []
1176
1177
# run training
1178
for dataloader_idx, dataloader in enumerate(dataloaders):
1179
dl_outputs = []
1180
for batch_idx, batch in enumerate(dataloader):
1181
1182
if batch is None: # pragma: no cover
1183
continue
1184
1185
# stop short when on fast_dev_run (sets max_batch=1)
1186
if batch_idx >= max_batches:
1187
break
1188
1189
# -----------------
1190
# RUN EVALUATION STEP
1191
# -----------------
1192
output = self.evaluation_forward(model,
1193
batch,
1194
batch_idx,
1195
dataloader_idx,
1196
test)
1197
1198
# track outputs for collation
1199
dl_outputs.append(output)
1200
1201
# batch done
1202
if test:
1203
self.test_progress_bar.update(1)
1204
else:
1205
self.val_progress_bar.update(1)
1206
outputs.append(dl_outputs)
1207
1208
# with a single dataloader don't pass an array
1209
if len(dataloaders) == 1:
1210
outputs = outputs[0]
1211
1212
# give model a chance to do something with the outputs (and method defined)
1213
model = self.get_model()
1214
if test:
1215
eval_results_ = model.test_end(outputs)
1216
else:
1217
eval_results_ = model.validation_end(outputs)
1218
eval_results = eval_results_
1219
1220
# enable train mode again
1221
model.train()
1222
1223
# enable gradients to save memory
1224
torch.set_grad_enabled(True)
1225
1226
return eval_results
1227
1228
def run_evaluation(self, test=False):
1229
# when testing make sure user defined a test step
1230
model = self.get_model()
1231
model.on_pre_performance_check()
1232
1233
# select dataloaders
1234
if test:
1235
dataloaders = self.get_test_dataloaders()
1236
max_batches = self.num_test_batches
1237
else:
1238
# val
1239
dataloaders = self.get_val_dataloaders()
1240
max_batches = self.num_val_batches
1241
1242
# init validation or test progress bar
1243
# main progress bar will already be closed when testing so initial position is free
1244
position = 2 * self.process_position + (not test)
1245
desc = 'Testing' if test else 'Validating'
1246
pbar = tqdm.tqdm(desc=desc, total=max_batches, leave=test, position=position,
1247
disable=not self.show_progress_bar, dynamic_ncols=True,
1248
unit='batch', file=sys.stdout)
1249
setattr(self, f'{"test" if test else "val"}_progress_bar', pbar)
1250
1251
# run evaluation
1252
eval_results = self.evaluate(self.model,
1253
dataloaders,
1254
max_batches,
1255
test)
1256
if eval_results is not None:
1257
_, prog_bar_metrics, log_metrics, callback_metrics, _ = self.process_output(
1258
eval_results)
1259
1260
# add metrics to prog bar
1261
self.add_tqdm_metrics(prog_bar_metrics)
1262
1263
# log metrics
1264
self.log_metrics(log_metrics, {})
1265
1266
# track metrics for callbacks
1267
self.callback_metrics.update(callback_metrics)
1268
1269
# hook
1270
model.on_post_performance_check()
1271
1272
# add model specific metrics
1273
tqdm_metrics = self.training_tqdm_dict
1274
if not test:
1275
self.main_progress_bar.set_postfix(**tqdm_metrics)
1276
1277
# close progress bar
1278
if test:
1279
self.test_progress_bar.close()
1280
else:
1281
self.val_progress_bar.close()
1282
1283
# model checkpointing
1284
if self.proc_rank == 0 and self.checkpoint_callback is not None and not test:
1285
self.checkpoint_callback.on_epoch_end(epoch=self.current_epoch,
1286
logs=self.callback_metrics)
1287
1288
def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test=False):
1289
# make dataloader_idx arg in validation_step optional
1290
args = [batch, batch_idx]
1291
# print(batch)
1292
if test and len(self.get_test_dataloaders()) > 1:
1293
args.append(dataloader_idx)
1294
1295
elif not test and len(self.get_val_dataloaders()) > 1:
1296
args.append(dataloader_idx)
1297
1298
# handle DP, DDP forward
1299
if self.use_ddp or self.use_dp:
1300
output = model(*args)
1301
return output
1302
1303
# single GPU
1304
if self.single_gpu:
1305
# for single GPU put inputs on gpu manually
1306
root_gpu = 0
1307
if isinstance(self.data_parallel_device_ids, list):
1308
root_gpu = self.data_parallel_device_ids[0]
1309
batch = self.transfer_batch_to_gpu(batch, root_gpu)
1310
args[0] = batch
1311
1312
# CPU
1313
if test:
1314
output = model.test_step(*args)
1315
else:
1316
output = model.validation_step(*args)
1317
1318
return output
1319
1320
def train(self):
1321
model = self.get_model()
1322
# run all epochs
1323
for epoch in range(self.current_epoch, 1000000):
1324
# set seed for distributed sampler (enables shuffling for each epoch)
1325
if self.use_ddp and hasattr(self.get_train_dataloader().sampler, 'set_epoch'):
1326
self.get_train_dataloader().sampler.set_epoch(epoch)
1327
1328
# get model
1329
model = self.get_model()
1330
1331
# update training progress in trainer and model
1332
model.current_epoch = epoch
1333
self.current_epoch = epoch
1334
1335
total_val_batches = 0
1336
if not self.disable_validation:
1337
# val can be checked multiple times in epoch
1338
is_val_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
1339
val_checks_per_epoch = self.num_training_batches // self.val_check_batch
1340
val_checks_per_epoch = val_checks_per_epoch if is_val_epoch else 0
1341
total_val_batches = self.num_val_batches * val_checks_per_epoch
1342
1343
# total batches includes multiple val checks
1344
self.total_batches = self.num_training_batches + total_val_batches
1345
self.batch_loss_value = 0 # accumulated grads
1346
1347
if self.is_iterable_train_dataloader:
1348
# for iterable train loader, the progress bar never ends
1349
num_iterations = None
1350
else:
1351
num_iterations = self.total_batches
1352
1353
# reset progress bar
1354
# .reset() doesn't work on disabled progress bar so we should check
1355
desc = f'Epoch {epoch + 1}' if not self.is_iterable_train_dataloader else ''
1356
self.main_progress_bar.set_description(desc)
1357
1358
# changing gradient according accumulation_scheduler
1359
self.accumulation_scheduler.on_epoch_begin(epoch, self)
1360
1361
# -----------------
1362
# RUN TNG EPOCH
1363
# -----------------
1364
self.run_training_epoch()
1365
1366
# update LR schedulers
1367
if self.lr_schedulers is not None:
1368
for lr_scheduler in self.lr_schedulers:
1369
lr_scheduler.step(epoch=self.current_epoch)
1370
1371
self.main_progress_bar.close()
1372
1373
model.on_train_end()
1374
1375
if self.logger is not None:
1376
self.logger.finalize("success")
1377
1378
def run_training_epoch(self):
1379
# before epoch hook
1380
if self.is_function_implemented('on_epoch_start'):
1381
model = self.get_model()
1382
model.on_epoch_start()
1383
1384
# run epoch
1385
for batch_idx, batch in enumerate(self.get_train_dataloader()):
1386
# stop epoch if we limited the number of training batches
1387
if batch_idx >= self.num_training_batches:
1388
break
1389
1390
self.batch_idx = batch_idx
1391
1392
model = self.get_model()
1393
model.global_step = self.global_step
1394
1395
# ---------------
1396
# RUN TRAIN STEP
1397
# ---------------
1398
output = self.run_training_batch(batch, batch_idx)
1399
batch_result, grad_norm_dic, batch_step_metrics = output
1400
1401
# when returning -1 from train_step, we end epoch early
1402
early_stop_epoch = batch_result == -1
1403
1404
# ---------------
1405
# RUN VAL STEP
1406
# ---------------
1407
should_check_val = (
1408
not self.disable_validation and self.global_step % self.val_check_batch == 0 and not self.fisrt_epoch)
1409
self.fisrt_epoch = False
1410
1411
if should_check_val:
1412
self.run_evaluation(test=self.testing)
1413
1414
# when logs should be saved
1415
should_save_log = (batch_idx + 1) % self.log_save_interval == 0 or early_stop_epoch
1416
if should_save_log:
1417
if self.proc_rank == 0 and self.logger is not None:
1418
self.logger.save()
1419
1420
# when metrics should be logged
1421
should_log_metrics = batch_idx % self.row_log_interval == 0 or early_stop_epoch
1422
if should_log_metrics:
1423
# logs user requested information to logger
1424
self.log_metrics(batch_step_metrics, grad_norm_dic)
1425
1426
self.global_step += 1
1427
self.total_batch_idx += 1
1428
1429
# end epoch early
1430
# stop when the flag is changed or we've gone past the amount
1431
# requested in the batches
1432
if early_stop_epoch:
1433
break
1434
if self.global_step > self.max_updates:
1435
print("| Training end..")
1436
exit()
1437
1438
# epoch end hook
1439
if self.is_function_implemented('on_epoch_end'):
1440
model = self.get_model()
1441
model.on_epoch_end()
1442
1443
def run_training_batch(self, batch, batch_idx):
1444
# track grad norms
1445
grad_norm_dic = {}
1446
1447
# track all metrics for callbacks
1448
all_callback_metrics = []
1449
1450
# track metrics to log
1451
all_log_metrics = []
1452
1453
if batch is None:
1454
return 0, grad_norm_dic, {}
1455
1456
# hook
1457
if self.is_function_implemented('on_batch_start'):
1458
model_ref = self.get_model()
1459
response = model_ref.on_batch_start(batch)
1460
1461
if response == -1:
1462
return -1, grad_norm_dic, {}
1463
1464
splits = [batch]
1465
self.hiddens = None
1466
for split_idx, split_batch in enumerate(splits):
1467
self.split_idx = split_idx
1468
1469
# call training_step once per optimizer
1470
for opt_idx, optimizer in enumerate(self.optimizers):
1471
if optimizer is None:
1472
continue
1473
# make sure only the gradients of the current optimizer's paramaters are calculated
1474
# in the training step to prevent dangling gradients in multiple-optimizer setup.
1475
if len(self.optimizers) > 1:
1476
for param in self.get_model().parameters():
1477
param.requires_grad = False
1478
for group in optimizer.param_groups:
1479
for param in group['params']:
1480
param.requires_grad = True
1481
1482
# wrap the forward step in a closure so second order methods work
1483
def optimizer_closure():
1484
# forward pass
1485
output = self.training_forward(
1486
split_batch, batch_idx, opt_idx, self.hiddens)
1487
1488
closure_loss = output[0]
1489
progress_bar_metrics = output[1]
1490
log_metrics = output[2]
1491
callback_metrics = output[3]
1492
self.hiddens = output[4]
1493
if closure_loss is None:
1494
return None
1495
1496
# accumulate loss
1497
# (if accumulate_grad_batches = 1 no effect)
1498
closure_loss = closure_loss / self.accumulate_grad_batches
1499
1500
# backward pass
1501
model_ref = self.get_model()
1502
if closure_loss.requires_grad:
1503
model_ref.backward(closure_loss, optimizer)
1504
1505
# track metrics for callbacks
1506
all_callback_metrics.append(callback_metrics)
1507
1508
# track progress bar metrics
1509
self.add_tqdm_metrics(progress_bar_metrics)
1510
all_log_metrics.append(log_metrics)
1511
1512
# insert after step hook
1513
if self.is_function_implemented('on_after_backward'):
1514
model_ref = self.get_model()
1515
model_ref.on_after_backward()
1516
1517
return closure_loss
1518
1519
# calculate loss
1520
loss = optimizer_closure()
1521
if loss is None:
1522
continue
1523
1524
# nan grads
1525
if self.print_nan_grads:
1526
self.print_nan_gradients()
1527
1528
# track total loss for logging (avoid mem leaks)
1529
self.batch_loss_value += loss.item()
1530
1531
# gradient update with accumulated gradients
1532
if (self.batch_idx + 1) % self.accumulate_grad_batches == 0:
1533
1534
# track gradient norms when requested
1535
if batch_idx % self.row_log_interval == 0:
1536
if self.track_grad_norm > 0:
1537
model = self.get_model()
1538
grad_norm_dic = model.grad_norm(
1539
self.track_grad_norm)
1540
1541
# clip gradients
1542
self.clip_gradients()
1543
1544
# calls .step(), .zero_grad()
1545
# override function to modify this behavior
1546
model = self.get_model()
1547
model.optimizer_step(self.current_epoch, batch_idx, optimizer, opt_idx)
1548
1549
# calculate running loss for display
1550
self.running_loss.append(self.batch_loss_value)
1551
self.batch_loss_value = 0
1552
self.avg_loss = np.mean(self.running_loss[-100:])
1553
1554
# activate batch end hook
1555
if self.is_function_implemented('on_batch_end'):
1556
model = self.get_model()
1557
model.on_batch_end()
1558
1559
# update progress bar
1560
self.main_progress_bar.update(1)
1561
self.main_progress_bar.set_postfix(**self.training_tqdm_dict)
1562
1563
# collapse all metrics into one dict
1564
all_log_metrics = {k: v for d in all_log_metrics for k, v in d.items()}
1565
1566
# track all metrics for callbacks
1567
self.callback_metrics.update({k: v for d in all_callback_metrics for k, v in d.items()})
1568
1569
return 0, grad_norm_dic, all_log_metrics
1570
1571
def training_forward(self, batch, batch_idx, opt_idx, hiddens):
1572
"""
1573
Handle forward for each training case (distributed, single gpu, etc...)
1574
:param batch:
1575
:param batch_idx:
1576
:return:
1577
"""
1578
# ---------------
1579
# FORWARD
1580
# ---------------
1581
# enable not needing to add opt_idx to training_step
1582
args = [batch, batch_idx, opt_idx]
1583
1584
# distributed forward
1585
if self.use_ddp or self.use_dp:
1586
output = self.model(*args)
1587
# single GPU forward
1588
elif self.single_gpu:
1589
gpu_id = 0
1590
if isinstance(self.data_parallel_device_ids, list):
1591
gpu_id = self.data_parallel_device_ids[0]
1592
batch = self.transfer_batch_to_gpu(copy.copy(batch), gpu_id)
1593
args[0] = batch
1594
output = self.model.training_step(*args)
1595
# CPU forward
1596
else:
1597
output = self.model.training_step(*args)
1598
1599
# allow any mode to define training_end
1600
model_ref = self.get_model()
1601
output_ = model_ref.training_end(output)
1602
if output_ is not None:
1603
output = output_
1604
1605
# format and reduce outputs accordingly
1606
output = self.process_output(output, train=True)
1607
1608
return output
1609
1610
# ---------------
1611
# Utils
1612
# ---------------
1613
def is_function_implemented(self, f_name):
1614
model = self.get_model()
1615
f_op = getattr(model, f_name, None)
1616
return callable(f_op)
1617
1618
def _percent_range_check(self, name):
1619
value = getattr(self, name)
1620
msg = f"`{name}` must lie in the range [0.0, 1.0], but got {value:.3f}."
1621
if name == "val_check_interval":
1622
msg += " If you want to disable validation set `val_percent_check` to 0.0 instead."
1623
1624
if not 0. <= value <= 1.:
1625
raise ValueError(msg)
1626
1627