Path: blob/master/notebooks/book1/09/naive_bayes_mnist_jax.ipynb
1193 views
Kernel: Python 3 (ipykernel)
Please find torch implementation of this notebook here: https://colab.research.google.com/github/probml/pyprobml/blob/master/notebooks/book1/09/naive_bayes_mnist_torch.ipynb
Naive Bayes classifiers
We show how to implement Naive Bayes classifiers from scratch. We use binary features, and 2 classes. Based on sec 18.9 of http://d2l.ai/chapter_appendix-mathematics-for-deep-learning/naive-bayes.html.
In [1]:
Out[1]:
mkdir: cannot create directory ‘figures’: File exists
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
In [2]:
Get data
We use a binarized version of MNIST.
In [3]:
In [4]:
Out[4]:
Dataset MNIST
Number of datapoints: 60000
Root location: ./temp
Split: Train
StandardTransform
Transform: <function <lambda> at 0x7f9b642bfa60>
In [5]:
Out[5]:
<class 'jaxlib.xla_extension.DeviceArray'>
(1, 28, 28)
<class 'int'>
4
In [6]:
Out[6]:
DeviceArray([[0.35686275, 0.10980392, 0.01960784, 0.9137255 , 0.98039216],
[0. , 0. , 0.4 , 0.99607843, 0.8627451 ],
[0. , 0. , 0.6627451 , 0.99607843, 0.5372549 ],
[0. , 0. , 0.6627451 , 0.99607843, 0.22352941],
[0. , 0. , 0.6627451 , 0.99607843, 0.22352941]], dtype=float32)
In [7]:
Out[7]:
[DeviceArray(0., dtype=float32), DeviceArray(1., dtype=float32)]
In [8]:
Out[8]:
(1, 28, 28)
(2, 1, 28, 28)
(1, 2, 28, 28)
(2, 28, 28)
In [9]:
In [10]:
Out[10]:
[(10, 28, 28), (10,)]
[[0.7294118 0.99215686 0.99215686 0.5882353 0.10588235]
[0.0627451 0.3647059 0.9882353 0.99215686 0.73333335]
[0. 0. 0.9764706 0.99215686 0.9764706 ]
[0.50980395 0.7176471 0.99215686 0.99215686 0.8117647 ]
[0.99215686 0.99215686 0.99215686 0.98039216 0.7137255 ]]
In [11]:
Out[11]:
[(10, 28, 28), (10,)]
[[ True True True True False]
[False False True True True]
[False False True True True]
[ True True True True True]
[ True True True True True]]
In [12]:
In [13]:
Out[13]:
(60000, 28, 28)
<class 'jaxlib.xla_extension.DeviceArray'>
[[ True True True True False]
[False False True True True]
[False False True True True]
[ True True True True True]
[ True True True True True]]
Training
In [14]:
Out[14]:
DeviceArray([0.09871667, 0.11236667, 0.0993 , 0.10218333, 0.09736667,
0.09035 , 0.09863333, 0.10441667, 0.09751666, 0.09915 ], dtype=float32)
In [15]:
Out[15]:
[0 1 2 3 4 5 6 7 8 9]
dict_keys([5, 0, 4, 1, 9, 2, 3, 6, 7, 8])
dict_values([5421, 5923, 5842, 6742, 5949, 5958, 6131, 5918, 6265, 5851])
We use add-one smoothing for class conditional Bernoulli distributions.
In [16]:
Out[16]:
(10, 28, 28)
<class 'jaxlib.xla_extension.DeviceArray'>
In [17]:
Out[17]:
Testing
In [18]:
Out[18]:
[-268.9725 -301.7044 -245.19514 -218.87386 -193.45703 -206.09087
-292.52264 -114.625656 -220.33133 -163.17844 ]
ytrue 7 yhat 7
[7]
In [19]:
Out[19]:
In [20]:
Out[20]:
In [21]:
Out[21]:
In [22]:
Out[22]:
0.8427
In [ ]: