Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
jantic
GitHub Repository: jantic/deoldify
Path: blob/master/fastai/callback.py
781 views
1
"Callbacks provides extensibility to the `basic_train` loop. See `train` for examples of custom callbacks."
2
from .basic_data import *
3
from .torch_core import *
4
import torch.distributed as dist
5
6
__all__ = ['AverageMetric', 'Callback', 'CallbackHandler', 'OptimWrapper', 'SmoothenValue', 'Scheduler', 'annealing_cos', 'CallbackList',
7
'annealing_exp', 'annealing_linear', 'annealing_no', 'annealing_poly']
8
9
class OptimWrapper():
10
"Basic wrapper around `opt` to simplify hyper-parameters changes."
11
def __init__(self, opt:optim.Optimizer, wd:Floats=0., true_wd:bool=False, bn_wd:bool=True):
12
assert not isinstance(opt, OptimWrapper)
13
self.opt,self.true_wd,self.bn_wd = opt,true_wd,bn_wd
14
self.opt_keys = list(self.opt.param_groups[0].keys())
15
self.opt_keys.remove('params')
16
self.read_defaults()
17
self.wd = wd
18
19
@classmethod
20
def create(cls, opt_func:Union[type,Callable], lr:Union[float,Tuple,List], layer_groups:ModuleList, wd:Floats=0.,
21
true_wd:bool=False, bn_wd:bool=True)->optim.Optimizer:
22
"Create an `optim.Optimizer` from `opt_func` with `lr`. Set lr on `layer_groups`."
23
split_params = split_no_wd_params(layer_groups)
24
opt = opt_func([{'params': p, 'lr':0} for p in split_params])
25
opt = cls(opt, wd=wd, true_wd=true_wd, bn_wd=bn_wd)
26
opt.lr,opt.opt_func = listify(lr, layer_groups),opt_func
27
return opt
28
29
def new(self, layer_groups:Collection[nn.Module], split_no_wd:bool=True):
30
"Create a new `OptimWrapper` from `self` with another `layer_groups` but the same hyper-parameters."
31
opt_func = getattr(self, 'opt_func', self.opt.__class__)
32
res = self.create(opt_func, self.lr, layer_groups, wd=self.wd, true_wd=self.true_wd, bn_wd=self.bn_wd)
33
res.mom,res.beta = self.mom,self.beta
34
return res
35
36
def new_with_params(self, param_groups:Collection[Collection[nn.Parameter]]):
37
"Create a new `OptimWrapper` from `self` with another `layer_groups` but the same hyper-parameters."
38
opt_func = getattr(self, 'opt_func', self.opt.__class__)
39
opt = opt_func([{'params': p, 'lr':0} for p in param_groups])
40
opt = self.__class__(opt, wd=self.wd, true_wd=self.true_wd, bn_wd=self.bn_wd)
41
opt.lr,opt.opt_func,opt.mom,opt.beta = self.lr,opt_func,self.mom,self.beta
42
return opt
43
44
def __repr__(self)->str:
45
return f'OptimWrapper over {repr(self.opt)}.\nTrue weight decay: {self.true_wd}'
46
47
#Pytorch optimizer methods
48
def step(self)->None:
49
"Set weight decay and step optimizer."
50
# weight decay outside of optimizer step (AdamW)
51
if self.true_wd:
52
for lr,wd,pg1,pg2 in zip(self._lr,self._wd,self.opt.param_groups[::2],self.opt.param_groups[1::2]):
53
for p in pg1['params']: p.data.mul_(1 - wd*lr)
54
if self.bn_wd:
55
for p in pg2['params']: p.data.mul_(1 - wd*lr)
56
self.set_val('weight_decay', listify(0, self._wd))
57
self.opt.step()
58
59
def zero_grad(self)->None:
60
"Clear optimizer gradients."
61
self.opt.zero_grad()
62
63
#Passthrough to the inner opt.
64
def __getattr__(self, k:str)->Any: return getattr(self.opt, k, None)
65
def __setstate__(self,data:Any): self.__dict__.update(data)
66
67
def clear(self):
68
"Reset the state of the inner optimizer."
69
sd = self.state_dict()
70
sd['state'] = {}
71
self.load_state_dict(sd)
72
73
@property
74
def n_params(self): return sum([len(pg['params']) for pg in self.opt.param_groups])
75
76
#Hyperparameters as properties
77
@property
78
def lr(self)->float: return self._lr[-1]
79
@lr.setter
80
def lr(self, val:float)->None:
81
self._lr = self.set_val('lr', listify(val, self._lr))
82
83
@property
84
def mom(self)->float:return self._mom[-1]
85
@mom.setter
86
def mom(self, val:float)->None:
87
if 'momentum' in self.opt_keys: self.set_val('momentum', listify(val, self._mom))
88
elif 'betas' in self.opt_keys: self.set_val('betas', (listify(val, self._mom), self._beta))
89
self._mom = listify(val, self._mom)
90
91
@property
92
def beta(self)->float: return None if self._beta is None else self._beta[-1]
93
@beta.setter
94
def beta(self, val:float)->None:
95
"Set beta (or alpha as makes sense for given optimizer)."
96
if val is None: return
97
if 'betas' in self.opt_keys: self.set_val('betas', (self._mom, listify(val, self._beta)))
98
elif 'alpha' in self.opt_keys: self.set_val('alpha', listify(val, self._beta))
99
self._beta = listify(val, self._beta)
100
101
@property
102
def wd(self)->float: return self._wd[-1]
103
@wd.setter
104
def wd(self, val:float)->None:
105
"Set weight decay."
106
if not self.true_wd: self.set_val('weight_decay', listify(val, self._wd), bn_groups=self.bn_wd)
107
self._wd = listify(val, self._wd)
108
109
#Helper functions
110
def read_defaults(self)->None:
111
"Read the values inside the optimizer for the hyper-parameters."
112
self._beta = None
113
if 'lr' in self.opt_keys: self._lr = self.read_val('lr')
114
if 'momentum' in self.opt_keys: self._mom = self.read_val('momentum')
115
if 'alpha' in self.opt_keys: self._beta = self.read_val('alpha')
116
if 'betas' in self.opt_keys: self._mom,self._beta = self.read_val('betas')
117
if 'weight_decay' in self.opt_keys: self._wd = self.read_val('weight_decay')
118
reserved_names = ['params', 'lr', 'momentum', 'alpha', 'betas', 'weight_decay']
119
stat_names = [n for n in self.opt_keys if n not in reserved_names]
120
self._stats = {n:self.read_val(n) for n in stat_names}
121
122
def get_stat(self, name:str)->float:
123
if name in ['lr', 'mom', 'beta', 'wd']: return getattr(self, name)
124
else: return self._stats[name][-1]
125
def set_stat(self, name:str, value:Union[float, Collection[float]])->None:
126
if name in ['lr', 'mom', 'beta', 'wd']: setattr(self, name, value)
127
else:
128
val = listify(value, self._stats[name])
129
self.set_val(name, val)
130
self._stats[name] = val
131
132
def set_val(self, key:str, val:Any, bn_groups:bool=True)->Any:
133
"Set `val` inside the optimizer dictionary at `key`."
134
if is_tuple(val): val = [(v1,v2) for v1,v2 in zip(*val)]
135
for v,pg1,pg2 in zip(val,self.opt.param_groups[::2],self.opt.param_groups[1::2]):
136
pg1[key] = v
137
if bn_groups: pg2[key] = v
138
return val
139
140
def read_val(self, key:str) -> Union[List[float],Tuple[List[float],List[float]]]:
141
"Read a hyperparameter `key` in the optimizer dictionary."
142
val = [pg[key] for pg in self.opt.param_groups[::2]]
143
if is_tuple(val[0]): val = [o[0] for o in val], [o[1] for o in val]
144
return val
145
146
def get_state(self):
147
"Return the inner state minus the layer groups."
148
return {'opt_state':self.opt.state_dict(), 'lr':self._lr, 'wd':self._wd, 'beta':self._beta, 'mom':self._mom,
149
'opt_func':self.opt_func, 'true_wd':self.true_wd, 'bn_wd':self.bn_wd}
150
151
@classmethod
152
def load_with_state_and_layer_group(cls, state:dict, layer_groups:Collection[nn.Module]):
153
res = cls.create(state['opt_func'], state['lr'], layer_groups, wd=state['wd'], true_wd=state['true_wd'],
154
bn_wd=state['bn_wd'])
155
res._mom,res._beta = state['mom'],state['beta']
156
res.load_state_dict(state['opt_state'])
157
return res
158
159
class Callback():
160
"Base class for callbacks that want to record values, dynamically change learner params, etc."
161
_order=0
162
def on_train_begin(self, **kwargs:Any)->None:
163
"To initialize constants in the callback."
164
pass
165
def on_epoch_begin(self, **kwargs:Any)->None:
166
"At the beginning of each epoch."
167
pass
168
def on_batch_begin(self, **kwargs:Any)->None:
169
"Set HP before the output and loss are computed."
170
pass
171
def on_loss_begin(self, **kwargs:Any)->None:
172
"Called after forward pass but before loss has been computed."
173
pass
174
def on_backward_begin(self, **kwargs:Any)->None:
175
"Called after the forward pass and the loss has been computed, but before backprop."
176
pass
177
def on_backward_end(self, **kwargs:Any)->None:
178
"Called after backprop but before optimizer step. Useful for true weight decay in AdamW."
179
pass
180
def on_step_end(self, **kwargs:Any)->None:
181
"Called after the step of the optimizer but before the gradients are zeroed."
182
pass
183
def on_batch_end(self, **kwargs:Any)->None:
184
"Called at the end of the batch."
185
pass
186
def on_epoch_end(self, **kwargs:Any)->None:
187
"Called at the end of an epoch."
188
pass
189
def on_train_end(self, **kwargs:Any)->None:
190
"Useful for cleaning up things and saving files/models."
191
pass
192
def jump_to_epoch(self, epoch)->None:
193
"To resume training at `epoch` directly."
194
pass
195
196
def get_state(self, minimal:bool=True):
197
"Return the inner state of the `Callback`, `minimal` or not."
198
to_remove = ['exclude', 'not_min'] + getattr(self, 'exclude', []).copy()
199
if minimal: to_remove += getattr(self, 'not_min', []).copy()
200
return {k:v for k,v in self.__dict__.items() if k not in to_remove}
201
202
def __repr__(self):
203
attrs = func_args(self.__init__)
204
to_remove = getattr(self, 'exclude', [])
205
list_repr = [self.__class__.__name__] + [f'{k}: {getattr(self, k)}' for k in attrs if k != 'self' and k not in to_remove]
206
return '\n'.join(list_repr)
207
208
class SmoothenValue():
209
"Create a smooth moving average for a value (loss, etc) using `beta`."
210
def __init__(self, beta:float):
211
self.beta,self.n,self.mov_avg = beta,0,0
212
213
def add_value(self, val:float)->None:
214
"Add `val` to calculate updated smoothed value."
215
self.n += 1
216
self.mov_avg = self.beta * self.mov_avg + (1 - self.beta) * val
217
self.smooth = self.mov_avg / (1 - self.beta ** self.n)
218
219
CallbackList = Collection[Callback]
220
221
def _get_init_state(): return {'epoch':0, 'iteration':0, 'num_batch':0, 'skip_validate': False}
222
223
@dataclass
224
class CallbackHandler():
225
"Manage all of the registered `callbacks` and `metrics`, smoothing loss by momentum `beta`."
226
callbacks:CallbackList=None
227
metrics:CallbackList=None
228
beta:float=0.98
229
230
def __post_init__(self)->None:
231
"Initialize smoother and learning stats."
232
self.callbacks = ifnone(self.callbacks, [])
233
self.metrics = ifnone(self.metrics, [])
234
self.metrics = [(met if isinstance(met, Callback) else AverageMetric(met)) for met in self.metrics]
235
self.callbacks = sorted(self.callbacks, key=lambda o: getattr(o, '_order', 0))
236
self.smoothener = SmoothenValue(self.beta)
237
self.state_dict:Dict[str,Union[int,float,Tensor]]=_get_init_state()
238
239
def _call_and_update(self, cb, cb_name, **kwargs)->None:
240
"Call `cb_name` on `cb` and update the inner state."
241
new = ifnone(getattr(cb, f'on_{cb_name}')(**self.state_dict, **kwargs), dict())
242
for k,v in new.items():
243
if k not in self.state_dict:
244
raise Exception(f"{k} isn't a valid key in the state of the callbacks.")
245
else: self.state_dict[k] = v
246
247
def __call__(self, cb_name, call_mets=True, **kwargs)->None:
248
"Call through to all of the `CallbakHandler` functions."
249
if call_mets:
250
for met in self.metrics: self._call_and_update(met, cb_name, **kwargs)
251
for cb in self.callbacks: self._call_and_update(cb, cb_name, **kwargs)
252
253
def set_dl(self, dl:DataLoader):
254
"Set the current `dl` used."
255
if hasattr(self, 'cb_dl'): self.callbacks.remove(self.cb_dl)
256
if isinstance(dl.dataset, Callback):
257
self.callbacks.append(dl.dataset)
258
self.cb_dl = dl.dataset
259
260
def on_train_begin(self, epochs:int, pbar:PBar, metrics:MetricFuncList)->None:
261
"About to start learning."
262
self.state_dict = _get_init_state()
263
self.state_dict.update(dict(n_epochs=epochs, pbar=pbar, metrics=metrics))
264
names = [(met.name if hasattr(met, 'name') else camel2snake(met.__class__.__name__)) for met in self.metrics]
265
self('train_begin', metrics_names=names)
266
if self.state_dict['epoch'] != 0:
267
self.state_dict['pbar'].first_bar.total -= self.state_dict['epoch']
268
for cb in self.callbacks: cb.jump_to_epoch(self.state_dict['epoch'])
269
270
def on_epoch_begin(self)->None:
271
"Handle new epoch."
272
self.state_dict['num_batch'],self.state_dict['stop_training'] = 0,False
273
self('epoch_begin')
274
275
def on_batch_begin(self, xb:Tensor, yb:Tensor, train:bool=True)->Tuple[Any,Any]:
276
"Handle new batch `xb`,`yb` in `train` or validation."
277
self.state_dict.update(dict(last_input=xb, last_target=yb, train=train,
278
stop_epoch=False, skip_step=False, skip_zero=False, skip_bwd=False))
279
self('batch_begin', mets = not self.state_dict['train'])
280
return self.state_dict['last_input'], self.state_dict['last_target']
281
282
def on_loss_begin(self, out:Tensor)->Any:
283
"Handle start of loss calculation with model output `out`."
284
self.state_dict['last_output'] = out
285
self('loss_begin', call_mets=False)
286
return self.state_dict['last_output']
287
288
def on_backward_begin(self, loss:Tensor)->Tuple[Any,Any]:
289
"Handle gradient calculation on `loss`."
290
self.smoothener.add_value(loss.detach().cpu())
291
self.state_dict['last_loss'], self.state_dict['smooth_loss'] = loss, self.smoothener.smooth
292
self('backward_begin', call_mets=False)
293
return self.state_dict['last_loss'], self.state_dict['skip_bwd']
294
295
def on_backward_end(self)->Any:
296
"Handle end of gradient calculation."
297
self('backward_end', call_mets=False)
298
return self.state_dict['skip_step']
299
300
def on_step_end(self)->Any:
301
"Handle end of optimization step."
302
self('step_end', call_mets=False)
303
return self.state_dict['skip_zero']
304
305
def on_batch_end(self, loss:Tensor)->Any:
306
"Handle end of processing one batch with `loss`."
307
self.state_dict['last_loss'] = loss
308
self('batch_end', call_mets = not self.state_dict['train'])
309
if self.state_dict['train']:
310
self.state_dict['iteration'] += 1
311
self.state_dict['num_batch'] += 1
312
return self.state_dict['stop_epoch']
313
314
def on_epoch_end(self, val_loss:Tensor)->bool:
315
"Epoch is done, process `val_loss`."
316
self.state_dict['last_metrics'] = [val_loss] if val_loss is not None else [None]
317
self('epoch_end', call_mets = val_loss is not None)
318
self.state_dict['epoch'] += 1
319
return self.state_dict['stop_training']
320
321
def on_train_end(self, exception:Union[bool,Exception])->None:
322
"Handle end of training, `exception` is an `Exception` or False if no exceptions during training."
323
self('train_end', exception=exception)
324
325
@property
326
def skip_validate(self): return self.state_dict['skip_validate']
327
328
class AverageMetric(Callback):
329
"Wrap a `func` in a callback for metrics computation."
330
def __init__(self, func):
331
# If func has a __name__ use this one else it should be a partial
332
name = func.__name__ if hasattr(func, '__name__') else func.func.__name__
333
self.func, self.name = func, name
334
self.world = num_distrib()
335
336
def on_epoch_begin(self, **kwargs):
337
"Set the inner value to 0."
338
self.val, self.count = 0.,0
339
340
def on_batch_end(self, last_output, last_target, **kwargs):
341
"Update metric computation with `last_output` and `last_target`."
342
if not is_listy(last_target): last_target=[last_target]
343
self.count += first_el(last_target).size(0)
344
val = self.func(last_output, *last_target)
345
if self.world:
346
val = val.clone()
347
dist.all_reduce(val, op=dist.ReduceOp.SUM)
348
val /= self.world
349
self.val += first_el(last_target).size(0) * val.detach().cpu()
350
351
def on_epoch_end(self, last_metrics, **kwargs):
352
"Set the final result in `last_metrics`."
353
return add_metrics(last_metrics, self.val/self.count)
354
355
def annealing_no(start:Number, end:Number, pct:float)->Number:
356
"No annealing, always return `start`."
357
return start
358
def annealing_linear(start:Number, end:Number, pct:float)->Number:
359
"Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0."
360
return start + pct * (end-start)
361
def annealing_exp(start:Number, end:Number, pct:float)->Number:
362
"Exponentially anneal from `start` to `end` as pct goes from 0.0 to 1.0."
363
return start * (end/start) ** pct
364
def annealing_cos(start:Number, end:Number, pct:float)->Number:
365
"Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0."
366
cos_out = np.cos(np.pi * pct) + 1
367
return end + (start-end)/2 * cos_out
368
369
def do_annealing_poly(start:Number, end:Number, pct:float, degree:Number)->Number:
370
"Helper function for `anneal_poly`."
371
return end + (start-end) * (1-pct)**degree
372
def annealing_poly(degree:Number)->Number:
373
"Anneal polynomically from `start` to `end` as pct goes from 0.0 to 1.0."
374
return functools.partial(do_annealing_poly, degree=degree)
375
376
class Scheduler():
377
"Used to \"step\" from start,end (`vals`) over `n_iter` iterations on a schedule defined by `func`"
378
def __init__(self, vals:StartOptEnd, n_iter:int, func:Optional[AnnealFunc]=None):
379
self.start,self.end = (vals[0],vals[1]) if is_tuple(vals) else (vals,0)
380
self.n_iter = max(1,n_iter)
381
if func is None: self.func = annealing_linear if is_tuple(vals) else annealing_no
382
else: self.func = func
383
self.n = 0
384
385
def restart(self): self.n = 0
386
387
def step(self)->Number:
388
"Return next value along annealed schedule."
389
self.n += 1
390
return self.func(self.start, self.end, self.n/self.n_iter)
391
392
@property
393
def is_done(self)->bool:
394
"Return `True` if schedule completed."
395
return self.n >= self.n_iter
396
397
398