Guest User

Sample TensorFlow XML-to-TFRecord converter

a guest
Sep 14th, 2021
133
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 6.25 KB | None | 0 0
  1. """ Sample TensorFlow XML-to-TFRecord converter
  2.  
  3. usage: generate_tfrecord.py [-h] [-x XML_DIR] [-l LABELS_PATH] [-o OUTPUT_PATH] [-i IMAGE_DIR] [-c CSV_PATH]
  4.  
  5. optional arguments:
  6.  -h, --help            show this help message and exit
  7.  -x XML_DIR, --xml_dir XML_DIR
  8.                        Path to the folder where the input .xml files are stored.
  9.  -l LABELS_PATH, --labels_path LABELS_PATH
  10.                        Path to the labels (.pbtxt) file.
  11.  -o OUTPUT_PATH, --output_path OUTPUT_PATH
  12.                        Path of output TFRecord (.record) file.
  13.  -i IMAGE_DIR, --image_dir IMAGE_DIR
  14.                        Path to the folder where the input image files are stored. Defaults to the same directory as XML_DIR.
  15.  -c CSV_PATH, --csv_path CSV_PATH
  16.                        Path of output .csv file. If none provided, then no file will be written.
  17. """
  18.  
  19. import os
  20. import glob
  21. import pandas as pd
  22. import io
  23. import xml.etree.ElementTree as ET
  24. import argparse
  25.  
  26. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'    # Suppress TensorFlow logging (1)
  27. import tensorflow.compat.v1 as tf
  28. from PIL import Image
  29. from object_detection.utils import dataset_util, label_map_util
  30. from collections import namedtuple
  31.  
  32. # Initiate argument parser
  33. parser = argparse.ArgumentParser(
  34.     description="Sample TensorFlow XML-to-TFRecord converter")
  35. parser.add_argument("-x",
  36.                     "--xml_dir",
  37.                     help="Path to the folder where the input .xml files are stored.",
  38.                     type=str)
  39. parser.add_argument("-l",
  40.                     "--labels_path",
  41.                     help="Path to the labels (.pbtxt) file.", type=str)
  42. parser.add_argument("-o",
  43.                     "--output_path",
  44.                     help="Path of output TFRecord (.record) file.", type=str)
  45. parser.add_argument("-i",
  46.                     "--image_dir",
  47.                     help="Path to the folder where the input image files are stored. "
  48.                          "Defaults to the same directory as XML_DIR.",
  49.                     type=str, default=None)
  50. parser.add_argument("-c",
  51.                     "--csv_path",
  52.                     help="Path of output .csv file. If none provided, then no file will be "
  53.                          "written.",
  54.                     type=str, default=None)
  55.  
  56. args = parser.parse_args()
  57.  
  58. if args.image_dir is None:
  59.     args.image_dir = args.xml_dir
  60.  
  61. label_map = label_map_util.load_labelmap(args.labels_path)
  62. label_map_dict = label_map_util.get_label_map_dict(label_map)
  63.  
  64.  
  65. def xml_to_csv(path):
  66.     """Iterates through all .xml files (generated by labelImg) in a given directory and combines
  67.    them in a single Pandas dataframe.
  68.  
  69.    Parameters:
  70.    ----------
  71.    path : str
  72.        The path containing the .xml files
  73.    Returns
  74.    -------
  75.    Pandas DataFrame
  76.        The produced dataframe
  77.    """
  78.  
  79.     xml_list = []
  80.     for xml_file in glob.glob(path + '/*.xml'):
  81.         tree = ET.parse(xml_file)
  82.         root = tree.getroot()
  83.         for member in root.findall('object'):
  84.             value = (root.find('filename').text,
  85.                      int(root.find('size')[0].text),
  86.                      int(root.find('size')[1].text),
  87.                      member[0].text,
  88.                      int(member[4][0].text),
  89.                      int(member[4][1].text),
  90.                      int(member[4][2].text),
  91.                      int(member[4][3].text)
  92.                      )
  93.             xml_list.append(value)
  94.     column_name = ['filename', 'width', 'height',
  95.                    'class', 'xmin', 'ymin', 'xmax', 'ymax']
  96.     xml_df = pd.DataFrame(xml_list, columns=column_name)
  97.     return xml_df
  98.  
  99.  
  100. def class_text_to_int(row_label):
  101.     return label_map_dict[row_label]
  102.  
  103.  
  104. def split(df, group):
  105.     data = namedtuple('data', ['filename', 'object'])
  106.     gb = df.groupby(group)
  107.     return [data(filename, gb.get_group(x)) for filename, x in zip(gb.groups.keys(), gb.groups)]
  108.  
  109.  
  110. def create_tf_example(group, path):
  111.     with tf.gfile.GFile(os.path.join(path, '{}'.format(group.filename)), 'rb') as fid:
  112.         encoded_jpg = fid.read()
  113.     encoded_jpg_io = io.BytesIO(encoded_jpg)
  114.     image = Image.open(encoded_jpg_io)
  115.     width, height = image.size
  116.  
  117.     filename = group.filename.encode('utf8')
  118.     image_format = b'jpg'
  119.     xmins = []
  120.     xmaxs = []
  121.     ymins = []
  122.     ymaxs = []
  123.     classes_text = []
  124.     classes = []
  125.  
  126.     for index, row in group.object.iterrows():
  127.         xmins.append(row['xmin'] / width)
  128.         xmaxs.append(row['xmax'] / width)
  129.         ymins.append(row['ymin'] / height)
  130.         ymaxs.append(row['ymax'] / height)
  131.         classes_text.append(row['class'].encode('utf8'))
  132.         classes.append(class_text_to_int(row['class']))
  133.  
  134.     tf_example = tf.train.Example(features=tf.train.Features(feature={
  135.         'image/height': dataset_util.int64_feature(height),
  136.         'image/width': dataset_util.int64_feature(width),
  137.         'image/filename': dataset_util.bytes_feature(filename),
  138.         'image/source_id': dataset_util.bytes_feature(filename),
  139.         'image/encoded': dataset_util.bytes_feature(encoded_jpg),
  140.         'image/format': dataset_util.bytes_feature(image_format),
  141.         'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
  142.         'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
  143.         'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
  144.         'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
  145.         'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
  146.         'image/object/class/label': dataset_util.int64_list_feature(classes),
  147.     }))
  148.     return tf_example
  149.  
  150.  
  151. def main(_):
  152.  
  153.     writer = tf.python_io.TFRecordWriter(args.output_path)
  154.     path = os.path.join(args.image_dir)
  155.     examples = xml_to_csv(args.xml_dir)
  156.     grouped = split(examples, 'filename')
  157.     for group in grouped:
  158.         tf_example = create_tf_example(group, path)
  159.         writer.write(tf_example.SerializeToString())
  160.     writer.close()
  161.     print('Successfully created the TFRecord file: {}'.format(args.output_path))
  162.     if args.csv_path is not None:
  163.         examples.to_csv(args.csv_path, index=None)
  164.         print('Successfully created the CSV file: {}'.format(args.csv_path))
  165.  
  166.  
  167. if __name__ == '__main__':
  168.     tf.app.run()
Advertisement
Add Comment
Please, Sign In to add comment