Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
hackassin
GitHub Repository: hackassin/learnopencv
Path: blob/master/Efficient-image-loading/loader.py
3118 views
1
import os
2
from abc import abstractmethod
3
from timeit import default_timer as timer
4
5
import cv2
6
import lmdb
7
import numpy as np
8
import tensorflow as tf
9
from PIL import Image
10
from turbojpeg import TurboJPEG
11
12
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
13
14
15
class ImageLoader:
16
extensions: tuple = (".png", ".jpg", ".jpeg", ".tiff", ".bmp", ".gif", ".tfrecords")
17
18
def __init__(self, path: str, mode: str = "BGR"):
19
self.path = path
20
self.mode = mode
21
self.dataset = self.parse_input(self.path)
22
self.sample_idx = 0
23
24
def parse_input(self, path):
25
26
# single image or tfrecords file
27
if os.path.isfile(path):
28
assert path.lower().endswith(
29
self.extensions,
30
), f"Unsupportable extension, please, use one of {self.extensions}"
31
return [path]
32
33
if os.path.isdir(path):
34
# lmdb environment
35
if any([file.endswith(".mdb") for file in os.listdir(path)]):
36
return path
37
else:
38
# folder with images
39
paths = [os.path.join(path, image) for image in os.listdir(path)]
40
return paths
41
42
def __iter__(self):
43
self.sample_idx = 0
44
return self
45
46
def __len__(self):
47
return len(self.dataset)
48
49
@abstractmethod
50
def __next__(self):
51
pass
52
53
54
class CV2Loader(ImageLoader):
55
def __next__(self):
56
start = timer()
57
path = self.dataset[self.sample_idx] # get image path by index from the dataset
58
image = cv2.imread(path) # read the image
59
full_time = timer() - start
60
if self.mode == "RGB":
61
start = timer()
62
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # change color mode
63
full_time += timer() - start
64
self.sample_idx += 1
65
return image, full_time
66
67
68
class PILLoader(ImageLoader):
69
def __next__(self):
70
start = timer()
71
path = self.dataset[self.sample_idx] # get image path by index from the dataset
72
image = np.asarray(Image.open(path)) # read the image as numpy array
73
full_time = timer() - start
74
if self.mode == "BGR":
75
start = timer()
76
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) # change color mode
77
full_time += timer() - start
78
self.sample_idx += 1
79
return image, full_time
80
81
82
class TurboJpegLoader(ImageLoader):
83
def __init__(self, path, **kwargs):
84
super(TurboJpegLoader, self).__init__(path, **kwargs)
85
self.jpeg_reader = TurboJPEG() # create TurboJPEG object for image reading
86
87
def __next__(self):
88
start = timer()
89
file = open(self.dataset[self.sample_idx], "rb") # open the input file as bytes
90
full_time = timer() - start
91
if self.mode == "RGB":
92
mode = 0
93
elif self.mode == "BGR":
94
mode = 1
95
start = timer()
96
image = self.jpeg_reader.decode(file.read(), mode) # decode raw image
97
full_time += timer() - start
98
self.sample_idx += 1
99
return image, full_time
100
101
102
class LmdbLoader(ImageLoader):
103
def __init__(self, path, **kwargs):
104
super(LmdbLoader, self).__init__(path, **kwargs)
105
self.path = path
106
self._dataset_size = 0
107
self.dataset = self.open_database()
108
109
# we need to open the database to read images from it
110
def open_database(self):
111
lmdb_env = lmdb.open(self.path) # open the environment by path
112
lmdb_txn = lmdb_env.begin() # start reading
113
lmdb_cursor = lmdb_txn.cursor() # create cursor to iterate through the database
114
self._dataset_size = lmdb_env.stat()[
115
"entries"
116
] # get number of items in full dataset
117
return lmdb_cursor
118
119
def __iter__(self):
120
self.dataset.first() # return the cursor to the first database element
121
return self
122
123
def __next__(self):
124
start = timer()
125
raw_image = self.dataset.value() # get raw image
126
image = np.frombuffer(raw_image, dtype=np.uint8) # convert it to numpy
127
image = cv2.imdecode(image, cv2.IMREAD_COLOR) # decode image
128
full_time = timer() - start
129
if self.mode == "RGB":
130
start = timer()
131
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
132
full_time += timer() - start
133
start = timer()
134
self.dataset.next() # step to the next element in database
135
full_time += timer() - start
136
return image, full_time
137
138
def __len__(self):
139
return self._dataset_size # get dataset length
140
141
142
class TFRecordsLoader(ImageLoader):
143
def __init__(self, path, **kwargs):
144
super(TFRecordsLoader, self).__init__(path, **kwargs)
145
self._dataset = self.open_database()
146
147
def open_database(self):
148
def _parse_image_function(example_proto):
149
return tf.io.parse_single_example(example_proto, image_feature_description)
150
151
# dataset structure description
152
image_feature_description = {
153
"label": tf.io.FixedLenFeature([], tf.int64),
154
"image_raw": tf.io.FixedLenFeature([], tf.string),
155
}
156
raw_image_dataset = tf.data.TFRecordDataset(self.path) # open dataset by path
157
parsed_image_dataset = raw_image_dataset.map(
158
_parse_image_function,
159
) # parse dataset using structure description
160
161
return parsed_image_dataset
162
163
def __iter__(self):
164
self.dataset = self._dataset.as_numpy_iterator()
165
return self
166
167
def __next__(self):
168
start = timer()
169
value = next(self.dataset)[
170
"image_raw"
171
] # step to the next element in database and get new image
172
image = tf.image.decode_jpeg(value).numpy() # decode raw image
173
full_time = timer() - start
174
if self.mode == "BGR":
175
start = timer()
176
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
177
full_time += timer() - start
178
return image, full_time
179
180
def __len__(self):
181
return self._dataset.reduce(
182
np.int64(0), lambda x, _: x + 1,
183
).numpy() # get dataset length
184
185
186
methods = {
187
"cv2": CV2Loader,
188
"pil": PILLoader,
189
"turbojpeg": TurboJpegLoader,
190
"lmdb": LmdbLoader,
191
"tfrecords": TFRecordsLoader,
192
}
193
194