Path: blob/master/deep_learning/rnn/dataloader.py
1480 views
import numpy as np1from keras.utils import to_categorical234__all__ = ['DataLoader']567class DataLoader:8"""Container for a dataset."""910def __init__(self, images, labels, num_classes):11if images.shape[0] != labels.shape[0]:12raise ValueError('images.shape: %s labels.shape: %s' % (images.shape, labels.shape))1314self.num_classes = num_classes15self._images = images16self._labels = labels1718self._num_examples = images.shape[0]19self._epochs_completed = 020self._index_in_epoch = 02122def next_batch(self, batch_size, shuffle = True):23"""Return the next `batch_size` examples from this data set."""2425# shuffle for the first epoch26start = self._index_in_epoch27if self._epochs_completed == 0 and start == 0 and shuffle:28self._shuffle_images_and_labels()2930if start + batch_size > self._num_examples:31# retrieve the rest of the examples that does not add up to a full batch size32self._epochs_completed += 133rest_num_examples = self._num_examples - start34rest_images = self._images[start:self._num_examples]35rest_labels = self._labels[start:self._num_examples]36if shuffle:37self._shuffle_images_and_labels()3839# complete the batch size from the next epoch40start = 041self._index_in_epoch = batch_size - rest_num_examples42end = self._index_in_epoch43new_images = self._images[start:end]44new_labels = self._labels[start:end]45images = np.concatenate((rest_images, new_images), axis = 0)46labels = np.concatenate((rest_labels, new_labels), axis = 0)47return images, to_categorical(labels, self.num_classes)48else:49self._index_in_epoch += batch_size50end = self._index_in_epoch51return (self._images[start:end],52to_categorical(self._labels[start:end], self.num_classes))5354def _shuffle_images_and_labels(self):55permutated = np.arange(self._num_examples)56np.random.shuffle(permutated)57self._images[permutated]58self._labels[permutated]596061