Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/book2/23/flow_spline_mnist.ipynb
1193 views
Kernel: Python 3.7.11 ('probml')

Open In Colab

Spline Flow using JAX, Haiku, Optax and Distrax

In this notebook we will implement Spline flow to fit a distribution to MNIST dataset. We will be using the RationalQuadraticSpline, a piecewise rational quadratic spline, and Masked Couplings as explained in paper Neural Spline Flows by Conor Durkan, Artur Bekasov, Iain Murray, George Papamakarios.

This notebook replicates the original distrax code with suitable minor modifications.

For implementing the Quadratic Splines with Coupling flows, We will be using following libraries:

  • JAX - NumPy on GPU, and TPU with automatic differentiation.

  • Haiku - JAX based Neural Network Library.

  • Optax - gradient processing and optimization library for JAX.

  • Distrax - a lightweight library of probability distributions and bijectors.

Installing required libraries in Colab

!pip install -qq -U optax distrax dm-haiku
Requirement already satisfied: optax in /usr/local/lib/python3.7/dist-packages (0.1.1) Requirement already satisfied: distrax in /usr/local/lib/python3.7/dist-packages (0.1.1) Requirement already satisfied: dm-haiku in /usr/local/lib/python3.7/dist-packages (0.0.6) Requirement already satisfied: typing-extensions>=3.10.0 in /usr/local/lib/python3.7/dist-packages (from optax) (3.10.0.2) Requirement already satisfied: absl-py>=0.7.1 in /usr/local/lib/python3.7/dist-packages (from optax) (1.0.0) Requirement already satisfied: jaxlib>=0.1.37 in /usr/local/lib/python3.7/dist-packages (from optax) (0.3.0+cuda11.cudnn805) Requirement already satisfied: jax>=0.1.55 in /usr/local/lib/python3.7/dist-packages (from optax) (0.3.1) Requirement already satisfied: chex>=0.0.4 in /usr/local/lib/python3.7/dist-packages (from optax) (0.1.1) Requirement already satisfied: numpy>=1.18.0 in /usr/local/lib/python3.7/dist-packages (from optax) (1.21.5) Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from absl-py>=0.7.1->optax) (1.15.0) Requirement already satisfied: toolz>=0.9.0 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax) (0.11.2) Requirement already satisfied: dm-tree>=0.1.5 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax) (0.1.6) Requirement already satisfied: opt-einsum in /usr/local/lib/python3.7/dist-packages (from jax>=0.1.55->optax) (3.3.0) Requirement already satisfied: scipy>=1.2.1 in /usr/local/lib/python3.7/dist-packages (from jax>=0.1.55->optax) (1.4.1) Requirement already satisfied: flatbuffers<3.0,>=1.12 in /usr/local/lib/python3.7/dist-packages (from jaxlib>=0.1.37->optax) (2.0) Requirement already satisfied: tensorflow-probability>=0.15.0 in /usr/local/lib/python3.7/dist-packages (from distrax) (0.16.0) Requirement already satisfied: cloudpickle>=1.3 in /usr/local/lib/python3.7/dist-packages (from tensorflow-probability>=0.15.0->distrax) (1.3.0) Requirement already satisfied: decorator in /usr/local/lib/python3.7/dist-packages (from tensorflow-probability>=0.15.0->distrax) (4.4.2) Requirement already satisfied: gast>=0.3.2 in /usr/local/lib/python3.7/dist-packages (from tensorflow-probability>=0.15.0->distrax) (0.5.3) Requirement already satisfied: jmp>=0.0.2 in /usr/local/lib/python3.7/dist-packages (from dm-haiku) (0.0.2) Requirement already satisfied: tabulate>=0.8.9 in /usr/local/lib/python3.7/dist-packages (from dm-haiku) (0.8.9)

Importing all required libraries and packages

from typing import Any, Iterator, Mapping, Optional, Sequence, Tuple try: import distrax except ModuleNotFoundError: %pip install -qq distrax import distrax try: import haiku as hk except ModuleNotFoundError: %pip install -qq dm-haiku import haiku as hk import jax import jax.numpy as jnp import numpy as np try: import optax except ModuleNotFoundError: %pip install -qq optax import optax try: import tensorflow_datasets as tfds except ModuleNotFoundError: %pip install -qq tensorflow tensorflow_datasets import tensorflow_datasets as tfds import matplotlib.pyplot as plt Array = jnp.ndarray PRNGKey = Array Batch = Mapping[str, np.ndarray] OptState = Any MNIST_IMAGE_SHAPE = (28, 28, 1) batch_size = 128

Conditioner

Let u∈RDu \in \mathbb{R}^D be the input. The input is split into two equal sub spaces (xA,xB)(x^A, x^B) each of size Rd\mathbb{R}^{d} such that d=D/2d = D/2.

Let us assume we have a bijection f^(⋅;θ):Rd→Rd\hat{f}(\cdot;\theta): \mathbb{R}^d \to \mathbb{R}^d parameterized by θ\theta

We define a single coupling layer as a function f:RD→RD f: \mathbb{R}^D \to \mathbb{R}^D given by x=f(u)x = f(u) as below:

xA=f^(uA;Θ(uB)) x^A = \hat{f}(u^A; \Theta(u^B))

xB=uB x^B = u^B

x=(xA,xB) x = (x^A, x^B)

In other words, the input uu is split into (uA,uB)(u^A, u^B) and output (xA,xB)(x^A, x^B) is combined back into xx using a binary mask mm. Therefore, the single coupling layer f:RD→RD f: \mathbb{R}^D \to \mathbb{R}^D given by x=f(u)x = f(u) is defined in a single equation as below:

x=f(u)=b⊙u+(1−b)f^(u;Θ(b⊙u))x = f(u) = b \odot u + (1-b) \hat{f}(u; \Theta(b \odot u))

We will implement the full flow by chaining multiple coupling layers. The mask bb will be flipped between each layer to ensure we capture dependencies in more expressive way.

The function Θ{\Theta} is called the Conditioner which we implement with a set of Linear layers and ReLU activation functions.

def make_conditioner( event_shape: Sequence[int], hidden_sizes: Sequence[int], num_bijector_params: int ) -> hk.Sequential: """Creates an MLP conditioner for each layer of the flow.""" return hk.Sequential( [ hk.Flatten(preserve_dims=-len(event_shape)), hk.nets.MLP(hidden_sizes, activate_final=True), # We initialize this linear layer to zero so that the flow is initialized # to the identity function. hk.Linear(np.prod(event_shape) * num_bijector_params, w_init=jnp.zeros, b_init=jnp.zeros), hk.Reshape(tuple(event_shape) + (num_bijector_params,), preserve_dims=-1), ] )

Flow Model

Next we implement the Bijector f^\hat{f} using distrax.RationalQuadraticSpline and the Masked Coupling ff using distrax.MaskedCoupling

We join together sequentailly a number of masked coupling layers to define the complete Spline FLow.

We define base distribution of our flow as Uniform distribution.

def make_flow_model( event_shape: Sequence[int], num_layers: int, hidden_sizes: Sequence[int], num_bins: int ) -> distrax.Transformed: """Creates the flow model.""" # Alternating binary mask. mask = jnp.arange(0, np.prod(event_shape)) % 2 mask = jnp.reshape(mask, event_shape) mask = mask.astype(bool) def bijector_fn(params: Array): return distrax.RationalQuadraticSpline(params, range_min=0.0, range_max=1.0) # Number of parameters for the rational-quadratic spline: # - `num_bins` bin widths # - `num_bins` bin heights # - `num_bins + 1` knot slopes # for a total of `3 * num_bins + 1` parameters. num_bijector_params = 3 * num_bins + 1 layers = [] for _ in range(num_layers): layer = distrax.MaskedCoupling( mask=mask, bijector=bijector_fn, conditioner=make_conditioner(event_shape, hidden_sizes, num_bijector_params), ) layers.append(layer) # Flip the mask after each layer. mask = jnp.logical_not(mask) # We invert the flow so that the `forward` method is called with `log_prob`. flow = distrax.Inverse(distrax.Chain(layers)) base_distribution = distrax.Independent( distrax.Uniform(low=jnp.zeros(event_shape), high=jnp.ones(event_shape)), reinterpreted_batch_ndims=len(event_shape), ) return distrax.Transformed(base_distribution, flow)

Data Loading and preparation

In this cell, we define a function to load the MNIST dataset using TFDS (Tensorflow Datasets) package.

We also have a function prepare_data to:

  1. dequantize the data i.e. to convert the integer pixel values from {0,1,...,255} to real number values [0,256) by adding a random uniform noise [0,1); and

  2. Normalize the pixel values from [0,256) to [0,1)

The dequantization of data is done only at training time.

def load_dataset(split: tfds.Split, batch_size: int) -> Iterator[Batch]: ds = tfds.load("mnist", split=split, shuffle_files=True) ds = ds.shuffle(buffer_size=10 * batch_size) ds = ds.batch(batch_size) ds = ds.prefetch(buffer_size=5) ds = ds.repeat() return iter(tfds.as_numpy(ds)) def prepare_data(batch: Batch, prng_key: Optional[PRNGKey] = None) -> Array: data = batch["image"].astype(np.float32) if prng_key is not None: # Dequantize pixel values {0, 1, ..., 255} with uniform noise [0, 1). data += jax.random.uniform(prng_key, data.shape) return data / 256.0 # Normalize pixel values from [0, 256) to [0, 1).

Log Probability, Sample and training loss Functions

Next we define the log_prob model_sample and loss_fn. log_prob is responsible for calculating the log of the probability of the data which we want to maximize for MNIST data inside loss_fn.

model_sample allows us to sample new data points after the model has been trained on MNIST. FOr a well trained model, these samples will look like MNIST digits generated synthetically.

flow_num_layers = 8 mlp_num_layers = 2 hidden_size = 500 num_bins = 4 learning_rate = 1e-4 # using 100,000 steps could take long (about 2 hours) but will give better results. # You can try with 10,000 steps to run it fast but result may not be very good training_steps = 10000 eval_frequency = 1000
@hk.without_apply_rng @hk.transform def log_prob(data: Array) -> Array: model = make_flow_model( event_shape=MNIST_IMAGE_SHAPE, num_layers=flow_num_layers, hidden_sizes=[hidden_size] * mlp_num_layers, num_bins=num_bins, ) return model.log_prob(data) @hk.without_apply_rng @hk.transform def model_sample(key: PRNGKey, num_samples: int) -> Array: model = make_flow_model( event_shape=MNIST_IMAGE_SHAPE, num_layers=flow_num_layers, hidden_sizes=[hidden_size] * mlp_num_layers, num_bins=num_bins, ) return model.sample(seed=key, sample_shape=[num_samples]) def loss_fn(params: hk.Params, prng_key: PRNGKey, batch: Batch) -> Array: data = prepare_data(batch, prng_key) # Loss is average negative log likelihood. loss = -jnp.mean(log_prob.apply(params, data)) return loss @jax.jit def eval_fn(params: hk.Params, batch: Batch) -> Array: data = prepare_data(batch) # We don't dequantize during evaluation. loss = -jnp.mean(log_prob.apply(params, data)) return loss

Training

Next we define, the update function for the gradient update. We use jax.grad to calculate the gradient of loss wrt model parameters.

optimizer = optax.adam(learning_rate) @jax.jit def update(params: hk.Params, prng_key: PRNGKey, opt_state: OptState, batch: Batch) -> Tuple[hk.Params, OptState]: """Single SGD update step.""" grads = jax.grad(loss_fn)(params, prng_key, batch) updates, new_opt_state = optimizer.update(grads, opt_state) new_params = optax.apply_updates(params, updates) return new_params, new_opt_state

Now we carry out the training of the model.

prng_seq = hk.PRNGSequence(42) params = log_prob.init(next(prng_seq), np.zeros((1, *MNIST_IMAGE_SHAPE))) opt_state = optimizer.init(params) train_ds = load_dataset(tfds.Split.TRAIN, batch_size) valid_ds = load_dataset(tfds.Split.TEST, batch_size) for step in range(training_steps): params, opt_state = update(params, next(prng_seq), opt_state, next(train_ds)) if step % eval_frequency == 0: val_loss = eval_fn(params, next(valid_ds)) print(f"STEP: {step:5d}; Validation loss: {val_loss:.3f}")
STEP: 0; Validation loss: -5.243 STEP: 1000; Validation loss: -3330.548 STEP: 2000; Validation loss: -3316.721 STEP: 3000; Validation loss: -3319.474 STEP: 4000; Validation loss: -3321.137 STEP: 5000; Validation loss: -3297.604 STEP: 6000; Validation loss: -3295.713 STEP: 7000; Validation loss: -3260.267 STEP: 8000; Validation loss: -3263.344 STEP: 9000; Validation loss: -3286.821 STEP: 10000; Validation loss: -3301.308 STEP: 11000; Validation loss: -3296.085 STEP: 12000; Validation loss: -3310.035 STEP: 13000; Validation loss: -3326.781 STEP: 14000; Validation loss: -3305.100 STEP: 15000; Validation loss: -3317.956 STEP: 16000; Validation loss: -3365.609 STEP: 17000; Validation loss: -3339.233 STEP: 18000; Validation loss: -3346.478 STEP: 19000; Validation loss: -3325.882 STEP: 20000; Validation loss: -3340.056 STEP: 21000; Validation loss: -3342.137 STEP: 22000; Validation loss: -3338.241 STEP: 23000; Validation loss: -3354.057 STEP: 24000; Validation loss: -3384.099 STEP: 25000; Validation loss: -3326.196 STEP: 26000; Validation loss: -3367.737 STEP: 27000; Validation loss: -3353.810 STEP: 28000; Validation loss: -3403.825 STEP: 29000; Validation loss: -3367.897 STEP: 30000; Validation loss: -3384.129 STEP: 31000; Validation loss: -3395.217 STEP: 32000; Validation loss: -3426.372 STEP: 33000; Validation loss: -3381.938 STEP: 34000; Validation loss: -3397.225 STEP: 35000; Validation loss: -3406.573 STEP: 36000; Validation loss: -3395.962 STEP: 37000; Validation loss: -3371.118 STEP: 38000; Validation loss: -3405.979 STEP: 39000; Validation loss: -3394.809 STEP: 40000; Validation loss: -3373.334 STEP: 41000; Validation loss: -3388.248 STEP: 42000; Validation loss: -3407.083 STEP: 43000; Validation loss: -3404.585 STEP: 44000; Validation loss: -3400.513 STEP: 45000; Validation loss: -3403.710 STEP: 46000; Validation loss: -3420.211 STEP: 47000; Validation loss: -3408.354 STEP: 48000; Validation loss: -3442.830 STEP: 49000; Validation loss: -3444.429 STEP: 50000; Validation loss: -3388.476 STEP: 51000; Validation loss: -3432.037 STEP: 52000; Validation loss: -3424.902 STEP: 53000; Validation loss: -3415.527 STEP: 54000; Validation loss: -3410.981 STEP: 55000; Validation loss: -3437.098 STEP: 56000; Validation loss: -3404.357 STEP: 57000; Validation loss: -3397.224 STEP: 58000; Validation loss: -3414.235 STEP: 59000; Validation loss: -3448.812 STEP: 60000; Validation loss: -3392.534 STEP: 61000; Validation loss: -3390.733 STEP: 62000; Validation loss: -3419.801 STEP: 63000; Validation loss: -3416.396 STEP: 64000; Validation loss: -3416.169 STEP: 65000; Validation loss: -3384.039 STEP: 66000; Validation loss: -3419.544 STEP: 67000; Validation loss: -3408.180 STEP: 68000; Validation loss: -3410.440 STEP: 69000; Validation loss: -3414.480 STEP: 70000; Validation loss: -3399.575 STEP: 71000; Validation loss: -3412.096 STEP: 72000; Validation loss: -3440.463 STEP: 73000; Validation loss: -3420.625 STEP: 74000; Validation loss: -3425.068 STEP: 75000; Validation loss: -3431.229 STEP: 76000; Validation loss: -3415.992 STEP: 77000; Validation loss: -3404.405 STEP: 78000; Validation loss: -3469.461 STEP: 79000; Validation loss: -3391.149 STEP: 80000; Validation loss: -3392.615 STEP: 81000; Validation loss: -3423.543 STEP: 82000; Validation loss: -3413.071 STEP: 83000; Validation loss: -3441.366 STEP: 84000; Validation loss: -3414.967 STEP: 85000; Validation loss: -3396.721 STEP: 86000; Validation loss: -3409.096 STEP: 87000; Validation loss: -3431.324 STEP: 88000; Validation loss: -3437.945 STEP: 89000; Validation loss: -3433.234 STEP: 90000; Validation loss: -3422.854 STEP: 91000; Validation loss: -3404.757 STEP: 92000; Validation loss: -3427.891 STEP: 93000; Validation loss: -3431.066 STEP: 94000; Validation loss: -3439.144 STEP: 95000; Validation loss: -3432.875 STEP: 96000; Validation loss: -3410.505 STEP: 97000; Validation loss: -3375.825 STEP: 98000; Validation loss: -3402.304 STEP: 99000; Validation loss: -3391.469

Sampling from Trained Flow Model

Plot new samples

After the model has been trained in MNIST, we draw new samples and plot them. Once the model has been trained enough, these should look like MNIST dataset digits.

def plot_batch(batch: Batch) -> None: """Plots a batch of MNIST digits.""" images = batch.reshape((-1,) + MNIST_IMAGE_SHAPE) plt.figure(figsize=(10, 4)) for i in range(10): plt.subplot(2, 5, i + 1) plt.imshow(np.squeeze(images[i]), cmap="gray") plt.axis("off") plt.show() sample = model_sample.apply(params, next(prng_seq), num_samples=10) plot_batch(sample)
Image in a Jupyter notebook