Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
yiming-wange
GitHub Repository: yiming-wange/cs224n-2023-solution
Path: blob/main/minBERT/evaluation.py
984 views
1
#!/usr/bin/env python3
2
3
'''
4
Model evaluation functions.
5
6
When training your multitask model, you will find it useful to run
7
model_eval_multitask to be able to evaluate your model on the 3 tasks in the
8
development set.
9
10
Before submission, your code needs to call test_model_multitask(args, model, device) to generate
11
your predictions. We'll evaluate these predictions against our labels on our end,
12
which is how the leaderboard will be updated.
13
The provided test_model() function in multitask_classifier.py **already does this for you**,
14
so unless you change it you shouldn't need to call anything from here
15
explicitly aside from model_eval_multitask.
16
'''
17
18
import torch
19
from torch.utils.data import DataLoader
20
from sklearn.metrics import classification_report, f1_score, recall_score, accuracy_score
21
from tqdm import tqdm
22
import numpy as np
23
24
from datasets import load_multitask_data, load_multitask_test_data, \
25
SentenceClassificationDataset, SentenceClassificationTestDataset, \
26
SentencePairDataset, SentencePairTestDataset
27
28
29
TQDM_DISABLE = True
30
31
# Evaluate a multitask model for accuracy.on SST only.
32
def model_eval_sst(dataloader, model, device):
33
model.eval() # switch to eval model, will turn off randomness like dropout
34
y_true = []
35
y_pred = []
36
sents = []
37
sent_ids = []
38
for step, batch in enumerate(tqdm(dataloader, desc=f'eval', disable=TQDM_DISABLE)):
39
b_ids, b_mask, b_labels, b_sents, b_sent_ids = batch['token_ids'],batch['attention_mask'], \
40
batch['labels'], batch['sents'], batch['sent_ids']
41
42
b_ids = b_ids.to(device)
43
b_mask = b_mask.to(device)
44
45
logits = model.predict_sentiment(b_ids, b_mask)
46
logits = logits.detach().cpu().numpy()
47
preds = np.argmax(logits, axis=1).flatten()
48
49
b_labels = b_labels.flatten()
50
y_true.extend(b_labels)
51
y_pred.extend(preds)
52
sents.extend(b_sents)
53
sent_ids.extend(b_sent_ids)
54
55
f1 = f1_score(y_true, y_pred, average='macro')
56
acc = accuracy_score(y_true, y_pred)
57
58
return acc, f1, y_pred, y_true, sents, sent_ids
59
60
# Perform model evaluation in terms by averaging accuracies across tasks.
61
def model_eval_multitask(sentiment_dataloader,
62
paraphrase_dataloader,
63
sts_dataloader,
64
model, device):
65
model.eval() # switch to eval model, will turn off randomness like dropout
66
67
with torch.no_grad():
68
para_y_true = []
69
para_y_pred = []
70
para_sent_ids = []
71
72
# Evaluate paraphrase detection.
73
for step, batch in enumerate(tqdm(paraphrase_dataloader, desc=f'eval', disable=TQDM_DISABLE)):
74
(b_ids1, b_mask1,
75
b_ids2, b_mask2,
76
b_labels, b_sent_ids) = (batch['token_ids_1'], batch['attention_mask_1'],
77
batch['token_ids_2'], batch['attention_mask_2'],
78
batch['labels'], batch['sent_ids'])
79
80
b_ids1 = b_ids1.to(device)
81
b_mask1 = b_mask1.to(device)
82
b_ids2 = b_ids2.to(device)
83
b_mask2 = b_mask2.to(device)
84
85
logits = model.predict_paraphrase(b_ids1, b_mask1, b_ids2, b_mask2)
86
y_hat = logits.sigmoid().round().flatten().cpu().numpy()
87
b_labels = b_labels.flatten().cpu().numpy()
88
89
para_y_pred.extend(y_hat)
90
para_y_true.extend(b_labels)
91
para_sent_ids.extend(b_sent_ids)
92
93
paraphrase_accuracy = np.mean(np.array(para_y_pred) == np.array(para_y_true))
94
95
sts_y_true = []
96
sts_y_pred = []
97
sts_sent_ids = []
98
99
100
# Evaluate semantic textual similarity.
101
for step, batch in enumerate(tqdm(sts_dataloader, desc=f'eval', disable=TQDM_DISABLE)):
102
(b_ids1, b_mask1,
103
b_ids2, b_mask2,
104
b_labels, b_sent_ids) = (batch['token_ids_1'], batch['attention_mask_1'],
105
batch['token_ids_2'], batch['attention_mask_2'],
106
batch['labels'], batch['sent_ids'])
107
108
b_ids1 = b_ids1.to(device)
109
b_mask1 = b_mask1.to(device)
110
b_ids2 = b_ids2.to(device)
111
b_mask2 = b_mask2.to(device)
112
113
logits = model.predict_similarity(b_ids1, b_mask1, b_ids2, b_mask2)
114
y_hat = logits.flatten().cpu().numpy()
115
b_labels = b_labels.flatten().cpu().numpy()
116
117
sts_y_pred.extend(y_hat)
118
sts_y_true.extend(b_labels)
119
sts_sent_ids.extend(b_sent_ids)
120
pearson_mat = np.corrcoef(sts_y_pred,sts_y_true)
121
sts_corr = pearson_mat[1][0]
122
123
124
sst_y_true = []
125
sst_y_pred = []
126
sst_sent_ids = []
127
128
# Evaluate sentiment classification.
129
for step, batch in enumerate(tqdm(sentiment_dataloader, desc=f'eval', disable=TQDM_DISABLE)):
130
b_ids, b_mask, b_labels, b_sent_ids = batch['token_ids'], batch['attention_mask'], batch['labels'], batch['sent_ids']
131
132
b_ids = b_ids.to(device)
133
b_mask = b_mask.to(device)
134
135
logits = model.predict_sentiment(b_ids, b_mask)
136
y_hat = logits.argmax(dim=-1).flatten().cpu().numpy()
137
b_labels = b_labels.flatten().cpu().numpy()
138
139
sst_y_pred.extend(y_hat)
140
sst_y_true.extend(b_labels)
141
sst_sent_ids.extend(b_sent_ids)
142
143
sentiment_accuracy = np.mean(np.array(sst_y_pred) == np.array(sst_y_true))
144
145
print(f'Paraphrase detection accuracy: {paraphrase_accuracy:.3f}')
146
print(f'Sentiment classification accuracy: {sentiment_accuracy:.3f}')
147
print(f'Semantic Textual Similarity correlation: {sts_corr:.3f}')
148
149
return (paraphrase_accuracy, para_y_pred, para_sent_ids,
150
sentiment_accuracy,sst_y_pred, sst_sent_ids,
151
sts_corr, sts_y_pred, sts_sent_ids)
152
153
# Perform model evaluation in terms by averaging accuracies across tasks.
154
def model_eval_test_multitask(sentiment_dataloader,
155
paraphrase_dataloader,
156
sts_dataloader,
157
model, device):
158
model.eval() # switch to eval model, will turn off randomness like dropout
159
160
with torch.no_grad():
161
162
para_y_pred = []
163
para_sent_ids = []
164
# Evaluate paraphrase detection.
165
for step, batch in enumerate(tqdm(paraphrase_dataloader, desc=f'eval', disable=TQDM_DISABLE)):
166
(b_ids1, b_mask1,
167
b_ids2, b_mask2,
168
b_sent_ids) = (batch['token_ids_1'], batch['attention_mask_1'],
169
batch['token_ids_2'], batch['attention_mask_2'],
170
batch['sent_ids'])
171
172
b_ids1 = b_ids1.to(device)
173
b_mask1 = b_mask1.to(device)
174
b_ids2 = b_ids2.to(device)
175
b_mask2 = b_mask2.to(device)
176
177
logits = model.predict_paraphrase(b_ids1, b_mask1, b_ids2, b_mask2)
178
y_hat = logits.sigmoid().round().flatten().cpu().numpy()
179
180
para_y_pred.extend(y_hat)
181
para_sent_ids.extend(b_sent_ids)
182
183
184
sts_y_pred = []
185
sts_sent_ids = []
186
187
188
# Evaluate semantic textual similarity.
189
for step, batch in enumerate(tqdm(sts_dataloader, desc=f'eval', disable=TQDM_DISABLE)):
190
(b_ids1, b_mask1,
191
b_ids2, b_mask2,
192
b_sent_ids) = (batch['token_ids_1'], batch['attention_mask_1'],
193
batch['token_ids_2'], batch['attention_mask_2'],
194
batch['sent_ids'])
195
196
b_ids1 = b_ids1.to(device)
197
b_mask1 = b_mask1.to(device)
198
b_ids2 = b_ids2.to(device)
199
b_mask2 = b_mask2.to(device)
200
201
logits = model.predict_similarity(b_ids1, b_mask1, b_ids2, b_mask2)
202
y_hat = logits.flatten().cpu().numpy()
203
204
sts_y_pred.extend(y_hat)
205
sts_sent_ids.extend(b_sent_ids)
206
207
208
sst_y_pred = []
209
sst_sent_ids = []
210
211
# Evaluate sentiment classification.
212
for step, batch in enumerate(tqdm(sentiment_dataloader, desc=f'eval', disable=TQDM_DISABLE)):
213
b_ids, b_mask, b_sent_ids = batch['token_ids'], batch['attention_mask'], batch['sent_ids']
214
215
b_ids = b_ids.to(device)
216
b_mask = b_mask.to(device)
217
218
logits = model.predict_sentiment(b_ids, b_mask)
219
y_hat = logits.argmax(dim=-1).flatten().cpu().numpy()
220
221
sst_y_pred.extend(y_hat)
222
sst_sent_ids.extend(b_sent_ids)
223
224
return (para_y_pred, para_sent_ids,
225
sst_y_pred, sst_sent_ids,
226
sts_y_pred, sts_sent_ids)
227
228
229
def test_model_multitask(args, model, device):
230
sst_test_data, num_labels,para_test_data, sts_test_data = \
231
load_multitask_data(args.sst_test,args.para_test, args.sts_test, split='test')
232
233
sst_dev_data, num_labels,para_dev_data, sts_dev_data = \
234
load_multitask_data(args.sst_dev,args.para_dev,args.sts_dev,split='dev')
235
236
sst_test_data = SentenceClassificationTestDataset(sst_test_data, args)
237
sst_dev_data = SentenceClassificationDataset(sst_dev_data, args)
238
239
sst_test_dataloader = DataLoader(sst_test_data, shuffle=True, batch_size=args.batch_size,
240
collate_fn=sst_test_data.collate_fn)
241
sst_dev_dataloader = DataLoader(sst_dev_data, shuffle=False, batch_size=args.batch_size,
242
collate_fn=sst_dev_data.collate_fn)
243
244
para_test_data = SentencePairTestDataset(para_test_data, args)
245
para_dev_data = SentencePairDataset(para_dev_data, args)
246
247
para_test_dataloader = DataLoader(para_test_data, shuffle=True, batch_size=args.batch_size,
248
collate_fn=para_test_data.collate_fn)
249
para_dev_dataloader = DataLoader(para_dev_data, shuffle=False, batch_size=args.batch_size,
250
collate_fn=para_dev_data.collate_fn)
251
252
sts_test_data = SentencePairTestDataset(sts_test_data, args)
253
sts_dev_data = SentencePairDataset(sts_dev_data, args, isRegression=True)
254
255
sts_test_dataloader = DataLoader(sts_test_data, shuffle=True, batch_size=args.batch_size,
256
collate_fn=sts_test_data.collate_fn)
257
sts_dev_dataloader = DataLoader(sts_dev_data, shuffle=False, batch_size=args.batch_size,
258
collate_fn=sts_dev_data.collate_fn)
259
260
dev_paraphrase_accuracy, dev_para_y_pred, dev_para_sent_ids, \
261
dev_sentiment_accuracy,dev_sst_y_pred, dev_sst_sent_ids, dev_sts_corr, \
262
dev_sts_y_pred, dev_sts_sent_ids = model_eval_multitask(sst_dev_dataloader,
263
para_dev_dataloader,
264
sts_dev_dataloader, model, device)
265
266
test_para_y_pred, test_para_sent_ids, test_sst_y_pred, \
267
test_sst_sent_ids, test_sts_y_pred, test_sts_sent_ids = \
268
model_eval_test_multitask(sst_test_dataloader,
269
para_test_dataloader,
270
sts_test_dataloader, model, device)
271
272
with open(args.sst_dev_out, "w+") as f:
273
print(f"dev sentiment acc :: {dev_sentiment_accuracy :.3f}")
274
f.write(f"id \t Predicted_Sentiment \n")
275
for p, s in zip(dev_sst_sent_ids, dev_sst_y_pred):
276
f.write(f"{p} , {s} \n")
277
278
with open(args.sst_test_out, "w+") as f:
279
f.write(f"id \t Predicted_Sentiment \n")
280
for p, s in zip(test_sst_sent_ids, test_sst_y_pred):
281
f.write(f"{p} , {s} \n")
282
283
with open(args.para_dev_out, "w+") as f:
284
print(f"dev paraphrase acc :: {dev_paraphrase_accuracy :.3f}")
285
f.write(f"id \t Predicted_Is_Paraphrase \n")
286
for p, s in zip(dev_para_sent_ids, dev_para_y_pred):
287
f.write(f"{p} , {s} \n")
288
289
with open(args.para_test_out, "w+") as f:
290
f.write(f"id \t Predicted_Is_Paraphrase \n")
291
for p, s in zip(test_para_sent_ids, test_para_y_pred):
292
f.write(f"{p} , {s} \n")
293
294
with open(args.sts_dev_out, "w+") as f:
295
print(f"dev sts corr :: {dev_sts_corr :.3f}")
296
f.write(f"id \t Predicted_Similiary \n")
297
for p, s in zip(dev_sts_sent_ids, dev_sts_y_pred):
298
f.write(f"{p} , {s} \n")
299
300
with open(args.sts_test_out, "w+") as f:
301
f.write(f"id \t Predicted_Similiary \n")
302
for p, s in zip(test_sts_sent_ids, test_sts_y_pred):
303
f.write(f"{p} , {s} \n")
304
305