Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/book1/02/change_of_vars_demo1d.ipynb
1193 views
Kernel: Python 3 (ipykernel)

Monte Carlo approximation on Uniform distribution

import jax.numpy as jnp from jax import random import matplotlib.pyplot as plt import seaborn as sns try: from probml_utils import savefig, latexify, is_latexify_enabled except: %pip install git+https://github.com/probml/probml-utils.git from probml_utils import savefig, latexify, is_latexify_enabled
latexify(width_scale_factor=1, fig_height=2)
/home/rohit_khoiwal/.local/lib/python3.8/site-packages/probml_utils/plotting.py:26: UserWarning: LATEXIFY environment variable not set, not latexifying warnings.warn("LATEXIFY environment variable not set, not latexifying")
x_samples = jnp.linspace(-1, 1, 200) lower_limit = -1 upper_limit = 1 px_uniform = 1 / (upper_limit - lower_limit) * jnp.ones(len(x_samples)) square_fn = lambda x: x**2 y = square_fn(x_samples) # analytic y_pdf = 1 / (2 * jnp.sqrt(y + 1e-2)) # monte carlo n = 1000 key = random.PRNGKey(0) uniform_samples = random.uniform(key, shape=(n, 1), minval=lower_limit, maxval=upper_limit) fn_samples = square_fn(uniform_samples) print(jnp.mean(fn_samples))
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
0.31545982
if not is_latexify_enabled(): fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(15, 3)) else: fig, ax = plt.subplots(nrows=1, ncols=3) ax[0].set_title("Uniform distribution") ax[0].plot(x_samples, px_uniform, "-") ax[0].set_xlabel("$x$") ax[0].set_ylabel("$p(x)$") ax[1].set_title("Analytical p(y), $y(x)$ = $x^2$") ax[1].plot(y, y_pdf, "-", linewidth=2) ax[1].set_xlabel("$y$") ax[1].set_ylabel("$p(y)$") ax[2].set_title("Monte carlo approximation") sns.distplot(fn_samples, kde=False, ax=ax[2], bins=20, norm_hist=True, hist_kws=dict(edgecolor="k", linewidth=1)) ax[2].set_xlabel("$y$") ax[2].set_ylabel("$Frequency$") sns.despine() savefig("changeOfVars") plt.show()
/home/rohit_khoiwal/.local/lib/python3.8/site-packages/seaborn/distributions.py:2619: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms). warnings.warn(msg, FutureWarning) /home/rohit_khoiwal/.local/lib/python3.8/site-packages/probml_utils/plotting.py:79: UserWarning: set FIG_DIR environment variable to save figures warnings.warn("set FIG_DIR environment variable to save figures")
Image in a Jupyter notebook