Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/notebooks/gan_mixture_of_gaussians.ipynb
1192 views
Kernel: Python 3

Open In Colab

This notebook implements a Generative Adversarial Network to fit a synthetic dataset generated from a mixture of Gaussians in 2D.

The code was adapted from the ODEGAN code here: https://github.com/deepmind/deepmind-research/blob/master/ode_gan/odegan_mog16.ipynb. The original notebook was created by Chongli Qin‎.

Some modifications made by Mihaela Rosca here were also incorporated.

Imports

!pip install -q flax
|████████████████████████████████| 184 kB 10.5 MB/s |████████████████████████████████| 140 kB 33.7 MB/s |████████████████████████████████| 72 kB 403 kB/s
from typing import Sequence import matplotlib.pyplot as plt import jax import jax.numpy as jnp import flax.linen as nn from flax.training import train_state import optax import functools import scipy as sp import math rng = jax.random.PRNGKey(0)

Data Generation

Data is generated from a 2D mixture of Gaussians.

@functools.partial(jax.jit, static_argnums=(1,)) def real_data(rng, batch_size): mog_mean = jnp.array( [ [1.50, 1.50], [1.50, 0.50], [1.50, -0.50], [1.50, -1.50], [0.50, 1.50], [0.50, 0.50], [0.50, -0.50], [0.50, -1.50], [-1.50, 1.50], [-1.50, 0.50], [-1.50, -0.50], [-1.50, -1.50], [-0.50, 1.50], [-0.50, 0.50], [-0.50, -0.50], [-0.50, -1.50], ] ) temp = jnp.tile(mog_mean, (batch_size // 16 + 1, 1)) mus = temp[0:batch_size, :] return mus + 0.02 * jax.random.normal(rng, shape=(batch_size, 2))

Plotting

def plot_on_ax(ax, values, contours=None, bbox=None, xlabel="", ylabel="", title="", cmap="Blues"): kernel = sp.stats.gaussian_kde(values.T) ax.axis(bbox) ax.set_aspect(abs(bbox[1] - bbox[0]) / abs(bbox[3] - bbox[2])) ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) ax.set_xticks([]) ax.set_yticks([]) xx, yy = jnp.mgrid[bbox[0] : bbox[1] : 300j, bbox[2] : bbox[3] : 300j] positions = jnp.vstack([xx.ravel(), yy.ravel()]) f = jnp.reshape(kernel(positions).T, xx.shape) cfset = ax.contourf(xx, yy, f, cmap=cmap) if contours is not None: x = jnp.arange(-2.0, 2.0, 0.1) y = jnp.arange(-2.0, 2.0, 0.1) cx, cy = jnp.meshgrid(x, y) new_set = ax.contour( cx, cy, contours.squeeze().reshape(cx.shape), levels=20, colors="k", linewidths=0.8, alpha=0.5 ) ax.set_title(title)

Models and Training

A multilayer perceptron with the ReLU activation function.

class MLP(nn.Module): features: Sequence[int] @nn.compact def __call__(self, x): for feat in self.features[:-1]: x = jax.nn.relu(nn.Dense(features=feat)(x)) x = nn.Dense(features=self.features[-1])(x) return x

The loss function for the discriminator is:

LD(ϕ,θ)=Ep(x)g(Dϕ(x))+Eq(z)h(Dϕ(Gθ(z)))L_D(\phi, \theta) = \mathbb{E}_{p^*(x)} g(D_\phi(x)) + \mathbb{E}_{q(z)}h(D_\phi(G_\theta(z)))

where g(t)=logtg(t) = -\log t, h(t)=log(1t)h(t) = -\log(1 - t) as in the original GAN.

@jax.jit def discriminator_step(disc_state, gen_state, latents, real_examples): def loss_fn(disc_params): fake_examples = gen_state.apply_fn(gen_state.params, latents) real_logits = disc_state.apply_fn(disc_params, real_examples) fake_logits = disc_state.apply_fn(disc_params, fake_examples) disc_real = -jax.nn.log_sigmoid(real_logits) # log(1 - sigmoid(x)) = log_sigmoid(-x) disc_fake = -jax.nn.log_sigmoid(-fake_logits) return jnp.mean(disc_real + disc_fake) disc_loss, disc_grad = jax.value_and_grad(loss_fn)(disc_state.params) disc_state = disc_state.apply_gradients(grads=disc_grad) return disc_state, disc_loss

The loss function for the generator is:

LG(ϕ,θ)=Eq(z)l(Dϕ(Gθ(z))L_G(\phi, \theta) = \mathbb{E}_{q(z)} l(D_\phi(G_\theta(z))

where l(t)=logtl(t) = -\log t for the non-saturating generator loss.

@jax.jit def generator_step(disc_state, gen_state, latents): def loss_fn(gen_params): fake_examples = gen_state.apply_fn(gen_params, latents) fake_logits = disc_state.apply_fn(disc_state.params, fake_examples) disc_fake = -jax.nn.log_sigmoid(fake_logits) return jnp.mean(disc_fake) gen_loss, gen_grad = jax.value_and_grad(loss_fn)(gen_state.params) gen_state = gen_state.apply_gradients(grads=gen_grad) return gen_state, gen_loss

Perform a training step by first updating the discriminator parameters ϕ\phi using the gradient ϕLD(ϕ,θ)\nabla_\phi L_D (\phi, \theta) and then updating the generator parameters θ\theta using the gradient θLG(ϕ,θ)\nabla_\theta L_G (\phi, \theta).

@jax.jit def train_step(disc_state, gen_state, latents, real_examples): disc_state, disc_loss = discriminator_step(disc_state, gen_state, latents, real_examples) gen_state, gen_loss = generator_step(disc_state, gen_state, latents) return disc_state, gen_state, disc_loss, gen_loss
batch_size = 512 latent_size = 32 discriminator = MLP(features=[25, 25, 1]) generator = MLP(features=[25, 25, 2])
# Initialize parameters for the discriminator and the generator latents = jax.random.normal(rng, shape=(batch_size, latent_size)) real_examples = real_data(rng, batch_size) disc_params = discriminator.init(rng, real_examples) gen_params = generator.init(rng, latents)
# Plot real examples bbox = [-2, 2, -2, 2] plot_on_ax(plt.gca(), real_examples, bbox=bbox, title="Data") plt.tight_layout() plt.savefig("gan_gmm_data.pdf") plt.show()
Image in a Jupyter notebook
# Create train states for the discriminator and the generator lr = 0.05 disc_state = train_state.TrainState.create( apply_fn=discriminator.apply, params=disc_params, tx=optax.sgd(learning_rate=lr) ) gen_state = train_state.TrainState.create(apply_fn=generator.apply, params=gen_params, tx=optax.sgd(learning_rate=lr))
# x and y grid for plotting discriminator contours x = jnp.arange(-2.0, 2.0, 0.1) y = jnp.arange(-2.0, 2.0, 0.1) X, Y = jnp.meshgrid(x, y) pairs = jnp.stack((X, Y), axis=-1) pairs = jnp.reshape(pairs, (-1, 2)) # Latents for testing generator test_latents = jax.random.normal(rng, shape=(batch_size * 10, latent_size))
num_iters = 20001 n_save = 2000 draw_contours = False history = [] for i in range(num_iters): rng_iter = jax.random.fold_in(rng, i) data_rng, latent_rng = jax.random.split(rng_iter) # Sample minibatch of examples real_examples = real_data(data_rng, batch_size) # Sample minibatch of latents latents = jax.random.normal(latent_rng, shape=(batch_size, latent_size)) # Update both the generator disc_state, gen_state, disc_loss, gen_loss = train_step(disc_state, gen_state, latents, real_examples) if i % n_save == 0: print(f"i = {i}, Discriminator Loss = {disc_loss}, " + f"Generator Loss = {gen_loss}") # Generate examples using the test latents fake_examples = gen_state.apply_fn(gen_state.params, test_latents) if draw_contours: real_logits = disc_state.apply_fn(disc_state.params, pairs) disc_contour = -real_logits + jax.nn.log_sigmoid(real_logits) else: disc_contour = None history.append((i, fake_examples, disc_contour, disc_loss, gen_loss))
i = 0, Discriminator Loss = 1.399444580078125, Generator Loss = 0.7146230340003967 i = 2000, Discriminator Loss = 1.2572847604751587, Generator Loss = 0.7397961616516113 i = 4000, Discriminator Loss = 1.2612197399139404, Generator Loss = 0.6012262105941772 i = 6000, Discriminator Loss = 1.1314020156860352, Generator Loss = 0.7101864814758301 i = 8000, Discriminator Loss = 1.03669011592865, Generator Loss = 1.6814203262329102 i = 10000, Discriminator Loss = 0.9454075694084167, Generator Loss = 0.8876760005950928 i = 12000, Discriminator Loss = 0.8018237352371216, Generator Loss = 3.261043071746826 i = 14000, Discriminator Loss = 0.6465665102005005, Generator Loss = 1.6940462589263916 i = 16000, Discriminator Loss = 0.3630889654159546, Generator Loss = 2.8523144721984863 i = 18000, Discriminator Loss = 0.43770408630371094, Generator Loss = 2.005770683288574 i = 20000, Discriminator Loss = 0.25742101669311523, Generator Loss = 6.7867584228515625

Plot Results

Plot the data and the examples generated by the generator.

# Plot generated examples from history for i, hist in enumerate(history): iter, fake_examples, contours, disc_loss, gen_loss = hist plot_on_ax( plt.gca(), fake_examples, contours=contours, bbox=bbox, xlabel=f"Disc Loss: {disc_loss:.3f} | Gen Loss: {gen_loss:.3f}", title=f"Samples at Iteration {iter}", ) plt.tight_layout() plt.savefig(f"gan_gmm_iter_{iter}.pdf") plt.show()
Image in a Jupyter notebookImage in a Jupyter notebookImage in a Jupyter notebookImage in a Jupyter notebookImage in a Jupyter notebookImage in a Jupyter notebookImage in a Jupyter notebookImage in a Jupyter notebookImage in a Jupyter notebookImage in a Jupyter notebookImage in a Jupyter notebook
cols = 3 rows = math.ceil((len(history) + 1) / cols) bbox = [-2, 2, -2, 2] fig, axs = plt.subplots(rows, cols, figsize=(cols * 3, rows * 3), dpi=200) axs = axs.flatten() # Plot real examples plot_on_ax(axs[0], real_examples, bbox=bbox, title="Data") # Plot generated examples from history for i, hist in enumerate(history): iter, fake_examples, contours, disc_loss, gen_loss = hist plot_on_ax( axs[i + 1], fake_examples, contours=contours, bbox=bbox, xlabel=f"Disc Loss: {disc_loss:.3f} | Gen Loss: {gen_loss:.3f}", title=f"Samples at Iteration {iter}", ) # Remove extra plots from the figure for i in range(len(history) + 1, len(axs)): axs[i].remove() plt.tight_layout() plt.show()
Image in a Jupyter notebook