Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
labmlai
GitHub Repository: labmlai/annotated_deep_learning_paper_implementations
Path: blob/master/labml_nn/optimizers/amsgrad.py
4910 views
1
"""
2
---
3
title: AMSGrad Optimizer
4
summary: A simple PyTorch implementation/tutorial of AMSGrad optimizer.
5
---
6
7
# AMSGrad
8
9
This is a [PyTorch](https://pytorch.org) implementation of the paper
10
[On the Convergence of Adam and Beyond](https://arxiv.org/abs/1904.09237).
11
12
We implement this as an extension to our [Adam optimizer implementation](adam.html).
13
The implementation it self is really small since it's very similar to Adam.
14
15
We also have an implementation of the synthetic example described in the paper where Adam fails to converge.
16
"""
17
18
from typing import Dict
19
20
import torch
21
from torch import nn
22
23
from labml_nn.optimizers import WeightDecay
24
from labml_nn.optimizers.adam import Adam
25
26
27
class AMSGrad(Adam):
28
"""
29
## AMSGrad Optimizer
30
31
This class extends from Adam optimizer defined in [`adam.py`](adam.html).
32
Adam optimizer is extending the class `GenericAdaptiveOptimizer`
33
defined in [`__init__.py`](index.html).
34
"""
35
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,
36
weight_decay: WeightDecay = WeightDecay(),
37
optimized_update: bool = True,
38
amsgrad=True, defaults=None):
39
"""
40
### Initialize the optimizer
41
42
* `params` is the list of parameters
43
* `lr` is the learning rate $\alpha$
44
* `betas` is a tuple of ($\beta_1$, $\beta_2$)
45
* `eps` is $\hat{\epsilon}$ or $\epsilon$ based on `optimized_update`
46
* `weight_decay` is an instance of class `WeightDecay` defined in [`__init__.py`](index.html)
47
* 'optimized_update' is a flag whether to optimize the bias correction of the second moment
48
by doing it after adding $\epsilon$
49
* `amsgrad` is a flag indicating whether to use AMSGrad or fallback to plain Adam
50
* `defaults` is a dictionary of default for group values.
51
This is useful when you want to extend the class `Adam`.
52
"""
53
defaults = {} if defaults is None else defaults
54
defaults.update(dict(amsgrad=amsgrad))
55
56
super().__init__(params, lr, betas, eps, weight_decay, optimized_update, defaults)
57
58
def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):
59
"""
60
### Initialize a parameter state
61
62
* `state` is the optimizer state of the parameter (tensor)
63
* `group` stores optimizer attributes of the parameter group
64
* `param` is the parameter tensor $\theta_{t-1}$
65
"""
66
67
# Call `init_state` of Adam optimizer which we are extending
68
super().init_state(state, group, param)
69
70
# If `amsgrad` flag is `True` for this parameter group, we maintain the maximum of
71
# exponential moving average of squared gradient
72
if group['amsgrad']:
73
state['max_exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format)
74
75
def get_mv(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor):
76
"""
77
### Calculate $m_t$ and and $v_t$ or $\max(v_1, v_2, ..., v_{t-1}, v_t)$
78
79
* `state` is the optimizer state of the parameter (tensor)
80
* `group` stores optimizer attributes of the parameter group
81
* `grad` is the current gradient tensor $g_t$ for the parameter $\theta_{t-1}$
82
"""
83
84
# Get $m_t$ and $v_t$ from *Adam*
85
m, v = super().get_mv(state, group, grad)
86
87
# If this parameter group is using `amsgrad`
88
if group['amsgrad']:
89
# Get $\max(v_1, v_2, ..., v_{t-1})$.
90
#
91
# 🗒 The paper uses the notation $\hat{v}_t$ for this, which we don't use
92
# that here because it confuses with the Adam's usage of the same notation
93
# for bias corrected exponential moving average.
94
v_max = state['max_exp_avg_sq']
95
# Calculate $\max(v_1, v_2, ..., v_{t-1}, v_t)$.
96
#
97
# 🤔 I feel you should be taking / maintaining the max of the bias corrected
98
# second exponential average of squared gradient.
99
# But this is how it's
100
# [implemented in PyTorch also](https://github.com/pytorch/pytorch/blob/19f4c5110e8bcad5e7e75375194262fca0a6293a/torch/optim/functional.py#L90).
101
# I guess it doesn't really matter since bias correction only increases the value
102
# and it only makes an actual difference during the early few steps of the training.
103
torch.maximum(v_max, v, out=v_max)
104
105
return m, v_max
106
else:
107
# Fall back to *Adam* if the parameter group is not using `amsgrad`
108
return m, v
109
110
111
def _synthetic_experiment(is_adam: bool):
112
"""
113
## Synthetic Experiment
114
115
This is the synthetic experiment described in the paper,
116
that shows a scenario where *Adam* fails.
117
118
The paper (and Adam) formulates the problem of optimizing as
119
minimizing the expected value of a function, $\mathbb{E}[f(\theta)]$
120
with respect to the parameters $\theta$.
121
In the stochastic training setting we do not get hold of the function $f$
122
it self; that is,
123
when you are optimizing a NN $f$ would be the function on entire
124
batch of data.
125
What we actually evaluate is a mini-batch so the actual function is
126
realization of the stochastic $f$.
127
This is why we are talking about an expected value.
128
So let the function realizations be $f_1, f_2, ..., f_T$ for each time step
129
of training.
130
131
We measure the performance of the optimizer as the regret,
132
$$R(T) = \sum_{t=1}^T \big[ f_t(\theta_t) - f_t(\theta^*) \big]$$
133
where $\theta_t$ is the parameters at time step $t$, and $\theta^*$ is the
134
optimal parameters that minimize $\mathbb{E}[f(\theta)]$.
135
136
Now lets define the synthetic problem,
137
138
\begin{align}
139
f_t(x) =
140
\begin{cases}
141
1010 x, & \text{for } t \mod 101 = 1 \\
142
-10 x, & \text{otherwise}
143
\end{cases}
144
\end{align}
145
146
where $-1 \le x \le +1$.
147
The optimal solution is $x = -1$.
148
149
This code will try running *Adam* and *AMSGrad* on this problem.
150
"""
151
152
# Define $x$ parameter
153
x = nn.Parameter(torch.tensor([.0]))
154
# Optimal, $x^* = -1$
155
x_star = nn.Parameter(torch.tensor([-1]), requires_grad=False)
156
157
def func(t: int, x_: nn.Parameter):
158
"""
159
### $f_t(x)$
160
"""
161
if t % 101 == 1:
162
return (1010 * x_).sum()
163
else:
164
return (-10 * x_).sum()
165
166
# Initialize the relevant optimizer
167
if is_adam:
168
optimizer = Adam([x], lr=1e-2, betas=(0.9, 0.99))
169
else:
170
optimizer = AMSGrad([x], lr=1e-2, betas=(0.9, 0.99))
171
# $R(T)$
172
total_regret = 0
173
174
from labml import monit, tracker, experiment
175
176
# Create experiment to record results
177
with experiment.record(name='synthetic', comment='Adam' if is_adam else 'AMSGrad'):
178
# Run for $10^7$ steps
179
for step in monit.loop(10_000_000):
180
# $f_t(\theta_t) - f_t(\theta^*)$
181
regret = func(step, x) - func(step, x_star)
182
# $R(T) = \sum_{t=1}^T \big[ f_t(\theta_t) - f_t(\theta^*) \big]$
183
total_regret += regret.item()
184
# Track results every 1,000 steps
185
if (step + 1) % 1000 == 0:
186
tracker.save(loss=regret, x=x, regret=total_regret / (step + 1))
187
# Calculate gradients
188
regret.backward()
189
# Optimize
190
optimizer.step()
191
# Clear gradients
192
optimizer.zero_grad()
193
194
# Make sure $-1 \le x \le +1$
195
x.data.clamp_(-1., +1.)
196
197
198
if __name__ == '__main__':
199
# Run the synthetic experiment is *Adam*.
200
# You can see that Adam converges at $x = +1$
201
_synthetic_experiment(True)
202
# Run the synthetic experiment is *AMSGrad*
203
# You can see that AMSGrad converges to true optimal $x = -1$
204
_synthetic_experiment(False)
205
206