Path: blob/main/minBERT/evaluation.py
3763 views
#!/usr/bin/env python312'''3Model evaluation functions.45When training your multitask model, you will find it useful to run6model_eval_multitask to be able to evaluate your model on the 3 tasks in the7development set.89Before submission, your code needs to call test_model_multitask(args, model, device) to generate10your predictions. We'll evaluate these predictions against our labels on our end,11which is how the leaderboard will be updated.12The provided test_model() function in multitask_classifier.py **already does this for you**,13so unless you change it you shouldn't need to call anything from here14explicitly aside from model_eval_multitask.15'''1617import torch18from torch.utils.data import DataLoader19from sklearn.metrics import classification_report, f1_score, recall_score, accuracy_score20from tqdm import tqdm21import numpy as np2223from datasets import load_multitask_data, load_multitask_test_data, \24SentenceClassificationDataset, SentenceClassificationTestDataset, \25SentencePairDataset, SentencePairTestDataset262728TQDM_DISABLE = True2930# Evaluate a multitask model for accuracy.on SST only.31def model_eval_sst(dataloader, model, device):32model.eval() # switch to eval model, will turn off randomness like dropout33y_true = []34y_pred = []35sents = []36sent_ids = []37for step, batch in enumerate(tqdm(dataloader, desc=f'eval', disable=TQDM_DISABLE)):38b_ids, b_mask, b_labels, b_sents, b_sent_ids = batch['token_ids'],batch['attention_mask'], \39batch['labels'], batch['sents'], batch['sent_ids']4041b_ids = b_ids.to(device)42b_mask = b_mask.to(device)4344logits = model.predict_sentiment(b_ids, b_mask)45logits = logits.detach().cpu().numpy()46preds = np.argmax(logits, axis=1).flatten()4748b_labels = b_labels.flatten()49y_true.extend(b_labels)50y_pred.extend(preds)51sents.extend(b_sents)52sent_ids.extend(b_sent_ids)5354f1 = f1_score(y_true, y_pred, average='macro')55acc = accuracy_score(y_true, y_pred)5657return acc, f1, y_pred, y_true, sents, sent_ids5859# Perform model evaluation in terms by averaging accuracies across tasks.60def model_eval_multitask(sentiment_dataloader,61paraphrase_dataloader,62sts_dataloader,63model, device):64model.eval() # switch to eval model, will turn off randomness like dropout6566with torch.no_grad():67para_y_true = []68para_y_pred = []69para_sent_ids = []7071# Evaluate paraphrase detection.72for step, batch in enumerate(tqdm(paraphrase_dataloader, desc=f'eval', disable=TQDM_DISABLE)):73(b_ids1, b_mask1,74b_ids2, b_mask2,75b_labels, b_sent_ids) = (batch['token_ids_1'], batch['attention_mask_1'],76batch['token_ids_2'], batch['attention_mask_2'],77batch['labels'], batch['sent_ids'])7879b_ids1 = b_ids1.to(device)80b_mask1 = b_mask1.to(device)81b_ids2 = b_ids2.to(device)82b_mask2 = b_mask2.to(device)8384logits = model.predict_paraphrase(b_ids1, b_mask1, b_ids2, b_mask2)85y_hat = logits.sigmoid().round().flatten().cpu().numpy()86b_labels = b_labels.flatten().cpu().numpy()8788para_y_pred.extend(y_hat)89para_y_true.extend(b_labels)90para_sent_ids.extend(b_sent_ids)9192paraphrase_accuracy = np.mean(np.array(para_y_pred) == np.array(para_y_true))9394sts_y_true = []95sts_y_pred = []96sts_sent_ids = []979899# Evaluate semantic textual similarity.100for step, batch in enumerate(tqdm(sts_dataloader, desc=f'eval', disable=TQDM_DISABLE)):101(b_ids1, b_mask1,102b_ids2, b_mask2,103b_labels, b_sent_ids) = (batch['token_ids_1'], batch['attention_mask_1'],104batch['token_ids_2'], batch['attention_mask_2'],105batch['labels'], batch['sent_ids'])106107b_ids1 = b_ids1.to(device)108b_mask1 = b_mask1.to(device)109b_ids2 = b_ids2.to(device)110b_mask2 = b_mask2.to(device)111112logits = model.predict_similarity(b_ids1, b_mask1, b_ids2, b_mask2)113y_hat = logits.flatten().cpu().numpy()114b_labels = b_labels.flatten().cpu().numpy()115116sts_y_pred.extend(y_hat)117sts_y_true.extend(b_labels)118sts_sent_ids.extend(b_sent_ids)119pearson_mat = np.corrcoef(sts_y_pred,sts_y_true)120sts_corr = pearson_mat[1][0]121122123sst_y_true = []124sst_y_pred = []125sst_sent_ids = []126127# Evaluate sentiment classification.128for step, batch in enumerate(tqdm(sentiment_dataloader, desc=f'eval', disable=TQDM_DISABLE)):129b_ids, b_mask, b_labels, b_sent_ids = batch['token_ids'], batch['attention_mask'], batch['labels'], batch['sent_ids']130131b_ids = b_ids.to(device)132b_mask = b_mask.to(device)133134logits = model.predict_sentiment(b_ids, b_mask)135y_hat = logits.argmax(dim=-1).flatten().cpu().numpy()136b_labels = b_labels.flatten().cpu().numpy()137138sst_y_pred.extend(y_hat)139sst_y_true.extend(b_labels)140sst_sent_ids.extend(b_sent_ids)141142sentiment_accuracy = np.mean(np.array(sst_y_pred) == np.array(sst_y_true))143144print(f'Paraphrase detection accuracy: {paraphrase_accuracy:.3f}')145print(f'Sentiment classification accuracy: {sentiment_accuracy:.3f}')146print(f'Semantic Textual Similarity correlation: {sts_corr:.3f}')147148return (paraphrase_accuracy, para_y_pred, para_sent_ids,149sentiment_accuracy,sst_y_pred, sst_sent_ids,150sts_corr, sts_y_pred, sts_sent_ids)151152# Perform model evaluation in terms by averaging accuracies across tasks.153def model_eval_test_multitask(sentiment_dataloader,154paraphrase_dataloader,155sts_dataloader,156model, device):157model.eval() # switch to eval model, will turn off randomness like dropout158159with torch.no_grad():160161para_y_pred = []162para_sent_ids = []163# Evaluate paraphrase detection.164for step, batch in enumerate(tqdm(paraphrase_dataloader, desc=f'eval', disable=TQDM_DISABLE)):165(b_ids1, b_mask1,166b_ids2, b_mask2,167b_sent_ids) = (batch['token_ids_1'], batch['attention_mask_1'],168batch['token_ids_2'], batch['attention_mask_2'],169batch['sent_ids'])170171b_ids1 = b_ids1.to(device)172b_mask1 = b_mask1.to(device)173b_ids2 = b_ids2.to(device)174b_mask2 = b_mask2.to(device)175176logits = model.predict_paraphrase(b_ids1, b_mask1, b_ids2, b_mask2)177y_hat = logits.sigmoid().round().flatten().cpu().numpy()178179para_y_pred.extend(y_hat)180para_sent_ids.extend(b_sent_ids)181182183sts_y_pred = []184sts_sent_ids = []185186187# Evaluate semantic textual similarity.188for step, batch in enumerate(tqdm(sts_dataloader, desc=f'eval', disable=TQDM_DISABLE)):189(b_ids1, b_mask1,190b_ids2, b_mask2,191b_sent_ids) = (batch['token_ids_1'], batch['attention_mask_1'],192batch['token_ids_2'], batch['attention_mask_2'],193batch['sent_ids'])194195b_ids1 = b_ids1.to(device)196b_mask1 = b_mask1.to(device)197b_ids2 = b_ids2.to(device)198b_mask2 = b_mask2.to(device)199200logits = model.predict_similarity(b_ids1, b_mask1, b_ids2, b_mask2)201y_hat = logits.flatten().cpu().numpy()202203sts_y_pred.extend(y_hat)204sts_sent_ids.extend(b_sent_ids)205206207sst_y_pred = []208sst_sent_ids = []209210# Evaluate sentiment classification.211for step, batch in enumerate(tqdm(sentiment_dataloader, desc=f'eval', disable=TQDM_DISABLE)):212b_ids, b_mask, b_sent_ids = batch['token_ids'], batch['attention_mask'], batch['sent_ids']213214b_ids = b_ids.to(device)215b_mask = b_mask.to(device)216217logits = model.predict_sentiment(b_ids, b_mask)218y_hat = logits.argmax(dim=-1).flatten().cpu().numpy()219220sst_y_pred.extend(y_hat)221sst_sent_ids.extend(b_sent_ids)222223return (para_y_pred, para_sent_ids,224sst_y_pred, sst_sent_ids,225sts_y_pred, sts_sent_ids)226227228def test_model_multitask(args, model, device):229sst_test_data, num_labels,para_test_data, sts_test_data = \230load_multitask_data(args.sst_test,args.para_test, args.sts_test, split='test')231232sst_dev_data, num_labels,para_dev_data, sts_dev_data = \233load_multitask_data(args.sst_dev,args.para_dev,args.sts_dev,split='dev')234235sst_test_data = SentenceClassificationTestDataset(sst_test_data, args)236sst_dev_data = SentenceClassificationDataset(sst_dev_data, args)237238sst_test_dataloader = DataLoader(sst_test_data, shuffle=True, batch_size=args.batch_size,239collate_fn=sst_test_data.collate_fn)240sst_dev_dataloader = DataLoader(sst_dev_data, shuffle=False, batch_size=args.batch_size,241collate_fn=sst_dev_data.collate_fn)242243para_test_data = SentencePairTestDataset(para_test_data, args)244para_dev_data = SentencePairDataset(para_dev_data, args)245246para_test_dataloader = DataLoader(para_test_data, shuffle=True, batch_size=args.batch_size,247collate_fn=para_test_data.collate_fn)248para_dev_dataloader = DataLoader(para_dev_data, shuffle=False, batch_size=args.batch_size,249collate_fn=para_dev_data.collate_fn)250251sts_test_data = SentencePairTestDataset(sts_test_data, args)252sts_dev_data = SentencePairDataset(sts_dev_data, args, isRegression=True)253254sts_test_dataloader = DataLoader(sts_test_data, shuffle=True, batch_size=args.batch_size,255collate_fn=sts_test_data.collate_fn)256sts_dev_dataloader = DataLoader(sts_dev_data, shuffle=False, batch_size=args.batch_size,257collate_fn=sts_dev_data.collate_fn)258259dev_paraphrase_accuracy, dev_para_y_pred, dev_para_sent_ids, \260dev_sentiment_accuracy,dev_sst_y_pred, dev_sst_sent_ids, dev_sts_corr, \261dev_sts_y_pred, dev_sts_sent_ids = model_eval_multitask(sst_dev_dataloader,262para_dev_dataloader,263sts_dev_dataloader, model, device)264265test_para_y_pred, test_para_sent_ids, test_sst_y_pred, \266test_sst_sent_ids, test_sts_y_pred, test_sts_sent_ids = \267model_eval_test_multitask(sst_test_dataloader,268para_test_dataloader,269sts_test_dataloader, model, device)270271with open(args.sst_dev_out, "w+") as f:272print(f"dev sentiment acc :: {dev_sentiment_accuracy :.3f}")273f.write(f"id \t Predicted_Sentiment \n")274for p, s in zip(dev_sst_sent_ids, dev_sst_y_pred):275f.write(f"{p} , {s} \n")276277with open(args.sst_test_out, "w+") as f:278f.write(f"id \t Predicted_Sentiment \n")279for p, s in zip(test_sst_sent_ids, test_sst_y_pred):280f.write(f"{p} , {s} \n")281282with open(args.para_dev_out, "w+") as f:283print(f"dev paraphrase acc :: {dev_paraphrase_accuracy :.3f}")284f.write(f"id \t Predicted_Is_Paraphrase \n")285for p, s in zip(dev_para_sent_ids, dev_para_y_pred):286f.write(f"{p} , {s} \n")287288with open(args.para_test_out, "w+") as f:289f.write(f"id \t Predicted_Is_Paraphrase \n")290for p, s in zip(test_para_sent_ids, test_para_y_pred):291f.write(f"{p} , {s} \n")292293with open(args.sts_dev_out, "w+") as f:294print(f"dev sts corr :: {dev_sts_corr :.3f}")295f.write(f"id \t Predicted_Similiary \n")296for p, s in zip(dev_sts_sent_ids, dev_sts_y_pred):297f.write(f"{p} , {s} \n")298299with open(args.sts_test_out, "w+") as f:300f.write(f"id \t Predicted_Similiary \n")301for p, s in zip(test_sts_sent_ids, test_sts_y_pred):302f.write(f"{p} , {s} \n")303304305