Path: blob/main/C2 - Advanced Learning Algorithms/week2/optional-labs/lab_utils_relu.py
3584 views
import numpy as np1import matplotlib.pyplot as plt2from matplotlib.gridspec import GridSpec3plt.style.use('./deeplearning.mplstyle')4from matplotlib.widgets import Slider5from lab_utils_common import dlc67def widgvis(fig):8fig.canvas.toolbar_visible = False9fig.canvas.header_visible = False10fig.canvas.footer_visible = False111213def plt_base(ax):14X = np.linspace(0, 3, 3*100)15y = np.r_[ -2*X[0:100]+2, 1*X[100:200]-3+2, 3*X[200:300]-7+2 ]16w00 = -217b00 = 218w01 = 0 # 119b01 = 0 # -120w02 = 0 # 221b02 = 0 # -422ax[0].plot(X, y, color = dlc["dlblue"], label="target")23arts = []24arts.extend( plt_yhat(ax[0], X, w00, b00, w01, b01, w02, b02) )25_ = plt_unit(ax[1], X, w00, b00) #Fixed26arts.extend( plt_unit(ax[2], X, w01, b01) )27arts.extend( plt_unit(ax[3], X, w02, b02) )28return(X, arts)2930def plt_yhat(ax, X, w00, b00, w01, b01, w02, b02):31yhat = np.maximum(0, np.dot(w00, X) + b00) + \32np.maximum(0, np.dot(w01, X) + b01) + \33np.maximum(0, np.dot(w02, X) + b02)34lp = ax.plot(X, yhat, lw=2, color = dlc["dlorange"], label="a2")35return(lp)3637def plt_unit(ax, X, w, b):38z = np.dot(w,X) + b39yhat = np.maximum(0,z)40lpa = ax.plot(X, z, dlc["dlblue"], label="z")41lpb = ax.plot(X, yhat, dlc["dlmagenta"], lw=1, label="a")42return([lpa[0], lpb[0]])4344# if output is need for debug, put this in a cell and call ahead of time. Output will be below that cell.45#from ipywidgets import Output #this line stays here46#output = Output() #this line stays here47#display(output) #this line goes in notebook4849def plt_relu_ex():50artists = []5152fig = plt.figure()53fig.suptitle("Explore Non-Linear Activation")5455gs = GridSpec(3, 2, width_ratios=[2, 1], height_ratios=[1, 1, 1])56ax1 = fig.add_subplot(gs[0:2,0])57ax2 = fig.add_subplot(gs[0,1])58ax3 = fig.add_subplot(gs[1,1])59ax4 = fig.add_subplot(gs[2,1])60ax = [ax1,ax2,ax3,ax4]6162widgvis(fig)63#plt.subplots_adjust(bottom=0.35)6465axb2 = fig.add_axes([0.15, 0.10, 0.30, 0.03]) # [left, bottom, width, height]66axw2 = fig.add_axes([0.15, 0.15, 0.30, 0.03])67axb1 = fig.add_axes([0.15, 0.20, 0.30, 0.03])68axw1 = fig.add_axes([0.15, 0.25, 0.30, 0.03])6970sw1 = Slider(axw1, 'w1', -4.0, 4.0, valinit=0, valstep=0.1)71sb1 = Slider(axb1, 'b1', -4.0, 4.0, valinit=0, valstep=0.1)72sw2 = Slider(axw2, 'w2', -4.0, 4.0, valinit=0, valstep=0.1)73sb2 = Slider(axb2, 'b2', -4.0, 4.0, valinit=0, valstep=0.1)7475X,lp = plt_base(ax)76artists.extend( lp )7778#@output.capture()79def update(val):80#print("-----------")81#print(f"len artists {len(artists)}", artists)82for i in range(len(artists)):83artist = artists[i]84#print("artist:", artist)85artist.remove()86artists.clear()87#print(artists)88w00 = -289b00 = 290w01 = sw1.val # 191b01 = sb1.val # -192w02 = sw2.val # 293b02 = sb2.val # -494artists.extend(plt_yhat(ax[0], X, w00, b00, w01, b01, w02, b02))95artists.extend(plt_unit(ax[2], X, w01, b01) )96artists.extend(plt_unit(ax[3], X, w02, b02) )97#fig.canvas.draw_idle()9899sw1.on_changed(update)100sb1.on_changed(update)101sw2.on_changed(update)102sb2.on_changed(update)103104ax[0].set_title(" Match Target ")105ax[0].legend()106ax[0].set_xlabel("x")107ax[1].set_title("Unit 0 (fixed) ")108ax[1].legend()109ax[2].set_title("Unit 1")110ax[2].legend()111ax[3].set_title("Unit 2")112ax[3].legend()113plt.tight_layout()114115plt.show()116return([sw1,sw2,sb1,sb2,artists]) # returned to keep a live reference to sliders117118119120