Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
yiming-wange
GitHub Repository: yiming-wange/cs224n-2023-solution
Path: blob/main/a4/sanity_check.py
995 views
1
#!/usr/bin/env python3
2
# -*- coding: utf-8 -*-
3
4
"""
5
CS224N 2022-23: Homework 4
6
sanity_check.py: Sanity Checks for Assignment 4
7
Sahil Chopra <[email protected]>
8
Michael Hahn <>
9
Vera Lin <[email protected]>
10
Siyan Li <[email protected]>
11
12
If you are a student, please don't run overwrite_output_for_sanity_check as it will overwrite the correct output!
13
14
Usage:
15
sanity_check.py 1d
16
sanity_check.py 1e
17
sanity_check.py 1f
18
sanity_check.py overwrite_output_for_sanity_check
19
"""
20
import sys
21
22
import numpy as np
23
24
from docopt import docopt
25
from utils import batch_iter
26
import nltk
27
from utils import autograder_read_corpus
28
from vocab import Vocab, VocabEntry
29
30
from nmt_model import NMT
31
32
33
import torch
34
import torch.nn as nn
35
import torch.nn.utils
36
37
#----------
38
# CONSTANTS
39
#----------
40
BATCH_SIZE = 5
41
EMBED_SIZE = 3
42
HIDDEN_SIZE = 2
43
DROPOUT_RATE = 0.0
44
45
def reinitialize_layers(model):
46
""" Reinitialize the Layer Weights for Sanity Checks.
47
"""
48
def init_weights(m):
49
if type(m) == nn.Linear:
50
m.weight.data.fill_(0.3)
51
if m.bias is not None:
52
m.bias.data.fill_(0.1)
53
elif type(m) == nn.Embedding:
54
m.weight.data.fill_(0.15)
55
elif type(m) == nn.Conv1d:
56
m.weight.data.fill_(0.15)
57
elif type(m) == nn.Dropout:
58
nn.Dropout(DROPOUT_RATE)
59
elif type(m) == nn.LSTM:
60
for param in m.state_dict():
61
getattr(m, param).data.fill_(0.1)
62
elif type(m) == nn.LSTMCell:
63
for param in m.state_dict():
64
getattr(m, param).data.fill_(0.1)
65
with torch.no_grad():
66
model.apply(init_weights)
67
68
69
def generate_outputs(model, source, target, vocab):
70
""" Generate outputs.
71
"""
72
print ("-"*80)
73
print("Generating Comparison Outputs")
74
reinitialize_layers(model)
75
model.gen_sanity_check = True
76
model.counter = 0
77
78
# Compute sentence lengths
79
source_lengths = [len(s) for s in source]
80
81
# Convert list of lists into tensors
82
source_padded = model.vocab.src.to_input_tensor(source, device=model.device)
83
target_padded = model.vocab.tgt.to_input_tensor(target, device=model.device)
84
85
# Run the model forward
86
with torch.no_grad():
87
enc_hiddens, dec_init_state = model.encode(source_padded, source_lengths)
88
enc_masks = model.generate_sent_masks(enc_hiddens, source_lengths)
89
combined_outputs = model.decode(enc_hiddens, enc_masks, dec_init_state, target_padded)
90
91
# Save Tensors to disk
92
torch.save(enc_hiddens, './sanity_check_en_es_data/enc_hiddens.pkl')
93
torch.save(dec_init_state, './sanity_check_en_es_data/dec_init_state.pkl')
94
torch.save(enc_masks, './sanity_check_en_es_data/enc_masks.pkl')
95
torch.save(combined_outputs, './sanity_check_en_es_data/combined_outputs.pkl')
96
torch.save(target_padded, './sanity_check_en_es_data/target_padded.pkl')
97
98
# 1f
99
# Inputs
100
Ybar_t = torch.load('./sanity_check_en_es_data/Ybar_t.pkl')
101
enc_hiddens_proj = torch.load('./sanity_check_en_es_data/enc_hiddens_proj.pkl')
102
reinitialize_layers(model)
103
# Run Tests
104
with torch.no_grad():
105
dec_state_target, o_t_target, e_t_target = model.step(Ybar_t, dec_init_state, enc_hiddens, enc_hiddens_proj,
106
enc_masks)
107
torch.save(dec_state_target, './sanity_check_en_es_data/dec_state.pkl')
108
torch.save(o_t_target, './sanity_check_en_es_data/o_t.pkl')
109
torch.save(e_t_target, './sanity_check_en_es_data/e_t.pkl')
110
111
model.gen_sanity_check = False
112
113
def question_1d_sanity_check(model, src_sents, tgt_sents, vocab):
114
""" Sanity check for question 1d.
115
Compares student output to that of model with dummy data.
116
"""
117
print("Running Sanity Check for Question 1d: Encode")
118
print ("-"*80)
119
120
# Configure for Testing
121
reinitialize_layers(model)
122
source_lengths = [len(s) for s in src_sents]
123
source_padded = model.vocab.src.to_input_tensor(src_sents, device=model.device)
124
125
# Load Outputs
126
enc_hiddens_target = torch.load('./sanity_check_en_es_data/enc_hiddens.pkl')
127
dec_init_state_target = torch.load('./sanity_check_en_es_data/dec_init_state.pkl')
128
129
# Test
130
with torch.no_grad():
131
enc_hiddens_pred, dec_init_state_pred = model.encode(source_padded, source_lengths)
132
assert(enc_hiddens_target.shape == enc_hiddens_pred.shape), "enc_hiddens shape is incorrect: it should be:\n {} but is:\n{}".format(enc_hiddens_target.shape, enc_hiddens_pred.shape)
133
assert(np.allclose(enc_hiddens_target.numpy(), enc_hiddens_pred.numpy())), "enc_hiddens is incorrect: it should be:\n {} but is:\n{}".format(enc_hiddens_target, enc_hiddens_pred)
134
print("enc_hiddens Sanity Checks Passed!")
135
assert(dec_init_state_target[0].shape == dec_init_state_pred[0].shape), "dec_init_state[0] shape is incorrect: it should be:\n {} but is:\n{}".format(dec_init_state_target[0].shape, dec_init_state_pred[0].shape)
136
assert(np.allclose(dec_init_state_target[0].numpy(), dec_init_state_pred[0].numpy())), "dec_init_state[0] is incorrect: it should be:\n {} but is:\n{}".format(dec_init_state_target[0], dec_init_state_pred[0])
137
print("dec_init_state[0] Sanity Checks Passed!")
138
assert(dec_init_state_target[1].shape == dec_init_state_pred[1].shape), "dec_init_state[1] shape is incorrect: it should be:\n {} but is:\n{}".format(dec_init_state_target[1].shape, dec_init_state_pred[1].shape)
139
assert(np.allclose(dec_init_state_target[1].numpy(), dec_init_state_pred[1].numpy())), "dec_init_state[1] is incorrect: it should be:\n {} but is:\n{}".format(dec_init_state_target[1], dec_init_state_pred[1])
140
print("dec_init_state[1] Sanity Checks Passed!")
141
print ("-"*80)
142
print("All Sanity Checks Passed for Question 1d: Encode!")
143
print ("-"*80)
144
145
146
def question_1e_sanity_check(model, src_sents, tgt_sents, vocab):
147
""" Sanity check for question 1e.
148
Compares student output to that of model with dummy data.
149
"""
150
print ("-"*80)
151
print("Running Sanity Check for Question 1e: Decode")
152
print ("-"*80)
153
154
# Load Inputs
155
dec_init_state = torch.load('./sanity_check_en_es_data/dec_init_state.pkl')
156
enc_hiddens = torch.load('./sanity_check_en_es_data/enc_hiddens.pkl')
157
enc_masks = torch.load('./sanity_check_en_es_data/enc_masks.pkl')
158
target_padded = torch.load('./sanity_check_en_es_data/target_padded.pkl')
159
160
# Load Outputs
161
combined_outputs_target = torch.load('./sanity_check_en_es_data/combined_outputs.pkl')
162
print(combined_outputs_target.shape)
163
164
# Configure for Testing
165
reinitialize_layers(model)
166
COUNTER = [0]
167
def stepFunction(Ybar_t, dec_state, enc_hiddens, enc_hiddens_proj, enc_masks):
168
dec_state = torch.load('./sanity_check_en_es_data/step_dec_state_{}.pkl'.format(COUNTER[0]))
169
o_t = torch.load('./sanity_check_en_es_data/step_o_t_{}.pkl'.format(COUNTER[0]))
170
COUNTER[0]+=1
171
return dec_state, o_t, None
172
model.step = stepFunction
173
174
# Run Tests
175
with torch.no_grad():
176
combined_outputs_pred = model.decode(enc_hiddens, enc_masks, dec_init_state, target_padded)
177
assert(combined_outputs_target.shape == combined_outputs_pred.shape), "combined_outputs shape is incorrect: it should be:\n {} but is:\n{}".format(combined_outputs_target.shape, combined_outputs_pred.shape)
178
assert(np.allclose(combined_outputs_pred.numpy(), combined_outputs_target.numpy())), "combined_outputs is incorrect: it should be:\n {} but is:\n{}".format(combined_outputs_target, combined_outputs_pred)
179
print("combined_outputs Sanity Checks Passed!")
180
print ("-"*80)
181
print("All Sanity Checks Passed for Question 1e: Decode!")
182
print ("-"*80)
183
184
def question_1f_sanity_check(model, src_sents, tgt_sents, vocab):
185
""" Sanity check for question 1f.
186
Compares student output to that of model with dummy data.
187
"""
188
print ("-"*80)
189
print("Running Sanity Check for Question 1f: Step")
190
print ("-"*80)
191
reinitialize_layers(model)
192
193
# Inputs
194
Ybar_t = torch.load('./sanity_check_en_es_data/Ybar_t.pkl')
195
dec_init_state = torch.load('./sanity_check_en_es_data/dec_init_state.pkl')
196
enc_hiddens = torch.load('./sanity_check_en_es_data/enc_hiddens.pkl')
197
enc_masks = torch.load('./sanity_check_en_es_data/enc_masks.pkl')
198
enc_hiddens_proj = torch.load('./sanity_check_en_es_data/enc_hiddens_proj.pkl')
199
200
# Output
201
dec_state_target = torch.load('./sanity_check_en_es_data/dec_state.pkl')
202
o_t_target = torch.load('./sanity_check_en_es_data/o_t.pkl')
203
e_t_target = torch.load('./sanity_check_en_es_data/e_t.pkl')
204
205
# Run Tests
206
with torch.no_grad():
207
dec_state_pred, o_t_pred, e_t_pred= model.step(Ybar_t, dec_init_state, enc_hiddens, enc_hiddens_proj, enc_masks)
208
assert(dec_state_target[0].shape == dec_state_pred[0].shape), "decoder_state[0] shape is incorrect: it should be:\n {} but is:\n{}".format(dec_state_target[0].shape, dec_state_pred[0].shape)
209
assert(np.allclose(dec_state_target[0].numpy(), dec_state_pred[0].numpy())), "decoder_state[0] is incorrect: it should be:\n {} but is:\n{}".format(dec_state_target[0], dec_state_pred[0])
210
print("dec_state[0] Sanity Checks Passed!")
211
assert(dec_state_target[1].shape == dec_state_pred[1].shape), "decoder_state[1] shape is incorrect: it should be:\n {} but is:\n{}".format(dec_state_target[1].shape, dec_state_pred[1].shape)
212
assert(np.allclose(dec_state_target[1].numpy(), dec_state_pred[1].numpy())), "decoder_state[1] is incorrect: it should be:\n {} but is:\n{}".format(dec_state_target[1], dec_state_pred[1])
213
print("dec_state[1] Sanity Checks Passed!")
214
assert(np.allclose(o_t_target.numpy(), o_t_pred.numpy())), "combined_output is incorrect: it should be:\n {} but is:\n{}".format(o_t_target, o_t_pred)
215
print("combined_output Sanity Checks Passed!")
216
assert(np.allclose(e_t_target.numpy(), e_t_pred.numpy())), "e_t is incorrect: it should be:\n {} but is:\n{}".format(e_t_target, e_t_pred)
217
print("e_t Sanity Checks Passed!")
218
print ("-"*80)
219
print("All Sanity Checks Passed for Question 1f: Step!")
220
print ("-"*80)
221
222
223
224
def main():
225
""" Main func.
226
"""
227
args = docopt(__doc__)
228
229
# Check Python & PyTorch Versions
230
assert (sys.version_info >= (3, 5)), "Please update your installation of Python to version >= 3.5."
231
assert(torch.__version__ >= "1.6.0"), "Please update your installation of PyTorch >= 1.6.0. You have version {}.".format(torch.__version__)
232
233
# Seed the Random Number Generators
234
seed = 1234
235
torch.manual_seed(seed)
236
torch.cuda.manual_seed(seed)
237
np.random.seed(seed * 13 // 7)
238
239
# Load training data & vocabulary
240
train_data_src = autograder_read_corpus('./sanity_check_en_es_data/train_sanity_check.es', 'src')
241
train_data_tgt = autograder_read_corpus('./sanity_check_en_es_data/train_sanity_check.en', 'tgt')
242
train_data = list(zip(train_data_src, train_data_tgt))
243
244
for src_sents, tgt_sents in batch_iter(train_data, batch_size=BATCH_SIZE, shuffle=True):
245
src_sents = src_sents
246
tgt_sents = tgt_sents
247
break
248
vocab = Vocab.load('./sanity_check_en_es_data/vocab_sanity_check.json')
249
250
# Create NMT Model
251
model = NMT(
252
embed_size=EMBED_SIZE,
253
hidden_size=HIDDEN_SIZE,
254
dropout_rate=DROPOUT_RATE,
255
vocab=vocab)
256
257
if args['1d']:
258
question_1d_sanity_check(model, src_sents, tgt_sents, vocab)
259
elif args['1e']:
260
question_1e_sanity_check(model, src_sents, tgt_sents, vocab)
261
elif args['1f']:
262
question_1f_sanity_check(model, src_sents, tgt_sents, vocab)
263
elif args['overwrite_output_for_sanity_check']:
264
generate_outputs(model, src_sents, tgt_sents, vocab)
265
else:
266
raise RuntimeError('invalid run mode')
267
268
269
if __name__ == '__main__':
270
main()
271
272
273