Path: blob/master/labml_nn/optimizers/configs.py
4910 views
"""1---2title: Configurable optimizer module3summary: This implements a configurable module for optimizers.4---56# Configurable Optimizer7"""89from typing import Tuple1011import torch1213from labml.configs import BaseConfigs, option, meta_config14from labml_nn.optimizers import WeightDecay151617class OptimizerConfigs(BaseConfigs):18"""19<a id="OptimizerConfigs"></a>2021## Optimizer Configurations22"""2324# Optimizer25optimizer: torch.optim.Adam2627# Weight decay28weight_decay_obj: WeightDecay29# Whether weight decay is decoupled;30# i.e. weight decay is not added to gradients31weight_decouple: bool = True32# Weight decay33weight_decay: float = 0.034# Whether weight decay is absolute or should be multiplied by learning rate35weight_decay_absolute: bool = False3637# Whether the adam update is optimized (different epsilon)38optimized_adam_update: bool = True3940# Parameters to be optimized41parameters: any4243# Learning rate $\alpha$44learning_rate: float = 0.0145# Beta values $(\beta_1, \beta_2)$ for Adam46betas: Tuple[float, float] = (0.9, 0.999)47# Epsilon $\epsilon$ for adam48eps: float = 1e-084950# Momentum for SGD51momentum: float = 0.552# Whether to use AMSGrad53amsgrad: bool = False5455# Number of warmup optimizer steps56warmup: int = 2_00057# Total number of optimizer steps (for cosine decay)58total_steps: int = int(1e10)5960# Whether to degenerate to SGD in AdaBelief61degenerate_to_sgd: bool = True6263# Whether to use Rectified Adam in AdaBelief64rectify: bool = True6566# Model embedding size for Noam optimizer67d_model: int6869rho: float7071def __init__(self):72super().__init__(_primary='optimizer')737475meta_config(OptimizerConfigs.parameters)767778@option(OptimizerConfigs.weight_decay_obj, 'L2')79def _weight_decay(c: OptimizerConfigs):80return WeightDecay(c.weight_decay, c.weight_decouple, c.weight_decay_absolute)818283@option(OptimizerConfigs.optimizer, 'SGD')84def _sgd_optimizer(c: OptimizerConfigs):85return torch.optim.SGD(c.parameters, c.learning_rate, c.momentum,86weight_decay=c.weight_decay)878889@option(OptimizerConfigs.optimizer, 'Adam')90def _adam_optimizer(c: OptimizerConfigs):91if c.amsgrad:92from labml_nn.optimizers.amsgrad import AMSGrad93return AMSGrad(c.parameters,94lr=c.learning_rate, betas=c.betas, eps=c.eps,95optimized_update=c.optimized_adam_update,96weight_decay=c.weight_decay_obj, amsgrad=c.amsgrad)97else:98from labml_nn.optimizers.adam import Adam99return Adam(c.parameters,100lr=c.learning_rate, betas=c.betas, eps=c.eps,101optimized_update=c.optimized_adam_update,102weight_decay=c.weight_decay_obj)103104105@option(OptimizerConfigs.optimizer, 'AdamW')106def _adam_warmup_optimizer(c: OptimizerConfigs):107from labml_nn.optimizers.adam_warmup import AdamWarmup108return AdamWarmup(c.parameters,109lr=c.learning_rate, betas=c.betas, eps=c.eps,110weight_decay=c.weight_decay_obj, amsgrad=c.amsgrad, warmup=c.warmup)111112113@option(OptimizerConfigs.optimizer, 'RAdam')114def _radam_optimizer(c: OptimizerConfigs):115from labml_nn.optimizers.radam import RAdam116return RAdam(c.parameters,117lr=c.learning_rate, betas=c.betas, eps=c.eps,118weight_decay=c.weight_decay_obj, amsgrad=c.amsgrad,119degenerated_to_sgd=c.degenerate_to_sgd)120121122@option(OptimizerConfigs.optimizer, 'AdaBelief')123def _ada_belief_optimizer(c: OptimizerConfigs):124from labml_nn.optimizers.ada_belief import AdaBelief125return AdaBelief(c.parameters,126lr=c.learning_rate, betas=c.betas, eps=c.eps,127weight_decay=c.weight_decay_obj, amsgrad=c.amsgrad,128degenerate_to_sgd=c.degenerate_to_sgd,129rectify=c.rectify)130131132@option(OptimizerConfigs.optimizer, 'Noam')133def _noam_optimizer(c: OptimizerConfigs):134from labml_nn.optimizers.noam import Noam135return Noam(c.parameters,136lr=c.learning_rate, betas=c.betas, eps=c.eps,137weight_decay=c.weight_decay_obj, amsgrad=c.amsgrad, warmup=c.warmup,138d_model=c.d_model)139140141@option(OptimizerConfigs.optimizer, 'Sophia')142def _sophia_optimizer(c: OptimizerConfigs):143from labml_nn.optimizers.sophia import Sophia144return Sophia(c.parameters,145lr=c.learning_rate, betas=c.betas, eps=c.eps,146weight_decay=c.weight_decay_obj, rho=c.rho)147148149@option(OptimizerConfigs.optimizer, 'AdamWarmupCosineDecay')150def _noam_optimizer(c: OptimizerConfigs):151from labml_nn.optimizers.adam_warmup_cosine_decay import AdamWarmupCosineDecay152return AdamWarmupCosineDecay(c.parameters,153lr=c.learning_rate, betas=c.betas, eps=c.eps,154weight_decay=c.weight_decay_obj, amsgrad=c.amsgrad,155warmup=c.warmup, total_steps=c.total_steps)156157158