Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/book1/13/mlp_mnist_tf.ipynb
1192 views
Kernel: Python 3 (ipykernel)

Open In Colab

MLP on (Fashion) MNIST using TF 2.0

try: # %tensorflow_version only exists in Colab. %tensorflow_version 2.x IS_COLAB = True except Exception: IS_COLAB = False # TensorFlow ≥2.0 is required try: import tensorflow as tf except ModuleNotFoundError: %pip install -qq tensorflow import tensorflow as tf from tensorflow import keras assert tf.__version__ >= "2.0" if not tf.config.list_physical_devices("GPU"): print("No GPU was detected. DNNs can be very slow without a GPU.") if IS_COLAB: print("Go to Runtime > Change runtime and select a GPU hardware accelerator.")
Collecting tensorflow Downloading tensorflow-2.8.0-cp39-cp39-manylinux2010_x86_64.whl (497.6 MB) |████████████████████████████████| 497.6 MB 8.6 kB/s eta 0:00:012 Requirement already satisfied: protobuf>=3.9.2 in /home/anand/anaconda3/lib/python3.9/site-packages (from tensorflow) (3.20.1) Requirement already satisfied: h5py>=2.9.0 in /home/anand/anaconda3/lib/python3.9/site-packages (from tensorflow) (3.3.0) Requirement already satisfied: six>=1.12.0 in /home/anand/anaconda3/lib/python3.9/site-packages (from tensorflow) (1.16.0) Requirement already satisfied: setuptools in /home/anand/anaconda3/lib/python3.9/site-packages (from tensorflow) (58.0.4) Collecting gast>=0.2.1 Using cached gast-0.5.3-py3-none-any.whl (19 kB) Collecting tf-estimator-nightly==2.8.0.dev2021122109 Using cached tf_estimator_nightly-2.8.0.dev2021122109-py2.py3-none-any.whl (462 kB) Collecting keras-preprocessing>=1.1.1 Using cached Keras_Preprocessing-1.1.2-py2.py3-none-any.whl (42 kB) Collecting flatbuffers>=1.12 Using cached flatbuffers-2.0-py2.py3-none-any.whl (26 kB) Requirement already satisfied: typing-extensions>=3.6.6 in /home/anand/anaconda3/lib/python3.9/site-packages (from tensorflow) (4.2.0) Collecting astunparse>=1.6.0 Using cached astunparse-1.6.3-py2.py3-none-any.whl (12 kB) Collecting keras<2.9,>=2.8.0rc0 Using cached keras-2.8.0-py2.py3-none-any.whl (1.4 MB) Collecting google-pasta>=0.1.1 Using cached google_pasta-0.2.0-py3-none-any.whl (57 kB) Requirement already satisfied: absl-py>=0.4.0 in /home/anand/anaconda3/lib/python3.9/site-packages (from tensorflow) (1.0.0) Requirement already satisfied: numpy>=1.20 in /home/anand/anaconda3/lib/python3.9/site-packages (from tensorflow) (1.20.3) Collecting libclang>=9.0.1 Using cached libclang-14.0.1-py2.py3-none-manylinux1_x86_64.whl (14.5 MB) Collecting opt-einsum>=2.3.2 Using cached opt_einsum-3.3.0-py3-none-any.whl (65 kB) Requirement already satisfied: grpcio<2.0,>=1.24.3 in /home/anand/anaconda3/lib/python3.9/site-packages (from tensorflow) (1.46.0) Collecting tensorboard<2.9,>=2.8 Using cached tensorboard-2.8.0-py3-none-any.whl (5.8 MB) Collecting tensorflow-io-gcs-filesystem>=0.23.1 Downloading tensorflow_io_gcs_filesystem-0.25.0-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (2.1 MB) |████████████████████████████████| 2.1 MB 16.4 MB/s eta 0:00:01 Requirement already satisfied: wrapt>=1.11.0 in /home/anand/anaconda3/lib/python3.9/site-packages (from tensorflow) (1.12.1) Collecting termcolor>=1.1.0 Using cached termcolor-1.1.0.tar.gz (3.9 kB) Requirement already satisfied: wheel<1.0,>=0.23.0 in /home/anand/anaconda3/lib/python3.9/site-packages (from astunparse>=1.6.0->tensorflow) (0.37.0) Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /home/anand/anaconda3/lib/python3.9/site-packages (from tensorboard<2.9,>=2.8->tensorflow) (1.8.1) Requirement already satisfied: werkzeug>=0.11.15 in /home/anand/anaconda3/lib/python3.9/site-packages (from tensorboard<2.9,>=2.8->tensorflow) (2.0.2) Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /home/anand/anaconda3/lib/python3.9/site-packages (from tensorboard<2.9,>=2.8->tensorflow) (0.4.6) Requirement already satisfied: markdown>=2.6.8 in /home/anand/anaconda3/lib/python3.9/site-packages (from tensorboard<2.9,>=2.8->tensorflow) (3.3.7) Requirement already satisfied: google-auth<3,>=1.6.3 in /home/anand/anaconda3/lib/python3.9/site-packages (from tensorboard<2.9,>=2.8->tensorflow) (2.6.6) Requirement already satisfied: requests<3,>=2.21.0 in /home/anand/anaconda3/lib/python3.9/site-packages (from tensorboard<2.9,>=2.8->tensorflow) (2.26.0) Requirement already satisfied: tensorboard-data-server<0.7.0,>=0.6.0 in /home/anand/anaconda3/lib/python3.9/site-packages (from tensorboard<2.9,>=2.8->tensorflow) (0.6.1) Requirement already satisfied: pyasn1-modules>=0.2.1 in /home/anand/anaconda3/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.9,>=2.8->tensorflow) (0.2.8) Requirement already satisfied: rsa<5,>=3.1.4 in /home/anand/anaconda3/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.9,>=2.8->tensorflow) (4.8) Requirement already satisfied: cachetools<6.0,>=2.0.0 in /home/anand/anaconda3/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.9,>=2.8->tensorflow) (5.0.0) Requirement already satisfied: requests-oauthlib>=0.7.0 in /home/anand/anaconda3/lib/python3.9/site-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard<2.9,>=2.8->tensorflow) (1.3.1) Requirement already satisfied: importlib-metadata>=4.4 in /home/anand/anaconda3/lib/python3.9/site-packages (from markdown>=2.6.8->tensorboard<2.9,>=2.8->tensorflow) (4.8.1) Requirement already satisfied: zipp>=0.5 in /home/anand/anaconda3/lib/python3.9/site-packages (from importlib-metadata>=4.4->markdown>=2.6.8->tensorboard<2.9,>=2.8->tensorflow) (3.6.0) Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /home/anand/anaconda3/lib/python3.9/site-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard<2.9,>=2.8->tensorflow) (0.4.8) Requirement already satisfied: certifi>=2017.4.17 in /home/anand/anaconda3/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorboard<2.9,>=2.8->tensorflow) (2021.10.8) Requirement already satisfied: charset-normalizer~=2.0.0 in /home/anand/anaconda3/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorboard<2.9,>=2.8->tensorflow) (2.0.4) Requirement already satisfied: idna<4,>=2.5 in /home/anand/anaconda3/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorboard<2.9,>=2.8->tensorflow) (3.2) Requirement already satisfied: urllib3<1.27,>=1.21.1 in /home/anand/anaconda3/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorboard<2.9,>=2.8->tensorflow) (1.26.7) Requirement already satisfied: oauthlib>=3.0.0 in /home/anand/anaconda3/lib/python3.9/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard<2.9,>=2.8->tensorflow) (3.2.0) Building wheels for collected packages: termcolor Building wheel for termcolor (setup.py) ... done Created wheel for termcolor: filename=termcolor-1.1.0-py3-none-any.whl size=4847 sha256=19c5dfa536b959af6d419d7c7e9c0e1e3779279487c931b181d470f2ab6360c8 Stored in directory: /home/anand/.cache/pip/wheels/b6/0d/90/0d1bbd99855f99cb2f6c2e5ff96f8023fad8ec367695f7d72d Successfully built termcolor Installing collected packages: tf-estimator-nightly, termcolor, tensorflow-io-gcs-filesystem, tensorboard, opt-einsum, libclang, keras-preprocessing, keras, google-pasta, gast, flatbuffers, astunparse, tensorflow Attempting uninstall: tensorboard Found existing installation: tensorboard 2.9.0 Uninstalling tensorboard-2.9.0: Successfully uninstalled tensorboard-2.9.0 Successfully installed astunparse-1.6.3 flatbuffers-2.0 gast-0.5.3 google-pasta-0.2.0 keras-2.8.0 keras-preprocessing-1.1.2 libclang-14.0.1 opt-einsum-3.3.0 tensorboard-2.8.0 tensorflow-2.8.0 tensorflow-io-gcs-filesystem-0.25.0 termcolor-1.1.0 tf-estimator-nightly-2.8.0.dev2021122109 Note: you may need to restart the kernel to use updated packages.
2022-05-14 05:15:42.904069: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory 2022-05-14 05:15:42.904090: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.
No GPU was detected. DNNs can be very slow without a GPU.
2022-05-14 05:15:44.553337: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory 2022-05-14 05:15:44.553363: W tensorflow/stream_executor/cuda/cuda_driver.cc:269] failed call to cuInit: UNKNOWN ERROR (303) 2022-05-14 05:15:44.553379: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (anand-HP-Laptop-14s-dq2xxx): /proc/driver/nvidia/version does not exist
# Standard Python libraries from __future__ import absolute_import, division, print_function, unicode_literals import os import time import numpy as np import glob import matplotlib.pyplot as plt import PIL try: import imageio except ModuleNotFoundError: %pip install -qq imageio import imageio from IPython import display import sklearn from time import time np.random.seed(0) try: from probml_utils.mnist_helper_tf import * except ModuleNotFoundError: %pip install -qq git+https://github.com/probml/probml-utils.git from probml_utils.mnist_helper_tf import *
Collecting git+https://github.com/probml/probml-utils.git Cloning https://github.com/probml/probml-utils.git to /tmp/pip-req-build-cc5y8gq7 Running command git clone -q https://github.com/probml/probml-utils.git /tmp/pip-req-build-cc5y8gq7 Resolved https://github.com/probml/probml-utils.git to commit 1cf7cf7b168d8f91e78e20f4464c60e4f693491b Installing build dependencies ... done Getting requirements to build wheel ... done Preparing wheel metadata ... done Requirement already satisfied: jupyter in /home/anand/anaconda3/lib/python3.9/site-packages (from probml-utils==0.1.dev52+g1cf7cf7) (1.0.0) Requirement already satisfied: matplotlib in /home/anand/anaconda3/lib/python3.9/site-packages (from probml-utils==0.1.dev52+g1cf7cf7) (3.4.3) Requirement already satisfied: pandas in /home/anand/anaconda3/lib/python3.9/site-packages (from probml-utils==0.1.dev52+g1cf7cf7) (1.3.4) Requirement already satisfied: numpy in /home/anand/anaconda3/lib/python3.9/site-packages (from probml-utils==0.1.dev52+g1cf7cf7) (1.20.3) Collecting distrax Using cached distrax-0.1.2-py3-none-any.whl (272 kB) Collecting TexSoup Using cached TexSoup-0.3.1.tar.gz (26 kB) Collecting jax Using cached jax-0.3.10.tar.gz (939 kB) Requirement already satisfied: regex in /home/anand/anaconda3/lib/python3.9/site-packages (from probml-utils==0.1.dev52+g1cf7cf7) (2021.8.3) Collecting jaxlib Downloading jaxlib-0.3.10-cp39-none-manylinux2014_x86_64.whl (69.0 MB) |████████████████████████████████| 69.0 MB 88 kB/s eta 0:00:012 |██████████████████████████████▊ | 66.2 MB 4.4 MB/s eta 0:00:01 Collecting firebase-admin Using cached firebase_admin-5.2.0-py3-none-any.whl (115 kB) Collecting graphviz Using cached graphviz-0.20-py3-none-any.whl (46 kB) Requirement already satisfied: scipy in /home/anand/anaconda3/lib/python3.9/site-packages (from probml-utils==0.1.dev52+g1cf7cf7) (1.7.1) Requirement already satisfied: scikit-learn in /home/anand/anaconda3/lib/python3.9/site-packages (from probml-utils==0.1.dev52+g1cf7cf7) (0.24.2) Collecting chex>=0.0.7 Using cached chex-0.1.3-py3-none-any.whl (72 kB) Requirement already satisfied: absl-py>=0.9.0 in /home/anand/anaconda3/lib/python3.9/site-packages (from distrax->probml-utils==0.1.dev52+g1cf7cf7) (1.0.0) Collecting tensorflow-probability>=0.15.0 Using cached tensorflow_probability-0.16.0-py2.py3-none-any.whl (6.3 MB) Requirement already satisfied: six in /home/anand/anaconda3/lib/python3.9/site-packages (from absl-py>=0.9.0->distrax->probml-utils==0.1.dev52+g1cf7cf7) (1.16.0) Collecting dm-tree>=0.1.5 Downloading dm_tree-0.1.7-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (142 kB) |████████████████████████████████| 142 kB 11.3 MB/s eta 0:00:01 Requirement already satisfied: toolz>=0.9.0 in /home/anand/anaconda3/lib/python3.9/site-packages (from chex>=0.0.7->distrax->probml-utils==0.1.dev52+g1cf7cf7) (0.11.1) Requirement already satisfied: opt_einsum in /home/anand/anaconda3/lib/python3.9/site-packages (from jax->probml-utils==0.1.dev52+g1cf7cf7) (3.3.0) Requirement already satisfied: typing_extensions in /home/anand/anaconda3/lib/python3.9/site-packages (from jax->probml-utils==0.1.dev52+g1cf7cf7) (4.2.0) Requirement already satisfied: flatbuffers<3.0,>=1.12 in /home/anand/anaconda3/lib/python3.9/site-packages (from jaxlib->probml-utils==0.1.dev52+g1cf7cf7) (2.0) Requirement already satisfied: cloudpickle>=1.3 in /home/anand/anaconda3/lib/python3.9/site-packages (from tensorflow-probability>=0.15.0->distrax->probml-utils==0.1.dev52+g1cf7cf7) (2.0.0) Requirement already satisfied: decorator in /home/anand/anaconda3/lib/python3.9/site-packages (from tensorflow-probability>=0.15.0->distrax->probml-utils==0.1.dev52+g1cf7cf7) (5.1.0) Requirement already satisfied: gast>=0.3.2 in /home/anand/anaconda3/lib/python3.9/site-packages (from tensorflow-probability>=0.15.0->distrax->probml-utils==0.1.dev52+g1cf7cf7) (0.5.3) Collecting google-api-python-client>=1.7.8 Using cached google_api_python_client-2.47.0-py2.py3-none-any.whl (8.4 MB) Collecting cachecontrol>=0.12.6 Using cached CacheControl-0.12.11-py2.py3-none-any.whl (21 kB) Collecting google-api-core[grpc]<3.0.0dev,>=1.22.1 Using cached google_api_core-2.7.3-py3-none-any.whl (114 kB) Collecting google-cloud-storage>=1.37.1 Using cached google_cloud_storage-2.3.0-py2.py3-none-any.whl (107 kB) Collecting google-cloud-firestore>=2.1.0 Using cached google_cloud_firestore-2.4.0-py2.py3-none-any.whl (243 kB) Requirement already satisfied: msgpack>=0.5.2 in /home/anand/anaconda3/lib/python3.9/site-packages (from cachecontrol>=0.12.6->firebase-admin->probml-utils==0.1.dev52+g1cf7cf7) (1.0.2) Requirement already satisfied: requests in /home/anand/anaconda3/lib/python3.9/site-packages (from cachecontrol>=0.12.6->firebase-admin->probml-utils==0.1.dev52+g1cf7cf7) (2.26.0) Collecting googleapis-common-protos<2.0dev,>=1.52.0 Using cached googleapis_common_protos-1.56.1-py2.py3-none-any.whl (211 kB) Requirement already satisfied: protobuf>=3.12.0 in /home/anand/anaconda3/lib/python3.9/site-packages (from google-api-core[grpc]<3.0.0dev,>=1.22.1->firebase-admin->probml-utils==0.1.dev52+g1cf7cf7) (3.20.1) Requirement already satisfied: google-auth<3.0dev,>=1.25.0 in /home/anand/anaconda3/lib/python3.9/site-packages (from google-api-core[grpc]<3.0.0dev,>=1.22.1->firebase-admin->probml-utils==0.1.dev52+g1cf7cf7) (2.6.6) Requirement already satisfied: grpcio<2.0dev,>=1.33.2 in /home/anand/anaconda3/lib/python3.9/site-packages (from google-api-core[grpc]<3.0.0dev,>=1.22.1->firebase-admin->probml-utils==0.1.dev52+g1cf7cf7) (1.46.0) Collecting grpcio-status<2.0dev,>=1.33.2 Using cached grpcio_status-1.46.1-py3-none-any.whl (10.0 kB) Collecting google-auth-httplib2>=0.1.0 Using cached google_auth_httplib2-0.1.0-py2.py3-none-any.whl (9.3 kB) Collecting uritemplate<5,>=3.0.1 Using cached uritemplate-4.1.1-py2.py3-none-any.whl (10 kB) Collecting httplib2<1dev,>=0.15.0 Using cached httplib2-0.20.4-py3-none-any.whl (96 kB) Requirement already satisfied: pyasn1-modules>=0.2.1 in /home/anand/anaconda3/lib/python3.9/site-packages (from google-auth<3.0dev,>=1.25.0->google-api-core[grpc]<3.0.0dev,>=1.22.1->firebase-admin->probml-utils==0.1.dev52+g1cf7cf7) (0.2.8) Requirement already satisfied: rsa<5,>=3.1.4 in /home/anand/anaconda3/lib/python3.9/site-packages (from google-auth<3.0dev,>=1.25.0->google-api-core[grpc]<3.0.0dev,>=1.22.1->firebase-admin->probml-utils==0.1.dev52+g1cf7cf7) (4.8) Requirement already satisfied: cachetools<6.0,>=2.0.0 in /home/anand/anaconda3/lib/python3.9/site-packages (from google-auth<3.0dev,>=1.25.0->google-api-core[grpc]<3.0.0dev,>=1.22.1->firebase-admin->probml-utils==0.1.dev52+g1cf7cf7) (5.0.0) Collecting google-cloud-core<3.0.0dev,>=1.4.1 Using cached google_cloud_core-2.3.0-py2.py3-none-any.whl (29 kB) Collecting proto-plus>=1.10.0 Using cached proto_plus-1.20.3-py3-none-any.whl (46 kB) Collecting google-resumable-media>=2.3.2 Using cached google_resumable_media-2.3.2-py2.py3-none-any.whl (76 kB) Collecting google-crc32c<2.0dev,>=1.0 Downloading google_crc32c-1.3.0-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (36 kB) Collecting grpcio<2.0dev,>=1.33.2 Downloading grpcio-1.46.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.4 MB) |████████████████████████████████| 4.4 MB 8.6 MB/s eta 0:00:01 Requirement already satisfied: pyparsing!=3.0.0,!=3.0.1,!=3.0.2,!=3.0.3,<4,>=2.4.2 in /home/anand/anaconda3/lib/python3.9/site-packages (from httplib2<1dev,>=0.15.0->google-api-python-client>=1.7.8->firebase-admin->probml-utils==0.1.dev52+g1cf7cf7) (3.0.4) Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /home/anand/anaconda3/lib/python3.9/site-packages (from pyasn1-modules>=0.2.1->google-auth<3.0dev,>=1.25.0->google-api-core[grpc]<3.0.0dev,>=1.22.1->firebase-admin->probml-utils==0.1.dev52+g1cf7cf7) (0.4.8) Requirement already satisfied: idna<4,>=2.5 in /home/anand/anaconda3/lib/python3.9/site-packages (from requests->cachecontrol>=0.12.6->firebase-admin->probml-utils==0.1.dev52+g1cf7cf7) (3.2) Requirement already satisfied: urllib3<1.27,>=1.21.1 in /home/anand/anaconda3/lib/python3.9/site-packages (from requests->cachecontrol>=0.12.6->firebase-admin->probml-utils==0.1.dev52+g1cf7cf7) (1.26.7) Requirement already satisfied: charset-normalizer~=2.0.0 in /home/anand/anaconda3/lib/python3.9/site-packages (from requests->cachecontrol>=0.12.6->firebase-admin->probml-utils==0.1.dev52+g1cf7cf7) (2.0.4) Requirement already satisfied: certifi>=2017.4.17 in /home/anand/anaconda3/lib/python3.9/site-packages (from requests->cachecontrol>=0.12.6->firebase-admin->probml-utils==0.1.dev52+g1cf7cf7) (2021.10.8) Requirement already satisfied: notebook in /home/anand/anaconda3/lib/python3.9/site-packages (from jupyter->probml-utils==0.1.dev52+g1cf7cf7) (6.4.5) Requirement already satisfied: ipywidgets in /home/anand/anaconda3/lib/python3.9/site-packages (from jupyter->probml-utils==0.1.dev52+g1cf7cf7) (7.6.5) Requirement already satisfied: ipykernel in /home/anand/anaconda3/lib/python3.9/site-packages (from jupyter->probml-utils==0.1.dev52+g1cf7cf7) (6.4.1) Requirement already satisfied: jupyter-console in /home/anand/anaconda3/lib/python3.9/site-packages (from jupyter->probml-utils==0.1.dev52+g1cf7cf7) (6.4.0) Requirement already satisfied: qtconsole in /home/anand/anaconda3/lib/python3.9/site-packages (from jupyter->probml-utils==0.1.dev52+g1cf7cf7) (5.1.1) Requirement already satisfied: nbconvert in /home/anand/anaconda3/lib/python3.9/site-packages (from jupyter->probml-utils==0.1.dev52+g1cf7cf7) (6.1.0) Requirement already satisfied: ipython<8.0,>=7.23.1 in /home/anand/anaconda3/lib/python3.9/site-packages (from ipykernel->jupyter->probml-utils==0.1.dev52+g1cf7cf7) (7.29.0) Requirement already satisfied: debugpy<2.0,>=1.0.0 in /home/anand/anaconda3/lib/python3.9/site-packages (from ipykernel->jupyter->probml-utils==0.1.dev52+g1cf7cf7) (1.4.1) Requirement already satisfied: traitlets<6.0,>=4.1.0 in /home/anand/anaconda3/lib/python3.9/site-packages (from ipykernel->jupyter->probml-utils==0.1.dev52+g1cf7cf7) (5.1.0) Requirement already satisfied: matplotlib-inline<0.2.0,>=0.1.0 in /home/anand/anaconda3/lib/python3.9/site-packages (from ipykernel->jupyter->probml-utils==0.1.dev52+g1cf7cf7) (0.1.2) Requirement already satisfied: ipython-genutils in /home/anand/anaconda3/lib/python3.9/site-packages (from ipykernel->jupyter->probml-utils==0.1.dev52+g1cf7cf7) (0.2.0) Requirement already satisfied: tornado<7.0,>=4.2 in /home/anand/anaconda3/lib/python3.9/site-packages (from ipykernel->jupyter->probml-utils==0.1.dev52+g1cf7cf7) (6.1) Requirement already satisfied: jupyter-client<8.0 in /home/anand/anaconda3/lib/python3.9/site-packages (from ipykernel->jupyter->probml-utils==0.1.dev52+g1cf7cf7) (6.1.12) Requirement already satisfied: setuptools>=18.5 in /home/anand/anaconda3/lib/python3.9/site-packages (from ipython<8.0,>=7.23.1->ipykernel->jupyter->probml-utils==0.1.dev52+g1cf7cf7) (58.0.4) Requirement already satisfied: backcall in /home/anand/anaconda3/lib/python3.9/site-packages (from ipython<8.0,>=7.23.1->ipykernel->jupyter->probml-utils==0.1.dev52+g1cf7cf7) (0.2.0) Requirement already satisfied: jedi>=0.16 in /home/anand/anaconda3/lib/python3.9/site-packages (from ipython<8.0,>=7.23.1->ipykernel->jupyter->probml-utils==0.1.dev52+g1cf7cf7) (0.18.0) Requirement already satisfied: pygments in /home/anand/anaconda3/lib/python3.9/site-packages (from ipython<8.0,>=7.23.1->ipykernel->jupyter->probml-utils==0.1.dev52+g1cf7cf7) (2.10.0) Requirement already satisfied: pickleshare in /home/anand/anaconda3/lib/python3.9/site-packages (from ipython<8.0,>=7.23.1->ipykernel->jupyter->probml-utils==0.1.dev52+g1cf7cf7) (0.7.5) Requirement already satisfied: pexpect>4.3 in /home/anand/anaconda3/lib/python3.9/site-packages (from ipython<8.0,>=7.23.1->ipykernel->jupyter->probml-utils==0.1.dev52+g1cf7cf7) (4.8.0) Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /home/anand/anaconda3/lib/python3.9/site-packages (from ipython<8.0,>=7.23.1->ipykernel->jupyter->probml-utils==0.1.dev52+g1cf7cf7) (3.0.20) Requirement already satisfied: parso<0.9.0,>=0.8.0 in /home/anand/anaconda3/lib/python3.9/site-packages (from jedi>=0.16->ipython<8.0,>=7.23.1->ipykernel->jupyter->probml-utils==0.1.dev52+g1cf7cf7) (0.8.2) Requirement already satisfied: jupyter-core>=4.6.0 in /home/anand/anaconda3/lib/python3.9/site-packages (from jupyter-client<8.0->ipykernel->jupyter->probml-utils==0.1.dev52+g1cf7cf7) (4.8.1) Requirement already satisfied: pyzmq>=13 in /home/anand/anaconda3/lib/python3.9/site-packages (from jupyter-client<8.0->ipykernel->jupyter->probml-utils==0.1.dev52+g1cf7cf7) (22.2.1) Requirement already satisfied: python-dateutil>=2.1 in /home/anand/anaconda3/lib/python3.9/site-packages (from jupyter-client<8.0->ipykernel->jupyter->probml-utils==0.1.dev52+g1cf7cf7) (2.8.2) Requirement already satisfied: ptyprocess>=0.5 in /home/anand/anaconda3/lib/python3.9/site-packages (from pexpect>4.3->ipython<8.0,>=7.23.1->ipykernel->jupyter->probml-utils==0.1.dev52+g1cf7cf7) (0.7.0) Requirement already satisfied: wcwidth in /home/anand/anaconda3/lib/python3.9/site-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->ipython<8.0,>=7.23.1->ipykernel->jupyter->probml-utils==0.1.dev52+g1cf7cf7) (0.2.5) Requirement already satisfied: jupyterlab-widgets>=1.0.0 in /home/anand/anaconda3/lib/python3.9/site-packages (from ipywidgets->jupyter->probml-utils==0.1.dev52+g1cf7cf7) (1.0.0) Requirement already satisfied: nbformat>=4.2.0 in /home/anand/anaconda3/lib/python3.9/site-packages (from ipywidgets->jupyter->probml-utils==0.1.dev52+g1cf7cf7) (5.1.3) Requirement already satisfied: widgetsnbextension~=3.5.0 in /home/anand/anaconda3/lib/python3.9/site-packages (from ipywidgets->jupyter->probml-utils==0.1.dev52+g1cf7cf7) (3.5.1) Requirement already satisfied: jsonschema!=2.5.0,>=2.4 in /home/anand/anaconda3/lib/python3.9/site-packages (from nbformat>=4.2.0->ipywidgets->jupyter->probml-utils==0.1.dev52+g1cf7cf7) (3.2.0) Requirement already satisfied: pyrsistent>=0.14.0 in /home/anand/anaconda3/lib/python3.9/site-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets->jupyter->probml-utils==0.1.dev52+g1cf7cf7) (0.18.0) Requirement already satisfied: attrs>=17.4.0 in /home/anand/anaconda3/lib/python3.9/site-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets->jupyter->probml-utils==0.1.dev52+g1cf7cf7) (21.2.0) Requirement already satisfied: prometheus-client in /home/anand/anaconda3/lib/python3.9/site-packages (from notebook->jupyter->probml-utils==0.1.dev52+g1cf7cf7) (0.11.0) Requirement already satisfied: Send2Trash>=1.5.0 in /home/anand/anaconda3/lib/python3.9/site-packages (from notebook->jupyter->probml-utils==0.1.dev52+g1cf7cf7) (1.8.0) Requirement already satisfied: jinja2 in /home/anand/anaconda3/lib/python3.9/site-packages (from notebook->jupyter->probml-utils==0.1.dev52+g1cf7cf7) (2.11.3) Requirement already satisfied: argon2-cffi in /home/anand/anaconda3/lib/python3.9/site-packages (from notebook->jupyter->probml-utils==0.1.dev52+g1cf7cf7) (20.1.0) Requirement already satisfied: terminado>=0.8.3 in /home/anand/anaconda3/lib/python3.9/site-packages (from notebook->jupyter->probml-utils==0.1.dev52+g1cf7cf7) (0.9.4) Requirement already satisfied: cffi>=1.0.0 in /home/anand/anaconda3/lib/python3.9/site-packages (from argon2-cffi->notebook->jupyter->probml-utils==0.1.dev52+g1cf7cf7) (1.14.6) Requirement already satisfied: pycparser in /home/anand/anaconda3/lib/python3.9/site-packages (from cffi>=1.0.0->argon2-cffi->notebook->jupyter->probml-utils==0.1.dev52+g1cf7cf7) (2.20) Requirement already satisfied: MarkupSafe>=0.23 in /home/anand/anaconda3/lib/python3.9/site-packages (from jinja2->notebook->jupyter->probml-utils==0.1.dev52+g1cf7cf7) (1.1.1) Requirement already satisfied: cycler>=0.10 in /home/anand/anaconda3/lib/python3.9/site-packages (from matplotlib->probml-utils==0.1.dev52+g1cf7cf7) (0.10.0) Requirement already satisfied: pillow>=6.2.0 in /home/anand/anaconda3/lib/python3.9/site-packages (from matplotlib->probml-utils==0.1.dev52+g1cf7cf7) (8.4.0) Requirement already satisfied: kiwisolver>=1.0.1 in /home/anand/anaconda3/lib/python3.9/site-packages (from matplotlib->probml-utils==0.1.dev52+g1cf7cf7) (1.3.1) Requirement already satisfied: bleach in /home/anand/anaconda3/lib/python3.9/site-packages (from nbconvert->jupyter->probml-utils==0.1.dev52+g1cf7cf7) (4.0.0) Requirement already satisfied: pandocfilters>=1.4.1 in /home/anand/anaconda3/lib/python3.9/site-packages (from nbconvert->jupyter->probml-utils==0.1.dev52+g1cf7cf7) (1.4.3) Requirement already satisfied: jupyterlab-pygments in /home/anand/anaconda3/lib/python3.9/site-packages (from nbconvert->jupyter->probml-utils==0.1.dev52+g1cf7cf7) (0.1.2) Requirement already satisfied: defusedxml in /home/anand/anaconda3/lib/python3.9/site-packages (from nbconvert->jupyter->probml-utils==0.1.dev52+g1cf7cf7) (0.7.1) Requirement already satisfied: testpath in /home/anand/anaconda3/lib/python3.9/site-packages (from nbconvert->jupyter->probml-utils==0.1.dev52+g1cf7cf7) (0.5.0) Requirement already satisfied: nbclient<0.6.0,>=0.5.0 in /home/anand/anaconda3/lib/python3.9/site-packages (from nbconvert->jupyter->probml-utils==0.1.dev52+g1cf7cf7) (0.5.3) Requirement already satisfied: entrypoints>=0.2.2 in /home/anand/anaconda3/lib/python3.9/site-packages (from nbconvert->jupyter->probml-utils==0.1.dev52+g1cf7cf7) (0.3) Requirement already satisfied: mistune<2,>=0.8.1 in /home/anand/anaconda3/lib/python3.9/site-packages (from nbconvert->jupyter->probml-utils==0.1.dev52+g1cf7cf7) (0.8.4) Requirement already satisfied: nest-asyncio in /home/anand/anaconda3/lib/python3.9/site-packages (from nbclient<0.6.0,>=0.5.0->nbconvert->jupyter->probml-utils==0.1.dev52+g1cf7cf7) (1.5.1) Requirement already satisfied: async-generator in /home/anand/anaconda3/lib/python3.9/site-packages (from nbclient<0.6.0,>=0.5.0->nbconvert->jupyter->probml-utils==0.1.dev52+g1cf7cf7) (1.10) Requirement already satisfied: packaging in /home/anand/anaconda3/lib/python3.9/site-packages (from bleach->nbconvert->jupyter->probml-utils==0.1.dev52+g1cf7cf7) (21.0) Requirement already satisfied: webencodings in /home/anand/anaconda3/lib/python3.9/site-packages (from bleach->nbconvert->jupyter->probml-utils==0.1.dev52+g1cf7cf7) (0.5.1) Requirement already satisfied: pytz>=2017.3 in /home/anand/anaconda3/lib/python3.9/site-packages (from pandas->probml-utils==0.1.dev52+g1cf7cf7) (2021.3) Requirement already satisfied: qtpy in /home/anand/anaconda3/lib/python3.9/site-packages (from qtconsole->jupyter->probml-utils==0.1.dev52+g1cf7cf7) (1.10.0) Requirement already satisfied: joblib>=0.11 in /home/anand/anaconda3/lib/python3.9/site-packages (from scikit-learn->probml-utils==0.1.dev52+g1cf7cf7) (1.1.0) Requirement already satisfied: threadpoolctl>=2.0.0 in /home/anand/anaconda3/lib/python3.9/site-packages (from scikit-learn->probml-utils==0.1.dev52+g1cf7cf7) (2.2.0) Building wheels for collected packages: probml-utils, jax, TexSoup Building wheel for probml-utils (PEP 517) ... done Created wheel for probml-utils: filename=probml_utils-0.1.dev52+g1cf7cf7-py3-none-any.whl size=32155 sha256=5d6739efadbe5e929020af3aa42e0e774bf19f7ec5e58346c25fd1becb542248 Stored in directory: /tmp/pip-ephem-wheel-cache-8esvtr17/wheels/20/25/f1/287926200d414dd40122434c898030c9a5e467b9da81948d25 Building wheel for jax (setup.py) ... done Created wheel for jax: filename=jax-0.3.10-py3-none-any.whl size=1088067 sha256=a1f76679d8ff60fdd98a6e40c856e4200dd3c78425ebe9b6892106bf8ba88767 Stored in directory: /home/anand/.cache/pip/wheels/14/4a/ff/e9ddfa09012c67d22f926a7873c546c04e722969e8d86f84ec Building wheel for TexSoup (setup.py) ... done Created wheel for TexSoup: filename=TexSoup-0.3.1-py3-none-any.whl size=27666 sha256=56f50088f7561ea2568919f0d68ded1845686657faaa825dd43b4ba46a21d04a Stored in directory: /home/anand/.cache/pip/wheels/7a/1b/4e/376437fded0b1407afc823f0e4d56893dcc449a4ca968da53a Successfully built probml-utils jax TexSoup Installing collected packages: grpcio, googleapis-common-protos, httplib2, grpcio-status, google-crc32c, google-api-core, uritemplate, proto-plus, jaxlib, jax, google-resumable-media, google-cloud-core, google-auth-httplib2, dm-tree, tensorflow-probability, google-cloud-storage, google-cloud-firestore, google-api-python-client, chex, cachecontrol, TexSoup, graphviz, firebase-admin, distrax, probml-utils Attempting uninstall: grpcio Found existing installation: grpcio 1.46.0 Uninstalling grpcio-1.46.0: Successfully uninstalled grpcio-1.46.0 Successfully installed TexSoup-0.3.1 cachecontrol-0.12.11 chex-0.1.3 distrax-0.1.2 dm-tree-0.1.7 firebase-admin-5.2.0 google-api-core-2.7.3 google-api-python-client-2.47.0 google-auth-httplib2-0.1.0 google-cloud-core-2.3.0 google-cloud-firestore-2.4.0 google-cloud-storage-2.3.0 google-crc32c-1.3.0 google-resumable-media-2.3.2 googleapis-common-protos-1.56.1 graphviz-0.20 grpcio-1.46.1 grpcio-status-1.46.1 httplib2-0.20.4 jax-0.3.10 jaxlib-0.3.10 probml-utils-0.1.dev52+g1cf7cf7 proto-plus-1.20.3 tensorflow-probability-0.16.0 uritemplate-4.1.1 Note: you may need to restart the kernel to use updated packages.
train_images, train_labels, test_images, test_labels, class_names = get_dataset(FASHION=False) print(train_images.shape) plot_dataset(train_images, train_labels, class_names)
(60000, 28, 28)
Image in a Jupyter notebook
model = keras.Sequential( [ keras.layers.Flatten(input_shape=(28, 28)), keras.layers.Dense(128, activation=tf.nn.relu), keras.layers.Dense(128, activation=tf.nn.relu), keras.layers.Dense(10, activation=tf.nn.softmax), ] ) model.summary() model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"]) # We just train for 1 epochs because (1) it is faster, and # (2) it produces more errors, which makes for a more interesting plot :) time_start = time() model.fit(train_images, train_labels, epochs=1) print("time spent training {:0.3f}".format(time() - time_start))
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= flatten (Flatten) (None, 784) 0 dense (Dense) (None, 128) 100480 dense_1 (Dense) (None, 128) 16512 dense_2 (Dense) (None, 10) 1290 ================================================================= Total params: 118,282 Trainable params: 118,282 Non-trainable params: 0 _________________________________________________________________
2022-05-14 05:16:43.431578: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F FMA To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags. 2022-05-14 05:16:43.653962: W tensorflow/core/framework/cpu_allocator_impl.cc:82] Allocation of 188160000 exceeds 10% of free system memory.
1875/1875 [==============================] - 2s 871us/step - loss: 0.2288 - accuracy: 0.9319 time spent training 2.223
# Overall accuracy train_loss, train_acc = model.evaluate(train_images, train_labels) print("Train accuracy:", train_acc) test_loss, test_acc = model.evaluate(test_images, test_labels) print("Test accuracy:", test_acc)
90/1875 [>.............................] - ETA: 1s - loss: 0.0953 - accuracy: 0.9743
2022-05-14 05:16:45.832119: W tensorflow/core/framework/cpu_allocator_impl.cc:82] Allocation of 188160000 exceeds 10% of free system memory.
1875/1875 [==============================] - 1s 519us/step - loss: 0.1004 - accuracy: 0.9700 Train accuracy: 0.9699666500091553 129/313 [===========>..................] - ETA: 0s - loss: 0.1547 - accuracy: 0.9547
2022-05-14 05:16:46.995403: W tensorflow/core/framework/cpu_allocator_impl.cc:82] Allocation of 31360000 exceeds 10% of free system memory.
313/313 [==============================] - 0s 791us/step - loss: 0.1162 - accuracy: 0.9646 Test accuracy: 0.9646000266075134
# To apply prediction to a single image, we need to reshape to an (N,D,D) tensor # where N=1 img = test_images[0] img = np.expand_dims(img, 0) print(img.shape) predictions_single = model.predict(img) print(predictions_single.shape)
(1, 28, 28) (1, 10)
predictions = model.predict(test_images) print(np.shape(predictions)) ndx = find_interesting_test_images(predictions, test_labels) plot_interesting_test_results(test_images, test_labels, predictions, class_names, ndx)
(10000, 10) (354,)
2022-05-14 05:16:47.452415: W tensorflow/core/framework/cpu_allocator_impl.cc:82] Allocation of 31360000 exceeds 10% of free system memory.
Image in a Jupyter notebook
model_epoch = model # save old model # Train for 1 more epochs time_start = time() model.fit(train_images, train_labels, epochs=1) print("time spent training {:0.3f}".format(time() - time_start))
182/1875 [=>............................] - ETA: 1s - loss: 0.0915 - accuracy: 0.9748
2022-05-14 05:16:48.643030: W tensorflow/core/framework/cpu_allocator_impl.cc:82] Allocation of 188160000 exceeds 10% of free system memory.
1875/1875 [==============================] - 2s 861us/step - loss: 0.0948 - accuracy: 0.9711 time spent training 1.875
# Overall accuracy train_loss, train_acc = model.evaluate(train_images, train_labels) print("Train accuracy:", train_acc) test_loss, test_acc = model.evaluate(test_images, test_labels) print("Test accuracy:", test_acc)
1875/1875 [==============================] - 1s 530us/step - loss: 0.0622 - accuracy: 0.9810 Train accuracy: 0.9809666872024536 313/313 [==============================] - 0s 744us/step - loss: 0.0896 - accuracy: 0.9732 Test accuracy: 0.9732000231742859
predictions = model.predict(test_images) print(np.shape(predictions)) # test_ndx = find_interesting_test_images(predictions) # re-use old inddices plot_interesting_test_results(test_images, test_labels, predictions, class_names, ndx)
(10000, 10)
Image in a Jupyter notebook