Path: blob/main/C2 - Advanced Learning Algorithms/week2/optional-labs/backprop/lab_utils_backprop.py
3589 views
from sympy import *1import numpy as np2import re34import matplotlib.pyplot as plt5from matplotlib.widgets import TextBox6from matplotlib.widgets import Button7import ipywidgets as widgets89def widgvis(fig):10fig.canvas.toolbar_visible = False11fig.canvas.header_visible = False12fig.canvas.footer_visible = False1314def between(a, b, x):15''' determine if a point x is between a and b. a may be greater or less than b '''16if a > b:17return b <= x <= a18if b > a:19return a <= x <= b2021def near(pt, alist, dist=15):22for a in alist:23x, y = a.ao.get_position() #(bot left, bot right) data coords, not relative24x = x - 525y = y + 2.526if 0 < (pt[0] - x) < 25 and 0 < (y - pt[1]) < 25:27return(True, a)28return(False,None)2930def inboxes(pt, boxlist):31''' returns true if pt is within one of the boxes in boxlist '''32#with out:33# print(f" inboxes:{boxlist}, {pt}")34for b in boxlist:35if b.inbox(pt):36return(True, b)37return(False, None)383940class avalue():41''' one of the values on the figure that can be filled in '''42def __init__(self, value, pt, cl):43self.value = value44self.cl = cl # color45self.pt = pt # point4647def add_anote(self, ax):48self.ax = ax49self.ao = self.ax.annotate("?", self.pt, c=self.cl, fontsize='x-small')5051class astring():52''' a string that can be set visible or invisible '''53def __init__(self, ax, string, pt, cl):54self.string = string55self.cl = cl # color56self.pt = pt # point57self.ax = ax58self.ao = self.ax.annotate(self.string, self.pt, c="white", fontsize='x-small')5960def astring_visible(self):61self.ao.set_color(self.cl)6263def astring_invisible(self):64self.ao.set_color("white")656667class abox():68''' one of the boxes in the graph that has a value '''69def __init__(self, ax, value, left, bottom, right, top, anpt, cl, adj_anote_obj):70self.ax = ax71self.value = value # correct value for annotation72self.left = left73self.right = right74self.bottom = bottom75self.top = top76self.anpt= anpt # x,y where expression should be listed77self.cl = cl78self.ao = self.ax.annotate("?", self.anpt, c=self.cl, fontsize='x-small')79self.astr = adj_anote_obj # 2ndary text for marking edges or none8081def inbox(self, pt):82''' true if point is within the box '''83#with out: #debug84# print(f" b.inbox: {pt}")85x, y = pt86isbetween = between(self.top, self.bottom, y) and between(self.left, self.right, x)87return isbetween8889def update_val(self, value, cl=None):90self.ao.set_text(value)91if cl:92self.ao.set_c(cl)93else:94self.ao.set_c(self.cl)9596def show_secondary(self):97if self.astr: # if there is a 2ndary set of text98self.astr.ao.set_c("green")99100def clear_secondary(self):101if self.astr: # if there is a 2ndary set of text102self.astr.ao.set_c("white")103104105106## For debug, put this in the notebook being debugged and be sure to set the out=out parameter107#out = widgets.Output(layout={'border': '1px solid black'})108#out109110class plt_network():111112def __init__(self, fn, image, out=None):113self.out = out # debug114#with self.out:115# print("hello world")116img = plt.imread(image)117self.fig, self.ax = plt.subplots(figsize=self.sizefig(img))118boxes = fn(self.ax)119self.boxes = boxes120widgvis(self.fig)121self.ax.xaxis.set_visible(False)122self.ax.yaxis.set_visible(False)123self.ax.imshow(img)124self.fig.text(0.1,0.9, "Click in boxes to fill in values.")125self.glist = [] # place to stash global things126self.san = [] # selected annotation127128self.cid = self.fig.canvas.mpl_connect('button_press_event', self.onclick)129self.axreveal = plt.axes([0.55, 0.02, 0.15, 0.075]) #[left, bottom, width, height]130self.axhide = plt.axes([0.76, 0.02, 0.15, 0.075])131self.breveal = Button(self.axreveal, 'Reveal All')132self.breveal.on_clicked(self.reveal_values)133self.bhide = Button(self.axhide, 'Hide All')134self.bhide.on_clicked(self.hide_values)135#plt.show()136137def sizefig(self,img):138iy,ix,iz = np.shape(img)139if 10/5 < ix/iy: # if x is the limiting size140figx = 10141figy = figx*iy/ix142else:143figy = 5144figx = figy*ix/iy145return(figx,figy)146147def updateval(self, event):148#with self.out: #debug149# print(event)150box = self.san[0]151num_format = re.compile(r"[+-]?\d+(?:\.\d+)?")152isnumber = re.match(num_format,event)153if not isnumber:154box.update_val('?','red')155else:156#with self.out:157# print(event)158newval = int(float(event)) if int(float(event)) == float(event) else float(event)159newval = round(newval,2)160#with self.out:161# print(newval, box.value, type(newval), type(box.value))162if newval == box.value:163box.show_secondary()164box.update_val(round(newval,2))165else:166box.update_val(round(newval,2), 'red')167box.clear_secondary()168self.glist[0].remove()169self.glist.clear()170self.san.clear()171172# collects all clicks within diagram and dispatches173def onclick(self, event):174#with self.out:175# print('%s click: button=%d, x=%d, y=%d, xdata=%f, ydata=%f' %176# ('double' if event.dblclick else 'single', event.button,177# event.x, event.y, event.xdata, event.ydata))178if len(self.san) != 0: # already waiting for new value179return180inbox, box = inboxes((event.xdata, event.ydata), self.boxes)181#with self.out:182# print(f" in box: {inbox, box}")183if inbox:184self.san.append(box)185#an.set_text(an.get_text() + "1") # debug186graphBox = self.fig.add_axes([0.225, 0.02, 0.2, 0.075]) # [left, bottom, width, height]187txtBox = TextBox(graphBox, "newvalue: ")188txtBox.on_submit(self.updateval)189self.glist.append(graphBox)190self.glist.append(txtBox)191return192193def reveal_values(self, event):194for b in self.boxes:195b.update_val(b.value)196b.show_secondary()197plt.draw()198199def hide_values(self, event):200for b in self.boxes:201b.update_val("?")202b.clear_secondary()203plt.draw()204205#--------------------------------------------------------------------------206207208def config_nw0(ax):209#"./images/C2_W2_BP_network0.PNG"210211w = 3212a = 2+3*w213J = a**2214215pass ; dJ_dJ = 1216dJ_da = 2*a ; dJ_da = dJ_dJ * dJ_da217da_dw = 3 ; dJ_dw = dJ_da * da_dw218219box1 = abox(ax, round(a,2), 307, 140, 352, 100, (315, 128),'blue', None) # left, bottom, right, top,220box2 = abox(ax, round(J,2), 581, 138, 624, 100, (589, 128),'blue', None)221222dJ_da_a = astring(ax, r"$\frac{\partial J}{\partial a}=$"+f"{dJ_da}", (291,186), "green")223box3 = abox(ax, round(dJ_da,2), 545, 417, 588, 380, (553,407), 'green', dJ_da_a)224225dJ_dw_a = astring(ax, r"$\frac{\partial J}{\partial w}=$"+f"{dJ_dw}", (60,186), "green")226box4 = abox(ax, round(da_dw,2), 195, 421, 237, 380, (203,411), 'green', None)227box5 = abox(ax, round(dJ_dw,2), 265, 515, 310, 475, (273,505), 'green', dJ_dw_a)228229boxes = [box1, box2, box3, box4, box5]230231return boxes232233def config_nw1(ax):234# "./images/C2_W2_BP_Network1.PNG"235236x = 2237w = -2238b = 8239y = 1240241c = w * x242a = c + b243d = a - y244J = d**2/2245246pass ; dJ_dJ = 1247dJ_dd = 2*d/2 ; dJ_dd = dJ_dJ * dJ_dd248dd_da = 1 ; dJ_da = dJ_dd * dd_da249da_db = 1 ; dJ_db = dJ_da * da_db250da_dc = 1 ; dJ_dc = dJ_da * da_dc251dc_dw = x ; dJ_dw = dJ_dc * dc_dw252253box1 = abox(ax, round(c,2), 330, 162, 382, 114, (338, 150),'blue', None) # left, bottom, right, top,254box2 = abox(ax, round(a,2), 636, 162, 688, 114, (644, 150),'blue', None)255box3 = abox(ax, round(d,2), 964, 162, 1015, 114, (972, 150),'blue', None)256box4 = abox(ax, round(J,2), 1266, 162, 1315, 114, (1274,150),'blue', None)257258dJ_dd_a = astring(ax, r"$\frac{\partial J}{\partial d}=$"+f"{dJ_dd}", (967,208), "green")259box5 = abox(ax, round(dJ_dd,2), 1222, 488, 1275, 441, (1230,478), 'green', dJ_dd_a)260261dJ_da_a = astring(ax, r"$\frac{\partial J}{\partial a}=$"+f"{dJ_da}", (615,208), "green")262box6 = abox(ax, round(dd_da,2), 900, 383, 951, 333, (908,373), 'green', None)263box7 = abox(ax, round(dJ_da,2), 988, 483, 1037, 441, (996,473), 'green', dJ_da_a)264265dJ_dc_a = astring(ax, r"$\frac{\partial J}{\partial c}=$"+f"{dJ_dc}", (337,208), "green")266box8 = abox(ax, round(da_dc,2), 570, 380, 620, 333, (578,370), 'green', None)267box9 = abox(ax, round(dJ_dc,2), 638, 467, 688, 419, (646,457), 'green', dJ_dc_a)268269dJ_db_a = astring(ax, r"$\frac{\partial J}{\partial b}=$"+f"{dJ_dc}", (474,252), "green")270box10 = abox(ax, round(da_db,2), 563, 582, 615, 533, (571,572), 'green', None)271box11 = abox(ax, round(dJ_db,2), 630, 677, 684, 630, (638,667), 'green', dJ_db_a)272273dJ_dw_a = astring(ax, r"$\frac{\partial J}{\partial w}=$"+f"{dJ_dw}", (60,208), "green")274box12 = abox(ax, round(dc_dw,2), 191, 379, 341, 332, (199,369), 'green', None)275box13 = abox(ax, round(dJ_dw,2), 266, 495, 319, 448, (274,485), 'green', dJ_dw_a)276277boxes = [box1, box2, box3, box4, box5, box6, box7, box8, box9, box10, box11, box12, box13]278279return boxes280281#not used282def config_nw2():283x0 = 1284x1 = 2285w0 = -2286w1 = 3287b = -4288y = 1289d = x0 * w0290e = x1 * w1291f = d+e+b292g = -f293h = np.exp(g)294i = h+1295a = 1/i296k = y-a297L = k**2298299pass ; dL_dL = 1300dL_dk = 2*k ; dL_dk = dL_dL * dL_dk301dk_da = -1 ; dL_da = dL_dk * dk_da302da_di = -1/i**2 ; dL_di = dL_da * da_di303di_dh = 1 ; dL_dh = dL_di * di_dh304dh_dg = exp(g) ; dL_dg = dL_dh * dh_dg305dg_df = -1 ; dL_df = dL_dg * dg_df306df_dd = 1 ; dL_dd = dL_df * df_dd307df_de = 2 ; dL_de = dL_df * df_de308df_db = 1 ; dL_db = dL_df * df_db309dd_dw0 = 1 ; dL_dw0 = dL_dd * dd_dw0310de_dw1 = 2 ; dL_dw1 = dL_de * de_dw1311312an1 = avalue(round(d,2), (270,265), 'blue')313an2 = avalue(round(e,2), (270,350), 'blue')314an3 = avalue(round(f,2), (400,315), 'blue')315an4 = avalue(round(g,2), (540,315), 'blue')316an5 = avalue(round(h,2), (650,315), 'blue')317an6 = avalue(round(i,2), (760,315), 'blue')318an7 = avalue(round(a,2), (890,315), 'blue')319an8 = avalue(round(k,2), (1015,315), 'blue')320an9 = avalue(round(L,2), (1120,315), 'blue')321bn1 = avalue(round(dL_dd,2), (260,300), 'green') #d322bn2 = avalue(round(dL_de,2), (270,385), 'green') #e323bn3 = avalue(round(dL_df,2), (408,350), 'green') #f324bn4 = avalue(round(dL_dg,2), (540,350), 'green') #g325bn5 = avalue(round(dL_dh,2), (650,350), 'green') #h326bn6 = avalue(round(dL_di,2), (760,350), 'green') #i327bn7 = avalue(round(dL_da,2), (890,350), 'green') #a328bn8 = avalue(round(dL_dk,2), (1015,350), 'green') #k329bn9 = avalue(round(dL_dw0,2), (210,300), 'green') #w0330bn10 = avalue(round(dL_dw1,2), (205,440),'green') #w1331bn11 = avalue(round(dL_db,2), (345,385), 'green') #b332333anotes = [an1, an2, an3, an4, an5, an6, an7, an8, an9,334bn1, bn2, bn3, bn4, bn5, bn6, bn7, bn8, bn9, bn10, bn11]335336box1 = abox(r"$\frac{\partial v}{\partial t}$", 943, 347, 980, 310, (980,300))337boxes = [box1]338339fn = "./images/C2_W2_BP_bkground.PNG"340return fn, anotes, boxes341342