Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/book1/01/emnist_viz_jax.ipynb
1193 views
Kernel: Python 3

Open In Colab

try: import torchvision except ModuleNotFoundError: %pip install -qq torchvision import torchvision from torchvision import datasets from torchvision import transforms import numpy as np import jax import jax.numpy as jnp import itertools try: from bokeh.io import output_notebook, show except ModuleNotFoundError: %pip install -qq bokeh from bokeh.io import output_notebook, show from bokeh.layouts import gridplot from bokeh.plotting import figure

This is required for Bokeh to work in notebooks.

output_notebook()

According to NIST,

The EMNIST dataset is a set of handwritten character digits derived from the NIST Special Database 19 and converted to a 28x28 pixel image format and dataset structure that directly matches the MNIST dataset. Further information on the dataset contents and conversion process can be found in the paper available at https://arxiv.org/abs/1702.05373v1.

Since we are going to work with JAX, let's transform the PyTorch Tensors to JAX DeviceArrays.

transform = transforms.Compose( [lambda img: torchvision.transforms.functional.rotate(img, 90), transforms.ToTensor(), jnp.array] ) training_data = datasets.EMNIST(root="~/data", split="byclass", download=True, transform=transform)
Downloading https://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip to /root/data/EMNIST/raw/gzip.zip
0%| | 0/561753746 [00:00<?, ?it/s]
Extracting /root/data/EMNIST/raw/gzip.zip to /root/data/EMNIST/raw

Here are some examples from the dataset

def plot_one(item): image, raw_label = item label = training_data.classes[raw_label] p = figure(title=f"label = {label}", tooltips=[("x", "$x"), ("y", "$y"), ("value", "@image")], match_aspect=True) p.x_range.range_padding = p.y_range.range_padding = 0 # must give a vector of image data for image parameter subplot = p.image(image=[np.array(image.squeeze())], x=0, y=0, dw=1, dh=1, level="image") p.title.align = "center" p.axis.visible = False p.grid.grid_line_width = 0.5 return p # Take the first 25 images subplots = list(map(plot_one, itertools.islice(training_data, 25))) grid = gridplot(subplots, ncols=5, toolbar_location=None, plot_width=150, plot_height=150) show(grid)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
MIME type unknown not supported
MIME type unknown not supported