Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
labmlai
GitHub Repository: labmlai/annotated_deep_learning_paper_implementations
Path: blob/master/labml_nn/optimizers/adam_fp16.py
4928 views
1
"""
2
---
3
title: Adam Optimizer for Half Precision Training
4
summary: A simple PyTorch implementation/tutorial of Adam optimizer
5
---
6
7
# Adam Optimizer for Half Precision Training
8
"""
9
10
from typing import Dict, Tuple, Optional, Any
11
12
import torch
13
from torch import nn
14
from torch.optim import Optimizer
15
from torch.cuda.amp import grad_scaler
16
from collections import defaultdict, abc
17
18
from labml_nn.optimizers import WeightDecay
19
from labml_nn.optimizers.adam import Adam
20
21
22
class AdamFP16(Adam):
23
"""
24
## Adam Optimizer for Half Precision Training
25
26
We extend [Adam Optimizer](adam.html) but use FP32 to store gradients and moments.
27
"""
28
29
def __init__(self, params, lr: float = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-16,
30
weight_decay: WeightDecay = WeightDecay(), optimized_update: bool = True,
31
defaults: Optional[Dict[str, Any]] = None):
32
# Parameter to store 32 bit gradients. This get populated by the `GradScaler` defined below.
33
self.grad_fp32 = {}
34
# Call the [Adam Optimizer](adam.html) initializer
35
super().__init__(params, lr, betas, eps, weight_decay, optimized_update, defaults)
36
37
def init_state(self, state: Dict[str, any], group: Dict[str, any], param: nn.Parameter):
38
"""
39
### Initialize a parameter state
40
41
* `state` is the optimizer state of the parameter (tensor)
42
* `group` stores optimizer attributes of the parameter group
43
* `param` is the parameter tensor $\theta_{t-1}$
44
45
All the state tensors use FP32.
46
"""
47
48
# This is the number of optimizer steps taken on the parameter, $t$
49
state['step'] = 0
50
# Exponential moving average of gradients, $m_t$
51
state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format, dtype=torch.float)
52
# Exponential moving average of squared gradient values, $v_t$
53
state['exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format, dtype=torch.float)
54
# Maintain a FP32 copy of the parameters
55
state['fp32_copy'] = param.to(torch.float)
56
57
def step_param(self, state: Dict[str, any], group: Dict[str, any], grad: torch.Tensor, param: torch.nn.Parameter):
58
"""
59
### Take an update step for a given parameter tensor
60
61
* `state` is the optimizer state of the parameter (tensor)
62
* `group` stores optimizer attributes of the parameter group
63
* `grad` is the current gradient tensor $g_t$ for the parameter $\theta_{t-1}$
64
* `param` is the parameter tensor $\theta_{t-1}$
65
"""
66
67
# Get the FP32 parameters
68
param_fp32 = state['fp32_copy']
69
# Get the FP32 gradients if available
70
grad_fp32 = self.grad_fp32.get(param, None)
71
if grad_fp32 is not None:
72
del self.grad_fp32[param]
73
grad = grad_fp32
74
else:
75
# Otherwise, convert the gradients to FP32
76
grad = grad.to(torch.float)
77
78
# Calculate weight decay
79
grad = self.weight_decay(param_fp32, grad, group)
80
81
# Get $m_t$ and $v_t$
82
m, v = self.get_mv(state, group, grad)
83
84
# Increment $t$ the number of optimizer steps
85
state['step'] += 1
86
87
# Perform *Adam* update
88
self.adam_update(state, group, param_fp32, m, v)
89
90
# Set the parameters
91
param.data = param_fp32.to(param.dtype)
92
93
94
class GradScalerFP16(grad_scaler.GradScaler):
95
"""
96
## Gradient Scaler with half precision gradients
97
98
We extend PyTorch gradient scaler to use FP32 gradients.
99
"""
100
101
def _unscale_grads_(self, optimizer: Optimizer, inv_scale: torch.Tensor, found_inf: torch.Tensor,
102
allow_fp16: bool) -> Dict[torch.device, torch.Tensor]:
103
per_device_inv_scale = grad_scaler._MultiDeviceReplicator(inv_scale)
104
per_device_found_inf = grad_scaler._MultiDeviceReplicator(found_inf)
105
106
per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated]
107
108
with torch.no_grad():
109
# Loop through parameters
110
for group in optimizer.param_groups:
111
for param in group["params"]:
112
# Skip non-trainable parameters
113
if param.grad is None:
114
continue
115
# Not implemented for sparse tensors
116
if param.grad.is_sparse:
117
raise NotImplementedError
118
119
# If we are using the `AdamFP16` optimizer set `optimizer.grad_fp32[param]` to the FP32 gradients
120
if isinstance(optimizer, AdamFP16):
121
grad = param.grad.to(torch.float)
122
optimizer.grad_fp32[param] = grad
123
# Otherwise, do not convert the gradients to FP32
124
else:
125
grad = param.grad
126
127
per_device_and_dtype_grads[grad.device][grad.dtype].append(grad)
128
129
# Unscale all the gradients
130
for device, per_dtype_grads in per_device_and_dtype_grads.items():
131
for grads in per_dtype_grads.values():
132
torch._amp_foreach_non_finite_check_and_unscale_(grads,
133
per_device_found_inf.get(device),
134
per_device_inv_scale.get(device))
135
#
136
return per_device_found_inf._per_device_tensors
137
138