Path: blob/master/notebooks/book2/26/ipm_divergences.ipynb
1192 views
Critics in IPMs variational bounds on -divergences
Author: Mihaela Rosca
This colab uses a simple example (two 1-d distributions) to show how the critics of various IPMs (Wasserstein distance and MMD) look like. We also look at how smooth estimators (neural nets) can estimte density ratios which are not smooth, and how that can be useful in providing a good learning signal for a model.
Collecting dm-haiku
Downloading dm_haiku-0.0.4-py3-none-any.whl (284 kB)
|████████████████████████████████| 284 kB 7.9 MB/s
Requirement already satisfied: absl-py>=0.7.1 in /usr/local/lib/python3.7/dist-packages (from dm-haiku) (0.12.0)
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from dm-haiku) (3.7.4.3)
Requirement already satisfied: tabulate>=0.8.9 in /usr/local/lib/python3.7/dist-packages (from dm-haiku) (0.8.9)
Requirement already satisfied: numpy>=1.18.0 in /usr/local/lib/python3.7/dist-packages (from dm-haiku) (1.19.5)
Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from absl-py>=0.7.1->dm-haiku) (1.15.0)
Installing collected packages: dm-haiku
Successfully installed dm-haiku-0.0.4
Collecting optax
Downloading optax-0.0.9-py3-none-any.whl (118 kB)
|████████████████████████████████| 118 kB 9.1 MB/s
Requirement already satisfied: jaxlib>=0.1.37 in /usr/local/lib/python3.7/dist-packages (from optax) (0.1.69+cuda110)
Collecting chex>=0.0.4
Downloading chex-0.0.8-py3-none-any.whl (57 kB)
|████████████████████████████████| 57 kB 5.5 MB/s
Requirement already satisfied: absl-py>=0.7.1 in /usr/local/lib/python3.7/dist-packages (from optax) (0.12.0)
Requirement already satisfied: numpy>=1.18.0 in /usr/local/lib/python3.7/dist-packages (from optax) (1.19.5)
Requirement already satisfied: jax>=0.1.55 in /usr/local/lib/python3.7/dist-packages (from optax) (0.2.17)
Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from absl-py>=0.7.1->optax) (1.15.0)
Requirement already satisfied: dm-tree>=0.1.5 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax) (0.1.6)
Requirement already satisfied: toolz>=0.9.0 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax) (0.11.1)
Requirement already satisfied: opt-einsum in /usr/local/lib/python3.7/dist-packages (from jax>=0.1.55->optax) (3.3.0)
Requirement already satisfied: flatbuffers<3.0,>=1.12 in /usr/local/lib/python3.7/dist-packages (from jaxlib>=0.1.37->optax) (1.12)
Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from jaxlib>=0.1.37->optax) (1.4.1)
Installing collected packages: chex, optax
Successfully installed chex-0.0.8 optax-0.0.9
KL and non overlapping distributions
non overlapping distributions (visual)
explain ratio will be infinity - integral
move the distributions closer and they will not have signal
Approximation of the ratio using the f-gan approach
Gradients
In order to see why the learned density ratio has useful properties for learning, we can plot the gradients of the learned density ratio across the input space
Wasserstein distance for the same two distributions
Computing the Wasserstein critic in 1 dimension. Reminder that the Wasserstein distance is defined as:
The below code finds the values of f evaluated at the samples of the two distributions. This vector is computed to maximise the empirical (Monte Carlo) estimate of the IPM:
where are samples from the first distribution, while are samples from the second distribution. Since we want the function to be 1-Lipschitz, an inequality constraint is added to ensure that for all two choices of samples in the two distributions,
This maximisation needs to occur under the constraint that the function is 1-Lipschitz, which is ensured uisng the constraint on the linear program.
Note: This approach does not scale to large datasets.
Thank you to Arthur Gretton and Dougal J Sutherland for this version of the code.
MMD computation
The MMD is an IPM defined as:
where is a RKHS. Using the mean embedding operators in an RKHS, we can write:
replacing in the MMD:
which means that
To obtain an estimate of evaluated at we use that:
to estimate we use:
To estimate the dot products, we use:
For more details see the slides here: http://www.gatsby.ucl.ac.uk/~gretton/coursefiles/lecture5_distribEmbed_1.pdf