Path: blob/master/Efficient-image-loading/create_tfrecords.py
3118 views
import os1from argparse import ArgumentParser23import tensorflow as tf45from tools import get_images_paths678def _byte_feature(value):9"""Convert string / byte into bytes_list."""10if isinstance(value, type(tf.constant(0))):11value = value.numpy() # BytesList can't unpack string from EagerTensor.12return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))131415def _int64_feature(value):16"""Convert bool / enum / int / uint into int64_list."""17return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))181920def image_example(image_string, label):21feature = {22"label": _int64_feature(label),23"image_raw": _byte_feature(image_string),24}25return tf.train.Example(features=tf.train.Features(feature=feature))262728def store_many_tfrecords(images_list, save_file):2930assert save_file.endswith(31".tfrecords",32), 'File path is wrong, it should contain "*myname*.tfrecords"'3334directory = os.path.dirname(save_file)35if not os.path.exists(directory):36os.makedirs(directory)3738with tf.io.TFRecordWriter(save_file) as writer: # start writer39for label, filename in enumerate(images_list): # cycle by each image path40image_string = open(filename, "rb").read() # read the image as bytes string41tf_example = image_example(42image_string, label,43) # save the data as tf.Example object44writer.write(tf_example.SerializeToString()) # and write it into database454647if __name__ == "__main__":48parser = ArgumentParser()49parser.add_argument(50"--path",51"-p",52type=str,53required=True,54help="path to the images folder to collect",55)56parser.add_argument(57"--output",58"-o",59type=str,60required=True,61help='path to the output tfrecords file i.e. "path/to/folder/myname.tfrecords"',62)6364args = parser.parse_args()65image_paths = get_images_paths(args.path)66store_many_tfrecords(image_paths, args.output)676869