Path: blob/master/labml_nn/optimizers/amsgrad.py
4910 views
"""1---2title: AMSGrad Optimizer3summary: A simple PyTorch implementation/tutorial of AMSGrad optimizer.4---56# AMSGrad78This is a [PyTorch](https://pytorch.org) implementation of the paper9[On the Convergence of Adam and Beyond](https://arxiv.org/abs/1904.09237).1011We implement this as an extension to our [Adam optimizer implementation](adam.html).12The implementation it self is really small since it's very similar to Adam.1314We also have an implementation of the synthetic example described in the paper where Adam fails to converge.15"""1617from typing import Dict1819import torch20from torch import nn2122from labml_nn.optimizers import WeightDecay23from labml_nn.optimizers.adam import Adam242526class AMSGrad(Adam):27"""28## AMSGrad Optimizer2930This class extends from Adam optimizer defined in [`adam.py`](adam.html).31Adam optimizer is extending the class `GenericAdaptiveOptimizer`32defined in [`__init__.py`](index.html).33"""34def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,35weight_decay: WeightDecay = WeightDecay(),36optimized_update: bool = True,37amsgrad=True, defaults=None):38"""39### Initialize the optimizer4041* `params` is the list of parameters42* `lr` is the learning rate $\alpha$43* `betas` is a tuple of ($\beta_1$, $\beta_2$)44* `eps` is $\hat{\epsilon}$ or $\epsilon$ based on `optimized_update`45* `weight_decay` is an instance of class `WeightDecay` defined in [`__init__.py`](index.html)46* 'optimized_update' is a flag whether to optimize the bias correction of the second moment47by doing it after adding $\epsilon$48* `amsgrad` is a flag indicating whether to use AMSGrad or fallback to plain Adam49* `defaults` is a dictionary of default for group values.50This is useful when you want to extend the class `Adam`.51"""52defaults = {} if defaults is None else defaults53defaults.update(dict(amsgrad=amsgrad))5455super().__init__(params, lr, betas, eps, weight_decay, optimized_update, defaults)5657def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):58"""59### Initialize a parameter state6061* `state` is the optimizer state of the parameter (tensor)62* `group` stores optimizer attributes of the parameter group63* `param` is the parameter tensor $\theta_{t-1}$64"""6566# Call `init_state` of Adam optimizer which we are extending67super().init_state(state, group, param)6869# If `amsgrad` flag is `True` for this parameter group, we maintain the maximum of70# exponential moving average of squared gradient71if group['amsgrad']:72state['max_exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format)7374def get_mv(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor):75"""76### Calculate $m_t$ and and $v_t$ or $\max(v_1, v_2, ..., v_{t-1}, v_t)$7778* `state` is the optimizer state of the parameter (tensor)79* `group` stores optimizer attributes of the parameter group80* `grad` is the current gradient tensor $g_t$ for the parameter $\theta_{t-1}$81"""8283# Get $m_t$ and $v_t$ from *Adam*84m, v = super().get_mv(state, group, grad)8586# If this parameter group is using `amsgrad`87if group['amsgrad']:88# Get $\max(v_1, v_2, ..., v_{t-1})$.89#90# 🗒 The paper uses the notation $\hat{v}_t$ for this, which we don't use91# that here because it confuses with the Adam's usage of the same notation92# for bias corrected exponential moving average.93v_max = state['max_exp_avg_sq']94# Calculate $\max(v_1, v_2, ..., v_{t-1}, v_t)$.95#96# 🤔 I feel you should be taking / maintaining the max of the bias corrected97# second exponential average of squared gradient.98# But this is how it's99# [implemented in PyTorch also](https://github.com/pytorch/pytorch/blob/19f4c5110e8bcad5e7e75375194262fca0a6293a/torch/optim/functional.py#L90).100# I guess it doesn't really matter since bias correction only increases the value101# and it only makes an actual difference during the early few steps of the training.102torch.maximum(v_max, v, out=v_max)103104return m, v_max105else:106# Fall back to *Adam* if the parameter group is not using `amsgrad`107return m, v108109110def _synthetic_experiment(is_adam: bool):111"""112## Synthetic Experiment113114This is the synthetic experiment described in the paper,115that shows a scenario where *Adam* fails.116117The paper (and Adam) formulates the problem of optimizing as118minimizing the expected value of a function, $\mathbb{E}[f(\theta)]$119with respect to the parameters $\theta$.120In the stochastic training setting we do not get hold of the function $f$121it self; that is,122when you are optimizing a NN $f$ would be the function on entire123batch of data.124What we actually evaluate is a mini-batch so the actual function is125realization of the stochastic $f$.126This is why we are talking about an expected value.127So let the function realizations be $f_1, f_2, ..., f_T$ for each time step128of training.129130We measure the performance of the optimizer as the regret,131$$R(T) = \sum_{t=1}^T \big[ f_t(\theta_t) - f_t(\theta^*) \big]$$132where $\theta_t$ is the parameters at time step $t$, and $\theta^*$ is the133optimal parameters that minimize $\mathbb{E}[f(\theta)]$.134135Now lets define the synthetic problem,136137\begin{align}138f_t(x) =139\begin{cases}1401010 x, & \text{for } t \mod 101 = 1 \\141-10 x, & \text{otherwise}142\end{cases}143\end{align}144145where $-1 \le x \le +1$.146The optimal solution is $x = -1$.147148This code will try running *Adam* and *AMSGrad* on this problem.149"""150151# Define $x$ parameter152x = nn.Parameter(torch.tensor([.0]))153# Optimal, $x^* = -1$154x_star = nn.Parameter(torch.tensor([-1]), requires_grad=False)155156def func(t: int, x_: nn.Parameter):157"""158### $f_t(x)$159"""160if t % 101 == 1:161return (1010 * x_).sum()162else:163return (-10 * x_).sum()164165# Initialize the relevant optimizer166if is_adam:167optimizer = Adam([x], lr=1e-2, betas=(0.9, 0.99))168else:169optimizer = AMSGrad([x], lr=1e-2, betas=(0.9, 0.99))170# $R(T)$171total_regret = 0172173from labml import monit, tracker, experiment174175# Create experiment to record results176with experiment.record(name='synthetic', comment='Adam' if is_adam else 'AMSGrad'):177# Run for $10^7$ steps178for step in monit.loop(10_000_000):179# $f_t(\theta_t) - f_t(\theta^*)$180regret = func(step, x) - func(step, x_star)181# $R(T) = \sum_{t=1}^T \big[ f_t(\theta_t) - f_t(\theta^*) \big]$182total_regret += regret.item()183# Track results every 1,000 steps184if (step + 1) % 1000 == 0:185tracker.save(loss=regret, x=x, regret=total_regret / (step + 1))186# Calculate gradients187regret.backward()188# Optimize189optimizer.step()190# Clear gradients191optimizer.zero_grad()192193# Make sure $-1 \le x \le +1$194x.data.clamp_(-1., +1.)195196197if __name__ == '__main__':198# Run the synthetic experiment is *Adam*.199# You can see that Adam converges at $x = +1$200_synthetic_experiment(True)201# Run the synthetic experiment is *AMSGrad*202# You can see that AMSGrad converges to true optimal $x = -1$203_synthetic_experiment(False)204205206