Path: blob/master/model_selection/prob_calibration/calibration_module/utils.py
2585 views
import os1import math2import numpy as np3import pandas as pd4import matplotlib.pyplot as plt5import sklearn.metrics as metrics6from typing import Dict, List, Tuple, Optional7from sklearn.utils import check_consistent_length, column_or_1d8from sklearn.calibration import calibration_curve91011__all__ = [12'compute_calibration_error',13'create_binned_data',14'get_bin_boundaries',15'compute_binary_score',16'compute_calibration_summary',17]181920def compute_calibration_error(21y_true: np.ndarray,22y_prob: np.ndarray,23n_bins: int=15,24round_digits: int=4) -> float:25"""26Computes the calibration error for binary classification via binning27data points into the specified number of bins. Samples with similar28``y_prob`` will be grouped into the same bin. The bin boundary is29determined by having similar number of samples within each bin.3031Parameters32----------33y_true : 1d ndarray34Binary true targets.3536y_prob : 1d ndarray37Raw probability/score of the positive class.3839n_bins : int, default 1540A bigger bin number requires more data. In general,41the larger the bin size, the closer the calibration error42will be to the true calibration error.4344round_digits : int, default 445Round the calibration error metric.4647Returns48-------49calibration_error : float50RMSE between the average positive label and predicted probability51within each bin.52"""53y_true = column_or_1d(y_true)54y_prob = column_or_1d(y_prob)55check_consistent_length(y_true, y_prob)5657binned_y_true, binned_y_prob = create_binned_data(y_true, y_prob, n_bins)5859# looping shouldn't be a source of bottleneck as n_bins should be a small number.60bin_errors = 0.061for bin_y_true, bin_y_prob in zip(binned_y_true, binned_y_prob):62avg_y_true = np.mean(bin_y_true)63avg_y_score = np.mean(bin_y_prob)64bin_error = (avg_y_score - avg_y_true) ** 265bin_errors += bin_error * len(bin_y_true)6667calibration_error = math.sqrt(bin_errors / len(y_true))68return round(calibration_error, round_digits)697071def create_binned_data(72y_true: np.ndarray,73y_prob: np.ndarray,74n_bins: int) -> Tuple[List[np.ndarray], List[np.ndarray]]:75"""76Bin ``y_true`` and ``y_prob`` by distribution of the data.77i.e. each bin will contain approximately an equal number of78data points. Bins are sorted based on ascending order of ``y_prob``.7980Parameters81----------82y_true : 1d ndarray83Binary true targets.8485y_prob : 1d ndarray86Raw probability/score of the positive class.8788n_bins : int, default 1589A bigger bin number requires more data.9091Returns92-------93binned_y_true/binned_y_prob : 1d ndarray94Each element in the list stores the data for that bin.95"""96sorted_indices = np.argsort(y_prob)97sorted_y_true = y_true[sorted_indices]98sorted_y_prob = y_prob[sorted_indices]99binned_y_true = np.array_split(sorted_y_true, n_bins)100binned_y_prob = np.array_split(sorted_y_prob, n_bins)101return binned_y_true, binned_y_prob102103104def get_bin_boundaries(binned_y_prob: List[np.ndarray]) -> np.ndarray:105"""106Given ``binned_y_prob`` from ``create_binned_data`` get the107boundaries for each bin.108109Parameters110----------111binned_y_prob : list112Each element in the list stores the data for that bin.113114Returns115-------116bins : 1d ndarray117Boundaries for each bin.118"""119bins = []120for i in range(len(binned_y_prob) - 1):121last_prob = binned_y_prob[i][-1]122next_first_prob = binned_y_prob[i + 1][0]123bins.append((last_prob + next_first_prob) / 2.0)124125bins.append(1.0)126return np.array(bins)127128129def compute_binary_score(130y_true: np.ndarray,131y_prob: np.ndarray,132round_digits: int=4) -> Dict[str, float]:133"""134Compute various evaluation metrics for binary classification.135Including auc, precision, recall, f1, log loss, brier score. The136threshold for precision and recall numbers are based on the one137that gives the best f1 score.138139Parameters140----------141y_true : 1d ndarray142Binary true targets.143144y_prob : 1d ndarray145Raw probability/score of the positive class.146147round_digits : int, default 4148Round the evaluation metric.149150Returns151-------152metrics_dict : dict153Metrics are stored in key value pair. ::154155{156'auc': 0.82,157'precision': 0.56,158'recall': 0.61,159'f1': 0.59,160'log_loss': 0.42,161'brier': 0.12162}163"""164auc = round(metrics.roc_auc_score(y_true, y_prob), round_digits)165log_loss = round(metrics.log_loss(y_true, y_prob), round_digits)166brier_score = round(metrics.brier_score_loss(y_true, y_prob), round_digits)167168precision, recall, threshold = metrics.precision_recall_curve(y_true, y_prob)169f1 = 2 * (precision * recall) / (precision + recall)170171mask = ~np.isnan(f1)172f1 = f1[mask]173precision = precision[mask]174recall = recall[mask]175176best_index = np.argmax(f1)177precision = round(precision[best_index], round_digits)178recall = round(recall[best_index], round_digits)179f1 = round(f1[best_index], round_digits)180return {181'auc': auc,182'precision': precision,183'recall': recall,184'f1': f1,185'log_loss': log_loss,186'brier': brier_score187}188189190def compute_calibration_summary(191eval_dict: Dict[str, pd.DataFrame],192label_col: str='label',193score_col: str='score',194n_bins: int=15,195strategy: str='quantile',196round_digits: int=4,197show: bool=True,198save_plot_path: Optional[str]=None) -> pd.DataFrame:199"""200Plots the calibration curve and computes the summary statistics for the model.201202Parameters203----------204eval_dict : dict205We can evaluate multiple calibration model's performance in one go. The key206is the model name used to distinguish different calibration model, the value207is the dataframe that stores the binary true targets and the predicted score208for the positive class.209210label_col : str211Column name for the dataframe in ``eval_dict`` that stores the binary true targets.212213score_col : str214Column name for the dataframe in ``eval_dict`` that stores the predicted score.215216n_bins : int, default 15217Number of bins to discretize the calibration curve plot and calibration error statistics.218A bigger number requires more data, but will be closer to the true calibration error.219220strategy : {'uniform', 'quantile'}, default 'quantile'221Strategy used to define the boundary of the bins.222223- uniform: The bins have identical widths.224- quantile: The bins have the same number of samples and depend on the predicted score.225226round_digits : default 4227Round the evaluation metric.228229show : bool, default True230Whether to show the plots on the console or jupyter notebook.231232save_plot_path : str, default None233Path where we'll store the calibration plot. None means it will not save the plot.234235Returns236-------237df_metrics : pd.DataFrame238Corresponding metrics for all the input dataframe.239"""240241fig, (ax1, ax2) = plt.subplots(2)242243# estimator_metrics stores list of dict, e.g.244# [{'auc': 0.776, 'name': 'xgb'}]245estimator_metrics = []246for name, df_eval in eval_dict.items():247prob_true, prob_pred = calibration_curve(248df_eval[label_col],249df_eval[score_col],250n_bins=n_bins,251strategy=strategy)252253calibration_error = compute_calibration_error(254df_eval[label_col], df_eval[score_col], n_bins, round_digits)255metrics_dict = compute_binary_score(df_eval[label_col], df_eval[score_col], round_digits)256metrics_dict['calibration_error'] = calibration_error257metrics_dict['name'] = name258estimator_metrics.append(metrics_dict)259260ax1.plot(prob_pred, prob_true, 's-', label=name)261ax2.hist(df_eval[score_col], range=(0, 1), bins=n_bins, label=name, histtype='step', lw=2)262263ax1.plot([0, 1], [0, 1], 'k:', label='perfect')264265ax1.set_xlabel('Fraction of positives (Predicted)')266ax1.set_ylabel('Fraction of positives (Actual)')267ax1.set_ylim([-0.05, 1.05])268ax1.legend(loc='upper left', ncol=2)269ax1.set_title('Calibration Plots (Reliability Curve)')270271ax2.set_xlabel('Predicted scores')272ax2.set_ylabel('Count')273ax2.set_title('Histogram of Predicted Scores')274ax2.legend(loc='upper right', ncol=2)275276plt.tight_layout()277if show:278plt.show()279280if save_plot_path is not None:281save_dir = os.path.dirname(save_plot_path)282if save_dir:283os.makedirs(save_dir, exist_ok=True)284285fig.savefig(save_plot_path, dpi=300, bbox_inches='tight')286287plt.close(fig)288289df_metrics = pd.DataFrame(estimator_metrics)290return df_metrics291292293