Path: blob/master/deprecated/notebooks/two_moons_normalizingFlow.ipynb
1192 views
Kernel: Python 3
Two Moons Normalizing Flow Using Distrax + Haiku
Neural Spline Flow based off of distrax
documentation for a flow. Code to load 2 moons example dataset sourced from Chris Waites's jax-flows demo.
In [ ]:
Collecting dm-haiku
Downloading dm_haiku-0.0.6-py3-none-any.whl (309 kB)
|████████████████████████████████| 309 kB 13.1 MB/s
Collecting distrax
Downloading distrax-0.1.2-py3-none-any.whl (272 kB)
|████████████████████████████████| 272 kB 14.5 MB/s
Collecting optax
Downloading optax-0.1.1-py3-none-any.whl (136 kB)
|████████████████████████████████| 136 kB 72.1 MB/s
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from dm-haiku) (3.10.0.2)
Requirement already satisfied: numpy>=1.18.0 in /usr/local/lib/python3.7/dist-packages (from dm-haiku) (1.21.5)
Requirement already satisfied: tabulate>=0.8.9 in /usr/local/lib/python3.7/dist-packages (from dm-haiku) (0.8.9)
Requirement already satisfied: absl-py>=0.7.1 in /usr/local/lib/python3.7/dist-packages (from dm-haiku) (1.0.0)
Collecting jmp>=0.0.2
Downloading jmp-0.0.2-py3-none-any.whl (16 kB)
Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from absl-py>=0.7.1->dm-haiku) (1.15.0)
Requirement already satisfied: tensorflow-probability>=0.15.0 in /usr/local/lib/python3.7/dist-packages (from distrax) (0.16.0)
Requirement already satisfied: jax>=0.1.55 in /usr/local/lib/python3.7/dist-packages (from distrax) (0.3.4)
Collecting chex>=0.0.7
Downloading chex-0.1.1-py3-none-any.whl (70 kB)
|████████████████████████████████| 70 kB 8.3 MB/s
Requirement already satisfied: jaxlib>=0.1.67 in /usr/local/lib/python3.7/dist-packages (from distrax) (0.3.2+cuda11.cudnn805)
Requirement already satisfied: dm-tree>=0.1.5 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.7->distrax) (0.1.6)
Requirement already satisfied: toolz>=0.9.0 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.7->distrax) (0.11.2)
Requirement already satisfied: scipy>=1.2.1 in /usr/local/lib/python3.7/dist-packages (from jax>=0.1.55->distrax) (1.4.1)
Requirement already satisfied: opt-einsum in /usr/local/lib/python3.7/dist-packages (from jax>=0.1.55->distrax) (3.3.0)
Requirement already satisfied: flatbuffers<3.0,>=1.12 in /usr/local/lib/python3.7/dist-packages (from jaxlib>=0.1.67->distrax) (2.0)
Requirement already satisfied: gast>=0.3.2 in /usr/local/lib/python3.7/dist-packages (from tensorflow-probability>=0.15.0->distrax) (0.5.3)
Requirement already satisfied: cloudpickle>=1.3 in /usr/local/lib/python3.7/dist-packages (from tensorflow-probability>=0.15.0->distrax) (1.3.0)
Requirement already satisfied: decorator in /usr/local/lib/python3.7/dist-packages (from tensorflow-probability>=0.15.0->distrax) (4.4.2)
Installing collected packages: jmp, chex, optax, dm-haiku, distrax
Successfully installed chex-0.1.1 distrax-0.1.2 dm-haiku-0.0.6 jmp-0.0.2 optax-0.1.1
In [ ]:
Plotting 2 moons dataset
Code taken directly from Chris Waites's jax-flows demo. This is the distribution we want to create a bijection to from a simple base distribution, such as a gaussian distribution.
In [ ]:
Creating the normalizing flow in distrax+haiku
Instead of a uniform distribution, we use a normal distribution as the base distribution. This makes more sense for a standardized two moons dataset that is scaled according to a normal distribution using sklearn
's StandardScaler()
. Using a uniform base distribution will result in inf
and nan
loss.
In [ ]:
Setting up the optimizer
In [ ]:
Training the flow
In [ ]:
STEP: 0; Validation loss: 2.799
STEP: 100; Validation loss: 1.549
STEP: 200; Validation loss: 1.405
STEP: 300; Validation loss: 1.332
STEP: 400; Validation loss: 1.309
STEP: 500; Validation loss: 1.252
STEP: 600; Validation loss: 1.304
STEP: 700; Validation loss: 1.291
STEP: 800; Validation loss: 1.304
STEP: 900; Validation loss: 1.294
In [ ]: