Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ethen8181
GitHub Repository: ethen8181/machine-learning
Path: blob/master/deep_learning/rnn/dataloader.py
1480 views
1
import numpy as np
2
from keras.utils import to_categorical
3
4
5
__all__ = ['DataLoader']
6
7
8
class DataLoader:
9
"""Container for a dataset."""
10
11
def __init__(self, images, labels, num_classes):
12
if images.shape[0] != labels.shape[0]:
13
raise ValueError('images.shape: %s labels.shape: %s' % (images.shape, labels.shape))
14
15
self.num_classes = num_classes
16
self._images = images
17
self._labels = labels
18
19
self._num_examples = images.shape[0]
20
self._epochs_completed = 0
21
self._index_in_epoch = 0
22
23
def next_batch(self, batch_size, shuffle = True):
24
"""Return the next `batch_size` examples from this data set."""
25
26
# shuffle for the first epoch
27
start = self._index_in_epoch
28
if self._epochs_completed == 0 and start == 0 and shuffle:
29
self._shuffle_images_and_labels()
30
31
if start + batch_size > self._num_examples:
32
# retrieve the rest of the examples that does not add up to a full batch size
33
self._epochs_completed += 1
34
rest_num_examples = self._num_examples - start
35
rest_images = self._images[start:self._num_examples]
36
rest_labels = self._labels[start:self._num_examples]
37
if shuffle:
38
self._shuffle_images_and_labels()
39
40
# complete the batch size from the next epoch
41
start = 0
42
self._index_in_epoch = batch_size - rest_num_examples
43
end = self._index_in_epoch
44
new_images = self._images[start:end]
45
new_labels = self._labels[start:end]
46
images = np.concatenate((rest_images, new_images), axis = 0)
47
labels = np.concatenate((rest_labels, new_labels), axis = 0)
48
return images, to_categorical(labels, self.num_classes)
49
else:
50
self._index_in_epoch += batch_size
51
end = self._index_in_epoch
52
return (self._images[start:end],
53
to_categorical(self._labels[start:end], self.num_classes))
54
55
def _shuffle_images_and_labels(self):
56
permutated = np.arange(self._num_examples)
57
np.random.shuffle(permutated)
58
self._images[permutated]
59
self._labels[permutated]
60
61