Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
jantic
GitHub Repository: jantic/deoldify
Path: blob/master/fastai/callbacks/rnn.py
781 views
1
"Regroups lr adjustment to seq_len, AR and TAR"
2
from ..torch_core import *
3
from ..callback import *
4
from ..basic_train import Learner, LearnerCallback
5
6
__all__ = ['RNNTrainer']
7
8
class RNNTrainer(LearnerCallback):
9
"`Callback` that regroups lr adjustment to seq_len, AR and TAR."
10
def __init__(self, learn:Learner, alpha:float=0., beta:float=0.):
11
super().__init__(learn)
12
self.not_min += ['raw_out', 'out']
13
self.alpha,self.beta = alpha,beta
14
15
def on_epoch_begin(self, **kwargs):
16
"Reset the hidden state of the model."
17
self.learn.model.reset()
18
19
def on_loss_begin(self, last_output:Tuple[Tensor,Tensor,Tensor], **kwargs):
20
"Save the extra outputs for later and only returns the true output."
21
self.raw_out,self.out = last_output[1],last_output[2]
22
return {'last_output': last_output[0]}
23
24
def on_backward_begin(self, last_loss:Rank0Tensor, last_input:Tensor, **kwargs):
25
"Apply AR and TAR to `last_loss`."
26
#AR and TAR
27
if self.alpha != 0.: last_loss += self.alpha * self.out[-1].float().pow(2).mean()
28
if self.beta != 0.:
29
h = self.raw_out[-1]
30
if len(h)>1: last_loss += self.beta * (h[:,1:] - h[:,:-1]).float().pow(2).mean()
31
return {'last_loss': last_loss}
32
33