Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
hackassin
GitHub Repository: hackassin/learnopencv
Path: blob/master/Efficient-image-loading/create_tfrecords.py
3118 views
1
import os
2
from argparse import ArgumentParser
3
4
import tensorflow as tf
5
6
from tools import get_images_paths
7
8
9
def _byte_feature(value):
10
"""Convert string / byte into bytes_list."""
11
if isinstance(value, type(tf.constant(0))):
12
value = value.numpy() # BytesList can't unpack string from EagerTensor.
13
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
14
15
16
def _int64_feature(value):
17
"""Convert bool / enum / int / uint into int64_list."""
18
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
19
20
21
def image_example(image_string, label):
22
feature = {
23
"label": _int64_feature(label),
24
"image_raw": _byte_feature(image_string),
25
}
26
return tf.train.Example(features=tf.train.Features(feature=feature))
27
28
29
def store_many_tfrecords(images_list, save_file):
30
31
assert save_file.endswith(
32
".tfrecords",
33
), 'File path is wrong, it should contain "*myname*.tfrecords"'
34
35
directory = os.path.dirname(save_file)
36
if not os.path.exists(directory):
37
os.makedirs(directory)
38
39
with tf.io.TFRecordWriter(save_file) as writer: # start writer
40
for label, filename in enumerate(images_list): # cycle by each image path
41
image_string = open(filename, "rb").read() # read the image as bytes string
42
tf_example = image_example(
43
image_string, label,
44
) # save the data as tf.Example object
45
writer.write(tf_example.SerializeToString()) # and write it into database
46
47
48
if __name__ == "__main__":
49
parser = ArgumentParser()
50
parser.add_argument(
51
"--path",
52
"-p",
53
type=str,
54
required=True,
55
help="path to the images folder to collect",
56
)
57
parser.add_argument(
58
"--output",
59
"-o",
60
type=str,
61
required=True,
62
help='path to the output tfrecords file i.e. "path/to/folder/myname.tfrecords"',
63
)
64
65
args = parser.parse_args()
66
image_paths = get_images_paths(args.path)
67
store_many_tfrecords(image_paths, args.output)
68
69