Path: blob/main/models_evaluation/evaluate_model.py
1231 views
# pylint: disable=too-many-locals12import argparse3import json4import os56from deepparse.parser import AddressParser7from models_evaluation.tools import (8train_country_file,9zero_shot_eval_country_file,10test_on_country_data,11)121314def main(args):15results_type = args.results_type16saving_dir = os.path.join(".", "models_evaluation", "results", results_type)17os.makedirs(saving_dir, exist_ok=True)1819address_parser = AddressParser(model_type=args.model_type, device=0)20directory_path = args.test_directory21test_files = os.listdir(directory_path)22training_test_results = {}23zero_shot_test_results = {}24for idx, test_file in enumerate(test_files):25results, country = test_on_country_data(address_parser, test_file, directory_path, args)26print(f"{idx} file done of {len(test_files)}.")2728if train_country_file(test_file):29training_test_results.update({country: results["test_accuracy"]})30elif zero_shot_eval_country_file(test_file):31zero_shot_test_results.update({country: results["test_accuracy"]})32else:33print(f"Error with the identification of test file type {test_file}.")3435training_base_string = "training_test_results"36training_incomplete_base_string = "training_incomplete_test_results"37zero_shot_base_string = "zero_shot_test_results"3839with open(40os.path.join(saving_dir, f"{training_base_string}_{args.model_type}.json"),41"w",42encoding="utf-8",43) as file:44json.dump(training_test_results, file, ensure_ascii=False)4546with open(47os.path.join(saving_dir, f"{zero_shot_base_string}_{args.model_type}.json"),48"w",49encoding="utf-8",50) as file:51json.dump(zero_shot_test_results, file, ensure_ascii=False)5253incomplete_test_directory = args.incomplete_test_directory54incomplete_test_files = os.listdir(incomplete_test_directory)55incomplete_training_test_results = {}56for idx, incomplete_test_file in enumerate(incomplete_test_files):57results, country = test_on_country_data(address_parser, incomplete_test_file, incomplete_test_directory, args)58print(f"{idx} file done of {len(incomplete_test_files)}.")5960if train_country_file(incomplete_test_file):61incomplete_training_test_results.update({country: results["test_accuracy"]})62else:63print(f"Error with the identification of test file type {incomplete_test_file}.")6465with open(66os.path.join(saving_dir, f"{training_incomplete_base_string}_{args.model_type}.json"),67"w",68encoding="utf-8",69) as file:70json.dump(incomplete_training_test_results, file, ensure_ascii=False)717273if __name__ == "__main__":74parser = argparse.ArgumentParser()7576parser.add_argument(77"model_type",78type=str,79help="Model type to retrain.",80choices=["fasttext", "bpemb"],81)82parser.add_argument("test_directory", type=str, help="Path to the test directory.")83parser.add_argument(84"incomplete_test_directory",85type=str,86help="Path the to incomplete test directory.",87)88parser.add_argument("model_path", type=str, help="Path to the model to evaluate on.")89parser.add_argument(90"--batch_size",91type=int,92default=2048,93help="Batch size of the data to evaluate on.",94)95parser.add_argument(96"--results_type",97type=str,98default="actual",99help="Either or not the evaluation is for new models.",100)101args_parser = parser.parse_args()102103main(args_parser)104105106