Kernel: Python 3
Backpropagation Applied to MNIST
Based on Lecun 1989: http://yann.lecun.com/exdb/publis/pdf/lecun-89e.pdf
Adapted to JAX from https://github.com/karpathy/lecun1989-repro/blob/master/prepro.py
Author: Peter G. Chang (@peterchang0414)
1989 Reproduction
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
epoch 1
eval: split train. loss 5.576071e-02. error 8.11%. misses: 591
eval: split test . loss 5.287848e-02. error 7.37%. misses: 148
epoch 2
eval: split train. loss 4.097378e-02. error 5.80%. misses: 423
eval: split test . loss 4.257497e-02. error 6.08%. misses: 122
epoch 3
eval: split train. loss 3.390130e-02. error 4.92%. misses: 359
eval: split test . loss 3.796291e-02. error 5.48%. misses: 110
epoch 4
eval: split train. loss 2.989994e-02. error 4.38%. misses: 319
eval: split test . loss 3.480190e-02. error 5.23%. misses: 105
epoch 5
eval: split train. loss 2.566473e-02. error 3.77%. misses: 275
eval: split test . loss 3.232093e-02. error 4.73%. misses: 95
epoch 6
eval: split train. loss 2.348944e-02. error 3.33%. misses: 242
eval: split test . loss 3.208887e-02. error 4.58%. misses: 92
epoch 7
eval: split train. loss 2.151174e-02. error 3.09%. misses: 225
eval: split test . loss 3.206819e-02. error 4.93%. misses: 99
epoch 8
eval: split train. loss 1.941714e-02. error 2.77%. misses: 202
eval: split test . loss 3.061979e-02. error 4.73%. misses: 95
epoch 9
eval: split train. loss 1.694829e-02. error 2.41%. misses: 176
eval: split test . loss 2.916610e-02. error 4.38%. misses: 88
epoch 10
eval: split train. loss 1.605429e-02. error 2.22%. misses: 162
eval: split test . loss 2.967581e-02. error 4.58%. misses: 92
epoch 11
eval: split train. loss 1.565071e-02. error 2.18%. misses: 159
eval: split test . loss 3.011220e-02. error 4.58%. misses: 92
epoch 12
eval: split train. loss 1.397184e-02. error 1.93%. misses: 141
eval: split test . loss 2.919692e-02. error 4.53%. misses: 91
epoch 13
eval: split train. loss 1.240323e-02. error 1.59%. misses: 116
eval: split test . loss 2.727516e-02. error 3.64%. misses: 73
epoch 14
eval: split train. loss 1.198561e-02. error 1.56%. misses: 114
eval: split test . loss 2.697299e-02. error 3.89%. misses: 78
epoch 15
eval: split train. loss 1.133908e-02. error 1.44%. misses: 105
eval: split test . loss 2.733141e-02. error 3.94%. misses: 79
epoch 16
eval: split train. loss 1.065093e-02. error 1.47%. misses: 107
eval: split test . loss 2.849034e-02. error 4.09%. misses: 82
epoch 17
eval: split train. loss 9.458693e-03. error 1.26%. misses: 92
eval: split test . loss 2.668566e-02. error 3.79%. misses: 76
epoch 18
eval: split train. loss 7.680640e-03. error 1.08%. misses: 79
eval: split test . loss 2.510950e-02. error 3.74%. misses: 75
epoch 19
eval: split train. loss 6.790097e-03. error 1.00%. misses: 73
eval: split test . loss 2.578570e-02. error 3.69%. misses: 74
epoch 20
eval: split train. loss 6.345607e-03. error 0.93%. misses: 68
eval: split test . loss 2.508449e-02. error 3.54%. misses: 71
epoch 21
eval: split train. loss 5.988171e-03. error 0.92%. misses: 66
eval: split test . loss 2.509341e-02. error 3.59%. misses: 72
epoch 22
eval: split train. loss 5.771732e-03. error 0.88%. misses: 64
eval: split test . loss 2.479325e-02. error 3.54%. misses: 71
epoch 23
eval: split train. loss 5.265484e-03. error 0.82%. misses: 60
eval: split test . loss 2.467080e-02. error 3.69%. misses: 74
Results:
"Modern" Adjustments
In [ ]:
In [ ]:
In [ ]:
In [ ]:
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
In [ ]:
In [ ]:
In [ ]:
epoch 1 with learning rate 0.000300
eval: split train. loss 4.722151e-01. error 12.73%. misses: 928
eval: split test . loss 4.376389e-01. error 11.81%. misses: 237
epoch 2 with learning rate 0.000297
eval: split train. loss 3.456218e-01. error 9.77%. misses: 712
eval: split test . loss 3.105372e-01. error 8.87%. misses: 178
epoch 3 with learning rate 0.000295
eval: split train. loss 2.216365e-01. error 6.45%. misses: 469
eval: split test . loss 1.981873e-01. error 5.53%. misses: 111
epoch 4 with learning rate 0.000292
eval: split train. loss 2.072843e-01. error 5.99%. misses: 437
eval: split test . loss 1.910520e-01. error 5.48%. misses: 110
epoch 5 with learning rate 0.000290
eval: split train. loss 1.750381e-01. error 5.49%. misses: 399
eval: split test . loss 1.611853e-01. error 4.93%. misses: 99
epoch 6 with learning rate 0.000288
eval: split train. loss 1.538368e-01. error 4.42%. misses: 321
eval: split test . loss 1.411121e-01. error 4.19%. misses: 84
epoch 7 with learning rate 0.000285
eval: split train. loss 1.451264e-01. error 4.62%. misses: 337
eval: split test . loss 1.325464e-01. error 4.09%. misses: 82
epoch 8 with learning rate 0.000282
eval: split train. loss 1.257392e-01. error 3.52%. misses: 257
eval: split test . loss 1.164299e-01. error 3.34%. misses: 67
epoch 9 with learning rate 0.000280
eval: split train. loss 1.177755e-01. error 3.40%. misses: 248
eval: split test . loss 1.107324e-01. error 3.69%. misses: 74
epoch 10 with learning rate 0.000277
eval: split train. loss 1.129500e-01. error 3.26%. misses: 237
eval: split test . loss 1.068543e-01. error 3.14%. misses: 63
epoch 11 with learning rate 0.000275
eval: split train. loss 1.157665e-01. error 3.36%. misses: 245
eval: split test . loss 1.119875e-01. error 3.34%. misses: 67
epoch 12 with learning rate 0.000273
eval: split train. loss 1.185108e-01. error 3.61%. misses: 263
eval: split test . loss 1.146749e-01. error 3.69%. misses: 74
epoch 13 with learning rate 0.000270
eval: split train. loss 9.700271e-02. error 2.94%. misses: 214
eval: split test . loss 9.375140e-02. error 3.04%. misses: 61
epoch 14 with learning rate 0.000267
eval: split train. loss 1.081733e-01. error 3.10%. misses: 226
eval: split test . loss 1.054694e-01. error 3.24%. misses: 65
epoch 15 with learning rate 0.000265
eval: split train. loss 9.071133e-02. error 2.76%. misses: 201
eval: split test . loss 8.586112e-02. error 2.64%. misses: 53
epoch 16 with learning rate 0.000262
eval: split train. loss 9.541860e-02. error 2.80%. misses: 203
eval: split test . loss 9.335707e-02. error 3.19%. misses: 64
epoch 17 with learning rate 0.000260
eval: split train. loss 8.359449e-02. error 2.80%. misses: 203
eval: split test . loss 8.113335e-02. error 2.79%. misses: 56
epoch 18 with learning rate 0.000258
eval: split train. loss 8.313517e-02. error 2.46%. misses: 179
eval: split test . loss 8.725357e-02. error 2.64%. misses: 53
epoch 19 with learning rate 0.000255
eval: split train. loss 8.930960e-02. error 2.78%. misses: 203
eval: split test . loss 8.548871e-02. error 2.79%. misses: 56
epoch 20 with learning rate 0.000253
eval: split train. loss 7.986999e-02. error 2.48%. misses: 181
eval: split test . loss 7.389561e-02. error 2.64%. misses: 53
epoch 21 with learning rate 0.000250
eval: split train. loss 7.751217e-02. error 2.30%. misses: 168
eval: split test . loss 7.085717e-02. error 2.44%. misses: 49
epoch 22 with learning rate 0.000247
eval: split train. loss 6.842067e-02. error 2.15%. misses: 157
eval: split test . loss 6.652185e-02. error 2.29%. misses: 46
epoch 23 with learning rate 0.000245
eval: split train. loss 7.121788e-02. error 2.17%. misses: 158
eval: split test . loss 6.131270e-02. error 1.79%. misses: 36
epoch 24 with learning rate 0.000242
eval: split train. loss 7.509596e-02. error 2.46%. misses: 179
eval: split test . loss 6.493099e-02. error 2.14%. misses: 43
epoch 25 with learning rate 0.000240
eval: split train. loss 7.613951e-02. error 2.61%. misses: 190
eval: split test . loss 7.143638e-02. error 2.19%. misses: 44
epoch 26 with learning rate 0.000238
eval: split train. loss 7.980061e-02. error 2.65%. misses: 193
eval: split test . loss 7.566121e-02. error 2.34%. misses: 47
epoch 27 with learning rate 0.000235
eval: split train. loss 6.504884e-02. error 2.13%. misses: 155
eval: split test . loss 5.958221e-02. error 1.99%. misses: 40
epoch 28 with learning rate 0.000232
eval: split train. loss 6.683959e-02. error 2.24%. misses: 163
eval: split test . loss 6.922408e-02. error 2.59%. misses: 52
epoch 29 with learning rate 0.000230
eval: split train. loss 6.794566e-02. error 2.17%. misses: 158
eval: split test . loss 6.709250e-02. error 2.44%. misses: 49
epoch 30 with learning rate 0.000227
eval: split train. loss 6.295200e-02. error 1.96%. misses: 143
eval: split test . loss 5.890007e-02. error 2.54%. misses: 51
epoch 31 with learning rate 0.000225
eval: split train. loss 6.818665e-02. error 2.25%. misses: 164
eval: split test . loss 6.444851e-02. error 2.24%. misses: 45
epoch 32 with learning rate 0.000223
eval: split train. loss 6.571253e-02. error 2.18%. misses: 159
eval: split test . loss 6.434719e-02. error 2.44%. misses: 49
epoch 33 with learning rate 0.000220
eval: split train. loss 6.399426e-02. error 2.18%. misses: 159
eval: split test . loss 6.240412e-02. error 2.24%. misses: 45
epoch 34 with learning rate 0.000217
eval: split train. loss 5.683114e-02. error 1.80%. misses: 131
eval: split test . loss 5.610501e-02. error 1.84%. misses: 37
epoch 35 with learning rate 0.000215
eval: split train. loss 5.706797e-02. error 1.77%. misses: 129
eval: split test . loss 6.036913e-02. error 2.34%. misses: 47
epoch 36 with learning rate 0.000212
eval: split train. loss 5.528478e-02. error 1.95%. misses: 142
eval: split test . loss 5.302548e-02. error 2.04%. misses: 41
epoch 37 with learning rate 0.000210
eval: split train. loss 5.490229e-02. error 1.84%. misses: 133
eval: split test . loss 5.376581e-02. error 1.94%. misses: 39
epoch 38 with learning rate 0.000208
eval: split train. loss 5.350880e-02. error 1.67%. misses: 122
eval: split test . loss 5.158291e-02. error 1.79%. misses: 36
epoch 39 with learning rate 0.000205
eval: split train. loss 5.476158e-02. error 1.77%. misses: 129
eval: split test . loss 5.336771e-02. error 1.69%. misses: 34
epoch 40 with learning rate 0.000202
eval: split train. loss 5.242018e-02. error 1.67%. misses: 122
eval: split test . loss 5.161439e-02. error 1.89%. misses: 38
epoch 41 with learning rate 0.000200
eval: split train. loss 5.457530e-02. error 1.74%. misses: 126
eval: split test . loss 6.135549e-02. error 2.44%. misses: 49
epoch 42 with learning rate 0.000197
eval: split train. loss 5.634554e-02. error 1.91%. misses: 139
eval: split test . loss 6.446160e-02. error 2.34%. misses: 47
epoch 43 with learning rate 0.000195
eval: split train. loss 5.192847e-02. error 1.81%. misses: 132
eval: split test . loss 6.171136e-02. error 2.14%. misses: 43
epoch 44 with learning rate 0.000192
eval: split train. loss 5.048798e-02. error 1.66%. misses: 121
eval: split test . loss 5.762529e-02. error 1.94%. misses: 39
epoch 45 with learning rate 0.000190
eval: split train. loss 5.038778e-02. error 1.58%. misses: 114
eval: split test . loss 5.986194e-02. error 2.09%. misses: 42
epoch 46 with learning rate 0.000188
eval: split train. loss 4.796446e-02. error 1.69%. misses: 122
eval: split test . loss 5.005924e-02. error 1.84%. misses: 37
epoch 47 with learning rate 0.000185
eval: split train. loss 4.932489e-02. error 1.71%. misses: 125
eval: split test . loss 5.289536e-02. error 2.24%. misses: 45
epoch 48 with learning rate 0.000182
eval: split train. loss 5.115648e-02. error 1.78%. misses: 129
eval: split test . loss 5.819925e-02. error 2.04%. misses: 41
epoch 49 with learning rate 0.000180
eval: split train. loss 5.329847e-02. error 1.80%. misses: 131
eval: split test . loss 5.682039e-02. error 2.09%. misses: 42
epoch 50 with learning rate 0.000177
eval: split train. loss 4.632418e-02. error 1.59%. misses: 116
eval: split test . loss 5.570131e-02. error 2.09%. misses: 42
epoch 51 with learning rate 0.000175
eval: split train. loss 5.221667e-02. error 1.73%. misses: 126
eval: split test . loss 6.282473e-02. error 2.14%. misses: 43
epoch 52 with learning rate 0.000173
eval: split train. loss 4.739231e-02. error 1.73%. misses: 126
eval: split test . loss 5.634123e-02. error 1.99%. misses: 40
epoch 53 with learning rate 0.000170
eval: split train. loss 5.621015e-02. error 2.07%. misses: 151
eval: split test . loss 6.867130e-02. error 2.19%. misses: 44
epoch 54 with learning rate 0.000167
eval: split train. loss 4.532041e-02. error 1.60%. misses: 117
eval: split test . loss 5.811055e-02. error 1.99%. misses: 40
epoch 55 with learning rate 0.000165
eval: split train. loss 4.347728e-02. error 1.55%. misses: 113
eval: split test . loss 5.601728e-02. error 2.14%. misses: 43
epoch 56 with learning rate 0.000162
eval: split train. loss 4.743553e-02. error 1.60%. misses: 117
eval: split test . loss 6.145428e-02. error 2.24%. misses: 45
epoch 57 with learning rate 0.000160
eval: split train. loss 4.246239e-02. error 1.56%. misses: 114
eval: split test . loss 5.335664e-02. error 1.79%. misses: 36
epoch 58 with learning rate 0.000158
eval: split train. loss 4.323665e-02. error 1.45%. misses: 106
eval: split test . loss 5.636141e-02. error 1.89%. misses: 38
epoch 59 with learning rate 0.000155
eval: split train. loss 4.607718e-02. error 1.69%. misses: 122
eval: split test . loss 5.969046e-02. error 2.14%. misses: 43
epoch 60 with learning rate 0.000152
eval: split train. loss 4.451877e-02. error 1.48%. misses: 108
eval: split test . loss 5.823955e-02. error 1.99%. misses: 40
epoch 61 with learning rate 0.000150
eval: split train. loss 4.184551e-02. error 1.40%. misses: 101
eval: split test . loss 5.383835e-02. error 1.69%. misses: 34
epoch 62 with learning rate 0.000148
eval: split train. loss 4.327311e-02. error 1.49%. misses: 109
eval: split test . loss 5.188924e-02. error 1.84%. misses: 37
epoch 63 with learning rate 0.000145
eval: split train. loss 3.812368e-02. error 1.34%. misses: 97
eval: split test . loss 4.565141e-02. error 1.69%. misses: 34
epoch 64 with learning rate 0.000142
eval: split train. loss 4.123368e-02. error 1.43%. misses: 103
eval: split test . loss 5.299970e-02. error 1.84%. misses: 37
epoch 65 with learning rate 0.000140
eval: split train. loss 4.013669e-02. error 1.32%. misses: 95
eval: split test . loss 5.678133e-02. error 1.99%. misses: 40
epoch 66 with learning rate 0.000137
eval: split train. loss 3.984843e-02. error 1.34%. misses: 97
eval: split test . loss 5.329720e-02. error 2.09%. misses: 42
epoch 67 with learning rate 0.000135
eval: split train. loss 4.191425e-02. error 1.39%. misses: 101
eval: split test . loss 5.370571e-02. error 1.99%. misses: 40
epoch 68 with learning rate 0.000133
eval: split train. loss 4.354529e-02. error 1.45%. misses: 106
eval: split test . loss 5.472580e-02. error 1.99%. misses: 40
epoch 69 with learning rate 0.000130
eval: split train. loss 3.600218e-02. error 1.25%. misses: 91
eval: split test . loss 5.039397e-02. error 2.14%. misses: 43
epoch 70 with learning rate 0.000127
eval: split train. loss 3.712326e-02. error 1.14%. misses: 83
eval: split test . loss 4.781391e-02. error 1.79%. misses: 36
epoch 71 with learning rate 0.000125
eval: split train. loss 4.377073e-02. error 1.43%. misses: 103
eval: split test . loss 5.955317e-02. error 2.04%. misses: 41
epoch 72 with learning rate 0.000123
eval: split train. loss 4.096783e-02. error 1.41%. misses: 103
eval: split test . loss 5.084800e-02. error 1.94%. misses: 39
epoch 73 with learning rate 0.000120
eval: split train. loss 3.989225e-02. error 1.32%. misses: 95
eval: split test . loss 5.022623e-02. error 2.09%. misses: 42
epoch 74 with learning rate 0.000117
eval: split train. loss 3.819638e-02. error 1.43%. misses: 103
eval: split test . loss 4.982632e-02. error 1.99%. misses: 40
epoch 75 with learning rate 0.000115
eval: split train. loss 3.834034e-02. error 1.29%. misses: 94
eval: split test . loss 4.789943e-02. error 1.64%. misses: 33
epoch 76 with learning rate 0.000112
eval: split train. loss 3.586408e-02. error 1.21%. misses: 88
eval: split test . loss 4.683260e-02. error 1.69%. misses: 34
epoch 77 with learning rate 0.000110
eval: split train. loss 3.496870e-02. error 1.12%. misses: 82
eval: split test . loss 4.608429e-02. error 1.54%. misses: 31
epoch 78 with learning rate 0.000108
eval: split train. loss 3.359542e-02. error 1.07%. misses: 78
eval: split test . loss 4.598244e-02. error 1.69%. misses: 34
epoch 79 with learning rate 0.000105
eval: split train. loss 3.431604e-02. error 1.15%. misses: 84
eval: split test . loss 4.517807e-02. error 1.59%. misses: 32
epoch 80 with learning rate 0.000102
eval: split train. loss 3.316079e-02. error 1.06%. misses: 77
eval: split test . loss 4.969697e-02. error 1.74%. misses: 35
Change 1: replace tanh on last layer with FC and use softmax. Lower learning rate to 0.01
Change 2: change from SGD to AdamW with LR 3e-4, double epochs to 46, decay LR to 1e-4 over the course of training.
Change 3: Introduce data augmentation, e.g. a shift by at most 1 pixel in both x/y directions, and bump up training time to 60 epochs.
Change 4: add dropout at layer H3, shift activation function to relu, and bring up iterations to 80.