Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
jantic
GitHub Repository: jantic/deoldify
Path: blob/master/fastai/general_optimizer.py
781 views
1
from .torch_core import *
2
from torch.optim import Optimizer
3
import types
4
5
__all__ = ['StatScope', 'Statistic', 'ConstStatistic', 'AvgStatistic', 'AvgSquare', 'GeneralOptimizer']
6
7
StatScope = Enum('StatScope', 'Global Group Layer Channel Weight')
8
9
@dataclass
10
class Statistic():
11
name:str
12
param:float=0.9 # e.g. for exp moving average
13
scope:StatScope=StatScope.Weight
14
init:float=0. # starting value
15
16
@property
17
def buf(self): return f'{self.name}_buffer'
18
19
def new_step(self):
20
"Set state when computing statistics for Global or Group"
21
raise NotImplementedError
22
23
def accumulate(self, val):
24
"Add `val` to statistic"
25
raise NotImplementedError
26
27
def update(self, state, param, val=None, step=None):
28
"Update state with accumlated, or `val` (if `Weight` or `Layer` scope)"
29
raise NotImplementedError
30
31
class ConstStatistic(Statistic):
32
@property
33
def buf(self): return None
34
def new_step(self): pass
35
def accumulate(self): pass
36
def update(self, state, param, val=None, step=None): return param
37
38
@dataclass
39
class CounterStat(Statistic):
40
def __post_init__(self): self.init,self._buf,self.name = 0,self.name,None
41
@property
42
def buf(self): return self._buf
43
def new_step(self): pass
44
def accumulate(self, val): pass
45
def update(self, state, param, val=None, step=None): return state + 1
46
47
@dataclass
48
class AvgStatistic(Statistic):
49
decay:bool=False
50
debias:bool=False
51
def new_step(self): self.val,self.count = 0.,0
52
53
def accumulate(self, val):
54
self.count += 1
55
self.val += self._get_val1(val)
56
57
def _get_val1(self, val): return val.mean()
58
def _get_val2(self, state, val, param): return state.add_(1-param, val) if self.decay else state.add_(val)
59
def _get_val3(self, state, val, param):
60
v = val.view(val.size(0), -1).mean(1)
61
return state.add_(1-param, v) if self.decay else state.add_(v)
62
63
def update(self, state, param, val=None, step=None):
64
if self.scope == StatScope.Weight:
65
# `state` is a tensor
66
res = self._get_val2(state.mul_(param), val, param)
67
elif self.scope == StatScope.Channel:
68
# `state` is a tensor of size n_channels
69
res = self._get_val3(state.mul_(param), val, param)
70
# For everything else, `state` is a scalar
71
elif self.scope == StatScope.Layer: res = state*param + self._get_val1(val) * (1-param if self.decay else 1.)
72
elif self.count != 0: res = state*param + self.val/self.count * (1-param if self.decay else 1.)
73
else: return state
74
if self.debias and step is not None: res /= (1 - param ** step)
75
return res
76
77
class AvgSquare(AvgStatistic):
78
79
def __init__(self, name:str, param:float=0.9, scope=StatScope.Weight, init:float=0., decay:bool=True, debias:bool=False):
80
super().__init__(name, param=param, scope=scope, init=init, decay=decay, debias=debias)
81
82
def _get_val1(self, val): return torch.norm(val).pow(2)/val.numel()
83
def _get_val2(self, state, val, param):
84
return state.addcmul_(1-param, val, val) if self.decay else state.addcmul_(val, val)
85
def _get_val3(self, state, val, param):
86
v = val.view(val.size(0), -1).mean(1)
87
return state.addcmul_(1-param, v, v) if self.decay else state.addcmul_(v, v)
88
89
class GeneralOptimizer(Optimizer):
90
def __init__(self, params, stats=None, on_step:Callable=None):
91
defaults = {s.name:s.param for s in listify(stats) if s.name is not None}
92
super().__init__(params, defaults)
93
self.global_stats,self.group_stats,self.layer_stats,self.channel_stats,self.weight_stats = self._split_stats(stats)
94
self.init_stats()
95
if on_step is not None: self.on_step = types.MethodType(on_step, self)
96
97
def step(self, closure=None):
98
self.update_stats()
99
for i,pg in enumerate(self.param_groups):
100
for p in pg['params']:
101
if p.grad is not None: self.on_step(p, pg, i)
102
103
def on_step(self, p, group, group_idx): p.data.add_(-group['lr'], p.grad.data)
104
105
def _split_stats(self, stats):
106
splits = [[stat for stat in listify(stats) if stat.scope==scope] for scope in StatScope]
107
for split,s in zip([splits[0], splits[1], splits[2]+splits[3]+splits[4]], StatScope):
108
if np.any([getattr(s, 'debias', False) for s in split]): split.insert(0, CounterStat('step', scope=s))
109
return splits
110
111
def _init_stats(self, stats, data=None):
112
return {stat.buf: stat.init if data is None
113
else torch.zeros_like(data) + stat.init for stat in stats if stat.buf is not None}
114
115
def init_stats(self):
116
self.state['global'] = self._init_stats(self.global_stats)
117
for i,pg in enumerate(self.param_groups):
118
self.state[f'group{i}'] = self._init_stats(self.group_stats)
119
for p in pg['params']:
120
self.state[p] = self._init_stats(self.layer_stats)
121
self.state[p].update(self._init_stats(self.channel_stats, p.data.view(p.data.size(0), -1).mean(1)))
122
self.state[p].update(self._init_stats(self.weight_stats, p.data))
123
124
def _set_bufs(self, p, stats, pg, val=None):
125
d = self.state[p]
126
for stat in stats:
127
if stat.buf is not None: d[stat.buf] = stat.update(d[stat.buf], pg[stat.name], val=val, step=d.get('step', None))
128
129
def update_stats(self):
130
for stat in self.global_stats: stat.new_step()
131
for i,pg in enumerate(self.param_groups):
132
for stat in self.group_stats: stat.new_step()
133
for p in pg['params']:
134
if p.grad is not None:
135
for stat in self.global_stats + self.group_stats: stat.accumulate(p.grad.data)
136
self._set_bufs(p, self.layer_stats+self.channel_stats+self.weight_stats, pg, p.grad.data)
137
self._set_bufs(f'group{i}', self.group_stats, pg)
138
self._set_bufs('global', self.global_stats, self.param_groups[0])
139
140
141