Path: blob/master/labml_nn/optimizers/adam.py
4910 views
"""1---2title: Adam Optimizer3summary: A simple PyTorch implementation/tutorial of Adam optimizer4---56# Adam Optimizer78This is a [PyTorch](https://pytorch.org) implementation of popular optimizer *Adam* from paper9[Adam: A Method for Stochastic Optimization](https://arxiv.org/abs/1412.6980).1011*Adam* update is,1213\begin{align}14m_t &\leftarrow \beta_1 m_{t-1} + (1 - \beta_1) \cdot g_t \\15v_t &\leftarrow \beta_2 v_{t-1} + (1 - \beta_2) \cdot g_t^2 \\16\hat{m}_t &\leftarrow \frac{m_t}{1-\beta_1^t} \\17\hat{v}_t &\leftarrow \frac{v_t}{1-\beta_2^t} \\18\theta_t &\leftarrow \theta_{t-1} - \alpha \cdot \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}19\end{align}2021where $\alpha$, $\beta_1$, $\beta_2$ and $\epsilon$ are scalar hyper parameters.22$m_t$ and $v_t$ are first and second order moments.23$\hat{m}_t$ and $\hat{v}_t$ are biased corrected moments.24$\epsilon$ is used as a fix for division by zero error, but also acts as a form of a hyper-parameter25that acts against variance in gradients.2627Effective step taken assuming $\epsilon = 0$ is,28$$\Delta t = \alpha \cdot \frac{\hat{m}_t}{\hat{v}_t}$$29This is bounded by,30$$\vert \Delta t \vert \le \alpha \cdot \frac{1 - \beta_1}{\sqrt{1-\beta_2}}$$31when $1-\beta_1 \gt \sqrt{1-\beta_2}$32and33$$\vert \Delta t\vert \le \alpha$$34otherwise.35And in most common scenarios,36$$\vert \Delta t \vert \approx \alpha$$37"""3839import math40from typing import Dict, Any, Tuple, Optional4142import torch43from labml import tracker44from torch import nn4546from labml_nn.optimizers import GenericAdaptiveOptimizer, WeightDecay474849class Adam(GenericAdaptiveOptimizer):50"""51## Adam Optimizer5253We extend the class `GenericAdaptiveOptimizer` defined in [`__init__.py`](index.html)54to implement the Adam optimizer.55"""5657def __init__(self, params,58lr: float = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-16,59weight_decay: WeightDecay = WeightDecay(),60optimized_update: bool = True,61defaults: Optional[Dict[str, Any]] = None):62"""63### Initialize the optimizer6465* `params` is the list of parameters66* `lr` is the learning rate $\alpha$67* `betas` is a tuple of ($\beta_1$, $\beta_2$)68* `eps` is $\hat{\epsilon}$ or $\epsilon$ based on `optimized_update`69* `weight_decay` is an instance of class `WeightDecay` defined in [`__init__.py`](index.html)70* `optimized_update` is a flag whether to optimize the bias correction of the second moment71by doing it after adding $\epsilon$72* `defaults` is a dictionary of default for group values.73This is useful when you want to extend the class `Adam`.74"""75defaults = {} if defaults is None else defaults76defaults.update(weight_decay.defaults())77super().__init__(params, defaults, lr, betas, eps)7879self.weight_decay = weight_decay80self.optimized_update = optimized_update8182def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):83"""84### Initialize a parameter state8586* `state` is the optimizer state of the parameter (tensor)87* `group` stores optimizer attributes of the parameter group88* `param` is the parameter tensor $\theta_{t-1}$89"""9091# This is the number of optimizer steps taken on the parameter, $t$92state['step'] = 093# Exponential moving average of gradients, $m_t$94state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format)95# Exponential moving average of squared gradient values, $v_t$96state['exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format)9798def get_mv(self, state: Dict[str, Any], group: Dict[str, Any], grad: torch.Tensor):99"""100### Calculate $m_t$ and and $v_t$101102* `state` is the optimizer state of the parameter (tensor)103* `group` stores optimizer attributes of the parameter group104* `grad` is the current gradient tensor $g_t$ for the parameter $\theta_{t-1}$105"""106107# Get $\beta_1$ and $\beta_2$108beta1, beta2 = group['betas']109110# Get $m_{t-1}$ and $v_{t-1}$111m, v = state['exp_avg'], state['exp_avg_sq']112113# In-place calculation of $m_t$114# $$m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) \cdot g_t$$115m.mul_(beta1).add_(grad, alpha=1 - beta1)116# In-place calculation of $v_t$117# $$v_t \leftarrow \beta_2 v_{t-1} + (1 - \beta_2) \cdot g_t^2$$118v.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)119120return m, v121122def get_lr(self, state: Dict[str, any], group: Dict[str, any]):123"""124### Get learning-rate125126This returns the modified learning rate based on the state.127For *Adam* this is just the specified learning rate for the parameter group,128$\alpha$.129"""130return group['lr']131132def adam_update(self, state: Dict[str, any], group: Dict[str, any], param: torch.nn.Parameter,133m: torch.Tensor, v: torch.Tensor):134"""135### Do the *Adam* parameter update136137* `state` is the optimizer state of the parameter (tensor)138* `group` stores optimizer attributes of the parameter group139* `param` is the parameter tensor $\theta_{t-1}$140* `m` and `v` are the uncorrected first and second moments $m_t$ and $v_t$.141142This computes the following143144\begin{align}145\theta_t &\leftarrow \theta_{t-1} - \alpha \cdot \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}146\end{align}147148Since $\alpha$, $\beta_1$, $\beta_2$ and $\epsilon$ are scalars and others are tensors149we modify this calculation to optimize the computation.150151\begin{align}152\theta_t &\leftarrow \theta_{t-1} - \alpha \cdot \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} \\153\theta_t &\leftarrow \theta_{t-1} - \alpha \cdot154\frac{m_t / (1-\beta_1^t)}{\sqrt{v_t/(1-\beta_2^t)} + \epsilon} \\155\theta_t &\leftarrow \theta_{t-1} - \alpha \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \cdot156\frac{m_t}{\sqrt{v_t} + \hat{\epsilon}} \\157\end{align}158159where160$$\hat{\epsilon} = (1-\beta_2^t) \epsilon$$161is what we should specify as the hyper-parameter.162"""163164# Get $\beta_1$ and $\beta_2$165beta1, beta2 = group['betas']166# Bias correction term for $\hat{m}_t$, $1 - \beta_1^t$167bias_correction1 = 1 - beta1 ** state['step']168# Bias correction term for $\hat{v}_t$, $1 - \beta_2^t$169bias_correction2 = 1 - beta2 ** state['step']170171# Get learning rate172lr = self.get_lr(state, group)173174# Whether to optimize the computation175if self.optimized_update:176# $\sqrt{v_t} + \hat{\epsilon}$177denominator = v.sqrt().add_(group['eps'])178# $\alpha \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t}$179step_size = lr * math.sqrt(bias_correction2) / bias_correction1180# $\theta_t \leftarrow \theta_{t-1} - \alpha \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \cdot181# \frac{m_t}{\sqrt{v_t} + \hat{\epsilon}}$182param.data.addcdiv_(m, denominator, value=-step_size)183# Computation without optimization184else:185# $\frac{\sqrt{v_t}}{\sqrt{1-\beta_2^t}} + \epsilon$186denominator = (v.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])187# $\frac{\alpha}{1-\beta_1^t}$188step_size = lr / bias_correction1189# $\theta_t \leftarrow \theta_{t-1} - \alpha \cdot190# \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}$191param.data.addcdiv_(m, denominator, value=-step_size)192193def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):194"""195### Take an update step for a given parameter tensor196197* `state` is the optimizer state of the parameter (tensor)198* `group` stores optimizer attributes of the parameter group199* `grad` is the current gradient tensor $g_t$ for the parameter $\theta_{t-1}$200* `param` is the parameter tensor $\theta_{t-1}$201"""202203# Calculate weight decay204grad = self.weight_decay(param, grad, group)205206# Get $m_t$ and $v_t$207m, v = self.get_mv(state, group, grad)208209# Increment $t$ the number of optimizer steps210state['step'] += 1211212# Perform *Adam* update213self.adam_update(state, group, param, m, v)214215216