Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
jantic
GitHub Repository: jantic/deoldify
Path: blob/master/fastai/callbacks/fp16.py
781 views
1
"Callback support for half precision (fp16) training. Increases training speed."
2
from ..torch_core import *
3
from ..callback import *
4
from ..basic_train import *
5
from torch._utils import _unflatten_dense_tensors
6
from torch.nn.utils import parameters_to_vector
7
8
__all__ = ['MixedPrecision']
9
10
def get_master(layer_groups:ModuleList, flat_master:bool=False) -> Tuple[List[List[Tensor]], List[List[Tensor]]]:
11
"Return two lists, one for the model parameters in FP16 and one for the master parameters in FP32."
12
split_params = split_no_wd_params(layer_groups)
13
model_params = [[param for param in pg if param.requires_grad] for pg in split_params]
14
if flat_master:
15
master_params = []
16
for lg in model_params:
17
if len(lg) !=0 :
18
mp = parameters_to_vector([param.data.float() for param in lg])
19
mp = torch.nn.Parameter(mp, requires_grad=True)
20
if mp.grad is None: mp.grad = mp.new(*mp.size())
21
master_params.append([mp])
22
else: master_params.append([])
23
return model_params, master_params
24
else:
25
master_params = [[param.clone().float().detach() for param in lg] for lg in model_params]
26
for mp in master_params:
27
for param in mp: param.requires_grad = True
28
return model_params, master_params
29
30
def model_g2master_g(model_params:Sequence[Tensor], master_params:Sequence[Tensor], flat_master:bool=False)->None:
31
"Copy the `model_params` gradients to `master_params` for the optimizer step."
32
if flat_master:
33
for model_group,master_group in zip(model_params,master_params):
34
if len(master_group) != 0:
35
if master_group[0].grad is None: master_group[0].grad = master_group[0].data.new(*master_group[0].data.size())
36
master_group[0].grad.data.copy_(parameters_to_vector([p.grad.data.float() for p in model_group]))
37
else:
38
for model_group,master_group in zip(model_params,master_params):
39
for model, master in zip(model_group, master_group):
40
if model.grad is not None:
41
if master.grad is None: master.grad = master.data.new(*master.data.size())
42
master.grad.data.copy_(model.grad.data)
43
else: master.grad = None
44
45
def master2model(model_params:Sequence[Tensor], master_params:Sequence[Tensor], flat_master:bool=False)->None:
46
"Copy `master_params` to `model_params`."
47
if flat_master:
48
for model_group,master_group in zip(model_params,master_params):
49
if len(model_group) != 0:
50
for model, master in zip(model_group, _unflatten_dense_tensors(master_group[0].data, model_group)):
51
model.data.copy_(master)
52
else:
53
for model_group,master_group in zip(model_params,master_params):
54
for model, master in zip(model_group, master_group): model.data.copy_(master.data)
55
56
def grad_overflow(param_group):
57
for group in param_group:
58
for p in group:
59
if p.grad is not None:
60
s = float(p.grad.data.float().sum())
61
if s == float('inf') or s == float('-inf') or s != s: return True
62
return False
63
64
class MixedPrecision(LearnerCallback):
65
_order = 999 #Need to run after things that could call on_backward_begin and change the loss
66
"Callback that handles mixed-precision training."
67
def __init__(self, learn:Learner, loss_scale:float=None, max_noskip:int=1000, dynamic:bool=True, clip:float=None,
68
flat_master:bool=False, max_scale:float=2**24):
69
super().__init__(learn)
70
self.flat_master,self.dynamic,self.max_noskip,self.clip,self.max_scale = flat_master,dynamic,max_noskip,clip,max_scale
71
self.loss_scale = ifnone(loss_scale, 2**16 if dynamic else 512)
72
self.not_min += ['model_params', 'master_params']
73
assert torch.backends.cudnn.enabled, "Mixed precision training requires cudnn."
74
self.opt = None
75
76
def on_train_begin(self, **kwargs:Any)->None:
77
"Prepare the master model."
78
#Get a copy of the model params in FP32
79
self.model_params, self.master_params = get_master(self.learn.layer_groups, self.flat_master)
80
#Changes the optimizer so that the optimization step is done in FP32.
81
new_opt = self.learn.opt.new_with_params(self.master_params)
82
if self.opt is not None:
83
self.opt.lr,self.opt.wd = self.learn.opt.lr,self.learn.opt.wd
84
new_opt.load_state_dict(self.opt)
85
self.learn.opt.opt = new_opt.opt
86
self.noskip = 0
87
88
def on_loss_begin(self, last_output:Tensor, **kwargs:Any) -> Tensor:
89
"Convert half precision output to FP32 to avoid reduction overflow."
90
return {'last_output': to_float(last_output)}
91
92
def on_backward_begin(self, last_loss:Rank0Tensor, **kwargs:Any) -> Rank0Tensor:
93
"Scale gradients up by `self.loss_scale` to prevent underflow."
94
#To avoid gradient underflow, we scale the gradients
95
ret_loss = last_loss * self.loss_scale
96
return {'last_loss': ret_loss}
97
98
def on_backward_end(self, **kwargs:Any)->None:
99
"Convert the gradients back to FP32 and divide them by the scale."
100
if self.dynamic and grad_overflow(self.model_params) and self.loss_scale > 1:
101
self.loss_scale /= 2
102
self.noskip = 0
103
#The step will be skipped since we don't update the master grads so they are all None or zero
104
else:
105
model_g2master_g(self.model_params, self.master_params, self.flat_master)
106
for group in self.master_params:
107
for param in group:
108
if param.grad is not None: param.grad.div_(self.loss_scale)
109
if self.clip is not None:
110
for group in self.master_params: nn.utils.clip_grad_norm_(group, self.clip)
111
if not self.dynamic: return
112
self.noskip += 1
113
if self.noskip >= self.max_noskip and self.loss_scale < self.max_scale:
114
self.loss_scale *= 2
115
self.noskip = 0
116
117
def on_step_end(self, **kwargs:Any)->None:
118
"Update the params from master to model and zero grad."
119
#Zeros the gradients of the model since the optimizer is disconnected.
120
self.learn.model.zero_grad()
121
#Update the params from master to model.
122
master2model(self.model_params, self.master_params, self.flat_master)
123
124