Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
greyhatguy007
GitHub Repository: greyhatguy007/Machine-Learning-Specialization-Coursera
Path: blob/main/C2 - Advanced Learning Algorithms/week2/C2W2A1/lab_utils_softmax.py
3520 views
1
import numpy as np
2
import matplotlib.pyplot as plt
3
plt.style.use('./deeplearning.mplstyle')
4
import tensorflow as tf
5
from IPython.display import display, Markdown, Latex
6
from matplotlib.widgets import Slider
7
from lab_utils_common import dlc
8
9
10
def plt_softmax(my_softmax):
11
fig, ax = plt.subplots(1,2,figsize=(8,4))
12
plt.subplots_adjust(bottom=0.35)
13
14
axz0 = fig.add_axes([0.15, 0.10, 0.30, 0.03]) # [left, bottom, width, height]
15
axz1 = fig.add_axes([0.15, 0.15, 0.30, 0.03])
16
axz2 = fig.add_axes([0.15, 0.20, 0.30, 0.03])
17
axz3 = fig.add_axes([0.15, 0.25, 0.30, 0.03])
18
19
z3 = Slider(axz3, 'z3', 0.1, 10.0, valinit=4, valstep=0.1)
20
z2 = Slider(axz2, 'z2', 0.1, 10.0, valinit=3, valstep=0.1)
21
z1 = Slider(axz1, 'z1', 0.1, 10.0, valinit=2, valstep=0.1)
22
z0 = Slider(axz0, 'z0', 0.1, 10.0, valinit=1, valstep=0.1)
23
24
z = np.array(['z0','z1','z2','z3'])
25
bar = ax[0].barh(z, height=0.6, width=[z0.val,z1.val,z2.val,z3.val], left=None, align='center')
26
bars = bar.get_children()
27
ax[0].set_xlim([0,10])
28
ax[0].set_title("z input to softmax")
29
30
a = my_softmax(np.array([z0.val,z1.val,z2.val,z3.val]))
31
anames = np.array(['a0','a1','a2','a3'])
32
sbar = ax[1].barh(anames, height=0.6, width=a, left=None, align='center',color=dlc["dldarkred"])
33
sbars = sbar.get_children()
34
ax[1].set_xlim([0,1])
35
ax[1].set_title("softmax(z)")
36
37
def update(val):
38
bars[0].set_width(z0.val)
39
bars[1].set_width(z1.val)
40
bars[2].set_width(z2.val)
41
bars[3].set_width(z3.val)
42
a = my_softmax(np.array([z0.val,z1.val,z2.val,z3.val]))
43
sbars[0].set_width(a[0])
44
sbars[1].set_width(a[1])
45
sbars[2].set_width(a[2])
46
sbars[3].set_width(a[3])
47
48
fig.canvas.draw_idle()
49
50
z0.on_changed(update)
51
z1.on_changed(update)
52
z2.on_changed(update)
53
z3.on_changed(update)
54
55
plt.show()
56
57