Path: blob/master/labml_nn/optimizers/radam.py
4910 views
"""1---2title: Rectified Adam (RAdam) optimizer3summary: A simple PyTorch implementation/tutorial of RAdam optimizer.4---56# Rectified Adam (RAdam) optimizer78This implementation is based on9[the official implementation](https://github.com/LiyuanLucasLiu/RAdam)10of the paper11[On the Variance of the Adaptive Learning Rate and Beyond](https://arxiv.org/abs/1908.03265).1213We have implemented it in [PyTorch](https://pytorch.org)14as an extension to [our AMSGrad implementation](amsgrad.html)15thus requiring only the modifications to be implemented.1617Adam optimizer sometimes converges to a bad local optima during the initial stages of the training;18especially when training transformers.19Researches use warmups to counter this; for the the initial training steps (warm-up stage)20they use a low learning rate.21This paper identifies the problem to be the high variance of adaptive learning rate22during initial stages of training, and counters it using a new rectification term to23reduce variance.2425The paper also evaluates two variance reduction mechanisms:26* **Adam-2k**: Only compute the adaptive learning rate ($v_t$ in [Adam](adam.html)) during the first 2k steps,27without changing parameters or calculating momentum ($m_t$).28* **Adam-eps**: Adam with large $\epsilon \approx 10^{-4}$.2930## Rectified Adam3132Let $\sigma(g_1, ..., g_t)$ and $\psi(g_1, ..., g_t)$ be the functions to calculate33momentum and adaptive learning rate.34For Adam, they are3536\begin{align}37\sigma(g_1, ..., g_t) &= \frac{(1 - \beta_1)\sum_{i=1}^t \beta_1^{t-i} g_i}{1 - \beta_1^t} \\38\psi(g_1, ..., g_t) &= \sqrt \frac{1 - \beta_2^t}{(1 - \beta_2)\sum_{i=1}^t \beta_2^{t-i} g_i^2}39\end{align}4041### Exponential moving average as simple moving average4243The distribution of exponential moving average can be approximated as a simple moving average.4445\begin{align}46p\Bigg(\frac{(1-\beta_2) \sum_{i=1}^t \beta_2^{t-i} g_i^2}{1 - \beta_2^t} \Bigg) \approx47p\Bigg(\frac{\sum_{i=1}^{f(t,\beta_2)} g_{t+1-i}^2}{f(t,\beta_2)} \Bigg)48\end{align}4950Here we are taking the simple moving average of the last $f(t,\beta_2)$ gradients.51$f(t,\beta_2)$ satisfies the following,5253\begin{align}54\frac{(1-\beta_2) \sum_{i=1}^t \beta_2^{t-i} \cdot i}{1 - \beta_2^t} =55\frac{\sum_{i=1}^{f(t,\beta_2)} (t+1-i)}{f(t,\beta_2)}56\end{align}5758which gives,59$$f(t,\beta_2) = \frac{2}{1-\beta_2} - 1 - \frac{2 t \beta_2^t}{1 - \beta_2^t}$$6061### Scaled inverse chi-squared6263From above we have64$$65p\Big( \psi^2(g_1, ..., g_t) \Big) \approx66p\Bigg(\frac{\sum_{i=1}^{f(t,\beta_2)} g_{t+1-i}^2}{f(t,\beta_2)} \Bigg)67$$68where $g_i \sim \mathcal{N}(0, \sigma^2)$.69Note that $sigma$ here is the standard deviation and different from $\sigma(.)$ for momentum.7071[Scaled inverse chi-squared](https://en.wikipedia.org/wiki/Scaled_inverse_chi-squared_distribution)72is the distribution of squared inverse of mean of $p$ normal distributions.73$$74p\Bigg(\frac{\sum_{i=1}^{f(t,\beta_2)} g_{t+1-i}^2}{f(t,\beta_2)} \Bigg)75\sim76\text{Scale-inv} \mathcal{X}^2(\rho,\frac{1}{\sigma^2})77$$78where $\rho = f(t,\beta_2)$.7980### Rectification8182They prove that variance of $\psi(.)$ decreases with $\rho$ when83$\psi^2(.) \sim \text{Scale-inv} \mathcal{X}^2(\rho,\frac{1}{\sigma^2})$.8485Therefore the variance is minimized at maximal $\rho$ which is86$\rho_{\infty} = \frac{2}{1-\beta_2} - 1$. Let the minimum variance be $C_{\text{var}}$8788In order to ensure that the adaptive learning89rate $\psi(.)$ has consistent variance, we rectify the variance with $r$9091\begin{align}92r = \sqrt{\frac{C_{\text{var}}}{Var\big[\psi(.)\big]}}93\end{align}9495### Approximating $Var[\psi(.)]$9697They estimate $Var[\psi(.)] \approx \frac{Var[\psi^2(.)]}{4 \mathbb{E}[\psi^2(.)}$98based on first order expansion of $\sqrt{\psi^2(.)}$99🤪 I didn't get how it was derived.100101From $\text{Scale-inv} \mathcal{X}^2$ distribution we have,102103\begin{align}104\mathbb{E}\big[\psi^2(.)\big] &= \frac{\rho / \sigma^2}{\rho-2} \\105Var\big[\psi^2(.)\big] &= \frac{2 \rho / \sigma^4}{(\rho-2)^2 (\rho - 2)}106\end{align}107108which gives,109$$110Var[\psi(.)] \approx \frac{\rho}{2(\rho-2)(\rho-4)\sigma^2}111$$112113### Rectification term114115We have116117\begin{align}118r &= \sqrt{\frac{C_{\text{var}}}{Var\big[\psi(.)\big]}} \\119Var[\psi(.)] &\approx \frac{\rho}{2(\rho-2)(\rho-4)\sigma^2}120\end{align}121122where $C_{\text{var}}$ is $Var\big[\psi(.)\big]$ for $\rho_\infty$.123Lt $\rho$ and step $t$ be $\rho_t$, and $r_t$ be the rectification term124at step $t$.125126\begin{align}127C_{\text{var}} &\approx \frac{\rho_\infty}{2(\rho_\infty-2)(\rho_\infty-4)\sigma^2} \\128Var[\psi(g_1,...,g_t)] &\approx \frac{\rho_t}{2(\rho_t-2)(\rho_t-4)\sigma^2}129\end{align}130131This gives,132133\begin{align}134r_t &= \sqrt{\frac{(\rho_t-2)(\rho_t-4)\rho_\infty}{(\rho_\infty-2)(\rho_\infty-4)\rho_t}}135\end{align}136"""137138import math139from typing import Dict, Optional140141import torch142143from labml_nn.optimizers import WeightDecay144from labml_nn.optimizers.amsgrad import AMSGrad145146147class RAdam(AMSGrad):148"""149## Rectified Adam Optimizer150151This class extends from AMSAdam optimizer defined in [`amsadam.py`](amsadam.html).152"""153154def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,155weight_decay: WeightDecay = WeightDecay(),156optimized_update: bool = True,157amsgrad=False,158degenerated_to_sgd=True, defaults=None):159"""160### Initialize the optimizer161162* `params` is the list of parameters163* `lr` is the learning rate $\alpha$164* `betas` is a tuple of ($\beta_1$, $\beta_2$)165* `eps` is $\hat{\epsilon}$ or $\epsilon$ based on `optimized_update`166* `weight_decay` is an instance of class `WeightDecay` defined in [`__init__.py`](index.html)167* `optimized_update` is a flag whether to optimize the bias correction of the second moment168by doing it after adding $\epsilon$169* `amsgrad` is a flag indicating whether to use AMSGrad or fallback to plain Adam170* `degenerate_to_sgd` whether to use sgd when the rectification term $r_t$ is intractable.171* `defaults` is a dictionary of default for group values.172This is useful when you want to extend the class `RAdam`.173"""174self.degenerated_to_sgd = degenerated_to_sgd175super().__init__(params, lr, betas, eps, weight_decay, optimized_update, amsgrad, defaults)176177def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):178"""179### Take an update step for a given parameter tensor180181* `state` is the optimizer state of the parameter (tensor)182* `group` stores optimizer attributes of the parameter group183* `grad` is the current gradient tensor $g_t$ for the parameter $\theta_{t-1}$184* `param` is the parameter tensor $\theta_{t-1}$185"""186187# Calculate weight decay188grad = self.weight_decay(param, grad, group)189190# Get $m_t$ and $v_t$; i.e. $\sigma(.)$ and $\psi(.)$ without bias correction191m, v = self.get_mv(state, group, grad)192193# Calculate $t$ the number of optimizer steps194state['step'] += 1195196# Perform *RAdam* update197self.r_adam_update(state, group, param, m, v)198199@staticmethod200def calc_rectification_term(beta2: float, step: int) -> Optional[float]:201"""202### Calculate rectification term $r_t$203"""204205# $\beta_2^t$206beta2_t = beta2 ** step207# $$\rho_\infty = \frac{2}{1 - \beta_2} - 1$$208rho_inf = 2 / (1 - beta2) - 1209# $$\rho_t = \frac{2}{1-\beta_2} - 1 - \frac{2 t \beta_2^t}{1-\beta_2^t}$$210rho = rho_inf - 2 * step * beta2_t / (1 - beta2_t)211212# $r_t$ is tractable when $\rho_t >= 4$.213# We are being a little more conservative since it's an approximated value214if rho >= 5:215# $$r_t = \sqrt{\frac{(\rho_t-2)(\rho_t-4)\rho_\infty}{(\rho_\infty-2)(\rho_\infty-4)\rho_t}}$$216r2 = (rho - 4) / (rho_inf - 4) * (rho - 2) / rho * rho_inf / (rho_inf - 2)217return math.sqrt(r2)218else:219return None220221def r_adam_update(self, state: Dict[str, any], group: Dict[str, any], param: torch.nn.Parameter,222m: torch.Tensor, v: torch.Tensor):223"""224### Do the *RAdam* parameter update225226* `state` is the optimizer state of the parameter (tensor)227* `group` stores optimizer attributes of the parameter group228* `param` is the parameter tensor $\theta_{t-1}$229* `m` and `v` are the uncorrected first and second moments $m_t$ and $v_t$;230i.e. $\sigma(.)$ and $\psi(.)$ without bias correction231"""232233# Get $\beta_1$ and $\beta_2$234beta1, beta2 = group['betas']235# Bias correction term for $\hat{m}_t$, $1 - \beta_1^t$236bias_correction1 = 1 - beta1 ** state['step']237# Bias correction term for $\hat{v}_t$, $1 - \beta_2^t$238bias_correction2 = 1 - beta2 ** state['step']239240r = self.calc_rectification_term(beta2, state['step'])241242# Get learning rate243lr = self.get_lr(state, group)244245# If $r_t$ is intractable246if r is not None:247# Whether to optimize the computation by combining scalar computations248if self.optimized_update:249# Denominator $\sqrt{v_t} + \hat{\epsilon}$250denominator = v.sqrt().add_(group['eps'])251# Step size $\alpha \sqrt{r_t} * \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t}$252step_size = lr * math.sqrt(bias_correction2) * r / bias_correction1253# Update parameters $\theta_t \leftarrow \theta_{t-1} - \alpha \sqrt{r_t} \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \cdot254# \frac{m_t}{\sqrt{v_t} + \hat{\epsilon}}$255param.data.addcdiv_(m, denominator, value=-step_size)256# Computation without optimization257else:258# Denominator $\frac{\sqrt{v_t}}{\sqrt{1-\beta_2^t}} + \epsilon$259denominator = (v.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])260# Step size $\frac{\alpha \sqrt{r_t}}{1-\beta_1^t}$261step_size = lr * r / bias_correction1262# Update parameters $\theta_t \leftarrow \theta_{t-1} - \alpha \sqrt{r_t} \cdot263# \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}$264param.data.addcdiv_(m, denominator, value=-step_size)265266# If $r_t$ is intractable do a SGD with momentum267elif self.degenerated_to_sgd:268# Step size $\frac{\alpha}{1-\beta_1^t}$269step_size = lr / bias_correction1270# Update parameters271# $\theta_t \leftarrow \theta_{t-1} - \alpha \cdot \hat{m}_t$272param.data.add_(m, alpha=-step_size)273274275def _test_rectification_term():276"""277### Plot $r_t$ against $t$ for various $\beta_2$278279280"""281import matplotlib.pyplot as plt282import numpy as np283284beta2 = [0.9999, 0.999, 0.99, 0.9, 0.8, 0.6, 0.5]285plt.plot(np.arange(1, 5_000), [[RAdam.calc_rectification_term(b, i) for b in beta2] for i in range(1, 5_000)])286plt.legend(beta2)287plt.title("Optimizer")288plt.show()289290291if __name__ == '__main__':292_test_rectification_term()293294295