Advertisement
Guest User

Untitled

a guest
Jul 23rd, 2019
70
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 5.40 KB | None | 0 0
  1. import hashlib
  2. import io
  3. import logging
  4. import os
  5.  
  6. import contextlib2
  7. from tqdm import tqdm
  8.  
  9. import numpy as np
  10. import pandas as pd
  11.  
  12. import PIL.Image
  13. import tensorflow as tf
  14.  
  15. from object_detection.dataset_tools import tf_record_creation_util
  16. from object_detection.utils import dataset_util
  17. from object_detection.utils import label_map_util
  18.  
  19. flags = tf.app.flags
  20. flags.DEFINE_string('train_image_dir', '', 'Root directory to raw dataset.')
  21. flags.DEFINE_string('valid_image_dir', '', 'Root directory to raw dataset.')
  22. flags.DEFINE_string('boxable_csv', 'class-descriptions-boxable.csv', 'Label csv file of open images v4.')
  23. flags.DEFINE_string('train_annotation_csv', 'train-annotations-bbox.csv', 'Box csv file of open images v4 for train.')
  24. flags.DEFINE_string('valid_annotation_csv', 'validation-annotations-bbox.csv', 'Box csv file of open images v4 for valid.')
  25. flags.DEFINE_string('output_dir', '', 'Path to directory to output TFRecords.')
  26. flags.DEFINE_string('label_map_path', 'data/pet_label_map.pbtxt', 'Path to label map proto')
  27. flags.DEFINE_integer('num_shards', 10, 'Number of TFRecord shards')
  28.  
  29. FLAGS = flags.FLAGS
  30.  
  31. def get_label_names(label_csv, names=['Human face']):
  32. df = pd.read_csv(label_csv, names=['tag', 'label'])
  33. df = df.loc[df['label'].isin(names)]
  34. return df['tag'].tolist()
  35.  
  36. def create_face_tf_example(
  37. image_path, classes_text, classes, xmins, xmaxs, ymins, ymaxs, occluded, truncated):
  38. filename = os.path.basename(image_path)
  39. with tf.gfile.GFile(image_path, 'rb') as fid:
  40. encoded_jpg = fid.read()
  41. encoded_jpg_io = io.BytesIO(encoded_jpg)
  42. image = PIL.Image.open(encoded_jpg_io)
  43. width, height = image.size
  44. key = hashlib.sha256(encoded_jpg).hexdigest()
  45.  
  46. feature_dict = {
  47. 'image/height': dataset_util.int64_feature(height),
  48. 'image/width': dataset_util.int64_feature(width),
  49. 'image/filename': dataset_util.bytes_feature(filename.encode('utf8')),
  50. 'image/source_id': dataset_util.bytes_feature(filename.encode('utf8')),
  51. 'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),
  52. 'image/encoded': dataset_util.bytes_feature(encoded_jpg),
  53. 'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
  54. 'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
  55. 'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
  56. 'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
  57. 'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
  58. 'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
  59. 'image/object/class/label': dataset_util.int64_list_feature(classes),
  60. 'image/object/occluded': dataset_util.int64_list_feature(occluded),
  61. 'image/object/truncated': dataset_util.int64_list_feature(truncated),
  62. }
  63. example = tf.train.Example(features=tf.train.Features(feature=feature_dict))
  64. return example
  65.  
  66. def create_tf_record(output_filename, num_shards, label_map_dict, annotation_csv, image_dir, label_dict):
  67. logging.info('Creating tfrecord: {}'.format(output_filename))
  68. logging.info('Label map: {}'.format(label_map_dict))
  69.  
  70. label_names = label_dict.keys()
  71. df = pd.read_csv(annotation_csv)
  72.  
  73. df = df.loc[df['LabelName'].isin(label_names)].reset_index(drop=True)
  74. for k, v in label_dict.items():
  75. df['LabelName'] = df['LabelName'].replace(k, v)
  76. df['ClassId'] = [label_map_dict[name] for name in df['LabelName']]
  77.  
  78. examples = list(set(df['ImageID'].tolist()))
  79.  
  80. with contextlib2.ExitStack() as tf_record_close_stack:
  81. output_tfrecords = tf_record_creation_util.open_sharded_output_tfrecords(
  82. tf_record_close_stack, output_filename, num_shards)
  83. for idx, example in enumerate(tqdm(examples)):
  84. image_path = os.path.join(image_dir, '{}.jpg'.format(example))
  85. if not os.path.exists(image_path):
  86. # logging.warning('Could not find %s, ignoring example.', example)
  87. continue
  88. example_df = df[df['ImageID'] == example].reset_index(drop=True)
  89. classes_text = [l.encode('utf8') for l in example_df['LabelName'].tolist()]
  90. classes = example_df['ClassId'].tolist()
  91. xmins = example_df['XMin'].tolist()
  92. xmaxs = example_df['XMax'].tolist()
  93. ymins = example_df['YMin'].tolist()
  94. ymaxs = example_df['YMax'].tolist()
  95. occluded = example_df['IsOccluded'].tolist()
  96. truncated = example_df['IsTruncated'].tolist()
  97.  
  98. try:
  99. tf_example = create_face_tf_example(
  100. image_path, classes_text, classes, xmins, xmaxs, ymins, ymaxs, occluded, truncated)
  101. shard_idx = idx % num_shards
  102. output_tfrecords[shard_idx].write(tf_example.SerializeToString())
  103. except Exception as e:
  104. logging.warning('Invalid example: %s, %s', example, e)
  105.  
  106.  
  107. def main(_):
  108. train_image_dir = FLAGS.train_image_dir
  109. valid_image_dir = FLAGS.valid_image_dir
  110. train_annotation_csv = FLAGS.train_annotation_csv
  111. valid_annotation_csv = FLAGS.valid_annotation_csv
  112. label_map_dict = label_map_util.get_label_map_dict(FLAGS.label_map_path)
  113.  
  114. label_dict = { '/m/0dzct': 'face' }
  115. train_output_path = os.path.join(FLAGS.output_dir, 'faces_train.record')
  116. valid_output_path = os.path.join(FLAGS.output_dir, 'faces_valid.record')
  117.  
  118. create_tf_record(
  119. train_output_path,
  120. FLAGS.num_shards,
  121. label_map_dict,
  122. train_annotation_csv,
  123. train_image_dir,
  124. label_dict)
  125. create_tf_record(
  126. valid_output_path,
  127. 10,
  128. label_map_dict,
  129. valid_annotation_csv,
  130. valid_image_dir,
  131. label_dict)
  132.  
  133. if __name__ == '__main__':
  134. tf.app.run()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement