Path: blob/master/labml_nn/optimizers/adam_warmup.py
4925 views
"""1---2title: Adam optimizer with warm-up3summary: A simple PyTorch implementation/tutorial of Adam optimizer with warm-up.4---56# Adam Optimizer with Warmup78This extends [AMSGrad optimizer](amsgrad.html) and adds a warmup stage.9"""1011from typing import Dict1213from labml_nn.optimizers import WeightDecay14from labml_nn.optimizers.amsgrad import AMSGrad151617class AdamWarmup(AMSGrad):18"""19## Adam Optimizer with Warmup2021This class extends from AMSGrad optimizer defined in [`amsgrad.py`](amsgrad.html).22"""23def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,24weight_decay: WeightDecay = WeightDecay(),25optimized_update: bool = True,26amsgrad=False, warmup=0, defaults=None):27"""28### Initialize the optimizer2930* `params` is the list of parameters31* `lr` is the learning rate $\alpha$32* `betas` is a tuple of ($\beta_1$, $\beta_2$)33* `eps` is $\hat{\epsilon}$ or $\epsilon$ based on `optimized_update`34* `weight_decay` is an instance of class `WeightDecay` defined in [`__init__.py`](index.html)35* 'optimized_update' is a flag whether to optimize the bias correction of the second moment36by doing it after adding $\epsilon$37* `amsgrad` is a flag indicating whether to use AMSGrad or fallback to plain Adam38* `warmup` number of warmup steps39* `defaults` is a dictionary of default for group values.40This is useful when you want to extend the class `AdamWarmup`.41"""4243defaults = {} if defaults is None else defaults44defaults.update(dict(warmup=warmup))45super().__init__(params, lr, betas, eps, weight_decay, optimized_update, amsgrad, defaults)4647def get_lr(self, state: Dict[str, any], group: Dict[str, any]):48"""49### Get learning-rate5051$$\alpha \min \bigg(1, \frac{t}{w}\bigg)$$52where $w$ is the number of warmup steps.53"""54# If we are in warmup stage55if group['warmup'] > state['step']:56# A linearly increasing learning rate from $0$ to $\alpha$57return 1e-8 + state['step'] * group['lr'] / group['warmup']58else:59# Constant learning rate $\alpha$60return group['lr']616263