Guest User

Untitled

a guest
Apr 22nd, 2018
78
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.14 KB | None | 0 0
  1. """
  2. Usage:
  3. # From the data set dir
  4. # Create train data:
  5. python ../generate_tfrecord.py --csv_input=data/train_labels.csv --output_path=train.record
  6.  
  7. """
  8. from __future__ import division
  9. from __future__ import print_function
  10. from __future__ import absolute_import
  11.  
  12. import os
  13. import io
  14. import pandas as pd
  15. import tensorflow as tf
  16.  
  17. from PIL import Image
  18. from object_detection.utils import dataset_util
  19. from collections import namedtuple, OrderedDict
  20.  
  21. flags = tf.app.flags
  22. flags.DEFINE_string('csv_input', '', 'Path to the CSV input')
  23. flags.DEFINE_string('output_path', '', 'Path to output TFRecord')
  24. FLAGS = flags.FLAGS
  25.  
  26.  
  27. # TO-DO replace this with label map
  28. def class_text_to_int(row_label):
  29. if row_label == 'license_plate':
  30. return 1
  31. else:
  32. None
  33.  
  34.  
  35. def split(df, group):
  36. data = namedtuple('data', ['filename', 'object'])
  37. gb = df.groupby(group)
  38. return [data(filename, gb.get_group(x)) for filename, x in zip(gb.groups.keys(), gb.groups)]
  39.  
  40.  
  41. def create_tf_example(group, path):
  42. with tf.gfile.GFile(os.path.join(path, '{}'.format(group.filename)), 'rb') as fid:
  43. encoded_jpg = fid.read()
  44. encoded_jpg_io = io.BytesIO(encoded_jpg)
  45. image = Image.open(encoded_jpg_io)
  46. width, height = image.size
  47.  
  48. filename = group.filename.encode('utf8')
  49.  
  50. image_format = b'jpg'
  51. xmins = []
  52. xmaxs = []
  53. ymins = []
  54. ymaxs = []
  55. classes_text = []
  56. classes = []
  57.  
  58. for index, row in group.object.iterrows():
  59. xmins.append(row['xmin'] / width)
  60. xmaxs.append(row['xmax'] / width)
  61. ymins.append(row['ymin'] / height)
  62. ymaxs.append(row['ymax'] / height)
  63. classes_text.append(row['class'].encode('utf8'))
  64. classes.append(class_text_to_int(row['class']))
  65.  
  66. tf_example = tf.train.Example(features=tf.train.Features(feature={
  67. 'image/height': dataset_util.int64_feature(height),
  68. 'image/width': dataset_util.int64_feature(width),
  69. 'image/filename': dataset_util.bytes_feature(filename),
  70. 'image/source_id': dataset_util.bytes_feature(filename),
  71. 'image/encoded': dataset_util.bytes_feature(encoded_jpg),
  72. 'image/format': dataset_util.bytes_feature(image_format),
  73. 'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
  74. 'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
  75. 'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
  76. 'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
  77. 'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
  78. 'image/object/class/label': dataset_util.int64_list_feature(classes),
  79. }))
  80. return tf_example
  81.  
  82.  
  83. def main(_):
  84. writer = tf.python_io.TFRecordWriter(FLAGS.output_path)
  85. path = os.path.join(os.getcwd())
  86. examples = pd.read_csv(FLAGS.csv_input)
  87. grouped = split(examples, 'filename')
  88. for group in grouped:
  89. tf_example = create_tf_example(group, path)
  90. writer.write(tf_example.SerializeToString())
  91.  
  92. writer.close()
  93. output_path = os.path.join(os.getcwd(), FLAGS.output_path)
  94. print('Successfully created the TFRecords: {}'.format(output_path))
  95.  
  96.  
  97. if __name__ == '__main__':
  98. tf.app.run()
Add Comment
Please, Sign In to add comment