Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
uob-COMS30035
GitHub Repository: uob-COMS30035/lab_sheets_public
Path: blob/main/lab2/utils.py
340 views
1
import numpy as np
2
import matplotlib.pyplot as plt
3
import h5py
4
5
6
7
def load_data():
8
train_dataset = h5py.File('datasets/train_catvnoncat.h5', "r")
9
train_set_x_orig = np.array(train_dataset["train_set_x"][:]) # your train set features
10
train_set_y_orig = np.array(train_dataset["train_set_y"][:]) # your train set labels
11
12
test_dataset = h5py.File('datasets/test_catvnoncat.h5', "r")
13
test_set_x_orig = np.array(test_dataset["test_set_x"][:]) # your test set features
14
test_set_y_orig = np.array(test_dataset["test_set_y"][:]) # your test set labels
15
16
classes = np.array(test_dataset["list_classes"][:]) # the list of classes
17
18
train_set_y_orig = train_set_y_orig.reshape((1, train_set_y_orig.shape[0]))
19
test_set_y_orig = test_set_y_orig.reshape((1, test_set_y_orig.shape[0]))
20
21
return train_set_x_orig, train_set_y_orig, test_set_x_orig, test_set_y_orig, classes
22
23
24
def print_mislabeled_images(classes, X, y, p):
25
"""
26
Plots images where predictions and truth were different.
27
X -- dataset
28
y -- true labels
29
p -- predictions
30
"""
31
a = p + y
32
mislabeled_indices = np.asarray(np.where(a == 1))
33
plt.rcParams['figure.figsize'] = (40.0, 40.0) # set default size of plots
34
num_images = len(mislabeled_indices[0])
35
for i in range(num_images):
36
index = mislabeled_indices[1][i]
37
38
plt.subplot(2, num_images, i + 1)
39
plt.imshow(X[:,index].reshape(64,64,3), interpolation='nearest')
40
plt.axis('off')
41
plt.title("Prediction: " + classes[int(p[0,index])].decode("utf-8") + " \n Class: " + classes[y[0,index]].decode("utf-8"))
42
43