Path: blob/master/deep_learning/seq2seq/translation_mt5/seq2seq_eval.py
2581 views
"""1Run a seq2seq Marian translation model evaluation on wmt16 dataset.2"""3import os4import torch5import random6import evaluate7import numpy as np8from datasets import load_dataset9from dataclasses import dataclass10from transformers import (11AutoTokenizer,12AutoModelForSeq2SeqLM,13Seq2SeqTrainingArguments,14Seq2SeqTrainer,15DataCollatorForSeq2Seq16)17from translation_utils import download_file, create_translation_data181920@dataclass21class Config:22cache_dir: str = "./translation"23data_dir: str = os.path.join(cache_dir, "wmt16")24source_lang: str = 'de'25target_lang: str = 'en'2627batch_size: int = 1628num_workers: int = 429seed: int = 4230max_source_length: int = 12831max_target_length: int = 1283233device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")3435model_checkpoint: str = "Helsinki-NLP/opus-mt-de-en"3637def __post_init__(self):38random.seed(self.seed)39np.random.seed(self.seed)40torch.manual_seed(self.seed)41torch.cuda.manual_seed_all(self.seed)4243self.tokenizer = AutoTokenizer.from_pretrained(44self.model_checkpoint,45cache_dir=self.cache_dir46)47self.model = AutoModelForSeq2SeqLM.from_pretrained(48self.model_checkpoint,49cache_dir=self.cache_dir50)51print('# of parameters: ', self.model.num_parameters())525354def batch_tokenize_fn(examples):55"""56Generate the input_ids and labels field for huggingface dataset/dataset dict.5758Truncation is enabled where we cap the sentence to the max length. Padding will be done later59in a data collator, so we pad examples to the longest length within a mini-batch and not60the whole dataset.61"""62sources = examples[config.source_lang]63targets = examples[config.target_lang]64model_inputs = config.tokenizer(sources, max_length=config.max_source_length, truncation=True)6566# setup the tokenizer for targets,67# huggingface expects the target tokenized ids to be stored in the labels field68with config.tokenizer.as_target_tokenizer():69labels = config.tokenizer(targets, max_length=config.max_target_length, truncation=True)7071model_inputs["labels"] = labels["input_ids"]72return model_inputs737475def compute_metrics(eval_pred):76"""77note: we can run trainer.predict on our eval/test dataset to see what a sample78eval_pred object would look like when implementing custom compute metrics function79"""80predictions, labels = eval_pred81# Decode generated summaries into text82decoded_preds = config.tokenizer.batch_decode(predictions, skip_special_tokens=True)83# Replace -100 in the labels as we can't decode them84labels = np.where(labels != -100, labels, config.tokenizer.pad_token_id)85# Decode reference summaries into text86decoded_labels = config.tokenizer.batch_decode(labels, skip_special_tokens=True)87result = rouge_score.compute(88predictions=decoded_preds,89references=decoded_labels,90rouge_types=["rouge1", "rouge2", "rougeL"]91)92score = sacrebleu_score.compute(93predictions=decoded_preds,94references=decoded_labels95)96result["sacrebleu"] = score["score"]97return {k: round(v, 4) for k, v in result.items()}9899100def create_wmt16_data_files(config: Config):101# files are downloaded from102# http://www.statmt.org/wmt16/multimodal-task.html103urls = [104'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/training.tar.gz',105'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz',106'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/mmt16_task1_test.tar.gz'107]108109for url in urls:110download_file(url, config.data_dir)111112data_files = {}113for split in ["train", "val", "test"]:114source_input_path = os.path.join(config.data_dir, f"{split}.{config.source_lang}")115target_input_path = os.path.join(config.data_dir, f"{split}.{config.target_lang}")116output_path = os.path.join(config.cache_dir, f"{split}.tsv")117create_translation_data(source_input_path, target_input_path, output_path)118data_files[split] = [output_path]119120return data_files121122123if __name__ == "__main__":124config = Config()125126data_files = create_wmt16_data_files(config)127dataset_dict = load_dataset(128'csv',129delimiter='\t',130column_names=[config.source_lang, config.target_lang],131data_files=data_files132)133dataset_dict_tokenized = dataset_dict.map(134batch_tokenize_fn,135batched=True136)137138model_name = config.model_checkpoint.split("/")[-1]139output_dir = os.path.join(140config.cache_dir,141f"{model_name}_{config.source_lang}-{config.target_lang}"142)143args = Seq2SeqTrainingArguments(144output_dir=output_dir,145per_device_eval_batch_size=config.batch_size,146predict_with_generate=True147)148data_collator = DataCollatorForSeq2Seq(config.tokenizer, model=config.model)149rouge_score = evaluate.load("rouge", cache_dir=config.cache_dir)150sacrebleu_score = evaluate.load("sacrebleu", cache_dir=config.cache_dir)151trainer = Seq2SeqTrainer(152config.model,153args,154train_dataset=dataset_dict_tokenized["train"],155eval_dataset=dataset_dict_tokenized["val"],156data_collator=data_collator,157tokenizer=config.tokenizer,158compute_metrics=compute_metrics,159)160print(trainer.evaluate())161162163