Path: blob/master/notebooks/book2/25/diffusion_mnist.ipynb
1192 views
Kernel: Python 3
Diffusion generative model for MNIST
Author: Winnie Xu.
In [1]:
Out[1]:
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting diffrax
Downloading diffrax-0.2.0-py3-none-any.whl (140 kB)
|████████████████████████████████| 140 kB 28.3 MB/s
Collecting equinox>=0.5.4
Downloading equinox-0.5.6-py3-none-any.whl (65 kB)
|████████████████████████████████| 65 kB 3.9 MB/s
Requirement already satisfied: jax>=0.3.4 in /usr/local/lib/python3.7/dist-packages (from diffrax) (0.3.14)
Requirement already satisfied: opt-einsum in /usr/local/lib/python3.7/dist-packages (from jax>=0.3.4->diffrax) (3.3.0)
Requirement already satisfied: numpy>=1.19 in /usr/local/lib/python3.7/dist-packages (from jax>=0.3.4->diffrax) (1.21.6)
Requirement already satisfied: etils[epath] in /usr/local/lib/python3.7/dist-packages (from jax>=0.3.4->diffrax) (0.6.0)
Requirement already satisfied: absl-py in /usr/local/lib/python3.7/dist-packages (from jax>=0.3.4->diffrax) (1.2.0)
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from jax>=0.3.4->diffrax) (4.1.1)
Requirement already satisfied: scipy>=1.5 in /usr/local/lib/python3.7/dist-packages (from jax>=0.3.4->diffrax) (1.7.3)
Requirement already satisfied: importlib_resources in /usr/local/lib/python3.7/dist-packages (from etils[epath]->jax>=0.3.4->diffrax) (5.9.0)
Requirement already satisfied: zipp in /usr/local/lib/python3.7/dist-packages (from etils[epath]->jax>=0.3.4->diffrax) (3.8.1)
Installing collected packages: equinox, diffrax
Successfully installed diffrax-0.2.0 equinox-0.5.6
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting einops
Downloading einops-0.4.1-py3-none-any.whl (28 kB)
Installing collected packages: einops
Successfully installed einops-0.4.1
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting optax
Downloading optax-0.1.3-py3-none-any.whl (145 kB)
|████████████████████████████████| 145 kB 16.5 MB/s
Requirement already satisfied: jax>=0.1.55 in /usr/local/lib/python3.7/dist-packages (from optax) (0.3.14)
Requirement already satisfied: numpy>=1.18.0 in /usr/local/lib/python3.7/dist-packages (from optax) (1.21.6)
Collecting chex>=0.0.4
Downloading chex-0.1.3-py3-none-any.whl (72 kB)
|████████████████████████████████| 72 kB 592 kB/s
Requirement already satisfied: typing-extensions>=3.10.0 in /usr/local/lib/python3.7/dist-packages (from optax) (4.1.1)
Requirement already satisfied: absl-py>=0.7.1 in /usr/local/lib/python3.7/dist-packages (from optax) (1.2.0)
Requirement already satisfied: jaxlib>=0.1.37 in /usr/local/lib/python3.7/dist-packages (from optax) (0.3.14+cuda11.cudnn805)
Requirement already satisfied: toolz>=0.9.0 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax) (0.12.0)
Requirement already satisfied: dm-tree>=0.1.5 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax) (0.1.7)
Requirement already satisfied: opt-einsum in /usr/local/lib/python3.7/dist-packages (from jax>=0.1.55->optax) (3.3.0)
Requirement already satisfied: etils[epath] in /usr/local/lib/python3.7/dist-packages (from jax>=0.1.55->optax) (0.6.0)
Requirement already satisfied: scipy>=1.5 in /usr/local/lib/python3.7/dist-packages (from jax>=0.1.55->optax) (1.7.3)
Requirement already satisfied: flatbuffers<3.0,>=1.12 in /usr/local/lib/python3.7/dist-packages (from jaxlib>=0.1.37->optax) (2.0)
Requirement already satisfied: importlib_resources in /usr/local/lib/python3.7/dist-packages (from etils[epath]->jax>=0.1.55->optax) (5.9.0)
Requirement already satisfied: zipp in /usr/local/lib/python3.7/dist-packages (from etils[epath]->jax>=0.1.55->optax) (3.8.1)
Installing collected packages: chex, optax
Successfully installed chex-0.1.3 optax-0.1.3
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Requirement already satisfied: equinox in /usr/local/lib/python3.7/dist-packages (0.5.6)
Requirement already satisfied: jax>=0.3.4 in /usr/local/lib/python3.7/dist-packages (from equinox) (0.3.14)
Requirement already satisfied: scipy>=1.5 in /usr/local/lib/python3.7/dist-packages (from jax>=0.3.4->equinox) (1.7.3)
Requirement already satisfied: opt-einsum in /usr/local/lib/python3.7/dist-packages (from jax>=0.3.4->equinox) (3.3.0)
Requirement already satisfied: absl-py in /usr/local/lib/python3.7/dist-packages (from jax>=0.3.4->equinox) (1.2.0)
Requirement already satisfied: etils[epath] in /usr/local/lib/python3.7/dist-packages (from jax>=0.3.4->equinox) (0.6.0)
Requirement already satisfied: numpy>=1.19 in /usr/local/lib/python3.7/dist-packages (from jax>=0.3.4->equinox) (1.21.6)
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from jax>=0.3.4->equinox) (4.1.1)
Requirement already satisfied: importlib_resources in /usr/local/lib/python3.7/dist-packages (from etils[epath]->jax>=0.3.4->equinox) (5.9.0)
Requirement already satisfied: zipp in /usr/local/lib/python3.7/dist-packages (from etils[epath]->jax>=0.3.4->equinox) (3.8.1)
Download Dataset
In [2]:
In [3]:
Out[3]:
Downloaded https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz to /content/data/mnist/train-images-idx3-ubyte.gz
In [4]:
Out[4]:
(60000, 1, 28, 28)
In [5]:
Out[5]:
<matplotlib.image.AxesImage at 0x7f2fc063d5d0>
Define Score Model
In [6]:
Define Training Objective
In [7]:
In [8]:
Train Score-Based Model
In [9]:
In [10]:
Out[10]:
Step=0 Loss=1.003075122833252
Step=10000 Loss=0.029320005891285836
Step=20000 Loss=0.019951716899499296
Step=30000 Loss=0.018376670414488764
Step=40000 Loss=0.017525794696807862
Step=50000 Loss=0.01694488068567589
Step=60000 Loss=0.01650137415379286
Step=70000 Loss=0.016147429181076586
Step=80000 Loss=0.015854475118219854
Step=90000 Loss=0.015603535754419863
Step=99999 Loss=0.015381233805701314
In [11]:
Out[11]:
<Figure size 432x288 with 0 Axes>