Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ethen8181
GitHub Repository: ethen8181/machine-learning
Path: blob/master/deep_learning/multi_label/fasttext_module/model.py
1487 views
1
import os
2
import fasttext
3
import pandas as pd
4
from copy import deepcopy
5
from typing import Any, Dict, List, Tuple
6
from joblib import Parallel, delayed, dump, load
7
from sklearn.model_selection import ParameterSampler
8
from fasttext_module.utils import prepend_file_name
9
from fasttext_module.split import train_test_split_file
10
11
12
__all__ = [
13
'FasttextPipeline',
14
'fit_and_score',
15
'fit_fasttext',
16
'score'
17
]
18
19
20
class FasttextPipeline:
21
"""
22
Fasttext text classification pipeline.
23
24
Parameters
25
----------
26
model_id : str
27
Unique identifier for the model, the model checkpoint will have this name.
28
29
fasttext_params : dict
30
Interpreted as fasttext.train_supervised(fasttext_params). Note that
31
we do not need to specify the input text file under this parameter.
32
33
fasttext_hyper_params : dict
34
Controls which parameters and its corresponding range that will be tuned.
35
e.g. {"dim": [80, 100]}
36
37
fasttext_search_params : dict
38
Controls how long to perform the hyperparameter search and what metric to optimize for.
39
40
- n_iter (int) Number of parameter settings that are chosen fasttext_hyper_params.
41
- random_state (int) Seed for sampling from fasttext_hyper_params.
42
- n_jobs (int) Number of jobs to run in parallel. -1 means use all processors.
43
- verbose (int) The higher the number, the more messages printed.
44
- scoring (str) The metrics to use for selecting the best parameter. e.g.
45
f1@1, precision@1, recall@1. The valid metrics are precision/recall/f1 followed
46
by @k, where k controls the top k predictions that we'll be evaluating the prediction.
47
48
Attributes
49
----------
50
model_ : _FastText
51
Fasttext model.
52
53
df_tune_results_ : pd.DataFrame
54
DataFrame that stores the hyperparameter tuning results, including the
55
parameters that were tuned and its corresponding train/test score.
56
57
best_params_ : dict
58
Best hyperparameter chosen to re-fit the model on the entire dataset.
59
"""
60
61
def __init__(self,
62
model_id: str,
63
fasttext_params: Dict[str, Any],
64
fasttext_hyper_params: Dict[str, List[Any]],
65
fasttext_search_params: Dict[str, Any]):
66
self.model_id = model_id
67
self.fasttext_params = fasttext_params
68
self.fasttext_hyper_params = fasttext_hyper_params
69
self.fasttext_search_params = fasttext_search_params
70
71
def fit_file(self, fasttext_file_path: str,
72
val_size: float=0.1, split_random_state: int=1234):
73
"""
74
Fit the pipeline to the input file path.
75
76
Parameters
77
----------
78
fasttext_file_path : str
79
The text file should already be in the fasttext expected format.
80
81
val_size: float, default 0.1
82
Proportion of the dataset to include in the validation split.
83
The validation set will be used to pick the best parameter from
84
the hyperparameter search.
85
86
split_random_state : int, default 1234
87
Seed for the split.
88
89
Returns
90
-------
91
self
92
"""
93
self._tune_fasttext(fasttext_file_path, val_size, split_random_state,
94
**self.fasttext_search_params)
95
self.model_ = fit_fasttext(fasttext_file_path, self.fasttext_params, self.best_params_)
96
return self
97
98
def _tune_fasttext(self, fasttext_file_path: str, val_size: float, split_random_state: int,
99
n_iter: int, random_state: int, n_jobs: int, verbose: int, scoring: str):
100
parameter_sampler = ParameterSampler(self.fasttext_hyper_params, n_iter, random_state)
101
102
fasttext_file_path_train = prepend_file_name(fasttext_file_path, 'train')
103
fasttext_file_path_val = prepend_file_name(fasttext_file_path, 'val')
104
count_train, count_val = train_test_split_file(
105
fasttext_file_path, fasttext_file_path_train, fasttext_file_path_val,
106
val_size, split_random_state)
107
108
k = int(scoring.split('@')[-1])
109
parallel = Parallel(n_jobs=n_jobs, verbose=verbose)
110
results = parallel(delayed(fit_and_score)(fasttext_file_path_train,
111
fasttext_file_path_val,
112
self.fasttext_params,
113
k,
114
param)
115
for param in parameter_sampler)
116
117
df_tune_results = (pd.DataFrame
118
.from_dict(results)
119
.sort_values(f'test_{scoring}', ascending=False))
120
self.best_params_ = df_tune_results['params'].iloc[0]
121
self.df_tune_results_ = df_tune_results
122
123
# clean up the intermediate train/test split file to prevent hogging up
124
# un-needed disk space
125
for file_path in [fasttext_file_path_train, fasttext_file_path_val]:
126
os.remove(file_path)
127
128
return self
129
130
def save(self, directory: str) -> str:
131
"""
132
Saves the pipeline.
133
134
Parameters
135
----------
136
directory : str
137
The directory to save the model. Will create the directory if it
138
doesn't exist.
139
140
Returns
141
-------
142
model_checkpoint_dir : str
143
The directory of the saved model.
144
"""
145
model_checkpoint_dir = os.path.join(directory, self.model_id)
146
if not os.path.isdir(model_checkpoint_dir):
147
os.makedirs(model_checkpoint_dir, exist_ok=True)
148
149
# some model can't be pickled and have their own way of saving it
150
model = self.model_
151
model_checkpoint = os.path.join(model_checkpoint_dir, 'model.fasttext')
152
model.save_model(model_checkpoint)
153
154
self.model_ = None
155
pipeline_checkpoint = os.path.join(model_checkpoint_dir, 'fasttext_pipeline.pkl')
156
dump(self, pipeline_checkpoint)
157
158
self.model_ = model
159
return model_checkpoint_dir
160
161
@classmethod
162
def load(cls, directory: str):
163
"""
164
Loads the full model from file.
165
166
Parameters
167
----------
168
directory : str
169
The saved directory returned by calling .save.
170
171
Returns
172
-------
173
model : FasttextPipeline
174
"""
175
pipeline_checkpoint = os.path.join(directory, 'fasttext_pipeline.pkl')
176
fasttext_pipeline = load(pipeline_checkpoint)
177
178
model_checkpoint = os.path.join(directory, 'model.fasttext')
179
model = fasttext.load_model(model_checkpoint)
180
181
fasttext_pipeline.model_ = model
182
return fasttext_pipeline
183
184
def score_str(self, fasttext_file_path: str, k: int=1, round_digits: int=3) -> str:
185
"""
186
Computes the model evaluation score for the input data and formats
187
them into a string, making it easier for logging. This method calls
188
score internally.
189
190
Parameters
191
----------
192
fasttext_file_path : str
193
Path to the text file in the fasttext format.
194
195
k : int, default 1
196
Ranking metrics precision/recall/f1 are evaluated for top k prediction.
197
198
round_digits : int, default 3
199
Round decimal points for the metrics returned.
200
201
Returns
202
-------
203
score_str : str
204
e.g. ' metric - num_records: 29740, precision@1: 0.784, recall@1: 0.243, f1@1: 0.371'
205
"""
206
num_records, precision_at_k, recall_at_k, f1_at_k = score(
207
self.model_, fasttext_file_path, k, round_digits)
208
209
num_records = f'num_records: {num_records}'
210
precision_at_k = f'precision@{k}: {precision_at_k}'
211
recall_at_k = f'recall@{k}: {recall_at_k}'
212
f1_at_k = f'f1@{k}: {f1_at_k}'
213
return f' metric - {num_records}, {precision_at_k}, {recall_at_k}, {f1_at_k}'
214
215
def predict(self, texts: List[str], k: int=1,
216
threshold: float=0.1,
217
on_unicode_error: str='strict') -> List[List[Tuple[float, str]]]:
218
"""
219
Given a list of raw text, predict the list of labels and corresponding probabilities.
220
We can use k and threshold in conjunction to control to number of labels to return for
221
each text in the input list.
222
223
Parameters
224
----------
225
texts : list[str]
226
A list of raw text/string.
227
228
k : int, default 1
229
Controls the number of returned labels. 1 will return the top most probable labels.
230
231
threshold : float, default 0.1
232
This filters the returned labels that are lower than the specified probability.
233
e.g. if k is specified to be 2, but once the returned probable labels has a probability
234
lower than this threshold, then only 1 predicted labels will be returned.
235
236
on_unicode_error : str, default 'strict'
237
Controls the behavior when the input string can't be converted according to the
238
encoding rule.
239
240
Returns
241
-------
242
batch_predictions : list[list[tuple[float, str]]]
243
e.g. [[(0.562, '__label__label1'), (0.362, '__label__label2')]]
244
"""
245
246
# fasttext's own predict method doesn't work well when k and threshold is
247
# specified together for batch prediction, this is due to the size of the
248
# prediction returned for each text in the batch is not equal, hence we
249
# roll out our own predict method to accommodate for this.
250
251
# appending the new line at the end of the text is needed for fasttext prediction
252
# note that it should be done after the tokenization to prevent the tokenizer
253
# from modifying the new line symbol
254
tokenized_texts = [text + '\n' for text in texts]
255
batch_predictions = self.model_.f.multilinePredict(
256
tokenized_texts, k, threshold, on_unicode_error)
257
258
return batch_predictions
259
260
261
def fit_and_score(fasttext_file_path_train: str,
262
fasttext_file_path_test: str,
263
fasttext_params: Dict[str, Any],
264
k: int,
265
params: Dict[str, Any]) -> Dict[str, Any]:
266
"""
267
Fits the fasttext model and computes the score for a given train and test split
268
on a set of parameters.
269
270
Parameters
271
----------
272
fasttext_file_path_train : str
273
The text file should already be in the fasttext expected format.
274
This is used for training the model.
275
276
fasttext_file_path_test : str
277
The text file should already be in the fasttext expected format.
278
This is used for testing the model on the holdout set.
279
280
fasttext_params : dict
281
The fixed set of parameters for fastttext.
282
283
k : int
284
Ranking metrics precision/recall/f1 are evaluated for top k prediction.
285
286
params : dict
287
The parameters that are tuned. Will over-ride any parameter that
288
are specified in fasttext_params.
289
290
Returns
291
-------
292
result : dict
293
Stores the results for the current iteration e.g.::
294
295
{
296
'params': {'epoch': 10, 'dim': 85},
297
'epoch': 10,
298
'dim': 85,
299
'train_precision@1': 0.486,
300
'train_recall@1': 0.210,
301
'train_f1@1': 0.294,
302
'test_precision@1': 0.407,
303
'test_recall@1': 0.175,
304
'test_f1@1': 0.245
305
}
306
"""
307
current_model = fit_fasttext(fasttext_file_path_train, fasttext_params, params)
308
309
fasttext_file_path_dict = {
310
'train': fasttext_file_path_train,
311
'test': fasttext_file_path_test
312
}
313
314
result = {'params': params}
315
result.update(params)
316
for group, fasttext_file_path in fasttext_file_path_dict.items():
317
num_records, precision_at_k, recall_at_k, f1_at_k = score(
318
current_model, fasttext_file_path, k)
319
metric = {
320
f'{group}_precision@{k}': precision_at_k,
321
f'{group}_recall@{k}': recall_at_k,
322
f'{group}_f1@{k}': f1_at_k
323
}
324
result.update(metric)
325
326
return result
327
328
329
def fit_fasttext(fasttext_file_path: str,
330
fasttext_params: Dict[str, Any],
331
params: Dict[str, Any]) -> fasttext.FastText._FastText:
332
"""
333
Fits a fasttext model.
334
335
Parameters
336
----------
337
fasttext_file_path : str
338
The text file should already be in the fasttext expected format.
339
340
fasttext_params : dict
341
The fixed set of parameters for fastttext.
342
343
params : dict
344
The parameters that are tuned. Will over-ride any parameter that
345
are specified in fasttext_params.
346
347
Returns
348
-------
349
model : _FastText
350
Trained fasttext model.
351
"""
352
current_params = deepcopy(fasttext_params)
353
current_params.update(params)
354
current_params['input'] = fasttext_file_path
355
model = fasttext.train_supervised(**current_params)
356
return model
357
358
359
def score(model: fasttext.FastText._FastText,
360
fasttext_file_path: str,
361
k: int=1,
362
round_digits: int=3) -> Tuple[int, float, float, float]:
363
"""
364
Computes the model evaluation score including precision/recall/f1 at k
365
for the input file.
366
367
Parameters
368
----------
369
model : _FastText
370
Trained fasttext model.
371
372
fasttext_file_path : str
373
Path to the text file in the fasttext format.
374
375
k : int, default 1
376
Ranking metrics precision/recall/f1 are evaluated for top k prediction.
377
378
round_digits : int, default 3
379
Round decimal points for the metrics returned.
380
381
Returns
382
-------
383
num_records : int
384
Number of records in the file.
385
386
precision_at_k : float
387
388
recall_at_k : float
389
390
f1_at_k : float
391
"""
392
393
num_records, precision_at_k, recall_at_k = model.test(fasttext_file_path, k)
394
f1_at_k = 2 * (precision_at_k * recall_at_k) / (precision_at_k + recall_at_k)
395
396
precision_at_k = round(precision_at_k, round_digits)
397
recall_at_k = round(recall_at_k, round_digits)
398
f1_at_k = round(f1_at_k, round_digits)
399
return num_records, precision_at_k, recall_at_k, f1_at_k
400
401