Path: blob/master/notebooks/book2/23/two_moons_nsf_normalizing_flow.ipynb
1193 views
Two Moons Normalizing Flow Using Distrax + Haiku
Here is an implementation of the Neural Spline Flow (NSF
) based off of distrax
documentation for a flow. For a more detailed walkthrough of the math behind normalizing flow, see this notebook's example using MNIST. Code to load 2 moons example dataset sourced from Chris Waites's jax-flows demo.
Installing build dependencies ... done
Getting requirements to build wheel ... done
Preparing wheel metadata ... done
|████████████████████████████████| 272 kB 4.9 MB/s
|████████████████████████████████| 88 kB 9.8 MB/s
|████████████████████████████████| 125 kB 63.6 MB/s
|████████████████████████████████| 72 kB 817 kB/s
|████████████████████████████████| 1.1 MB 59.6 MB/s
Building wheel for probml-utils (PEP 517) ... done
Building wheel for TexSoup (setup.py) ... done
Building wheel for umap-learn (setup.py) ... done
Building wheel for pynndescent (setup.py) ... done
|████████████████████████████████| 342 kB 5.0 MB/s
|████████████████████████████████| 145 kB 4.6 MB/s
Plotting 2 moons dataset
Code taken directly from Chris Waites's jax-flows demo. This is the distribution we want to create a bijection to from a simple base distribution, such as a gaussian distribution.
Creating the normalizing flow in distrax+haiku
Instead of a uniform distribution, we use a Gaussian distribution as the base distribution. This makes more sense for a standardized two moons dataset that is scaled according to a normal distribution using sklearn
's StandardScaler()
. Using a uniform base distribution will result in inf
and nan
loss.
Setting up the optimizer
Training the flow
We can now look at the trained flow's output. It should be the same, or close, to the original Two Moons dataset that we are trying to model.
To sample intermediate layers, need to get haiku
parameters and match those to intermediate layers of the NSF. This isn't trivial but we can use the names of the modules to group them and then use those to sample from the base gaussian. To get the parameters, we need to parse the keys that haiku
uses to store parameter values from each module. For the NSF
flow, each layer is composed of an affine transformation and series of MLPs
to condition the masked coupling layer. To parse, we'll use some regex to determine the key name and use that to find the indices of the ordered dictionary of each layer. We'll collect each layer in a list within a larger list of all the lists.
Maybe we could have used better naming convention, but we leave that as an exercise for the reader.
We now can use a haiku
data structure (filter
) to get the params according to the list of the module neames we've collected. Thus, we have a list of the layers' parameters.
To sample sequential layers, we need the cumulative parameters up to that layer. Those will be the parameters that we'll use in our asmple function, whose skeleton is made for the number of layers we seek to sample.