Path: blob/main/C2 - Advanced Learning Algorithms/week4/optional labs/utils.py
3586 views
from PIL import Image1import networkx as nx2import matplotlib.pyplot as plt3from networkx.drawing.nx_pydot import graphviz_layout4import numpy as np5from matplotlib.widgets import Slider, Button6plt.style.use('./deeplearning.mplstyle')78def compute_entropy(y):910entropy = 01112if len(y) == 0:13return 014entropy = sum(y[y==1])/len(y)15if entropy == 0 or entropy == 1:16return 017else:18return -entropy*np.log2(entropy) - (1-entropy)*np.log2(1-entropy)192021def split_dataset(X, node_indices, feature):2223left_indices = []24right_indices = []2526for i in node_indices:27if X[i][feature] == 1:28left_indices.append(i)29else:30right_indices.append(i)3132return left_indices, right_indices33343536def compute_information_gain(X, y, node_indices, feature):3738left_indices, right_indices = split_dataset(X, node_indices, feature)3940X_node, y_node = X[node_indices], y[node_indices]41X_left, y_left = X[left_indices], y[left_indices]42X_right, y_right = X[right_indices], y[right_indices]4344information_gain = 04546node_entropy = compute_entropy(y_node)47left_entropy = compute_entropy(y_left)48right_entropy = compute_entropy(y_right)49w_left = len(X_left) / len(X_node)50w_right = len(X_right) / len(X_node)51weighted_entropy = w_left * left_entropy + w_right * right_entropy52information_gain = node_entropy - weighted_entropy5354return information_gain5556def get_best_split(X, y, node_indices):57num_features = X.shape[1]5859best_feature = -16061max_info_gain = 062for feature in range(num_features):63info_gain = compute_information_gain(X, y, node_indices, feature)64if info_gain > max_info_gain:65max_info_gain = info_gain66best_feature = feature676869return best_feature707172def build_tree_recursive(X, y, node_indices, branch_name, max_depth, current_depth, tree):7374if current_depth == max_depth:75formatting = " "*current_depth + "-"*current_depth76print(formatting, "%s leaf node with indices" % branch_name, node_indices)77return787980best_feature = get_best_split(X, y, node_indices)8182formatting = "-"*current_depth83print("%s Depth %d, %s: Split on feature: %d" % (formatting, current_depth, branch_name, best_feature))848586left_indices, right_indices = split_dataset(X, node_indices, best_feature)87tree.append((left_indices, right_indices, best_feature))8889build_tree_recursive(X, y, left_indices, "Left", max_depth, current_depth+1, tree)90build_tree_recursive(X, y, right_indices, "Right", max_depth, current_depth+1, tree)91return tree9293def generate_node_image(node_indices):94image_paths = ["images/%d.png" % idx for idx in node_indices]95images = [Image.open(x) for x in image_paths]96widths, heights = zip(*(i.size for i in images))9798total_width = sum(widths)99max_height = max(heights)100101new_im = Image.new('RGB', (total_width, max_height))102103x_offset = 0104for im in images:105new_im.paste(im, (x_offset,0))106x_offset += im.size[0]107108new_im = new_im.resize((int(total_width*len(node_indices)/10), int(max_height*len(node_indices)/10)))109110return new_im111112113def generate_split_viz(node_indices, left_indices, right_indices, feature):114115G=nx.DiGraph()116117indices_list = [node_indices, left_indices, right_indices]118for idx, indices in enumerate(indices_list):119G.add_node(idx,image= generate_node_image(indices))120121G.add_edge(0,1)122G.add_edge(0,2)123124pos = graphviz_layout(G, prog="dot")125126fig=plt.figure()127ax=plt.subplot(111)128ax.set_aspect('equal')129nx.draw_networkx_edges(G,pos,ax=ax, arrows=True, arrowsize=40)130131trans=ax.transData.transform132trans2=fig.transFigure.inverted().transform133134feature_name = ["Ear Shape", "Face Shape", "Whiskers"][feature]135ax_name = ["Splitting on %s" % feature_name , "Left: %s = 1" % feature_name, "Right: %s = 0" % feature_name]136for idx, n in enumerate(G):137xx,yy=trans(pos[n]) # figure coordinates138xa,ya=trans2((xx,yy)) # axes coordinates139piesize = len(indices_list[idx])/9140p2=piesize/2.0141a = plt.axes([xa-p2,ya-p2, piesize, piesize])142a.set_aspect('equal')143a.imshow(G.nodes[n]['image'])144a.axis('off')145a.set_title(ax_name[idx])146ax.axis('off')147plt.show()148149150def generate_tree_viz(root_indices, y, tree):151152G=nx.DiGraph()153154155G.add_node(0,image= generate_node_image(root_indices))156idx = 1157root = 0158159num_images = [len(root_indices)]160161feature_name = ["Ear Shape", "Face Shape", "Whiskers"]162y_name = ["Non Cat","Cat"]163164decision_names = []165leaf_names = []166167for i, level in enumerate(tree):168indices_list = level[:2]169for indices in indices_list:170G.add_node(idx,image= generate_node_image(indices))171G.add_edge(root, idx)172173# For visualization174num_images.append(len(indices))175idx += 1176if i > 0:177leaf_names.append("Leaf node: %s" % y_name[max(y[indices])])178179decision_names.append("Split on: %s" % feature_name[level[2]])180root += 1181182183node_names = decision_names + leaf_names184pos = graphviz_layout(G, prog="dot")185186fig=plt.figure(figsize=(14, 10))187ax=plt.subplot(111)188ax.set_aspect('equal')189nx.draw_networkx_edges(G,pos,ax=ax, arrows=True, arrowsize=40)190191trans=ax.transData.transform192trans2=fig.transFigure.inverted().transform193194for idx, n in enumerate(G):195xx,yy=trans(pos[n]) # figure coordinates196xa,ya=trans2((xx,yy)) # axes coordinates197piesize = num_images[idx]/25198p2=piesize/2.0199a = plt.axes([xa-p2,ya-p2, piesize, piesize])200a.set_aspect('equal')201a.imshow(G.nodes[n]['image'])202a.axis('off')203try:204a.set_title(node_names[idx], y=-0.8, fontsize=13, loc="left")205except:206pass207ax.axis('off')208plt.show()209210def plot_entropy():211def entropy(p):212if p == 0 or p == 1:213return 0214else:215return -p * np.log2(p) - (1- p)*np.log2(1 - p)216p_array = np.linspace(0,1,201)217h_array = [entropy(p) for p in p_array]218fig, ax = plt.subplots()219plt.subplots_adjust(left=0.25, bottom=0.25)220ax.set_title('p x H(p)')221ax.set_xlabel('p')222ax.set_ylabel('H(p)')223axfreq = plt.axes([0.25, 0.1, 0.65, 0.03])224h_plot = ax.plot(p_array,h_array)225scatter = ax.scatter(0,0,color = 'red', zorder = 100, s = 70)226slider = Slider(axfreq, 'p', 0, 1, valinit = 0, valstep = 0.05)227228def update(val):229x = val230y = entropy(x)231scatter.set_offsets((x,y))232233slider.on_changed(update)234return slider235#plt.plot()236237