Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
labmlai
GitHub Repository: labmlai/annotated_deep_learning_paper_implementations
Path: blob/master/labml_nn/optimizers/ada_belief.py
4922 views
1
"""
2
---
3
title: AdaBelief optimizer
4
summary: A simple PyTorch implementation/tutorial of AdaBelief optimizer.
5
---
6
7
# AdaBelief Optimizer
8
9
This is based from AdaBelief
10
[official implementation](https://github.com/juntang-zhuang/Adabelief-Optimizer)
11
of the paper
12
[AdaBelief Optimizer: Adapting Stepsizes by the Belief in Observed Gradients](https://arxiv.org/abs/2010.07468).
13
14
This is implemented in [PyTorch](https://pytorch.org) as an extension to [RAdam](radam.html).
15
16
The main difference between Adam optimizer and AdaBelief is that,
17
how it calculates the adaptive learning rate;
18
instead of dividing by the exponential moving average of square of the gradients,
19
AdaBelief divides by the exponential mean of variance.
20
21
\begin{align}
22
m_t &\leftarrow \beta_1 m_{t-1} + (1 - \beta_1) \cdot g_t \\
23
\textcolor{cyan}{s_t} &\textcolor{cyan}{\leftarrow} \textcolor{cyan}{\beta_2 s_{t-1} + (1 - \beta_2) \cdot (g_t - m_t)^2} \\
24
\hat{m}_t &\leftarrow \frac{m_t}{1-\beta_1^t} \\
25
\textcolor{cyan}{\hat{s}_t} &\textcolor{cyan}{\leftarrow} \frac{\textcolor{cyan}{s_t} + \textcolor{red}{\epsilon}}{\textcolor{cyan}{1-\beta_2^t}} \\
26
\theta_t &\leftarrow \theta_{t-1} - \alpha \cdot \frac{\hat{m}_t}{\sqrt{\textcolor{cyan}{\hat{s}_t}} + \epsilon}
27
\end{align}
28
29
šŸ¤” The paper calculates variance as $(g_t - m_t)^2$,
30
but I feel it should use the bias corrected momentum
31
$(g_t - \textcolor{orange}{\hat{m}_t})^2$.
32
I guess this doesn't affect things much because
33
bias correction is $\approx 1$ after the initial training steps.
34
"""
35
36
from typing import Dict, Any
37
38
import torch
39
from torch import nn
40
41
from labml_nn.optimizers import WeightDecay
42
from labml_nn.optimizers.radam import RAdam
43
44
45
class AdaBelief(RAdam):
46
"""
47
## AdaBelief Optimizer
48
49
This class extends from RAdam optimizer defined in [`radam.py`](radam.html).
50
"""
51
52
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,
53
weight_decay: WeightDecay = WeightDecay(), amsgrad=False,
54
degenerate_to_sgd=True,
55
rectify=True, defaults=None):
56
"""
57
### Initialize the optimizer
58
59
* `params` is the list of parameters
60
* `lr` is the learning rate $\alpha$
61
* `betas` is a tuple of ($\beta_1$, $\beta_2$)
62
* `eps` is $\hat{\epsilon}$ or $\epsilon$ based on `optimized_update`
63
* `weight_decay` is an instance of class `WeightDecay` defined in [`__init__.py`](index.html)
64
* `optimized_update` is a flag whether to optimize the bias correction of the second moment
65
by doing it after adding $\epsilon$
66
* `amsgrad` is a flag indicating whether to use AMSGrad or fallback to plain Adam
67
* `degenerate_to_sgd` whether to use sgd when the rectification term $r_t$ is intractable
68
* `rectify` is whether to use RAdam update
69
* `defaults` is a dictionary of default for group values.
70
This is useful when you want to extend the class `AdaBelief`.
71
"""
72
73
defaults = {} if defaults is None else defaults
74
super().__init__(params, lr, betas, eps, weight_decay, amsgrad, degenerate_to_sgd, defaults)
75
self.rectify = rectify
76
77
def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):
78
"""
79
### Initialize a parameter state
80
81
* `state` is the optimizer state of the parameter (tensor)
82
* `group` stores optimizer attributes of the parameter group
83
* `param` is the parameter tensor $\theta_{t-1}$
84
"""
85
state['step'] = 0
86
# Exponential moving average of gradient values
87
state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format)
88
# Exponential moving average of variance
89
state['exp_avg_var'] = torch.zeros_like(param, memory_format=torch.preserve_format)
90
91
# If `amsgrad` flag is `True` for this parameter group, we maintain the maximum of
92
# exponential moving average of variance
93
if group['amsgrad']:
94
# Maintains max of all exp. moving avg. of sq. grad. values
95
state['max_exp_avg_var'] = torch.zeros_like(param, memory_format=torch.preserve_format)
96
97
def get_ms(self, state: Dict[str, Any], group: Dict[str, Any], grad: torch.Tensor):
98
"""
99
### Calculate $m_t$ and $s_t$ or $\max(s_1, s_2, ..., s_{t-1}, s_t)$
100
101
* `state` is the optimizer state of the parameter (tensor)
102
* `group` stores optimizer attributes of the parameter group
103
* `grad` is the current gradient tensor $g_t$ for the parameter $\theta_{t-1}$
104
"""
105
106
# Get $\beta_1$ and $\beta_2$
107
beta1, beta2 = group['betas']
108
109
# Get $m_{t-1}$ and $s_{t-1}$
110
m, s = state['exp_avg'], state['exp_avg_var']
111
112
# In-place calculation of $m_t$
113
# $$m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) \cdot g_t$$
114
m.mul_(beta1).add_(grad, alpha=1 - beta1)
115
# Difference between gradient and momentum
116
grad_residual = grad - m
117
# In-place calculation of $s_t$
118
# $$s_t \leftarrow \beta_2 s_{t-1} + (1 - \beta_2) \cdot (g_t - m_t)^2$$
119
s.mul_(beta2).addcmul_(grad_residual, grad_residual, value=1 - beta2)
120
121
# If this parameter group is using `amsgrad`
122
if group['amsgrad']:
123
# Get $\max(s_1, s_2, ..., s_{t-1})$.
124
s_max = state['max_exp_avg_var']
125
# Calculate $\max(s_1, s_2, ..., s_{t-1}, s_t)$.
126
torch.maximum(s_max, s, out=s_max)
127
128
return m, s_max
129
else:
130
# $m_t$ and $s_t$ otherwise
131
return m, s
132
133
def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):
134
"""
135
### Take an update step for a given parameter tensor
136
137
* `state` is the optimizer state of the parameter (tensor)
138
* `group` stores optimizer attributes of the parameter group
139
* `grad` is the current gradient tensor $g_t$ for the parameter $\theta_{t-1}$
140
* `param` is the parameter tensor $\theta_{t-1}$
141
"""
142
143
# Calculate weight decay
144
grad = self.weight_decay(param, grad, group)
145
146
# Get $m_t$ and $v_t$
147
m, s = self.get_ms(state, group, grad)
148
149
# Increment $t$ the number of optimizer steps
150
state['step'] += 1
151
152
if not self.rectify:
153
# Perform *Adam* update, defined in [`adam.py`](adam.html), with
154
# $\textcolor{cyan}{s_t} + \textcolor{red}{\epsilon}$ in place of $v_t$.
155
self.adam_update(state, group, param, m, s + group['eps'])
156
else:
157
# Perform *Rectified Adam* update defined in [`radam.py`](radam.html), with
158
# $\textcolor{cyan}{s_t} + \textcolor{red}{\epsilon}$ in place of $v_t$.
159
self.r_adam_update(state, group, param, m, s + group['eps'])
160
161