Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/vision/integrated_gradients.py
8109 views
1
"""
2
Title: Model interpretability with Integrated Gradients
3
Author: [A_K_Nain](https://twitter.com/A_K_Nain)
4
Date created: 2020/06/02
5
Last modified: 2020/06/02
6
Description: How to obtain integrated gradients for a classification model.
7
Accelerator: None
8
"""
9
10
"""
11
## Integrated Gradients
12
13
[Integrated Gradients](https://arxiv.org/abs/1703.01365) is a technique for
14
attributing a classification model's prediction to its input features. It is
15
a model interpretability technique: you can use it to visualize the relationship
16
between input features and model predictions.
17
18
Integrated Gradients is a variation on computing
19
the gradient of the prediction output with regard to features of the input.
20
To compute integrated gradients, we need to perform the following steps:
21
22
1. Identify the input and the output. In our case, the input is an image and the
23
output is the last layer of our model (dense layer with softmax activation).
24
25
2. Compute which features are important to a neural network
26
when making a prediction on a particular data point. To identify these features, we
27
need to choose a baseline input. A baseline input can be a black image (all pixel
28
values set to zero) or random noise. The shape of the baseline input needs to be
29
the same as our input image, e.g. (299, 299, 3).
30
31
3. Interpolate the baseline for a given number of steps. The number of steps represents
32
the steps we need in the gradient approximation for a given input image. The number of
33
steps is a hyperparameter. The authors recommend using anywhere between
34
20 and 1000 steps.
35
36
4. Preprocess these interpolated images and do a forward pass.
37
5. Get the gradients for these interpolated images.
38
6. Approximate the gradients integral using the trapezoidal rule.
39
40
To read in-depth about integrated gradients and why this method works,
41
consider reading this excellent
42
[article](https://distill.pub/2020/attribution-baselines/).
43
44
**References:**
45
46
- Integrated Gradients original [paper](https://arxiv.org/abs/1703.01365)
47
- [Original implementation](https://github.com/ankurtaly/Integrated-Gradients)
48
"""
49
50
"""
51
## Setup
52
"""
53
54
55
import numpy as np
56
import matplotlib.pyplot as plt
57
from scipy import ndimage
58
from IPython.display import Image, display
59
60
import tensorflow as tf
61
import keras
62
from keras import layers
63
from keras.applications import xception
64
65
# Size of the input image
66
img_size = (299, 299, 3)
67
68
# Load Xception model with imagenet weights
69
model = xception.Xception(weights="imagenet")
70
71
# The local path to our target image
72
img_path = keras.utils.get_file("elephant.jpg", "https://i.imgur.com/Bvro0YD.png")
73
display(Image(img_path))
74
75
"""
76
## Integrated Gradients algorithm
77
"""
78
79
80
def get_img_array(img_path, size=(299, 299)):
81
# `img` is a PIL image of size 299x299
82
img = keras.utils.load_img(img_path, target_size=size)
83
# `array` is a float32 Numpy array of shape (299, 299, 3)
84
array = keras.utils.img_to_array(img)
85
# We add a dimension to transform our array into a "batch"
86
# of size (1, 299, 299, 3)
87
array = np.expand_dims(array, axis=0)
88
return array
89
90
91
def get_gradients(img_input, top_pred_idx):
92
"""Computes the gradients of outputs w.r.t input image.
93
94
Args:
95
img_input: 4D image tensor
96
top_pred_idx: Predicted label for the input image
97
98
Returns:
99
Gradients of the predictions w.r.t img_input
100
"""
101
images = tf.cast(img_input, tf.float32)
102
103
with tf.GradientTape() as tape:
104
tape.watch(images)
105
preds = model(images)
106
top_class = preds[:, top_pred_idx]
107
108
grads = tape.gradient(top_class, images)
109
return grads
110
111
112
def get_integrated_gradients(img_input, top_pred_idx, baseline=None, num_steps=50):
113
"""Computes Integrated Gradients for a predicted label.
114
115
Args:
116
img_input (ndarray): Original image
117
top_pred_idx: Predicted label for the input image
118
baseline (ndarray): The baseline image to start with for interpolation
119
num_steps: Number of interpolation steps between the baseline
120
and the input used in the computation of integrated gradients. These
121
steps along determine the integral approximation error. By default,
122
num_steps is set to 50.
123
124
Returns:
125
Integrated gradients w.r.t input image
126
"""
127
# If baseline is not provided, start with a black image
128
# having same size as the input image.
129
if baseline is None:
130
baseline = np.zeros(img_size).astype(np.float32)
131
else:
132
baseline = baseline.astype(np.float32)
133
134
# 1. Do interpolation.
135
img_input = img_input.astype(np.float32)
136
interpolated_image = [
137
baseline + (step / num_steps) * (img_input - baseline)
138
for step in range(num_steps + 1)
139
]
140
interpolated_image = np.array(interpolated_image).astype(np.float32)
141
142
# 2. Preprocess the interpolated images
143
interpolated_image = xception.preprocess_input(interpolated_image)
144
145
# 3. Get the gradients
146
grads = []
147
for i, img in enumerate(interpolated_image):
148
img = tf.expand_dims(img, axis=0)
149
grad = get_gradients(img, top_pred_idx=top_pred_idx)
150
grads.append(grad[0])
151
grads = tf.convert_to_tensor(grads, dtype=tf.float32)
152
153
# 4. Approximate the integral using the trapezoidal rule
154
grads = (grads[:-1] + grads[1:]) / 2.0
155
avg_grads = tf.reduce_mean(grads, axis=0)
156
157
# 5. Calculate integrated gradients and return
158
integrated_grads = (img_input - baseline) * avg_grads
159
return integrated_grads
160
161
162
def random_baseline_integrated_gradients(
163
img_input, top_pred_idx, num_steps=50, num_runs=2
164
):
165
"""Generates a number of random baseline images.
166
167
Args:
168
img_input (ndarray): 3D image
169
top_pred_idx: Predicted label for the input image
170
num_steps: Number of interpolation steps between the baseline
171
and the input used in the computation of integrated gradients. These
172
steps along determine the integral approximation error. By default,
173
num_steps is set to 50.
174
num_runs: number of baseline images to generate
175
176
Returns:
177
Averaged integrated gradients for `num_runs` baseline images
178
"""
179
# 1. List to keep track of Integrated Gradients (IG) for all the images
180
integrated_grads = []
181
182
# 2. Get the integrated gradients for all the baselines
183
for run in range(num_runs):
184
baseline = np.random.random(img_size) * 255
185
igrads = get_integrated_gradients(
186
img_input=img_input,
187
top_pred_idx=top_pred_idx,
188
baseline=baseline,
189
num_steps=num_steps,
190
)
191
integrated_grads.append(igrads)
192
193
# 3. Return the average integrated gradients for the image
194
integrated_grads = tf.convert_to_tensor(integrated_grads)
195
return tf.reduce_mean(integrated_grads, axis=0)
196
197
198
"""
199
## Helper class for visualizing gradients and integrated gradients
200
"""
201
202
203
class GradVisualizer:
204
"""Plot gradients of the outputs w.r.t an input image."""
205
206
def __init__(self, positive_channel=None, negative_channel=None):
207
if positive_channel is None:
208
self.positive_channel = [0, 255, 0]
209
else:
210
self.positive_channel = positive_channel
211
212
if negative_channel is None:
213
self.negative_channel = [255, 0, 0]
214
else:
215
self.negative_channel = negative_channel
216
217
def apply_polarity(self, attributions, polarity):
218
if polarity == "positive":
219
return np.clip(attributions, 0, 1)
220
else:
221
return np.clip(attributions, -1, 0)
222
223
def apply_linear_transformation(
224
self,
225
attributions,
226
clip_above_percentile=99.9,
227
clip_below_percentile=70.0,
228
lower_end=0.2,
229
):
230
# 1. Get the thresholds
231
m = self.get_thresholded_attributions(
232
attributions, percentage=100 - clip_above_percentile
233
)
234
e = self.get_thresholded_attributions(
235
attributions, percentage=100 - clip_below_percentile
236
)
237
238
# 2. Transform the attributions by a linear function f(x) = a*x + b such that
239
# f(m) = 1.0 and f(e) = lower_end
240
transformed_attributions = (1 - lower_end) * (np.abs(attributions) - e) / (
241
m - e
242
) + lower_end
243
244
# 3. Make sure that the sign of transformed attributions is the same as original attributions
245
transformed_attributions *= np.sign(attributions)
246
247
# 4. Only keep values that are bigger than the lower_end
248
transformed_attributions *= transformed_attributions >= lower_end
249
250
# 5. Clip values and return
251
transformed_attributions = np.clip(transformed_attributions, 0.0, 1.0)
252
return transformed_attributions
253
254
def get_thresholded_attributions(self, attributions, percentage):
255
if percentage == 100.0:
256
return np.min(attributions)
257
258
# 1. Flatten the attributions
259
flatten_attr = attributions.flatten()
260
261
# 2. Get the sum of the attributions
262
total = np.sum(flatten_attr)
263
264
# 3. Sort the attributions from largest to smallest.
265
sorted_attributions = np.sort(np.abs(flatten_attr))[::-1]
266
267
# 4. Calculate the percentage of the total sum that each attribution
268
# and the values about it contribute.
269
cum_sum = 100.0 * np.cumsum(sorted_attributions) / total
270
271
# 5. Threshold the attributions by the percentage
272
indices_to_consider = np.where(cum_sum >= percentage)[0][0]
273
274
# 6. Select the desired attributions and return
275
attributions = sorted_attributions[indices_to_consider]
276
return attributions
277
278
def binarize(self, attributions, threshold=0.001):
279
return attributions > threshold
280
281
def morphological_cleanup_fn(self, attributions, structure=np.ones((4, 4))):
282
closed = ndimage.grey_closing(attributions, structure=structure)
283
opened = ndimage.grey_opening(closed, structure=structure)
284
return opened
285
286
def draw_outlines(
287
self,
288
attributions,
289
percentage=90,
290
connected_component_structure=np.ones((3, 3)),
291
):
292
# 1. Binarize the attributions.
293
attributions = self.binarize(attributions)
294
295
# 2. Fill the gaps
296
attributions = ndimage.binary_fill_holes(attributions)
297
298
# 3. Compute connected components
299
connected_components, num_comp = ndimage.label(
300
attributions, structure=connected_component_structure
301
)
302
303
# 4. Sum up the attributions for each component
304
total = np.sum(attributions[connected_components > 0])
305
component_sums = []
306
for comp in range(1, num_comp + 1):
307
mask = connected_components == comp
308
component_sum = np.sum(attributions[mask])
309
component_sums.append((component_sum, mask))
310
311
# 5. Compute the percentage of top components to keep
312
sorted_sums_and_masks = sorted(component_sums, key=lambda x: x[0], reverse=True)
313
sorted_sums = list(zip(*sorted_sums_and_masks))[0]
314
cumulative_sorted_sums = np.cumsum(sorted_sums)
315
cutoff_threshold = percentage * total / 100
316
cutoff_idx = np.where(cumulative_sorted_sums >= cutoff_threshold)[0][0]
317
if cutoff_idx > 2:
318
cutoff_idx = 2
319
320
# 6. Set the values for the kept components
321
border_mask = np.zeros_like(attributions)
322
for i in range(cutoff_idx + 1):
323
border_mask[sorted_sums_and_masks[i][1]] = 1
324
325
# 7. Make the mask hollow and show only the border
326
eroded_mask = ndimage.binary_erosion(border_mask, iterations=1)
327
border_mask[eroded_mask] = 0
328
329
# 8. Return the outlined mask
330
return border_mask
331
332
def process_grads(
333
self,
334
image,
335
attributions,
336
polarity="positive",
337
clip_above_percentile=99.9,
338
clip_below_percentile=0,
339
morphological_cleanup=False,
340
structure=np.ones((3, 3)),
341
outlines=False,
342
outlines_component_percentage=90,
343
overlay=True,
344
):
345
if polarity not in ["positive", "negative"]:
346
raise ValueError(f""" Allowed polarity values: 'positive' or 'negative'
347
but provided {polarity}""")
348
if clip_above_percentile < 0 or clip_above_percentile > 100:
349
raise ValueError("clip_above_percentile must be in [0, 100]")
350
351
if clip_below_percentile < 0 or clip_below_percentile > 100:
352
raise ValueError("clip_below_percentile must be in [0, 100]")
353
354
# 1. Apply polarity
355
if polarity == "positive":
356
attributions = self.apply_polarity(attributions, polarity=polarity)
357
channel = self.positive_channel
358
else:
359
attributions = self.apply_polarity(attributions, polarity=polarity)
360
attributions = np.abs(attributions)
361
channel = self.negative_channel
362
363
# 2. Take average over the channels
364
attributions = np.average(attributions, axis=2)
365
366
# 3. Apply linear transformation to the attributions
367
attributions = self.apply_linear_transformation(
368
attributions,
369
clip_above_percentile=clip_above_percentile,
370
clip_below_percentile=clip_below_percentile,
371
lower_end=0.0,
372
)
373
374
# 4. Cleanup
375
if morphological_cleanup:
376
attributions = self.morphological_cleanup_fn(
377
attributions, structure=structure
378
)
379
# 5. Draw the outlines
380
if outlines:
381
attributions = self.draw_outlines(
382
attributions, percentage=outlines_component_percentage
383
)
384
385
# 6. Expand the channel axis and convert to RGB
386
attributions = np.expand_dims(attributions, 2) * channel
387
388
# 7.Superimpose on the original image
389
if overlay:
390
attributions = np.clip((attributions * 0.8 + image), 0, 255)
391
return attributions
392
393
def visualize(
394
self,
395
image,
396
gradients,
397
integrated_gradients,
398
polarity="positive",
399
clip_above_percentile=99.9,
400
clip_below_percentile=0,
401
morphological_cleanup=False,
402
structure=np.ones((3, 3)),
403
outlines=False,
404
outlines_component_percentage=90,
405
overlay=True,
406
figsize=(15, 8),
407
):
408
# 1. Make two copies of the original image
409
img1 = np.copy(image)
410
img2 = np.copy(image)
411
412
# 2. Process the normal gradients
413
grads_attr = self.process_grads(
414
image=img1,
415
attributions=gradients,
416
polarity=polarity,
417
clip_above_percentile=clip_above_percentile,
418
clip_below_percentile=clip_below_percentile,
419
morphological_cleanup=morphological_cleanup,
420
structure=structure,
421
outlines=outlines,
422
outlines_component_percentage=outlines_component_percentage,
423
overlay=overlay,
424
)
425
426
# 3. Process the integrated gradients
427
igrads_attr = self.process_grads(
428
image=img2,
429
attributions=integrated_gradients,
430
polarity=polarity,
431
clip_above_percentile=clip_above_percentile,
432
clip_below_percentile=clip_below_percentile,
433
morphological_cleanup=morphological_cleanup,
434
structure=structure,
435
outlines=outlines,
436
outlines_component_percentage=outlines_component_percentage,
437
overlay=overlay,
438
)
439
440
_, ax = plt.subplots(1, 3, figsize=figsize)
441
ax[0].imshow(image)
442
ax[1].imshow(grads_attr.astype(np.uint8))
443
ax[2].imshow(igrads_attr.astype(np.uint8))
444
445
ax[0].set_title("Input")
446
ax[1].set_title("Normal gradients")
447
ax[2].set_title("Integrated gradients")
448
plt.show()
449
450
451
"""
452
## Let's test-drive it
453
"""
454
455
# 1. Convert the image to numpy array
456
img = get_img_array(img_path)
457
458
# 2. Keep a copy of the original image
459
orig_img = np.copy(img[0]).astype(np.uint8)
460
461
# 3. Preprocess the image
462
img_processed = tf.cast(xception.preprocess_input(img), dtype=tf.float32)
463
464
# 4. Get model predictions
465
preds = model.predict(img_processed)
466
top_pred_idx = tf.argmax(preds[0])
467
print("Predicted:", top_pred_idx, xception.decode_predictions(preds, top=1)[0])
468
469
# 5. Get the gradients of the last layer for the predicted label
470
grads = get_gradients(img_processed, top_pred_idx=top_pred_idx)
471
472
# 6. Get the integrated gradients
473
igrads = random_baseline_integrated_gradients(
474
np.copy(orig_img), top_pred_idx=top_pred_idx, num_steps=50, num_runs=2
475
)
476
477
# 7. Process the gradients and plot
478
vis = GradVisualizer()
479
vis.visualize(
480
image=orig_img,
481
gradients=grads[0].numpy(),
482
integrated_gradients=igrads.numpy(),
483
clip_above_percentile=99,
484
clip_below_percentile=0,
485
)
486
487
vis.visualize(
488
image=orig_img,
489
gradients=grads[0].numpy(),
490
integrated_gradients=igrads.numpy(),
491
clip_above_percentile=95,
492
clip_below_percentile=28,
493
morphological_cleanup=True,
494
outlines=True,
495
)
496
497
"""
498
## Relevant Chapters from Deep Learning with Python
499
- [Chapter 10: Interpreting what ConvNets learn](https://deeplearningwithpython.io/chapters/chapter10_interpreting-what-convnets-learn)
500
"""
501
502