Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
GRAAL-Research
GitHub Repository: GRAAL-Research/deepparse
Path: blob/main/models_evaluation/tools.py
1231 views
1
# pylint: disable=too-many-locals
2
3
import json
4
import os
5
6
import pandas as pd
7
import pycountry
8
9
from deepparse.dataset_container import PickleDatasetContainer
10
from deepparse.parser import AddressParser
11
12
13
def clean_up_name(country: str) -> str:
14
"""
15
Function to clean up pycountry name
16
"""
17
if "Korea" in country:
18
country = "South Korea"
19
elif "Russian Federation" in country:
20
country = "Russia"
21
elif "Venezuela" in country:
22
country = "Venezuela"
23
elif "Moldova" in country:
24
country = "Moldova"
25
elif "Bosnia" in country:
26
country = "Bosnia"
27
return country
28
29
30
# country that we trained on
31
train_test_files = [
32
"br.p",
33
"us.p",
34
"kp.p",
35
"ru.p",
36
"de.p",
37
"fr.p",
38
"nl.p",
39
"ch.p",
40
"fi.p",
41
"es.p",
42
"cz.p",
43
"gb.p",
44
"mx.p",
45
"no.p",
46
"ca.p",
47
"it.p",
48
"au.p",
49
"dk.p",
50
"pl.p",
51
"at.p",
52
]
53
54
55
def train_country_file(file: str) -> bool:
56
"""
57
Validate if a file is a training country (as reference of our article).
58
"""
59
return file in train_test_files
60
61
62
# country that we did not train on
63
other_test_files = [
64
"ie.p",
65
"rs.p",
66
"uz.p",
67
"ua.p",
68
"za.p",
69
"py.p",
70
"gr.p",
71
"dz.p",
72
"by.p",
73
"se.p",
74
"pt.p",
75
"hu.p",
76
"is.p",
77
"co.p",
78
"lv.p",
79
"my.p",
80
"ba.p",
81
"in.p",
82
"re.p",
83
"hr.p",
84
"ee.p",
85
"nc.p",
86
"jp.p",
87
"nz.p",
88
"sg.p",
89
"ro.p",
90
"bd.p",
91
"sk.p",
92
"ar.p",
93
"kz.p",
94
"ve.p",
95
"id.p",
96
"bg.p",
97
"cy.p",
98
"bm.p",
99
"md.p",
100
"si.p",
101
"lt.p",
102
"ph.p",
103
"be.p",
104
"fo.p",
105
]
106
107
108
def zero_shot_eval_country_file(file: str) -> bool:
109
"""
110
Validate if a file is a zero shot country (as reference of our article).
111
"""
112
return file in other_test_files
113
114
115
def convert_2_letters_name_into_country_name(country_file_name: str) -> str:
116
country_name = pycountry.countries.get(alpha_2=country_file_name.replace(".p", "").upper()).name
117
country_name = clean_up_name(country_name)
118
return country_name
119
120
121
def convert_two_letter_name_to_country_name_in_json(file_path: str) -> None:
122
"""
123
Function to convert the name of the countries in a results file into their complete name.
124
"""
125
with open(file_path, "r") as file:
126
data = json.load(file)
127
128
new_data = {}
129
for country_name, value in data.items():
130
new_data.update({convert_2_letters_name_into_country_name(country_name): value})
131
132
with open(file_path, "w") as file:
133
json.dump(new_data, file)
134
135
136
def test_on_country_data(address_parser: AddressParser, file: str, directory_path: str, args) -> tuple:
137
"""
138
Compute the results over a country data.
139
"""
140
country = convert_2_letters_name_into_country_name(file)
141
142
print(f"Testing on test files {country}")
143
144
test_file_path = os.path.join(directory_path, file)
145
test_container = PickleDatasetContainer(test_file_path, is_training_container=False)
146
147
results = address_parser.test(
148
test_container,
149
batch_size=args.batch_size,
150
num_workers=4,
151
logging_path=f"./checkpoints/{args.model_type}",
152
checkpoint=args.model_path,
153
)
154
return results, country
155
156
157
def make_table(data_type: str, root_path: str = ".", with_attention: bool = False):
158
"""
159
Function to generate an Markdown table
160
"""
161
table_dir = os.path.join("tables", "actual")
162
os.makedirs(table_dir, exist_ok=True)
163
164
fasttext_all_res = json.load(open(os.path.join(root_path, f"{data_type}_test_results_fasttext.json"), "r"))
165
bpemb_all_res = json.load(open(os.path.join(root_path, f"{data_type}_test_results_bpemb.json"), "r"))
166
167
zipped_data = zip(fasttext_all_res.items(), bpemb_all_res.items())
168
columns_name = ["Country", r"FastText (%)", r"BPEmb (%)"]
169
170
if with_attention:
171
fasttext_att_all_res = json.load(
172
open(
173
os.path.join(root_path, f"{data_type}_test_results_fasttext_attention.json"),
174
"r",
175
)
176
)
177
bpemb_att_all_res = json.load(
178
open(
179
os.path.join(root_path, f"{data_type}_test_results_bpemb_attention.json"),
180
"r",
181
)
182
)
183
184
zipped_data = zip(
185
fasttext_all_res.items(),
186
fasttext_att_all_res.items(),
187
bpemb_all_res.items(),
188
bpemb_att_all_res.items(),
189
)
190
columns_name = [
191
"Country",
192
r"FastText (%)",
193
r"FastTextAtt (%)",
194
r"BPEmb (%)",
195
r"BPEmbAtt (%)",
196
]
197
198
columns = columns_name * 2
199
formatted_data = []
200
# We format the data to have two pairs of columns for a less long table
201
for idx, all_model_data in enumerate(zipped_data):
202
country = all_model_data[0][0]
203
res_data = [model_res[1] for model_res in all_model_data]
204
if idx % 2 and idx != 0:
205
row_data = [country]
206
row_data.extend(res_data)
207
data.extend(row_data)
208
formatted_data.append(data)
209
else:
210
data = [country]
211
data.extend(res_data)
212
if idx == 40:
213
data.extend([0] * (len(res_data) + 1))
214
formatted_data.append(data)
215
table = pd.DataFrame(formatted_data, columns=columns).round(2).to_markdown(index=False)
216
217
with open(os.path.join(table_dir, f"{data_type}_table.md"), "w", encoding="utf-8") as file:
218
file.writelines(table)
219
220
221
def make_table_rst(data_type: str, root_path: str = ".", with_attention: bool = False):
222
# pylint: disable=too-many-locals
223
"""
224
Function to generate an Sphinx RST table
225
"""
226
table_dir = os.path.join("tables", "actual")
227
os.makedirs(table_dir, exist_ok=True)
228
229
fasttext_all_res = json.load(open(os.path.join(root_path, f"{data_type}_test_results_fasttext.json"), "r"))
230
bpemb_all_res = json.load(open(os.path.join(root_path, f"{data_type}_test_results_bpemb.json"), "r"))
231
232
zipped_data = zip(fasttext_all_res.items(), bpemb_all_res.items())
233
columns_name = ["Country", r"FastText (%)", r"BPEmb (%)"]
234
235
if with_attention:
236
fasttext_att_all_res = json.load(
237
open(
238
os.path.join(root_path, f"{data_type}_test_results_fasttext_attention.json"),
239
"r",
240
)
241
)
242
bpemb_att_all_res = json.load(
243
open(
244
os.path.join(root_path, f"{data_type}_test_results_bpemb_attention.json"),
245
"r",
246
)
247
)
248
249
zipped_data = zip(
250
fasttext_all_res.items(),
251
fasttext_att_all_res.items(),
252
bpemb_all_res.items(),
253
bpemb_att_all_res.items(),
254
)
255
columns_name = [
256
"Country",
257
r"FastText (%)",
258
r"FastTextAtt (%)",
259
r"BPEmb (%)",
260
r"BPEmbAtt (%)",
261
]
262
263
columns = columns_name * 2
264
265
formatted_data = []
266
# we format the data to have two pairs of columns for a less long table
267
for idx, all_model_data in enumerate(zipped_data):
268
country = all_model_data[0][0]
269
res_data = [model_res[1] for model_res in all_model_data]
270
if idx % 2 and idx != 0:
271
row_data = [country]
272
row_data.extend(res_data)
273
data.extend(row_data)
274
formatted_data.append(data)
275
else:
276
data = [country]
277
data.extend(res_data)
278
if idx == 40:
279
data.extend([0] * (len(res_data) + 1))
280
formatted_data.append(data)
281
table = pd.DataFrame(formatted_data, columns=columns).round(2)
282
new_line_prefix = "\t\t"
283
string = ".. list-table::\n" + new_line_prefix + ":header-rows: 1\n" + "\n"
284
285
for idx, column in enumerate(table.columns):
286
if idx == 0:
287
string = string + new_line_prefix + "*" + f"\t- {column}\n"
288
else:
289
string = string + new_line_prefix + f"\t- {column}\n"
290
291
for _, row in table.iterrows():
292
for idx, data in enumerate(list(row)):
293
if idx == 0:
294
string = string + new_line_prefix + "*" + f"\t- {data}\n"
295
else:
296
string = string + new_line_prefix + f"\t- {data}\n"
297
298
with open(os.path.join(table_dir, f"{data_type}_table.rst"), "w", encoding="utf-8") as file:
299
file.writelines(string)
300
301
302
def make_comparison_table(results_a_file_name: str, results_b_file_name: str, root_path: str = "."):
303
"""
304
Function to generate an Markdown table
305
"""
306
table_dir = os.path.join("tables", "comparison")
307
os.makedirs(table_dir, exist_ok=True)
308
309
model_a_res = json.load(open(os.path.join(root_path, results_a_file_name), "r"))
310
model_b_res = json.load(open(os.path.join(root_path, results_b_file_name), "r"))
311
312
formatted_data = []
313
# we format the data to have two pairs of columns for a less long table
314
for idx, ((country, fasttext_res), (_, bpemb_res)) in enumerate(zip(model_a_res.items(), model_b_res.items())):
315
if idx % 2 and idx != 0:
316
data.extend([country, fasttext_res, bpemb_res])
317
formatted_data.append(data)
318
else:
319
data = [country, fasttext_res, bpemb_res]
320
if idx == 40:
321
formatted_data.append(data)
322
table = (
323
pd.DataFrame(
324
formatted_data,
325
columns=[
326
"Country",
327
r"Model A (%)",
328
r"Model B (%)",
329
"Country",
330
r"Model A (%)",
331
r"Model B (%)",
332
],
333
)
334
.round(2)
335
.to_markdown(index=False)
336
)
337
338
with open(
339
os.path.join(table_dir, f"{results_a_file_name}_vs_{results_b_file_name}_table.md"),
340
"w",
341
encoding="utf-8",
342
) as file:
343
file.writelines(table)
344
345