Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
GRAAL-Research
GitHub Repository: GRAAL-Research/deepparse
Path: blob/main/models_evaluation/evaluate_model.py
1231 views
1
# pylint: disable=too-many-locals
2
3
import argparse
4
import json
5
import os
6
7
from deepparse.parser import AddressParser
8
from models_evaluation.tools import (
9
train_country_file,
10
zero_shot_eval_country_file,
11
test_on_country_data,
12
)
13
14
15
def main(args):
16
results_type = args.results_type
17
saving_dir = os.path.join(".", "models_evaluation", "results", results_type)
18
os.makedirs(saving_dir, exist_ok=True)
19
20
address_parser = AddressParser(model_type=args.model_type, device=0)
21
directory_path = args.test_directory
22
test_files = os.listdir(directory_path)
23
training_test_results = {}
24
zero_shot_test_results = {}
25
for idx, test_file in enumerate(test_files):
26
results, country = test_on_country_data(address_parser, test_file, directory_path, args)
27
print(f"{idx} file done of {len(test_files)}.")
28
29
if train_country_file(test_file):
30
training_test_results.update({country: results["test_accuracy"]})
31
elif zero_shot_eval_country_file(test_file):
32
zero_shot_test_results.update({country: results["test_accuracy"]})
33
else:
34
print(f"Error with the identification of test file type {test_file}.")
35
36
training_base_string = "training_test_results"
37
training_incomplete_base_string = "training_incomplete_test_results"
38
zero_shot_base_string = "zero_shot_test_results"
39
40
with open(
41
os.path.join(saving_dir, f"{training_base_string}_{args.model_type}.json"),
42
"w",
43
encoding="utf-8",
44
) as file:
45
json.dump(training_test_results, file, ensure_ascii=False)
46
47
with open(
48
os.path.join(saving_dir, f"{zero_shot_base_string}_{args.model_type}.json"),
49
"w",
50
encoding="utf-8",
51
) as file:
52
json.dump(zero_shot_test_results, file, ensure_ascii=False)
53
54
incomplete_test_directory = args.incomplete_test_directory
55
incomplete_test_files = os.listdir(incomplete_test_directory)
56
incomplete_training_test_results = {}
57
for idx, incomplete_test_file in enumerate(incomplete_test_files):
58
results, country = test_on_country_data(address_parser, incomplete_test_file, incomplete_test_directory, args)
59
print(f"{idx} file done of {len(incomplete_test_files)}.")
60
61
if train_country_file(incomplete_test_file):
62
incomplete_training_test_results.update({country: results["test_accuracy"]})
63
else:
64
print(f"Error with the identification of test file type {incomplete_test_file}.")
65
66
with open(
67
os.path.join(saving_dir, f"{training_incomplete_base_string}_{args.model_type}.json"),
68
"w",
69
encoding="utf-8",
70
) as file:
71
json.dump(incomplete_training_test_results, file, ensure_ascii=False)
72
73
74
if __name__ == "__main__":
75
parser = argparse.ArgumentParser()
76
77
parser.add_argument(
78
"model_type",
79
type=str,
80
help="Model type to retrain.",
81
choices=["fasttext", "bpemb"],
82
)
83
parser.add_argument("test_directory", type=str, help="Path to the test directory.")
84
parser.add_argument(
85
"incomplete_test_directory",
86
type=str,
87
help="Path the to incomplete test directory.",
88
)
89
parser.add_argument("model_path", type=str, help="Path to the model to evaluate on.")
90
parser.add_argument(
91
"--batch_size",
92
type=int,
93
default=2048,
94
help="Batch size of the data to evaluate on.",
95
)
96
parser.add_argument(
97
"--results_type",
98
type=str,
99
default="actual",
100
help="Either or not the evaluation is for new models.",
101
)
102
args_parser = parser.parse_args()
103
104
main(args_parser)
105
106