Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ethen8181
GitHub Repository: ethen8181/machine-learning
Path: blob/master/deep_learning/seq2seq/translation_mt5/seq2seq_eval.py
2581 views
1
"""
2
Run a seq2seq Marian translation model evaluation on wmt16 dataset.
3
"""
4
import os
5
import torch
6
import random
7
import evaluate
8
import numpy as np
9
from datasets import load_dataset
10
from dataclasses import dataclass
11
from transformers import (
12
AutoTokenizer,
13
AutoModelForSeq2SeqLM,
14
Seq2SeqTrainingArguments,
15
Seq2SeqTrainer,
16
DataCollatorForSeq2Seq
17
)
18
from translation_utils import download_file, create_translation_data
19
20
21
@dataclass
22
class Config:
23
cache_dir: str = "./translation"
24
data_dir: str = os.path.join(cache_dir, "wmt16")
25
source_lang: str = 'de'
26
target_lang: str = 'en'
27
28
batch_size: int = 16
29
num_workers: int = 4
30
seed: int = 42
31
max_source_length: int = 128
32
max_target_length: int = 128
33
34
device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
36
model_checkpoint: str = "Helsinki-NLP/opus-mt-de-en"
37
38
def __post_init__(self):
39
random.seed(self.seed)
40
np.random.seed(self.seed)
41
torch.manual_seed(self.seed)
42
torch.cuda.manual_seed_all(self.seed)
43
44
self.tokenizer = AutoTokenizer.from_pretrained(
45
self.model_checkpoint,
46
cache_dir=self.cache_dir
47
)
48
self.model = AutoModelForSeq2SeqLM.from_pretrained(
49
self.model_checkpoint,
50
cache_dir=self.cache_dir
51
)
52
print('# of parameters: ', self.model.num_parameters())
53
54
55
def batch_tokenize_fn(examples):
56
"""
57
Generate the input_ids and labels field for huggingface dataset/dataset dict.
58
59
Truncation is enabled where we cap the sentence to the max length. Padding will be done later
60
in a data collator, so we pad examples to the longest length within a mini-batch and not
61
the whole dataset.
62
"""
63
sources = examples[config.source_lang]
64
targets = examples[config.target_lang]
65
model_inputs = config.tokenizer(sources, max_length=config.max_source_length, truncation=True)
66
67
# setup the tokenizer for targets,
68
# huggingface expects the target tokenized ids to be stored in the labels field
69
with config.tokenizer.as_target_tokenizer():
70
labels = config.tokenizer(targets, max_length=config.max_target_length, truncation=True)
71
72
model_inputs["labels"] = labels["input_ids"]
73
return model_inputs
74
75
76
def compute_metrics(eval_pred):
77
"""
78
note: we can run trainer.predict on our eval/test dataset to see what a sample
79
eval_pred object would look like when implementing custom compute metrics function
80
"""
81
predictions, labels = eval_pred
82
# Decode generated summaries into text
83
decoded_preds = config.tokenizer.batch_decode(predictions, skip_special_tokens=True)
84
# Replace -100 in the labels as we can't decode them
85
labels = np.where(labels != -100, labels, config.tokenizer.pad_token_id)
86
# Decode reference summaries into text
87
decoded_labels = config.tokenizer.batch_decode(labels, skip_special_tokens=True)
88
result = rouge_score.compute(
89
predictions=decoded_preds,
90
references=decoded_labels,
91
rouge_types=["rouge1", "rouge2", "rougeL"]
92
)
93
score = sacrebleu_score.compute(
94
predictions=decoded_preds,
95
references=decoded_labels
96
)
97
result["sacrebleu"] = score["score"]
98
return {k: round(v, 4) for k, v in result.items()}
99
100
101
def create_wmt16_data_files(config: Config):
102
# files are downloaded from
103
# http://www.statmt.org/wmt16/multimodal-task.html
104
urls = [
105
'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/training.tar.gz',
106
'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz',
107
'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/mmt16_task1_test.tar.gz'
108
]
109
110
for url in urls:
111
download_file(url, config.data_dir)
112
113
data_files = {}
114
for split in ["train", "val", "test"]:
115
source_input_path = os.path.join(config.data_dir, f"{split}.{config.source_lang}")
116
target_input_path = os.path.join(config.data_dir, f"{split}.{config.target_lang}")
117
output_path = os.path.join(config.cache_dir, f"{split}.tsv")
118
create_translation_data(source_input_path, target_input_path, output_path)
119
data_files[split] = [output_path]
120
121
return data_files
122
123
124
if __name__ == "__main__":
125
config = Config()
126
127
data_files = create_wmt16_data_files(config)
128
dataset_dict = load_dataset(
129
'csv',
130
delimiter='\t',
131
column_names=[config.source_lang, config.target_lang],
132
data_files=data_files
133
)
134
dataset_dict_tokenized = dataset_dict.map(
135
batch_tokenize_fn,
136
batched=True
137
)
138
139
model_name = config.model_checkpoint.split("/")[-1]
140
output_dir = os.path.join(
141
config.cache_dir,
142
f"{model_name}_{config.source_lang}-{config.target_lang}"
143
)
144
args = Seq2SeqTrainingArguments(
145
output_dir=output_dir,
146
per_device_eval_batch_size=config.batch_size,
147
predict_with_generate=True
148
)
149
data_collator = DataCollatorForSeq2Seq(config.tokenizer, model=config.model)
150
rouge_score = evaluate.load("rouge", cache_dir=config.cache_dir)
151
sacrebleu_score = evaluate.load("sacrebleu", cache_dir=config.cache_dir)
152
trainer = Seq2SeqTrainer(
153
config.model,
154
args,
155
train_dataset=dataset_dict_tokenized["train"],
156
eval_dataset=dataset_dict_tokenized["val"],
157
data_collator=data_collator,
158
tokenizer=config.tokenizer,
159
compute_metrics=compute_metrics,
160
)
161
print(trainer.evaluate())
162
163