Path: blob/master/notebooks/misc/gan_mog_mode_hopping.ipynb
1192 views
Kernel: Python 3
Mixture of Gaussians example with GANs
This code was adapted from the ODEGAN code here: https://github.com/deepmind/deepmind-research/blob/master/ode_gan/odegan_mog16.ipynb
The original colab was created by Chongli Qin. Adapted by Mihaela Rosca.
This code implements GANs for Mixture of Gaussians.
It also provides an implementation of ODEGAN (Training Generative Adversarial Networks by Solving Ordinary Differential Equations by Qin et al.)
The approach of ODE-GAN was mentioned in the book as using higer order integrators such as RungeKutta4.
In [ ]:
Collecting dm-haiku
Downloading dm_haiku-0.0.4-py3-none-any.whl (284 kB)
|████████████████████████████████| 284 kB 10.6 MB/s eta 0:00:01
Requirement already satisfied: tabulate>=0.8.9 in /usr/local/lib/python3.7/dist-packages (from dm-haiku) (0.8.9)
Requirement already satisfied: absl-py>=0.7.1 in /usr/local/lib/python3.7/dist-packages (from dm-haiku) (0.12.0)
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from dm-haiku) (3.7.4.3)
Requirement already satisfied: numpy>=1.18.0 in /usr/local/lib/python3.7/dist-packages (from dm-haiku) (1.19.5)
Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from absl-py>=0.7.1->dm-haiku) (1.15.0)
Installing collected packages: dm-haiku
Successfully installed dm-haiku-0.0.4
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
/usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py:3176: UserWarning: Explicitly requested dtype float64 requested in zeros is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
lax._check_user_dtype_supported(dtype, "zeros")
i = 0, discriminant loss = 1.5134295, generator loss = 0.70949596
i = 2000, discriminant loss = 1.3139982, generator loss = 0.7433742
i = 4000, discriminant loss = 1.1802745, generator loss = 1.3980222
i = 6000, discriminant loss = 0.99104536, generator loss = 1.8252661
i = 8000, discriminant loss = 0.6599817, generator loss = 1.4244308
i = 10000, discriminant loss = 0.74728256, generator loss = 2.8712406
i = 12000, discriminant loss = 0.28268716, generator loss = 2.6096854
i = 14000, discriminant loss = 0.42307052, generator loss = 2.5968158
i = 16000, discriminant loss = 0.5030432, generator loss = 1.4615185
i = 18000, discriminant loss = 0.38424677, generator loss = 2.1412582
i = 20000, discriminant loss = 0.53288984, generator loss = 2.3324332
In [ ]: