Path: blob/master/labml_nn/optimizers/ada_belief.py
4922 views
"""1---2title: AdaBelief optimizer3summary: A simple PyTorch implementation/tutorial of AdaBelief optimizer.4---56# AdaBelief Optimizer78This is based from AdaBelief9[official implementation](https://github.com/juntang-zhuang/Adabelief-Optimizer)10of the paper11[AdaBelief Optimizer: Adapting Stepsizes by the Belief in Observed Gradients](https://arxiv.org/abs/2010.07468).1213This is implemented in [PyTorch](https://pytorch.org) as an extension to [RAdam](radam.html).1415The main difference between Adam optimizer and AdaBelief is that,16how it calculates the adaptive learning rate;17instead of dividing by the exponential moving average of square of the gradients,18AdaBelief divides by the exponential mean of variance.1920\begin{align}21m_t &\leftarrow \beta_1 m_{t-1} + (1 - \beta_1) \cdot g_t \\22\textcolor{cyan}{s_t} &\textcolor{cyan}{\leftarrow} \textcolor{cyan}{\beta_2 s_{t-1} + (1 - \beta_2) \cdot (g_t - m_t)^2} \\23\hat{m}_t &\leftarrow \frac{m_t}{1-\beta_1^t} \\24\textcolor{cyan}{\hat{s}_t} &\textcolor{cyan}{\leftarrow} \frac{\textcolor{cyan}{s_t} + \textcolor{red}{\epsilon}}{\textcolor{cyan}{1-\beta_2^t}} \\25\theta_t &\leftarrow \theta_{t-1} - \alpha \cdot \frac{\hat{m}_t}{\sqrt{\textcolor{cyan}{\hat{s}_t}} + \epsilon}26\end{align}2728š¤ The paper calculates variance as $(g_t - m_t)^2$,29but I feel it should use the bias corrected momentum30$(g_t - \textcolor{orange}{\hat{m}_t})^2$.31I guess this doesn't affect things much because32bias correction is $\approx 1$ after the initial training steps.33"""3435from typing import Dict, Any3637import torch38from torch import nn3940from labml_nn.optimizers import WeightDecay41from labml_nn.optimizers.radam import RAdam424344class AdaBelief(RAdam):45"""46## AdaBelief Optimizer4748This class extends from RAdam optimizer defined in [`radam.py`](radam.html).49"""5051def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,52weight_decay: WeightDecay = WeightDecay(), amsgrad=False,53degenerate_to_sgd=True,54rectify=True, defaults=None):55"""56### Initialize the optimizer5758* `params` is the list of parameters59* `lr` is the learning rate $\alpha$60* `betas` is a tuple of ($\beta_1$, $\beta_2$)61* `eps` is $\hat{\epsilon}$ or $\epsilon$ based on `optimized_update`62* `weight_decay` is an instance of class `WeightDecay` defined in [`__init__.py`](index.html)63* `optimized_update` is a flag whether to optimize the bias correction of the second moment64by doing it after adding $\epsilon$65* `amsgrad` is a flag indicating whether to use AMSGrad or fallback to plain Adam66* `degenerate_to_sgd` whether to use sgd when the rectification term $r_t$ is intractable67* `rectify` is whether to use RAdam update68* `defaults` is a dictionary of default for group values.69This is useful when you want to extend the class `AdaBelief`.70"""7172defaults = {} if defaults is None else defaults73super().__init__(params, lr, betas, eps, weight_decay, amsgrad, degenerate_to_sgd, defaults)74self.rectify = rectify7576def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):77"""78### Initialize a parameter state7980* `state` is the optimizer state of the parameter (tensor)81* `group` stores optimizer attributes of the parameter group82* `param` is the parameter tensor $\theta_{t-1}$83"""84state['step'] = 085# Exponential moving average of gradient values86state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format)87# Exponential moving average of variance88state['exp_avg_var'] = torch.zeros_like(param, memory_format=torch.preserve_format)8990# If `amsgrad` flag is `True` for this parameter group, we maintain the maximum of91# exponential moving average of variance92if group['amsgrad']:93# Maintains max of all exp. moving avg. of sq. grad. values94state['max_exp_avg_var'] = torch.zeros_like(param, memory_format=torch.preserve_format)9596def get_ms(self, state: Dict[str, Any], group: Dict[str, Any], grad: torch.Tensor):97"""98### Calculate $m_t$ and $s_t$ or $\max(s_1, s_2, ..., s_{t-1}, s_t)$99100* `state` is the optimizer state of the parameter (tensor)101* `group` stores optimizer attributes of the parameter group102* `grad` is the current gradient tensor $g_t$ for the parameter $\theta_{t-1}$103"""104105# Get $\beta_1$ and $\beta_2$106beta1, beta2 = group['betas']107108# Get $m_{t-1}$ and $s_{t-1}$109m, s = state['exp_avg'], state['exp_avg_var']110111# In-place calculation of $m_t$112# $$m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) \cdot g_t$$113m.mul_(beta1).add_(grad, alpha=1 - beta1)114# Difference between gradient and momentum115grad_residual = grad - m116# In-place calculation of $s_t$117# $$s_t \leftarrow \beta_2 s_{t-1} + (1 - \beta_2) \cdot (g_t - m_t)^2$$118s.mul_(beta2).addcmul_(grad_residual, grad_residual, value=1 - beta2)119120# If this parameter group is using `amsgrad`121if group['amsgrad']:122# Get $\max(s_1, s_2, ..., s_{t-1})$.123s_max = state['max_exp_avg_var']124# Calculate $\max(s_1, s_2, ..., s_{t-1}, s_t)$.125torch.maximum(s_max, s, out=s_max)126127return m, s_max128else:129# $m_t$ and $s_t$ otherwise130return m, s131132def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):133"""134### Take an update step for a given parameter tensor135136* `state` is the optimizer state of the parameter (tensor)137* `group` stores optimizer attributes of the parameter group138* `grad` is the current gradient tensor $g_t$ for the parameter $\theta_{t-1}$139* `param` is the parameter tensor $\theta_{t-1}$140"""141142# Calculate weight decay143grad = self.weight_decay(param, grad, group)144145# Get $m_t$ and $v_t$146m, s = self.get_ms(state, group, grad)147148# Increment $t$ the number of optimizer steps149state['step'] += 1150151if not self.rectify:152# Perform *Adam* update, defined in [`adam.py`](adam.html), with153# $\textcolor{cyan}{s_t} + \textcolor{red}{\epsilon}$ in place of $v_t$.154self.adam_update(state, group, param, m, s + group['eps'])155else:156# Perform *Rectified Adam* update defined in [`radam.py`](radam.html), with157# $\textcolor{cyan}{s_t} + \textcolor{red}{\epsilon}$ in place of $v_t$.158self.r_adam_update(state, group, param, m, s + group['eps'])159160161