Path: blob/master/examples/vision/integrated_gradients.py
8109 views
"""1Title: Model interpretability with Integrated Gradients2Author: [A_K_Nain](https://twitter.com/A_K_Nain)3Date created: 2020/06/024Last modified: 2020/06/025Description: How to obtain integrated gradients for a classification model.6Accelerator: None7"""89"""10## Integrated Gradients1112[Integrated Gradients](https://arxiv.org/abs/1703.01365) is a technique for13attributing a classification model's prediction to its input features. It is14a model interpretability technique: you can use it to visualize the relationship15between input features and model predictions.1617Integrated Gradients is a variation on computing18the gradient of the prediction output with regard to features of the input.19To compute integrated gradients, we need to perform the following steps:20211. Identify the input and the output. In our case, the input is an image and the22output is the last layer of our model (dense layer with softmax activation).23242. Compute which features are important to a neural network25when making a prediction on a particular data point. To identify these features, we26need to choose a baseline input. A baseline input can be a black image (all pixel27values set to zero) or random noise. The shape of the baseline input needs to be28the same as our input image, e.g. (299, 299, 3).29303. Interpolate the baseline for a given number of steps. The number of steps represents31the steps we need in the gradient approximation for a given input image. The number of32steps is a hyperparameter. The authors recommend using anywhere between3320 and 1000 steps.34354. Preprocess these interpolated images and do a forward pass.365. Get the gradients for these interpolated images.376. Approximate the gradients integral using the trapezoidal rule.3839To read in-depth about integrated gradients and why this method works,40consider reading this excellent41[article](https://distill.pub/2020/attribution-baselines/).4243**References:**4445- Integrated Gradients original [paper](https://arxiv.org/abs/1703.01365)46- [Original implementation](https://github.com/ankurtaly/Integrated-Gradients)47"""4849"""50## Setup51"""525354import numpy as np55import matplotlib.pyplot as plt56from scipy import ndimage57from IPython.display import Image, display5859import tensorflow as tf60import keras61from keras import layers62from keras.applications import xception6364# Size of the input image65img_size = (299, 299, 3)6667# Load Xception model with imagenet weights68model = xception.Xception(weights="imagenet")6970# The local path to our target image71img_path = keras.utils.get_file("elephant.jpg", "https://i.imgur.com/Bvro0YD.png")72display(Image(img_path))7374"""75## Integrated Gradients algorithm76"""777879def get_img_array(img_path, size=(299, 299)):80# `img` is a PIL image of size 299x29981img = keras.utils.load_img(img_path, target_size=size)82# `array` is a float32 Numpy array of shape (299, 299, 3)83array = keras.utils.img_to_array(img)84# We add a dimension to transform our array into a "batch"85# of size (1, 299, 299, 3)86array = np.expand_dims(array, axis=0)87return array888990def get_gradients(img_input, top_pred_idx):91"""Computes the gradients of outputs w.r.t input image.9293Args:94img_input: 4D image tensor95top_pred_idx: Predicted label for the input image9697Returns:98Gradients of the predictions w.r.t img_input99"""100images = tf.cast(img_input, tf.float32)101102with tf.GradientTape() as tape:103tape.watch(images)104preds = model(images)105top_class = preds[:, top_pred_idx]106107grads = tape.gradient(top_class, images)108return grads109110111def get_integrated_gradients(img_input, top_pred_idx, baseline=None, num_steps=50):112"""Computes Integrated Gradients for a predicted label.113114Args:115img_input (ndarray): Original image116top_pred_idx: Predicted label for the input image117baseline (ndarray): The baseline image to start with for interpolation118num_steps: Number of interpolation steps between the baseline119and the input used in the computation of integrated gradients. These120steps along determine the integral approximation error. By default,121num_steps is set to 50.122123Returns:124Integrated gradients w.r.t input image125"""126# If baseline is not provided, start with a black image127# having same size as the input image.128if baseline is None:129baseline = np.zeros(img_size).astype(np.float32)130else:131baseline = baseline.astype(np.float32)132133# 1. Do interpolation.134img_input = img_input.astype(np.float32)135interpolated_image = [136baseline + (step / num_steps) * (img_input - baseline)137for step in range(num_steps + 1)138]139interpolated_image = np.array(interpolated_image).astype(np.float32)140141# 2. Preprocess the interpolated images142interpolated_image = xception.preprocess_input(interpolated_image)143144# 3. Get the gradients145grads = []146for i, img in enumerate(interpolated_image):147img = tf.expand_dims(img, axis=0)148grad = get_gradients(img, top_pred_idx=top_pred_idx)149grads.append(grad[0])150grads = tf.convert_to_tensor(grads, dtype=tf.float32)151152# 4. Approximate the integral using the trapezoidal rule153grads = (grads[:-1] + grads[1:]) / 2.0154avg_grads = tf.reduce_mean(grads, axis=0)155156# 5. Calculate integrated gradients and return157integrated_grads = (img_input - baseline) * avg_grads158return integrated_grads159160161def random_baseline_integrated_gradients(162img_input, top_pred_idx, num_steps=50, num_runs=2163):164"""Generates a number of random baseline images.165166Args:167img_input (ndarray): 3D image168top_pred_idx: Predicted label for the input image169num_steps: Number of interpolation steps between the baseline170and the input used in the computation of integrated gradients. These171steps along determine the integral approximation error. By default,172num_steps is set to 50.173num_runs: number of baseline images to generate174175Returns:176Averaged integrated gradients for `num_runs` baseline images177"""178# 1. List to keep track of Integrated Gradients (IG) for all the images179integrated_grads = []180181# 2. Get the integrated gradients for all the baselines182for run in range(num_runs):183baseline = np.random.random(img_size) * 255184igrads = get_integrated_gradients(185img_input=img_input,186top_pred_idx=top_pred_idx,187baseline=baseline,188num_steps=num_steps,189)190integrated_grads.append(igrads)191192# 3. Return the average integrated gradients for the image193integrated_grads = tf.convert_to_tensor(integrated_grads)194return tf.reduce_mean(integrated_grads, axis=0)195196197"""198## Helper class for visualizing gradients and integrated gradients199"""200201202class GradVisualizer:203"""Plot gradients of the outputs w.r.t an input image."""204205def __init__(self, positive_channel=None, negative_channel=None):206if positive_channel is None:207self.positive_channel = [0, 255, 0]208else:209self.positive_channel = positive_channel210211if negative_channel is None:212self.negative_channel = [255, 0, 0]213else:214self.negative_channel = negative_channel215216def apply_polarity(self, attributions, polarity):217if polarity == "positive":218return np.clip(attributions, 0, 1)219else:220return np.clip(attributions, -1, 0)221222def apply_linear_transformation(223self,224attributions,225clip_above_percentile=99.9,226clip_below_percentile=70.0,227lower_end=0.2,228):229# 1. Get the thresholds230m = self.get_thresholded_attributions(231attributions, percentage=100 - clip_above_percentile232)233e = self.get_thresholded_attributions(234attributions, percentage=100 - clip_below_percentile235)236237# 2. Transform the attributions by a linear function f(x) = a*x + b such that238# f(m) = 1.0 and f(e) = lower_end239transformed_attributions = (1 - lower_end) * (np.abs(attributions) - e) / (240m - e241) + lower_end242243# 3. Make sure that the sign of transformed attributions is the same as original attributions244transformed_attributions *= np.sign(attributions)245246# 4. Only keep values that are bigger than the lower_end247transformed_attributions *= transformed_attributions >= lower_end248249# 5. Clip values and return250transformed_attributions = np.clip(transformed_attributions, 0.0, 1.0)251return transformed_attributions252253def get_thresholded_attributions(self, attributions, percentage):254if percentage == 100.0:255return np.min(attributions)256257# 1. Flatten the attributions258flatten_attr = attributions.flatten()259260# 2. Get the sum of the attributions261total = np.sum(flatten_attr)262263# 3. Sort the attributions from largest to smallest.264sorted_attributions = np.sort(np.abs(flatten_attr))[::-1]265266# 4. Calculate the percentage of the total sum that each attribution267# and the values about it contribute.268cum_sum = 100.0 * np.cumsum(sorted_attributions) / total269270# 5. Threshold the attributions by the percentage271indices_to_consider = np.where(cum_sum >= percentage)[0][0]272273# 6. Select the desired attributions and return274attributions = sorted_attributions[indices_to_consider]275return attributions276277def binarize(self, attributions, threshold=0.001):278return attributions > threshold279280def morphological_cleanup_fn(self, attributions, structure=np.ones((4, 4))):281closed = ndimage.grey_closing(attributions, structure=structure)282opened = ndimage.grey_opening(closed, structure=structure)283return opened284285def draw_outlines(286self,287attributions,288percentage=90,289connected_component_structure=np.ones((3, 3)),290):291# 1. Binarize the attributions.292attributions = self.binarize(attributions)293294# 2. Fill the gaps295attributions = ndimage.binary_fill_holes(attributions)296297# 3. Compute connected components298connected_components, num_comp = ndimage.label(299attributions, structure=connected_component_structure300)301302# 4. Sum up the attributions for each component303total = np.sum(attributions[connected_components > 0])304component_sums = []305for comp in range(1, num_comp + 1):306mask = connected_components == comp307component_sum = np.sum(attributions[mask])308component_sums.append((component_sum, mask))309310# 5. Compute the percentage of top components to keep311sorted_sums_and_masks = sorted(component_sums, key=lambda x: x[0], reverse=True)312sorted_sums = list(zip(*sorted_sums_and_masks))[0]313cumulative_sorted_sums = np.cumsum(sorted_sums)314cutoff_threshold = percentage * total / 100315cutoff_idx = np.where(cumulative_sorted_sums >= cutoff_threshold)[0][0]316if cutoff_idx > 2:317cutoff_idx = 2318319# 6. Set the values for the kept components320border_mask = np.zeros_like(attributions)321for i in range(cutoff_idx + 1):322border_mask[sorted_sums_and_masks[i][1]] = 1323324# 7. Make the mask hollow and show only the border325eroded_mask = ndimage.binary_erosion(border_mask, iterations=1)326border_mask[eroded_mask] = 0327328# 8. Return the outlined mask329return border_mask330331def process_grads(332self,333image,334attributions,335polarity="positive",336clip_above_percentile=99.9,337clip_below_percentile=0,338morphological_cleanup=False,339structure=np.ones((3, 3)),340outlines=False,341outlines_component_percentage=90,342overlay=True,343):344if polarity not in ["positive", "negative"]:345raise ValueError(f""" Allowed polarity values: 'positive' or 'negative'346but provided {polarity}""")347if clip_above_percentile < 0 or clip_above_percentile > 100:348raise ValueError("clip_above_percentile must be in [0, 100]")349350if clip_below_percentile < 0 or clip_below_percentile > 100:351raise ValueError("clip_below_percentile must be in [0, 100]")352353# 1. Apply polarity354if polarity == "positive":355attributions = self.apply_polarity(attributions, polarity=polarity)356channel = self.positive_channel357else:358attributions = self.apply_polarity(attributions, polarity=polarity)359attributions = np.abs(attributions)360channel = self.negative_channel361362# 2. Take average over the channels363attributions = np.average(attributions, axis=2)364365# 3. Apply linear transformation to the attributions366attributions = self.apply_linear_transformation(367attributions,368clip_above_percentile=clip_above_percentile,369clip_below_percentile=clip_below_percentile,370lower_end=0.0,371)372373# 4. Cleanup374if morphological_cleanup:375attributions = self.morphological_cleanup_fn(376attributions, structure=structure377)378# 5. Draw the outlines379if outlines:380attributions = self.draw_outlines(381attributions, percentage=outlines_component_percentage382)383384# 6. Expand the channel axis and convert to RGB385attributions = np.expand_dims(attributions, 2) * channel386387# 7.Superimpose on the original image388if overlay:389attributions = np.clip((attributions * 0.8 + image), 0, 255)390return attributions391392def visualize(393self,394image,395gradients,396integrated_gradients,397polarity="positive",398clip_above_percentile=99.9,399clip_below_percentile=0,400morphological_cleanup=False,401structure=np.ones((3, 3)),402outlines=False,403outlines_component_percentage=90,404overlay=True,405figsize=(15, 8),406):407# 1. Make two copies of the original image408img1 = np.copy(image)409img2 = np.copy(image)410411# 2. Process the normal gradients412grads_attr = self.process_grads(413image=img1,414attributions=gradients,415polarity=polarity,416clip_above_percentile=clip_above_percentile,417clip_below_percentile=clip_below_percentile,418morphological_cleanup=morphological_cleanup,419structure=structure,420outlines=outlines,421outlines_component_percentage=outlines_component_percentage,422overlay=overlay,423)424425# 3. Process the integrated gradients426igrads_attr = self.process_grads(427image=img2,428attributions=integrated_gradients,429polarity=polarity,430clip_above_percentile=clip_above_percentile,431clip_below_percentile=clip_below_percentile,432morphological_cleanup=morphological_cleanup,433structure=structure,434outlines=outlines,435outlines_component_percentage=outlines_component_percentage,436overlay=overlay,437)438439_, ax = plt.subplots(1, 3, figsize=figsize)440ax[0].imshow(image)441ax[1].imshow(grads_attr.astype(np.uint8))442ax[2].imshow(igrads_attr.astype(np.uint8))443444ax[0].set_title("Input")445ax[1].set_title("Normal gradients")446ax[2].set_title("Integrated gradients")447plt.show()448449450"""451## Let's test-drive it452"""453454# 1. Convert the image to numpy array455img = get_img_array(img_path)456457# 2. Keep a copy of the original image458orig_img = np.copy(img[0]).astype(np.uint8)459460# 3. Preprocess the image461img_processed = tf.cast(xception.preprocess_input(img), dtype=tf.float32)462463# 4. Get model predictions464preds = model.predict(img_processed)465top_pred_idx = tf.argmax(preds[0])466print("Predicted:", top_pred_idx, xception.decode_predictions(preds, top=1)[0])467468# 5. Get the gradients of the last layer for the predicted label469grads = get_gradients(img_processed, top_pred_idx=top_pred_idx)470471# 6. Get the integrated gradients472igrads = random_baseline_integrated_gradients(473np.copy(orig_img), top_pred_idx=top_pred_idx, num_steps=50, num_runs=2474)475476# 7. Process the gradients and plot477vis = GradVisualizer()478vis.visualize(479image=orig_img,480gradients=grads[0].numpy(),481integrated_gradients=igrads.numpy(),482clip_above_percentile=99,483clip_below_percentile=0,484)485486vis.visualize(487image=orig_img,488gradients=grads[0].numpy(),489integrated_gradients=igrads.numpy(),490clip_above_percentile=95,491clip_below_percentile=28,492morphological_cleanup=True,493outlines=True,494)495496"""497## Relevant Chapters from Deep Learning with Python498- [Chapter 10: Interpreting what ConvNets learn](https://deeplearningwithpython.io/chapters/chapter10_interpreting-what-convnets-learn)499"""500501502