Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
prophesier
GitHub Repository: prophesier/diff-svc
Path: blob/main/utils/training_utils.py
694 views
1
from utils.hparams import hparams
2
3
4
class RSQRTSchedule(object):
5
def __init__(self, optimizer):
6
super().__init__()
7
self.optimizer = optimizer
8
self.constant_lr = hparams['lr']
9
self.warmup_updates = hparams['warmup_updates']
10
self.hidden_size = hparams['hidden_size']
11
self.lr = hparams['lr']
12
for param_group in optimizer.param_groups:
13
param_group['lr'] = self.lr
14
self.step(0)
15
16
def step(self, num_updates):
17
constant_lr = self.constant_lr
18
warmup = min(num_updates / self.warmup_updates, 1.0)
19
rsqrt_decay = max(self.warmup_updates, num_updates) ** -0.5
20
rsqrt_hidden = self.hidden_size ** -0.5
21
self.lr = max(constant_lr * warmup * rsqrt_decay * rsqrt_hidden, 1e-7)
22
for param_group in self.optimizer.param_groups:
23
param_group['lr'] = self.lr
24
return self.lr
25
26
def get_lr(self):
27
return self.optimizer.param_groups[0]['lr']
28
29