Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
jantic
GitHub Repository: jantic/deoldify
Path: blob/master/fastai/callbacks/tracker.py
781 views
1
# Contribution from @fredguth, https://github.com/fredguth/fastai_playground.
2
3
from fastai.torch_core import *
4
from fastai.callback import *
5
from fastai.basic_train import *
6
7
__all__ = ['TerminateOnNaNCallback', 'EarlyStoppingCallback', 'SaveModelCallback', 'TrackerCallback',
8
'ReduceLROnPlateauCallback', 'TrackEpochCallback' ]
9
10
class TerminateOnNaNCallback(Callback):
11
"A `Callback` that terminates training if loss is NaN."
12
13
def __init__(self):
14
self.stop = False
15
16
def on_batch_end(self, last_loss, epoch, num_batch, **kwargs:Any)->None:
17
"Test if `last_loss` is NaN and interrupts training."
18
if self.stop: return True #to skip validation after stopping during training
19
if torch.isnan(last_loss):
20
print (f'Epoch/Batch ({epoch}/{num_batch}): Invalid loss, terminating training.')
21
return {'stop_epoch': True, 'stop_training': True, 'skip_validate': True}
22
23
class TrackerCallback(LearnerCallback):
24
"A `LearnerCallback` that keeps track of the best value in `monitor`."
25
def __init__(self, learn:Learner, monitor:str='valid_loss', mode:str='auto'):
26
super().__init__(learn)
27
self.monitor,self.mode = monitor,mode
28
if self.mode not in ['auto', 'min', 'max']:
29
warn(f'{self.__class__} mode {self.mode} is invalid, falling back to "auto" mode.')
30
self.mode = 'auto'
31
mode_dict = {'min': np.less, 'max':np.greater}
32
mode_dict['auto'] = np.less if 'loss' in self.monitor else np.greater
33
self.operator = mode_dict[self.mode]
34
35
def on_train_begin(self, **kwargs:Any)->None:
36
"Initializes the best value."
37
self.best = float('inf') if self.operator == np.less else -float('inf')
38
39
def get_monitor_value(self):
40
"Pick the monitored value."
41
if self.monitor=='trn_loss' and len(self.learn.recorder.losses) == 0: return None
42
elif len(self.learn.recorder.val_losses) == 0: return None
43
values = {'train_loss':self.learn.recorder.losses[-1].cpu().numpy(),
44
'valid_loss':self.learn.recorder.val_losses[-1]}
45
if values['valid_loss'] is None: return
46
if self.learn.recorder.metrics:
47
for m, n in zip(self.learn.recorder.metrics[-1],self.learn.recorder.names[3:-1]):
48
values[n] = m
49
if values.get(self.monitor) is None:
50
warn(f'{self.__class__} conditioned on metric `{self.monitor}` which is not available. Available metrics are: {", ".join(map(str, self.learn.recorder.names[1:-1]))}')
51
return values.get(self.monitor)
52
53
class EarlyStoppingCallback(TrackerCallback):
54
"A `TrackerCallback` that terminates training when monitored quantity stops improving."
55
def __init__(self, learn:Learner, monitor:str='valid_loss', mode:str='auto', min_delta:int=0, patience:int=0):
56
super().__init__(learn, monitor=monitor, mode=mode)
57
self.min_delta,self.patience = min_delta,patience
58
if self.operator == np.less: self.min_delta *= -1
59
60
def on_train_begin(self, **kwargs:Any)->None:
61
"Initialize inner arguments."
62
self.wait = 0
63
super().on_train_begin(**kwargs)
64
65
def on_epoch_end(self, epoch, **kwargs:Any)->None:
66
"Compare the value monitored to its best score and maybe stop training."
67
current = self.get_monitor_value()
68
if current is None: return
69
if self.operator(current - self.min_delta, self.best):
70
self.best,self.wait = current,0
71
else:
72
self.wait += 1
73
if self.wait > self.patience:
74
print(f'Epoch {epoch}: early stopping')
75
return {"stop_training":True}
76
77
class SaveModelCallback(TrackerCallback):
78
"A `TrackerCallback` that saves the model when monitored quantity is best."
79
def __init__(self, learn:Learner, monitor:str='valid_loss', mode:str='auto', every:str='improvement', name:str='bestmodel'):
80
super().__init__(learn, monitor=monitor, mode=mode)
81
self.every,self.name = every,name
82
if self.every not in ['improvement', 'epoch']:
83
warn(f'SaveModel every {self.every} is invalid, falling back to "improvement".')
84
self.every = 'improvement'
85
86
def jump_to_epoch(self, epoch:int)->None:
87
try:
88
self.learn.load(f'{self.name}_{epoch-1}', purge=False)
89
print(f"Loaded {self.name}_{epoch-1}")
90
except: print(f'Model {self.name}_{epoch-1} not found.')
91
92
def on_epoch_end(self, epoch:int, **kwargs:Any)->None:
93
"Compare the value monitored to its best score and maybe save the model."
94
if self.every=="epoch": self.learn.save(f'{self.name}_{epoch}')
95
else: #every="improvement"
96
current = self.get_monitor_value()
97
if current is not None and self.operator(current, self.best):
98
print(f'Better model found at epoch {epoch} with {self.monitor} value: {current}.')
99
self.best = current
100
self.learn.save(f'{self.name}')
101
102
def on_train_end(self, **kwargs):
103
"Load the best model."
104
if self.every=="improvement" and (self.learn.path/f'{self.learn.model_dir}/{self.name}.pth').is_file():
105
self.learn.load(f'{self.name}', purge=False)
106
107
class ReduceLROnPlateauCallback(TrackerCallback):
108
"A `TrackerCallback` that reduces learning rate when a metric has stopped improving."
109
def __init__(self, learn:Learner, monitor:str='valid_loss', mode:str='auto', patience:int=0, factor:float=0.2,
110
min_delta:int=0):
111
super().__init__(learn, monitor=monitor, mode=mode)
112
self.patience,self.factor,self.min_delta = patience,factor,min_delta
113
if self.operator == np.less: self.min_delta *= -1
114
115
def on_train_begin(self, **kwargs:Any)->None:
116
"Initialize inner arguments."
117
self.wait, self.opt = 0, self.learn.opt
118
super().on_train_begin(**kwargs)
119
120
def on_epoch_end(self, epoch, **kwargs:Any)->None:
121
"Compare the value monitored to its best and maybe reduce lr."
122
current = self.get_monitor_value()
123
if current is None: return
124
if self.operator(current - self.min_delta, self.best): self.best,self.wait = current,0
125
else:
126
self.wait += 1
127
if self.wait > self.patience:
128
self.opt.lr *= self.factor
129
self.wait = 0
130
print(f'Epoch {epoch}: reducing lr to {self.opt.lr}')
131
132
133
class TrackEpochCallback(LearnerCallback):
134
_order = -20 #Need to run before fit_one_cycle
135
def __init__(self, learn:Learner, name:str='epoch', epoch_offset:int=None):
136
"Store completed epoch number in `learn.model_dir/name`."
137
super().__init__(learn)
138
learn._test_writeable_path()
139
self.path = learn.path/learn.model_dir/name
140
if epoch_offset is None:
141
if os.path.isfile(self.path):
142
with self.path.open('r') as f:
143
try: self.start_epoch = int(f.read())+1
144
except: self.start_epoch = 0
145
else: self.start_epoch = 0
146
147
def on_train_begin(self, **kwargs:Any):
148
return {'epoch': self.start_epoch}
149
150
def on_epoch_end(self, epoch, **kwargs:Any)->None:
151
with self.path.open('w') as f: f.write(f'{epoch}')
152
153
def restart(self): os.remove(self.path)
154
155