Advertisement
Guest User

Untitled

a guest
Apr 19th, 2019
76
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 6.99 KB | None | 0 0
  1. from __future__ import print_function
  2.  
  3. import numpy as np
  4. import os
  5. import sys
  6. import tensorflow as tf
  7. import argparse
  8. import cv2
  9. import time
  10.  
  11. from object_detection.utils import visualization_utils as vis_util
  12.  
  13. slim = tf.contrib.slim
  14.  
  15. def GetAllFilesListRecusive(path, extensions):
  16. files_all = []
  17. for root, subFolders, files in os.walk(path):
  18. for name in files:
  19. # linux tricks with .directory that still is file
  20. if not 'directory' in name and sum([ext in name for ext in extensions]) > 0:
  21. files_all.append(os.path.join(root, name))
  22. return files_all
  23.  
  24. snapshot_dir = './snapshots/'
  25.  
  26. SAVE_DIR = './output/'
  27.  
  28. def calculate_perfomance(sess, input, output, shape, runs = 1000, batch_size = 1):
  29.  
  30. start = time.time()
  31.  
  32. print('Calculating inference time on size', shape)
  33.  
  34. # To exclude numpy generating time
  35. N = 100
  36. for i in range(0, N):
  37. img = np.random.random((batch_size, shape[0], shape[1], 3))
  38. stop = time.time()
  39. time_for_generate = (stop - start) / N
  40.  
  41. # warmup
  42. sess.run([output],
  43. feed_dict={input: img})
  44.  
  45. start = time.time()
  46. for i in range(runs):
  47. img = np.random.random((batch_size, shape[0], shape[1], 3))
  48. sess.run([output],
  49. feed_dict={input: img})
  50. stop = time.time()
  51.  
  52. inf_time = ((stop - start) / float(runs)) - time_for_generate
  53.  
  54. print('Average inference time: {}'.format(inf_time))
  55.  
  56.  
  57. def get_arguments():
  58. parser = argparse.ArgumentParser(description="Object Detection Inference")
  59. parser.add_argument("--img-path", type=str, default='./input',
  60. help="Path to the RGB image file.",
  61. required=False)
  62. parser.add_argument("--save-dir", type=str, default=SAVE_DIR,
  63. help="Path to save output.")
  64. parser.add_argument("--snapshots-dir", type=str, default=snapshot_dir,
  65. help="Path to checkpoints.")
  66. parser.add_argument("--pb-file", type=str, default='',
  67. help="Path to to pb file, alternative for checkpoint. If set, checkpoints will be ignored")
  68. parser.add_argument("--weighted", action="store_true", default=False,
  69. help="If true, will output weighted images")
  70. parser.add_argument("--batch-size", type = int, default = 1,
  71. help="Size of batch for time measure")
  72. parser.add_argument("--measure-time", action="store_true", default=False,
  73. help="Evaluate only model inference time")
  74. parser.add_argument("--runs", type=int, default=100,
  75. help="Repeats for time measure. More runs - longer testing - more precise results")
  76. parser.add_argument("--with_score", action="store_true", default=False,
  77. help="If true will try to calculate score basing on dirs as classes")
  78.  
  79.  
  80. return parser.parse_args()
  81.  
  82. def save(saver, sess, logdir, step):
  83. model_name = 'model.ckpt'
  84. checkpoint_path = os.path.join(logdir, model_name)
  85.  
  86. if not os.path.exists(logdir):
  87. os.makedirs(logdir)
  88. saver.save(sess, checkpoint_path, global_step=step)
  89. print('The checkpoint has been created.')
  90.  
  91. def load(saver, sess, ckpt_path):
  92. saver.restore(sess, ckpt_path)
  93. print("Restored model parameters from {}".format(ckpt_path))
  94.  
  95.  
  96. def load_img(img_path, h, w):
  97. if os.path.isfile(img_path):
  98. print('successful load img: {0}'.format(img_path))
  99. else:
  100. print('not found file: {0}'.format(img_path))
  101. sys.exit(0)
  102.  
  103. filename = img_path.split('/')[-1]
  104. img = cv2.imread(img_path)
  105.  
  106. if h and w:
  107. img = cv2.resize(img, (int(w), int(h)))
  108. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  109. img = img / 255.0
  110. img = img - 0.5
  111. img = img * 2.0
  112. print('input image shape: ', img.shape)
  113.  
  114. return img, filename
  115.  
  116. def load_classification_pb(class_filename, mem_frac = 0.5, input_name = 'input',
  117. output_name = 'InceptionV4/Logits/Predictions:0', # 'MobilenetV1/Predictions/Reshape_1:0',
  118. type = tf.float32):
  119.  
  120. # Load classification model
  121. classification_graph = tf.Graph()
  122. with classification_graph.as_default():
  123. class_graph_def = tf.GraphDef()
  124. with tf.gfile.GFile(class_filename, 'rb') as fid:
  125. serialized_graph = fid.read()
  126. class_graph_def.ParseFromString(serialized_graph)
  127.  
  128. class_image = tf.placeholder(type, shape=(None, None, None, 3))
  129.  
  130. tf.import_graph_def(class_graph_def, {input_name : class_image}, name = '')
  131. predictions = classification_graph.get_tensor_by_name(output_name)
  132.  
  133. config = tf.ConfigProto()
  134. config.gpu_options.per_process_gpu_memory_fraction = mem_frac
  135. config.allow_soft_placement = True
  136. config.log_device_placement = False
  137. sess = tf.Session(graph = classification_graph, config = config)
  138.  
  139. width = None
  140. height = None
  141. labels = None
  142. try:
  143. shape_tensor = classification_graph.get_tensor_by_name('input_size:0')
  144. labels_tensor = classification_graph.get_tensor_by_name('label_names:0')
  145. shape, labels = sess.run([shape_tensor, labels_tensor])
  146. width, height, _ = shape
  147. print(shape, labels)
  148. except:
  149. pass
  150.  
  151. return class_image, predictions, classification_graph, sess, width, height, labels
  152.  
  153. def main():
  154.  
  155. args = get_arguments()
  156.  
  157. if args.img_path[-4] != '.':
  158. files = GetAllFilesListRecusive(args.img_path, ['.jpg', '.jpeg', '.png'])
  159. else:
  160. files = [args.img_path]
  161.  
  162. image_tensor, predictions, graph, sess, width, height, labels = load_classification_pb(args.pb_file)
  163.  
  164. if args.measure_time:
  165. calculate_perfomance(sess, image_tensor, predictions, (height, width), args.runs, args.batch_size)
  166. quit()
  167.  
  168.  
  169. total = 0
  170. correct = 0
  171. for path in files:
  172.  
  173. img, filename = load_img(path, height, width)
  174. if args.with_score:
  175. t = path[ : path.rfind('/')]
  176. cl = t[t.rfind('/') + 1 : ]
  177.  
  178. image_np_expanded = np.expand_dims(img, axis = 0)
  179.  
  180. # if args.pb_file != '':
  181. # img = np.expand_dims(img, axis = 0)
  182.  
  183. t = time.time()
  184. preds = sess.run(
  185. [predictions], feed_dict = {image_tensor: image_np_expanded})[0][0]
  186.  
  187. print('time: ', (time.time() - t) * 1000.0)
  188. indx = np.argmax(preds)
  189. if args.with_score:
  190. print(labels[indx].decode("utf-8"), cl)
  191. if labels[indx].decode("utf-8") == cl:
  192. correct = correct + 1
  193. total = total + 1
  194. print('class: ', labels[indx])
  195.  
  196. if args.with_score:
  197. print('Correct score: ', (correct / float(total)) * 100.0)
  198.  
  199. if __name__ == '__main__':
  200. main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement