Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
yiming-wange
GitHub Repository: yiming-wange/cs224n-2023-solution
Path: blob/main/minBERT/classifier.py
984 views
1
import time, random, numpy as np, argparse, sys, re, os
2
from types import SimpleNamespace
3
import csv
4
5
import torch
6
import torch.nn.functional as F
7
from torch import nn
8
from torch.utils.data import Dataset, DataLoader
9
from sklearn.metrics import classification_report, f1_score, recall_score, accuracy_score
10
11
# change it with respect to the original model
12
from tokenizer import BertTokenizer
13
from bert import BertModel
14
from optimizer import AdamW
15
from tqdm import tqdm
16
17
18
TQDM_DISABLE=False
19
# fix the random seed
20
def seed_everything(seed=11711):
21
random.seed(seed)
22
np.random.seed(seed)
23
torch.manual_seed(seed)
24
torch.cuda.manual_seed(seed)
25
torch.cuda.manual_seed_all(seed)
26
torch.backends.cudnn.benchmark = False
27
torch.backends.cudnn.deterministic = True
28
29
class BertSentimentClassifier(torch.nn.Module):
30
'''
31
This module performs sentiment classification using BERT embeddings on the SST dataset.
32
33
In the SST dataset, there are 5 sentiment categories (from 0 - "negative" to 4 - "positive").
34
Thus, your forward() should return one logit for each of the 5 classes.
35
'''
36
def __init__(self, config):
37
super(BertSentimentClassifier, self).__init__()
38
self.num_labels = config.num_labels
39
self.bert = BertModel.from_pretrained('/home/minbert-default-final-project/bert')
40
41
# Pretrain mode does not require updating bert paramters.
42
for param in self.bert.parameters():
43
if config.option == 'pretrain':
44
param.requires_grad = False
45
elif config.option == 'finetune':
46
param.requires_grad = True
47
48
### TODO
49
self.classifier = nn.Sequential(nn.Linear(config.hidden_size, 64),
50
nn.ReLU(),
51
nn.Linear(64, self.num_labels))
52
53
def forward(self, input_ids, attention_mask):
54
'''Takes a batch of sentences and returns logits for sentiment classes'''
55
# The final BERT contextualized embedding is the hidden state of [CLS] token (the first token).
56
# HINT: you should consider what is the appropriate output to return given that
57
# the training loop currently uses F.cross_entropy as the loss function.
58
### TODO
59
output_dict = self.bert(input_ids, attention_mask)
60
cls = output_dict['pooler_output']
61
logits = self.classifier(cls)
62
return logits
63
64
65
class SentimentDataset(Dataset):
66
def __init__(self, dataset, args):
67
self.dataset = dataset
68
self.p = args
69
self.tokenizer = BertTokenizer.from_pretrained('/home/minbert-default-final-project/bert')
70
71
def __len__(self):
72
return len(self.dataset)
73
74
def __getitem__(self, idx):
75
return self.dataset[idx]
76
77
def pad_data(self, data):
78
79
sents = [x[0] for x in data]
80
labels = [x[1] for x in data]
81
sent_ids = [x[2] for x in data]
82
83
encoding = self.tokenizer(sents, return_tensors='pt', padding=True, truncation=True)
84
token_ids = torch.LongTensor(encoding['input_ids'])
85
attention_mask = torch.LongTensor(encoding['attention_mask'])
86
labels = torch.LongTensor(labels)
87
88
return token_ids, attention_mask, labels, sents, sent_ids
89
90
def collate_fn(self, all_data):
91
token_ids, attention_mask, labels, sents, sent_ids= self.pad_data(all_data)
92
93
batched_data = {
94
'token_ids': token_ids,
95
'attention_mask': attention_mask,
96
'labels': labels,
97
'sents': sents,
98
'sent_ids': sent_ids
99
}
100
101
return batched_data
102
103
class SentimentTestDataset(Dataset):
104
def __init__(self, dataset, args):
105
self.dataset = dataset
106
self.p = args
107
self.tokenizer = BertTokenizer.from_pretrained('/home/minbert-default-final-project/bert')
108
109
def __len__(self):
110
return len(self.dataset)
111
112
def __getitem__(self, idx):
113
return self.dataset[idx]
114
115
def pad_data(self, data):
116
117
sents = [x[0] for x in data]
118
sent_ids = [x[1] for x in data]
119
120
encoding = self.tokenizer(sents, return_tensors='pt', padding=True, truncation=True)
121
token_ids = torch.LongTensor(encoding['input_ids'])
122
attention_mask = torch.LongTensor(encoding['attention_mask'])
123
124
return token_ids, attention_mask, sents, sent_ids
125
126
def collate_fn(self, all_data):
127
token_ids, attention_mask, sents, sent_ids= self.pad_data(all_data)
128
129
batched_data = {
130
'token_ids': token_ids,
131
'attention_mask': attention_mask,
132
'sents': sents,
133
'sent_ids': sent_ids
134
}
135
136
return batched_data
137
138
# Load the data: a list of (sentence, label)
139
def load_data(filename, flag='train'):
140
num_labels = {}
141
data = []
142
if flag == 'test':
143
with open(filename, 'r') as fp:
144
for record in csv.DictReader(fp,delimiter = '\t'):
145
sent = record['sentence'].lower().strip()
146
sent_id = record['id'].lower().strip()
147
data.append((sent,sent_id))
148
else:
149
with open(filename, 'r') as fp:
150
for record in csv.DictReader(fp,delimiter = '\t'):
151
sent = record['sentence'].lower().strip()
152
sent_id = record['id'].lower().strip()
153
label = int(record['sentiment'].strip())
154
if label not in num_labels:
155
num_labels[label] = len(num_labels)
156
data.append((sent, label,sent_id))
157
print(f"load {len(data)} data from {filename}")
158
159
if flag == 'train':
160
return data, len(num_labels)
161
else:
162
return data
163
164
# Evaluate the model for accuracy.
165
def model_eval(dataloader, model, device):
166
model.eval() # switch to eval model, will turn off randomness like dropout
167
y_true = []
168
y_pred = []
169
sents = []
170
sent_ids = []
171
for step, batch in enumerate(tqdm(dataloader, desc=f'eval', disable=TQDM_DISABLE)):
172
b_ids, b_mask, b_labels, b_sents, b_sent_ids = batch['token_ids'],batch['attention_mask'], \
173
batch['labels'], batch['sents'], batch['sent_ids']
174
175
176
b_ids = b_ids.to(device)
177
b_mask = b_mask.to(device)
178
179
logits = model(b_ids, b_mask)
180
logits = logits.detach().cpu().numpy()
181
preds = np.argmax(logits, axis=1).flatten()
182
183
b_labels = b_labels.flatten()
184
y_true.extend(b_labels)
185
y_pred.extend(preds)
186
sents.extend(b_sents)
187
sent_ids.extend(b_sent_ids)
188
189
f1 = f1_score(y_true, y_pred, average='macro')
190
acc = accuracy_score(y_true, y_pred)
191
192
return acc, f1, y_pred, y_true, sents, sent_ids
193
194
195
def model_test_eval(dataloader, model, device):
196
model.eval() # switch to eval model, will turn off randomness like dropout
197
y_pred = []
198
sents = []
199
sent_ids = []
200
for step, batch in enumerate(tqdm(dataloader, desc=f'eval', disable=TQDM_DISABLE)):
201
b_ids, b_mask, b_sents, b_sent_ids = batch['token_ids'],batch['attention_mask'], \
202
batch['sents'], batch['sent_ids']
203
204
205
b_ids = b_ids.to(device)
206
b_mask = b_mask.to(device)
207
208
logits = model(b_ids, b_mask)
209
logits = logits.detach().cpu().numpy()
210
preds = np.argmax(logits, axis=1).flatten()
211
212
y_pred.extend(preds)
213
sents.extend(b_sents)
214
sent_ids.extend(b_sent_ids)
215
216
return y_pred, sents, sent_ids
217
218
219
def save_model(model, optimizer, args, config, filepath):
220
save_info = {
221
'model': model.state_dict(),
222
'optim': optimizer.state_dict(),
223
'args': args,
224
'model_config': config,
225
'system_rng': random.getstate(),
226
'numpy_rng': np.random.get_state(),
227
'torch_rng': torch.random.get_rng_state(),
228
}
229
230
torch.save(save_info, filepath)
231
print(f"save the model to {filepath}")
232
233
234
def train(args):
235
device = torch.device('cuda') if args.use_gpu else torch.device('cpu')
236
# Load data
237
# Create the data and its corresponding datasets and dataloader
238
train_data, num_labels = load_data(args.train, 'train')
239
dev_data = load_data(args.dev, 'valid')
240
241
train_dataset = SentimentDataset(train_data, args)
242
dev_dataset = SentimentDataset(dev_data, args)
243
244
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size,
245
collate_fn=train_dataset.collate_fn)
246
dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=args.batch_size,
247
collate_fn=dev_dataset.collate_fn)
248
249
# Init model
250
config = {'hidden_dropout_prob': args.hidden_dropout_prob,
251
'num_labels': num_labels,
252
'hidden_size': 768,
253
'data_dir': '.',
254
'option': args.option}
255
256
config = SimpleNamespace(**config)
257
258
model = BertSentimentClassifier(config)
259
model = model.to(device)
260
261
lr = args.lr
262
optimizer = AdamW(model.parameters(), lr=lr)
263
best_dev_acc = 0
264
265
# Run for the specified number of epochs
266
for epoch in range(args.epochs):
267
model.train()
268
train_loss = 0
269
num_batches = 0
270
for batch in tqdm(train_dataloader, desc=f'train-{epoch}', disable=TQDM_DISABLE):
271
b_ids, b_mask, b_labels = (batch['token_ids'],
272
batch['attention_mask'], batch['labels'])
273
274
b_ids = b_ids.to(device)
275
b_mask = b_mask.to(device)
276
b_labels = b_labels.to(device)
277
278
optimizer.zero_grad()
279
logits = model(b_ids, b_mask)
280
loss = F.cross_entropy(logits, b_labels.view(-1), reduction='sum') / args.batch_size
281
282
loss.backward()
283
optimizer.step()
284
285
train_loss += loss.item()
286
num_batches += 1
287
288
train_loss = train_loss / (num_batches)
289
290
train_acc, train_f1, *_ = model_eval(train_dataloader, model, device)
291
dev_acc, dev_f1, *_ = model_eval(dev_dataloader, model, device)
292
293
if dev_acc > best_dev_acc:
294
best_dev_acc = dev_acc
295
save_model(model, optimizer, args, config, args.filepath)
296
297
print(f"Epoch {epoch}: train loss :: {train_loss :.3f}, train acc :: {train_acc :.3f}, dev acc :: {dev_acc :.3f}")
298
299
300
def test(args):
301
with torch.no_grad():
302
device = torch.device('cuda') if args.use_gpu else torch.device('cpu')
303
saved = torch.load(args.filepath)
304
config = saved['model_config']
305
model = BertSentimentClassifier(config)
306
model.load_state_dict(saved['model'])
307
model = model.to(device)
308
print(f"load model from {args.filepath}")
309
310
dev_data = load_data(args.dev, 'valid')
311
dev_dataset = SentimentDataset(dev_data, args)
312
dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=args.batch_size, collate_fn=dev_dataset.collate_fn)
313
314
test_data = load_data(args.test, 'test')
315
test_dataset = SentimentTestDataset(test_data, args)
316
test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=args.batch_size, collate_fn=test_dataset.collate_fn)
317
318
dev_acc, dev_f1, dev_pred, dev_true, dev_sents, dev_sent_ids = model_eval(dev_dataloader, model, device)
319
print('DONE DEV')
320
test_pred, test_sents, test_sent_ids = model_test_eval(test_dataloader, model, device)
321
print('DONE Test')
322
with open(args.dev_out, "w+") as f:
323
print(f"dev acc :: {dev_acc :.3f}")
324
f.write(f"id \t Predicted_Sentiment \n")
325
for p, s in zip(dev_sent_ids,dev_pred ):
326
f.write(f"{p} , {s} \n")
327
328
with open(args.test_out, "w+") as f:
329
f.write(f"id \t Predicted_Sentiment \n")
330
for p, s in zip(test_sent_ids,test_pred ):
331
f.write(f"{p} , {s} \n")
332
def get_args():
333
parser = argparse.ArgumentParser()
334
parser.add_argument("--seed", type=int, default=11711)
335
parser.add_argument("--epochs", type=int, default=10)
336
parser.add_argument("--option", type=str,
337
help='pretrain: the BERT parameters are frozen; finetune: BERT parameters are updated',
338
choices=('pretrain', 'finetune'), default="pretrain")
339
parser.add_argument("--use_gpu", action='store_true')
340
parser.add_argument("--dev_out", type=str, default="cfimdb-dev-output.txt")
341
parser.add_argument("--test_out", type=str, default="cfimdb-test-output.txt")
342
343
344
parser.add_argument("--batch_size", help='sst: 64, cfimdb: 8 can fit a 12GB GPU', type=int, default=8)
345
parser.add_argument("--hidden_dropout_prob", type=float, default=0.3)
346
parser.add_argument("--lr", type=float, help="learning rate, default lr for 'pretrain': 1e-3, 'finetune': 1e-5",
347
default=1e-5)
348
349
args = parser.parse_args()
350
return args
351
352
if __name__ == "__main__":
353
args = get_args()
354
seed_everything(args.seed)
355
#args.filepath = f'{args.option}-{args.epochs}-{args.lr}.pt'
356
357
print('Training Sentiment Classifier on SST...')
358
config = SimpleNamespace(
359
filepath='sst-classifier.pt',
360
lr=args.lr,
361
use_gpu=args.use_gpu,
362
epochs=args.epochs,
363
batch_size=args.batch_size,
364
hidden_dropout_prob=args.hidden_dropout_prob,
365
train='data/ids-sst-train.csv',
366
dev='data/ids-sst-dev.csv',
367
test='data/ids-sst-test-student.csv',
368
option=args.option,
369
dev_out = 'predictions/'+args.option+'-sst-dev-out.csv',
370
test_out = 'predictions/'+args.option+'-sst-test-out.csv'
371
)
372
373
train(config)
374
375
print('Evaluating on SST...')
376
test(config)
377
378
print('Training Sentiment Classifier on cfimdb...')
379
config = SimpleNamespace(
380
filepath='cfimdb-classifier.pt',
381
lr=args.lr,
382
use_gpu=args.use_gpu,
383
epochs=args.epochs,
384
batch_size=8,
385
hidden_dropout_prob=args.hidden_dropout_prob,
386
train='data/ids-cfimdb-train.csv',
387
dev='data/ids-cfimdb-dev.csv',
388
test='data/ids-cfimdb-test-student.csv',
389
option=args.option,
390
dev_out = 'predictions/'+args.option+'-cfimdb-dev-out.csv',
391
test_out = 'predictions/'+args.option+'-cfimdb-test-out.csv'
392
)
393
394
train(config)
395
396
print('Evaluating on cfimdb...')
397
test(config)
398
399