Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
labmlai
GitHub Repository: labmlai/annotated_deep_learning_paper_implementations
Path: blob/master/labml_nn/optimizers/radam.py
4910 views
1
"""
2
---
3
title: Rectified Adam (RAdam) optimizer
4
summary: A simple PyTorch implementation/tutorial of RAdam optimizer.
5
---
6
7
# Rectified Adam (RAdam) optimizer
8
9
This implementation is based on
10
[the official implementation](https://github.com/LiyuanLucasLiu/RAdam)
11
of the paper
12
[On the Variance of the Adaptive Learning Rate and Beyond](https://arxiv.org/abs/1908.03265).
13
14
We have implemented it in [PyTorch](https://pytorch.org)
15
as an extension to [our AMSGrad implementation](amsgrad.html)
16
thus requiring only the modifications to be implemented.
17
18
Adam optimizer sometimes converges to a bad local optima during the initial stages of the training;
19
especially when training transformers.
20
Researches use warmups to counter this; for the the initial training steps (warm-up stage)
21
they use a low learning rate.
22
This paper identifies the problem to be the high variance of adaptive learning rate
23
during initial stages of training, and counters it using a new rectification term to
24
reduce variance.
25
26
The paper also evaluates two variance reduction mechanisms:
27
* **Adam-2k**: Only compute the adaptive learning rate ($v_t$ in [Adam](adam.html)) during the first 2k steps,
28
without changing parameters or calculating momentum ($m_t$).
29
* **Adam-eps**: Adam with large $\epsilon \approx 10^{-4}$.
30
31
## Rectified Adam
32
33
Let $\sigma(g_1, ..., g_t)$ and $\psi(g_1, ..., g_t)$ be the functions to calculate
34
momentum and adaptive learning rate.
35
For Adam, they are
36
37
\begin{align}
38
\sigma(g_1, ..., g_t) &= \frac{(1 - \beta_1)\sum_{i=1}^t \beta_1^{t-i} g_i}{1 - \beta_1^t} \\
39
\psi(g_1, ..., g_t) &= \sqrt \frac{1 - \beta_2^t}{(1 - \beta_2)\sum_{i=1}^t \beta_2^{t-i} g_i^2}
40
\end{align}
41
42
### Exponential moving average as simple moving average
43
44
The distribution of exponential moving average can be approximated as a simple moving average.
45
46
\begin{align}
47
p\Bigg(\frac{(1-\beta_2) \sum_{i=1}^t \beta_2^{t-i} g_i^2}{1 - \beta_2^t} \Bigg) \approx
48
p\Bigg(\frac{\sum_{i=1}^{f(t,\beta_2)} g_{t+1-i}^2}{f(t,\beta_2)} \Bigg)
49
\end{align}
50
51
Here we are taking the simple moving average of the last $f(t,\beta_2)$ gradients.
52
$f(t,\beta_2)$ satisfies the following,
53
54
\begin{align}
55
\frac{(1-\beta_2) \sum_{i=1}^t \beta_2^{t-i} \cdot i}{1 - \beta_2^t} =
56
\frac{\sum_{i=1}^{f(t,\beta_2)} (t+1-i)}{f(t,\beta_2)}
57
\end{align}
58
59
which gives,
60
$$f(t,\beta_2) = \frac{2}{1-\beta_2} - 1 - \frac{2 t \beta_2^t}{1 - \beta_2^t}$$
61
62
### Scaled inverse chi-squared
63
64
From above we have
65
$$
66
p\Big( \psi^2(g_1, ..., g_t) \Big) \approx
67
p\Bigg(\frac{\sum_{i=1}^{f(t,\beta_2)} g_{t+1-i}^2}{f(t,\beta_2)} \Bigg)
68
$$
69
where $g_i \sim \mathcal{N}(0, \sigma^2)$.
70
Note that $sigma$ here is the standard deviation and different from $\sigma(.)$ for momentum.
71
72
[Scaled inverse chi-squared](https://en.wikipedia.org/wiki/Scaled_inverse_chi-squared_distribution)
73
is the distribution of squared inverse of mean of $p$ normal distributions.
74
$$
75
p\Bigg(\frac{\sum_{i=1}^{f(t,\beta_2)} g_{t+1-i}^2}{f(t,\beta_2)} \Bigg)
76
\sim
77
\text{Scale-inv} \mathcal{X}^2(\rho,\frac{1}{\sigma^2})
78
$$
79
where $\rho = f(t,\beta_2)$.
80
81
### Rectification
82
83
They prove that variance of $\psi(.)$ decreases with $\rho$ when
84
$\psi^2(.) \sim \text{Scale-inv} \mathcal{X}^2(\rho,\frac{1}{\sigma^2})$.
85
86
Therefore the variance is minimized at maximal $\rho$ which is
87
$\rho_{\infty} = \frac{2}{1-\beta_2} - 1$. Let the minimum variance be $C_{\text{var}}$
88
89
In order to ensure that the adaptive learning
90
rate $\psi(.)$ has consistent variance, we rectify the variance with $r$
91
92
\begin{align}
93
r = \sqrt{\frac{C_{\text{var}}}{Var\big[\psi(.)\big]}}
94
\end{align}
95
96
### Approximating $Var[\psi(.)]$
97
98
They estimate $Var[\psi(.)] \approx \frac{Var[\psi^2(.)]}{4 \mathbb{E}[\psi^2(.)}$
99
based on first order expansion of $\sqrt{\psi^2(.)}$
100
🤪 I didn't get how it was derived.
101
102
From $\text{Scale-inv} \mathcal{X}^2$ distribution we have,
103
104
\begin{align}
105
\mathbb{E}\big[\psi^2(.)\big] &= \frac{\rho / \sigma^2}{\rho-2} \\
106
Var\big[\psi^2(.)\big] &= \frac{2 \rho / \sigma^4}{(\rho-2)^2 (\rho - 2)}
107
\end{align}
108
109
which gives,
110
$$
111
Var[\psi(.)] \approx \frac{\rho}{2(\rho-2)(\rho-4)\sigma^2}
112
$$
113
114
### Rectification term
115
116
We have
117
118
\begin{align}
119
r &= \sqrt{\frac{C_{\text{var}}}{Var\big[\psi(.)\big]}} \\
120
Var[\psi(.)] &\approx \frac{\rho}{2(\rho-2)(\rho-4)\sigma^2}
121
\end{align}
122
123
where $C_{\text{var}}$ is $Var\big[\psi(.)\big]$ for $\rho_\infty$.
124
Lt $\rho$ and step $t$ be $\rho_t$, and $r_t$ be the rectification term
125
at step $t$.
126
127
\begin{align}
128
C_{\text{var}} &\approx \frac{\rho_\infty}{2(\rho_\infty-2)(\rho_\infty-4)\sigma^2} \\
129
Var[\psi(g_1,...,g_t)] &\approx \frac{\rho_t}{2(\rho_t-2)(\rho_t-4)\sigma^2}
130
\end{align}
131
132
This gives,
133
134
\begin{align}
135
r_t &= \sqrt{\frac{(\rho_t-2)(\rho_t-4)\rho_\infty}{(\rho_\infty-2)(\rho_\infty-4)\rho_t}}
136
\end{align}
137
"""
138
139
import math
140
from typing import Dict, Optional
141
142
import torch
143
144
from labml_nn.optimizers import WeightDecay
145
from labml_nn.optimizers.amsgrad import AMSGrad
146
147
148
class RAdam(AMSGrad):
149
"""
150
## Rectified Adam Optimizer
151
152
This class extends from AMSAdam optimizer defined in [`amsadam.py`](amsadam.html).
153
"""
154
155
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
156
weight_decay: WeightDecay = WeightDecay(),
157
optimized_update: bool = True,
158
amsgrad=False,
159
degenerated_to_sgd=True, defaults=None):
160
"""
161
### Initialize the optimizer
162
163
* `params` is the list of parameters
164
* `lr` is the learning rate $\alpha$
165
* `betas` is a tuple of ($\beta_1$, $\beta_2$)
166
* `eps` is $\hat{\epsilon}$ or $\epsilon$ based on `optimized_update`
167
* `weight_decay` is an instance of class `WeightDecay` defined in [`__init__.py`](index.html)
168
* `optimized_update` is a flag whether to optimize the bias correction of the second moment
169
by doing it after adding $\epsilon$
170
* `amsgrad` is a flag indicating whether to use AMSGrad or fallback to plain Adam
171
* `degenerate_to_sgd` whether to use sgd when the rectification term $r_t$ is intractable.
172
* `defaults` is a dictionary of default for group values.
173
This is useful when you want to extend the class `RAdam`.
174
"""
175
self.degenerated_to_sgd = degenerated_to_sgd
176
super().__init__(params, lr, betas, eps, weight_decay, optimized_update, amsgrad, defaults)
177
178
def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):
179
"""
180
### Take an update step for a given parameter tensor
181
182
* `state` is the optimizer state of the parameter (tensor)
183
* `group` stores optimizer attributes of the parameter group
184
* `grad` is the current gradient tensor $g_t$ for the parameter $\theta_{t-1}$
185
* `param` is the parameter tensor $\theta_{t-1}$
186
"""
187
188
# Calculate weight decay
189
grad = self.weight_decay(param, grad, group)
190
191
# Get $m_t$ and $v_t$; i.e. $\sigma(.)$ and $\psi(.)$ without bias correction
192
m, v = self.get_mv(state, group, grad)
193
194
# Calculate $t$ the number of optimizer steps
195
state['step'] += 1
196
197
# Perform *RAdam* update
198
self.r_adam_update(state, group, param, m, v)
199
200
@staticmethod
201
def calc_rectification_term(beta2: float, step: int) -> Optional[float]:
202
"""
203
### Calculate rectification term $r_t$
204
"""
205
206
# $\beta_2^t$
207
beta2_t = beta2 ** step
208
# $$\rho_\infty = \frac{2}{1 - \beta_2} - 1$$
209
rho_inf = 2 / (1 - beta2) - 1
210
# $$\rho_t = \frac{2}{1-\beta_2} - 1 - \frac{2 t \beta_2^t}{1-\beta_2^t}$$
211
rho = rho_inf - 2 * step * beta2_t / (1 - beta2_t)
212
213
# $r_t$ is tractable when $\rho_t >= 4$.
214
# We are being a little more conservative since it's an approximated value
215
if rho >= 5:
216
# $$r_t = \sqrt{\frac{(\rho_t-2)(\rho_t-4)\rho_\infty}{(\rho_\infty-2)(\rho_\infty-4)\rho_t}}$$
217
r2 = (rho - 4) / (rho_inf - 4) * (rho - 2) / rho * rho_inf / (rho_inf - 2)
218
return math.sqrt(r2)
219
else:
220
return None
221
222
def r_adam_update(self, state: Dict[str, any], group: Dict[str, any], param: torch.nn.Parameter,
223
m: torch.Tensor, v: torch.Tensor):
224
"""
225
### Do the *RAdam* parameter update
226
227
* `state` is the optimizer state of the parameter (tensor)
228
* `group` stores optimizer attributes of the parameter group
229
* `param` is the parameter tensor $\theta_{t-1}$
230
* `m` and `v` are the uncorrected first and second moments $m_t$ and $v_t$;
231
i.e. $\sigma(.)$ and $\psi(.)$ without bias correction
232
"""
233
234
# Get $\beta_1$ and $\beta_2$
235
beta1, beta2 = group['betas']
236
# Bias correction term for $\hat{m}_t$, $1 - \beta_1^t$
237
bias_correction1 = 1 - beta1 ** state['step']
238
# Bias correction term for $\hat{v}_t$, $1 - \beta_2^t$
239
bias_correction2 = 1 - beta2 ** state['step']
240
241
r = self.calc_rectification_term(beta2, state['step'])
242
243
# Get learning rate
244
lr = self.get_lr(state, group)
245
246
# If $r_t$ is intractable
247
if r is not None:
248
# Whether to optimize the computation by combining scalar computations
249
if self.optimized_update:
250
# Denominator $\sqrt{v_t} + \hat{\epsilon}$
251
denominator = v.sqrt().add_(group['eps'])
252
# Step size $\alpha \sqrt{r_t} * \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t}$
253
step_size = lr * math.sqrt(bias_correction2) * r / bias_correction1
254
# Update parameters $\theta_t \leftarrow \theta_{t-1} - \alpha \sqrt{r_t} \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \cdot
255
# \frac{m_t}{\sqrt{v_t} + \hat{\epsilon}}$
256
param.data.addcdiv_(m, denominator, value=-step_size)
257
# Computation without optimization
258
else:
259
# Denominator $\frac{\sqrt{v_t}}{\sqrt{1-\beta_2^t}} + \epsilon$
260
denominator = (v.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
261
# Step size $\frac{\alpha \sqrt{r_t}}{1-\beta_1^t}$
262
step_size = lr * r / bias_correction1
263
# Update parameters $\theta_t \leftarrow \theta_{t-1} - \alpha \sqrt{r_t} \cdot
264
# \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}$
265
param.data.addcdiv_(m, denominator, value=-step_size)
266
267
# If $r_t$ is intractable do a SGD with momentum
268
elif self.degenerated_to_sgd:
269
# Step size $\frac{\alpha}{1-\beta_1^t}$
270
step_size = lr / bias_correction1
271
# Update parameters
272
# $\theta_t \leftarrow \theta_{t-1} - \alpha \cdot \hat{m}_t$
273
param.data.add_(m, alpha=-step_size)
274
275
276
def _test_rectification_term():
277
"""
278
### Plot $r_t$ against $t$ for various $\beta_2$
279
280
![Plot of r_t](radam_r_t.png)
281
"""
282
import matplotlib.pyplot as plt
283
import numpy as np
284
285
beta2 = [0.9999, 0.999, 0.99, 0.9, 0.8, 0.6, 0.5]
286
plt.plot(np.arange(1, 5_000), [[RAdam.calc_rectification_term(b, i) for b in beta2] for i in range(1, 5_000)])
287
plt.legend(beta2)
288
plt.title("Optimizer")
289
plt.show()
290
291
292
if __name__ == '__main__':
293
_test_rectification_term()
294
295