Advertisement
Bart91

generate_tfrecord.py

Dec 2nd, 2018
100
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.25 KB | None | 0 0
  1. def class_text_to_int(klassen_label):
  2.   if klassen_label == 'person':
  3.     return 1
  4.   else:
  5.    
  6. def split(df, gruppiereBei):
  7.   data = namedtuple('data', ['filename', 'object'])
  8.   gb = df.groupby(gruppiereBei)
  9.   return [data(filename, gb.get_group(x)) for filename, x in zip(gb.groups.keys(), gb.groups)]
  10.  
  11. def bmp_to_jpg(bmp_data):
  12.   bmp_bild = tf.image.decode_bmp(bmp_data, channels = 3)
  13.   return tf.Session().run(tf.image.encode_jpeg(bmp_bild, format='rgb', quality = 100, channels = 3))
  14.  
  15. def is_bmp(bild_datei):
  16.   return ".bmp" in bild_datei
  17.  
  18. def create_tf_example(gruppe, path):
  19.   with tf.gfile.GFile(os.path.join(path, '{}'.format(gruppe.filename)), 'rb') as fid:
  20.   raw_bild = fid.read()
  21.   if(is_bmp(gruppe.filename)):
  22.     raw_bild = bmp_to_jpg(raw_bild)
  23.     decoded_bild = tf.image.decode_jpeg(raw_bild, channels = 3)
  24.     breite = tf.Session().run(tf.shape(decoded_bild)[0])
  25.     hoehe = tf.Session().run(tf.shape(decoded_bild)[1])
  26.     depth = tf.Session().run(tf.shape(decoded_bild)[2])
  27.  
  28.     dateiname = gruppe.filename.encode('utf8')
  29.     bild_format = b'JPEG'
  30.     xmins = []
  31.     xmaxs = []
  32.     ymins = []
  33.     ymaxs = []
  34.     klassen_text = []
  35.     klassen = []
  36.  
  37.     for index, zeile in gruppe.object.iterrows():
  38.       xmins.append(zeile['xmin'] / breite)
  39.       xmaxs.append(zeile['xmax'] / breite)
  40.       ymins.append(zeile['ymin'] / hoehe)
  41.       ymaxs.append(zeile['ymax'] / hoehe)
  42.            
  43.     klassen_text.append(zeile['class'].encode('utf8'))
  44.            
  45.     klassen.append(class_text_to_int(zeile['class']))
  46.  
  47.     tf_example = tf.train.Example(features=tf.train.Features(feature={
  48.     #'image/height': dataset_util.int64_feature(hoehe),
  49.     #'image/width': dataset_util.int64_feature(breite),
  50.     #'image/depth': dataset_util.int64_feature(depth),
  51.     #'image/filename': dataset_util.bytes_feature(dateiname),
  52.     #'image/source_id': dataset_util.bytes_feature(dateiname),
  53.     #'image/encoded': dataset_util.bytes_feature(raw_bild),
  54.     #'image/format': dataset_util.bytes_feature(bild_format),
  55.     'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
  56.     'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
  57.     'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
  58.     'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
  59.     'image/class/text': dataset_util.bytes_list_feature(klassen_text),
  60.     'image/class/label': dataset_util.int64_list_feature(klassen),
  61.         }))
  62.     return tf_example
  63.  
  64.   def main(_):
  65.     current_pfad = os.getcwd()
  66.     input_pfad = os.path.join(current_pfad, 'data/')
  67.     output_pfad = input_pfad
  68.     for csv_datei in os.listdir(input_pfad):
  69.       if csv_datei.endswith(".csv"):
  70.         output_file = csv_datei.replace("_labels.csv",".record")
  71.         output_file = os.path.join(output_pfad, output_file)
  72.         writer = tf.python_io.TFRecordWriter(output_file)
  73.         examples = pd.read_csv(os.path.join(input_pfad, csv_datei))
  74.         gruppiert = split(examples, 'filename')
  75.           for gruppe in gruppiert:
  76.             tf_example = create_tf_example(gruppe,
  77.             os.path.join(current_pfad+"/images/",
  78.             csv_datei.replace("_labels.csv","")))
  79.             writer.write(tf_example.SerializeToString())
  80.           writer.close()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement