Path: blob/master/labml_nn/optimizers/adam_fp16.py
4928 views
"""1---2title: Adam Optimizer for Half Precision Training3summary: A simple PyTorch implementation/tutorial of Adam optimizer4---56# Adam Optimizer for Half Precision Training7"""89from typing import Dict, Tuple, Optional, Any1011import torch12from torch import nn13from torch.optim import Optimizer14from torch.cuda.amp import grad_scaler15from collections import defaultdict, abc1617from labml_nn.optimizers import WeightDecay18from labml_nn.optimizers.adam import Adam192021class AdamFP16(Adam):22"""23## Adam Optimizer for Half Precision Training2425We extend [Adam Optimizer](adam.html) but use FP32 to store gradients and moments.26"""2728def __init__(self, params, lr: float = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-16,29weight_decay: WeightDecay = WeightDecay(), optimized_update: bool = True,30defaults: Optional[Dict[str, Any]] = None):31# Parameter to store 32 bit gradients. This get populated by the `GradScaler` defined below.32self.grad_fp32 = {}33# Call the [Adam Optimizer](adam.html) initializer34super().__init__(params, lr, betas, eps, weight_decay, optimized_update, defaults)3536def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):37"""38### Initialize a parameter state3940* `state` is the optimizer state of the parameter (tensor)41* `group` stores optimizer attributes of the parameter group42* `param` is the parameter tensor $\theta_{t-1}$4344All the state tensors use FP32.45"""4647# This is the number of optimizer steps taken on the parameter, $t$48state['step'] = 049# Exponential moving average of gradients, $m_t$50state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format, dtype=torch.float)51# Exponential moving average of squared gradient values, $v_t$52state['exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format, dtype=torch.float)53# Maintain a FP32 copy of the parameters54state['fp32_copy'] = param.to(torch.float)5556def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):57"""58### Take an update step for a given parameter tensor5960* `state` is the optimizer state of the parameter (tensor)61* `group` stores optimizer attributes of the parameter group62* `grad` is the current gradient tensor $g_t$ for the parameter $\theta_{t-1}$63* `param` is the parameter tensor $\theta_{t-1}$64"""6566# Get the FP32 parameters67param_fp32 = state['fp32_copy']68# Get the FP32 gradients if available69grad_fp32 = self.grad_fp32.get(param, None)70if grad_fp32 is not None:71del self.grad_fp32[param]72grad = grad_fp3273else:74# Otherwise, convert the gradients to FP3275grad = grad.to(torch.float)7677# Calculate weight decay78grad = self.weight_decay(param_fp32, grad, group)7980# Get $m_t$ and $v_t$81m, v = self.get_mv(state, group, grad)8283# Increment $t$ the number of optimizer steps84state['step'] += 18586# Perform *Adam* update87self.adam_update(state, group, param_fp32, m, v)8889# Set the parameters90param.data = param_fp32.to(param.dtype)919293class GradScalerFP16(grad_scaler.GradScaler):94"""95## Gradient Scaler with half precision gradients9697We extend PyTorch gradient scaler to use FP32 gradients.98"""99100def _unscale_grads_(self, optimizer: Optimizer, inv_scale: torch.Tensor, found_inf: torch.Tensor,101allow_fp16: bool) -> Dict[torch.device, torch.Tensor]:102per_device_inv_scale = grad_scaler._MultiDeviceReplicator(inv_scale)103per_device_found_inf = grad_scaler._MultiDeviceReplicator(found_inf)104105per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated]106107with torch.no_grad():108# Loop through parameters109for group in optimizer.param_groups:110for param in group["params"]:111# Skip non-trainable parameters112if param.grad is None:113continue114# Not implemented for sparse tensors115if param.grad.is_sparse:116raise NotImplementedError117118# If we are using the `AdamFP16` optimizer set `optimizer.grad_fp32[param]` to the FP32 gradients119if isinstance(optimizer, AdamFP16):120grad = param.grad.to(torch.float)121optimizer.grad_fp32[param] = grad122# Otherwise, do not convert the gradients to FP32123else:124grad = param.grad125126per_device_and_dtype_grads[grad.device][grad.dtype].append(grad)127128# Unscale all the gradients129for device, per_dtype_grads in per_device_and_dtype_grads.items():130for grads in per_dtype_grads.values():131torch._amp_foreach_non_finite_check_and_unscale_(grads,132per_device_found_inf.get(device),133per_device_inv_scale.get(device))134#135return per_device_found_inf._per_device_tensors136137138