Path: blob/main/minBERT/classifier.py
3763 views
import time, random, numpy as np, argparse, sys, re, os1from types import SimpleNamespace2import csv34import torch5import torch.nn.functional as F6from torch import nn7from torch.utils.data import Dataset, DataLoader8from sklearn.metrics import classification_report, f1_score, recall_score, accuracy_score910# change it with respect to the original model11from tokenizer import BertTokenizer12from bert import BertModel13from optimizer import AdamW14from tqdm import tqdm151617TQDM_DISABLE=False18# fix the random seed19def seed_everything(seed=11711):20random.seed(seed)21np.random.seed(seed)22torch.manual_seed(seed)23torch.cuda.manual_seed(seed)24torch.cuda.manual_seed_all(seed)25torch.backends.cudnn.benchmark = False26torch.backends.cudnn.deterministic = True2728class BertSentimentClassifier(torch.nn.Module):29'''30This module performs sentiment classification using BERT embeddings on the SST dataset.3132In the SST dataset, there are 5 sentiment categories (from 0 - "negative" to 4 - "positive").33Thus, your forward() should return one logit for each of the 5 classes.34'''35def __init__(self, config):36super(BertSentimentClassifier, self).__init__()37self.num_labels = config.num_labels38self.bert = BertModel.from_pretrained('/home/minbert-default-final-project/bert')3940# Pretrain mode does not require updating bert paramters.41for param in self.bert.parameters():42if config.option == 'pretrain':43param.requires_grad = False44elif config.option == 'finetune':45param.requires_grad = True4647### TODO48self.classifier = nn.Sequential(nn.Linear(config.hidden_size, 64),49nn.ReLU(),50nn.Linear(64, self.num_labels))5152def forward(self, input_ids, attention_mask):53'''Takes a batch of sentences and returns logits for sentiment classes'''54# The final BERT contextualized embedding is the hidden state of [CLS] token (the first token).55# HINT: you should consider what is the appropriate output to return given that56# the training loop currently uses F.cross_entropy as the loss function.57### TODO58output_dict = self.bert(input_ids, attention_mask)59cls = output_dict['pooler_output']60logits = self.classifier(cls)61return logits626364class SentimentDataset(Dataset):65def __init__(self, dataset, args):66self.dataset = dataset67self.p = args68self.tokenizer = BertTokenizer.from_pretrained('/home/minbert-default-final-project/bert')6970def __len__(self):71return len(self.dataset)7273def __getitem__(self, idx):74return self.dataset[idx]7576def pad_data(self, data):7778sents = [x[0] for x in data]79labels = [x[1] for x in data]80sent_ids = [x[2] for x in data]8182encoding = self.tokenizer(sents, return_tensors='pt', padding=True, truncation=True)83token_ids = torch.LongTensor(encoding['input_ids'])84attention_mask = torch.LongTensor(encoding['attention_mask'])85labels = torch.LongTensor(labels)8687return token_ids, attention_mask, labels, sents, sent_ids8889def collate_fn(self, all_data):90token_ids, attention_mask, labels, sents, sent_ids= self.pad_data(all_data)9192batched_data = {93'token_ids': token_ids,94'attention_mask': attention_mask,95'labels': labels,96'sents': sents,97'sent_ids': sent_ids98}99100return batched_data101102class SentimentTestDataset(Dataset):103def __init__(self, dataset, args):104self.dataset = dataset105self.p = args106self.tokenizer = BertTokenizer.from_pretrained('/home/minbert-default-final-project/bert')107108def __len__(self):109return len(self.dataset)110111def __getitem__(self, idx):112return self.dataset[idx]113114def pad_data(self, data):115116sents = [x[0] for x in data]117sent_ids = [x[1] for x in data]118119encoding = self.tokenizer(sents, return_tensors='pt', padding=True, truncation=True)120token_ids = torch.LongTensor(encoding['input_ids'])121attention_mask = torch.LongTensor(encoding['attention_mask'])122123return token_ids, attention_mask, sents, sent_ids124125def collate_fn(self, all_data):126token_ids, attention_mask, sents, sent_ids= self.pad_data(all_data)127128batched_data = {129'token_ids': token_ids,130'attention_mask': attention_mask,131'sents': sents,132'sent_ids': sent_ids133}134135return batched_data136137# Load the data: a list of (sentence, label)138def load_data(filename, flag='train'):139num_labels = {}140data = []141if flag == 'test':142with open(filename, 'r') as fp:143for record in csv.DictReader(fp,delimiter = '\t'):144sent = record['sentence'].lower().strip()145sent_id = record['id'].lower().strip()146data.append((sent,sent_id))147else:148with open(filename, 'r') as fp:149for record in csv.DictReader(fp,delimiter = '\t'):150sent = record['sentence'].lower().strip()151sent_id = record['id'].lower().strip()152label = int(record['sentiment'].strip())153if label not in num_labels:154num_labels[label] = len(num_labels)155data.append((sent, label,sent_id))156print(f"load {len(data)} data from {filename}")157158if flag == 'train':159return data, len(num_labels)160else:161return data162163# Evaluate the model for accuracy.164def model_eval(dataloader, model, device):165model.eval() # switch to eval model, will turn off randomness like dropout166y_true = []167y_pred = []168sents = []169sent_ids = []170for step, batch in enumerate(tqdm(dataloader, desc=f'eval', disable=TQDM_DISABLE)):171b_ids, b_mask, b_labels, b_sents, b_sent_ids = batch['token_ids'],batch['attention_mask'], \172batch['labels'], batch['sents'], batch['sent_ids']173174175b_ids = b_ids.to(device)176b_mask = b_mask.to(device)177178logits = model(b_ids, b_mask)179logits = logits.detach().cpu().numpy()180preds = np.argmax(logits, axis=1).flatten()181182b_labels = b_labels.flatten()183y_true.extend(b_labels)184y_pred.extend(preds)185sents.extend(b_sents)186sent_ids.extend(b_sent_ids)187188f1 = f1_score(y_true, y_pred, average='macro')189acc = accuracy_score(y_true, y_pred)190191return acc, f1, y_pred, y_true, sents, sent_ids192193194def model_test_eval(dataloader, model, device):195model.eval() # switch to eval model, will turn off randomness like dropout196y_pred = []197sents = []198sent_ids = []199for step, batch in enumerate(tqdm(dataloader, desc=f'eval', disable=TQDM_DISABLE)):200b_ids, b_mask, b_sents, b_sent_ids = batch['token_ids'],batch['attention_mask'], \201batch['sents'], batch['sent_ids']202203204b_ids = b_ids.to(device)205b_mask = b_mask.to(device)206207logits = model(b_ids, b_mask)208logits = logits.detach().cpu().numpy()209preds = np.argmax(logits, axis=1).flatten()210211y_pred.extend(preds)212sents.extend(b_sents)213sent_ids.extend(b_sent_ids)214215return y_pred, sents, sent_ids216217218def save_model(model, optimizer, args, config, filepath):219save_info = {220'model': model.state_dict(),221'optim': optimizer.state_dict(),222'args': args,223'model_config': config,224'system_rng': random.getstate(),225'numpy_rng': np.random.get_state(),226'torch_rng': torch.random.get_rng_state(),227}228229torch.save(save_info, filepath)230print(f"save the model to {filepath}")231232233def train(args):234device = torch.device('cuda') if args.use_gpu else torch.device('cpu')235# Load data236# Create the data and its corresponding datasets and dataloader237train_data, num_labels = load_data(args.train, 'train')238dev_data = load_data(args.dev, 'valid')239240train_dataset = SentimentDataset(train_data, args)241dev_dataset = SentimentDataset(dev_data, args)242243train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size,244collate_fn=train_dataset.collate_fn)245dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=args.batch_size,246collate_fn=dev_dataset.collate_fn)247248# Init model249config = {'hidden_dropout_prob': args.hidden_dropout_prob,250'num_labels': num_labels,251'hidden_size': 768,252'data_dir': '.',253'option': args.option}254255config = SimpleNamespace(**config)256257model = BertSentimentClassifier(config)258model = model.to(device)259260lr = args.lr261optimizer = AdamW(model.parameters(), lr=lr)262best_dev_acc = 0263264# Run for the specified number of epochs265for epoch in range(args.epochs):266model.train()267train_loss = 0268num_batches = 0269for batch in tqdm(train_dataloader, desc=f'train-{epoch}', disable=TQDM_DISABLE):270b_ids, b_mask, b_labels = (batch['token_ids'],271batch['attention_mask'], batch['labels'])272273b_ids = b_ids.to(device)274b_mask = b_mask.to(device)275b_labels = b_labels.to(device)276277optimizer.zero_grad()278logits = model(b_ids, b_mask)279loss = F.cross_entropy(logits, b_labels.view(-1), reduction='sum') / args.batch_size280281loss.backward()282optimizer.step()283284train_loss += loss.item()285num_batches += 1286287train_loss = train_loss / (num_batches)288289train_acc, train_f1, *_ = model_eval(train_dataloader, model, device)290dev_acc, dev_f1, *_ = model_eval(dev_dataloader, model, device)291292if dev_acc > best_dev_acc:293best_dev_acc = dev_acc294save_model(model, optimizer, args, config, args.filepath)295296print(f"Epoch {epoch}: train loss :: {train_loss :.3f}, train acc :: {train_acc :.3f}, dev acc :: {dev_acc :.3f}")297298299def test(args):300with torch.no_grad():301device = torch.device('cuda') if args.use_gpu else torch.device('cpu')302saved = torch.load(args.filepath)303config = saved['model_config']304model = BertSentimentClassifier(config)305model.load_state_dict(saved['model'])306model = model.to(device)307print(f"load model from {args.filepath}")308309dev_data = load_data(args.dev, 'valid')310dev_dataset = SentimentDataset(dev_data, args)311dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=args.batch_size, collate_fn=dev_dataset.collate_fn)312313test_data = load_data(args.test, 'test')314test_dataset = SentimentTestDataset(test_data, args)315test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=args.batch_size, collate_fn=test_dataset.collate_fn)316317dev_acc, dev_f1, dev_pred, dev_true, dev_sents, dev_sent_ids = model_eval(dev_dataloader, model, device)318print('DONE DEV')319test_pred, test_sents, test_sent_ids = model_test_eval(test_dataloader, model, device)320print('DONE Test')321with open(args.dev_out, "w+") as f:322print(f"dev acc :: {dev_acc :.3f}")323f.write(f"id \t Predicted_Sentiment \n")324for p, s in zip(dev_sent_ids,dev_pred ):325f.write(f"{p} , {s} \n")326327with open(args.test_out, "w+") as f:328f.write(f"id \t Predicted_Sentiment \n")329for p, s in zip(test_sent_ids,test_pred ):330f.write(f"{p} , {s} \n")331def get_args():332parser = argparse.ArgumentParser()333parser.add_argument("--seed", type=int, default=11711)334parser.add_argument("--epochs", type=int, default=10)335parser.add_argument("--option", type=str,336help='pretrain: the BERT parameters are frozen; finetune: BERT parameters are updated',337choices=('pretrain', 'finetune'), default="pretrain")338parser.add_argument("--use_gpu", action='store_true')339parser.add_argument("--dev_out", type=str, default="cfimdb-dev-output.txt")340parser.add_argument("--test_out", type=str, default="cfimdb-test-output.txt")341342343parser.add_argument("--batch_size", help='sst: 64, cfimdb: 8 can fit a 12GB GPU', type=int, default=8)344parser.add_argument("--hidden_dropout_prob", type=float, default=0.3)345parser.add_argument("--lr", type=float, help="learning rate, default lr for 'pretrain': 1e-3, 'finetune': 1e-5",346default=1e-5)347348args = parser.parse_args()349return args350351if __name__ == "__main__":352args = get_args()353seed_everything(args.seed)354#args.filepath = f'{args.option}-{args.epochs}-{args.lr}.pt'355356print('Training Sentiment Classifier on SST...')357config = SimpleNamespace(358filepath='sst-classifier.pt',359lr=args.lr,360use_gpu=args.use_gpu,361epochs=args.epochs,362batch_size=args.batch_size,363hidden_dropout_prob=args.hidden_dropout_prob,364train='data/ids-sst-train.csv',365dev='data/ids-sst-dev.csv',366test='data/ids-sst-test-student.csv',367option=args.option,368dev_out = 'predictions/'+args.option+'-sst-dev-out.csv',369test_out = 'predictions/'+args.option+'-sst-test-out.csv'370)371372train(config)373374print('Evaluating on SST...')375test(config)376377print('Training Sentiment Classifier on cfimdb...')378config = SimpleNamespace(379filepath='cfimdb-classifier.pt',380lr=args.lr,381use_gpu=args.use_gpu,382epochs=args.epochs,383batch_size=8,384hidden_dropout_prob=args.hidden_dropout_prob,385train='data/ids-cfimdb-train.csv',386dev='data/ids-cfimdb-dev.csv',387test='data/ids-cfimdb-test-student.csv',388option=args.option,389dev_out = 'predictions/'+args.option+'-cfimdb-dev-out.csv',390test_out = 'predictions/'+args.option+'-cfimdb-test-out.csv'391)392393train(config)394395print('Evaluating on cfimdb...')396test(config)397398399