Advertisement
Guest User

generate_tfrecord.py

a guest
Feb 19th, 2018
235
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.52 KB | None | 0 0
  1. """
  2. Usage:
  3.  # From tensorflow/models/
  4.  # Create train data:
  5.  python generate_tfrecord.py --csv_input=data/train_labels.csv  --output_path=train.record
  6.  
  7.  # Create test data:
  8.  python generate_tfrecord.py --csv_input=data/test_labels.csv  --output_path=test.record
  9. """
  10. from __future__ import division
  11. from __future__ import print_function
  12. from __future__ import absolute_import
  13.  
  14. import os
  15. import io
  16. import pandas as pd
  17. import tensorflow as tf
  18.  
  19. import sys
  20. sys.path.append('../models-master/research') # point to your tensorflow dir
  21. sys.path.append('../models-master/research/slim') # point ot your slim dir
  22.  
  23. from PIL import Image
  24. from object_detection.utils import dataset_util
  25. from collections import namedtuple, OrderedDict
  26.  
  27. flags = tf.app.flags
  28. flags.DEFINE_string('csv_input', '', 'Path to the CSV input')
  29. flags.DEFINE_string('output_path', '', 'Path to output TFRecord')
  30. FLAGS = flags.FLAGS
  31.  
  32.  
  33. # TO-DO replace this with label map
  34. def class_text_to_int(row_label):
  35.     if row_label == 'tennisplayer':
  36.         return 1
  37.     else:
  38.         None
  39.  
  40.  
  41. def split(df, group):
  42.     data = namedtuple('data', ['filename', 'object'])
  43.     gb = df.groupby(group)
  44.     return [data(filename, gb.get_group(x)) for filename, x in zip(gb.groups.keys(), gb.groups)]
  45.  
  46.  
  47. def create_tf_example(group, path):
  48.     with tf.gfile.GFile(os.path.join(path, '{}'.format(group.filename)), 'rb') as fid:
  49.         encoded_jpg = fid.read()
  50.     encoded_jpg_io = io.BytesIO(encoded_jpg)
  51.     image = Image.open(encoded_jpg_io)
  52.     width, height = image.size
  53.  
  54.     filename = group.filename.encode('utf8')
  55.     image_format = b'jpg'
  56.     xmins = []
  57.     xmaxs = []
  58.     ymins = []
  59.     ymaxs = []
  60.     classes_text = []
  61.     classes = []
  62.  
  63.     for index, row in group.object.iterrows():
  64.         xmins.append(row['xmin'] / width)
  65.         xmaxs.append(row['xmax'] / width)
  66.         ymins.append(row['ymin'] / height)
  67.         ymaxs.append(row['ymax'] / height)
  68.         classes_text.append(row['class'].encode('utf8'))
  69.         classes.append(class_text_to_int(row['class']))
  70.  
  71.     tf_example = tf.train.Example(features=tf.train.Features(feature={
  72.         'image/height': dataset_util.int64_feature(height),
  73.         'image/width': dataset_util.int64_feature(width),
  74.         'image/filename': dataset_util.bytes_feature(filename),
  75.         'image/source_id': dataset_util.bytes_feature(filename),
  76.         'image/encoded': dataset_util.bytes_feature(encoded_jpg),
  77.         'image/format': dataset_util.bytes_feature(image_format),
  78.         'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
  79.         'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
  80.         'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
  81.         'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
  82.         'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
  83.         'image/object/class/label': dataset_util.int64_list_feature(classes),
  84.     }))
  85.     return tf_example
  86.  
  87.  
  88. def main(_):
  89.     writer = tf.python_io.TFRecordWriter(FLAGS.output_path)
  90.     path = os.path.join(os.getcwd(), 'images')
  91.     examples = pd.read_csv(FLAGS.csv_input)
  92.     grouped = split(examples, 'filename')
  93.     for group in grouped:
  94.         tf_example = create_tf_example(group, path)
  95.         writer.write(tf_example.SerializeToString())
  96.  
  97.     writer.close()
  98.     output_path = os.path.join(os.getcwd(), FLAGS.output_path)
  99.     print('Successfully created the TFRecords: {}'.format(output_path))
  100.  
  101.  
  102. if __name__ == '__main__':
  103.     tf.app.run()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement