Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
labmlai
GitHub Repository: labmlai/annotated_deep_learning_paper_implementations
Path: blob/master/labml_nn/optimizers/sophia.py
4937 views
1
"""
2
---
3
title: Sophia Optimizer
4
summary: A simple PyTorch implementation/tutorial of Sophia optimizer
5
---
6
7
# Sophia Optimizer
8
9
This is a [PyTorch](https://pytorch.org) implementation of *Sophia-G* from paper
10
[Sophia: A Scalable Stochastic Second-order Optimizer for Language Model Pre-training](https://arxiv.org/abs/2305.14342).
11
Official implementation is available at [Liuhong99/Sophia](https://github.com/Liuhong99/Sophia).
12
13
Sophia is more adaptive to heterogeneous curvatures than Adam, more resistant
14
to non-convexity and rapid change of Hessian than Newton’s method, and also uses a low-cost
15
pre-conditioner.
16
17
Sophia keeps diagonal Hessian estimates with EMA across iterations.
18
The diagonal Hessian $\hat{h}_t$ is calculated every $k$ steps.
19
20
\begin{align}
21
h_t = \beta_2 h_{t-k} + (1 - \beta_2) \hat{h}_t \ \ \ \ \text{ if } t \text{ mod } k = 1; \text{ else } h_t = h_{t-1}
22
\end{align}
23
24
Sophia uses EMA of gradients $m_t$, only considers positive entries of
25
the diagonal Hessian and does per-coordinate clipping to the update.
26
27
\begin{align}
28
m_t &\leftarrow \beta_1 m_{t-1} + (1 - \beta_1)g_t \\
29
\theta_{t + 1} &\leftarrow \theta_t - \eta \cdot \operatorname{clip} \bigg(\frac{m_t}{ \max \{h_t, \epsilon \} }, \rho \bigg)
30
\end{align}
31
32
where $\epsilon$ is a very small value to prevent division by $0$.
33
34
### Gauss-Newton-Bartlett (GNB) estimator
35
36
\begin{align}
37
\hat{L}(\theta) &= \frac{1}{B} \sum^{B}_{b=1} \ell_{CE} \big( f(\theta, x_b), \hat{y}_b \big) \\
38
\hat{h}_t &= B \cdot \nabla_\theta \hat{L} (\theta) \odot \nabla_\theta \hat{L} (\theta)
39
\end{align}
40
41
where $x_b$ are the inputs,
42
$B$ is the batch size (number of inputs/tokens),
43
$\ell_{CE}$ is cross entropy loss, and
44
$\hat{y}_b$ are sampled from the logits $f(\theta, x_b)$.
45
46
Note that this hessian estimate is always positive and therefore we
47
can replace $\max \{h_t, \epsilon \}$ with $h_t + \epsilon$.
48
49
Sophia with Gauss-Newton-Bartlett (GNB) estimator is **Sophia-G**
50
51
Here is an [experiment](../transformers/basic/with_sophia.html) that uses Sophia-G to train a transformer.
52
"""
53
54
from typing import Dict, Any, Tuple, Optional
55
56
import torch
57
from torch import nn
58
59
from labml_nn.optimizers import GenericAdaptiveOptimizer, WeightDecay
60
61
62
class Sophia(GenericAdaptiveOptimizer):
63
"""
64
## Sophia-G Optimizer
65
66
We extend the class `GenericAdaptiveOptimizer` defined in [`__init__.py`](index.html)
67
to implement the Sophia optimizer.
68
"""
69
70
def __init__(self, params,
71
lr: float = 1e-4, betas: Tuple[float, float] = (0.9, 0.95), eps: float = 1e-12,
72
rho: float = 0.03,
73
weight_decay: WeightDecay = WeightDecay(),
74
defaults: Optional[Dict[str, Any]] = None):
75
"""
76
### Initialize the optimizer
77
78
* `params` is the list of parameters
79
* `lr` is the maximum learning rate $\eta \rho$
80
* `betas` is a tuple of ($\beta_1$, $\beta_2$)
81
* `eps` is $\epsilon$
82
* `pho` is $\rho$
83
* `weight_decay` is an instance of class `WeightDecay` defined in [`__init__.py`](index.html)
84
* `defaults` is a dictionary of default for group values.
85
This is useful when you want to extend the class `Adam`.
86
"""
87
defaults = {} if defaults is None else defaults
88
defaults.update(weight_decay.defaults())
89
defaults.update(dict(rho=rho))
90
super().__init__(params, defaults, lr, betas, eps)
91
92
self.weight_decay = weight_decay
93
94
def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):
95
"""
96
### Initialize a parameter state
97
98
* `state` is the optimizer state of the parameter (tensor)
99
* `group` stores optimizer attributes of the parameter group
100
* `param` is the parameter tensor $\theta_{t-1}$
101
"""
102
103
# This is the number of optimizer steps taken on the parameter, $t$
104
state['step'] = 0
105
# Exponential moving average of gradients, $m_t$
106
state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format)
107
# Exponential moving average of Hessian diagonal, $h_t$
108
state['hessian'] = torch.zeros_like(param, memory_format=torch.preserve_format)
109
110
def update_hessian(self, n_tokens_training_batch):
111
"""
112
### Update the EMA of Hessian diagonal $h_t$
113
114
* `n_tokens_training_batch` is the number of tokens/inputs in the batch $B$
115
116
\begin{align}
117
\hat{h}_t &= B \cdot \nabla_\theta \hat{L} (\theta) \odot \nabla_\theta \hat{L} (\theta) \\
118
h_t &= \beta_2 h_{t-k} + (1 - \beta_2) \hat{h}_t
119
\end{align}
120
"""
121
122
# Iterate through parameter groups
123
for group in self.param_groups:
124
# $\beta_2$
125
_, beta2 = group['betas']
126
# Iterate through parameters
127
for p in group['params']:
128
# Skip parameters without gradients
129
if p.grad is None:
130
continue
131
132
# Get optimizer state
133
state = self.state[p]
134
135
# Initialize state if empty
136
if len(state) == 0:
137
self.init_state(state, group, p)
138
139
# Update EMA Hessian diagonal
140
#
141
# \begin{align}
142
# \hat{h}_t &= B \cdot \nabla_\theta \hat{L} (\theta) \odot \nabla_\theta \hat{L} (\theta) \\
143
# h_t &= \beta_2 h_{t-k} + (1 - \beta_2) \hat{h}_t
144
# \end{align}
145
state['hessian'].mul_(beta2).addcmul_(p.grad, p.grad, value=(1 - beta2) * n_tokens_training_batch)
146
147
def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):
148
"""
149
### Take an update step for a given parameter tensor
150
151
* `state` is the optimizer state of the parameter (tensor)
152
* `group` stores optimizer attributes of the parameter group
153
* `grad` is the current gradient tensor $g_t$ for the parameter $\theta_{t-1}$
154
* `param` is the parameter tensor $\theta_{t-1}$
155
156
We do the following parameter update,
157
158
\begin{align}
159
\theta_{t + 1} &\leftarrow \theta_t - \eta \cdot \operatorname{clip} \bigg(\frac{m_t}{h_t + \epsilon}, \rho \bigg)
160
\end{align}
161
"""
162
163
# Calculate weight decay
164
grad = self.weight_decay(param, grad, group)
165
166
# Get $\beta_1$ and $\beta_2$
167
beta1, beta2 = group['betas']
168
# Get $\rho$
169
rho = group['rho']
170
171
# Get $m_{t-1}$ and $h_{t}$
172
m, hessian = state['exp_avg'], state['hessian']
173
174
# In-place calculation of $m_t$
175
# $$m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) \cdot g_t$$
176
m.mul_(beta1).add_(grad, alpha=1 - beta1)
177
178
# Increment $t$ the number of optimizer steps
179
state['step'] += 1
180
181
# Get maximum learning rate $\eta \rho$
182
lr = group['lr']
183
184
# $\eta$
185
eta = lr / rho
186
187
# $$\operatorname{clip} \bigg(\frac{m_t}{h_t + \epsilon}, \rho \bigg)$$
188
ratio = (m / (hessian + group['eps'])).clamp(-rho, rho)
189
190
# $$\theta_{t + 1} \leftarrow \theta_t - \eta \cdot \operatorname{clip} \bigg(\frac{m_t}{h_t + \epsilon}, \rho \bigg)$$
191
param.data.add_(ratio, alpha=-eta)
192
193