Path: blob/master/site/en-snapshot/probability/examples/FFJORD_Demo.ipynb
25118 views
Copyright 2018 The TensorFlow Probability Authors.
Licensed under the Apache License, Version 2.0 (the "License");
Setup
First install packages used in this demo.
FFJORD bijector
In this colab we demonstrate FFJORD bijector, originally proposed in the paper by Grathwohl, Will, et al. arxiv link.
In the nutshell the idea behind such approach is to establish a correspondence between a known base distribution and the data distribution.
To establish this connection, we need to
Define a bijective map , between the space on which base distribution is defined and space of the data domain.
Efficiently keep track of the deformations we perform to transfer the notion of probability onto .
The second condition is formalized in the following expression for probability distribution defined on :
FFJORD bijector accomplishes this by defining a transformation
This transformation is invertible, as long as function describing the evolution of the state is well behaved and the log_det_jacobian
can be calculated by integrating the following expression.
In this demo we will train a FFJORD bijector to warp a gaussian distribution onto the distribution defined by moons
dataset. This will be done in 3 steps:
Define base distribution
Define FFJORD bijector
Minimize exact log-likelihood of the dataset
First, we load the data
Next, we instantiate a base distribution
We use a multi-layer perceptron to model state_derivative_fn
.
While not necessary for this dataset, it is often benefitial to make state_derivative_fn
dependent on time. Here we achieve this by concatenating t
to inputs of our network.
Now we construct a stack of FFJORD bijectors. Each bijector is provided with ode_solve_fn
and trace_augmentation_fn
and it's own state_derivative_fn
model, so that they represent a sequence of different transformations.
Now we can use TransformedDistribution
which is the result of warping base_distribution
with stacked_ffjord
bijector.
Now we define our training procedure. We simply minimize negative log-likelihood of the data.
Plot samples from base and transformed distributions.
0%| | 0/40 [00:00<?, ?it/s]
Training it for longer with learning rate results in further improvements.
Not convered in this example, FFJORD bijector supports hutchinson's stochastic trace estimation. The particular estimator can be provided via trace_augmentation_fn
. Similarly alternative integrators can be used by defining custom ode_solve_fn
.