Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/notebooks/lecun1989_flax.ipynb
1192 views
Kernel: Python 3

Open In Colab

Backpropagation Applied to MNIST

Based on Lecun 1989: http://yann.lecun.com/exdb/publis/pdf/lecun-89e.pdf

Adapted to JAX from https://github.com/karpathy/lecun1989-repro/blob/master/prepro.py

Author: Peter G. Chang (@peterchang0414)

1989 Reproduction

!pip install -q flax
import jax import jax.numpy as jnp from flax import linen as nn from torchvision import datasets def get_datasets(n_tr, n_te): train_test = {} for split in {"train", "test"}: data = datasets.MNIST("./data", train=split == "train", download=True) n = n_tr if split == "train" else n_te key = jax.random.PRNGKey(42) rp = jax.random.permutation(key, len(data))[:n] X = jnp.full((n, 16, 16, 1), 0.0, dtype=jnp.float32) Y = jnp.full((n, 10), -1.0, dtype=jnp.float32) for i, ix in enumerate(rp): I, yint = data[int(ix)] xi = jnp.array(I, dtype=jnp.float32) / 127.5 - 1.0 xi = jax.image.resize(xi, (16, 16), "bilinear") X = X.at[i].set(jnp.expand_dims(xi, axis=2)) Y = Y.at[i, yint].set(1.0) train_test[split] = (X, Y) return train_test
from flax import linen as nn from flax.training import train_state from flax.linen.activation import tanh import optax from typing import Callable class Net(nn.Module): bias_init: Callable = nn.initializers.zeros # sqrt(6) = 2.449... used by he_uniform() approximates Karpathy's 2.4 kernel_init: Callable = nn.initializers.he_uniform() @nn.compact def __call__(self, x): x = jnp.pad(x, [(0, 0), (2, 2), (2, 2), (0, 0)], constant_values=-1.0) x = nn.Conv( features=12, kernel_size=(5, 5), strides=2, padding="VALID", use_bias=False, kernel_init=self.kernel_init )(x) bias1 = self.param("bias1", self.bias_init, (8, 8, 12)) x = tanh(x + bias1) x = jnp.pad(x, [(0, 0), (2, 2), (2, 2), (0, 0)], constant_values=-1.0) x1, x2, x3 = (x[..., 0:8], x[..., 4:12], jnp.concatenate((x[..., 0:4], x[..., 8:12]), axis=-1)) slice1 = nn.Conv( features=4, kernel_size=(5, 5), strides=2, padding="VALID", use_bias=False, kernel_init=self.kernel_init )(x1) slice2 = nn.Conv( features=4, kernel_size=(5, 5), strides=2, padding="VALID", use_bias=False, kernel_init=self.kernel_init )(x2) slice3 = nn.Conv( features=4, kernel_size=(5, 5), strides=2, padding="VALID", use_bias=False, kernel_init=self.kernel_init )(x3) x = jnp.concatenate((slice1, slice2, slice3), axis=-1) bias2 = self.param("bias2", self.bias_init, (4, 4, 12)) x = tanh(x + bias2) x = x.reshape((x.shape[0], -1)) x = nn.Dense(features=30, use_bias=False)(x) bias3 = self.param("bias3", self.bias_init, (30,)) x = tanh(x + bias3) x = nn.Dense(features=10, use_bias=False)(x) bias4 = self.param("bias4", nn.initializers.constant(-1.0), (10,)) x = tanh(x + bias4) return x
@jax.jit def eval_step(params, X, Y): Yhat = Net().apply({"params": params}, X) loss = jnp.mean((Yhat - Y) ** 2) err = jnp.mean(jnp.argmax(Y, -1) != jnp.argmax(Yhat, -1)).astype(float) return loss, err
def eval_split(data, split, params): X, Y = data[split] loss, err = eval_step(params, X, Y) print(f"eval: split {split:5s}. loss {loss:e}. error {err*100:.2f}%. misses: {int(err*Y.shape[0])}")
from jax import value_and_grad import optax from flax.training import train_state def create_train_state(key, lr, X): model = Net() params = model.init(key, X)["params"] sgd_opt = optax.sgd(lr) return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=sgd_opt) @jax.jit def train_step(state, X, Y): def loss_fn(params): Yhat = Net().apply({"params": params}, X) loss = jnp.mean((Yhat - Y) ** 2) err = jnp.mean(jnp.argmax(Y, -1) != jnp.argmax(Yhat, -1)).astype(float) return loss, err (_, Yhats), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params) state = state.apply_gradients(grads=grads) return state def train_one_epoch(state, X, Y): for step_num in range(X.shape[0]): x, y = jnp.expand_dims(X[step_num], 0), jnp.expand_dims(Y[step_num], 0) state = train_step(state, x, y) return state def train(key, data, epochs, lr): Xtr, Ytr = data["train"] Xte, Yte = data["test"] train_state = create_train_state(key, lr, Xtr) for epoch in range(epochs): print(f"epoch {epoch+1}") train_state = train_one_epoch(train_state, Xtr, Ytr) for split in ["train", "test"]: eval_split(data, split, train_state.params)
data = get_datasets(7291, 2007)
key, _ = jax.random.split(jax.random.PRNGKey(42)) train(key, data, 23, 0.03)
epoch 1 eval: split train. loss 5.576071e-02. error 8.11%. misses: 591 eval: split test . loss 5.287848e-02. error 7.37%. misses: 148 epoch 2 eval: split train. loss 4.097378e-02. error 5.80%. misses: 423 eval: split test . loss 4.257497e-02. error 6.08%. misses: 122 epoch 3 eval: split train. loss 3.390130e-02. error 4.92%. misses: 359 eval: split test . loss 3.796291e-02. error 5.48%. misses: 110 epoch 4 eval: split train. loss 2.989994e-02. error 4.38%. misses: 319 eval: split test . loss 3.480190e-02. error 5.23%. misses: 105 epoch 5 eval: split train. loss 2.566473e-02. error 3.77%. misses: 275 eval: split test . loss 3.232093e-02. error 4.73%. misses: 95 epoch 6 eval: split train. loss 2.348944e-02. error 3.33%. misses: 242 eval: split test . loss 3.208887e-02. error 4.58%. misses: 92 epoch 7 eval: split train. loss 2.151174e-02. error 3.09%. misses: 225 eval: split test . loss 3.206819e-02. error 4.93%. misses: 99 epoch 8 eval: split train. loss 1.941714e-02. error 2.77%. misses: 202 eval: split test . loss 3.061979e-02. error 4.73%. misses: 95 epoch 9 eval: split train. loss 1.694829e-02. error 2.41%. misses: 176 eval: split test . loss 2.916610e-02. error 4.38%. misses: 88 epoch 10 eval: split train. loss 1.605429e-02. error 2.22%. misses: 162 eval: split test . loss 2.967581e-02. error 4.58%. misses: 92 epoch 11 eval: split train. loss 1.565071e-02. error 2.18%. misses: 159 eval: split test . loss 3.011220e-02. error 4.58%. misses: 92 epoch 12 eval: split train. loss 1.397184e-02. error 1.93%. misses: 141 eval: split test . loss 2.919692e-02. error 4.53%. misses: 91 epoch 13 eval: split train. loss 1.240323e-02. error 1.59%. misses: 116 eval: split test . loss 2.727516e-02. error 3.64%. misses: 73 epoch 14 eval: split train. loss 1.198561e-02. error 1.56%. misses: 114 eval: split test . loss 2.697299e-02. error 3.89%. misses: 78 epoch 15 eval: split train. loss 1.133908e-02. error 1.44%. misses: 105 eval: split test . loss 2.733141e-02. error 3.94%. misses: 79 epoch 16 eval: split train. loss 1.065093e-02. error 1.47%. misses: 107 eval: split test . loss 2.849034e-02. error 4.09%. misses: 82 epoch 17 eval: split train. loss 9.458693e-03. error 1.26%. misses: 92 eval: split test . loss 2.668566e-02. error 3.79%. misses: 76 epoch 18 eval: split train. loss 7.680640e-03. error 1.08%. misses: 79 eval: split test . loss 2.510950e-02. error 3.74%. misses: 75 epoch 19 eval: split train. loss 6.790097e-03. error 1.00%. misses: 73 eval: split test . loss 2.578570e-02. error 3.69%. misses: 74 epoch 20 eval: split train. loss 6.345607e-03. error 0.93%. misses: 68 eval: split test . loss 2.508449e-02. error 3.54%. misses: 71 epoch 21 eval: split train. loss 5.988171e-03. error 0.92%. misses: 66 eval: split test . loss 2.509341e-02. error 3.59%. misses: 72 epoch 22 eval: split train. loss 5.771732e-03. error 0.88%. misses: 64 eval: split test . loss 2.479325e-02. error 3.54%. misses: 71 epoch 23 eval: split train. loss 5.265484e-03. error 0.82%. misses: 60 eval: split test . loss 2.467080e-02. error 3.69%. misses: 74

Results:

epoch 23 eval: split train. loss 5.265484e-03. error 0.82%. misses: 60 eval: split test . loss 2.467080e-02. error 3.69%. misses: 74

"Modern" Adjustments

!pip install -q flax
import jax import jax.numpy as jnp from flax import linen as nn from torchvision import datasets def get_datasets(n_tr, n_te): train_test = {} for split in {"train", "test"}: data = datasets.MNIST("./data", train=split == "train", download=True) n = n_tr if split == "train" else n_te key = jax.random.PRNGKey(42) rp = jax.random.permutation(key, len(data))[:n] X = jnp.full((n, 16, 16, 1), 0.0, dtype=jnp.float32) Y = jnp.full((n, 10), 0, dtype=jnp.float32) for i, ix in enumerate(rp): I, yint = data[int(ix)] xi = jnp.array(I, dtype=jnp.float32) / 127.5 - 1.0 xi = jax.image.resize(xi, (16, 16), "bilinear") X = X.at[i].set(jnp.expand_dims(xi, axis=2)) Y = Y.at[i, yint].set(1.0) train_test[split] = (X, Y) return train_test
from flax import linen as nn from flax.training import train_state from flax.linen.activation import tanh import optax from typing import Callable class Net(nn.Module): training: bool bias_init: Callable = nn.initializers.zeros # sqrt(6) = 2.449... used by he_uniform() approximates Karpathy's 2.4 kernel_init: Callable = nn.initializers.he_uniform() @nn.compact def __call__(self, x): if self.training: augment_rng = self.make_rng("aug") shift_x, shift_y = jax.random.randint(augment_rng, (2,), -1, 2) x = jnp.roll(x, (shift_x, shift_y), (1, 2)) x = jnp.pad(x, [(0, 0), (2, 2), (2, 2), (0, 0)], constant_values=-1.0) x = nn.Conv( features=12, kernel_size=(5, 5), strides=2, padding="VALID", use_bias=False, kernel_init=self.kernel_init )(x) bias1 = self.param("bias1", self.bias_init, (8, 8, 12)) x = nn.relu(x + bias1) x = jnp.pad(x, [(0, 0), (2, 2), (2, 2), (0, 0)], constant_values=-1.0) x1, x2, x3 = (x[..., 0:8], x[..., 4:12], jnp.concatenate((x[..., 0:4], x[..., 8:12]), axis=-1)) slice1 = nn.Conv( features=4, kernel_size=(5, 5), strides=2, padding="VALID", use_bias=False, kernel_init=self.kernel_init )(x1) slice2 = nn.Conv( features=4, kernel_size=(5, 5), strides=2, padding="VALID", use_bias=False, kernel_init=self.kernel_init )(x2) slice3 = nn.Conv( features=4, kernel_size=(5, 5), strides=2, padding="VALID", use_bias=False, kernel_init=self.kernel_init )(x3) x = jnp.concatenate((slice1, slice2, slice3), axis=-1) bias2 = self.param("bias2", self.bias_init, (4, 4, 12)) x = nn.relu(x + bias2) x = nn.Dropout(0.25, deterministic=not self.training)(x) x = x.reshape((x.shape[0], -1)) x = nn.Dense(features=30, use_bias=False)(x) bias3 = self.param("bias3", self.bias_init, (30,)) x = nn.relu(x + bias3) x = nn.Dense(features=10, use_bias=False)(x) bias4 = self.param("bias4", self.bias_init, (10,)) x = x + bias4 return x
from jax import value_and_grad import optax from flax.training import train_state def learning_rate_fn(initial_rate, epochs, steps_per_epoch): return optax.linear_schedule( init_value=initial_rate, end_value=initial_rate / 3, transition_steps=epochs * steps_per_epoch ) def create_train_state(key, X, lr_fn): model = Net(training=True) key1, key2, key3 = jax.random.split(key, 3) params = model.init({"params": key1, "aug": key2, "dropout": key3}, X)["params"] opt = optax.adamw(lr_fn) return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=opt) @jax.jit def train_step(state, X, Y, rng=jax.random.PRNGKey(0)): aug_rng, dropout_rng = jax.random.split(jax.random.fold_in(rng, state.step)) def loss_fn(params): Yhat = Net(training=True).apply({"params": params}, X, rngs={"aug": aug_rng, "dropout": dropout_rng}) loss = jnp.mean(optax.softmax_cross_entropy(logits=Yhat, labels=Y)) err = jnp.mean(jnp.argmax(Y, -1) != jnp.argmax(Yhat, -1)).astype(float) return loss, err (_, Yhats), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params) state = state.apply_gradients(grads=grads) return state def train_one_epoch(state, X, Y): for step_num in range(X.shape[0]): x, y = jnp.expand_dims(X[step_num], 0), jnp.expand_dims(Y[step_num], 0) state = train_step(state, x, y) return state def train(key, data, epochs, lr): Xtr, Ytr = data["train"] Xte, Yte = data["test"] lr_fn = learning_rate_fn(lr, epochs, Xtr.shape[0]) train_state = create_train_state(key, Xtr, lr_fn) for epoch in range(epochs): print(f"epoch {epoch+1} with learning rate {lr_fn(train_state.step):.6f}") train_state = train_one_epoch(train_state, Xtr, Ytr) for split in ["train", "test"]: eval_split(data, split, train_state.params)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
@jax.jit def eval_step(params, X, Y): Yhat = Net(training=False).apply({"params": params}, X) loss = jnp.mean(optax.softmax_cross_entropy(logits=Yhat, labels=Y)) err = jnp.mean(jnp.argmax(Y, -1) != jnp.argmax(Yhat, -1)).astype(float) return loss, err def eval_split(data, split, params): X, Y = data[split] loss, err = eval_step(params, X, Y) print(f"eval: split {split:5s}. loss {loss:e}. error {err*100:.2f}%. misses: {int(err*Y.shape[0])}")
data = get_datasets(7291, 2007)
key, _ = jax.random.split(jax.random.PRNGKey(42)) train(key, data, 80, 3e-4)
epoch 1 with learning rate 0.000300 eval: split train. loss 4.722151e-01. error 12.73%. misses: 928 eval: split test . loss 4.376389e-01. error 11.81%. misses: 237 epoch 2 with learning rate 0.000297 eval: split train. loss 3.456218e-01. error 9.77%. misses: 712 eval: split test . loss 3.105372e-01. error 8.87%. misses: 178 epoch 3 with learning rate 0.000295 eval: split train. loss 2.216365e-01. error 6.45%. misses: 469 eval: split test . loss 1.981873e-01. error 5.53%. misses: 111 epoch 4 with learning rate 0.000292 eval: split train. loss 2.072843e-01. error 5.99%. misses: 437 eval: split test . loss 1.910520e-01. error 5.48%. misses: 110 epoch 5 with learning rate 0.000290 eval: split train. loss 1.750381e-01. error 5.49%. misses: 399 eval: split test . loss 1.611853e-01. error 4.93%. misses: 99 epoch 6 with learning rate 0.000288 eval: split train. loss 1.538368e-01. error 4.42%. misses: 321 eval: split test . loss 1.411121e-01. error 4.19%. misses: 84 epoch 7 with learning rate 0.000285 eval: split train. loss 1.451264e-01. error 4.62%. misses: 337 eval: split test . loss 1.325464e-01. error 4.09%. misses: 82 epoch 8 with learning rate 0.000282 eval: split train. loss 1.257392e-01. error 3.52%. misses: 257 eval: split test . loss 1.164299e-01. error 3.34%. misses: 67 epoch 9 with learning rate 0.000280 eval: split train. loss 1.177755e-01. error 3.40%. misses: 248 eval: split test . loss 1.107324e-01. error 3.69%. misses: 74 epoch 10 with learning rate 0.000277 eval: split train. loss 1.129500e-01. error 3.26%. misses: 237 eval: split test . loss 1.068543e-01. error 3.14%. misses: 63 epoch 11 with learning rate 0.000275 eval: split train. loss 1.157665e-01. error 3.36%. misses: 245 eval: split test . loss 1.119875e-01. error 3.34%. misses: 67 epoch 12 with learning rate 0.000273 eval: split train. loss 1.185108e-01. error 3.61%. misses: 263 eval: split test . loss 1.146749e-01. error 3.69%. misses: 74 epoch 13 with learning rate 0.000270 eval: split train. loss 9.700271e-02. error 2.94%. misses: 214 eval: split test . loss 9.375140e-02. error 3.04%. misses: 61 epoch 14 with learning rate 0.000267 eval: split train. loss 1.081733e-01. error 3.10%. misses: 226 eval: split test . loss 1.054694e-01. error 3.24%. misses: 65 epoch 15 with learning rate 0.000265 eval: split train. loss 9.071133e-02. error 2.76%. misses: 201 eval: split test . loss 8.586112e-02. error 2.64%. misses: 53 epoch 16 with learning rate 0.000262 eval: split train. loss 9.541860e-02. error 2.80%. misses: 203 eval: split test . loss 9.335707e-02. error 3.19%. misses: 64 epoch 17 with learning rate 0.000260 eval: split train. loss 8.359449e-02. error 2.80%. misses: 203 eval: split test . loss 8.113335e-02. error 2.79%. misses: 56 epoch 18 with learning rate 0.000258 eval: split train. loss 8.313517e-02. error 2.46%. misses: 179 eval: split test . loss 8.725357e-02. error 2.64%. misses: 53 epoch 19 with learning rate 0.000255 eval: split train. loss 8.930960e-02. error 2.78%. misses: 203 eval: split test . loss 8.548871e-02. error 2.79%. misses: 56 epoch 20 with learning rate 0.000253 eval: split train. loss 7.986999e-02. error 2.48%. misses: 181 eval: split test . loss 7.389561e-02. error 2.64%. misses: 53 epoch 21 with learning rate 0.000250 eval: split train. loss 7.751217e-02. error 2.30%. misses: 168 eval: split test . loss 7.085717e-02. error 2.44%. misses: 49 epoch 22 with learning rate 0.000247 eval: split train. loss 6.842067e-02. error 2.15%. misses: 157 eval: split test . loss 6.652185e-02. error 2.29%. misses: 46 epoch 23 with learning rate 0.000245 eval: split train. loss 7.121788e-02. error 2.17%. misses: 158 eval: split test . loss 6.131270e-02. error 1.79%. misses: 36 epoch 24 with learning rate 0.000242 eval: split train. loss 7.509596e-02. error 2.46%. misses: 179 eval: split test . loss 6.493099e-02. error 2.14%. misses: 43 epoch 25 with learning rate 0.000240 eval: split train. loss 7.613951e-02. error 2.61%. misses: 190 eval: split test . loss 7.143638e-02. error 2.19%. misses: 44 epoch 26 with learning rate 0.000238 eval: split train. loss 7.980061e-02. error 2.65%. misses: 193 eval: split test . loss 7.566121e-02. error 2.34%. misses: 47 epoch 27 with learning rate 0.000235 eval: split train. loss 6.504884e-02. error 2.13%. misses: 155 eval: split test . loss 5.958221e-02. error 1.99%. misses: 40 epoch 28 with learning rate 0.000232 eval: split train. loss 6.683959e-02. error 2.24%. misses: 163 eval: split test . loss 6.922408e-02. error 2.59%. misses: 52 epoch 29 with learning rate 0.000230 eval: split train. loss 6.794566e-02. error 2.17%. misses: 158 eval: split test . loss 6.709250e-02. error 2.44%. misses: 49 epoch 30 with learning rate 0.000227 eval: split train. loss 6.295200e-02. error 1.96%. misses: 143 eval: split test . loss 5.890007e-02. error 2.54%. misses: 51 epoch 31 with learning rate 0.000225 eval: split train. loss 6.818665e-02. error 2.25%. misses: 164 eval: split test . loss 6.444851e-02. error 2.24%. misses: 45 epoch 32 with learning rate 0.000223 eval: split train. loss 6.571253e-02. error 2.18%. misses: 159 eval: split test . loss 6.434719e-02. error 2.44%. misses: 49 epoch 33 with learning rate 0.000220 eval: split train. loss 6.399426e-02. error 2.18%. misses: 159 eval: split test . loss 6.240412e-02. error 2.24%. misses: 45 epoch 34 with learning rate 0.000217 eval: split train. loss 5.683114e-02. error 1.80%. misses: 131 eval: split test . loss 5.610501e-02. error 1.84%. misses: 37 epoch 35 with learning rate 0.000215 eval: split train. loss 5.706797e-02. error 1.77%. misses: 129 eval: split test . loss 6.036913e-02. error 2.34%. misses: 47 epoch 36 with learning rate 0.000212 eval: split train. loss 5.528478e-02. error 1.95%. misses: 142 eval: split test . loss 5.302548e-02. error 2.04%. misses: 41 epoch 37 with learning rate 0.000210 eval: split train. loss 5.490229e-02. error 1.84%. misses: 133 eval: split test . loss 5.376581e-02. error 1.94%. misses: 39 epoch 38 with learning rate 0.000208 eval: split train. loss 5.350880e-02. error 1.67%. misses: 122 eval: split test . loss 5.158291e-02. error 1.79%. misses: 36 epoch 39 with learning rate 0.000205 eval: split train. loss 5.476158e-02. error 1.77%. misses: 129 eval: split test . loss 5.336771e-02. error 1.69%. misses: 34 epoch 40 with learning rate 0.000202 eval: split train. loss 5.242018e-02. error 1.67%. misses: 122 eval: split test . loss 5.161439e-02. error 1.89%. misses: 38 epoch 41 with learning rate 0.000200 eval: split train. loss 5.457530e-02. error 1.74%. misses: 126 eval: split test . loss 6.135549e-02. error 2.44%. misses: 49 epoch 42 with learning rate 0.000197 eval: split train. loss 5.634554e-02. error 1.91%. misses: 139 eval: split test . loss 6.446160e-02. error 2.34%. misses: 47 epoch 43 with learning rate 0.000195 eval: split train. loss 5.192847e-02. error 1.81%. misses: 132 eval: split test . loss 6.171136e-02. error 2.14%. misses: 43 epoch 44 with learning rate 0.000192 eval: split train. loss 5.048798e-02. error 1.66%. misses: 121 eval: split test . loss 5.762529e-02. error 1.94%. misses: 39 epoch 45 with learning rate 0.000190 eval: split train. loss 5.038778e-02. error 1.58%. misses: 114 eval: split test . loss 5.986194e-02. error 2.09%. misses: 42 epoch 46 with learning rate 0.000188 eval: split train. loss 4.796446e-02. error 1.69%. misses: 122 eval: split test . loss 5.005924e-02. error 1.84%. misses: 37 epoch 47 with learning rate 0.000185 eval: split train. loss 4.932489e-02. error 1.71%. misses: 125 eval: split test . loss 5.289536e-02. error 2.24%. misses: 45 epoch 48 with learning rate 0.000182 eval: split train. loss 5.115648e-02. error 1.78%. misses: 129 eval: split test . loss 5.819925e-02. error 2.04%. misses: 41 epoch 49 with learning rate 0.000180 eval: split train. loss 5.329847e-02. error 1.80%. misses: 131 eval: split test . loss 5.682039e-02. error 2.09%. misses: 42 epoch 50 with learning rate 0.000177 eval: split train. loss 4.632418e-02. error 1.59%. misses: 116 eval: split test . loss 5.570131e-02. error 2.09%. misses: 42 epoch 51 with learning rate 0.000175 eval: split train. loss 5.221667e-02. error 1.73%. misses: 126 eval: split test . loss 6.282473e-02. error 2.14%. misses: 43 epoch 52 with learning rate 0.000173 eval: split train. loss 4.739231e-02. error 1.73%. misses: 126 eval: split test . loss 5.634123e-02. error 1.99%. misses: 40 epoch 53 with learning rate 0.000170 eval: split train. loss 5.621015e-02. error 2.07%. misses: 151 eval: split test . loss 6.867130e-02. error 2.19%. misses: 44 epoch 54 with learning rate 0.000167 eval: split train. loss 4.532041e-02. error 1.60%. misses: 117 eval: split test . loss 5.811055e-02. error 1.99%. misses: 40 epoch 55 with learning rate 0.000165 eval: split train. loss 4.347728e-02. error 1.55%. misses: 113 eval: split test . loss 5.601728e-02. error 2.14%. misses: 43 epoch 56 with learning rate 0.000162 eval: split train. loss 4.743553e-02. error 1.60%. misses: 117 eval: split test . loss 6.145428e-02. error 2.24%. misses: 45 epoch 57 with learning rate 0.000160 eval: split train. loss 4.246239e-02. error 1.56%. misses: 114 eval: split test . loss 5.335664e-02. error 1.79%. misses: 36 epoch 58 with learning rate 0.000158 eval: split train. loss 4.323665e-02. error 1.45%. misses: 106 eval: split test . loss 5.636141e-02. error 1.89%. misses: 38 epoch 59 with learning rate 0.000155 eval: split train. loss 4.607718e-02. error 1.69%. misses: 122 eval: split test . loss 5.969046e-02. error 2.14%. misses: 43 epoch 60 with learning rate 0.000152 eval: split train. loss 4.451877e-02. error 1.48%. misses: 108 eval: split test . loss 5.823955e-02. error 1.99%. misses: 40 epoch 61 with learning rate 0.000150 eval: split train. loss 4.184551e-02. error 1.40%. misses: 101 eval: split test . loss 5.383835e-02. error 1.69%. misses: 34 epoch 62 with learning rate 0.000148 eval: split train. loss 4.327311e-02. error 1.49%. misses: 109 eval: split test . loss 5.188924e-02. error 1.84%. misses: 37 epoch 63 with learning rate 0.000145 eval: split train. loss 3.812368e-02. error 1.34%. misses: 97 eval: split test . loss 4.565141e-02. error 1.69%. misses: 34 epoch 64 with learning rate 0.000142 eval: split train. loss 4.123368e-02. error 1.43%. misses: 103 eval: split test . loss 5.299970e-02. error 1.84%. misses: 37 epoch 65 with learning rate 0.000140 eval: split train. loss 4.013669e-02. error 1.32%. misses: 95 eval: split test . loss 5.678133e-02. error 1.99%. misses: 40 epoch 66 with learning rate 0.000137 eval: split train. loss 3.984843e-02. error 1.34%. misses: 97 eval: split test . loss 5.329720e-02. error 2.09%. misses: 42 epoch 67 with learning rate 0.000135 eval: split train. loss 4.191425e-02. error 1.39%. misses: 101 eval: split test . loss 5.370571e-02. error 1.99%. misses: 40 epoch 68 with learning rate 0.000133 eval: split train. loss 4.354529e-02. error 1.45%. misses: 106 eval: split test . loss 5.472580e-02. error 1.99%. misses: 40 epoch 69 with learning rate 0.000130 eval: split train. loss 3.600218e-02. error 1.25%. misses: 91 eval: split test . loss 5.039397e-02. error 2.14%. misses: 43 epoch 70 with learning rate 0.000127 eval: split train. loss 3.712326e-02. error 1.14%. misses: 83 eval: split test . loss 4.781391e-02. error 1.79%. misses: 36 epoch 71 with learning rate 0.000125 eval: split train. loss 4.377073e-02. error 1.43%. misses: 103 eval: split test . loss 5.955317e-02. error 2.04%. misses: 41 epoch 72 with learning rate 0.000123 eval: split train. loss 4.096783e-02. error 1.41%. misses: 103 eval: split test . loss 5.084800e-02. error 1.94%. misses: 39 epoch 73 with learning rate 0.000120 eval: split train. loss 3.989225e-02. error 1.32%. misses: 95 eval: split test . loss 5.022623e-02. error 2.09%. misses: 42 epoch 74 with learning rate 0.000117 eval: split train. loss 3.819638e-02. error 1.43%. misses: 103 eval: split test . loss 4.982632e-02. error 1.99%. misses: 40 epoch 75 with learning rate 0.000115 eval: split train. loss 3.834034e-02. error 1.29%. misses: 94 eval: split test . loss 4.789943e-02. error 1.64%. misses: 33 epoch 76 with learning rate 0.000112 eval: split train. loss 3.586408e-02. error 1.21%. misses: 88 eval: split test . loss 4.683260e-02. error 1.69%. misses: 34 epoch 77 with learning rate 0.000110 eval: split train. loss 3.496870e-02. error 1.12%. misses: 82 eval: split test . loss 4.608429e-02. error 1.54%. misses: 31 epoch 78 with learning rate 0.000108 eval: split train. loss 3.359542e-02. error 1.07%. misses: 78 eval: split test . loss 4.598244e-02. error 1.69%. misses: 34 epoch 79 with learning rate 0.000105 eval: split train. loss 3.431604e-02. error 1.15%. misses: 84 eval: split test . loss 4.517807e-02. error 1.59%. misses: 32 epoch 80 with learning rate 0.000102 eval: split train. loss 3.316079e-02. error 1.06%. misses: 77 eval: split test . loss 4.969697e-02. error 1.74%. misses: 35

Change 1: replace tanh on last layer with FC and use softmax. Lower learning rate to 0.01

epoch 23 eval: split train. loss 7.162272e-03. error 0.05%. misses: 4 eval: split test . loss 1.687743e-01. error 4.14%. misses: 83

Change 2: change from SGD to AdamW with LR 3e-4, double epochs to 46, decay LR to 1e-4 over the course of training.

epoch 46 eval: split train. loss 1.890260e-03. error 0.04%. misses: 2 eval: split test . loss 1.953933e-01. error 4.04%. misses: 81

Change 3: Introduce data augmentation, e.g. a shift by at most 1 pixel in both x/y directions, and bump up training time to 60 epochs.

epoch 60 eval: split train. loss 5.098452e-02. error 1.65%. misses: 120 eval: split test . loss 9.166716e-02. error 2.59%. misses: 52

Change 4: add dropout at layer H3, shift activation function to relu, and bring up iterations to 80.

epoch 80 eval: split train. loss 3.316079e-02. error 1.06%. misses: 77 eval: split test . loss 4.969697e-02. error 1.74%. misses: 35