Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/book1/14/supplementary/lecun1989_with_commentary.ipynb
1193 views
Kernel: Python 3

Annotated MNIST

This tutorial demonstrates how to construct the original convolutional neural network (CNN) proposed by LeCun et al. in http://yann.lecun.com/exdb/publis/pdf/lecun-89e.pdf

The original pytorch tutorial is at https://github.com/karpathy/lecun1989-repro/blob/master/prepro.py.

It is converted to use JAX/ Flax, and is based on Flax's official Annotated MNIST notebook.

1. Imports

Import JAX, Flax, ordinary NumPy, and torchvision datasets. Flax can use any data-loading pipeline and this example demonstrates how to utilize torchvision datasets.

!pip install -q flax
import jax import jax.numpy as jnp # JAX NumPy from flax import linen as nn # The Linen API from flax.training import train_state # Useful dataclass to keep train state import numpy as np # Ordinary NumPy import optax # Optimizers from torchvision import datasets # torchvision.datasets For MNIST

2. Define network

Create the original convolutional neural network with the Linen API by subclassing Module. Because the architecture in this example is fairly complex—the connection between the first and second hidden layers is quite unusual from a modern point of view—you cannot define the inlined submodules directly within the __call__ method and wrap it with the @compact decorator.

The most notable difference between LeCun1989 and recent CNNs is that the "units" in the original architecture share their weights but do not share their biases (thresholds), whereas its modern descendants share both weights and biases between the units. We define a custom LocalBias layer to capture this particularity.

class LocalBias(nn.Module): @nn.compact def __call__(self, x): bias_shape = x.shape[1:] bias = self.param("bias", nn.initializers.zeros, bias_shape, jnp.float32) bias = jnp.asarray(bias, jnp.float32) bias = bias.reshape((1,) * (x.ndim - bias.ndim) + bias.shape) return x + bias

Now, we need to write our own __call__ function. In particular, H2 neurons all connect to only 8 of the 12 input planes. We implement this with 3 separate convolutions that we concatenate the results of. Additionally, we define a custom weight-initializing function lecun1989_uniform and a static method pad to pad images with -1 on the edges.

class LeCun1989(nn.Module): """1989 LeCun ConvNet per description in the paper""" def setup(self): # The variance of Uniform[-2.4/sqrt(fan_in), 2.4/sqrt(fan_in)] # is 2.4**2/3/fan_in lecun1989_uniform = jax.nn.initializers.variance_scaling(2.4**2 / 3, "fan_in", "uniform") self.H1w = nn.Conv(12, (5, 5), 2, use_bias=False, kernel_init=lecun1989_uniform, padding="VALID") self.H1b = LocalBias() # Each slice look at 8 planes and output 4 planes, so when we # concatenate the output planes together, we get a total of 12 planes. self.H2s1w = nn.Conv(4, (5, 5), 2, use_bias=False, kernel_init=lecun1989_uniform, padding="VALID") self.H2s2w = nn.Conv(4, (5, 5), 2, use_bias=False, kernel_init=lecun1989_uniform, padding="VALID") self.H2s3w = nn.Conv(4, (5, 5), 2, use_bias=False, kernel_init=lecun1989_uniform, padding="VALID") self.H2b = LocalBias() self.H3 = nn.Dense(30, kernel_init=lecun1989_uniform) self.Out = nn.Dense(10, kernel_init=lecun1989_uniform, bias_init=jax.nn.initializers.constant(-1)) @staticmethod def pad(x): return jnp.pad(x, ((0, 0), (2, 2), (2, 2), (0, 0)), constant_values=-1.0) def __call__(self, x): x = self.pad(x) x = self.H1w(x) x = self.H1b(x) x = jnp.tanh(x) x1 = self.pad(x[..., 0:8]) x2 = self.pad(x[..., 2:10]) x3 = self.pad(x[..., 4:12]) x = jnp.concatenate([self.H2s1w(x1), self.H2s2w(x2), self.H2s3w(x3)], axis=-1) x = self.H2b(x) x = jnp.tanh(x) x = x.reshape((x.shape[0], -1)) # flatten x = self.H3(x) x = jnp.tanh(x) x = self.Out(x) x = jnp.tanh(x) return x

3. Define loss

Define a cross-entropy loss function using just jax.numpy that takes the model's logits and label vectors and returns a scalar loss. The labels can be one-hot encoded with jax.nn.one_hot, as demonstrated below.

Note that for demonstration purposes, we return nn.log_softmax() from the model and then simply multiply these (normalized) logits with the labels. In our examples/mnist folder we actually return non-normalized logits and then use optax.softmax_cross_entropy() to compute the loss, which has the same result.

def mse_loss(*, logits, labels): one_hot_labels = 2 * jax.nn.one_hot(labels, num_classes=10) - 1 return jnp.mean((logits - one_hot_labels) ** 2)

4. Metric computation

For loss and accuracy metrics, create a separate function:

def compute_metrics(*, logits, labels): loss = mse_loss(logits=logits, labels=labels) accuracy = jnp.mean(jnp.argmax(logits, -1) == labels) metrics = { "loss": loss, "accuracy": accuracy, } return metrics

5. Loading data

Define a function that loads and prepares the MNIST dataset and converts the samples to floating-point numbers.

def get_dateset(train: bool, size: int): data = datasets.MNIST("/tmp/mnist", train=train, download=True) X = data.data[:size].float() / 127.5 - 1.0 X = jnp.float32(X) X = jax.image.resize(X, (size, 16, 16), "bilinear") X = jnp.expand_dims(X, 3) Y = jnp.float32(data.targets[:size]) return {"image": X, "label": Y} def get_datasets(): """Preprocess today's MNIST dataset into 1989 version's size/format (approximately) Some relevant notes for this part: - First 7291 digits from the training set are used for training - First 2007 digits from the testing set are used for testing - each image is 16x16 pixels grayscale (not binary) - images are scaled to range [-1, 1] - paper doesn't say exactly, but reading between the lines I assume label targets to be {-1, 1} >>> from contextlib import redirect_stdout >>> with redirect_stdout(None): ... train_ds, test_ds = get_datasets() >>> >>> type(train_ds['image']) <class 'jaxlib.xla_extension.DeviceArray'> >>> train_ds['image'].shape (7291, 16, 16, 1) >>> train_ds['label'].shape (7291, 1) """ train_ds = get_dateset(True, 7291) test_ds = get_dateset(False, 2007) return train_ds, test_ds

6. Create train state

A common pattern in Flax is to create a single dataclass that represents the entire training state, including step number, parameters, and optimizer state.

Also adding optimizer & model to this state has the advantage that we only need to pass around a single argument to functions like train_step() (see below).

Because this is such a common pattern, Flax provides the class flax.training.train_state.TrainState that serves most basic usecases. Usually one would subclass it to add more data to be tracked, but in this example we can use it without any modifications.

def create_train_state(rng, learning_rate): """Creates initial `TrainState`.""" cnn = LeCun1989() params = cnn.init(rng, jnp.ones([1, 16, 16, 1]))["params"] tx = optax.sgd(learning_rate) return train_state.TrainState.create(apply_fn=cnn.apply, params=params, tx=tx)

7. Training step

A function that:

  • Evaluates the neural network given the parameters and a batch of input images with the Module.apply method.

  • Computes the mse_loss loss function.

  • Evaluates the loss function and its gradient using jax.value_and_grad.

  • Applies a pytree of gradients to the optimizer to update the model's parameters.

  • Computes the metrics using compute_metrics (defined earlier).

Use JAX's @jit decorator to trace the entire train_step function and just-in-time compile it with XLA into fused device operations that run faster and more efficiently on hardware accelerators.

@jax.jit def train_step(state, batch): """Train for a single step.""" def loss_fn(params): logits = LeCun1989().apply({"params": params}, batch["image"]) loss = mse_loss(logits=logits, labels=batch["label"]) return loss, logits grad_fn = jax.value_and_grad(loss_fn, has_aux=True) (_, logits), grads = grad_fn(state.params) state = state.apply_gradients(grads=grads) metrics = compute_metrics(logits=logits, labels=batch["label"]) return state, metrics

8. Evaluation step

Create a function that evaluates your model on the test set with Module.apply

@jax.jit def eval_step(params, batch): logits = LeCun1989().apply({"params": params}, batch["image"]) return compute_metrics(logits=logits, labels=batch["label"])

9. Train function

Define a training function that:

  • Shuffles the training data before each epoch using jax.random.permutation that takes a PRNGKey as a parameter (check the JAX - the sharp bits).

  • Runs an optimization step for each batch.

  • Retrieves the training metrics from the device with jax.device_get and computes their mean across each batch in an epoch.

  • Returns the optimizer with updated parameters and the training loss and accuracy metrics.

def train_epoch(state, train_ds, batch_size, epoch, rng): """Train for a single epoch.""" train_ds_size = len(train_ds["image"]) steps_per_epoch = train_ds_size // batch_size perms = jax.random.permutation(rng, train_ds_size) perms = perms[: steps_per_epoch * batch_size] # skip incomplete batch perms = perms.reshape((steps_per_epoch, batch_size)) batch_metrics = [] for perm in perms: batch = {k: v[perm, ...] for k, v in train_ds.items()} state, metrics = train_step(state, batch) batch_metrics.append(metrics) # compute mean of metrics across each batch in epoch. batch_metrics_np = jax.device_get(batch_metrics) epoch_metrics_np = {k: np.mean([metrics[k] for metrics in batch_metrics_np]) for k in batch_metrics_np[0]} print( "train epoch: %d, loss: %.4f, accuracy: %.2f" % (epoch, epoch_metrics_np["loss"], epoch_metrics_np["accuracy"] * 100) ) return state

10. Eval function

Create a model evaluation function that:

  • Retrieves the evaluation metrics from the device with jax.device_get.

  • Copies the metrics data stored in a JAX pytree.

def eval_model(params, test_ds): metrics = eval_step(params, test_ds) metrics = jax.device_get(metrics) summary = jax.tree_map(lambda x: x.item(), metrics) return summary["loss"], summary["accuracy"]

11. Download data

train_ds, test_ds = get_datasets()

12. Seed randomness

rng = jax.random.PRNGKey(0) rng, init_rng = jax.random.split(rng)

13. Initialize train state

Remember that function initializes both the model parameters and the optimizer and puts both into the training state dataclass that is returned.

learning_rate = 0.03
state = create_train_state(init_rng, learning_rate) del init_rng # Must not be used anymore.

We can verify that the parameters are in the correct shape.

assert state.params["H1w"]["kernel"].shape == (5, 5, 1, 12) assert state.params["H1b"]["bias"].shape == (8, 8, 12) assert state.params["H2s1w"]["kernel"].shape == (5, 5, 8, 4) assert state.params["H2b"]["bias"].shape == (4, 4, 12) assert state.params["H3"]["kernel"].shape == (4 * 4 * 12, 30) assert state.params["H3"]["bias"].shape == (30,) assert state.params["Out"]["kernel"].shape == (30, 10) assert state.params["Out"]["bias"].shape == (10,)

14. Train and evaluate

Once the training and testing is done after 23 epochs, the output should show that your model was able to achieve approximately 95% accuracy. This may not seem very impressive, but remember that this network was from 1989!

num_epochs = 23 batch_size = 1
for epoch in range(1, num_epochs + 1): # Use a separate PRNG key to permute image data during shuffling rng, input_rng = jax.random.split(rng) # Run an optimization step over a training batch state = train_epoch(state, train_ds, batch_size, epoch, input_rng) # Evaluate on the test set after each training epoch test_loss, test_accuracy = eval_model(state.params, test_ds) print(" test epoch: %d, loss: %.2f, accuracy: %.2f" % (epoch, test_loss, test_accuracy * 100))
train epoch: 1, loss: 0.1088, accuracy: 82.57 test epoch: 1, loss: 0.08, accuracy: 88.29 train epoch: 2, loss: 0.0499, accuracy: 92.90 test epoch: 2, loss: 0.07, accuracy: 89.44 train epoch: 3, loss: 0.0377, accuracy: 94.46 test epoch: 3, loss: 0.06, accuracy: 90.33 train epoch: 4, loss: 0.0320, accuracy: 95.49 test epoch: 4, loss: 0.05, accuracy: 92.33 train epoch: 5, loss: 0.0275, accuracy: 96.09 test epoch: 5, loss: 0.05, accuracy: 91.68 train epoch: 6, loss: 0.0247, accuracy: 96.58 test epoch: 6, loss: 0.05, accuracy: 92.78 train epoch: 7, loss: 0.0215, accuracy: 97.02 test epoch: 7, loss: 0.04, accuracy: 93.77 train epoch: 8, loss: 0.0197, accuracy: 97.28 test epoch: 8, loss: 0.05, accuracy: 93.42 train epoch: 9, loss: 0.0179, accuracy: 97.54 test epoch: 9, loss: 0.05, accuracy: 93.17 train epoch: 10, loss: 0.0154, accuracy: 97.94 test epoch: 10, loss: 0.05, accuracy: 92.63 train epoch: 11, loss: 0.0138, accuracy: 98.26 test epoch: 11, loss: 0.05, accuracy: 92.97 train epoch: 12, loss: 0.0120, accuracy: 98.49 test epoch: 12, loss: 0.04, accuracy: 94.37 train epoch: 13, loss: 0.0111, accuracy: 98.56 test epoch: 13, loss: 0.04, accuracy: 93.57 train epoch: 14, loss: 0.0098, accuracy: 98.77 test epoch: 14, loss: 0.04, accuracy: 93.57 train epoch: 15, loss: 0.0088, accuracy: 98.86 test epoch: 15, loss: 0.04, accuracy: 93.67 train epoch: 16, loss: 0.0075, accuracy: 98.99 test epoch: 16, loss: 0.04, accuracy: 93.92 train epoch: 17, loss: 0.0074, accuracy: 98.97 test epoch: 17, loss: 0.04, accuracy: 93.92 train epoch: 18, loss: 0.0063, accuracy: 99.15 test epoch: 18, loss: 0.04, accuracy: 94.07 train epoch: 19, loss: 0.0053, accuracy: 99.27 test epoch: 19, loss: 0.04, accuracy: 94.02 train epoch: 20, loss: 0.0051, accuracy: 99.25 test epoch: 20, loss: 0.04, accuracy: 94.32 train epoch: 21, loss: 0.0046, accuracy: 99.36 test epoch: 21, loss: 0.04, accuracy: 94.02 train epoch: 22, loss: 0.0040, accuracy: 99.41 test epoch: 22, loss: 0.04, accuracy: 94.07 train epoch: 23, loss: 0.0039, accuracy: 99.47 test epoch: 23, loss: 0.04, accuracy: 93.72

Congrats! You made it to the end of the annotated LeCun1989 example. You can revisit the same example, but structured differently as a couple of Python modules, test modules, config files, another Colab, and documentation in Flax's Git repo:

https://github.com/google/flax/tree/main/examples/mnist