Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/book1/02/discrete_prob_dist_plot.ipynb
1193 views
Kernel: Python [conda env:pyprobml]

Discrete Probability Distribution Plot

import jax import jax.numpy as jnp import matplotlib.pyplot as plt import seaborn as sns try: from probml_utils import savefig, latexify except ModuleNotFoundError: %pip install -qq git+https://github.com/probml/probml-utils.git from probml_utils import savefig, latexify
latexify(width_scale_factor=2)
/home/patel_zeel/probml-utils/probml_utils/plotting.py:25: UserWarning: LATEXIFY environment variable not set, not latexifying warnings.warn("LATEXIFY environment variable not set, not latexifying")
# Bar graphs showing a uniform discrete distribution and another with full mass on one value. N = 4 def make_graph(probs, N, save_name, fig=None, ax=None): x = jnp.arange(1, N + 1) if fig is None: fig, ax = plt.subplots() ax.bar(x, probs, align="center") ax.set_xlim([min(x) - 0.5, max(x) + 0.5]) ax.set_xticks(x) ax.set_yticks(jnp.linspace(0, 1, N + 1)) ax.set_xlabel("$x$") ax.set_ylabel("$Pr(X=x)$") sns.despine() if len(save_name) > 0: savefig(save_name) return fig, ax uniform_probs = jnp.repeat(1.0 / N, N) _, _ = make_graph( uniform_probs, N, "uniform_histogram" ) # Do not add .pdf or .png as it is automatically added by savefig method delta_probs = jnp.array([1, 0, 0, 0]) _, _ = make_graph(delta_probs, N, "delta_histogram");
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.) /home/patel_zeel/probml-utils/probml_utils/plotting.py:65: UserWarning: set FIG_DIR environment variable to save figures warnings.warn("set FIG_DIR environment variable to save figures")
Image in a Jupyter notebookImage in a Jupyter notebook

Demo

You can see different examples of discrete distributions by changing the seed in the following demo.

from ipywidgets import interact @interact(random_state=(1, 10), N=(2, 10)) def generate_random(random_state, N): key = jax.random.PRNGKey(random_state) probs = jax.random.uniform(key, shape=(N,)) probs = probs / jnp.sum(probs) fig, ax = make_graph(probs, N, "") ax.set_yticks(jnp.linspace(0, 1, 11))
interactive(children=(IntSlider(value=5, description='random_state', max=10, min=1), IntSlider(value=6, descri…