Path: blob/master/Efficient-image-loading/loader.py
3118 views
import os1from abc import abstractmethod2from timeit import default_timer as timer34import cv25import lmdb6import numpy as np7import tensorflow as tf8from PIL import Image9from turbojpeg import TurboJPEG1011os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"121314class ImageLoader:15extensions: tuple = (".png", ".jpg", ".jpeg", ".tiff", ".bmp", ".gif", ".tfrecords")1617def __init__(self, path: str, mode: str = "BGR"):18self.path = path19self.mode = mode20self.dataset = self.parse_input(self.path)21self.sample_idx = 02223def parse_input(self, path):2425# single image or tfrecords file26if os.path.isfile(path):27assert path.lower().endswith(28self.extensions,29), f"Unsupportable extension, please, use one of {self.extensions}"30return [path]3132if os.path.isdir(path):33# lmdb environment34if any([file.endswith(".mdb") for file in os.listdir(path)]):35return path36else:37# folder with images38paths = [os.path.join(path, image) for image in os.listdir(path)]39return paths4041def __iter__(self):42self.sample_idx = 043return self4445def __len__(self):46return len(self.dataset)4748@abstractmethod49def __next__(self):50pass515253class CV2Loader(ImageLoader):54def __next__(self):55start = timer()56path = self.dataset[self.sample_idx] # get image path by index from the dataset57image = cv2.imread(path) # read the image58full_time = timer() - start59if self.mode == "RGB":60start = timer()61image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # change color mode62full_time += timer() - start63self.sample_idx += 164return image, full_time656667class PILLoader(ImageLoader):68def __next__(self):69start = timer()70path = self.dataset[self.sample_idx] # get image path by index from the dataset71image = np.asarray(Image.open(path)) # read the image as numpy array72full_time = timer() - start73if self.mode == "BGR":74start = timer()75image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) # change color mode76full_time += timer() - start77self.sample_idx += 178return image, full_time798081class TurboJpegLoader(ImageLoader):82def __init__(self, path, **kwargs):83super(TurboJpegLoader, self).__init__(path, **kwargs)84self.jpeg_reader = TurboJPEG() # create TurboJPEG object for image reading8586def __next__(self):87start = timer()88file = open(self.dataset[self.sample_idx], "rb") # open the input file as bytes89full_time = timer() - start90if self.mode == "RGB":91mode = 092elif self.mode == "BGR":93mode = 194start = timer()95image = self.jpeg_reader.decode(file.read(), mode) # decode raw image96full_time += timer() - start97self.sample_idx += 198return image, full_time99100101class LmdbLoader(ImageLoader):102def __init__(self, path, **kwargs):103super(LmdbLoader, self).__init__(path, **kwargs)104self.path = path105self._dataset_size = 0106self.dataset = self.open_database()107108# we need to open the database to read images from it109def open_database(self):110lmdb_env = lmdb.open(self.path) # open the environment by path111lmdb_txn = lmdb_env.begin() # start reading112lmdb_cursor = lmdb_txn.cursor() # create cursor to iterate through the database113self._dataset_size = lmdb_env.stat()[114"entries"115] # get number of items in full dataset116return lmdb_cursor117118def __iter__(self):119self.dataset.first() # return the cursor to the first database element120return self121122def __next__(self):123start = timer()124raw_image = self.dataset.value() # get raw image125image = np.frombuffer(raw_image, dtype=np.uint8) # convert it to numpy126image = cv2.imdecode(image, cv2.IMREAD_COLOR) # decode image127full_time = timer() - start128if self.mode == "RGB":129start = timer()130image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)131full_time += timer() - start132start = timer()133self.dataset.next() # step to the next element in database134full_time += timer() - start135return image, full_time136137def __len__(self):138return self._dataset_size # get dataset length139140141class TFRecordsLoader(ImageLoader):142def __init__(self, path, **kwargs):143super(TFRecordsLoader, self).__init__(path, **kwargs)144self._dataset = self.open_database()145146def open_database(self):147def _parse_image_function(example_proto):148return tf.io.parse_single_example(example_proto, image_feature_description)149150# dataset structure description151image_feature_description = {152"label": tf.io.FixedLenFeature([], tf.int64),153"image_raw": tf.io.FixedLenFeature([], tf.string),154}155raw_image_dataset = tf.data.TFRecordDataset(self.path) # open dataset by path156parsed_image_dataset = raw_image_dataset.map(157_parse_image_function,158) # parse dataset using structure description159160return parsed_image_dataset161162def __iter__(self):163self.dataset = self._dataset.as_numpy_iterator()164return self165166def __next__(self):167start = timer()168value = next(self.dataset)[169"image_raw"170] # step to the next element in database and get new image171image = tf.image.decode_jpeg(value).numpy() # decode raw image172full_time = timer() - start173if self.mode == "BGR":174start = timer()175image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)176full_time += timer() - start177return image, full_time178179def __len__(self):180return self._dataset.reduce(181np.int64(0), lambda x, _: x + 1,182).numpy() # get dataset length183184185methods = {186"cv2": CV2Loader,187"pil": PILLoader,188"turbojpeg": TurboJpegLoader,189"lmdb": LmdbLoader,190"tfrecords": TFRecordsLoader,191}192193194