Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
jantic
GitHub Repository: jantic/deoldify
Path: blob/master/fastai/callbacks/general_sched.py
781 views
1
from ..core import *
2
from ..callback import *
3
from ..basic_train import Learner, LearnerCallback
4
5
__all__ = ['GeneralScheduler', 'TrainingPhase']
6
7
@dataclass
8
class TrainingPhase():
9
"Schedule hyper-parameters for a phase of `length` iterations."
10
length:int
11
12
def __post_init__(self): self.scheds = dict()
13
def schedule_hp(self, name, vals, anneal=None):
14
"Adds a schedule for `name` between `vals` using `anneal`."
15
self.scheds[name] = Scheduler(vals, self.length, anneal)
16
return self
17
18
class GeneralScheduler(LearnerCallback):
19
"Schedule multiple `TrainingPhase` for a `Learner`."
20
def __init__(self, learn:Learner, phases:Collection[TrainingPhase], start_epoch:int=None):
21
super().__init__(learn)
22
self.phases,self.start_epoch = phases,start_epoch
23
24
def on_train_begin(self, epoch:int, **kwargs:Any)->None:
25
"Initialize the schedulers for training."
26
res = {'epoch':self.start_epoch} if self.start_epoch is not None else None
27
self.start_epoch = ifnone(self.start_epoch, epoch)
28
self.scheds = [p.scheds for p in self.phases]
29
self.opt = self.learn.opt
30
for k,v in self.scheds[0].items():
31
v.restart()
32
self.opt.set_stat(k, v.start)
33
self.idx_s = 0
34
return res
35
36
def jump_to_epoch(self, epoch:int)->None:
37
for _ in range(len(self.learn.data.train_dl) * epoch):
38
self.on_batch_end(True)
39
40
def on_batch_end(self, train, **kwargs:Any)->None:
41
"Take a step in lr,mom sched, start next stepper when the current one is complete."
42
if train:
43
if self.idx_s >= len(self.scheds): return {'stop_training': True, 'stop_epoch': True}
44
sched = self.scheds[self.idx_s]
45
for k,v in sched.items(): self.opt.set_stat(k, v.step())
46
if list(sched.values())[0].is_done: self.idx_s += 1
47