Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ethen8181
GitHub Repository: ethen8181/machine-learning
Path: blob/master/big_data/h2o/h2o_explainers.py
2581 views
1
import matplotlib.pyplot as plt
2
from matplotlib.gridspec import GridSpec
3
4
5
class H2OPartialDependenceExplainer:
6
"""
7
Partial Dependence explanation for binary classification H2O models.
8
Works for both numerical and categorical (enum) features.
9
10
Parameters
11
----------
12
h2o_model : H2OEstimator
13
H2O Model that was already fitted on the data.
14
15
Attributes
16
----------
17
feature_name_ : str
18
The input feature_name to the .fit unmodified, will
19
be used in subsequent method.
20
21
is_cat_col_ : bool
22
Whether the feature we're aiming to explain is a
23
categorical feature or not.
24
25
partial_dep_ : DataFrame
26
A pandas dataframe that contains three columns, the
27
feature's value and their corresponding mean prediction
28
and standard deviation of the prediction. e.g.
29
30
feature_name mean_response stddev_response
31
3000.000000 0.284140 0.120659
32
318631.578947 0.134414 0.076054
33
634263.157895 0.142961 0.083630
34
35
The feature_name column will be the actual feature_name that
36
we pass to the .fit method whereas the mean_response and
37
stddev_response column will be fixed columns generated.
38
"""
39
40
def __init__(self, h2o_model):
41
self.h2o_model = h2o_model
42
43
def fit(self, data, feature_name, n_bins=20):
44
"""
45
Obtain the partial dependence result.
46
47
Parameters
48
----------
49
data : H2OFrame, shape [n_samples, n_features]
50
Input data to the H2O estimator/model.
51
52
feature_name : str
53
Feature name in the data what we wish to explain.
54
55
n_bins : int, default 20
56
Number of bins used. For categorical columns, we will make sure the number
57
of bins exceed the distinct level count.
58
59
Returns
60
-------
61
self
62
"""
63
self.is_cat_col_ = data[feature_name].isfactor()[0]
64
if self.is_cat_col_:
65
n_levels = len(data[feature_name].levels()[0])
66
n_bins = max(n_levels, n_bins)
67
68
partial_dep = self.h2o_model.partial_plot(data, cols=[feature_name],
69
nbins=n_bins, plot=False)
70
self.feature_name_ = feature_name
71
self.partial_dep_ = partial_dep[0].as_data_frame()
72
return self
73
74
def plot(self, centered=True, plot_stddev=True):
75
"""
76
Use the partial dependence result to generate
77
a partial dependence plot (using matplotlib).
78
79
Parameters
80
----------
81
centered : bool, default True
82
Center the partial dependence plot by subtacting every partial
83
dependence result table's column value with the value of the first
84
column, i.e. first column's value will serve as the baseline
85
(centered at 0) for all other values.
86
87
plot_stddev : bool, default True
88
Apart from plotting the mean partial dependence, also show the
89
standard deviation as a fill between.
90
91
Returns
92
-------
93
matplotlib figure
94
"""
95
figure = GridSpec(5, 1)
96
ax1 = plt.subplot(figure[0, :])
97
self._plot_title(ax1)
98
ax2 = plt.subplot(figure[1:, :])
99
self._plot_content(ax2, centered, plot_stddev)
100
return figure
101
102
def _plot_title(self, ax):
103
font_family = 'Arial'
104
title = "Partial Dependence Plot for '{}' feature".format(self.feature_name_)
105
subtitle = 'Number of unique grid points: {}'.format(self.partial_dep_.shape[0])
106
title_fontsize = 15
107
subtitle_fontsize = 12
108
109
ax.set_facecolor('white')
110
ax.text(
111
0, 0.7, title,
112
fontsize=title_fontsize, fontname=font_family)
113
ax.text(
114
0, 0.4, subtitle, color='grey',
115
fontsize=subtitle_fontsize, fontname=font_family)
116
ax.axis('off')
117
118
def _plot_content(self, ax, centered, plot_stddev):
119
# pd (partial dependence)
120
pd_linewidth = 2
121
pd_markersize = 5
122
pd_color = '#1A4E5D'
123
fill_alpha = 0.2
124
fill_color = '#66C2D7'
125
zero_linewidth = 1.5
126
zero_color = '#E75438'
127
xlabel_fontsize = 10
128
129
pd_mean = self.partial_dep_['mean_response']
130
if centered:
131
# center the partial dependence plot by subtacting every value
132
# with the value of the first column, i.e. first column's value
133
# will serve as the baseline (centered at 0) for all other values
134
pd_mean -= pd_mean[0]
135
136
std = self.partial_dep_['stddev_response']
137
upper = pd_mean + std
138
lower = pd_mean - std
139
x = self.partial_dep_[self.feature_name_]
140
141
ax.plot(
142
x, pd_mean, color=pd_color, linewidth=pd_linewidth,
143
marker='o', markersize=pd_markersize)
144
ax.plot(
145
x, [0] * pd_mean.size, color=zero_color,
146
linestyle='--', linewidth=zero_linewidth)
147
148
if plot_stddev:
149
ax.fill_between(x, upper, lower, alpha=fill_alpha, color=fill_color)
150
151
ax.set_xlabel(self.feature_name_, fontsize=xlabel_fontsize)
152
self._modify_axis(ax)
153
154
def _modify_axis(self, ax):
155
tick_labelsize = 8
156
tick_colors = '#9E9E9E'
157
tick_labelcolor = '#424242'
158
159
ax.tick_params(
160
axis='both', which='major', colors=tick_colors,
161
labelsize=tick_labelsize, labelcolor=tick_labelcolor)
162
163
ax.set_facecolor('white')
164
ax.get_yaxis().tick_left()
165
ax.get_xaxis().tick_bottom()
166
for direction in ('top', 'left', 'right', 'bottom'):
167
ax.spines[direction].set_visible(False)
168
169
for axis in ('x', 'y'):
170
ax.grid(True, 'major', axis, ls='--', lw=.5, c='k', alpha=.3)
171
172