Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
labmlai
GitHub Repository: labmlai/annotated_deep_learning_paper_implementations
Path: blob/master/labml_nn/optimizers/noam.py
4910 views
1
"""
2
---
3
title: Noam optimizer from Attention is All You Need paper
4
summary: >
5
This is a tutorial/implementation of Noam optimizer.
6
Noam optimizer has a warm-up period and then an exponentially decaying learning rate.
7
---
8
9
# Noam Optimizer
10
11
This is the [PyTorch](https://pytorch.org) implementation of optimizer introduced in the paper
12
[Attention Is All You Need](https://arxiv.org/abs/1706.03762).
13
"""
14
from typing import Dict
15
16
from labml_nn.optimizers import WeightDecay
17
from labml_nn.optimizers.amsgrad import AMSGrad
18
19
20
class Noam(AMSGrad):
21
"""
22
## Noam Optimizer
23
24
This class extends from Adam optimizer defined in [`adam.py`](adam.html).
25
"""
26
27
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,
28
weight_decay: WeightDecay = WeightDecay(),
29
optimized_update: bool = True,
30
amsgrad=False,
31
warmup=0, d_model=512, defaults=None):
32
"""
33
### Initialize the optimizer
34
35
* `params` is the list of parameters
36
* `lr` is the learning rate $\alpha$
37
* `betas` is a tuple of ($\beta_1$, $\beta_2$)
38
* `eps` is $\hat{\epsilon}$ or $\epsilon$ based on `optimized_update`
39
* `weight_decay` is an instance of class `WeightDecay` defined in [`__init__.py`](index.html)
40
* 'optimized_update' is a flag whether to optimize the bias correction of the second moment
41
by doing it after adding $\epsilon$
42
* `amsgrad` is a flag indicating whether to use AMSGrad or fallback to plain Adam
43
* `warmup` number of warmup steps
44
* `d_model` model size; i.e. number of dimensions in the transformer
45
* `defaults` is a dictionary of default for group values.
46
This is useful when you want to extend the class `AdamWarmup`.
47
"""
48
49
defaults = {} if defaults is None else defaults
50
defaults.update(dict(warmup=warmup))
51
super().__init__(params, lr, betas, eps, weight_decay, optimized_update, amsgrad, defaults)
52
self.d_model = d_model
53
54
def get_lr(self, state: Dict[str, any], group: Dict[str, any]):
55
"""
56
### Get learning-rate
57
58
$$\alpha \frac{1}{\sqrt{d_{model}}} \min \bigg(\frac{1}{\sqrt{t}}, \frac{t}{w^{3/2}}\bigg)$$
59
where $w$ is the number of warmup steps.
60
"""
61
# $$\min \bigg(\frac{1}{\sqrt{t}}, \frac{t}{w^{3/2}}\bigg)$$
62
factor = min(state['step'] ** (-0.5), state['step'] * group['warmup'] ** (-1.5))
63
# $$\alpha \frac{1}{\sqrt{d_{model}}} \min \bigg(\frac{1}{\sqrt{t}}, \frac{t}{w^{3/2}}\bigg)$$
64
return group['lr'] * self.d_model ** (-0.5) * factor
65
66
67
def _test_noam_lr():
68
"""
69
### Plot learning rate for different warmups and model sizes
70
71
![Plot of learning rate](noam_lr.png)
72
"""
73
import matplotlib.pyplot as plt
74
import numpy as np
75
from torch import nn
76
77
model = nn.Linear(10, 10)
78
opts = [Noam(model.parameters(), d_model=512, warmup=4000, lr=1),
79
Noam(model.parameters(), d_model=512, warmup=8000, lr=1),
80
Noam(model.parameters(), d_model=2048, warmup=2000, lr=1)]
81
plt.plot(np.arange(1, 20000), [[opt.get_lr({'step': i}, opt.defaults) for opt in opts] for i in range(1, 20000)])
82
plt.legend(["512:4000", "512:8000", "2048:2000"])
83
plt.title("Learning Rate")
84
plt.show()
85
86
87
if __name__ == '__main__':
88
_test_noam_lr()
89
90