Advertisement
Guest User

annotategpu

a guest
Sep 22nd, 2019
112
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 6.28 KB | None | 0 0
  1. from __future__ import absolute_import
  2. from __future__ import division
  3. from __future__ import print_function
  4.  
  5. import numpy as np
  6. import tensorflow as tf
  7. import argparse
  8. import os
  9. import time
  10. import pandas as pd
  11.  
  12. ### NEEDS tensorflow-gpu in venv to work (using 1.14)###
  13.  
  14. # restricts tf debug output to terminal (set to 0 for default behavior)
  15. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
  16.  
  17. # for performance analysis
  18. start_time = time.time()
  19.  
  20. def load_labels(label_file):
  21.     label = []
  22.     proto_as_ascii_lines = tf.io.gfile.GFile(label_file).readlines()
  23.     for l in proto_as_ascii_lines:
  24.         label.append(l.rstrip())
  25.     return label
  26.  
  27. def load_graph(model_file):
  28.     graph = tf.Graph()
  29.     graph_def = tf.compat.v1.GraphDef()
  30.     with open(model_file, "rb") as f:
  31.         graph_def.ParseFromString(f.read())
  32.     with graph.as_default():
  33.         tf.import_graph_def(graph_def)
  34.         # locks graph to prevent new operations from being added
  35.         graph.finalize()
  36.     return graph
  37.  
  38.  
  39. def read_tensor_from_image_file(file_name):
  40.     input_name = "file_reader"
  41.  
  42.     # adding data processing pipeline to CPU explicitly
  43.     with tf.device('/cpu:0'):
  44.         file_reader = tf.io.read_file(file_name, input_name)
  45.         if file_name.endswith(".png"):
  46.             image_reader = tf.image.decode_png(file_reader, channels=3, name="png_reader")
  47.         elif file_name.endswith(".gif"):
  48.             image_reader = tf.squeeze(tf.image.decode_gif(file_reader, name="gif_reader"))
  49.         elif file_name.endswith(".bmp"):
  50.             image_reader = tf.image.decode_bmp(file_reader, name="bmp_reader")
  51.         else:
  52.             image_reader = tf.image.decode_jpeg(file_reader, channels=3, name="jpeg_reader")
  53.             float_caster = tf.cast(image_reader, tf.float32)
  54.     # tensor dimensions (299, 299, 3)
  55.     return float_caster
  56.  
  57. def preprocess_image_batch(batch,
  58.                         input_height=299,
  59.                         input_width=299,
  60.                         input_mean=0,
  61.                         input_std=255):
  62.     resized = tf.compat.v1.image.resize_bilinear(batch, [input_height, input_width])
  63.     normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std])
  64.     sess = tf.compat.v1.Session()
  65.     result = sess.run(normalized)
  66.     return result
  67.  
  68. def generate_batches(filenames, batch_size):
  69.     # list of lists of size batch_size containing file names
  70.     batches = [filenames[i * batch_size:(i + 1) * batch_size] for i in range((len(filenames) + batch_size - 1) // batch_size)]  
  71.     num_batches = len(batches)
  72.     return batches, num_batches
  73.  
  74. if __name__ == "__main__":
  75.     #default args
  76.     file_name = "tensorflow/examples/label_image/data/grace_hopper.jpg"
  77.     model_file = \
  78.     "tensorflow/examples/label_image/data/inception_v3_2016_08_28_frozen.pb"
  79.     label_file = "tensorflow/examples/label_image/data/imagenet_slim_labels.txt"
  80.     input_height = 299
  81.     input_width = 299
  82.     input_mean = 0
  83.     input_std = 255
  84.     input_layer = "input"
  85.     output_layer = "InceptionV3/Predictions/Reshape_1"
  86.     #add arguments for classifier
  87.     parser = argparse.ArgumentParser()
  88.     parser.add_argument("--images", help="image directory to be processed")
  89.     parser.add_argument("--graph", help="graph/model to be executed")
  90.     parser.add_argument("--labels", help="name of file containing labels")
  91.     parser.add_argument("--output_file", help="path and name with which to write result file")
  92.     parser.add_argument("--batch_size", help="image batch size")
  93.     parser.add_argument("--input_height", type=int, help="input height")
  94.     parser.add_argument("--input_width", type=int, help="input width")
  95.     parser.add_argument("--input_mean", type=int, help="input mean")
  96.     parser.add_argument("--input_std", type=int, help="input std")
  97.     parser.add_argument("--input_layer", help="name of input layer")
  98.     parser.add_argument("--output_layer", help="name of output layer")
  99.     args = parser.parse_args()
  100.  
  101.     if args.graph:
  102.         model_file = args.graph
  103.     if args.images:
  104.         image_directory = args.images
  105.     if args.labels:
  106.         label_file = args.labels
  107.     if args.output_file:
  108.         output_file = args.output_file
  109.     if args.batch_size:
  110.         batch_size = int(args.batch_size)
  111.     else:
  112.         batch_size = 128 # default
  113.     if args.input_height:
  114.         input_height = args.input_height
  115.     if args.input_width:
  116.         input_width = args.input_width
  117.     if args.input_mean:
  118.         input_mean = args.input_mean
  119.     if args.input_std:
  120.         input_std = args.input_std
  121.     if args.input_layer:
  122.         input_layer = args.input_layer
  123.     if args.output_layer:
  124.         output_layer = args.output_layer
  125.  
  126.     # Loading tf graph and creating list of files to process
  127.     graph = load_graph(model_file)
  128.     file_list = []
  129.     for root, dirs, files in os.walk(os.path.abspath(image_directory)):
  130.         for file in files:
  131.                         file_list.append(os.path.join(root, file))
  132.  
  133.     input_name = "import/" + input_layer
  134.     output_name = "import/" + output_layer
  135.     input_operation = graph.get_operation_by_name(input_name)
  136.     output_operation = graph.get_operation_by_name(output_name)
  137.  
  138.     # generating batches
  139.     filenames = [os.path.join(image_directory, file) for file in os.listdir(image_directory)]
  140.     batches, num_batches = generate_batches(filenames, batch_size)
  141.    
  142.     # for tracking execution time
  143.     count = 0
  144.     batch_time = time.time()
  145.     batch_times = []
  146.     # preprocess imagery
  147.     for i in range(num_batches):
  148.         files = batches[i]
  149.         image_batch = list(map(read_tensor_from_image_file, batches[i]))
  150.         image_batch = preprocess_image_batch(image_batch)
  151.  
  152.         with tf.compat.v1.Session(graph=graph) as sess:
  153.             results = sess.run(output_operation.outputs[0], {
  154.                 input_operation.outputs[0]: image_batch})
  155.             labels = load_labels(label_file)
  156.             labels.append('file')
  157.             result_csv = pd.Dataframe(columns=labels)
  158.             # writing result files
  159.             file_num = 0
  160.             for result in results:
  161.                 cur_result = pd.DataFrame(columns=labels)
  162.                 top_k = result.argsort()[-5:][::-1]
  163.                 cur_result[file]
  164.                 cur_result['file'] = files[file_num]
  165.                 file_num += 1
  166.                 for i in top_k:
  167.                     print(labels[i], result[i])
  168.                     cur_result[labels[i]] = result[i]
  169.                 results_csv.append(cur_result)
  170.             # for logs
  171.             num_processed = len(image_batch)
  172.             count += num_processed
  173.             print("\n\nNum images processed: {}".format(count))
  174.             batch_times.append(time.time()-batch_time)
  175.             print("Time for last {} images: {}\n\n".format(num_processed, time.time()-batch_time))
  176.             batch_time = time.time()
  177.  
  178.     # final logs
  179.     print("Script took %s seconds to execute" % (time.time() - start_time))
  180.     print("Batch times: ")
  181.     print(batch_times)
  182.     result_file.close()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement