Path: blob/main/minBERT/multitask_classifier.py
984 views
import time, random, numpy as np, argparse, sys, re, os1from types import SimpleNamespace23import torch4from torch import nn5import torch.nn.functional as F6from torch.utils.data import DataLoader78from bert import BertModel9from optimizer import AdamW10from tqdm import tqdm1112from datasets import SentenceClassificationDataset, SentencePairDataset, \13load_multitask_data, load_multitask_test_data1415from evaluation import model_eval_sst, test_model_multitask161718TQDM_DISABLE=True1920# fix the random seed21def seed_everything(seed=11711):22random.seed(seed)23np.random.seed(seed)24torch.manual_seed(seed)25torch.cuda.manual_seed(seed)26torch.cuda.manual_seed_all(seed)27torch.backends.cudnn.benchmark = False28torch.backends.cudnn.deterministic = True293031BERT_HIDDEN_SIZE = 76832N_SENTIMENT_CLASSES = 5333435class MultitaskBERT(nn.Module):36'''37This module should use BERT for 3 tasks:3839- Sentiment classification (predict_sentiment)40- Paraphrase detection (predict_paraphrase)41- Semantic Textual Similarity (predict_similarity)42'''43def __init__(self, config):44super(MultitaskBERT, self).__init__()45# You will want to add layers here to perform the downstream tasks.46# Pretrain mode does not require updating bert paramters.47self.bert = BertModel.from_pretrained('bert-base-uncased')48for param in self.bert.parameters():49if config.option == 'pretrain':50param.requires_grad = False51elif config.option == 'finetune':52param.requires_grad = True53### TODO54raise NotImplementedError555657def forward(self, input_ids, attention_mask):58'Takes a batch of sentences and produces embeddings for them.'59# The final BERT embedding is the hidden state of [CLS] token (the first token)60# Here, you can start by just returning the embeddings straight from BERT.61# When thinking of improvements, you can later try modifying this62# (e.g., by adding other layers).63### TODO64raise NotImplementedError656667def predict_sentiment(self, input_ids, attention_mask):68'''Given a batch of sentences, outputs logits for classifying sentiment.69There are 5 sentiment classes:70(0 - negative, 1- somewhat negative, 2- neutral, 3- somewhat positive, 4- positive)71Thus, your output should contain 5 logits for each sentence.72'''73### TODO74raise NotImplementedError757677def predict_paraphrase(self,78input_ids_1, attention_mask_1,79input_ids_2, attention_mask_2):80'''Given a batch of pairs of sentences, outputs a single logit for predicting whether they are paraphrases.81Note that your output should be unnormalized (a logit); it will be passed to the sigmoid function82during evaluation, and handled as a logit by the appropriate loss function.83'''84### TODO85raise NotImplementedError868788def predict_similarity(self,89input_ids_1, attention_mask_1,90input_ids_2, attention_mask_2):91'''Given a batch of pairs of sentences, outputs a single logit corresponding to how similar they are.92Note that your output should be unnormalized (a logit).93'''94### TODO95raise NotImplementedError96979899100def save_model(model, optimizer, args, config, filepath):101save_info = {102'model': model.state_dict(),103'optim': optimizer.state_dict(),104'args': args,105'model_config': config,106'system_rng': random.getstate(),107'numpy_rng': np.random.get_state(),108'torch_rng': torch.random.get_rng_state(),109}110111torch.save(save_info, filepath)112print(f"save the model to {filepath}")113114115## Currently only trains on sst dataset116def train_multitask(args):117device = torch.device('cuda') if args.use_gpu else torch.device('cpu')118# Load data119# Create the data and its corresponding datasets and dataloader120sst_train_data, num_labels,para_train_data, sts_train_data = load_multitask_data(args.sst_train,args.para_train,args.sts_train, split ='train')121sst_dev_data, num_labels,para_dev_data, sts_dev_data = load_multitask_data(args.sst_dev,args.para_dev,args.sts_dev, split ='train')122123sst_train_data = SentenceClassificationDataset(sst_train_data, args)124sst_dev_data = SentenceClassificationDataset(sst_dev_data, args)125126sst_train_dataloader = DataLoader(sst_train_data, shuffle=True, batch_size=args.batch_size,127collate_fn=sst_train_data.collate_fn)128sst_dev_dataloader = DataLoader(sst_dev_data, shuffle=False, batch_size=args.batch_size,129collate_fn=sst_dev_data.collate_fn)130131# Init model132config = {'hidden_dropout_prob': args.hidden_dropout_prob,133'num_labels': num_labels,134'hidden_size': 768,135'data_dir': '.',136'option': args.option}137138config = SimpleNamespace(**config)139140model = MultitaskBERT(config)141model = model.to(device)142143lr = args.lr144optimizer = AdamW(model.parameters(), lr=lr)145best_dev_acc = 0146147# Run for the specified number of epochs148for epoch in range(args.epochs):149model.train()150train_loss = 0151num_batches = 0152for batch in tqdm(sst_train_dataloader, desc=f'train-{epoch}', disable=TQDM_DISABLE):153b_ids, b_mask, b_labels = (batch['token_ids'],154batch['attention_mask'], batch['labels'])155156b_ids = b_ids.to(device)157b_mask = b_mask.to(device)158b_labels = b_labels.to(device)159160optimizer.zero_grad()161logits = model.predict_sentiment(b_ids, b_mask)162loss = F.cross_entropy(logits, b_labels.view(-1), reduction='sum') / args.batch_size163164loss.backward()165optimizer.step()166167train_loss += loss.item()168num_batches += 1169170train_loss = train_loss / (num_batches)171172train_acc, train_f1, *_ = model_eval_sst(sst_train_dataloader, model, device)173dev_acc, dev_f1, *_ = model_eval_sst(sst_dev_dataloader, model, device)174175if dev_acc > best_dev_acc:176best_dev_acc = dev_acc177save_model(model, optimizer, args, config, args.filepath)178179print(f"Epoch {epoch}: train loss :: {train_loss :.3f}, train acc :: {train_acc :.3f}, dev acc :: {dev_acc :.3f}")180181182183def test_model(args):184with torch.no_grad():185device = torch.device('cuda') if args.use_gpu else torch.device('cpu')186saved = torch.load(args.filepath)187config = saved['model_config']188189model = MultitaskBERT(config)190model.load_state_dict(saved['model'])191model = model.to(device)192print(f"Loaded model to test from {args.filepath}")193194test_model_multitask(args, model, device)195196197def get_args():198parser = argparse.ArgumentParser()199parser.add_argument("--sst_train", type=str, default="data/ids-sst-train.csv")200parser.add_argument("--sst_dev", type=str, default="data/ids-sst-dev.csv")201parser.add_argument("--sst_test", type=str, default="data/ids-sst-test-student.csv")202203parser.add_argument("--para_train", type=str, default="data/quora-train.csv")204parser.add_argument("--para_dev", type=str, default="data/quora-dev.csv")205parser.add_argument("--para_test", type=str, default="data/quora-test-student.csv")206207parser.add_argument("--sts_train", type=str, default="data/sts-train.csv")208parser.add_argument("--sts_dev", type=str, default="data/sts-dev.csv")209parser.add_argument("--sts_test", type=str, default="data/sts-test-student.csv")210211parser.add_argument("--seed", type=int, default=11711)212parser.add_argument("--epochs", type=int, default=10)213parser.add_argument("--option", type=str,214help='pretrain: the BERT parameters are frozen; finetune: BERT parameters are updated',215choices=('pretrain', 'finetune'), default="pretrain")216parser.add_argument("--use_gpu", action='store_true')217218parser.add_argument("--sst_dev_out", type=str, default="predictions/sst-dev-output.csv")219parser.add_argument("--sst_test_out", type=str, default="predictions/sst-test-output.csv")220221parser.add_argument("--para_dev_out", type=str, default="predictions/para-dev-output.csv")222parser.add_argument("--para_test_out", type=str, default="predictions/para-test-output.csv")223224parser.add_argument("--sts_dev_out", type=str, default="predictions/sts-dev-output.csv")225parser.add_argument("--sts_test_out", type=str, default="predictions/sts-test-output.csv")226227# hyper parameters228parser.add_argument("--batch_size", help='sst: 64, cfimdb: 8 can fit a 12GB GPU', type=int, default=8)229parser.add_argument("--hidden_dropout_prob", type=float, default=0.3)230parser.add_argument("--lr", type=float, help="learning rate, default lr for 'pretrain': 1e-3, 'finetune': 1e-5",231default=1e-5)232233args = parser.parse_args()234return args235236if __name__ == "__main__":237args = get_args()238args.filepath = f'{args.option}-{args.epochs}-{args.lr}-multitask.pt' # save path239seed_everything(args.seed) # fix the seed for reproducibility240train_multitask(args)241test_model(args)242243244