Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
machinecurve
GitHub Repository: machinecurve/extra_keras_datasets
Path: blob/master/extra_keras_datasets/iris.py
153 views
1
"""
2
Import the Iris dataset
3
Source: http://archive.ics.uci.edu/ml/datasets/Iris
4
Description: The data set contains 3 classes of 50 instances each, where
5
each class refers to a type of iris plant.
6
7
~~~ Important note ~~~
8
Please cite the following paper when using or referencing the dataset:
9
Fisher,R.A. "The use of multiple measurements in taxonomic problems"
10
Annual Eugenics, 7, Part II, 179-188 (1936); also in "Contributions
11
to Mathematical Statistics" (John Wiley, NY, 1950).
12
"""
13
14
from tensorflow.keras.utils import get_file
15
import numpy as np
16
import math
17
import logging
18
19
20
def warn_citation():
21
"""Warns about citation requirements
22
# Returns
23
Void
24
"""
25
logging.warning(("Please cite the following paper when using or"
26
" referencing this Extra Keras Dataset:"))
27
logging.warning(
28
("Fisher,R.A. \"The use of multiple measurements in taxonomic "
29
"problems\" Annual Eugenics, 7, Part II, 179-188 (1936); also "
30
"in \"Contributions to Mathematical Statistics\" (John Wiley"
31
", NY, 1950).")
32
)
33
34
35
def load_data(path="iris.npz", test_split=0.2):
36
"""Loads the Iris dataset.
37
# Arguments
38
path: path where to cache the dataset locally
39
(relative to ~/.keras/datasets).
40
test_split: percentage of data to use for testing (by default 20%)
41
# Returns
42
Tuple of Numpy arrays: `(input_train, target_train),
43
(input_test, target_test)`.
44
Input structure: (sepal length, sepal width, petal length,
45
petal width)
46
Target structure: 0 = iris setosa; 1 = iris versicolor;
47
2 = iris virginica.
48
"""
49
# Log about loading
50
logging.basicConfig(level=logging.INFO)
51
logging.info('Loading dataset = iris')
52
53
# Load data
54
path = get_file(
55
path,
56
origin=("http://archive.ics.uci.edu/ml/machine-learning-databases/"
57
"iris/iris.data")
58
)
59
60
# Read data from file
61
f = open(path, "r")
62
lines = f.readlines()
63
64
# Process each line into input/target structure
65
samples = []
66
for line in lines:
67
sample = line_to_list(line)
68
if sample is not None:
69
samples.append(sample)
70
f.close()
71
72
# Randomly shuffle the data
73
np.random.shuffle(samples)
74
75
# Compute test_split in length
76
num_test_samples = math.floor(len(samples) * test_split)
77
78
# Split data
79
training_data = samples[num_test_samples:]
80
testing_data = samples[:num_test_samples]
81
82
# Split into inputs and targets
83
input_train = np.array([i[0:4] for i in training_data])
84
input_test = np.array([i[0:4] for i in testing_data])
85
target_train = np.array([i[4] for i in training_data])
86
target_test = np.array([i[4] for i in testing_data])
87
88
# Warn about citation
89
warn_citation()
90
91
# Return data
92
return (input_train, target_train), (input_test, target_test)
93
94
95
def line_to_list(line):
96
"""
97
Convert a String-based line into a list with input and target data.
98
"""
99
elements = line.split(",")
100
if len(elements) > 1:
101
target = target_string_to_int(elements[4])
102
full_sample = [float(i) for i in elements[0:4]]
103
full_sample.append(target)
104
return tuple(full_sample)
105
else:
106
return None
107
108
109
def target_string_to_int(target_value):
110
"""
111
Convert a String-based into an Integer-based target value.
112
"""
113
if target_value == "Iris-setosa\n":
114
return 0
115
elif target_value == "Iris-versicolor\n":
116
return 1
117
else:
118
return 2
119
120