Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/notebooks/two_moons_normalizingFlow.ipynb
1192 views
Kernel: Python 3

Two Moons Normalizing Flow Using Distrax + Haiku

Neural Spline Flow based off of distrax documentation for a flow. Code to load 2 moons example dataset sourced from Chris Waites's jax-flows demo.

!pip install -U dm-haiku distrax optax
Collecting dm-haiku Downloading dm_haiku-0.0.6-py3-none-any.whl (309 kB) |████████████████████████████████| 309 kB 13.1 MB/s Collecting distrax Downloading distrax-0.1.2-py3-none-any.whl (272 kB) |████████████████████████████████| 272 kB 14.5 MB/s Collecting optax Downloading optax-0.1.1-py3-none-any.whl (136 kB) |████████████████████████████████| 136 kB 72.1 MB/s Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from dm-haiku) (3.10.0.2) Requirement already satisfied: numpy>=1.18.0 in /usr/local/lib/python3.7/dist-packages (from dm-haiku) (1.21.5) Requirement already satisfied: tabulate>=0.8.9 in /usr/local/lib/python3.7/dist-packages (from dm-haiku) (0.8.9) Requirement already satisfied: absl-py>=0.7.1 in /usr/local/lib/python3.7/dist-packages (from dm-haiku) (1.0.0) Collecting jmp>=0.0.2 Downloading jmp-0.0.2-py3-none-any.whl (16 kB) Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from absl-py>=0.7.1->dm-haiku) (1.15.0) Requirement already satisfied: tensorflow-probability>=0.15.0 in /usr/local/lib/python3.7/dist-packages (from distrax) (0.16.0) Requirement already satisfied: jax>=0.1.55 in /usr/local/lib/python3.7/dist-packages (from distrax) (0.3.4) Collecting chex>=0.0.7 Downloading chex-0.1.1-py3-none-any.whl (70 kB) |████████████████████████████████| 70 kB 8.3 MB/s Requirement already satisfied: jaxlib>=0.1.67 in /usr/local/lib/python3.7/dist-packages (from distrax) (0.3.2+cuda11.cudnn805) Requirement already satisfied: dm-tree>=0.1.5 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.7->distrax) (0.1.6) Requirement already satisfied: toolz>=0.9.0 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.7->distrax) (0.11.2) Requirement already satisfied: scipy>=1.2.1 in /usr/local/lib/python3.7/dist-packages (from jax>=0.1.55->distrax) (1.4.1) Requirement already satisfied: opt-einsum in /usr/local/lib/python3.7/dist-packages (from jax>=0.1.55->distrax) (3.3.0) Requirement already satisfied: flatbuffers<3.0,>=1.12 in /usr/local/lib/python3.7/dist-packages (from jaxlib>=0.1.67->distrax) (2.0) Requirement already satisfied: gast>=0.3.2 in /usr/local/lib/python3.7/dist-packages (from tensorflow-probability>=0.15.0->distrax) (0.5.3) Requirement already satisfied: cloudpickle>=1.3 in /usr/local/lib/python3.7/dist-packages (from tensorflow-probability>=0.15.0->distrax) (1.3.0) Requirement already satisfied: decorator in /usr/local/lib/python3.7/dist-packages (from tensorflow-probability>=0.15.0->distrax) (4.4.2) Installing collected packages: jmp, chex, optax, dm-haiku, distrax Successfully installed chex-0.1.1 distrax-0.1.2 dm-haiku-0.0.6 jmp-0.0.2 optax-0.1.1
import matplotlib.pyplot as plt from IPython.display import clear_output from sklearn import datasets, preprocessing import distrax import jax import jax.numpy as jnp import numpy as np import haiku as hk import optax import tensorflow as tf import tensorflow_datasets as tfds from tensorflow_probability.substrates import jax as tfp tfd = tfp.distributions # key = jax.random.PRNGKey(1234)

Plotting 2 moons dataset

Code taken directly from Chris Waites's jax-flows demo. This is the distribution we want to create a bijection to from a simple base distribution, such as a gaussian distribution.

n_samples = 10000 plot_range = [(-2, 2), (-2, 2)] n_bins = 100 scaler = preprocessing.StandardScaler() X, _ = datasets.make_moons(n_samples=n_samples, noise=0.05) X = scaler.fit_transform(X) plt.hist2d(X[:, 0], X[:, 1], bins=n_bins, range=plot_range)[-1] plt.savefig("two-moons-original.pdf") plt.savefig("two-moons-original.png")
Image in a Jupyter notebook

Creating the normalizing flow in distrax+haiku

Instead of a uniform distribution, we use a normal distribution as the base distribution. This makes more sense for a standardized two moons dataset that is scaled according to a normal distribution using sklearn's StandardScaler(). Using a uniform base distribution will result in inf and nan loss.

from typing import Any, Iterator, Mapping, Optional, Sequence, Tuple # Hyperparams - change these to experiment flow_num_layers = 8 mlp_num_layers = 4 hidden_size = 1000 num_bins = 8 batch_size = 512 learning_rate = 1e-4 eval_frequency = 100 Array = jnp.ndarray PRNGKey = Array Batch = Mapping[str, np.ndarray] OptState = Any # Functions to create a distrax normalizing flow def make_conditioner( event_shape: Sequence[int], hidden_sizes: Sequence[int], num_bijector_params: int ) -> hk.Sequential: """Creates an MLP conditioner for each layer of the flow.""" return hk.Sequential( [ hk.Flatten(preserve_dims=-len(event_shape)), hk.nets.MLP(hidden_sizes, activate_final=True), # We initialize this linear layer to zero so that the flow is initialized # to the identity function. hk.Linear(np.prod(event_shape) * num_bijector_params, w_init=jnp.zeros, b_init=jnp.zeros), hk.Reshape(tuple(event_shape) + (num_bijector_params,), preserve_dims=-1), ] ) def make_flow_model( event_shape: Sequence[int], num_layers: int, hidden_sizes: Sequence[int], num_bins: int ) -> distrax.Transformed: """Creates the flow model.""" # Alternating binary mask. mask = jnp.arange(0, np.prod(event_shape)) % 2 mask = jnp.reshape(mask, event_shape) mask = mask.astype(bool) def bijector_fn(params: Array): return distrax.RationalQuadraticSpline(params, range_min=-2.0, range_max=2.0) # Number of parameters for the rational-quadratic spline: # - `num_bins` bin widths # - `num_bins` bin heights # - `num_bins + 1` knot slopes # for a total of `3 * num_bins + 1` parameters. num_bijector_params = 3 * num_bins + 1 layers = [] for _ in range(num_layers): layer = distrax.MaskedCoupling( mask=mask, bijector=bijector_fn, conditioner=make_conditioner(event_shape, hidden_sizes, num_bijector_params), ) layers.append(layer) # Flip the mask after each layer. mask = jnp.logical_not(mask) # We invert the flow so that the `forward` method is called with `log_prob`. flow = distrax.Inverse(distrax.Chain(layers)) # Making base distribution normal distribution mu = jnp.zeros(event_shape) sigma = jnp.ones(event_shape) base_distribution = distrax.Independent(distrax.MultivariateNormalDiag(mu, sigma)) return distrax.Transformed(base_distribution, flow) def load_dataset(split: tfds.Split, batch_size: int) -> Iterator[Batch]: # ds = tfds.load("mnist", split=split, shuffle_files=True) ds = split ds = ds.shuffle(buffer_size=10 * batch_size) ds = ds.batch(batch_size) ds = ds.prefetch(buffer_size=1000) ds = ds.repeat() return iter(tfds.as_numpy(ds)) def prepare_data(batch: Batch, prng_key: Optional[PRNGKey] = None) -> Array: data = batch.astype(np.float32) return data @hk.without_apply_rng @hk.transform def model_sample(key: PRNGKey, num_samples: int) -> Array: model = make_flow_model( event_shape=TWO_MOONS_SHAPE, num_layers=flow_num_layers, hidden_sizes=[hidden_size] * mlp_num_layers, num_bins=num_bins, ) return model.sample(seed=key, sample_shape=[num_samples]) @hk.without_apply_rng @hk.transform def log_prob(data: Array) -> Array: model = make_flow_model( event_shape=TWO_MOONS_SHAPE, num_layers=flow_num_layers, hidden_sizes=[hidden_size] * mlp_num_layers, num_bins=num_bins, ) return model.log_prob(data) def loss_fn(params: hk.Params, prng_key: PRNGKey, batch: Batch) -> Array: data = prepare_data(batch, prng_key) # Loss is average negative log likelihood. loss = -jnp.mean(log_prob.apply(params, data)) return loss @jax.jit def eval_fn(params: hk.Params, batch: Batch) -> Array: data = prepare_data(batch) # We don't dequantize during evaluation. loss = -jnp.mean(log_prob.apply(params, data)) return loss

Setting up the optimizer

optimizer = optax.adam(learning_rate) @jax.jit def update(params: hk.Params, prng_key: PRNGKey, opt_state: OptState, batch: Batch) -> Tuple[hk.Params, OptState]: """Single SGD update step.""" grads = jax.grad(loss_fn)(params, prng_key, batch) updates, new_opt_state = optimizer.update(grads, opt_state) new_params = optax.apply_updates(params, updates) return new_params, new_opt_state

Training the flow

# Event shape TWO_MOONS_SHAPE = (2,) # Create tf dataset from sklearn dataset dataset = tf.data.Dataset.from_tensor_slices(X) # Splitting into train/validate ds train = dataset.skip(2000) val = dataset.take(2000) # load_dataset(split: tfds.Split, batch_size: int) train_ds = load_dataset(train, 512) valid_ds = load_dataset(val, 512) # Initializing PRNG and Neural Net params prng_seq = hk.PRNGSequence(1) params = log_prob.init(next(prng_seq), np.zeros((1, *TWO_MOONS_SHAPE))) opt_state = optimizer.init(params) training_steps = 1000 for step in range(training_steps): params, opt_state = update(params, next(prng_seq), opt_state, next(train_ds)) if step % eval_frequency == 0: val_loss = eval_fn(params, next(valid_ds)) print(f"STEP: {step:5d}; Validation loss: {val_loss:.3f}")
STEP: 0; Validation loss: 2.799 STEP: 100; Validation loss: 1.549 STEP: 200; Validation loss: 1.405 STEP: 300; Validation loss: 1.332 STEP: 400; Validation loss: 1.309 STEP: 500; Validation loss: 1.252 STEP: 600; Validation loss: 1.304 STEP: 700; Validation loss: 1.291 STEP: 800; Validation loss: 1.304 STEP: 900; Validation loss: 1.294
n_samples = 10000 plot_range = [(-2, 2), (-2, 2)] n_bins = 100 X_transf = model_sample.apply(params, next(prng_seq), num_samples=n_samples) plt.hist2d(X_transf[:, 0], X_transf[:, 1], bins=n_bins, range=plot_range)[-1] plt.savefig("two-moons-flow.pdf") plt.savefig("two-moons-flow.png") plt.show()
Image in a Jupyter notebook