Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/book1/02/softmax_plot.ipynb
1192 views
Kernel: prob_ml

Sotfmax Distribution

import jax import jax.numpy as jnp import matplotlib.pyplot as plt import seaborn as sns try: from probml_utils import savefig, latexify except ModuleNotFoundError: %pip install -qq git+https://github.com/probml/probml-utils.git from probml_utils import savefig, latexify
latexify(fig_height=1.5)
/home/tensorboy/dev/env/lib/python3.8/site-packages/probml_utils/plotting.py:26: UserWarning: LATEXIFY environment variable not set, not latexifying warnings.warn("LATEXIFY environment variable not set, not latexifying")
def softmax(a): e = jnp.exp((1.0 * jnp.array(a))) return e / jnp.sum(e)
def plot_softmax_distribution(T, a, save_name, fig=None, axs=None): ind = jnp.arange(1, len(a) + 1) fig, axs = plt.subplots(1, len(T), sharey="row") axs[0].set_ylabel(r"$S(a | T)$") for i, ax in enumerate(axs): ax.bar(ind, softmax(a / T[i])) ax.set_xlabel("logits (a)") ax.set_ylim(0, 1) ax.set_title(f"T = {T[i]}") sns.despine() plt.tight_layout() if len(save_name) > 0: savefig(save_name)
T_array = [100, 2, 1] a = jnp.array([3, 0, 1]) plot_softmax_distribution(T_array, a, f"softmax_temp")
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.) /home/tensorboy/dev/env/lib/python3.8/site-packages/probml_utils/plotting.py:80: UserWarning: set FIG_DIR environment variable to save figures warnings.warn("set FIG_DIR environment variable to save figures")
Image in a Jupyter notebook

Interactive figure for softmax distribution

from ipywidgets import interact @interact(T=(1, 100)) def generate_interactinve_graph(T): a = jnp.array([3, 0, 1]) ind = jnp.arange(1, len(a) + 1) plt.figure() plt.bar(ind, softmax(a / T)) plt.title(f"T = {T}") plt.ylim(0, 1) plt.xlabel("logits (a)") plt.ylabel("$S(a | T)$") sns.despine()
interactive(children=(IntSlider(value=50, description='T', min=1), Output()), _dom_classes=('widget-interact',…