Path: blob/master/site/en-snapshot/quantum/tutorials/mnist.ipynb
25118 views
Copyright 2020 The TensorFlow Authors.
MNIST classification
This tutorial builds a quantum neural network (QNN) to classify a simplified version of MNIST, similar to the approach used in Farhi et al. The performance of the quantum neural network on this classical data problem is compared with a classical neural network.
Setup
Install TensorFlow Quantum:
Now import TensorFlow and the module dependencies:
1. Load the data
In this tutorial you will build a binary classifier to distinguish between the digits 3 and 6, following Farhi et al. This section covers the data handling that:
Loads the raw data from Keras.
Filters the dataset to only 3s and 6s.
Downscales the images so they fit can fit in a quantum computer.
Removes any contradictory examples.
Converts the binary images to Cirq circuits.
Converts the Cirq circuits to TensorFlow Quantum circuits.
1.1 Load the raw data
Load the MNIST dataset distributed with Keras.
Filter the dataset to keep just the 3s and 6s, remove the other classes. At the same time convert the label, y
, to boolean: True
for 3
and False
for 6.
Show the first example:
1.2 Downscale the images
An image size of 28x28 is much too large for current quantum computers. Resize the image down to 4x4:
Again, display the first training example—after resize:
1.3 Remove contradictory examples
From section 3.3 Learning to Distinguish Digits of Farhi et al., filter the dataset to remove images that are labeled as belonging to both classes.
This is not a standard machine-learning procedure, but is included in the interest of following the paper.
The resulting counts do not closely match the reported values, but the exact procedure is not specified.
It is also worth noting here that applying filtering contradictory examples at this point does not totally prevent the model from receiving contradictory training examples: the next step binarizes the data which will cause more collisions.
1.4 Encode the data as quantum circuits
To process images using a quantum computer, Farhi et al. proposed representing each pixel with a qubit, with the state depending on the value of the pixel. The first step is to convert to a binary encoding.
If you were to remove contradictory images at this point you would be left with only 193, likely not enough for effective training.
The qubits at pixel indices with values that exceed a threshold, are rotated through an gate.
Here is the circuit created for the first example (circuit diagrams do not show qubits with zero gates):
Compare this circuit to the indices where the image value exceeds the threshold:
Convert these Cirq
circuits to tensors for tfq
:
2. Quantum neural network
There is little guidance for a quantum circuit structure that classifies images. Since the classification is based on the expectation of the readout qubit, Farhi et al. propose using two qubit gates, with the readout qubit always acted upon. This is similar in some ways to running small a Unitary RNN across the pixels.
2.1 Build the model circuit
This following example shows this layered approach. Each layer uses n instances of the same gate, with each of the data qubits acting on the readout qubit.
Start with a simple class that will add a layer of these gates to a circuit:
Build an example circuit layer to see how it looks:
Now build a two-layered model, matching the data-circuit size, and include the preparation and readout operations.
2.2 Wrap the model-circuit in a tfq-keras model
Build the Keras model with the quantum components. This model is fed the "quantum data", from x_train_circ
, that encodes the classical data. It uses a Parametrized Quantum Circuit layer, tfq.layers.PQC
, to train the model circuit, on the quantum data.
To classify these images, Farhi et al. proposed taking the expectation of a readout qubit in a parameterized circuit. The expectation returns a value between 1 and -1.
Next, describe the training procedure to the model, using the compile
method.
Since the the expected readout is in the range [-1,1]
, optimizing the hinge loss is a somewhat natural fit.
Note: Another valid approach would be to shift the output range to [0,1]
, and treat it as the probability the model assigns to class 3
. This could be used with a standard a tf.losses.BinaryCrossentropy
loss.
To use the hinge loss here you need to make two small adjustments. First convert the labels, y_train_nocon
, from boolean to [-1,1]
, as expected by the hinge loss.
Second, use a custiom hinge_accuracy
metric that correctly handles [-1, 1]
as the y_true
labels argument. tf.losses.BinaryAccuracy(threshold=0.0)
expects y_true
to be a boolean, and so can't be used with hinge loss).
Train the quantum model
Now train the model—this takes about 45 min. If you don't want to wait that long, use a small subset of the data (set NUM_EXAMPLES=500
, below). This doesn't really affect the model's progress during training (it only has 32 parameters, and doesn't need much data to constrain these). Using fewer examples just ends training earlier (5min), but runs long enough to show that it is making progress in the validation logs.
Training this model to convergence should achieve >85% accuracy on the test set.
Note: The training accuracy reports the average over the epoch. The validation accuracy is evaluated at the end of each epoch.
3. Classical neural network
While the quantum neural network works for this simplified MNIST problem, a basic classical neural network can easily outperform a QNN on this task. After a single epoch, a classical neural network can achieve >98% accuracy on the holdout set.
In the following example, a classical neural network is used for for the 3-6 classification problem using the entire 28x28 image instead of subsampling the image. This easily converges to nearly 100% accuracy of the test set.
The above model has nearly 1.2M parameters. For a more fair comparison, try a 37-parameter model, on the subsampled images:
4. Comparison
Higher resolution input and a more powerful model make this problem easy for the CNN. While a classical model of similar power (~32 parameters) trains to a similar accuracy in a fraction of the time. One way or the other, the classical neural network easily outperforms the quantum neural network. For classical data, it is difficult to beat a classical neural network.