Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- """ Sample TensorFlow XML-to-TFRecord converter
- usage: generate_tfrecord.py [-h] [-x XML_DIR] [-l LABELS_PATH] [-o OUTPUT_PATH] [-i IMAGE_DIR] [-c CSV_PATH]
- optional arguments:
- -h, --help show this help message and exit
- -x XML_DIR, --xml_dir XML_DIR
- Path to the folder where the input .xml files are stored.
- -l LABELS_PATH, --labels_path LABELS_PATH
- Path to the labels (.pbtxt) file.
- -o OUTPUT_PATH, --output_path OUTPUT_PATH
- Path of output TFRecord (.record) file.
- -i IMAGE_DIR, --image_dir IMAGE_DIR
- Path to the folder where the input image files are stored. Defaults to the same directory as XML_DIR.
- -c CSV_PATH, --csv_path CSV_PATH
- Path of output .csv file. If none provided, then no file will be written.
- """
- import os
- import glob
- import pandas as pd
- import io
- import xml.etree.ElementTree as ET
- import argparse
- os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Suppress TensorFlow logging (1)
- import tensorflow.compat.v1 as tf
- from PIL import Image
- from object_detection.utils import dataset_util, label_map_util
- from collections import namedtuple
- # Initiate argument parser
- parser = argparse.ArgumentParser(
- description="Sample TensorFlow XML-to-TFRecord converter")
- parser.add_argument("-x",
- "--xml_dir",
- help="Path to the folder where the input .xml files are stored.",
- type=str)
- parser.add_argument("-l",
- "--labels_path",
- help="Path to the labels (.pbtxt) file.", type=str)
- parser.add_argument("-o",
- "--output_path",
- help="Path of output TFRecord (.record) file.", type=str)
- parser.add_argument("-i",
- "--image_dir",
- help="Path to the folder where the input image files are stored. "
- "Defaults to the same directory as XML_DIR.",
- type=str, default=None)
- parser.add_argument("-c",
- "--csv_path",
- help="Path of output .csv file. If none provided, then no file will be "
- "written.",
- type=str, default=None)
- args = parser.parse_args()
- if args.image_dir is None:
- args.image_dir = args.xml_dir
- label_map = label_map_util.load_labelmap(args.labels_path)
- label_map_dict = label_map_util.get_label_map_dict(label_map)
- def xml_to_csv(path):
- """Iterates through all .xml files (generated by labelImg) in a given directory and combines
- them in a single Pandas dataframe.
- Parameters:
- ----------
- path : str
- The path containing the .xml files
- Returns
- -------
- Pandas DataFrame
- The produced dataframe
- """
- xml_list = []
- for xml_file in glob.glob(path + '/*.xml'):
- tree = ET.parse(xml_file)
- root = tree.getroot()
- for member in root.findall('object'):
- value = (root.find('filename').text,
- int(root.find('size')[0].text),
- int(root.find('size')[1].text),
- member[0].text,
- int(member[4][0].text),
- int(member[4][1].text),
- int(member[4][2].text),
- int(member[4][3].text)
- )
- xml_list.append(value)
- column_name = ['filename', 'width', 'height',
- 'class', 'xmin', 'ymin', 'xmax', 'ymax']
- xml_df = pd.DataFrame(xml_list, columns=column_name)
- return xml_df
- def class_text_to_int(row_label):
- return label_map_dict[row_label]
- def split(df, group):
- data = namedtuple('data', ['filename', 'object'])
- gb = df.groupby(group)
- return [data(filename, gb.get_group(x)) for filename, x in zip(gb.groups.keys(), gb.groups)]
- def create_tf_example(group, path):
- with tf.gfile.GFile(os.path.join(path, '{}'.format(group.filename)), 'rb') as fid:
- encoded_jpg = fid.read()
- encoded_jpg_io = io.BytesIO(encoded_jpg)
- image = Image.open(encoded_jpg_io)
- width, height = image.size
- filename = group.filename.encode('utf8')
- image_format = b'jpg'
- xmins = []
- xmaxs = []
- ymins = []
- ymaxs = []
- classes_text = []
- classes = []
- for index, row in group.object.iterrows():
- xmins.append(row['xmin'] / width)
- xmaxs.append(row['xmax'] / width)
- ymins.append(row['ymin'] / height)
- ymaxs.append(row['ymax'] / height)
- classes_text.append(row['class'].encode('utf8'))
- classes.append(class_text_to_int(row['class']))
- tf_example = tf.train.Example(features=tf.train.Features(feature={
- 'image/height': dataset_util.int64_feature(height),
- 'image/width': dataset_util.int64_feature(width),
- 'image/filename': dataset_util.bytes_feature(filename),
- 'image/source_id': dataset_util.bytes_feature(filename),
- 'image/encoded': dataset_util.bytes_feature(encoded_jpg),
- 'image/format': dataset_util.bytes_feature(image_format),
- 'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
- 'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
- 'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
- 'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
- 'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
- 'image/object/class/label': dataset_util.int64_list_feature(classes),
- }))
- return tf_example
- def main(_):
- writer = tf.python_io.TFRecordWriter(args.output_path)
- path = os.path.join(args.image_dir)
- examples = xml_to_csv(args.xml_dir)
- grouped = split(examples, 'filename')
- for group in grouped:
- tf_example = create_tf_example(group, path)
- writer.write(tf_example.SerializeToString())
- writer.close()
- print('Successfully created the TFRecord file: {}'.format(args.output_path))
- if args.csv_path is not None:
- examples.to_csv(args.csv_path, index=None)
- print('Successfully created the CSV file: {}'.format(args.csv_path))
- if __name__ == '__main__':
- tf.app.run()
Advertisement
Add Comment
Please, Sign In to add comment