Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
yiming-wange
GitHub Repository: yiming-wange/cs224n-2023-solution
Path: blob/main/a2/sgd.py
995 views
1
#!/usr/bin/env python
2
3
# Save parameters every a few SGD iterations as fail-safe
4
SAVE_PARAMS_EVERY = 5000
5
6
import pickle
7
import glob
8
import random
9
import numpy as np
10
import os.path as op
11
12
def load_saved_params():
13
"""
14
A helper function that loads previously saved parameters and resets
15
iteration start.
16
"""
17
st = 0
18
for f in glob.glob("saved_params_*.npy"):
19
iter = int(op.splitext(op.basename(f))[0].split("_")[2])
20
if (iter > st):
21
st = iter
22
23
if st > 0:
24
params_file = "saved_params_%d.npy" % st
25
state_file = "saved_state_%d.pickle" % st
26
params = np.load(params_file)
27
with open(state_file, "rb") as f:
28
state = pickle.load(f)
29
return st, params, state
30
else:
31
return st, None, None
32
33
34
def save_params(iter, params):
35
params_file = "saved_params_%d.npy" % iter
36
np.save(params_file, params)
37
with open("saved_state_%d.pickle" % iter, "wb") as f:
38
pickle.dump(random.getstate(), f)
39
40
41
def sgd(f, x0, step, iterations, postprocessing=None, useSaved=False,
42
PRINT_EVERY=10):
43
""" Stochastic Gradient Descent
44
45
Implement the stochastic gradient descent method in this function.
46
47
Arguments:
48
f -- the function to optimize, it should take a single
49
argument and yield two outputs, a loss and the gradient
50
with respect to the arguments
51
x0 -- the initial point to start SGD from
52
step -- the step size for SGD
53
iterations -- total iterations to run SGD for
54
postprocessing -- postprocessing function for the parameters
55
if necessary. In the case of word2vec we will need to
56
normalize the word vectors to have unit length.
57
PRINT_EVERY -- specifies how many iterations to output loss
58
59
Return:
60
x -- the parameter value after SGD finishes
61
"""
62
63
# Anneal learning rate every several iterations
64
ANNEAL_EVERY = 20000
65
66
if useSaved:
67
start_iter, oldx, state = load_saved_params()
68
if start_iter > 0:
69
x0 = oldx
70
step *= 0.5 ** (start_iter / ANNEAL_EVERY)
71
72
if state:
73
random.setstate(state)
74
else:
75
start_iter = 0
76
77
x = x0
78
79
if not postprocessing:
80
postprocessing = lambda x: x
81
82
exploss = None
83
84
for iter in range(start_iter + 1, iterations + 1):
85
# You might want to print the progress every few iterations.
86
87
loss = None
88
### YOUR CODE HERE (~2 lines)
89
loss, grad = f(x)
90
x -= step * grad
91
### END YOUR CODE
92
93
x = postprocessing(x)
94
if iter % PRINT_EVERY == 0:
95
if not exploss:
96
exploss = loss
97
else:
98
exploss = .95 * exploss + .05 * loss
99
print("iter %d: %f" % (iter, exploss))
100
101
if iter % SAVE_PARAMS_EVERY == 0 and useSaved:
102
save_params(iter, x)
103
104
if iter % ANNEAL_EVERY == 0:
105
step *= 0.5
106
107
return x
108
109
110
def sanity_check():
111
quad = lambda x: (np.sum(x ** 2), x * 2)
112
113
print("Running sanity checks...")
114
t1 = sgd(quad, 0.5, 0.01, 1000, PRINT_EVERY=100)
115
print("test 1 result:", t1)
116
assert abs(t1) <= 1e-6
117
118
t2 = sgd(quad, 0.0, 0.01, 1000, PRINT_EVERY=100)
119
print("test 2 result:", t2)
120
assert abs(t2) <= 1e-6
121
122
t3 = sgd(quad, -1.5, 0.01, 1000, PRINT_EVERY=100)
123
print("test 3 result:", t3)
124
assert abs(t3) <= 1e-6
125
126
print("-" * 40)
127
print("ALL TESTS PASSED")
128
print("-" * 40)
129
130
131
if __name__ == "__main__":
132
sanity_check()
133
134