Path: blob/master/model_selection/partial_dependence/partial_dependence.py
2585 views
import numpy as np1import pandas as pd2import matplotlib.pyplot as plt3from math import ceil4from joblib import Parallel, delayed5from matplotlib.gridspec import GridSpec678__all__ = ['PartialDependenceExplainer']91011class PartialDependenceExplainer:12"""13Partial Dependence explanation [1]_.1415- Supports scikit-learn like classification and regression classifiers.16- Works for both numerical and categorical columns.1718Parameters19----------20estimator : sklearn-like classifier21Model that was fitted on the data.2223n_grid_points : int, default 5024Number of grid points used in replacement25for the original numeric data. Only used26if the targeted column is numeric. For categorical27column, the number of grid points will always be28the distinct number of categories in that column.29Smaller number of grid points serves as an30approximation for the total number of unique31points and will result in faster computation3233batch_size : int, default = 'auto'34Compute partial depedence prediction batch by batch to save35memory usage, the default batch size will be36ceil(number of rows in the data / the number of grid points used)3738n_jobs : int, default 139Number of jobs to run in parallel, if the model already fits40extremely fast on the data, then specify 1 so that there's no41overhead of spawning different processes to do the computation4243verbose : int, default 144The verbosity level: if non zero, progress messages are printed.45Above 50, the output is sent to stdout. The frequency of the messages increases46with the verbosity level. If it more than 10, all iterations are reported.4748pre_dispatch : int or str, default '2*n_jobs'49Controls the number of jobs that get dispatched during parallel50execution. Reducing this number can be useful to avoid an51explosion of memory consumption when more jobs get dispatched52than CPUs can process. Possible inputs:53- None, in which case all the jobs are immediately54created and spawned. Use this for lightweight and55fast-running jobs, to avoid delays due to on-demand56spawning of the jobs57- An int, giving the exact number of total jobs that are58spawned59- A string, giving an expression as a function of n_jobs,60as in '2*n_jobs'6162Attributes63----------64feature_name_ : str65The input feature_name to the .fit unmodified, will66be used in subsequent method.6768feature_type_ : str69The input feature_type to the .fit unmodified, will70be used in subsequent method.7172feature_grid_ : 1d ndarray73Unique grid points that were used to generate the74partial dependence result.7576results : list of DataFrame77Partial dependence result. If it's a classification78estimator then each index of the list is the result79for each class. On the other hand, if it's a regression80estimator, it will be a list with 1 element.8182References83----------84.. [1] `Python partial dependence plot toolbox85<https://github.com/SauceCat/PDPbox>`_86"""8788def __init__(self, estimator, n_grid_points = 50, batch_size = 'auto',89n_jobs = 1, verbose = 1, pre_dispatch = '2*n_jobs'):90self.n_jobs = n_jobs91self.verbose = verbose92self.estimator = estimator93self.pre_dispatch = pre_dispatch94self.n_grid_points = n_grid_points9596def fit(self, data, feature_name, feature_type):97"""98Obtain the partial dependence result.99100Parameters101----------102data : DataFrame, shape [n_samples, n_features]103Input data to the estimator/model.104105feature_name : str106Feature's name in the data what we wish to explain.107108feature_type : str, {'num', 'cat'}109Specify whether feature_name is a numerical or110categorical column.111112Returns113-------114self115"""116117# check whether it's a classification or regression model118estimator = self.estimator119try:120n_classes = estimator.classes_.size121is_classifier = True122predict = estimator.predict_proba123except AttributeError:124# for regression problem, still set the125# number of classes to 1 to initialize126# the loop later downstream127n_classes = 1128is_classifier = False129predict = estimator.predict130131target = data[feature_name]132unique_target = np.unique(target)133n_unique = unique_target.size134if feature_type == 'num':135if self.n_grid_points >= n_unique:136feature_grid = unique_target137else:138# when the number of required grid points is smaller than the number of139# unique values, we choose the percentile points to make sure the grid points140# span widely across the whole value range141percentile = np.percentile(target, np.linspace(0, 100, self.n_grid_points))142feature_grid = np.unique(percentile)143144feature_cols = feature_grid145else:146feature_grid = unique_target147feature_cols = np.asarray(['{}_{}'.format(feature_name, category)148for category in unique_target])149150# compute prediction batch by batch to save memory usage151n_rows = data.shape[0]152batch_size = ceil(n_rows / feature_grid.size)153parallel = Parallel(154n_jobs = self.n_jobs, verbose = self.verbose, pre_dispatch = self.pre_dispatch)155outputs = parallel(delayed(_predict_batch)(data_batch,156feature_grid,157feature_name,158is_classifier,159n_classes,160predict)161for data_batch in _data_iter(data, batch_size))162results = []163for output in zip(*outputs):164result = pd.concat(output, ignore_index = True)165result.columns = feature_cols166results.append(result)167168self.results_ = results169self.feature_name_ = feature_name170self.feature_grid_ = feature_grid171self.feature_type_ = feature_type172return self173174def plot(self, centered = True, target_class = 0):175"""176Use the partial dependence result to generate177a partial dependence plot (using matplotlib).178179Parameters180----------181centered : bool, default True182Center the partial dependence plot by subtacting every partial183dependence result table's column value with the value of the first184column, i.e. first column's value will serve as the baseline185(centered at 0) for all other values.186187target_class : int, default 0188The target class to show for the partial dependence result,189for regression task, we can leave the default number unmodified,190but for classification task, we should specify the target class191parameter to meet our needs192193Returns194-------195figure196"""197figure = GridSpec(5, 1)198ax1 = plt.subplot(figure[0, :])199self._plot_title(ax1)200ax2 = plt.subplot(figure[1:, :])201self._plot_content(ax2, centered, target_class)202return figure203204def _plot_title(self, ax):205font_family = 'Arial'206title = 'Partial Dependence Plot for {}'.format(self.feature_name_)207subtitle = 'Number of unique grid points: {}'.format(self.feature_grid_.size)208title_fontsize = 15209subtitle_fontsize = 12210211ax.set_facecolor('white')212ax.text(2130, 0.7, title,214fontsize = title_fontsize, fontname = font_family)215ax.text(2160, 0.4, subtitle, color = 'grey',217fontsize = subtitle_fontsize, fontname = font_family)218ax.axis('off')219220def _plot_content(self, ax, centered, target_class):221# pd (partial dependence)222pd_linewidth = 2223pd_markersize = 5224pd_color = '#1A4E5D'225fill_alpha = 0.2226fill_color = '#66C2D7'227zero_linewidth = 1.5228zero_color = '#E75438'229xlabel_fontsize = 10230231results = self.results_[target_class]232feature_cols = results.columns233if self.feature_type_ == 'cat':234# ticks = all the unique categories235x = range(len(feature_cols))236ax.set_xticks(x)237ax.set_xticklabels(feature_cols)238else:239x = feature_cols240241# center the partial dependence plot by subtacting every value242# with the value of the first column, i.e. first column's value243# will serve as the baseline (centered at 0) for all other values244pd = results.values.mean(axis = 0)245if centered:246pd -= pd[0]247248pd_std = results.values.std(axis = 0)249upper = pd + pd_std250lower = pd - pd_std251252ax.plot(253x, pd, color = pd_color, linewidth = pd_linewidth,254marker = 'o', markersize = pd_markersize)255ax.plot(256x, [0] * pd.size, color = zero_color,257linestyle = '--', linewidth = zero_linewidth)258ax.fill_between(x, upper, lower, alpha = fill_alpha, color = fill_color)259ax.set_xlabel(self.feature_name_, fontsize = xlabel_fontsize)260self._modify_axis(ax)261262def _modify_axis(self, ax):263tick_labelsize = 8264tick_colors = '#9E9E9E'265tick_labelcolor = '#424242'266267ax.tick_params(268axis = 'both', which = 'major', colors = tick_colors,269labelsize = tick_labelsize, labelcolor = tick_labelcolor)270271ax.set_facecolor('white')272ax.get_yaxis().tick_left()273ax.get_xaxis().tick_bottom()274for direction in ('top', 'left', 'right', 'bottom'):275ax.spines[direction].set_visible(False)276277for axis in ('x', 'y'):278ax.grid(True, 'major', axis, ls = '--', lw = .5, c = 'k', alpha = .3)279280281def _data_iter(data, batch_size):282"""Used by PartialDependenceExplainer to loop through the data by batch"""283n_rows = data.shape[0]284for i in range(0, n_rows, batch_size):285yield data[i:i + batch_size].reset_index(drop = True)286287288def _predict_batch(data_batch, feature_grid, feature_name,289is_classifier, n_classes, predict):290"""Used by PartialDependenceExplainer to generate prediction by batch"""291292# repeat the index and use it to slice the data to create the repeated data293# instead of creating the repetition using the values, i.e.294# np.repeat(data_batch.values, repeats = feature_grid.size, axis = 0)295# this prevents everything from getting converted to a different data type, e.g.296# if there is 1 object type column then everything would get converted to object297index_batch = np.repeat(data_batch.index.values, repeats = feature_grid.size)298ice_data = data_batch.iloc[index_batch].copy()299ice_data[feature_name] = np.tile(feature_grid, data_batch.shape[0])300301results = []302prediction = predict(ice_data)303for n_class in range(n_classes):304if is_classifier:305result = prediction[:, n_class]306else:307result = prediction308309# reshape tiled data back to original batch's shape310reshaped = result.reshape((data_batch.shape[0], feature_grid.size))311result = pd.DataFrame(reshaped)312results.append(result)313314return results315316317