Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
UBC-DSCI
GitHub Repository: UBC-DSCI/dsci-100-assets
Path: blob/master/2020-spring/materials/tutorial_07/tutorial_07.ipynb
2051 views
Kernel: R

Tutorial 7: Classification (Part II)

Handwritten Digit Classification using R

Source: https://media.giphy.com/media/UwrdbvJz1CNck/giphy.gif

MNIST is a computer vision dataset that consists of images of handwritten digits like these:

It also includes labels for each image, telling us which digit it is. For example, the labels for the above images are 5, 0, 4, and 1.

In this tutorial, we’re going to train a classifier to look at images and predict what digits they are. Our goal isn’t to train a really elaborate model that achieves state-of-the-art performance, but rather to dip a toe into using classification with pixelated images. As such, we’re going to keep working with the simple K-nearest neighbour classifier we have been exploring in the last two weeks.

Using image data for classification

As mentioned earlier, every MNIST data point has two parts: an image of a handwritten digit and a corresponding label. Both the training set and test set contain images and their corresponding labels.

Each image is 28 pixels by 28 pixels. We can interpret this as a big matrix of numbers:

We can flatten this matrix into a vector of 28x28 = 784 numbers and give it a class label (here 1 for the number one). It doesn’t matter how we flatten the array, as long as we’re consistent between images. From this perspective, the MNIST images are just a bunch of points in a 784-dimensional vector space, with a very rich structure.

We do this for every image of the digits we have, and we create a data table like the one shown below that we can use for classification. Note, like any other classification problem that we have seen before, we need many observations for each class. This problem is also a bit different from the first classification problem we have encountered (Wisonsin breast cancer data set), in that we have more than two classes (here we have 10 classes, one for each digit from 0 to 9).

This information is taken from: https://tensorflow.rstudio.com/tensorflow/articles/tutorial_mnist_beginners.html

### ### Run this cell before continuing. ### library(repr) library(tidyverse) library(caret) source('tests_tutorial_07.R') source("cleanup_tutorial_07.R") # functions needed to work with images # code below sourced from: https://gist.github.com/daviddalpiaz/ae62ae5ccd0bada4b9acd6dbc9008706 # helper function for visualization show_digit = function(arr784, col = gray(12:1 / 12), ...) { image(matrix(as.matrix(arr784[-785]), nrow = 28)[, 28:1], col = col, ...) }

Question 1.0 Multiple Choice:
{points: 1}

How many rows and columns does the array of an image have?

A. 784 columns and 1 row

B. 28 columns and 1 row

C. 18 columns and 18 rows

D. 28 columns and 28 rows

Assign your answer to an object called answer1.0.

# Make sure the correct answer is an uppercase letter. # Surround your answer with quotation marks. # Replace the fail() with your answer. # your code here fail() # No Answer - remove if you provide an answer
test_1.0()

Question 1.1 Multiple Choice:
{points: 1}

Once we linearize the array, how many rows represent a number?

A. 28

B. 784

C. 1

D. 18

Assign your answer to an object called answer1.1.

# Make sure the correct answer is an uppercase letter. # Surround your answer with quotation marks. # Replace the fail() with your answer. # your code here fail() # No Answer - remove if you provide an answer
test_1.1()

2. Exploring the Data

Before we move on to do the modeling component, it is always required that we take a look at our data and understand the problem and the structure of the data well. We can start this part by loading the images and taking a look at the first rows of the dataset. You can load the data set by running the cell below.

# Load images. # Run this cell. training_data <- read.csv('data/mnist_train_small.csv') testing_data <- read.csv('data/mnist_test_small.csv')

Look at the first 6 rows of training_data. What do you notice?

head(training_data)

There are no class labels! This data set has already been split into the X's (which you loaded above) and the labels, which you will load by running the cell below.

# Next, we will load the labels. # Run this cell. training_labels <- read_csv('data/mnist_train_label_small.csv')['y'] %>% mutate(y = as.factor(y)) testing_labels <- read_csv('data/mnist_test_label_small.csv')['y'] %>% mutate(y = as.factor(y))

Look at the first 6 labels of training_labels using the head() function.

# Use this cell to view the first 6 labels. # Run this cell. head(training_labels)

Question 2.0
{points: 1}

How many rows does the training data set have? Note, each row is a different number in the postal code system.

Use nrow(). Note, the testing data set should have fewer rows than the training data set.

Assign your answer to an object called number_of_rows.

# your code here fail() # No Answer - remove if you provide an answer number_of_rows digest(600)
test_2.0()

Question 2.1
{points: 1}

For mutli-class classification with k-nn it is important for the classes to have about the same number of observations in each class. For example, if 90% of our training set observationas were labeled as 2's, then k-nn classification predict 2 almost every time and we would get an accuracy score of 90% even though our classifier wasn't really doing a great job.

Use the group_by and summarize function to get the counts for each group and see if the data set is balanced across the classes (has roughly equal numbers of observation for each class). Name the output counts. counts should be a data frame with 2 columns, y and n (the column n should have the counts for how many observations there were for each class group).

# your code here fail() # No Answer - remove if you provide an answer counts
test_2.1()

Question 2.2
{points: 3}

Are the classes roughly balanced?

DOUBLE CLICK TO EDIT THIS CELL AND REPLACE THIS TEXT WITH YOUR ANSWER.

To view an image in the notebook, you can use the show_digit function (we gave you the code for this function in the first code cell in the notebook, All you have to do to use it is run the cell below). The show_digit function takes the row from the dataset whose image you want to produce, which you can obtain using the slice function.

The code we provide below will show you the image for the observation in the 200th row from the training data set.

# Run this cell to get the images for the 200th row from the training data set. options(repr.plot.height = 4, repr.plot.width = 3.3) show_digit(slice(training_data, 200))

Question 2.3
{points: 3}

Show the image for row 102.

# your code here fail() # No Answer - remove if you provide an answer

If you are unsure as to what number the plot is depicting (because the handwriting is messy) you can use slice to get the label from the training_labels:

# run this cell to get the training label for the 200th row training_labels %>% slice(200)

Question 2.4
{points: 1}

What is the class label for row 102?

Assign your answer to an object called label_102.

# Replace the fail() with your answer. # your code here fail() # No Answer - remove if you provide an answer label_102
test_2.4()

3. Splitting the Data

Question 3.0
{points: 3}

Split the training data into X_train and Y_train. Do the same for the test set. Remember that the train() function from the caret package requires that the x argument a data frame object and the y argument a numeric or factor vector. In other words, X_train should be a data.frame and Y_train should be a factor type.

At the end of this question you should have the following 4 data frames:

  • X_train

  • Y_train

  • X_test

  • Y_test

# Set the seed. Don't remove this! set.seed(9999) # your code here fail() # No Answer - remove if you provide an answer

Question 3.1
{points: 3}

We have already split the data into two datasets, one for training purposes and one for testing purposes. Do you think this is a good idea? If yes, why do we do this? If no, explain why this is not a good idea.

DOUBLE CLICK TO EDIT THIS CELL AND REPLACE THIS TEXT WITH YOUR ANSWER.

Which kk should we use?

As you learned from the worksheet, we can use cross-validation on the training data set to select which kk is the most optimal for our data set for k-nn classification.

Question 3.2
{points: 3}

To get all the marks in this question, you will have to:

  • set a seed to make your analysis reproducible

  • Apply 3-fold cross-validation to our small training data

    • Test the following kk's: 1, 3, 5, 7, 9, 11

  • Plot the kk vs the accuracy

    • Assign this plot to an object called cross_val_plot

note - this will take 5-15 minutes to run... so we recommend you split the classifier training and cross validation in one cell and plotting into another cell (so you can tweak and re-run the plot code without re-training the classifier each time. Another hint is to make your training data very small, get the code working and then re-run the code with your training data the size you actually want it to be.

# Set the seed. Don't remove this! set.seed(1234) # your code here fail() # No Answer - remove if you provide an answer

Question 3.3
{points: 3}

Based on the plot from Question 3.2, which kk would you choose and how can you be sure about your decision? In your answer you should reference why we do cross-validation.

DOUBLE CLICK TO EDIT THIS CELL AND REPLACE THIS TEXT WITH YOUR ANSWER.

4. Let's build our model

Question 4.0
{points: 3}

Now that we have explored our data, separated the data into training and testing sets and applied cross-validation to choose the best kk, we can build our final model.

# Set the seed. Don't remove this! set.seed(9999) # your code here fail() # No Answer - remove if you provide an answer

Question 4.1
{points: 3}

Use your final model to predict on the test dataset and report the accuracy of this prediction.

# Set the seed. Don't remove this! set.seed(9999) # your code here fail() # No Answer - remove if you provide an answer

Question 4.2
{points: 3}

Print out 3 images and true labels from the test set that were predicted correctly. Use the show_digit function we gave you above to print out the images.

# Set the seed. Don't remove this! set.seed(1000) # your code here fail() # No Answer - remove if you provide an answer

Question 4.3
{points: 3}

Print out 3 images and true labels from the test set that were NOT predicted correctly. For the incorrectly labelled images also print out the predicted labels. Use the show_digit function we gave you above to print out the images.

# Set the seed. Don't remove this! set.seed(300) # your code here fail() # No Answer - remove if you provide an answer

Question 4.4
{points: 3}

Do you notice any differences between the images that were predicted correctly versus the images that were not?

DOUBLE CLICK TO EDIT THIS CELL AND REPLACE THIS TEXT WITH YOUR ANSWER.

Question 4.5
{points: 3}

What does this accuracy mean? Is it good enough that you would use this model for the Canada Post? Can you imagine a way we might improve our classifier's accuracy?

DOUBLE CLICK TO EDIT THIS CELL AND REPLACE THIS TEXT WITH YOUR ANSWER.