Path: blob/master/big_data/h2o/h2o_explainers.py
2581 views
import matplotlib.pyplot as plt1from matplotlib.gridspec import GridSpec234class H2OPartialDependenceExplainer:5"""6Partial Dependence explanation for binary classification H2O models.7Works for both numerical and categorical (enum) features.89Parameters10----------11h2o_model : H2OEstimator12H2O Model that was already fitted on the data.1314Attributes15----------16feature_name_ : str17The input feature_name to the .fit unmodified, will18be used in subsequent method.1920is_cat_col_ : bool21Whether the feature we're aiming to explain is a22categorical feature or not.2324partial_dep_ : DataFrame25A pandas dataframe that contains three columns, the26feature's value and their corresponding mean prediction27and standard deviation of the prediction. e.g.2829feature_name mean_response stddev_response303000.000000 0.284140 0.12065931318631.578947 0.134414 0.07605432634263.157895 0.142961 0.0836303334The feature_name column will be the actual feature_name that35we pass to the .fit method whereas the mean_response and36stddev_response column will be fixed columns generated.37"""3839def __init__(self, h2o_model):40self.h2o_model = h2o_model4142def fit(self, data, feature_name, n_bins=20):43"""44Obtain the partial dependence result.4546Parameters47----------48data : H2OFrame, shape [n_samples, n_features]49Input data to the H2O estimator/model.5051feature_name : str52Feature name in the data what we wish to explain.5354n_bins : int, default 2055Number of bins used. For categorical columns, we will make sure the number56of bins exceed the distinct level count.5758Returns59-------60self61"""62self.is_cat_col_ = data[feature_name].isfactor()[0]63if self.is_cat_col_:64n_levels = len(data[feature_name].levels()[0])65n_bins = max(n_levels, n_bins)6667partial_dep = self.h2o_model.partial_plot(data, cols=[feature_name],68nbins=n_bins, plot=False)69self.feature_name_ = feature_name70self.partial_dep_ = partial_dep[0].as_data_frame()71return self7273def plot(self, centered=True, plot_stddev=True):74"""75Use the partial dependence result to generate76a partial dependence plot (using matplotlib).7778Parameters79----------80centered : bool, default True81Center the partial dependence plot by subtacting every partial82dependence result table's column value with the value of the first83column, i.e. first column's value will serve as the baseline84(centered at 0) for all other values.8586plot_stddev : bool, default True87Apart from plotting the mean partial dependence, also show the88standard deviation as a fill between.8990Returns91-------92matplotlib figure93"""94figure = GridSpec(5, 1)95ax1 = plt.subplot(figure[0, :])96self._plot_title(ax1)97ax2 = plt.subplot(figure[1:, :])98self._plot_content(ax2, centered, plot_stddev)99return figure100101def _plot_title(self, ax):102font_family = 'Arial'103title = "Partial Dependence Plot for '{}' feature".format(self.feature_name_)104subtitle = 'Number of unique grid points: {}'.format(self.partial_dep_.shape[0])105title_fontsize = 15106subtitle_fontsize = 12107108ax.set_facecolor('white')109ax.text(1100, 0.7, title,111fontsize=title_fontsize, fontname=font_family)112ax.text(1130, 0.4, subtitle, color='grey',114fontsize=subtitle_fontsize, fontname=font_family)115ax.axis('off')116117def _plot_content(self, ax, centered, plot_stddev):118# pd (partial dependence)119pd_linewidth = 2120pd_markersize = 5121pd_color = '#1A4E5D'122fill_alpha = 0.2123fill_color = '#66C2D7'124zero_linewidth = 1.5125zero_color = '#E75438'126xlabel_fontsize = 10127128pd_mean = self.partial_dep_['mean_response']129if centered:130# center the partial dependence plot by subtacting every value131# with the value of the first column, i.e. first column's value132# will serve as the baseline (centered at 0) for all other values133pd_mean -= pd_mean[0]134135std = self.partial_dep_['stddev_response']136upper = pd_mean + std137lower = pd_mean - std138x = self.partial_dep_[self.feature_name_]139140ax.plot(141x, pd_mean, color=pd_color, linewidth=pd_linewidth,142marker='o', markersize=pd_markersize)143ax.plot(144x, [0] * pd_mean.size, color=zero_color,145linestyle='--', linewidth=zero_linewidth)146147if plot_stddev:148ax.fill_between(x, upper, lower, alpha=fill_alpha, color=fill_color)149150ax.set_xlabel(self.feature_name_, fontsize=xlabel_fontsize)151self._modify_axis(ax)152153def _modify_axis(self, ax):154tick_labelsize = 8155tick_colors = '#9E9E9E'156tick_labelcolor = '#424242'157158ax.tick_params(159axis='both', which='major', colors=tick_colors,160labelsize=tick_labelsize, labelcolor=tick_labelcolor)161162ax.set_facecolor('white')163ax.get_yaxis().tick_left()164ax.get_xaxis().tick_bottom()165for direction in ('top', 'left', 'right', 'bottom'):166ax.spines[direction].set_visible(False)167168for axis in ('x', 'y'):169ax.grid(True, 'major', axis, ls='--', lw=.5, c='k', alpha=.3)170171172