Path: blob/master/deprecated/notebooks/flow_spline_mnist_jax.ipynb
1192 views
Spline Flow using JAX, Haiku, Optax and Distrax
In this notebook we will implement Spline flow to fit a distribution to MNIST dataset. We will be using the RationalQuadraticSpline, a piecewise rational quadratic spline, and Masked Couplings as explained in paper Neural Spline Flows by Conor Durkan, Artur Bekasov, Iain Murray, George Papamakarios.
This notebook replicates the original distrax code with suitable minor modifications.
For implementing the Quadratic Splines with Coupling flows, We will be using following libraries:
JAX - NumPy on GPU, and TPU with automatic differentiation.
Haiku - JAX based Neural Network Library.
Optax - gradient processing and optimization library for JAX.
Distrax - a lightweight library of probability distributions and bijectors.
Installing required libraries in Colab
Importing all required libraries and packages
Conditioner
Let be the input. The input is split into two equal sub spaces each of size such that .
Let us assume we have a bijection parameterized by
We define a single coupling layer as a function given by as below:
In other words, the input is split into and output is combined back into using a binary mask . Therefore, the single coupling layer given by is defined in a single equation as below:
We will implement the full flow by chaining multiple coupling layers. The mask will be flipped between each layer to ensure we capture dependencies in more expressive way.
The function is called the Conditioner which we implement with a set of Linear layers and ReLU activation functions.
Flow Model
Next we implement the Bijector using distrax.RationalQuadraticSpline
and the Masked Coupling using distrax.MaskedCoupling
We join together sequentailly a number of masked coupling layers to define the complete Spline FLow.
We define base distribution of our flow as Uniform distribution.
Data Loading and preparation
In this cell, we define a function to load the MNIST dataset using TFDS (Tensorflow Datasets) package.
We also have a function prepare_data
to:
dequantize the data i.e. to convert the integer pixel values from
{0,1,...,255}
to real number values[0,256)
by adding a random uniform noise[0,1)
; andNormalize the pixel values from
[0,256)
to[0,1)
The dequantization of data is done only at training time.
Log Probability, Sample and training loss Functions
Next we define the log_prob
model_sample
and loss_fn
. log_prob
is responsible for calculating the log of the probability of the data which we want to maximize for MNIST data inside loss_fn
.
model_sample
allows us to sample new data points after the model has been trained on MNIST. FOr a well trained model, these samples will look like MNIST digits generated synthetically.
Training
Next we define, the update
function for the gradient update. We use jax.grad
to calculate the gradient of loss wrt model parameters.
Now we carry out the training of the model.
Sampling from Trained Flow Model
Plot new samples
After the model has been trained in MNIST, we draw new samples and plot them. Once the model has been trained enough, these should look like MNIST dataset digits.