Path: blob/master/deep_learning/multi_label/fasttext_module/model.py
1487 views
import os1import fasttext2import pandas as pd3from copy import deepcopy4from typing import Any, Dict, List, Tuple5from joblib import Parallel, delayed, dump, load6from sklearn.model_selection import ParameterSampler7from fasttext_module.utils import prepend_file_name8from fasttext_module.split import train_test_split_file91011__all__ = [12'FasttextPipeline',13'fit_and_score',14'fit_fasttext',15'score'16]171819class FasttextPipeline:20"""21Fasttext text classification pipeline.2223Parameters24----------25model_id : str26Unique identifier for the model, the model checkpoint will have this name.2728fasttext_params : dict29Interpreted as fasttext.train_supervised(fasttext_params). Note that30we do not need to specify the input text file under this parameter.3132fasttext_hyper_params : dict33Controls which parameters and its corresponding range that will be tuned.34e.g. {"dim": [80, 100]}3536fasttext_search_params : dict37Controls how long to perform the hyperparameter search and what metric to optimize for.3839- n_iter (int) Number of parameter settings that are chosen fasttext_hyper_params.40- random_state (int) Seed for sampling from fasttext_hyper_params.41- n_jobs (int) Number of jobs to run in parallel. -1 means use all processors.42- verbose (int) The higher the number, the more messages printed.43- scoring (str) The metrics to use for selecting the best parameter. e.g.44f1@1, precision@1, recall@1. The valid metrics are precision/recall/f1 followed45by @k, where k controls the top k predictions that we'll be evaluating the prediction.4647Attributes48----------49model_ : _FastText50Fasttext model.5152df_tune_results_ : pd.DataFrame53DataFrame that stores the hyperparameter tuning results, including the54parameters that were tuned and its corresponding train/test score.5556best_params_ : dict57Best hyperparameter chosen to re-fit the model on the entire dataset.58"""5960def __init__(self,61model_id: str,62fasttext_params: Dict[str, Any],63fasttext_hyper_params: Dict[str, List[Any]],64fasttext_search_params: Dict[str, Any]):65self.model_id = model_id66self.fasttext_params = fasttext_params67self.fasttext_hyper_params = fasttext_hyper_params68self.fasttext_search_params = fasttext_search_params6970def fit_file(self, fasttext_file_path: str,71val_size: float=0.1, split_random_state: int=1234):72"""73Fit the pipeline to the input file path.7475Parameters76----------77fasttext_file_path : str78The text file should already be in the fasttext expected format.7980val_size: float, default 0.181Proportion of the dataset to include in the validation split.82The validation set will be used to pick the best parameter from83the hyperparameter search.8485split_random_state : int, default 123486Seed for the split.8788Returns89-------90self91"""92self._tune_fasttext(fasttext_file_path, val_size, split_random_state,93**self.fasttext_search_params)94self.model_ = fit_fasttext(fasttext_file_path, self.fasttext_params, self.best_params_)95return self9697def _tune_fasttext(self, fasttext_file_path: str, val_size: float, split_random_state: int,98n_iter: int, random_state: int, n_jobs: int, verbose: int, scoring: str):99parameter_sampler = ParameterSampler(self.fasttext_hyper_params, n_iter, random_state)100101fasttext_file_path_train = prepend_file_name(fasttext_file_path, 'train')102fasttext_file_path_val = prepend_file_name(fasttext_file_path, 'val')103count_train, count_val = train_test_split_file(104fasttext_file_path, fasttext_file_path_train, fasttext_file_path_val,105val_size, split_random_state)106107k = int(scoring.split('@')[-1])108parallel = Parallel(n_jobs=n_jobs, verbose=verbose)109results = parallel(delayed(fit_and_score)(fasttext_file_path_train,110fasttext_file_path_val,111self.fasttext_params,112k,113param)114for param in parameter_sampler)115116df_tune_results = (pd.DataFrame117.from_dict(results)118.sort_values(f'test_{scoring}', ascending=False))119self.best_params_ = df_tune_results['params'].iloc[0]120self.df_tune_results_ = df_tune_results121122# clean up the intermediate train/test split file to prevent hogging up123# un-needed disk space124for file_path in [fasttext_file_path_train, fasttext_file_path_val]:125os.remove(file_path)126127return self128129def save(self, directory: str) -> str:130"""131Saves the pipeline.132133Parameters134----------135directory : str136The directory to save the model. Will create the directory if it137doesn't exist.138139Returns140-------141model_checkpoint_dir : str142The directory of the saved model.143"""144model_checkpoint_dir = os.path.join(directory, self.model_id)145if not os.path.isdir(model_checkpoint_dir):146os.makedirs(model_checkpoint_dir, exist_ok=True)147148# some model can't be pickled and have their own way of saving it149model = self.model_150model_checkpoint = os.path.join(model_checkpoint_dir, 'model.fasttext')151model.save_model(model_checkpoint)152153self.model_ = None154pipeline_checkpoint = os.path.join(model_checkpoint_dir, 'fasttext_pipeline.pkl')155dump(self, pipeline_checkpoint)156157self.model_ = model158return model_checkpoint_dir159160@classmethod161def load(cls, directory: str):162"""163Loads the full model from file.164165Parameters166----------167directory : str168The saved directory returned by calling .save.169170Returns171-------172model : FasttextPipeline173"""174pipeline_checkpoint = os.path.join(directory, 'fasttext_pipeline.pkl')175fasttext_pipeline = load(pipeline_checkpoint)176177model_checkpoint = os.path.join(directory, 'model.fasttext')178model = fasttext.load_model(model_checkpoint)179180fasttext_pipeline.model_ = model181return fasttext_pipeline182183def score_str(self, fasttext_file_path: str, k: int=1, round_digits: int=3) -> str:184"""185Computes the model evaluation score for the input data and formats186them into a string, making it easier for logging. This method calls187score internally.188189Parameters190----------191fasttext_file_path : str192Path to the text file in the fasttext format.193194k : int, default 1195Ranking metrics precision/recall/f1 are evaluated for top k prediction.196197round_digits : int, default 3198Round decimal points for the metrics returned.199200Returns201-------202score_str : str203e.g. ' metric - num_records: 29740, precision@1: 0.784, recall@1: 0.243, f1@1: 0.371'204"""205num_records, precision_at_k, recall_at_k, f1_at_k = score(206self.model_, fasttext_file_path, k, round_digits)207208num_records = f'num_records: {num_records}'209precision_at_k = f'precision@{k}: {precision_at_k}'210recall_at_k = f'recall@{k}: {recall_at_k}'211f1_at_k = f'f1@{k}: {f1_at_k}'212return f' metric - {num_records}, {precision_at_k}, {recall_at_k}, {f1_at_k}'213214def predict(self, texts: List[str], k: int=1,215threshold: float=0.1,216on_unicode_error: str='strict') -> List[List[Tuple[float, str]]]:217"""218Given a list of raw text, predict the list of labels and corresponding probabilities.219We can use k and threshold in conjunction to control to number of labels to return for220each text in the input list.221222Parameters223----------224texts : list[str]225A list of raw text/string.226227k : int, default 1228Controls the number of returned labels. 1 will return the top most probable labels.229230threshold : float, default 0.1231This filters the returned labels that are lower than the specified probability.232e.g. if k is specified to be 2, but once the returned probable labels has a probability233lower than this threshold, then only 1 predicted labels will be returned.234235on_unicode_error : str, default 'strict'236Controls the behavior when the input string can't be converted according to the237encoding rule.238239Returns240-------241batch_predictions : list[list[tuple[float, str]]]242e.g. [[(0.562, '__label__label1'), (0.362, '__label__label2')]]243"""244245# fasttext's own predict method doesn't work well when k and threshold is246# specified together for batch prediction, this is due to the size of the247# prediction returned for each text in the batch is not equal, hence we248# roll out our own predict method to accommodate for this.249250# appending the new line at the end of the text is needed for fasttext prediction251# note that it should be done after the tokenization to prevent the tokenizer252# from modifying the new line symbol253tokenized_texts = [text + '\n' for text in texts]254batch_predictions = self.model_.f.multilinePredict(255tokenized_texts, k, threshold, on_unicode_error)256257return batch_predictions258259260def fit_and_score(fasttext_file_path_train: str,261fasttext_file_path_test: str,262fasttext_params: Dict[str, Any],263k: int,264params: Dict[str, Any]) -> Dict[str, Any]:265"""266Fits the fasttext model and computes the score for a given train and test split267on a set of parameters.268269Parameters270----------271fasttext_file_path_train : str272The text file should already be in the fasttext expected format.273This is used for training the model.274275fasttext_file_path_test : str276The text file should already be in the fasttext expected format.277This is used for testing the model on the holdout set.278279fasttext_params : dict280The fixed set of parameters for fastttext.281282k : int283Ranking metrics precision/recall/f1 are evaluated for top k prediction.284285params : dict286The parameters that are tuned. Will over-ride any parameter that287are specified in fasttext_params.288289Returns290-------291result : dict292Stores the results for the current iteration e.g.::293294{295'params': {'epoch': 10, 'dim': 85},296'epoch': 10,297'dim': 85,298'train_precision@1': 0.486,299'train_recall@1': 0.210,300'train_f1@1': 0.294,301'test_precision@1': 0.407,302'test_recall@1': 0.175,303'test_f1@1': 0.245304}305"""306current_model = fit_fasttext(fasttext_file_path_train, fasttext_params, params)307308fasttext_file_path_dict = {309'train': fasttext_file_path_train,310'test': fasttext_file_path_test311}312313result = {'params': params}314result.update(params)315for group, fasttext_file_path in fasttext_file_path_dict.items():316num_records, precision_at_k, recall_at_k, f1_at_k = score(317current_model, fasttext_file_path, k)318metric = {319f'{group}_precision@{k}': precision_at_k,320f'{group}_recall@{k}': recall_at_k,321f'{group}_f1@{k}': f1_at_k322}323result.update(metric)324325return result326327328def fit_fasttext(fasttext_file_path: str,329fasttext_params: Dict[str, Any],330params: Dict[str, Any]) -> fasttext.FastText._FastText:331"""332Fits a fasttext model.333334Parameters335----------336fasttext_file_path : str337The text file should already be in the fasttext expected format.338339fasttext_params : dict340The fixed set of parameters for fastttext.341342params : dict343The parameters that are tuned. Will over-ride any parameter that344are specified in fasttext_params.345346Returns347-------348model : _FastText349Trained fasttext model.350"""351current_params = deepcopy(fasttext_params)352current_params.update(params)353current_params['input'] = fasttext_file_path354model = fasttext.train_supervised(**current_params)355return model356357358def score(model: fasttext.FastText._FastText,359fasttext_file_path: str,360k: int=1,361round_digits: int=3) -> Tuple[int, float, float, float]:362"""363Computes the model evaluation score including precision/recall/f1 at k364for the input file.365366Parameters367----------368model : _FastText369Trained fasttext model.370371fasttext_file_path : str372Path to the text file in the fasttext format.373374k : int, default 1375Ranking metrics precision/recall/f1 are evaluated for top k prediction.376377round_digits : int, default 3378Round decimal points for the metrics returned.379380Returns381-------382num_records : int383Number of records in the file.384385precision_at_k : float386387recall_at_k : float388389f1_at_k : float390"""391392num_records, precision_at_k, recall_at_k = model.test(fasttext_file_path, k)393f1_at_k = 2 * (precision_at_k * recall_at_k) / (precision_at_k + recall_at_k)394395precision_at_k = round(precision_at_k, round_digits)396recall_at_k = round(recall_at_k, round_digits)397f1_at_k = round(f1_at_k, round_digits)398return num_records, precision_at_k, recall_at_k, f1_at_k399400401