Advertisement
Guest User

Untitled

a guest
May 19th, 2019
84
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.71 KB | None | 0 0
  1. """
  2. Usage:
  3.  # From tensorflow/models/
  4.  # Create train data:
  5.  python generate_tfrecord.py --csv_input=images/train_labels.csv --image_dir=images/train --output_path=train.record
  6.  # Create test data:
  7.  python generate_tfrecord.py --csv_input=images/test_labels.csv  --image_dir=images/test --output_path=test.record
  8. """
  9. from __future__ import division
  10. from __future__ import print_function
  11. from __future__ import absolute_import
  12.  
  13. import os
  14. import io
  15. import pandas as pd
  16. import tensorflow as tf
  17.  
  18. from PIL import Image
  19. from object_detection.utils import dataset_util
  20. from collections import namedtuple, OrderedDict
  21.  
  22. flags = tf.app.flags
  23. flags.DEFINE_string('csv_input', '', 'Path to the CSV input')
  24. flags.DEFINE_string('image_dir', '', 'Path to the image directory')
  25. flags.DEFINE_string('output_path', '', 'Path to output TFRecord')
  26. FLAGS = flags.FLAGS
  27.  
  28.  
  29. # TO-DO replace this with label map
  30. def class_text_to_int(row_label):
  31.     if row_label == 'nine':
  32.         return 1
  33.     elif row_label == 'ten':
  34.         return 2
  35.     elif row_label == 'jack':
  36.         return 3
  37.     elif row_label == 'queen':
  38.         return 4
  39.     elif row_label == 'king':
  40.         return 5
  41.     elif row_label == 'ace':
  42.         return 6
  43.     else:
  44.         None
  45.  
  46.  
  47. def split(df, group):
  48.     data = namedtuple('data', ['filename', 'object'])
  49.     gb = df.groupby(group)
  50.     return [data(filename, gb.get_group(x)) for filename, x in zip(gb.groups.keys(), gb.groups)]
  51.  
  52.  
  53. def create_tf_example(group, path):
  54.     with tf.gfile.GFile(os.path.join(path, '{}'.format(group.filename)), 'rb') as fid:
  55.         encoded_jpg = fid.read()
  56.     encoded_jpg_io = io.BytesIO(encoded_jpg)
  57.     image = Image.open(encoded_jpg_io)
  58.     width, height = image.size
  59.  
  60.     filename = group.filename.encode('utf8')
  61.     image_format = b'jpg'
  62.     xmins = []
  63.     xmaxs = []
  64.     ymins = []
  65.     ymaxs = []
  66.     classes_text = []
  67.     classes = []
  68.  
  69.     for index, row in group.object.iterrows():
  70.         xmins.append(row['xmin'] / width)
  71.         xmaxs.append(row['xmax'] / width)
  72.         ymins.append(row['ymin'] / height)
  73.         ymaxs.append(row['ymax'] / height)
  74.         classes_text.append(row['class'].encode('utf8'))
  75.         classes.append(class_text_to_int(row['class']))
  76.  
  77.     tf_example = tf.train.Example(features=tf.train.Features(feature={
  78.         'image/height': dataset_util.int64_feature(height),
  79.         'image/width': dataset_util.int64_feature(width),
  80.         'image/filename': dataset_util.bytes_feature(filename),
  81.         'image/source_id': dataset_util.bytes_feature(filename),
  82.         'image/encoded': dataset_util.bytes_feature(encoded_jpg),
  83.         'image/format': dataset_util.bytes_feature(image_format),
  84.         'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
  85.         'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
  86.         'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
  87.         'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
  88.         'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
  89.         'image/object/class/label': dataset_util.int64_list_feature(classes),
  90.     }))
  91.     return tf_example
  92.  
  93.  
  94. def main(_):
  95.     writer = tf.python_io.TFRecordWriter(FLAGS.output_path)
  96.     path = os.path.join(os.getcwd(), FLAGS.image_dir)
  97.     examples = pd.read_csv(FLAGS.csv_input)
  98.     grouped = split(examples, 'filename')
  99.     for group in grouped:
  100.         tf_example = create_tf_example(group, path)
  101.         writer.write(tf_example.SerializeToString())
  102.  
  103.     writer.close()
  104.     output_path = os.path.join(os.getcwd(), FLAGS.output_path)
  105.     print('Successfully created the TFRecords: {}'.format(output_path))
  106.  
  107.  
  108. if __name__ == '__main__':
  109.     tf.app.run()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement