Path: blob/master/labml_nn/optimizers/noam.py
4910 views
"""1---2title: Noam optimizer from Attention is All You Need paper3summary: >4This is a tutorial/implementation of Noam optimizer.5Noam optimizer has a warm-up period and then an exponentially decaying learning rate.6---78# Noam Optimizer910This is the [PyTorch](https://pytorch.org) implementation of optimizer introduced in the paper11[Attention Is All You Need](https://arxiv.org/abs/1706.03762).12"""13from typing import Dict1415from labml_nn.optimizers import WeightDecay16from labml_nn.optimizers.amsgrad import AMSGrad171819class Noam(AMSGrad):20"""21## Noam Optimizer2223This class extends from Adam optimizer defined in [`adam.py`](adam.html).24"""2526def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,27weight_decay: WeightDecay = WeightDecay(),28optimized_update: bool = True,29amsgrad=False,30warmup=0, d_model=512, defaults=None):31"""32### Initialize the optimizer3334* `params` is the list of parameters35* `lr` is the learning rate $\alpha$36* `betas` is a tuple of ($\beta_1$, $\beta_2$)37* `eps` is $\hat{\epsilon}$ or $\epsilon$ based on `optimized_update`38* `weight_decay` is an instance of class `WeightDecay` defined in [`__init__.py`](index.html)39* 'optimized_update' is a flag whether to optimize the bias correction of the second moment40by doing it after adding $\epsilon$41* `amsgrad` is a flag indicating whether to use AMSGrad or fallback to plain Adam42* `warmup` number of warmup steps43* `d_model` model size; i.e. number of dimensions in the transformer44* `defaults` is a dictionary of default for group values.45This is useful when you want to extend the class `AdamWarmup`.46"""4748defaults = {} if defaults is None else defaults49defaults.update(dict(warmup=warmup))50super().__init__(params, lr, betas, eps, weight_decay, optimized_update, amsgrad, defaults)51self.d_model = d_model5253def get_lr(self, state: Dict[str, any], group: Dict[str, any]):54"""55### Get learning-rate5657$$\alpha \frac{1}{\sqrt{d_{model}}} \min \bigg(\frac{1}{\sqrt{t}}, \frac{t}{w^{3/2}}\bigg)$$58where $w$ is the number of warmup steps.59"""60# $$\min \bigg(\frac{1}{\sqrt{t}}, \frac{t}{w^{3/2}}\bigg)$$61factor = min(state['step'] ** (-0.5), state['step'] * group['warmup'] ** (-1.5))62# $$\alpha \frac{1}{\sqrt{d_{model}}} \min \bigg(\frac{1}{\sqrt{t}}, \frac{t}{w^{3/2}}\bigg)$$63return group['lr'] * self.d_model ** (-0.5) * factor646566def _test_noam_lr():67"""68### Plot learning rate for different warmups and model sizes697071"""72import matplotlib.pyplot as plt73import numpy as np74from torch import nn7576model = nn.Linear(10, 10)77opts = [Noam(model.parameters(), d_model=512, warmup=4000, lr=1),78Noam(model.parameters(), d_model=512, warmup=8000, lr=1),79Noam(model.parameters(), d_model=2048, warmup=2000, lr=1)]80plt.plot(np.arange(1, 20000), [[opt.get_lr({'step': i}, opt.defaults) for opt in opts] for i in range(1, 20000)])81plt.legend(["512:4000", "512:8000", "2048:2000"])82plt.title("Learning Rate")83plt.show()848586if __name__ == '__main__':87_test_noam_lr()888990