Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/notebooks/flow_2d_mlp.ipynb
1192 views
Kernel: Python 3

Open In Colab

Mapping a 2d standard Gaussian to a more complex distribution using an invertible MLP

Author: George Papamakarios

Based on the example by Eric Jang from https://blog.evjang.com/2018/01/nf1.html

Reproduces Figure 23.1 of the book Probabilistic Machine Learning: Advanced Topics by Kevin P. Murphy

Imports and definitions

from typing import Sequence import distrax import haiku as hk import jax import jax.numpy as jnp import matplotlib.pyplot as plt import optax Array = jnp.ndarray PRNGKey = Array prng = hk.PRNGSequence(42)

Create flow model

class Parameter(hk.Module): """Helper Haiku module for defining model parameters.""" def __init__(self, module_name: str, param_name: str, shape: Sequence[int], init: hk.initializers.Initializer): """Initializer. Args: module_name: name of the module. param_name: name of the parameter. shape: shape of the parameter. init: initializer of the parameter value. """ super().__init__(name=module_name) self._param = hk.get_parameter(param_name, shape=shape, init=init) def __call__(self) -> Array: return self._param class LeakyRelu(distrax.Lambda): """Leaky ReLU elementwise bijector.""" def __init__(self, slope: Array): """Initializer. Args: slope: the slope for x < 0. Must be positive. """ forward = lambda x: jnp.where(x >= 0.0, x, x * slope) inverse = lambda y: jnp.where(y >= 0.0, y, y / slope) forward_log_det_jacobian = lambda x: jnp.where(x >= 0.0, 0.0, jnp.log(slope)) inverse_log_det_jacobian = lambda y: jnp.where(y >= 0.0, 0.0, -jnp.log(slope)) super().__init__( forward=forward, inverse=inverse, forward_log_det_jacobian=forward_log_det_jacobian, inverse_log_det_jacobian=inverse_log_det_jacobian, event_ndims_in=0, ) def make_model() -> distrax.Transformed: """Creates the flow model.""" num_layers = 6 layers = [] for _ in range(num_layers - 1): # Each intermediate layer is an affine transformation followed by a leaky # ReLU nonlinearity. matrix = Parameter("affine", "matrix", shape=[2, 2], init=hk.initializers.Identity())() bias = Parameter("affine", "bias", shape=[2], init=hk.initializers.TruncatedNormal(2.0))() affine = distrax.UnconstrainedAffine(matrix, bias) slope = Parameter("nonlinearity", "slope", shape=[2], init=jnp.ones)() nonlinearity = distrax.Block(LeakyRelu(slope), 1) layers.append(distrax.Chain([nonlinearity, affine])) # The final layer is just an affine transformation. matrix = Parameter("affine", "matrix", shape=[2, 2], init=hk.initializers.Identity())() bias = Parameter("affine", "bias", shape=[2], init=jnp.zeros)() affine = distrax.UnconstrainedAffine(matrix, bias) layers.append(affine) flow = distrax.Chain(layers[::-1]) base = distrax.MultivariateNormalDiag(loc=jnp.zeros(2), scale_diag=jnp.ones(2)) return distrax.Transformed(base, flow) @hk.without_apply_rng @hk.transform def model_log_prob(x: Array) -> Array: model = make_model() return model.log_prob(x) @hk.without_apply_rng @hk.transform def model_sample(key: PRNGKey, num_samples: int) -> Array: model = make_model() return model.sample(seed=key, sample_shape=[num_samples])

Define target distribution

def target_sample(key: PRNGKey, num_samples: int) -> Array: """Generates samples from target distribution. Args: key: a PRNG key. num_samples: number of samples to generate. Returns: An array of shape [num_samples, 2] containing the samples. """ key1, key2 = jax.random.split(key) x = 0.6 * jax.random.normal(key1, [num_samples]) y = 0.8 * x**2 + 0.2 * jax.random.normal(key2, [num_samples]) return jnp.concatenate([y[:, None], x[:, None]], axis=-1) # Plot samples from target distribution. data = target_sample(next(prng), num_samples=1000) plt.plot(data[:, 0], data[:, 1], ".", color="red", label="Target") plt.axis("equal") plt.title("Samples from target distribution") plt.legend();
Image in a Jupyter notebook

Train model

# Initialize model parameters. params = model_sample.init(next(prng), next(prng), num_samples=1) # Plot samples from the untrained model. x = target_sample(next(prng), num_samples=1000) y = model_sample.apply(params, next(prng), num_samples=1000) plt.plot(x[:, 0], x[:, 1], ".", color="red", label="Target") plt.plot(y[:, 0], y[:, 1], ".", color="green", label="Model") plt.axis("equal") plt.title("Samples from untrained model") plt.legend();
Image in a Jupyter notebook
# Loss function is negative log likelihood. loss_fn = jax.jit(lambda params, x: -jnp.mean(model_log_prob.apply(params, x))) # Optimizer. optimizer = optax.adam(1e-3) opt_state = optimizer.init(params) # Training loop. for i in range(5000): data = target_sample(next(prng), num_samples=100) loss, g = jax.value_and_grad(loss_fn)(params, data) updates, opt_state = optimizer.update(g, opt_state) params = optax.apply_updates(params, updates) if i % 100 == 0: print(f"Step {i}, loss = {loss:.3f}")
Step 0, loss = 3.890 Step 100, loss = 2.155 Step 200, loss = 1.884 Step 300, loss = 1.783 Step 400, loss = 1.435 Step 500, loss = 1.248 Step 600, loss = 1.212 Step 700, loss = 1.223 Step 800, loss = 1.412 Step 900, loss = 1.269 Step 1000, loss = 1.122 Step 1100, loss = 0.997 Step 1200, loss = 0.970 Step 1300, loss = 0.940 Step 1400, loss = 1.032 Step 1500, loss = 1.028 Step 1600, loss = 0.884 Step 1700, loss = 0.972 Step 1800, loss = 1.913 Step 1900, loss = 1.150 Step 2000, loss = 0.941 Step 2100, loss = 0.834 Step 2200, loss = 1.294 Step 2300, loss = 1.011 Step 2400, loss = 0.831 Step 2500, loss = 0.988 Step 2600, loss = 0.878 Step 2700, loss = 0.917 Step 2800, loss = 0.898 Step 2900, loss = 0.741 Step 3000, loss = 0.849 Step 3100, loss = 0.880 Step 3200, loss = 0.934 Step 3300, loss = 0.739 Step 3400, loss = 0.916 Step 3500, loss = 0.943 Step 3600, loss = 0.948 Step 3700, loss = 0.943 Step 3800, loss = 1.004 Step 3900, loss = 1.116 Step 4000, loss = 1.711 Step 4100, loss = 0.907 Step 4200, loss = 1.067 Step 4300, loss = 1.030 Step 4400, loss = 0.814 Step 4500, loss = 0.867 Step 4600, loss = 1.015 Step 4700, loss = 0.914 Step 4800, loss = 0.912 Step 4900, loss = 1.047
# Plot samples from the trained model. x = target_sample(next(prng), num_samples=1000) y = model_sample.apply(params, next(prng), num_samples=1000) plt.plot(x[:, 0], x[:, 1], ".", color="red", label="Target") plt.plot(y[:, 0], y[:, 1], ".", color="green", label="Model") plt.axis("equal") plt.title("Samples from trained model") plt.legend();
Image in a Jupyter notebook

Create plot with intermediate distributions

@hk.without_apply_rng @hk.transform def model_sample_intermediate(key: PRNGKey, num_samples: int) -> Array: model = make_model() samples = [] x = model.distribution.sample(seed=key, sample_shape=[num_samples]) samples.append(x) for layer in model.bijector.bijectors[::-1]: x = layer.forward(x) samples.append(x) return samples xs = model_sample_intermediate.apply(params, next(prng), num_samples=2000)
plt.rcParams["figure.figsize"] = [2 * len(xs), 3] fig, axs = plt.subplots(1, len(xs)) fig.tight_layout() color = xs[0][:, 1] cm = plt.cm.get_cmap("gnuplot2") for i, (x, ax) in enumerate(zip(xs, axs)): ax.scatter(x[:, 0], x[:, 1], s=10, cmap=cm, c=color) ax.axis("equal") if i == 0: title = "Base distribution" else: title = f"Layer {i}" ax.set_title(title)
Image in a Jupyter notebook