Path: blob/master/labml_nn/optimizers/sophia.py
4937 views
"""1---2title: Sophia Optimizer3summary: A simple PyTorch implementation/tutorial of Sophia optimizer4---56# Sophia Optimizer78This is a [PyTorch](https://pytorch.org) implementation of *Sophia-G* from paper9[Sophia: A Scalable Stochastic Second-order Optimizer for Language Model Pre-training](https://arxiv.org/abs/2305.14342).10Official implementation is available at [Liuhong99/Sophia](https://github.com/Liuhong99/Sophia).1112Sophia is more adaptive to heterogeneous curvatures than Adam, more resistant13to non-convexity and rapid change of Hessian than Newton’s method, and also uses a low-cost14pre-conditioner.1516Sophia keeps diagonal Hessian estimates with EMA across iterations.17The diagonal Hessian $\hat{h}_t$ is calculated every $k$ steps.1819\begin{align}20h_t = \beta_2 h_{t-k} + (1 - \beta_2) \hat{h}_t \ \ \ \ \text{ if } t \text{ mod } k = 1; \text{ else } h_t = h_{t-1}21\end{align}2223Sophia uses EMA of gradients $m_t$, only considers positive entries of24the diagonal Hessian and does per-coordinate clipping to the update.2526\begin{align}27m_t &\leftarrow \beta_1 m_{t-1} + (1 - \beta_1)g_t \\28\theta_{t + 1} &\leftarrow \theta_t - \eta \cdot \operatorname{clip} \bigg(\frac{m_t}{ \max \{h_t, \epsilon \} }, \rho \bigg)29\end{align}3031where $\epsilon$ is a very small value to prevent division by $0$.3233### Gauss-Newton-Bartlett (GNB) estimator3435\begin{align}36\hat{L}(\theta) &= \frac{1}{B} \sum^{B}_{b=1} \ell_{CE} \big( f(\theta, x_b), \hat{y}_b \big) \\37\hat{h}_t &= B \cdot \nabla_\theta \hat{L} (\theta) \odot \nabla_\theta \hat{L} (\theta)38\end{align}3940where $x_b$ are the inputs,41$B$ is the batch size (number of inputs/tokens),42$\ell_{CE}$ is cross entropy loss, and43$\hat{y}_b$ are sampled from the logits $f(\theta, x_b)$.4445Note that this hessian estimate is always positive and therefore we46can replace $\max \{h_t, \epsilon \}$ with $h_t + \epsilon$.4748Sophia with Gauss-Newton-Bartlett (GNB) estimator is **Sophia-G**4950Here is an [experiment](../transformers/basic/with_sophia.html) that uses Sophia-G to train a transformer.51"""5253from typing import Dict, Any, Tuple, Optional5455import torch56from torch import nn5758from labml_nn.optimizers import GenericAdaptiveOptimizer, WeightDecay596061class Sophia(GenericAdaptiveOptimizer):62"""63## Sophia-G Optimizer6465We extend the class `GenericAdaptiveOptimizer` defined in [`__init__.py`](index.html)66to implement the Sophia optimizer.67"""6869def __init__(self, params,70lr: float = 1e-4, betas: Tuple[float, float] = (0.9, 0.95), eps: float = 1e-12,71rho: float = 0.03,72weight_decay: WeightDecay = WeightDecay(),73defaults: Optional[Dict[str, Any]] = None):74"""75### Initialize the optimizer7677* `params` is the list of parameters78* `lr` is the maximum learning rate $\eta \rho$79* `betas` is a tuple of ($\beta_1$, $\beta_2$)80* `eps` is $\epsilon$81* `pho` is $\rho$82* `weight_decay` is an instance of class `WeightDecay` defined in [`__init__.py`](index.html)83* `defaults` is a dictionary of default for group values.84This is useful when you want to extend the class `Adam`.85"""86defaults = {} if defaults is None else defaults87defaults.update(weight_decay.defaults())88defaults.update(dict(rho=rho))89super().__init__(params, defaults, lr, betas, eps)9091self.weight_decay = weight_decay9293def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):94"""95### Initialize a parameter state9697* `state` is the optimizer state of the parameter (tensor)98* `group` stores optimizer attributes of the parameter group99* `param` is the parameter tensor $\theta_{t-1}$100"""101102# This is the number of optimizer steps taken on the parameter, $t$103state['step'] = 0104# Exponential moving average of gradients, $m_t$105state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format)106# Exponential moving average of Hessian diagonal, $h_t$107state['hessian'] = torch.zeros_like(param, memory_format=torch.preserve_format)108109def update_hessian(self, n_tokens_training_batch):110"""111### Update the EMA of Hessian diagonal $h_t$112113* `n_tokens_training_batch` is the number of tokens/inputs in the batch $B$114115\begin{align}116\hat{h}_t &= B \cdot \nabla_\theta \hat{L} (\theta) \odot \nabla_\theta \hat{L} (\theta) \\117h_t &= \beta_2 h_{t-k} + (1 - \beta_2) \hat{h}_t118\end{align}119"""120121# Iterate through parameter groups122for group in self.param_groups:123# $\beta_2$124_, beta2 = group['betas']125# Iterate through parameters126for p in group['params']:127# Skip parameters without gradients128if p.grad is None:129continue130131# Get optimizer state132state = self.state[p]133134# Initialize state if empty135if len(state) == 0:136self.init_state(state, group, p)137138# Update EMA Hessian diagonal139#140# \begin{align}141# \hat{h}_t &= B \cdot \nabla_\theta \hat{L} (\theta) \odot \nabla_\theta \hat{L} (\theta) \\142# h_t &= \beta_2 h_{t-k} + (1 - \beta_2) \hat{h}_t143# \end{align}144state['hessian'].mul_(beta2).addcmul_(p.grad, p.grad, value=(1 - beta2) * n_tokens_training_batch)145146def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):147"""148### Take an update step for a given parameter tensor149150* `state` is the optimizer state of the parameter (tensor)151* `group` stores optimizer attributes of the parameter group152* `grad` is the current gradient tensor $g_t$ for the parameter $\theta_{t-1}$153* `param` is the parameter tensor $\theta_{t-1}$154155We do the following parameter update,156157\begin{align}158\theta_{t + 1} &\leftarrow \theta_t - \eta \cdot \operatorname{clip} \bigg(\frac{m_t}{h_t + \epsilon}, \rho \bigg)159\end{align}160"""161162# Calculate weight decay163grad = self.weight_decay(param, grad, group)164165# Get $\beta_1$ and $\beta_2$166beta1, beta2 = group['betas']167# Get $\rho$168rho = group['rho']169170# Get $m_{t-1}$ and $h_{t}$171m, hessian = state['exp_avg'], state['hessian']172173# In-place calculation of $m_t$174# $$m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) \cdot g_t$$175m.mul_(beta1).add_(grad, alpha=1 - beta1)176177# Increment $t$ the number of optimizer steps178state['step'] += 1179180# Get maximum learning rate $\eta \rho$181lr = group['lr']182183# $\eta$184eta = lr / rho185186# $$\operatorname{clip} \bigg(\frac{m_t}{h_t + \epsilon}, \rho \bigg)$$187ratio = (m / (hessian + group['eps'])).clamp(-rho, rho)188189# $$\theta_{t + 1} \leftarrow \theta_t - \eta \cdot \operatorname{clip} \bigg(\frac{m_t}{h_t + \epsilon}, \rho \bigg)$$190param.data.add_(ratio, alpha=-eta)191192193