Kernel: Python 3
Optimization using Flax
Flax is a JAX library for creating deep neural networks. It also has a simple optimization library built in. Below we show how to fit a multi-class logistic regression model using flax.
In [ ]:
In [ ]:
jax version 0.2.7
Import code
In [ ]:
Building wheel for flax (setup.py) ... done
In [ ]:
In [ ]:
Cloning into 'pyprobml'...
remote: Enumerating objects: 165, done.
remote: Counting objects: 100% (165/165), done.
remote: Compressing objects: 100% (113/113), done.
remote: Total 5880 (delta 99), reused 89 (delta 51), pack-reused 5715
Receiving objects: 100% (5880/5880), 200.88 MiB | 31.76 MiB/s, done.
Resolving deltas: 100% (3347/3347), done.
Checking out files: 100% (484/484), done.
In [ ]:
testing fit-flax
train step: 0, loss: 2.5385, accuracy: 0.00
train step: 1, loss: 2.3316, accuracy: 0.00
FrozenDict({
Dense_0: {
bias: DeviceArray([-0.07041515, 0.02852547, 0.02282 , 0.02621041,
0.04047536, 0.0282678 , 0.02441146, 0.0334954 ,
-0.16449487, 0.03070411], dtype=float32),
kernel: DeviceArray([[-1.3670444e-01, 3.4958541e-02, -7.1266890e-03,
2.5135964e-02, 2.7813643e-02, 5.0281063e-03,
2.9270068e-02, 3.3206850e-02, 7.8319311e-03,
-1.9413888e-02],
[ 2.3842663e-02, -1.4299959e-02, 6.5577030e-04,
-6.1702281e-03, -1.8243194e-02, -3.5261810e-03,
-9.2503726e-03, -7.1800947e-03, 2.8626800e-02,
5.5447742e-03],
[ 1.2182933e-01, -1.1020064e-02, 6.6978633e-03,
-1.8483013e-02, 7.7033639e-03, 1.0741353e-03,
-1.6080946e-02, -2.6806772e-02, -7.5067461e-02,
1.0153547e-02],
[-6.5444291e-02, -1.2274325e-02, -2.4778858e-02,
-2.5078654e-03, -4.0774703e-02, -2.3661405e-02,
-3.6602318e-03, 3.6180019e-05, 2.0730698e-01,
-3.4241520e-02],
[-8.0969959e-02, 7.9388879e-03, 3.2993864e-02,
2.8362900e-02, 8.9270771e-03, 3.1862110e-02,
1.6838670e-02, 3.9151609e-02, -1.4136736e-01,
5.6262240e-02]], dtype=float32),
},
})
test passed
Now we show the source code for the fitting function in the file editor on the RHS.
In [ ]:
<IPython.core.display.Javascript object>
Data
We use the tensorflow datasets library to make it easy to create minibatches.
We switch to the multi-class version of Iris.
In [ ]:
(30, 4)
(30,)
(50, 4)
(50,)
Model
In [ ]:
Training loop
In [ ]:
train step: 0, loss: 4.2830, accuracy: 0.35
train step: 20, loss: 0.9130, accuracy: 0.65
train step: 40, loss: 0.1380, accuracy: 0.96
train step: 60, loss: 0.1236, accuracy: 0.96
train step: 80, loss: 0.1094, accuracy: 0.98
train step: 100, loss: 0.1041, accuracy: 0.98
train step: 120, loss: 0.1002, accuracy: 0.98
train step: 140, loss: 0.0969, accuracy: 0.98
train step: 160, loss: 0.0942, accuracy: 0.99
train step: 180, loss: 0.0917, accuracy: 0.99
In [ ]:
Compare to sklearn
In [ ]:
[5.69473 8.89993 -12.90385 -6.59589 -1.40077 1.88896 0.08464 -14.39687
-4.29397 -10.78889 12.81921 20.99277]
[3.97582 32.52712 -36.50294]
(50, 3)
[[0.00000 1.00000 0.00000]
[1.00000 0.00000 0.00000]
[0.00000 0.00000 1.00000]
[0.00000 0.99999 0.00000]
[0.00001 0.99999 0.00000]
[1.00000 0.00000 0.00000]
[0.00605 0.99395 0.00000]
[0.00000 0.00000 1.00000]
[0.00000 0.98867 0.01133]
[0.00006 0.99994 0.00000]]
(50,)
[1 0 2 1 1 0 1 2 1 1]
In [ ]:
FrozenDict({
Dense_0: {
bias: DeviceArray([0.67322, 1.05858, -1.73180], dtype=float32),
kernel: DeviceArray([[1.13125, 0.95521, -2.30099],
[2.86528, -0.09720, -3.41202],
[-3.86220, -0.45856, 4.46158],
[-1.57317, -1.24467, 3.70615]], dtype=float32),
},
})
(50, 3)
[[0.00057 0.94575 0.05368]
[0.99750 0.00250 0.00000]
[0.00000 0.00014 0.99986]
[0.00131 0.91360 0.08509]
[0.00045 0.97463 0.02493]
[0.99551 0.00449 0.00000]
[0.02960 0.96893 0.00147]
[0.00008 0.27976 0.72016]
[0.00012 0.66911 0.33077]
[0.00644 0.98950 0.00406]]
True
[1 0 2 1 1 0 1 2 1 1]
True