def estimate_pi(num_samples=5000, radius=2):
global key
uniform = distrax.Uniform(low=-radius, high=radius)
x_uniform_samples = uniform.sample(seed=key, sample_shape=num_samples)
key, subkey = jax.random.split(key)
y_uniform_samples = uniform.sample(seed=key, sample_shape=num_samples)
dist = x_uniform_samples**2 + y_uniform_samples**2
inside = dist <= radius**2
samples = 4 * (radius**2) * inside
integral_estimate = jnp.mean(samples)
pi_estimate = integral_estimate / (radius**2)
std_err = jnp.sqrt(jnp.var(samples) / num_samples)
print(("the estimated pi = %f" % pi_estimate))
print(("the standard pi = %f" % jnp.pi))
print(("stderr = %f" % std_err))
if pml.is_latexify_enabled():
FIG_SIZE = None
MARKER_SIZE = 1
else:
FIG_SIZE = (5, 5)
MARKER_SIZE = 4
fig, ax = plt.subplots(figsize=FIG_SIZE)
ax.plot(x_uniform_samples[inside], y_uniform_samples[inside], "bo", markersize=MARKER_SIZE)
ax.plot(x_uniform_samples[~inside], y_uniform_samples[~inside], "rD", markersize=MARKER_SIZE)
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.set_aspect("equal")
ax.set_xlim(-2, 2)
ax.set_ylim(-2, 2)
pml.savefig("mcEstimatePi")
plt.show()
estimate_pi()