Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/scripts/armijo_sgd.py
1192 views
1
#https://github.com/IssamLaradji/stochastic_line_search/blob/master/sgd_armijo.py
2
3
import superimport
4
5
import torch
6
import copy
7
import torch.nn as nn
8
import torch.nn.functional as F
9
10
class SGD_Armijo(torch.optim.Optimizer):
11
def __init__(self,
12
model,
13
batch_size,
14
dataset_size,
15
init_step_size=1,
16
sigma=0.1,
17
beta=0.9,
18
beta_2=None):
19
20
defaults = dict(
21
batch_size=batch_size,
22
init_step_size=init_step_size,
23
dataset_size=dataset_size,
24
sigma=sigma,
25
beta=beta,
26
beta_2=beta_2)
27
28
super().__init__(model.parameters(), defaults)
29
30
self.model = model
31
self.state['step'] = 0
32
self.state['step_size'] = init_step_size
33
34
def step(self, closure):
35
step_size = reset_step(self.state, self.defaults)
36
37
# call the closure to get loss and compute gradients
38
loss = closure()
39
40
# save the current parameters:
41
x_current = copy.deepcopy(self.param_groups)
42
43
# save the gradient at the current parameters
44
gradient, grad_norm = self.model.get_grads()
45
46
# only do the check if the gradient norm is big enough
47
with torch.no_grad():
48
if grad_norm >= 1e-8:
49
# check if condition is satisfied
50
found = 0
51
step_size_old = step_size
52
53
for e in range(100):
54
# try a prospective step
55
self._try_update(step_size, x_current, gradient)
56
57
# compute the loss at the next step; no need to compute gradients.
58
loss_temp = closure(compute_grad=False)
59
60
wolfe_results = wolfe_line_search(step_size=step_size,
61
step_size_old=step_size_old,
62
loss=loss,
63
grad_norm=grad_norm,
64
loss_temp=loss_temp,
65
params=self.defaults)
66
67
found, step_size, step_size_old = wolfe_results
68
69
if found == 1:
70
break
71
72
if found == 0:
73
self._try_update(1e-6, x_current, gradient)
74
75
else:
76
self._try_update(step_size, x_current, gradient)
77
78
# save the new step-size
79
self.state['step_size'] = step_size
80
self.state['step'] = self.state['step'] + 1
81
82
return float(loss)
83
84
def _try_update(self, step_size, x_current, gradient):
85
for i, group in enumerate(self.param_groups):
86
for j, p in enumerate(group['params']):
87
# update models parameters using SGD update
88
p.data = x_current[i]['params'][j] - \
89
step_size * gradient[i][j]
90
91
92
# ==============================================
93
# Helpers
94
95
96
def wolfe_line_search(step_size, step_size_old, loss, grad_norm,
97
loss_temp, params):
98
found = 0
99
100
# computing the new break condition
101
break_condition = loss_temp - \
102
(loss - (step_size) * params['sigma'] * grad_norm**2)
103
104
if (break_condition <= 0):
105
found = 1
106
107
else:
108
# decrease the step-size by a multiplicative factor
109
step_size = step_size * params['beta']
110
111
return found, step_size, step_size_old
112
113
114
def reset_step(state, params):
115
step_size = state['step_size']
116
117
if 'beta_2' in params and not params['beta_2'] is None:
118
beta_2 = params['beta_2']
119
else:
120
beta_2 = 2.0
121
122
# try to increase the step-size up to maximum ETA
123
step_size = min(
124
step_size * beta_2**(params['batch_size'] / params['dataset_size']),
125
10.0)
126
127
return step_size
128
129
###########
130
131
def get_grads(param_groups):
132
grad_norm = 0
133
gradient = []
134
135
if not isinstance(param_groups[0], dict):
136
param_groups = [{'params': param_groups}]
137
138
for i, group in enumerate(param_groups):
139
grad_group = []
140
for j, p in enumerate(group['params']):
141
grad_copy = torch.zeros_like(p.grad.data).copy_(p.grad.data)
142
grad_group.append(grad_copy)
143
grad_norm = grad_norm + torch.sum(torch.mul(grad_copy, grad_copy))
144
145
gradient.append(grad_group)
146
147
return gradient, torch.sqrt(grad_norm)
148
149
class ArmijoModel(nn.Module):
150
def __init__(self, model, objective):
151
super().__init__()
152
self.model = model
153
self.objective = objective
154
#self.opt = opt
155
156
def forward(self, x):
157
return self.model(x)
158
159
def get_grads(self):
160
return get_grads(list(self.parameters()))
161
162
def step(self, batch):
163
self.train()
164
X, y = batch
165
166
def closure(compute_grad=True):
167
if compute_grad:
168
self.zero_grad()
169
170
logits = self.forward(X)
171
loss = self.objective(logits, y)
172
173
if compute_grad:
174
loss.backward()
175
176
return float(loss)
177
178
minibatch_loss = self.opt.step(closure)
179
180
return float(minibatch_loss) , self.opt.state['step_size']
181
182
183
184