Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/book2/25/diffusion_mnist.ipynb
1192 views
Kernel: Python 3

Diffusion generative model for MNIST

Author: Winnie Xu.

import array import functools as ft import gzip import os import struct import urllib.request !pip install diffrax import diffrax as dfx # https://github.com/patrick-kidger/diffrax !pip install einops import einops # https://github.com/arogozhnikov/einops import numpy as np import matplotlib.pyplot as plt import jax import jax.numpy as jnp import jax.random as jr import matplotlib.pyplot as plt !pip install optax import optax # https://github.com/deepmind/optax !pip install equinox import equinox as eqx from typing import List
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/ Collecting diffrax Downloading diffrax-0.2.0-py3-none-any.whl (140 kB) |████████████████████████████████| 140 kB 28.3 MB/s Collecting equinox>=0.5.4 Downloading equinox-0.5.6-py3-none-any.whl (65 kB) |████████████████████████████████| 65 kB 3.9 MB/s Requirement already satisfied: jax>=0.3.4 in /usr/local/lib/python3.7/dist-packages (from diffrax) (0.3.14) Requirement already satisfied: opt-einsum in /usr/local/lib/python3.7/dist-packages (from jax>=0.3.4->diffrax) (3.3.0) Requirement already satisfied: numpy>=1.19 in /usr/local/lib/python3.7/dist-packages (from jax>=0.3.4->diffrax) (1.21.6) Requirement already satisfied: etils[epath] in /usr/local/lib/python3.7/dist-packages (from jax>=0.3.4->diffrax) (0.6.0) Requirement already satisfied: absl-py in /usr/local/lib/python3.7/dist-packages (from jax>=0.3.4->diffrax) (1.2.0) Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from jax>=0.3.4->diffrax) (4.1.1) Requirement already satisfied: scipy>=1.5 in /usr/local/lib/python3.7/dist-packages (from jax>=0.3.4->diffrax) (1.7.3) Requirement already satisfied: importlib_resources in /usr/local/lib/python3.7/dist-packages (from etils[epath]->jax>=0.3.4->diffrax) (5.9.0) Requirement already satisfied: zipp in /usr/local/lib/python3.7/dist-packages (from etils[epath]->jax>=0.3.4->diffrax) (3.8.1) Installing collected packages: equinox, diffrax Successfully installed diffrax-0.2.0 equinox-0.5.6 Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/ Collecting einops Downloading einops-0.4.1-py3-none-any.whl (28 kB) Installing collected packages: einops Successfully installed einops-0.4.1 Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/ Collecting optax Downloading optax-0.1.3-py3-none-any.whl (145 kB) |████████████████████████████████| 145 kB 16.5 MB/s Requirement already satisfied: jax>=0.1.55 in /usr/local/lib/python3.7/dist-packages (from optax) (0.3.14) Requirement already satisfied: numpy>=1.18.0 in /usr/local/lib/python3.7/dist-packages (from optax) (1.21.6) Collecting chex>=0.0.4 Downloading chex-0.1.3-py3-none-any.whl (72 kB) |████████████████████████████████| 72 kB 592 kB/s Requirement already satisfied: typing-extensions>=3.10.0 in /usr/local/lib/python3.7/dist-packages (from optax) (4.1.1) Requirement already satisfied: absl-py>=0.7.1 in /usr/local/lib/python3.7/dist-packages (from optax) (1.2.0) Requirement already satisfied: jaxlib>=0.1.37 in /usr/local/lib/python3.7/dist-packages (from optax) (0.3.14+cuda11.cudnn805) Requirement already satisfied: toolz>=0.9.0 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax) (0.12.0) Requirement already satisfied: dm-tree>=0.1.5 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax) (0.1.7) Requirement already satisfied: opt-einsum in /usr/local/lib/python3.7/dist-packages (from jax>=0.1.55->optax) (3.3.0) Requirement already satisfied: etils[epath] in /usr/local/lib/python3.7/dist-packages (from jax>=0.1.55->optax) (0.6.0) Requirement already satisfied: scipy>=1.5 in /usr/local/lib/python3.7/dist-packages (from jax>=0.1.55->optax) (1.7.3) Requirement already satisfied: flatbuffers<3.0,>=1.12 in /usr/local/lib/python3.7/dist-packages (from jaxlib>=0.1.37->optax) (2.0) Requirement already satisfied: importlib_resources in /usr/local/lib/python3.7/dist-packages (from etils[epath]->jax>=0.1.55->optax) (5.9.0) Requirement already satisfied: zipp in /usr/local/lib/python3.7/dist-packages (from etils[epath]->jax>=0.1.55->optax) (3.8.1) Installing collected packages: chex, optax Successfully installed chex-0.1.3 optax-0.1.3 Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/ Requirement already satisfied: equinox in /usr/local/lib/python3.7/dist-packages (0.5.6) Requirement already satisfied: jax>=0.3.4 in /usr/local/lib/python3.7/dist-packages (from equinox) (0.3.14) Requirement already satisfied: scipy>=1.5 in /usr/local/lib/python3.7/dist-packages (from jax>=0.3.4->equinox) (1.7.3) Requirement already satisfied: opt-einsum in /usr/local/lib/python3.7/dist-packages (from jax>=0.3.4->equinox) (3.3.0) Requirement already satisfied: absl-py in /usr/local/lib/python3.7/dist-packages (from jax>=0.3.4->equinox) (1.2.0) Requirement already satisfied: etils[epath] in /usr/local/lib/python3.7/dist-packages (from jax>=0.3.4->equinox) (0.6.0) Requirement already satisfied: numpy>=1.19 in /usr/local/lib/python3.7/dist-packages (from jax>=0.3.4->equinox) (1.21.6) Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from jax>=0.3.4->equinox) (4.1.1) Requirement already satisfied: importlib_resources in /usr/local/lib/python3.7/dist-packages (from etils[epath]->jax>=0.3.4->equinox) (5.9.0) Requirement already satisfied: zipp in /usr/local/lib/python3.7/dist-packages (from etils[epath]->jax>=0.3.4->equinox) (3.8.1)

Download Dataset

def mnist(): filename = "train-images-idx3-ubyte.gz" url_dir = "https://storage.googleapis.com/cvdf-datasets/mnist" target_dir = os.getcwd() + "/data/mnist" url = f"{url_dir}/{filename}" target = f"{target_dir}/{filename}" if not os.path.exists(target): os.makedirs(target_dir, exist_ok=True) urllib.request.urlretrieve(url, target) print(f"Downloaded {url} to {target}") with gzip.open(target, "rb") as fh: _, batch, rows, cols = struct.unpack(">IIII", fh.read(16)) shape = (batch, 1, rows, cols) dataset = jnp.array(array.array("B", fh.read()), dtype=jnp.uint8).reshape(shape) data_mean, data_std = jnp.mean(dataset), jnp.std(dataset) dataset = (dataset - data_mean) / data_std return dataset def dataloader(data, batch_size, *, key): dataset_size = data.shape[0] indices = jnp.arange(dataset_size) while True: perm = jr.permutation(key, indices) (key,) = jr.split(key, 1) start = 0 end = batch_size while end < dataset_size: batch_perm = perm[start:end] yield data[batch_perm] start = end end = start + batch_size
train_ds = mnist()
Downloaded https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz to /content/data/mnist/train-images-idx3-ubyte.gz
train_ds.shape
(60000, 1, 28, 28)
plt.imshow(train_ds[0].squeeze(), cmap="Greys")
<matplotlib.image.AxesImage at 0x7f2fc063d5d0>
Image in a Jupyter notebook

Define Score Model

class MixerBlock(eqx.Module): patch_mixer: eqx.nn.MLP hidden_mixer: eqx.nn.MLP norm1: eqx.nn.LayerNorm norm2: eqx.nn.LayerNorm def __init__(self, num_patches, hidden_size, mix_patch_size, mix_hidden_size, *, key): tkey, ckey = jr.split(key, 2) self.patch_mixer = eqx.nn.MLP(num_patches, num_patches, mix_patch_size, depth=1, key=tkey) self.hidden_mixer = eqx.nn.MLP(hidden_size, hidden_size, mix_hidden_size, depth=1, key=ckey) self.norm1 = eqx.nn.LayerNorm((hidden_size, num_patches)) self.norm2 = eqx.nn.LayerNorm((num_patches, hidden_size)) def __call__(self, y): y = y + jax.vmap(self.patch_mixer)(self.norm1(y)) y = einops.rearrange(y, "c p -> p c") y = y + jax.vmap(self.hidden_mixer)(self.norm2(y)) y = einops.rearrange(y, "p c -> c p") return y class Mixer2d(eqx.Module): conv_in: eqx.nn.Conv2d conv_out: eqx.nn.ConvTranspose2d blocks: list norm: eqx.nn.LayerNorm t1: float def __init__( self, img_size, patch_size, hidden_size, mix_patch_size, mix_hidden_size, num_blocks, t1, *, key, ): input_size, height, width = img_size assert (height % patch_size) == 0 assert (width % patch_size) == 0 num_patches = (height // patch_size) * (width // patch_size) inkey, outkey, *bkeys = jr.split(key, 2 + num_blocks) self.conv_in = eqx.nn.Conv2d(input_size + 1, hidden_size, patch_size, stride=patch_size, key=inkey) self.conv_out = eqx.nn.ConvTranspose2d(hidden_size, input_size, patch_size, stride=patch_size, key=outkey) self.blocks = [ MixerBlock(num_patches, hidden_size, mix_patch_size, mix_hidden_size, key=bkey) for bkey in bkeys ] self.norm = eqx.nn.LayerNorm((hidden_size, num_patches)) self.t1 = t1 def __call__(self, t, y): t = t / self.t1 _, height, width = y.shape t = einops.repeat(t, "-> 1 h w", h=height, w=width) y = jnp.concatenate([y, t]) y = self.conv_in(y) _, patch_height, patch_width = y.shape y = einops.rearrange(y, "c h w -> c (h w)") for block in self.blocks: y = block(y) y = self.norm(y) y = einops.rearrange(y, "c (h w) -> c h w", h=patch_height, w=patch_width) return self.conv_out(y)

Define Training Objective

def single_loss_fn(model, weight, int_beta, data, t, key): mean = data * jnp.exp(-0.5 * int_beta(t)) var = jnp.maximum(1 - jnp.exp(-int_beta(t)), 1e-5) std = jnp.sqrt(var) noise = jr.normal(key, data.shape) y = mean + std * noise pred = model(t, y) return weight(t) * jnp.mean((pred + noise / std) ** 2) def batch_loss_fn(model, weight, int_beta, data, t1, key): batch_size = data.shape[0] tkey, losskey = jr.split(key) losskey = jr.split(losskey, batch_size) # Low-discrepancy sampling over t to reduce variance t = jr.uniform(tkey, (batch_size,), minval=0, maxval=t1 / batch_size) t = t + (t1 / batch_size) * jnp.arange(batch_size) loss_fn = ft.partial(single_loss_fn, model, weight, int_beta) loss_fn = jax.vmap(loss_fn) return jnp.mean(loss_fn(data, t, losskey)) @eqx.filter_jit def single_sample_fn(model, int_beta, data_shape, dt0, t1, key): def drift(t, y, args): _, beta = jax.jvp(int_beta, (t,), (jnp.ones_like(t),)) return -0.5 * beta * (y + model(t, y)) term = dfx.ODETerm(drift) solver = dfx.Tsit5() t0 = 0 y1 = jr.normal(key, data_shape) # reverse time, solve from t1 to t0 sol = dfx.diffeqsolve(term, solver, t1, t0, -dt0, y1, adjoint=dfx.NoAdjoint()) return sol.ys[0]
def mnist(): filename = "train-images-idx3-ubyte.gz" url_dir = "https://storage.googleapis.com/cvdf-datasets/mnist" target_dir = os.getcwd() + "/data/mnist" url = f"{url_dir}/{filename}" target = f"{target_dir}/{filename}" if not os.path.exists(target): os.makedirs(target_dir, exist_ok=True) urllib.request.urlretrieve(url, target) print(f"Downloaded {url} to {target}") with gzip.open(target, "rb") as fh: _, batch, rows, cols = struct.unpack(">IIII", fh.read(16)) shape = (batch, 1, rows, cols) return jnp.array(array.array("B", fh.read()), dtype=jnp.uint8).reshape(shape) def dataloader(data, batch_size, *, key): dataset_size = data.shape[0] indices = jnp.arange(dataset_size) while True: perm = jr.permutation(key, indices) (key,) = jr.split(key, 1) start = 0 end = batch_size while end < dataset_size: batch_perm = perm[start:end] yield data[batch_perm] start = end end = start + batch_size

Train Score-Based Model

@eqx.filter_jit def make_step(model, weight, int_beta, data, t1, key, opt_state, opt_update): loss_fn = eqx.filter_value_and_grad(batch_loss_fn) loss, grads = loss_fn(model, weight, int_beta, data, t1, key) updates, opt_state = opt_update(grads, opt_state) model = eqx.apply_updates(model, updates) key = jr.split(key, 1)[0] return loss, model, key, opt_state def main( # Model hyperparameters patch_size=4, hidden_size=64, mix_patch_size=512, mix_hidden_size=512, num_blocks=4, t1=10.0, # Optimisation hyperparameters num_steps=50000, # 100000, #1_000_000, lr=3e-4, batch_size=256, print_every=10_000, # Sampling hyperparameters dt0=0.1, sample_size=10, # Seed seed=5678, ): key = jr.PRNGKey(seed) model_key, train_key, loader_key, sample_key = jr.split(key, 4) data = mnist() data_mean = jnp.mean(data) data_std = jnp.std(data) data_max = jnp.max(data) data_min = jnp.min(data) data_shape = data.shape[1:] data = (data - data_mean) / data_std model = Mixer2d( data_shape, patch_size, hidden_size, mix_patch_size, mix_hidden_size, num_blocks, t1, key=model_key, ) int_beta = lambda t: t # Try experimenting with other options here! weight = lambda t: 1 - jnp.exp(-int_beta(t)) # Just chosen to upweight the region near t=0. opt = optax.adabelief(lr) # Optax will update the floating-point JAX arrays in the model. opt_state = opt.init(eqx.filter(model, eqx.is_inexact_array)) total_value = 0 total_size = 0 for step, data in zip(range(num_steps), dataloader(data, batch_size, key=loader_key)): value, model, train_key, opt_state = make_step( model, weight, int_beta, data, t1, train_key, opt_state, opt.update ) total_value += value.item() total_size += 1 if (step % print_every) == 0 or step == num_steps - 1: print(f"Step={step} Loss={total_value / total_size}") total_value = 0 total_size = 0 sample_key = jr.split(sample_key, sample_size**2) sample_fn = ft.partial(single_sample_fn, model, int_beta, data_shape, dt0, t1) sample = jax.vmap(sample_fn)(sample_key) sample = data_mean + data_std * sample sample = jnp.clip(sample, data_min, data_max) sample = einops.rearrange(sample, "(n1 n2) 1 h w -> (n1 h) (n2 w)", n1=sample_size, n2=sample_size) plt.imshow(sample, cmap="Greys") plt.axis("off") plt.tight_layout() plt.show()
main()
Step=0 Loss=1.003075122833252 Step=10000 Loss=0.029320005891285836 Step=20000 Loss=0.019951716899499296 Step=30000 Loss=0.018376670414488764 Step=40000 Loss=0.017525794696807862 Step=50000 Loss=0.01694488068567589 Step=60000 Loss=0.01650137415379286 Step=70000 Loss=0.016147429181076586 Step=80000 Loss=0.015854475118219854 Step=90000 Loss=0.015603535754419863 Step=99999 Loss=0.015381233805701314
Image in a Jupyter notebook
plt.savefig("diffusion_mnist.pdf")
<Figure size 432x288 with 0 Axes>