Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
jantic
GitHub Repository: jantic/deoldify
Path: blob/master/fastai/callbacks/lr_finder.py
781 views
1
"Tools to help find the optimal learning rate for training"
2
from ..torch_core import *
3
from ..basic_data import DataBunch
4
from ..callback import *
5
from ..basic_train import Learner, LearnerCallback
6
7
__all__ = ['LRFinder']
8
9
class LRFinder(LearnerCallback):
10
"Causes `learn` to go on a mock training from `start_lr` to `end_lr` for `num_it` iterations."
11
def __init__(self, learn:Learner, start_lr:float=1e-7, end_lr:float=10, num_it:int=100, stop_div:bool=True):
12
super().__init__(learn)
13
self.data,self.stop_div = learn.data,stop_div
14
self.sched = Scheduler((start_lr, end_lr), num_it, annealing_exp)
15
16
def on_train_begin(self, pbar, **kwargs:Any)->None:
17
"Initialize optimizer and learner hyperparameters."
18
setattr(pbar, 'clean_on_interrupt', True)
19
self.learn.save('tmp')
20
self.opt = self.learn.opt
21
self.opt.lr = self.sched.start
22
self.stop,self.best_loss = False,0.
23
return {'skip_validate': True}
24
25
def on_batch_end(self, iteration:int, smooth_loss:TensorOrNumber, **kwargs:Any)->None:
26
"Determine if loss has runaway and we should stop."
27
if iteration==0 or smooth_loss < self.best_loss: self.best_loss = smooth_loss
28
self.opt.lr = self.sched.step()
29
if self.sched.is_done or (self.stop_div and (smooth_loss > 4*self.best_loss or torch.isnan(smooth_loss))):
30
#We use the smoothed loss to decide on the stopping since it's less shaky.
31
return {'stop_epoch': True, 'stop_training': True}
32
33
def on_train_end(self, **kwargs:Any)->None:
34
"Cleanup learn model weights disturbed during LRFinder exploration."
35
self.learn.load('tmp', purge=False)
36
if hasattr(self.learn.model, 'reset'): self.learn.model.reset()
37
for cb in self.callbacks:
38
if hasattr(cb, 'reset'): cb.reset()
39
print('LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.')
40
41