Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/book1/13/activation_fun_deriv.ipynb
1192 views
Kernel: Python 3 (ipykernel)

Plot some neural net activation functions and their derivatives

Based on sec 4.1 of http://d2l.ai/chapter_multilayer-perceptrons/mlp.html

import jax import jax.numpy as jnp import seaborn as sns import matplotlib.pyplot as plt %matplotlib inline try: import probml_utils as pml except ModuleNotFoundError: %pip install -qq git+https://github.com/probml/probml-utils.git import probml_utils as pml
pml.latexify(width_scale_factor=2, fig_height=1.75)
/home/vishal/anaconda3/envs/pyprob/lib/python3.7/site-packages/probml_utils/plotting.py:26: UserWarning: LATEXIFY environment variable not set, not latexifying warnings.warn("LATEXIFY environment variable not set, not latexifying")
x = jnp.arange(-4.0, 4.0, 0.1) partial_leaky_relu = lambda x: jax.nn.leaky_relu(x, negative_slope=0.1) fns = [ jax.nn.sigmoid, # jax.nn.relu, partial_leaky_relu, jax.nn.elu, jax.nn.swish, jax.nn.gelu, ] names = [ "sigmoid", #'relu', "leaky-relu", "elu", "swish", "gelu", ] # evaluate functions and their gradients on a grid of points fdict, gdict = {}, {} for fn, name in zip(fns, names): y = fn(x) grad_fn = jax.grad(lambda x: fn(x)) grads = jax.vmap(grad_fn)(x) fdict[name] = y # vector of fun gdict[name] = grads # gradient wrt x(i)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
# Plot the funcitons styles = ["r-", "g--", "b-.", "m:", "k-"] ax = plt.subplot() for i, name in enumerate(names): lab = f"{name}" ax.plot(x, fdict[name], styles[i], linewidth=0.7, label=lab) ax.set_ylim(-0.5, 2) ax.legend(frameon=False, loc="upper left", borderaxespad=0.1, labelspacing=0.2) sns.despine() plt.title("Activation function") pml.savefig("activation-funs") plt.show() ax = plt.subplot() for i, name in enumerate(names): lab = f"{name}" ax.plot(x, gdict[name], styles[i], linewidth=0.7, label=lab) ax.set_ylim(-0.5, 2) ax.legend(frameon=False, loc="upper left", borderaxespad=0.1, labelspacing=0.2) sns.despine() plt.title("Gradient of activation function") pml.savefig("activation-funs-grad") plt.show()
/home/vishal/anaconda3/envs/pyprob/lib/python3.7/site-packages/probml_utils/plotting.py:80: UserWarning: set FIG_DIR environment variable to save figures warnings.warn("set FIG_DIR environment variable to save figures")
Image in a Jupyter notebookImage in a Jupyter notebook