Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import hashlib
- import io
- import logging
- import os
- import contextlib2
- from tqdm import tqdm
- import numpy as np
- import pandas as pd
- import PIL.Image
- import tensorflow as tf
- from object_detection.dataset_tools import tf_record_creation_util
- from object_detection.utils import dataset_util
- from object_detection.utils import label_map_util
- flags = tf.app.flags
- flags.DEFINE_string('train_image_dir', '', 'Root directory to raw dataset.')
- flags.DEFINE_string('valid_image_dir', '', 'Root directory to raw dataset.')
- flags.DEFINE_string('boxable_csv', 'class-descriptions-boxable.csv', 'Label csv file of open images v4.')
- flags.DEFINE_string('train_annotation_csv', 'train-annotations-bbox.csv', 'Box csv file of open images v4 for train.')
- flags.DEFINE_string('valid_annotation_csv', 'validation-annotations-bbox.csv', 'Box csv file of open images v4 for valid.')
- flags.DEFINE_string('output_dir', '', 'Path to directory to output TFRecords.')
- flags.DEFINE_string('label_map_path', 'data/pet_label_map.pbtxt', 'Path to label map proto')
- flags.DEFINE_integer('num_shards', 10, 'Number of TFRecord shards')
- FLAGS = flags.FLAGS
- def get_label_names(label_csv, names=['Human face']):
- df = pd.read_csv(label_csv, names=['tag', 'label'])
- df = df.loc[df['label'].isin(names)]
- return df['tag'].tolist()
- def create_face_tf_example(
- image_path, classes_text, classes, xmins, xmaxs, ymins, ymaxs, occluded, truncated):
- filename = os.path.basename(image_path)
- with tf.gfile.GFile(image_path, 'rb') as fid:
- encoded_jpg = fid.read()
- encoded_jpg_io = io.BytesIO(encoded_jpg)
- image = PIL.Image.open(encoded_jpg_io)
- width, height = image.size
- key = hashlib.sha256(encoded_jpg).hexdigest()
- feature_dict = {
- 'image/height': dataset_util.int64_feature(height),
- 'image/width': dataset_util.int64_feature(width),
- 'image/filename': dataset_util.bytes_feature(filename.encode('utf8')),
- 'image/source_id': dataset_util.bytes_feature(filename.encode('utf8')),
- 'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),
- 'image/encoded': dataset_util.bytes_feature(encoded_jpg),
- 'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
- 'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
- 'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
- 'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
- 'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
- 'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
- 'image/object/class/label': dataset_util.int64_list_feature(classes),
- 'image/object/occluded': dataset_util.int64_list_feature(occluded),
- 'image/object/truncated': dataset_util.int64_list_feature(truncated),
- }
- example = tf.train.Example(features=tf.train.Features(feature=feature_dict))
- return example
- def create_tf_record(output_filename, num_shards, label_map_dict, annotation_csv, image_dir, label_dict):
- logging.info('Creating tfrecord: {}'.format(output_filename))
- logging.info('Label map: {}'.format(label_map_dict))
- label_names = label_dict.keys()
- df = pd.read_csv(annotation_csv)
- df = df.loc[df['LabelName'].isin(label_names)].reset_index(drop=True)
- for k, v in label_dict.items():
- df['LabelName'] = df['LabelName'].replace(k, v)
- df['ClassId'] = [label_map_dict[name] for name in df['LabelName']]
- examples = list(set(df['ImageID'].tolist()))
- with contextlib2.ExitStack() as tf_record_close_stack:
- output_tfrecords = tf_record_creation_util.open_sharded_output_tfrecords(
- tf_record_close_stack, output_filename, num_shards)
- for idx, example in enumerate(tqdm(examples)):
- image_path = os.path.join(image_dir, '{}.jpg'.format(example))
- if not os.path.exists(image_path):
- # logging.warning('Could not find %s, ignoring example.', example)
- continue
- example_df = df[df['ImageID'] == example].reset_index(drop=True)
- classes_text = [l.encode('utf8') for l in example_df['LabelName'].tolist()]
- classes = example_df['ClassId'].tolist()
- xmins = example_df['XMin'].tolist()
- xmaxs = example_df['XMax'].tolist()
- ymins = example_df['YMin'].tolist()
- ymaxs = example_df['YMax'].tolist()
- occluded = example_df['IsOccluded'].tolist()
- truncated = example_df['IsTruncated'].tolist()
- try:
- tf_example = create_face_tf_example(
- image_path, classes_text, classes, xmins, xmaxs, ymins, ymaxs, occluded, truncated)
- shard_idx = idx % num_shards
- output_tfrecords[shard_idx].write(tf_example.SerializeToString())
- except Exception as e:
- logging.warning('Invalid example: %s, %s', example, e)
- def main(_):
- train_image_dir = FLAGS.train_image_dir
- valid_image_dir = FLAGS.valid_image_dir
- train_annotation_csv = FLAGS.train_annotation_csv
- valid_annotation_csv = FLAGS.valid_annotation_csv
- label_map_dict = label_map_util.get_label_map_dict(FLAGS.label_map_path)
- label_dict = { '/m/0dzct': 'face' }
- train_output_path = os.path.join(FLAGS.output_dir, 'faces_train.record')
- valid_output_path = os.path.join(FLAGS.output_dir, 'faces_valid.record')
- create_tf_record(
- train_output_path,
- FLAGS.num_shards,
- label_map_dict,
- train_annotation_csv,
- train_image_dir,
- label_dict)
- create_tf_record(
- valid_output_path,
- 10,
- label_map_dict,
- valid_annotation_csv,
- valid_image_dir,
- label_dict)
- if __name__ == '__main__':
- tf.app.run()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement