Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
labmlai
GitHub Repository: labmlai/annotated_deep_learning_paper_implementations
Path: blob/master/labml_nn/optimizers/adam.py
4910 views
1
"""
2
---
3
title: Adam Optimizer
4
summary: A simple PyTorch implementation/tutorial of Adam optimizer
5
---
6
7
# Adam Optimizer
8
9
This is a [PyTorch](https://pytorch.org) implementation of popular optimizer *Adam* from paper
10
[Adam: A Method for Stochastic Optimization](https://arxiv.org/abs/1412.6980).
11
12
*Adam* update is,
13
14
\begin{align}
15
m_t &\leftarrow \beta_1 m_{t-1} + (1 - \beta_1) \cdot g_t \\
16
v_t &\leftarrow \beta_2 v_{t-1} + (1 - \beta_2) \cdot g_t^2 \\
17
\hat{m}_t &\leftarrow \frac{m_t}{1-\beta_1^t} \\
18
\hat{v}_t &\leftarrow \frac{v_t}{1-\beta_2^t} \\
19
\theta_t &\leftarrow \theta_{t-1} - \alpha \cdot \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}
20
\end{align}
21
22
where $\alpha$, $\beta_1$, $\beta_2$ and $\epsilon$ are scalar hyper parameters.
23
$m_t$ and $v_t$ are first and second order moments.
24
$\hat{m}_t$ and $\hat{v}_t$ are biased corrected moments.
25
$\epsilon$ is used as a fix for division by zero error, but also acts as a form of a hyper-parameter
26
that acts against variance in gradients.
27
28
Effective step taken assuming $\epsilon = 0$ is,
29
$$\Delta t = \alpha \cdot \frac{\hat{m}_t}{\hat{v}_t}$$
30
This is bounded by,
31
$$\vert \Delta t \vert \le \alpha \cdot \frac{1 - \beta_1}{\sqrt{1-\beta_2}}$$
32
when $1-\beta_1 \gt \sqrt{1-\beta_2}$
33
and
34
$$\vert \Delta t\vert \le \alpha$$
35
otherwise.
36
And in most common scenarios,
37
$$\vert \Delta t \vert \approx \alpha$$
38
"""
39
40
import math
41
from typing import Dict, Any, Tuple, Optional
42
43
import torch
44
from labml import tracker
45
from torch import nn
46
47
from labml_nn.optimizers import GenericAdaptiveOptimizer, WeightDecay
48
49
50
class Adam(GenericAdaptiveOptimizer):
51
"""
52
## Adam Optimizer
53
54
We extend the class `GenericAdaptiveOptimizer` defined in [`__init__.py`](index.html)
55
to implement the Adam optimizer.
56
"""
57
58
def __init__(self, params,
59
lr: float = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-16,
60
weight_decay: WeightDecay = WeightDecay(),
61
optimized_update: bool = True,
62
defaults: Optional[Dict[str, Any]] = None):
63
"""
64
### Initialize the optimizer
65
66
* `params` is the list of parameters
67
* `lr` is the learning rate $\alpha$
68
* `betas` is a tuple of ($\beta_1$, $\beta_2$)
69
* `eps` is $\hat{\epsilon}$ or $\epsilon$ based on `optimized_update`
70
* `weight_decay` is an instance of class `WeightDecay` defined in [`__init__.py`](index.html)
71
* `optimized_update` is a flag whether to optimize the bias correction of the second moment
72
by doing it after adding $\epsilon$
73
* `defaults` is a dictionary of default for group values.
74
This is useful when you want to extend the class `Adam`.
75
"""
76
defaults = {} if defaults is None else defaults
77
defaults.update(weight_decay.defaults())
78
super().__init__(params, defaults, lr, betas, eps)
79
80
self.weight_decay = weight_decay
81
self.optimized_update = optimized_update
82
83
def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):
84
"""
85
### Initialize a parameter state
86
87
* `state` is the optimizer state of the parameter (tensor)
88
* `group` stores optimizer attributes of the parameter group
89
* `param` is the parameter tensor $\theta_{t-1}$
90
"""
91
92
# This is the number of optimizer steps taken on the parameter, $t$
93
state['step'] = 0
94
# Exponential moving average of gradients, $m_t$
95
state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format)
96
# Exponential moving average of squared gradient values, $v_t$
97
state['exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format)
98
99
def get_mv(self, state: Dict[str, Any], group: Dict[str, Any], grad: torch.Tensor):
100
"""
101
### Calculate $m_t$ and and $v_t$
102
103
* `state` is the optimizer state of the parameter (tensor)
104
* `group` stores optimizer attributes of the parameter group
105
* `grad` is the current gradient tensor $g_t$ for the parameter $\theta_{t-1}$
106
"""
107
108
# Get $\beta_1$ and $\beta_2$
109
beta1, beta2 = group['betas']
110
111
# Get $m_{t-1}$ and $v_{t-1}$
112
m, v = state['exp_avg'], state['exp_avg_sq']
113
114
# In-place calculation of $m_t$
115
# $$m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) \cdot g_t$$
116
m.mul_(beta1).add_(grad, alpha=1 - beta1)
117
# In-place calculation of $v_t$
118
# $$v_t \leftarrow \beta_2 v_{t-1} + (1 - \beta_2) \cdot g_t^2$$
119
v.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
120
121
return m, v
122
123
def get_lr(self, state: Dict[str, any], group: Dict[str, any]):
124
"""
125
### Get learning-rate
126
127
This returns the modified learning rate based on the state.
128
For *Adam* this is just the specified learning rate for the parameter group,
129
$\alpha$.
130
"""
131
return group['lr']
132
133
def adam_update(self, state: Dict[str, any], group: Dict[str, any], param: torch.nn.Parameter,
134
m: torch.Tensor, v: torch.Tensor):
135
"""
136
### Do the *Adam* parameter update
137
138
* `state` is the optimizer state of the parameter (tensor)
139
* `group` stores optimizer attributes of the parameter group
140
* `param` is the parameter tensor $\theta_{t-1}$
141
* `m` and `v` are the uncorrected first and second moments $m_t$ and $v_t$.
142
143
This computes the following
144
145
\begin{align}
146
\theta_t &\leftarrow \theta_{t-1} - \alpha \cdot \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}
147
\end{align}
148
149
Since $\alpha$, $\beta_1$, $\beta_2$ and $\epsilon$ are scalars and others are tensors
150
we modify this calculation to optimize the computation.
151
152
\begin{align}
153
\theta_t &\leftarrow \theta_{t-1} - \alpha \cdot \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} \\
154
\theta_t &\leftarrow \theta_{t-1} - \alpha \cdot
155
\frac{m_t / (1-\beta_1^t)}{\sqrt{v_t/(1-\beta_2^t)} + \epsilon} \\
156
\theta_t &\leftarrow \theta_{t-1} - \alpha \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \cdot
157
\frac{m_t}{\sqrt{v_t} + \hat{\epsilon}} \\
158
\end{align}
159
160
where
161
$$\hat{\epsilon} = (1-\beta_2^t) \epsilon$$
162
is what we should specify as the hyper-parameter.
163
"""
164
165
# Get $\beta_1$ and $\beta_2$
166
beta1, beta2 = group['betas']
167
# Bias correction term for $\hat{m}_t$, $1 - \beta_1^t$
168
bias_correction1 = 1 - beta1 ** state['step']
169
# Bias correction term for $\hat{v}_t$, $1 - \beta_2^t$
170
bias_correction2 = 1 - beta2 ** state['step']
171
172
# Get learning rate
173
lr = self.get_lr(state, group)
174
175
# Whether to optimize the computation
176
if self.optimized_update:
177
# $\sqrt{v_t} + \hat{\epsilon}$
178
denominator = v.sqrt().add_(group['eps'])
179
# $\alpha \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t}$
180
step_size = lr * math.sqrt(bias_correction2) / bias_correction1
181
# $\theta_t \leftarrow \theta_{t-1} - \alpha \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \cdot
182
# \frac{m_t}{\sqrt{v_t} + \hat{\epsilon}}$
183
param.data.addcdiv_(m, denominator, value=-step_size)
184
# Computation without optimization
185
else:
186
# $\frac{\sqrt{v_t}}{\sqrt{1-\beta_2^t}} + \epsilon$
187
denominator = (v.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
188
# $\frac{\alpha}{1-\beta_1^t}$
189
step_size = lr / bias_correction1
190
# $\theta_t \leftarrow \theta_{t-1} - \alpha \cdot
191
# \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}$
192
param.data.addcdiv_(m, denominator, value=-step_size)
193
194
def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):
195
"""
196
### Take an update step for a given parameter tensor
197
198
* `state` is the optimizer state of the parameter (tensor)
199
* `group` stores optimizer attributes of the parameter group
200
* `grad` is the current gradient tensor $g_t$ for the parameter $\theta_{t-1}$
201
* `param` is the parameter tensor $\theta_{t-1}$
202
"""
203
204
# Calculate weight decay
205
grad = self.weight_decay(param, grad, group)
206
207
# Get $m_t$ and $v_t$
208
m, v = self.get_mv(state, group, grad)
209
210
# Increment $t$ the number of optimizer steps
211
state['step'] += 1
212
213
# Perform *Adam* update
214
self.adam_update(state, group, param, m, v)
215
216