Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
yiming-wange
GitHub Repository: yiming-wange/cs224n-2023-solution
Path: blob/main/minBERT/multitask_classifier.py
984 views
1
import time, random, numpy as np, argparse, sys, re, os
2
from types import SimpleNamespace
3
4
import torch
5
from torch import nn
6
import torch.nn.functional as F
7
from torch.utils.data import DataLoader
8
9
from bert import BertModel
10
from optimizer import AdamW
11
from tqdm import tqdm
12
13
from datasets import SentenceClassificationDataset, SentencePairDataset, \
14
load_multitask_data, load_multitask_test_data
15
16
from evaluation import model_eval_sst, test_model_multitask
17
18
19
TQDM_DISABLE=True
20
21
# fix the random seed
22
def seed_everything(seed=11711):
23
random.seed(seed)
24
np.random.seed(seed)
25
torch.manual_seed(seed)
26
torch.cuda.manual_seed(seed)
27
torch.cuda.manual_seed_all(seed)
28
torch.backends.cudnn.benchmark = False
29
torch.backends.cudnn.deterministic = True
30
31
32
BERT_HIDDEN_SIZE = 768
33
N_SENTIMENT_CLASSES = 5
34
35
36
class MultitaskBERT(nn.Module):
37
'''
38
This module should use BERT for 3 tasks:
39
40
- Sentiment classification (predict_sentiment)
41
- Paraphrase detection (predict_paraphrase)
42
- Semantic Textual Similarity (predict_similarity)
43
'''
44
def __init__(self, config):
45
super(MultitaskBERT, self).__init__()
46
# You will want to add layers here to perform the downstream tasks.
47
# Pretrain mode does not require updating bert paramters.
48
self.bert = BertModel.from_pretrained('bert-base-uncased')
49
for param in self.bert.parameters():
50
if config.option == 'pretrain':
51
param.requires_grad = False
52
elif config.option == 'finetune':
53
param.requires_grad = True
54
### TODO
55
raise NotImplementedError
56
57
58
def forward(self, input_ids, attention_mask):
59
'Takes a batch of sentences and produces embeddings for them.'
60
# The final BERT embedding is the hidden state of [CLS] token (the first token)
61
# Here, you can start by just returning the embeddings straight from BERT.
62
# When thinking of improvements, you can later try modifying this
63
# (e.g., by adding other layers).
64
### TODO
65
raise NotImplementedError
66
67
68
def predict_sentiment(self, input_ids, attention_mask):
69
'''Given a batch of sentences, outputs logits for classifying sentiment.
70
There are 5 sentiment classes:
71
(0 - negative, 1- somewhat negative, 2- neutral, 3- somewhat positive, 4- positive)
72
Thus, your output should contain 5 logits for each sentence.
73
'''
74
### TODO
75
raise NotImplementedError
76
77
78
def predict_paraphrase(self,
79
input_ids_1, attention_mask_1,
80
input_ids_2, attention_mask_2):
81
'''Given a batch of pairs of sentences, outputs a single logit for predicting whether they are paraphrases.
82
Note that your output should be unnormalized (a logit); it will be passed to the sigmoid function
83
during evaluation, and handled as a logit by the appropriate loss function.
84
'''
85
### TODO
86
raise NotImplementedError
87
88
89
def predict_similarity(self,
90
input_ids_1, attention_mask_1,
91
input_ids_2, attention_mask_2):
92
'''Given a batch of pairs of sentences, outputs a single logit corresponding to how similar they are.
93
Note that your output should be unnormalized (a logit).
94
'''
95
### TODO
96
raise NotImplementedError
97
98
99
100
101
def save_model(model, optimizer, args, config, filepath):
102
save_info = {
103
'model': model.state_dict(),
104
'optim': optimizer.state_dict(),
105
'args': args,
106
'model_config': config,
107
'system_rng': random.getstate(),
108
'numpy_rng': np.random.get_state(),
109
'torch_rng': torch.random.get_rng_state(),
110
}
111
112
torch.save(save_info, filepath)
113
print(f"save the model to {filepath}")
114
115
116
## Currently only trains on sst dataset
117
def train_multitask(args):
118
device = torch.device('cuda') if args.use_gpu else torch.device('cpu')
119
# Load data
120
# Create the data and its corresponding datasets and dataloader
121
sst_train_data, num_labels,para_train_data, sts_train_data = load_multitask_data(args.sst_train,args.para_train,args.sts_train, split ='train')
122
sst_dev_data, num_labels,para_dev_data, sts_dev_data = load_multitask_data(args.sst_dev,args.para_dev,args.sts_dev, split ='train')
123
124
sst_train_data = SentenceClassificationDataset(sst_train_data, args)
125
sst_dev_data = SentenceClassificationDataset(sst_dev_data, args)
126
127
sst_train_dataloader = DataLoader(sst_train_data, shuffle=True, batch_size=args.batch_size,
128
collate_fn=sst_train_data.collate_fn)
129
sst_dev_dataloader = DataLoader(sst_dev_data, shuffle=False, batch_size=args.batch_size,
130
collate_fn=sst_dev_data.collate_fn)
131
132
# Init model
133
config = {'hidden_dropout_prob': args.hidden_dropout_prob,
134
'num_labels': num_labels,
135
'hidden_size': 768,
136
'data_dir': '.',
137
'option': args.option}
138
139
config = SimpleNamespace(**config)
140
141
model = MultitaskBERT(config)
142
model = model.to(device)
143
144
lr = args.lr
145
optimizer = AdamW(model.parameters(), lr=lr)
146
best_dev_acc = 0
147
148
# Run for the specified number of epochs
149
for epoch in range(args.epochs):
150
model.train()
151
train_loss = 0
152
num_batches = 0
153
for batch in tqdm(sst_train_dataloader, desc=f'train-{epoch}', disable=TQDM_DISABLE):
154
b_ids, b_mask, b_labels = (batch['token_ids'],
155
batch['attention_mask'], batch['labels'])
156
157
b_ids = b_ids.to(device)
158
b_mask = b_mask.to(device)
159
b_labels = b_labels.to(device)
160
161
optimizer.zero_grad()
162
logits = model.predict_sentiment(b_ids, b_mask)
163
loss = F.cross_entropy(logits, b_labels.view(-1), reduction='sum') / args.batch_size
164
165
loss.backward()
166
optimizer.step()
167
168
train_loss += loss.item()
169
num_batches += 1
170
171
train_loss = train_loss / (num_batches)
172
173
train_acc, train_f1, *_ = model_eval_sst(sst_train_dataloader, model, device)
174
dev_acc, dev_f1, *_ = model_eval_sst(sst_dev_dataloader, model, device)
175
176
if dev_acc > best_dev_acc:
177
best_dev_acc = dev_acc
178
save_model(model, optimizer, args, config, args.filepath)
179
180
print(f"Epoch {epoch}: train loss :: {train_loss :.3f}, train acc :: {train_acc :.3f}, dev acc :: {dev_acc :.3f}")
181
182
183
184
def test_model(args):
185
with torch.no_grad():
186
device = torch.device('cuda') if args.use_gpu else torch.device('cpu')
187
saved = torch.load(args.filepath)
188
config = saved['model_config']
189
190
model = MultitaskBERT(config)
191
model.load_state_dict(saved['model'])
192
model = model.to(device)
193
print(f"Loaded model to test from {args.filepath}")
194
195
test_model_multitask(args, model, device)
196
197
198
def get_args():
199
parser = argparse.ArgumentParser()
200
parser.add_argument("--sst_train", type=str, default="data/ids-sst-train.csv")
201
parser.add_argument("--sst_dev", type=str, default="data/ids-sst-dev.csv")
202
parser.add_argument("--sst_test", type=str, default="data/ids-sst-test-student.csv")
203
204
parser.add_argument("--para_train", type=str, default="data/quora-train.csv")
205
parser.add_argument("--para_dev", type=str, default="data/quora-dev.csv")
206
parser.add_argument("--para_test", type=str, default="data/quora-test-student.csv")
207
208
parser.add_argument("--sts_train", type=str, default="data/sts-train.csv")
209
parser.add_argument("--sts_dev", type=str, default="data/sts-dev.csv")
210
parser.add_argument("--sts_test", type=str, default="data/sts-test-student.csv")
211
212
parser.add_argument("--seed", type=int, default=11711)
213
parser.add_argument("--epochs", type=int, default=10)
214
parser.add_argument("--option", type=str,
215
help='pretrain: the BERT parameters are frozen; finetune: BERT parameters are updated',
216
choices=('pretrain', 'finetune'), default="pretrain")
217
parser.add_argument("--use_gpu", action='store_true')
218
219
parser.add_argument("--sst_dev_out", type=str, default="predictions/sst-dev-output.csv")
220
parser.add_argument("--sst_test_out", type=str, default="predictions/sst-test-output.csv")
221
222
parser.add_argument("--para_dev_out", type=str, default="predictions/para-dev-output.csv")
223
parser.add_argument("--para_test_out", type=str, default="predictions/para-test-output.csv")
224
225
parser.add_argument("--sts_dev_out", type=str, default="predictions/sts-dev-output.csv")
226
parser.add_argument("--sts_test_out", type=str, default="predictions/sts-test-output.csv")
227
228
# hyper parameters
229
parser.add_argument("--batch_size", help='sst: 64, cfimdb: 8 can fit a 12GB GPU', type=int, default=8)
230
parser.add_argument("--hidden_dropout_prob", type=float, default=0.3)
231
parser.add_argument("--lr", type=float, help="learning rate, default lr for 'pretrain': 1e-3, 'finetune': 1e-5",
232
default=1e-5)
233
234
args = parser.parse_args()
235
return args
236
237
if __name__ == "__main__":
238
args = get_args()
239
args.filepath = f'{args.option}-{args.epochs}-{args.lr}-multitask.pt' # save path
240
seed_everything(args.seed) # fix the seed for reproducibility
241
train_multitask(args)
242
test_model(args)
243
244